mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13790 from gnecula:dim_as_value
PiperOrigin-RevId: 499943045
This commit is contained in:
commit
33c1e5d540
@ -415,6 +415,8 @@ class Trace:
|
||||
self.sublevel = sublevel
|
||||
|
||||
def full_raise(self, val) -> Tracer:
|
||||
if hasattr(val, "dimension_as_value"): # Used for shape_poly._DimPolynomial
|
||||
val = val.dimension_as_value()
|
||||
if not isinstance(val, Tracer):
|
||||
return self.pure(val)
|
||||
val._assert_live()
|
||||
@ -1880,9 +1882,10 @@ def dimension_as_value(d: DimSize):
|
||||
|
||||
Has the same abstract value as Python constants.
|
||||
"""
|
||||
if isinstance(d, Tracer): return d
|
||||
handler, ds = _dim_handler_and_canonical(d)
|
||||
return handler.as_value(*ds)
|
||||
if isinstance(d, (int, Tracer, np.int32, np.int64)): return d
|
||||
# For shape_poly._DimPolynomial
|
||||
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
|
||||
return operator.index(d)
|
||||
|
||||
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
||||
if isinstance(dim, Tracer) and config.jax_dynamic_shapes:
|
||||
|
@ -451,6 +451,10 @@ class _DimPolynomial():
|
||||
def get_aval(dim: "_DimPolynomial"):
|
||||
return dim_as_value_abstract(dim)
|
||||
|
||||
def dimension_as_value(self):
|
||||
"""Turns a dimension size into a Jax value that we can compute with."""
|
||||
return _dim_as_value(self)
|
||||
|
||||
def __jax_array__(self):
|
||||
# Used for implicit coercions of polynomials as JAX arrays
|
||||
return _dim_as_value(self)
|
||||
|
@ -1363,8 +1363,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
(jnp.array, "array"),
|
||||
(jnp.sin, "sin"),
|
||||
(lambda x: x, "id"),
|
||||
(core.dimension_as_value, "dimension_as_value"),
|
||||
]
|
||||
|
||||
])
|
||||
def test_poly_unary_op(self, *, op=jnp.array):
|
||||
if config.jax_enable_x64:
|
||||
@ -1425,7 +1425,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_mean0(self):
|
||||
def f_jax(x): # x: f32[b, 4]
|
||||
return jnp.sum(x, axis=0) / x.shape[0]
|
||||
|
||||
check_shape_poly(self,
|
||||
f_jax,
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user