Merge pull request #13790 from gnecula:dim_as_value

PiperOrigin-RevId: 499943045
This commit is contained in:
jax authors 2023-01-05 11:05:09 -08:00
commit 33c1e5d540
3 changed files with 11 additions and 5 deletions

View File

@ -415,6 +415,8 @@ class Trace:
self.sublevel = sublevel
def full_raise(self, val) -> Tracer:
if hasattr(val, "dimension_as_value"): # Used for shape_poly._DimPolynomial
val = val.dimension_as_value()
if not isinstance(val, Tracer):
return self.pure(val)
val._assert_live()
@ -1880,9 +1882,10 @@ def dimension_as_value(d: DimSize):
Has the same abstract value as Python constants.
"""
if isinstance(d, Tracer): return d
handler, ds = _dim_handler_and_canonical(d)
return handler.as_value(*ds)
if isinstance(d, (int, Tracer, np.int32, np.int64)): return d
# For shape_poly._DimPolynomial
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
return operator.index(d)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
if isinstance(dim, Tracer) and config.jax_dynamic_shapes:

View File

@ -451,6 +451,10 @@ class _DimPolynomial():
def get_aval(dim: "_DimPolynomial"):
return dim_as_value_abstract(dim)
def dimension_as_value(self):
"""Turns a dimension size into a Jax value that we can compute with."""
return _dim_as_value(self)
def __jax_array__(self):
# Used for implicit coercions of polynomials as JAX arrays
return _dim_as_value(self)

View File

@ -1363,8 +1363,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
(jnp.array, "array"),
(jnp.sin, "sin"),
(lambda x: x, "id"),
(core.dimension_as_value, "dimension_as_value"),
]
])
def test_poly_unary_op(self, *, op=jnp.array):
if config.jax_enable_x64:
@ -1425,7 +1425,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
def test_mean0(self):
def f_jax(x): # x: f32[b, 4]
return jnp.sum(x, axis=0) / x.shape[0]
check_shape_poly(self,
f_jax,
arg_descriptors=[RandArg((3, 4), _f32)],