1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

[jax2tf] Deprecate jax2tf with native_serialization=False or enable_xla=False.

Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.

PiperOrigin-RevId: 652749891
This commit is contained in:
George Necula 2024-07-16 02:04:59 -07:00 committed by jax authors
parent 28ffa25496
commit d34a6e9ce2
7 changed files with 96 additions and 275 deletions

@ -32,6 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
* HLO lowering rules should no longer wrap singleton ir.Values in tuples.
Instead, return singleton ir.Values unwrapped. Support for wrapped values
will be removed in a future version of JAX.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or `enable_xla=False` is now deprecated and this support will be removed in
a future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
## jaxlib 0.4.31

@ -916,7 +916,8 @@ jax2tf_default_native_serialization = bool_state(
help=(
'Sets the default value of the native_serialization parameter to '
'jax2tf.convert. Prefer using the parameter instead of the flag, '
'the flag may be removed in the future.'
'the flag may be removed in the future. '
'Starting with JAX 0.4.31 non-native serialization is deprecated.'
)
)

@ -59,7 +59,8 @@ For backwards compatibility purposes, and for special uses,
the JAX-TensorFlow interoperation APIs can be used also
in a **graph serialization** mode (the only mode available before version 0.4.7,
and the default mode before JAX version 0.4.15),
without going through StableHLO.
without going through StableHLO. (Starting with JAX version 0.4.31 the
graph serialization mode is deprecated. It will be removed in the near future).
* For calling JAX functions from TensorFlow,
it is possible to request that the JAX function be lowered with one TensorFlow

@ -304,6 +304,7 @@ def convert(fun_jax: Callable,
so the lowering tries harder to use non-XLA TF ops to lower the
function and aborts if this is not possible. Cannot be set to `False`
when using `native_serialization`.
Starting with JAX 0.4.31 support for `enable_xla=False` is deprecated.
native_serialization: serialize the JAX function natively to
StableHLO with compatibility guarantees. This makes it easier to have
confidence that the code executed when calling this function from
@ -312,6 +313,7 @@ def convert(fun_jax: Callable,
is set to `False` or to the configuration flag
`--jax2tf_default_native_serialization` otherwise.
Native serialization cannot be used with `enable_xla=False`.
Starting with JAX 0.4.31 support for non-native serialization is deprecated.
native_serialization_platforms: In conjunction with
`native_serialization`, specify the platform(s)
for which to lower the code. Must be a tuple of
@ -327,12 +329,17 @@ def convert(fun_jax: Callable,
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
only TensorFlow ops and thus can be called from a TensorFlow program.
"""
if not enable_xla:
warnings.warn("jax2tf.convert with enable_xla=False is deprecated.")
if native_serialization is DEFAULT_NATIVE_SERIALIZATION:
if not enable_xla:
native_serialization = False
else:
native_serialization = config.jax2tf_default_native_serialization.value
if not native_serialization:
warnings.warn(
"jax2tf.convert with native_serialization=False is deprecated.")
if native_serialization and not enable_xla:
raise ValueError(
"native_serialization is not supported with enable_xla=False")

@ -860,17 +860,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
return ad_checkpoint.checkpoint_name(jnp.sin(x), "sin")
jax2tf.convert(f_jax)(1.) # No error.
def test_convert_nullary_func(self):
# Even nullary functions are converted to TF (as opposed to constant-folded
# in JAX prior to conversion).
def f_jax():
return jnp.sin(1.)
f_tf = jax2tf.convert(f_jax)
# for native serialization the HLO we get from TF is constant-folded, so this
# test fails.
if not config.jax2tf_default_native_serialization.value:
self.assertIn("sine(", self.TfToHlo(f_tf))
def test_convert_of_nested_independent_jit(self):
def func(x):
def inner1(y):
@ -1132,31 +1121,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
np.full_like(x[1], fill_value=2.)),
(grad_tf[0].numpy(), grad_tf[1].numpy()))
@jtu.skip_on_flag("jax2tf_default_native_serialization", True)
def test_enable_xla(self):
# Tests that enable_xla flag is properly scoped to a conversion.
def fun(x):
# lax.reduce is unlikely to ever be convertible with enable_xla=False
return lax.reduce(x, np.float32(0), lambda v, acc: v + acc, dimensions=(0, 1))
tf_fun_with_xla = jax2tf.convert(fun, enable_xla=True)
tf_fun_without_xla = jax2tf.convert(fun, enable_xla=False)
x = np.ones((2, 3), dtype=np.float32)
self.assertAllClose(fun(x), tf_fun_with_xla(x))
with self.assertRaisesRegex(NotImplementedError,
"Call to reduce cannot be converted with enable_xla=False"):
tf_fun_without_xla(x)
# Now in reverse order (we had bugs with the management of enable_xla global)
tf_fun2_without_xla = jax2tf.convert(lambda x: fun(x), enable_xla=False)
tf_fun2_with_xla = jax2tf.convert(lambda x: fun(x), enable_xla=True)
with self.assertRaisesRegex(NotImplementedError,
"Call to reduce cannot be converted with enable_xla=False"):
tf_fun2_without_xla(x)
self.assertAllClose(fun(x), tf_fun2_with_xla(x))
def test_device_array_arg(self):
self.ConvertAndCompare(jnp.sin, jnp.zeros((2, 3), jnp.float32))
@ -1717,35 +1681,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
res,
x + _testing_multi_platform_to_add[tf_device_jax_platform])
def test_cond_primitive(self):
def f_cond(x):
return lax.cond(x < 1.0, jnp.cos, jnp.sin, x)
self.ConvertAndCompare(f_cond, np.pi / 4, enable_xla=False)
self.ConvertAndCompare(f_cond, np.pi / 2, enable_xla=False)
f_cond_tf = jax2tf.convert(f_cond, enable_xla=False)
self.assertNotIn("switch_case", self.TfToHlo(f_cond_tf, np.pi))
def f_switch(x):
return lax.switch(jnp.int32(x), [jnp.cos, jnp.sin, lambda _: 42.0], x)
self.ConvertAndCompare(f_switch, np.pi / 4, enable_xla=False)
self.ConvertAndCompare(f_switch, np.pi / 2, enable_xla=False)
self.ConvertAndCompare(f_switch, 2 * np.pi, enable_xla=False)
f_switch_tf = jax2tf.convert(f_switch, enable_xla=False)
self.assertIn("switch_case", self.TfToHlo(f_switch_tf, np.pi))
@jtu.skip_on_flag("jax2tf_default_native_serialization", False)
def test_ragged_dot(self):
dtype = np.float32
m, k, n, num_groups = 5, 4, 3, 2
lhs = np.arange(m * k, dtype=dtype).reshape((m, k))
rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n))
group_sizes = np.array([3, 2], dtype=np.int32)
self.ConvertAndCompare(jax.lax.ragged_dot, lhs, rhs, group_sizes)
@jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):

@ -113,9 +113,6 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
dtype=harness.dtype), limitations))
func_jax = harness.dyn_fun
args = harness.dyn_args_maker(self.rng())
enable_xla = harness.params.get("enable_xla", True)
if config.jax2tf_default_native_serialization.value and not enable_xla:
raise unittest.SkipTest("native_serialization not supported with enable_xla=False")
if ("eigh" == harness.group_name and
np.complex64 == harness.dtype and
@ -142,8 +139,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
try:
with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
self.ConvertAndCompare(func_jax, *args, limitations=limitations,
enable_xla=enable_xla)
self.ConvertAndCompare(func_jax, *args, limitations=limitations)
except Exception as e:
# TODO(b/264596006): custom calls are not registered properly with TF in OSS
if (config.jax2tf_default_native_serialization.value and

@ -85,7 +85,6 @@ class PolyHarness(Harness):
polymorphic_constraints: Sequence[str] = (),
input_signature: Sequence[tf.TensorSpec] | None = None,
expected_output_signature: tf.TensorSpec | None = None,
enable_xla: bool = True,
expect_error: tuple[Any | None, str | None] = (None, None),
skip_jax_run: bool = False,
check_result: bool = True,
@ -105,7 +104,6 @@ class PolyHarness(Harness):
input_signature: For `tf.function.get_concrete_function`. If missing,
generated from `polymorphic_shapes`.
expected_output_signature: the expected inferred output shape.
enable_xla: For `jax2tf.convert`.
expect_error: a pair of an Exception type and a regular expression to
match the expected exception string.
skip_jax_run: If True, then neither the JAX nor the TF functions are
@ -126,29 +124,11 @@ class PolyHarness(Harness):
self.expected_output_signature = expected_output_signature
self.skip_jax_run = skip_jax_run
self.expect_error = expect_error
self.enable_xla = enable_xla
self.tol = tol
self.check_result = check_result
self.limitations = limitations
self.override_jax_config_flags = override_jax_config_flags
# Replicate the harness for both enable and disable xla
def both_enable_and_disable_xla(self) -> tuple[PolyHarness, PolyHarness]:
assert self.enable_xla
other = PolyHarness(self.group_name,
f"{self.name}_enable_xla_False",
self.fun,
arg_descriptors=self.arg_descriptors,
polymorphic_shapes=self.polymorphic_shapes,
polymorphic_constraints=self.polymorphic_constraints,
input_signature=self.input_signature,
expected_output_signature=self.expected_output_signature,
expect_error=self.expect_error,
tol=self.tol,
enable_xla=False)
self.name = f"{self.name}_enable_xla_True"
return (self, other)
def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> jax.Array | None:
def log_message(extra: str):
return f"[{tst._testMethodName}]: {extra}"
@ -189,8 +169,7 @@ class PolyHarness(Harness):
stack.enter_context(tst.assertRaisesRegex(expect_error_type, expect_error_regex))
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes,
polymorphic_constraints=self.polymorphic_constraints,
enable_xla=self.enable_xla)
polymorphic_constraints=self.polymorphic_constraints)
# Run in tf.Eager mode first, because it is friendlier to debuggers
res_tf = f_tf(*args) if not self.skip_jax_run else None
f_tf_func = tf.function(
@ -1125,41 +1104,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
x0 = np.array([], np.float32)
self.assertEqual(jnp.array([0.], dtype=np.float32), f1_jax(x0))
# In graph serialization eager mode we catch the error
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
re.escape("Expected value >= 1 for dimension variable 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (b,). "
"Obtained dimension variables: 'b' = 0")):
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=False)(x0)
# In graph serialization graph mode we also catch it (except on TPU)
f1_tf = tf.function(
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=False),
autograph=False,
).get_concrete_function(tf.TensorSpec([None], dtype=np.float32))
# In graph serialization graph mode we also catch it (except on TPU, where
# the behavior is as for jit_compile=1)
if jtu.test_device_matches(["tpu"]):
self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0))
else:
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
re.escape("Expected value >= 1 for dimension variable")):
_ = f1_tf(x0)
# In graph serialization with jit_compile=True we do not catch the error
# and we return the wrong result
f1_tf = tf.function(
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=False),
autograph=False,
jit_compile=True
)
self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0))
# We also catch the error with native serialization
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
@ -1167,8 +1111,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
"Expected value >= 1 for dimension variable 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (b,). "
"Obtained dimension variables: 'b' = 0")):
_ = jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
native_serialization=True)(x0)
_ = jax2tf.convert(f1_jax, polymorphic_shapes=["b"])(x0)
# Checking that the actual dimensions denoted by the same
# dimension variables have equal sizes.
@ -1205,24 +1148,13 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
r"Found inconsistency"):
_ = f2_tf(x45)
# In graph serialization with jit_compile=True we do not catch the error
# and we return the wrong result
f2_tf = tf.function(
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False),
autograph=False,
jit_compile=True
)
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
# We also catch the error with native serialization
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
re.escape(
"Found inconsistency between dimension size args[0].shape[1] (= 5) "
"and the specification 'b' (= 4)")):
_ = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=True)(x45)
_ = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"])(x45)
x = np.ones((5,), dtype=np.float32)
with self.assertRaisesRegex(
@ -1572,12 +1504,12 @@ _POLY_SHAPE_TEST_HARNESSES = [
PolyHarness("argmax", "0",
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
arg_descriptors=[RandArg((3, 4, 5), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
# Reduce the non-poly dimension
PolyHarness("argmax", "1",
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
arg_descriptors=[RandArg((3, 4, 5), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("jnp.argsort", "",
lambda op: jnp.argsort(op),
arg_descriptors=[RandArg((3, 4, 5), _f32)],
@ -1659,7 +1591,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
rhs_spec=(2, 1, 0),
out_spec=(0, 2, 1))),
arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
polymorphic_shapes=["_, b, _", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["_, b, _", None]),
# The same example from above, but with stride=2.
PolyHarness("conv_general_dilated", "1d_stride=2_even",
lambda lhs, rhs: lax.conv_general_dilated(
@ -1671,7 +1603,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
rhs_spec=(2, 1, 0),
out_spec=(0, 2, 1))),
arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
polymorphic_shapes=["_, b, _", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["_, b, _", None]),
# The same example from above, but with stride=2 and odd input size.
PolyHarness("conv_general_dilated", "1d_stride=2_odd",
lambda lhs, rhs: lax.conv_general_dilated(
@ -1683,7 +1615,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
rhs_spec=(2, 1, 0),
out_spec=(0, 2, 1))),
arg_descriptors=[RandArg((1, 13, 16), _f32), RandArg((4, 16, 16), _f32)],
polymorphic_shapes=["_, b, _", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["_, b, _", None]),
PolyHarness("conv_general_dilated", "1d_stride=2_zero_output",
lambda lhs, rhs: lax.conv_general_dilated(
lhs, rhs,
@ -1697,7 +1629,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
arg_descriptors=[RandArg((1, 4, 16), _f32),
RandArg((8, 16, 16), _f32)],
polymorphic_shapes=["_, b, _",
None]).both_enable_and_disable_xla(),
None]),
# Issue #11402
PolyHarness("conv_general_dilated", "1d_2",
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
@ -1707,7 +1639,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
transpose_kernel=False),
arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
polymorphic_shapes=["b, _, _", None],
tol=5e-5).both_enable_and_disable_xla(),
tol=5e-5),
# Issue #11402
PolyHarness("conv_general_dilated", "1d_3",
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
@ -1717,7 +1649,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
transpose_kernel=False),
arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
polymorphic_shapes=["_, b, _", None],
tol=5e-5).both_enable_and_disable_xla(),
tol=5e-5),
PolyHarness("conv_general_dilated", "",
lambda lhs, rhs: lax.conv_general_dilated(
lhs, rhs,
@ -1730,7 +1662,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
batch_group_count=1,
precision=None),
arg_descriptors=[RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)],
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ...", None]),
[
[
PolyHarness(cum_name, "reduce_axis_poly",
@ -1763,47 +1695,47 @@ _POLY_SHAPE_TEST_HARNESSES = [
# x:shape: (b, 4)
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("dynamic_slice", "idx=tuple_arg",
# x:shape: (b, 4)
lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)),
arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ...", None]),
PolyHarness("dynamic_slice", "idx=array",
# x:shape: (b, 4)
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ...", None]),
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_large",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice(x, (1, 1), (x.shape[0], 2)),
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_small",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice(x, (-1, 1), (x.shape[0] - 1, 2)),
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("dynamic_slice_in_dim", "idx=0",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0),
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("dynamic_update_slice", "idx=tuple_int",
# x:shape: (b, 4)
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("dynamic_update_slice", "idx=tuple_arg",
# x:shape: (b, 4)
lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))),
arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ...", None]),
PolyHarness("dynamic_update_slice", "idx=array",
# x:shape: (b, 4)
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
polymorphic_shapes=["b, _", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, _", None]),
[
PolyHarness("eig", f"shape={jtu.format_shape_dtype_string((3, 5, 5), dtype)}_poly={poly}_{left=}_{right=}",
lambda x, left, right: lax.linalg.eig(x, compute_left_eigenvectors=left, compute_right_eigenvectors=right),
@ -1931,44 +1863,44 @@ _POLY_SHAPE_TEST_HARNESSES = [
PolyHarness("getitem", "op=static_idx=poly",
lambda a, i: a[i],
arg_descriptors=[RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
polymorphic_shapes=[None, "b0, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=[None, "b0, ..."]),
# operand is poly, index is integer
PolyHarness("getitem", "op=poly_idx=const",
lambda a: a[1],
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
# operand is poly, index is dim poly
PolyHarness("getitem", "op=poly_idx=dim",
lambda a: a[jnp.array(a.shape[0] - 2)],
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
# Both the operand and the index are poly
PolyHarness("getitem", "op=poly_idx=poly",
lambda a, i: a[i],
arg_descriptors=[RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)],
polymorphic_shapes=["b, ...", "b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ...", "b, ..."]),
# op is poly and index is an entire slice
PolyHarness("getitem", "op=poly_idx=slice-all",
lambda a: a[:],
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
# op is poly and index is a partial slice
PolyHarness("getitem", "op=poly_idx=slice-ct-1",
lambda a: a[:2],
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b + 2, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b + 2, ..."]),
PolyHarness("getitem", "op=poly_idx=slice-ct-2",
lambda a: a[:, :2],
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("getitem", "op=poly_idx=slice-None-1",
lambda a: a[:a.shape[0]],
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("getitem", "op=poly_idx=slice-poly",
lambda a: a[:a.shape[0] - 1],
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("image_resize", "linear_0",
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
method="linear"),
@ -2201,13 +2133,13 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min,
(2, 2), (1, 1), "VALID"),
arg_descriptors=[RandArg((3, 8), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("reduce_window", "min_window_size=dynamic",
# x: f32[b, 8]
lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min,
(2, x.shape[0]), (1, 1), "VALID"),
arg_descriptors=[RandArg((3, 8), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("reduce_window", "min_plus_max_window_size=static",
# x: f32[b, 8]
lambda x: (
@ -2217,7 +2149,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
lax.reduce_window(x, np.array(1., _f32), lax.max,
(2, 2), (1, 1), "VALID")),
arg_descriptors=[RandArg((3, 8), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("reduce_window", "min_plus_max_window_size=dynamic",
# x: f32[b, 8]
lambda x: (
@ -2227,19 +2159,19 @@ _POLY_SHAPE_TEST_HARNESSES = [
lax.reduce_window(x, np.array(1., _f32), lax.max,
(2, x.shape[0]), (1, 1), "VALID")),
arg_descriptors=[RandArg((3, 8), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("reduce_window", "add_monoid_base_window_size=static",
# x: f32[b, 8]
lambda x: lax.reduce_window(x, np.array(0., _f32), lax.add,
(2, 2), (1, 1), "VALID"),
arg_descriptors=[RandArg((3, 8), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
PolyHarness("reduce_window", "add_monoid_base_window_size=dynamic",
# x: f32[b, 8]
lambda x: lax.reduce_window(x, np.array(0., _f32), lax.add,
(2, x.shape[0]), (1, 1), "VALID"),
arg_descriptors=[RandArg((3, 8), _f32)],
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ..."]),
# https://github.com/google/jax/issues/11804
# Use the reshape trick to simulate a polymorphic dimension of 16*b.
# (See test "conv_general_dilated.1d_1" above for more details.)
@ -2249,7 +2181,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
jnp.reshape(x, (1, -1, 1)),
np.array(0., _f32), lax.add, (1, 4, 1), (1, 2, 1), "SAME"),
arg_descriptors=[RandArg((1, 128, 16), _f32)],
polymorphic_shapes=["_, b1, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["_, b1, ..."]),
PolyHarness("reduce_window", "add_generic_window_size=static",
# x: f32[1, 16*b, 1]
# Use an initial value of 1. to trigger the generic reduction path
@ -2257,7 +2189,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
jnp.reshape(x, (1, -1, 1)),
np.array(1., _f32), lax.add, (1, 4, 1), (1, 2, 1), "SAME"),
arg_descriptors=[RandArg((1, 128, 16), _f32)],
polymorphic_shapes=["_, b1, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["_, b1, ..."]),
PolyHarness("reduce_window", "variadic_generic_window_size=static",
# x: f32[b, 8] y: f32[b, 8]
lambda x, y: lax.reduce_window(
@ -2266,7 +2198,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
lax.sub(xy0[1], xy1[1])),
(2, 2), (1, 1), "VALID"),
arg_descriptors=[RandArg((3, 8), _f32), RandArg((3, 8), _i32)],
polymorphic_shapes=["b, ...", "b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ...", "b, ..."]),
PolyHarness("reduce_window", "variadic_generic_window_size=dynamic",
# x: f32[b, 8] y: f32[b, 8]
lambda x, y: lax.reduce_window(
@ -2275,7 +2207,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
lax.sub(xy0[1], xy1[1])),
(2, x.shape[0]), (1, 1), "VALID"),
arg_descriptors=[RandArg((3, 8), _f32), RandArg((3, 8), _i32)],
polymorphic_shapes=["b, ...", "b, ..."]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ...", "b, ..."]),
# TODO(necula): not yet supported, but also unlikely to come up.
# PolyHarness("random_uniform", "odd",
# lambda key, a: jax.random.uniform(key, (2 * a.shape[0] + 1, a.shape[1]),
@ -2493,7 +2425,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
PolyHarness("take", "",
lambda a, i: jnp.take(a, i, axis=1),
arg_descriptors=[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)],
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
polymorphic_shapes=["b, ...", None]),
PolyHarness("take_along_axis", "0",
lambda x, y: jnp.take_along_axis(x, y, axis=0),
arg_descriptors=[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)],
@ -2706,106 +2638,51 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# Exclude some harnesses that are known to fail for native serialization
# FOR NATIVE SERIALIZATION
if config.jax2tf_default_native_serialization.value:
if not harness.enable_xla:
raise unittest.SkipTest("disabled for native_serialization and enable_xla=False")
# Set of harness.group_name:platform that are implemented with custom call
custom_call_harnesses = {
"householder_product:gpu",
"vmap_geqrf:gpu", # used for linalg.qr
"vmap_lu:gpu",
# custom_linear_solve works as long as lu works.
"vmap_custom_linear_solve:gpu",
"vmap_qr:gpu", "qr:gpu",
"vmap_svd:gpu",
}
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
# Set of harness.group_name:platform that are implemented with custom call
custom_call_harnesses = {
"householder_product:gpu",
"vmap_geqrf:gpu", # used for linalg.qr
"vmap_lu:gpu",
# custom_linear_solve works as long as lu works.
"vmap_custom_linear_solve:gpu",
"vmap_qr:gpu", "qr:gpu",
"vmap_svd:gpu",
}
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
if "fft_fft_type" in harness.fullname:
if "nr_fft_lengths_2" in harness.fullname:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU")
if "fft_fft_type" in harness.fullname:
if "nr_fft_lengths_2" in harness.fullname:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU")
if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]):
# For eigh on GPU with shape polymorphism under native serialization,
# we use a different lowering for small matrices. See README.md.
shape = harness.original_harness.params["shape"]
if 0 < shape[-1] <= 32:
harness.check_result = False
if harness.group_name == "vmap_tan":
# Tan (b/274462307) require support for custom call stablehlo.tan.
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")
if (jtu.test_device_matches(["cpu", "gpu"]) and
harness.fullname in [
"cumsum_reduce_axis_poly", "cumprod_reduce_axis_poly",
"cummin_reduce_axis_poly", "cummax_reduce_axis_poly",
"cumlogsumexp_reduce_axis_poly",
"jnp_insert_insert_constant", "jnp_insert_insert_poly",
"jnp_nonzero_size_constant", "jnp_nonzero_size_poly"]):
# Need associative scan reductions on CPU and GPU. On TPU we use the
# reduce_window HLO, but on CPU and GPU (with axis size >= 32) we use
# a recursive associative scan that we cannot express with shape
# polymorphism.
raise unittest.SkipTest(
"native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU")
# FOR GRAPH SERIALIZATION
if not config.jax2tf_default_native_serialization.value:
if ("random_gamma_threefry_non_partitionable" in harness.fullname and
jtu.test_device_matches(["cpu"])):
harness.tol = 1e-6
if harness.group_name == "vmap_cumsum":
# For cumsum we use a different implementation than JAX native
# See README.md for associative scan reductions
harness.tol = 1e-5
if "vmap_" in harness.group_name:
# For non-native serialization, it seems that we cannot just use
# the custom_asserts; we get too many errors.
if [l for l in harness.limitations if l.custom_assert]:
harness.check_result = False
if "vmap_integer_pow" in harness.group_name:
# For non-native serialization the overflow behavior is different.
if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]):
# For eigh on GPU with shape polymorphism under native serialization,
# we use a different lowering for small matrices. See README.md.
shape = harness.original_harness.params["shape"]
if 0 < shape[-1] <= 32:
harness.check_result = False
if "average_axis_None_weights_Some" in harness.fullname:
harness.tol = 1e-5
if harness.group_name == "vmap_tan":
# Tan (b/274462307) require support for custom call stablehlo.tan.
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")
if harness.group_name == "schur":
raise unittest.SkipTest("jax2tf graph serialization does not support schur.")
if harness.group_name == "eig" and "left_True_right_True" in harness.fullname:
raise unittest.SkipTest("jax2tf graph serialization does not support both left and right.")
if harness.group_name == "vmap_eigh":
self.skipTest("b/312378994: error in TF eager execution for tf.linalg.eigh")
if "conv_general_dilated_1d_stride_2_zero_output_enable_xla_False" in harness.fullname:
raise unittest.SkipTest("incomplete support for conv_general_dilated in enable_xla=False")
if harness.group_name == "reduce_window" and "variadic" in harness.fullname:
raise unittest.SkipTest("jax2tf graph serialization does not support variadic reduce_window.")
if (harness.group_name == "reduce_window" and
not harness.enable_xla and
"window_size_dynamic" in harness.fullname and
any(n in harness.fullname
for n in ["min_plus_max", "add_monoid_base", "min_window"])):
raise unittest.SkipTest(
"jax2tf graph serialization with enable_xla=False does not support "
"dynamic tf.nn.pool")
if "reduce_window_add_generic" in harness.fullname and not harness.enable_xla:
raise unittest.SkipTest("TODO(b/287733072): wrong result for enable_xla_False")
if (jtu.test_device_matches(["cpu", "gpu"]) and
harness.fullname in [
"cumsum_reduce_axis_poly", "cumprod_reduce_axis_poly",
"cummin_reduce_axis_poly", "cummax_reduce_axis_poly",
"cumlogsumexp_reduce_axis_poly",
"jnp_insert_insert_constant", "jnp_insert_insert_poly",
"jnp_nonzero_size_constant", "jnp_nonzero_size_poly"]):
# Need associative scan reductions on CPU and GPU. On TPU we use the
# reduce_window HLO, but on CPU and GPU (with axis size >= 32) we use
# a recursive associative scan that we cannot express with shape
# polymorphism.
raise unittest.SkipTest(
"native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU")
# FOR BOTH NATIVE AND GRAPH SERIALIZATION
if harness.group_name == "vmap_conv_general_dilated":