make jax.eval_shape duck typing more robust

This commit is contained in:
Matthew Johnson 2021-02-09 11:19:09 -08:00
parent 5b0c47e856
commit 7394048782
3 changed files with 21 additions and 2 deletions

View File

@ -2318,8 +2318,14 @@ def eval_shape(fun: Callable, *args, **kwargs):
>>> print(out.dtype)
float32
"""
def dtype(x):
try:
return dtypes.result_type(x)
except ValueError:
return dtypes.result_type(getattr(x, 'dtype'))
def abstractify(x):
return ShapedArray(np.shape(x), dtypes.result_type(x))
return ShapedArray(np.shape(x), dtype(x))
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,

View File

@ -324,7 +324,9 @@ def dtype(x):
def result_type(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(jakevdp): propagate weak_type to the result.
if len(args) < 2:
if len(args) == 0:
raise ValueError("at least one array or dtype is required")
if len(args) == 1:
return canonicalize_dtype(dtype(args[0]))
# TODO(jakevdp): propagate weak_type to the result when necessary.
return canonicalize_dtype(_least_upper_bound(*{_jax_type(arg) for arg in args}))

View File

@ -1222,6 +1222,17 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(out_shape.shape, (3, 5))
def test_eval_shape_duck_typing2(self):
# https://github.com/google/jax/issues/5683
class EasyDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
x = EasyDict(shape=(3,), dtype=np.dtype('float32'))
out_shape = api.eval_shape(lambda x: x, x) # doesn't crash
self.assertEqual(out_shape.shape, (3,))
def test_issue_871(self):
T = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
x = jnp.array([1, 2, 3])