mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix duck typing in jax.eval_shape (cf. #798)
This commit is contained in:
parent
11c512a194
commit
dda95df519
19
jax/api.py
19
jax/api.py
@ -1083,21 +1083,26 @@ def eval_shape(fun, *args, **kwargs):
|
||||
|
||||
Args:
|
||||
*args: a positional argument tuple of arrays, scalars, or (nested) standard
|
||||
Python containers (pytrees) of those types. Since only the ``shape`` and
|
||||
``dtype`` attributes are accessed, only values that duck-type arrays are
|
||||
required, rather than real ndarrays.
|
||||
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
||||
those types. Since only the ``shape`` and ``dtype`` attributes are
|
||||
accessed, only values that duck-type arrays are required, rather than real
|
||||
ndarrays. The duck-typed objects cannot be namedtuples because those are
|
||||
treated as standard Python containers. See the example below.
|
||||
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
|
||||
Python containers (pytrees) of those types. As in ``args``, array values
|
||||
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
||||
|
||||
For example:
|
||||
|
||||
>>> f = lambda A, b, x: np.tanh(np.dot(A, x) + b)
|
||||
>>> MyArgArray = collections.namedtuple("MyArgArray", ["shape", "dtype"])
|
||||
>>> f = lambda A, x: np.tanh(np.dot(A, x))
|
||||
>>> class MyArgArray(object):
|
||||
... def __init__(self, shape, dtype):
|
||||
... self.shape = shape
|
||||
... self.dtype = dtype
|
||||
...
|
||||
>>> A = MyArgArray((2000, 3000), np.float32)
|
||||
>>> b = MyArgArray((2000,), np.float32)
|
||||
>>> x = MyArgArray((3000, 1000), np.float32)
|
||||
>>> out_shape = jax.eval_shape(f, A, b, x) # no FLOPs performed
|
||||
>>> out_shape = jax.eval_shape(f, A, x) # no FLOPs performed
|
||||
>>> print(out_shape)
|
||||
(2000, 1000)
|
||||
"""
|
||||
|
@ -635,12 +635,12 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(out_shape, (2,))
|
||||
|
||||
def test_eval_shape_output_dict(self):
|
||||
def fun4(x, y):
|
||||
def fun(x, y):
|
||||
return {'hi': x[0] + x[1] + y}
|
||||
|
||||
x = (np.ones(2), np.ones(2))
|
||||
y = 3.
|
||||
out_shape = api.eval_shape(fun4, x, y)
|
||||
out_shape = api.eval_shape(fun, x, y)
|
||||
|
||||
self.assertEqual(out_shape, {'hi': (2,)})
|
||||
|
||||
@ -653,6 +653,22 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
self.assertRaises(TypeError, lambda: api.eval_shape(fun, x, y))
|
||||
|
||||
def test_eval_shape_duck_typing(self):
|
||||
def fun(A, b, x):
|
||||
return np.dot(A, x) + b
|
||||
|
||||
class MyArgArray(object):
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
A = MyArgArray((3, 4), np.float32)
|
||||
b = MyArgArray((5,), np.float32)
|
||||
x = MyArgArray((4, 5), np.float32)
|
||||
out_shape = api.eval_shape(fun, A, b, x)
|
||||
|
||||
self.assertEqual(out_shape, (3, 5))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user