mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18516 from gnecula:poly_non_negative
PiperOrigin-RevId: 583303763
This commit is contained in:
commit
05bba6e790
@ -1126,7 +1126,7 @@ def _call_exported_abstract_eval(
|
||||
# `f32[c, d]` it is better to fail because `c == d` is inconclusive, than
|
||||
# succeed and add a compile-time check that `c == d`. In the latter case,
|
||||
# it would be ambiguous whether we should continue tracing with a result
|
||||
# a type `f32[c]` or `f32[d]`.
|
||||
# of type `f32[c]` or `f32[d]`.
|
||||
shape_constraints.check_statically(synthetic_eval)
|
||||
exported_dim_values = [synthetic_eval.evaluate(solution[var])
|
||||
for var in exported_dim_vars]
|
||||
|
@ -242,7 +242,14 @@ class _DimAtom:
|
||||
elif self.operation == _DimAtom.MOD:
|
||||
return divmod(*operand_values)[1] # type: ignore
|
||||
elif self.operation == _DimAtom.NON_NEGATIVE:
|
||||
return lax.max(operand_values[0], 0)
|
||||
operand = operand_values[0]
|
||||
if core.is_constant_dim(operand):
|
||||
return max(operand, 0)
|
||||
if core.is_symbolic_dim(operand):
|
||||
return core.non_negative_dim(operand)
|
||||
# In the context of `evaluate` dimension variables may be mapped to
|
||||
# JAX Tracers.
|
||||
return lax.max(operand, 0)
|
||||
else:
|
||||
assert False, self.operation
|
||||
|
||||
|
@ -682,6 +682,28 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(f_jax(x), res)
|
||||
|
||||
def test_poly_expressions(self):
|
||||
# Calling an Exported module whose output shape contains symbolic
|
||||
# expressions
|
||||
def output_shape(b):
|
||||
return (b + b, b - b, b * b,
|
||||
(b + 13) // b, (b + 13) % b,
|
||||
core.non_negative_dim(b - 5))
|
||||
def f(x): # x: f32[b]
|
||||
b = x.shape[0]
|
||||
return jnp.ones(output_shape(b), dtype=x.dtype)
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
exp = export.export(f)(export.poly_spec(x.shape, x.dtype, "b"))
|
||||
# Call with static shapes
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(res, f(x))
|
||||
|
||||
# Now re-export with shape polymorphism
|
||||
x_spec = export.poly_spec(x.shape, x.dtype, "a")
|
||||
exp2 = export.export(export.call_exported(exp))(x_spec)
|
||||
a = x_spec.shape[0]
|
||||
self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
|
||||
|
||||
def test_with_sharding(self):
|
||||
nr_devices = 2
|
||||
if len(jax.devices()) < nr_devices:
|
||||
|
Loading…
x
Reference in New Issue
Block a user