fix duck typing in jax.eval_shape (cf. #798)

This commit is contained in:
Matthew Johnson 2019-06-01 09:48:28 -07:00
parent 11c512a194
commit dda95df519
2 changed files with 30 additions and 9 deletions

View File

@ -1083,21 +1083,26 @@ def eval_shape(fun, *args, **kwargs):
Args:
*args: a positional argument tuple of arrays, scalars, or (nested) standard
Python containers (pytrees) of those types. Since only the ``shape`` and
``dtype`` attributes are accessed, only values that duck-type arrays are
required, rather than real ndarrays.
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
those types. Since only the ``shape`` and ``dtype`` attributes are
accessed, only values that duck-type arrays are required, rather than real
ndarrays. The duck-typed objects cannot be namedtuples because those are
treated as standard Python containers. See the example below.
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
Python containers (pytrees) of those types. As in ``args``, array values
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
For example:
>>> f = lambda A, b, x: np.tanh(np.dot(A, x) + b)
>>> MyArgArray = collections.namedtuple("MyArgArray", ["shape", "dtype"])
>>> f = lambda A, x: np.tanh(np.dot(A, x))
>>> class MyArgArray(object):
... def __init__(self, shape, dtype):
... self.shape = shape
... self.dtype = dtype
...
>>> A = MyArgArray((2000, 3000), np.float32)
>>> b = MyArgArray((2000,), np.float32)
>>> x = MyArgArray((3000, 1000), np.float32)
>>> out_shape = jax.eval_shape(f, A, b, x) # no FLOPs performed
>>> out_shape = jax.eval_shape(f, A, x) # no FLOPs performed
>>> print(out_shape)
(2000, 1000)
"""

View File

@ -635,12 +635,12 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(out_shape, (2,))
def test_eval_shape_output_dict(self):
def fun4(x, y):
def fun(x, y):
return {'hi': x[0] + x[1] + y}
x = (np.ones(2), np.ones(2))
y = 3.
out_shape = api.eval_shape(fun4, x, y)
out_shape = api.eval_shape(fun, x, y)
self.assertEqual(out_shape, {'hi': (2,)})
@ -653,6 +653,22 @@ class APITest(jtu.JaxTestCase):
self.assertRaises(TypeError, lambda: api.eval_shape(fun, x, y))
def test_eval_shape_duck_typing(self):
def fun(A, b, x):
return np.dot(A, x) + b
class MyArgArray(object):
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
A = MyArgArray((3, 4), np.float32)
b = MyArgArray((5,), np.float32)
x = MyArgArray((4, 5), np.float32)
out_shape = api.eval_shape(fun, A, b, x)
self.assertEqual(out_shape, (3, 5))
if __name__ == '__main__':
absltest.main()