mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14053 from gnecula:tf_weak_dim_as_value
PiperOrigin-RevId: 502946430
This commit is contained in:
commit
96b67bcb9f
@ -461,6 +461,7 @@ class _DimPolynomial():
|
||||
|
||||
core.pytype_aval_mappings[_DimPolynomial] = _DimPolynomial.get_aval
|
||||
xla.pytype_aval_mappings[_DimPolynomial] = _DimPolynomial.get_aval
|
||||
dtypes._weak_types.append(_DimPolynomial)
|
||||
|
||||
def _ensure_poly(p: DimSize,
|
||||
operation_name: str) -> _DimPolynomial:
|
||||
|
@ -42,6 +42,7 @@ from jax.experimental.jax2tf.tests import tf_test_util
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
from jax.config import config
|
||||
from jax._src.config import numpy_dtype_promotion
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -1444,6 +1445,25 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0])
|
||||
|
||||
def test_dim_as_value_weak_type(self):
|
||||
def f_jax(x): # x: f32[b]
|
||||
d0 = jnp.array(x.shape[0]) # in JAX should have weak_type=True
|
||||
if isinstance(d0, core.Tracer):
|
||||
self.assertTrue(d0.aval.weak_type), d0
|
||||
|
||||
# And an implicit conversion to array
|
||||
d1 = x.shape[0] + jnp.array(4)
|
||||
if isinstance(d1, core.Tracer):
|
||||
self.assertTrue(d1.aval.weak_type), d1
|
||||
return d0 + np.array(5., dtype=np.float32) + d1
|
||||
|
||||
with numpy_dtype_promotion("strict"):
|
||||
# strict type promotion is sensitive to weak_types
|
||||
check_shape_poly(self,
|
||||
f_jax,
|
||||
arg_descriptors=[RandArg((3,), _f32)],
|
||||
poly_axes=[0])
|
||||
|
||||
@unittest.skip('Failing at HEAD. Reenable after b/264913007 is fixed')
|
||||
def test_vmap_while(self):
|
||||
def cond_func(x): # x: f32[3]
|
||||
|
Loading…
x
Reference in New Issue
Block a user