mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
[jax2tf] Deprecate jax2tf with native_serialization=False or enable_xla=False.
Also disable many of the non-native-serialization jax2tf tests. In particular, I am disabling the thousands of primitives tests in graph serialization mode. I kept jax2tf_test running in both native and graph serialization mode. PiperOrigin-RevId: 652749891
This commit is contained in:
parent
28ffa25496
commit
d34a6e9ce2
@ -32,6 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* HLO lowering rules should no longer wrap singleton ir.Values in tuples.
|
||||
Instead, return singleton ir.Values unwrapped. Support for wrapped values
|
||||
will be removed in a future version of JAX.
|
||||
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
|
||||
or `enable_xla=False` is now deprecated and this support will be removed in
|
||||
a future version.
|
||||
Native serialization has been the default since JAX 0.4.16 (September 2023).
|
||||
|
||||
## jaxlib 0.4.31
|
||||
|
||||
|
@ -916,7 +916,8 @@ jax2tf_default_native_serialization = bool_state(
|
||||
help=(
|
||||
'Sets the default value of the native_serialization parameter to '
|
||||
'jax2tf.convert. Prefer using the parameter instead of the flag, '
|
||||
'the flag may be removed in the future.'
|
||||
'the flag may be removed in the future. '
|
||||
'Starting with JAX 0.4.31 non-native serialization is deprecated.'
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -59,7 +59,8 @@ For backwards compatibility purposes, and for special uses,
|
||||
the JAX-TensorFlow interoperation APIs can be used also
|
||||
in a **graph serialization** mode (the only mode available before version 0.4.7,
|
||||
and the default mode before JAX version 0.4.15),
|
||||
without going through StableHLO.
|
||||
without going through StableHLO. (Starting with JAX version 0.4.31 the
|
||||
graph serialization mode is deprecated. It will be removed in the near future).
|
||||
|
||||
* For calling JAX functions from TensorFlow,
|
||||
it is possible to request that the JAX function be lowered with one TensorFlow
|
||||
|
@ -304,6 +304,7 @@ def convert(fun_jax: Callable,
|
||||
so the lowering tries harder to use non-XLA TF ops to lower the
|
||||
function and aborts if this is not possible. Cannot be set to `False`
|
||||
when using `native_serialization`.
|
||||
Starting with JAX 0.4.31 support for `enable_xla=False` is deprecated.
|
||||
native_serialization: serialize the JAX function natively to
|
||||
StableHLO with compatibility guarantees. This makes it easier to have
|
||||
confidence that the code executed when calling this function from
|
||||
@ -312,6 +313,7 @@ def convert(fun_jax: Callable,
|
||||
is set to `False` or to the configuration flag
|
||||
`--jax2tf_default_native_serialization` otherwise.
|
||||
Native serialization cannot be used with `enable_xla=False`.
|
||||
Starting with JAX 0.4.31 support for non-native serialization is deprecated.
|
||||
native_serialization_platforms: In conjunction with
|
||||
`native_serialization`, specify the platform(s)
|
||||
for which to lower the code. Must be a tuple of
|
||||
@ -327,12 +329,17 @@ def convert(fun_jax: Callable,
|
||||
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
|
||||
only TensorFlow ops and thus can be called from a TensorFlow program.
|
||||
"""
|
||||
if not enable_xla:
|
||||
warnings.warn("jax2tf.convert with enable_xla=False is deprecated.")
|
||||
if native_serialization is DEFAULT_NATIVE_SERIALIZATION:
|
||||
if not enable_xla:
|
||||
native_serialization = False
|
||||
else:
|
||||
native_serialization = config.jax2tf_default_native_serialization.value
|
||||
|
||||
if not native_serialization:
|
||||
warnings.warn(
|
||||
"jax2tf.convert with native_serialization=False is deprecated.")
|
||||
if native_serialization and not enable_xla:
|
||||
raise ValueError(
|
||||
"native_serialization is not supported with enable_xla=False")
|
||||
|
@ -860,17 +860,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
return ad_checkpoint.checkpoint_name(jnp.sin(x), "sin")
|
||||
jax2tf.convert(f_jax)(1.) # No error.
|
||||
|
||||
def test_convert_nullary_func(self):
|
||||
# Even nullary functions are converted to TF (as opposed to constant-folded
|
||||
# in JAX prior to conversion).
|
||||
def f_jax():
|
||||
return jnp.sin(1.)
|
||||
f_tf = jax2tf.convert(f_jax)
|
||||
# for native serialization the HLO we get from TF is constant-folded, so this
|
||||
# test fails.
|
||||
if not config.jax2tf_default_native_serialization.value:
|
||||
self.assertIn("sine(", self.TfToHlo(f_tf))
|
||||
|
||||
def test_convert_of_nested_independent_jit(self):
|
||||
def func(x):
|
||||
def inner1(y):
|
||||
@ -1132,31 +1121,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
np.full_like(x[1], fill_value=2.)),
|
||||
(grad_tf[0].numpy(), grad_tf[1].numpy()))
|
||||
|
||||
@jtu.skip_on_flag("jax2tf_default_native_serialization", True)
|
||||
def test_enable_xla(self):
|
||||
# Tests that enable_xla flag is properly scoped to a conversion.
|
||||
def fun(x):
|
||||
# lax.reduce is unlikely to ever be convertible with enable_xla=False
|
||||
return lax.reduce(x, np.float32(0), lambda v, acc: v + acc, dimensions=(0, 1))
|
||||
|
||||
tf_fun_with_xla = jax2tf.convert(fun, enable_xla=True)
|
||||
tf_fun_without_xla = jax2tf.convert(fun, enable_xla=False)
|
||||
x = np.ones((2, 3), dtype=np.float32)
|
||||
|
||||
self.assertAllClose(fun(x), tf_fun_with_xla(x))
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"Call to reduce cannot be converted with enable_xla=False"):
|
||||
tf_fun_without_xla(x)
|
||||
|
||||
# Now in reverse order (we had bugs with the management of enable_xla global)
|
||||
tf_fun2_without_xla = jax2tf.convert(lambda x: fun(x), enable_xla=False)
|
||||
tf_fun2_with_xla = jax2tf.convert(lambda x: fun(x), enable_xla=True)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"Call to reduce cannot be converted with enable_xla=False"):
|
||||
tf_fun2_without_xla(x)
|
||||
self.assertAllClose(fun(x), tf_fun2_with_xla(x))
|
||||
|
||||
def test_device_array_arg(self):
|
||||
self.ConvertAndCompare(jnp.sin, jnp.zeros((2, 3), jnp.float32))
|
||||
|
||||
@ -1717,35 +1681,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
res,
|
||||
x + _testing_multi_platform_to_add[tf_device_jax_platform])
|
||||
|
||||
def test_cond_primitive(self):
|
||||
def f_cond(x):
|
||||
return lax.cond(x < 1.0, jnp.cos, jnp.sin, x)
|
||||
|
||||
self.ConvertAndCompare(f_cond, np.pi / 4, enable_xla=False)
|
||||
self.ConvertAndCompare(f_cond, np.pi / 2, enable_xla=False)
|
||||
|
||||
f_cond_tf = jax2tf.convert(f_cond, enable_xla=False)
|
||||
self.assertNotIn("switch_case", self.TfToHlo(f_cond_tf, np.pi))
|
||||
|
||||
def f_switch(x):
|
||||
return lax.switch(jnp.int32(x), [jnp.cos, jnp.sin, lambda _: 42.0], x)
|
||||
|
||||
self.ConvertAndCompare(f_switch, np.pi / 4, enable_xla=False)
|
||||
self.ConvertAndCompare(f_switch, np.pi / 2, enable_xla=False)
|
||||
self.ConvertAndCompare(f_switch, 2 * np.pi, enable_xla=False)
|
||||
|
||||
f_switch_tf = jax2tf.convert(f_switch, enable_xla=False)
|
||||
self.assertIn("switch_case", self.TfToHlo(f_switch_tf, np.pi))
|
||||
|
||||
@jtu.skip_on_flag("jax2tf_default_native_serialization", False)
|
||||
def test_ragged_dot(self):
|
||||
dtype = np.float32
|
||||
m, k, n, num_groups = 5, 4, 3, 2
|
||||
lhs = np.arange(m * k, dtype=dtype).reshape((m, k))
|
||||
rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n))
|
||||
group_sizes = np.array([3, 2], dtype=np.int32)
|
||||
self.ConvertAndCompare(jax.lax.ragged_dot, lhs, rhs, group_sizes)
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_custom_prng=True)
|
||||
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
||||
|
@ -113,9 +113,6 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
dtype=harness.dtype), limitations))
|
||||
func_jax = harness.dyn_fun
|
||||
args = harness.dyn_args_maker(self.rng())
|
||||
enable_xla = harness.params.get("enable_xla", True)
|
||||
if config.jax2tf_default_native_serialization.value and not enable_xla:
|
||||
raise unittest.SkipTest("native_serialization not supported with enable_xla=False")
|
||||
|
||||
if ("eigh" == harness.group_name and
|
||||
np.complex64 == harness.dtype and
|
||||
@ -142,8 +139,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
|
||||
try:
|
||||
with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
|
||||
self.ConvertAndCompare(func_jax, *args, limitations=limitations,
|
||||
enable_xla=enable_xla)
|
||||
self.ConvertAndCompare(func_jax, *args, limitations=limitations)
|
||||
except Exception as e:
|
||||
# TODO(b/264596006): custom calls are not registered properly with TF in OSS
|
||||
if (config.jax2tf_default_native_serialization.value and
|
||||
|
@ -85,7 +85,6 @@ class PolyHarness(Harness):
|
||||
polymorphic_constraints: Sequence[str] = (),
|
||||
input_signature: Sequence[tf.TensorSpec] | None = None,
|
||||
expected_output_signature: tf.TensorSpec | None = None,
|
||||
enable_xla: bool = True,
|
||||
expect_error: tuple[Any | None, str | None] = (None, None),
|
||||
skip_jax_run: bool = False,
|
||||
check_result: bool = True,
|
||||
@ -105,7 +104,6 @@ class PolyHarness(Harness):
|
||||
input_signature: For `tf.function.get_concrete_function`. If missing,
|
||||
generated from `polymorphic_shapes`.
|
||||
expected_output_signature: the expected inferred output shape.
|
||||
enable_xla: For `jax2tf.convert`.
|
||||
expect_error: a pair of an Exception type and a regular expression to
|
||||
match the expected exception string.
|
||||
skip_jax_run: If True, then neither the JAX nor the TF functions are
|
||||
@ -126,29 +124,11 @@ class PolyHarness(Harness):
|
||||
self.expected_output_signature = expected_output_signature
|
||||
self.skip_jax_run = skip_jax_run
|
||||
self.expect_error = expect_error
|
||||
self.enable_xla = enable_xla
|
||||
self.tol = tol
|
||||
self.check_result = check_result
|
||||
self.limitations = limitations
|
||||
self.override_jax_config_flags = override_jax_config_flags
|
||||
|
||||
# Replicate the harness for both enable and disable xla
|
||||
def both_enable_and_disable_xla(self) -> tuple[PolyHarness, PolyHarness]:
|
||||
assert self.enable_xla
|
||||
other = PolyHarness(self.group_name,
|
||||
f"{self.name}_enable_xla_False",
|
||||
self.fun,
|
||||
arg_descriptors=self.arg_descriptors,
|
||||
polymorphic_shapes=self.polymorphic_shapes,
|
||||
polymorphic_constraints=self.polymorphic_constraints,
|
||||
input_signature=self.input_signature,
|
||||
expected_output_signature=self.expected_output_signature,
|
||||
expect_error=self.expect_error,
|
||||
tol=self.tol,
|
||||
enable_xla=False)
|
||||
self.name = f"{self.name}_enable_xla_True"
|
||||
return (self, other)
|
||||
|
||||
def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> jax.Array | None:
|
||||
def log_message(extra: str):
|
||||
return f"[{tst._testMethodName}]: {extra}"
|
||||
@ -189,8 +169,7 @@ class PolyHarness(Harness):
|
||||
stack.enter_context(tst.assertRaisesRegex(expect_error_type, expect_error_regex))
|
||||
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes,
|
||||
polymorphic_constraints=self.polymorphic_constraints,
|
||||
enable_xla=self.enable_xla)
|
||||
polymorphic_constraints=self.polymorphic_constraints)
|
||||
# Run in tf.Eager mode first, because it is friendlier to debuggers
|
||||
res_tf = f_tf(*args) if not self.skip_jax_run else None
|
||||
f_tf_func = tf.function(
|
||||
@ -1125,41 +1104,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
x0 = np.array([], np.float32)
|
||||
self.assertEqual(jnp.array([0.], dtype=np.float32), f1_jax(x0))
|
||||
|
||||
# In graph serialization eager mode we catch the error
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
re.escape("Expected value >= 1 for dimension variable 'b'. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (b,). "
|
||||
"Obtained dimension variables: 'b' = 0")):
|
||||
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=False)(x0)
|
||||
|
||||
# In graph serialization graph mode we also catch it (except on TPU)
|
||||
f1_tf = tf.function(
|
||||
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=False),
|
||||
autograph=False,
|
||||
).get_concrete_function(tf.TensorSpec([None], dtype=np.float32))
|
||||
# In graph serialization graph mode we also catch it (except on TPU, where
|
||||
# the behavior is as for jit_compile=1)
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0))
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
re.escape("Expected value >= 1 for dimension variable")):
|
||||
_ = f1_tf(x0)
|
||||
|
||||
# In graph serialization with jit_compile=True we do not catch the error
|
||||
# and we return the wrong result
|
||||
f1_tf = tf.function(
|
||||
jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=False),
|
||||
autograph=False,
|
||||
jit_compile=True
|
||||
)
|
||||
self.assertEqual(jnp.array([1.], dtype=np.float32), f1_tf(x0))
|
||||
|
||||
# We also catch the error with native serialization
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
@ -1167,8 +1111,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
"Expected value >= 1 for dimension variable 'b'. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (b,). "
|
||||
"Obtained dimension variables: 'b' = 0")):
|
||||
_ = jax2tf.convert(f1_jax, polymorphic_shapes=["b"],
|
||||
native_serialization=True)(x0)
|
||||
_ = jax2tf.convert(f1_jax, polymorphic_shapes=["b"])(x0)
|
||||
|
||||
# Checking that the actual dimensions denoted by the same
|
||||
# dimension variables have equal sizes.
|
||||
@ -1205,24 +1148,13 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
r"Found inconsistency"):
|
||||
_ = f2_tf(x45)
|
||||
|
||||
# In graph serialization with jit_compile=True we do not catch the error
|
||||
# and we return the wrong result
|
||||
f2_tf = tf.function(
|
||||
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=False),
|
||||
autograph=False,
|
||||
jit_compile=True
|
||||
)
|
||||
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
|
||||
|
||||
# We also catch the error with native serialization
|
||||
with self.assertRaisesRegex(
|
||||
tf.errors.InvalidArgumentError,
|
||||
re.escape(
|
||||
"Found inconsistency between dimension size args[0].shape[1] (= 5) "
|
||||
"and the specification 'b' (= 4)")):
|
||||
_ = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
|
||||
native_serialization=True)(x45)
|
||||
_ = jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"])(x45)
|
||||
|
||||
x = np.ones((5,), dtype=np.float32)
|
||||
with self.assertRaisesRegex(
|
||||
@ -1572,12 +1504,12 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
PolyHarness("argmax", "0",
|
||||
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
|
||||
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
# Reduce the non-poly dimension
|
||||
PolyHarness("argmax", "1",
|
||||
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
|
||||
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("jnp.argsort", "",
|
||||
lambda op: jnp.argsort(op),
|
||||
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
||||
@ -1659,7 +1591,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
rhs_spec=(2, 1, 0),
|
||||
out_spec=(0, 2, 1))),
|
||||
arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
polymorphic_shapes=["_, b, _", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["_, b, _", None]),
|
||||
# The same example from above, but with stride=2.
|
||||
PolyHarness("conv_general_dilated", "1d_stride=2_even",
|
||||
lambda lhs, rhs: lax.conv_general_dilated(
|
||||
@ -1671,7 +1603,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
rhs_spec=(2, 1, 0),
|
||||
out_spec=(0, 2, 1))),
|
||||
arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
polymorphic_shapes=["_, b, _", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["_, b, _", None]),
|
||||
# The same example from above, but with stride=2 and odd input size.
|
||||
PolyHarness("conv_general_dilated", "1d_stride=2_odd",
|
||||
lambda lhs, rhs: lax.conv_general_dilated(
|
||||
@ -1683,7 +1615,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
rhs_spec=(2, 1, 0),
|
||||
out_spec=(0, 2, 1))),
|
||||
arg_descriptors=[RandArg((1, 13, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
polymorphic_shapes=["_, b, _", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["_, b, _", None]),
|
||||
PolyHarness("conv_general_dilated", "1d_stride=2_zero_output",
|
||||
lambda lhs, rhs: lax.conv_general_dilated(
|
||||
lhs, rhs,
|
||||
@ -1697,7 +1629,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
arg_descriptors=[RandArg((1, 4, 16), _f32),
|
||||
RandArg((8, 16, 16), _f32)],
|
||||
polymorphic_shapes=["_, b, _",
|
||||
None]).both_enable_and_disable_xla(),
|
||||
None]),
|
||||
# Issue #11402
|
||||
PolyHarness("conv_general_dilated", "1d_2",
|
||||
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
|
||||
@ -1707,7 +1639,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
transpose_kernel=False),
|
||||
arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
polymorphic_shapes=["b, _, _", None],
|
||||
tol=5e-5).both_enable_and_disable_xla(),
|
||||
tol=5e-5),
|
||||
# Issue #11402
|
||||
PolyHarness("conv_general_dilated", "1d_3",
|
||||
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
|
||||
@ -1717,7 +1649,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
transpose_kernel=False),
|
||||
arg_descriptors=[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
polymorphic_shapes=["_, b, _", None],
|
||||
tol=5e-5).both_enable_and_disable_xla(),
|
||||
tol=5e-5),
|
||||
PolyHarness("conv_general_dilated", "",
|
||||
lambda lhs, rhs: lax.conv_general_dilated(
|
||||
lhs, rhs,
|
||||
@ -1730,7 +1662,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
batch_group_count=1,
|
||||
precision=None),
|
||||
arg_descriptors=[RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)],
|
||||
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ...", None]),
|
||||
[
|
||||
[
|
||||
PolyHarness(cum_name, "reduce_axis_poly",
|
||||
@ -1763,47 +1695,47 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("dynamic_slice", "idx=tuple_arg",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)),
|
||||
arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
|
||||
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ...", None]),
|
||||
PolyHarness("dynamic_slice", "idx=array",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
|
||||
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
||||
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ...", None]),
|
||||
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_large",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice(x, (1, 1), (x.shape[0], 2)),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("dynamic_slice", "idx=tuple_int_start_oob_small",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice(x, (-1, 1), (x.shape[0] - 1, 2)),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("dynamic_slice_in_dim", "idx=0",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("dynamic_update_slice", "idx=tuple_int",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("dynamic_update_slice", "idx=tuple_arg",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))),
|
||||
arg_descriptors=[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
|
||||
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ...", None]),
|
||||
PolyHarness("dynamic_update_slice", "idx=array",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
|
||||
arg_descriptors=[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
||||
polymorphic_shapes=["b, _", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, _", None]),
|
||||
[
|
||||
PolyHarness("eig", f"shape={jtu.format_shape_dtype_string((3, 5, 5), dtype)}_poly={poly}_{left=}_{right=}",
|
||||
lambda x, left, right: lax.linalg.eig(x, compute_left_eigenvectors=left, compute_right_eigenvectors=right),
|
||||
@ -1931,44 +1863,44 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
PolyHarness("getitem", "op=static_idx=poly",
|
||||
lambda a, i: a[i],
|
||||
arg_descriptors=[RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
|
||||
polymorphic_shapes=[None, "b0, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=[None, "b0, ..."]),
|
||||
# operand is poly, index is integer
|
||||
PolyHarness("getitem", "op=poly_idx=const",
|
||||
lambda a: a[1],
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
# operand is poly, index is dim poly
|
||||
PolyHarness("getitem", "op=poly_idx=dim",
|
||||
lambda a: a[jnp.array(a.shape[0] - 2)],
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
# Both the operand and the index are poly
|
||||
PolyHarness("getitem", "op=poly_idx=poly",
|
||||
lambda a, i: a[i],
|
||||
arg_descriptors=[RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)],
|
||||
polymorphic_shapes=["b, ...", "b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ...", "b, ..."]),
|
||||
# op is poly and index is an entire slice
|
||||
PolyHarness("getitem", "op=poly_idx=slice-all",
|
||||
lambda a: a[:],
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
# op is poly and index is a partial slice
|
||||
PolyHarness("getitem", "op=poly_idx=slice-ct-1",
|
||||
lambda a: a[:2],
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b + 2, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b + 2, ..."]),
|
||||
PolyHarness("getitem", "op=poly_idx=slice-ct-2",
|
||||
lambda a: a[:, :2],
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("getitem", "op=poly_idx=slice-None-1",
|
||||
lambda a: a[:a.shape[0]],
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("getitem", "op=poly_idx=slice-poly",
|
||||
lambda a: a[:a.shape[0] - 1],
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("image_resize", "linear_0",
|
||||
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
|
||||
method="linear"),
|
||||
@ -2201,13 +2133,13 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min,
|
||||
(2, 2), (1, 1), "VALID"),
|
||||
arg_descriptors=[RandArg((3, 8), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("reduce_window", "min_window_size=dynamic",
|
||||
# x: f32[b, 8]
|
||||
lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min,
|
||||
(2, x.shape[0]), (1, 1), "VALID"),
|
||||
arg_descriptors=[RandArg((3, 8), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("reduce_window", "min_plus_max_window_size=static",
|
||||
# x: f32[b, 8]
|
||||
lambda x: (
|
||||
@ -2217,7 +2149,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lax.reduce_window(x, np.array(1., _f32), lax.max,
|
||||
(2, 2), (1, 1), "VALID")),
|
||||
arg_descriptors=[RandArg((3, 8), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("reduce_window", "min_plus_max_window_size=dynamic",
|
||||
# x: f32[b, 8]
|
||||
lambda x: (
|
||||
@ -2227,19 +2159,19 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lax.reduce_window(x, np.array(1., _f32), lax.max,
|
||||
(2, x.shape[0]), (1, 1), "VALID")),
|
||||
arg_descriptors=[RandArg((3, 8), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("reduce_window", "add_monoid_base_window_size=static",
|
||||
# x: f32[b, 8]
|
||||
lambda x: lax.reduce_window(x, np.array(0., _f32), lax.add,
|
||||
(2, 2), (1, 1), "VALID"),
|
||||
arg_descriptors=[RandArg((3, 8), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("reduce_window", "add_monoid_base_window_size=dynamic",
|
||||
# x: f32[b, 8]
|
||||
lambda x: lax.reduce_window(x, np.array(0., _f32), lax.add,
|
||||
(2, x.shape[0]), (1, 1), "VALID"),
|
||||
arg_descriptors=[RandArg((3, 8), _f32)],
|
||||
polymorphic_shapes=["b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
# https://github.com/google/jax/issues/11804
|
||||
# Use the reshape trick to simulate a polymorphic dimension of 16*b.
|
||||
# (See test "conv_general_dilated.1d_1" above for more details.)
|
||||
@ -2249,7 +2181,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
jnp.reshape(x, (1, -1, 1)),
|
||||
np.array(0., _f32), lax.add, (1, 4, 1), (1, 2, 1), "SAME"),
|
||||
arg_descriptors=[RandArg((1, 128, 16), _f32)],
|
||||
polymorphic_shapes=["_, b1, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["_, b1, ..."]),
|
||||
PolyHarness("reduce_window", "add_generic_window_size=static",
|
||||
# x: f32[1, 16*b, 1]
|
||||
# Use an initial value of 1. to trigger the generic reduction path
|
||||
@ -2257,7 +2189,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
jnp.reshape(x, (1, -1, 1)),
|
||||
np.array(1., _f32), lax.add, (1, 4, 1), (1, 2, 1), "SAME"),
|
||||
arg_descriptors=[RandArg((1, 128, 16), _f32)],
|
||||
polymorphic_shapes=["_, b1, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["_, b1, ..."]),
|
||||
PolyHarness("reduce_window", "variadic_generic_window_size=static",
|
||||
# x: f32[b, 8] y: f32[b, 8]
|
||||
lambda x, y: lax.reduce_window(
|
||||
@ -2266,7 +2198,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lax.sub(xy0[1], xy1[1])),
|
||||
(2, 2), (1, 1), "VALID"),
|
||||
arg_descriptors=[RandArg((3, 8), _f32), RandArg((3, 8), _i32)],
|
||||
polymorphic_shapes=["b, ...", "b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ...", "b, ..."]),
|
||||
PolyHarness("reduce_window", "variadic_generic_window_size=dynamic",
|
||||
# x: f32[b, 8] y: f32[b, 8]
|
||||
lambda x, y: lax.reduce_window(
|
||||
@ -2275,7 +2207,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lax.sub(xy0[1], xy1[1])),
|
||||
(2, x.shape[0]), (1, 1), "VALID"),
|
||||
arg_descriptors=[RandArg((3, 8), _f32), RandArg((3, 8), _i32)],
|
||||
polymorphic_shapes=["b, ...", "b, ..."]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ...", "b, ..."]),
|
||||
# TODO(necula): not yet supported, but also unlikely to come up.
|
||||
# PolyHarness("random_uniform", "odd",
|
||||
# lambda key, a: jax.random.uniform(key, (2 * a.shape[0] + 1, a.shape[1]),
|
||||
@ -2493,7 +2425,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
PolyHarness("take", "",
|
||||
lambda a, i: jnp.take(a, i, axis=1),
|
||||
arg_descriptors=[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)],
|
||||
polymorphic_shapes=["b, ...", None]).both_enable_and_disable_xla(),
|
||||
polymorphic_shapes=["b, ...", None]),
|
||||
PolyHarness("take_along_axis", "0",
|
||||
lambda x, y: jnp.take_along_axis(x, y, axis=0),
|
||||
arg_descriptors=[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)],
|
||||
@ -2706,106 +2638,51 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
# Exclude some harnesses that are known to fail for native serialization
|
||||
# FOR NATIVE SERIALIZATION
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
if not harness.enable_xla:
|
||||
raise unittest.SkipTest("disabled for native_serialization and enable_xla=False")
|
||||
# Set of harness.group_name:platform that are implemented with custom call
|
||||
custom_call_harnesses = {
|
||||
"householder_product:gpu",
|
||||
"vmap_geqrf:gpu", # used for linalg.qr
|
||||
"vmap_lu:gpu",
|
||||
# custom_linear_solve works as long as lu works.
|
||||
"vmap_custom_linear_solve:gpu",
|
||||
"vmap_qr:gpu", "qr:gpu",
|
||||
"vmap_svd:gpu",
|
||||
}
|
||||
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
|
||||
|
||||
# Set of harness.group_name:platform that are implemented with custom call
|
||||
custom_call_harnesses = {
|
||||
"householder_product:gpu",
|
||||
"vmap_geqrf:gpu", # used for linalg.qr
|
||||
"vmap_lu:gpu",
|
||||
# custom_linear_solve works as long as lu works.
|
||||
"vmap_custom_linear_solve:gpu",
|
||||
"vmap_qr:gpu", "qr:gpu",
|
||||
"vmap_svd:gpu",
|
||||
}
|
||||
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
|
||||
if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]):
|
||||
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
|
||||
|
||||
if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]):
|
||||
raise unittest.SkipTest("schur decomposition is only implemented on CPU.")
|
||||
if "fft_fft_type" in harness.fullname:
|
||||
if "nr_fft_lengths_2" in harness.fullname:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU")
|
||||
|
||||
if "fft_fft_type" in harness.fullname:
|
||||
if "nr_fft_lengths_2" in harness.fullname:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU")
|
||||
|
||||
if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]):
|
||||
# For eigh on GPU with shape polymorphism under native serialization,
|
||||
# we use a different lowering for small matrices. See README.md.
|
||||
shape = harness.original_harness.params["shape"]
|
||||
if 0 < shape[-1] <= 32:
|
||||
harness.check_result = False
|
||||
|
||||
if harness.group_name == "vmap_tan":
|
||||
# Tan (b/274462307) require support for custom call stablehlo.tan.
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
|
||||
if (jtu.test_device_matches(["cpu", "gpu"]) and
|
||||
harness.fullname in [
|
||||
"cumsum_reduce_axis_poly", "cumprod_reduce_axis_poly",
|
||||
"cummin_reduce_axis_poly", "cummax_reduce_axis_poly",
|
||||
"cumlogsumexp_reduce_axis_poly",
|
||||
"jnp_insert_insert_constant", "jnp_insert_insert_poly",
|
||||
"jnp_nonzero_size_constant", "jnp_nonzero_size_poly"]):
|
||||
# Need associative scan reductions on CPU and GPU. On TPU we use the
|
||||
# reduce_window HLO, but on CPU and GPU (with axis size >= 32) we use
|
||||
# a recursive associative scan that we cannot express with shape
|
||||
# polymorphism.
|
||||
raise unittest.SkipTest(
|
||||
"native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU")
|
||||
|
||||
# FOR GRAPH SERIALIZATION
|
||||
if not config.jax2tf_default_native_serialization.value:
|
||||
if ("random_gamma_threefry_non_partitionable" in harness.fullname and
|
||||
jtu.test_device_matches(["cpu"])):
|
||||
harness.tol = 1e-6
|
||||
|
||||
if harness.group_name == "vmap_cumsum":
|
||||
# For cumsum we use a different implementation than JAX native
|
||||
# See README.md for associative scan reductions
|
||||
harness.tol = 1e-5
|
||||
|
||||
if "vmap_" in harness.group_name:
|
||||
# For non-native serialization, it seems that we cannot just use
|
||||
# the custom_asserts; we get too many errors.
|
||||
if [l for l in harness.limitations if l.custom_assert]:
|
||||
harness.check_result = False
|
||||
|
||||
if "vmap_integer_pow" in harness.group_name:
|
||||
# For non-native serialization the overflow behavior is different.
|
||||
if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]):
|
||||
# For eigh on GPU with shape polymorphism under native serialization,
|
||||
# we use a different lowering for small matrices. See README.md.
|
||||
shape = harness.original_harness.params["shape"]
|
||||
if 0 < shape[-1] <= 32:
|
||||
harness.check_result = False
|
||||
|
||||
if "average_axis_None_weights_Some" in harness.fullname:
|
||||
harness.tol = 1e-5
|
||||
if harness.group_name == "vmap_tan":
|
||||
# Tan (b/274462307) require support for custom call stablehlo.tan.
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
|
||||
if harness.group_name == "schur":
|
||||
raise unittest.SkipTest("jax2tf graph serialization does not support schur.")
|
||||
|
||||
if harness.group_name == "eig" and "left_True_right_True" in harness.fullname:
|
||||
raise unittest.SkipTest("jax2tf graph serialization does not support both left and right.")
|
||||
|
||||
if harness.group_name == "vmap_eigh":
|
||||
self.skipTest("b/312378994: error in TF eager execution for tf.linalg.eigh")
|
||||
|
||||
if "conv_general_dilated_1d_stride_2_zero_output_enable_xla_False" in harness.fullname:
|
||||
raise unittest.SkipTest("incomplete support for conv_general_dilated in enable_xla=False")
|
||||
|
||||
if harness.group_name == "reduce_window" and "variadic" in harness.fullname:
|
||||
raise unittest.SkipTest("jax2tf graph serialization does not support variadic reduce_window.")
|
||||
|
||||
if (harness.group_name == "reduce_window" and
|
||||
not harness.enable_xla and
|
||||
"window_size_dynamic" in harness.fullname and
|
||||
any(n in harness.fullname
|
||||
for n in ["min_plus_max", "add_monoid_base", "min_window"])):
|
||||
raise unittest.SkipTest(
|
||||
"jax2tf graph serialization with enable_xla=False does not support "
|
||||
"dynamic tf.nn.pool")
|
||||
|
||||
if "reduce_window_add_generic" in harness.fullname and not harness.enable_xla:
|
||||
raise unittest.SkipTest("TODO(b/287733072): wrong result for enable_xla_False")
|
||||
if (jtu.test_device_matches(["cpu", "gpu"]) and
|
||||
harness.fullname in [
|
||||
"cumsum_reduce_axis_poly", "cumprod_reduce_axis_poly",
|
||||
"cummin_reduce_axis_poly", "cummax_reduce_axis_poly",
|
||||
"cumlogsumexp_reduce_axis_poly",
|
||||
"jnp_insert_insert_constant", "jnp_insert_insert_poly",
|
||||
"jnp_nonzero_size_constant", "jnp_nonzero_size_poly"]):
|
||||
# Need associative scan reductions on CPU and GPU. On TPU we use the
|
||||
# reduce_window HLO, but on CPU and GPU (with axis size >= 32) we use
|
||||
# a recursive associative scan that we cannot express with shape
|
||||
# polymorphism.
|
||||
raise unittest.SkipTest(
|
||||
"native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU")
|
||||
|
||||
# FOR BOTH NATIVE AND GRAPH SERIALIZATION
|
||||
if harness.group_name == "vmap_conv_general_dilated":
|
||||
|
Loading…
x
Reference in New Issue
Block a user