Merge pull request #14053 from gnecula:tf_weak_dim_as_value

PiperOrigin-RevId: 502946430
This commit is contained in:
jax authors 2023-01-18 12:21:29 -08:00
commit 96b67bcb9f
2 changed files with 21 additions and 0 deletions

View File

@ -461,6 +461,7 @@ class _DimPolynomial():
core.pytype_aval_mappings[_DimPolynomial] = _DimPolynomial.get_aval
xla.pytype_aval_mappings[_DimPolynomial] = _DimPolynomial.get_aval
dtypes._weak_types.append(_DimPolynomial)
def _ensure_poly(p: DimSize,
operation_name: str) -> _DimPolynomial:

View File

@ -42,6 +42,7 @@ from jax.experimental.jax2tf.tests import tf_test_util
import tensorflow as tf # type: ignore[import]
from jax.config import config
from jax._src.config import numpy_dtype_promotion
config.parse_flags_with_absl()
@ -1444,6 +1445,25 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
arg_descriptors=[RandArg((3, 4), _f32)],
poly_axes=[0])
def test_dim_as_value_weak_type(self):
def f_jax(x): # x: f32[b]
d0 = jnp.array(x.shape[0]) # in JAX should have weak_type=True
if isinstance(d0, core.Tracer):
self.assertTrue(d0.aval.weak_type), d0
# And an implicit conversion to array
d1 = x.shape[0] + jnp.array(4)
if isinstance(d1, core.Tracer):
self.assertTrue(d1.aval.weak_type), d1
return d0 + np.array(5., dtype=np.float32) + d1
with numpy_dtype_promotion("strict"):
# strict type promotion is sensitive to weak_types
check_shape_poly(self,
f_jax,
arg_descriptors=[RandArg((3,), _f32)],
poly_axes=[0])
@unittest.skip('Failing at HEAD. Reenable after b/264913007 is fixed')
def test_vmap_while(self):
def cond_func(x): # x: f32[3]