mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
make jax.eval_shape duck typing more robust
This commit is contained in:
parent
5b0c47e856
commit
7394048782
@ -2318,8 +2318,14 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
>>> print(out.dtype)
|
||||
float32
|
||||
"""
|
||||
def dtype(x):
|
||||
try:
|
||||
return dtypes.result_type(x)
|
||||
except ValueError:
|
||||
return dtypes.result_type(getattr(x, 'dtype'))
|
||||
|
||||
def abstractify(x):
|
||||
return ShapedArray(np.shape(x), dtypes.result_type(x))
|
||||
return ShapedArray(np.shape(x), dtype(x))
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
|
||||
|
@ -324,7 +324,9 @@ def dtype(x):
|
||||
def result_type(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
# TODO(jakevdp): propagate weak_type to the result.
|
||||
if len(args) < 2:
|
||||
if len(args) == 0:
|
||||
raise ValueError("at least one array or dtype is required")
|
||||
if len(args) == 1:
|
||||
return canonicalize_dtype(dtype(args[0]))
|
||||
# TODO(jakevdp): propagate weak_type to the result when necessary.
|
||||
return canonicalize_dtype(_least_upper_bound(*{_jax_type(arg) for arg in args}))
|
||||
|
@ -1222,6 +1222,17 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(out_shape.shape, (3, 5))
|
||||
|
||||
def test_eval_shape_duck_typing2(self):
|
||||
# https://github.com/google/jax/issues/5683
|
||||
class EasyDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
x = EasyDict(shape=(3,), dtype=np.dtype('float32'))
|
||||
out_shape = api.eval_shape(lambda x: x, x) # doesn't crash
|
||||
self.assertEqual(out_shape.shape, (3,))
|
||||
|
||||
def test_issue_871(self):
|
||||
T = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
|
||||
x = jnp.array([1, 2, 3])
|
||||
|
Loading…
x
Reference in New Issue
Block a user