mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Adds support for enable_xla=False for shape polymorphism tests and adds such tests for dynamic_slice.
It turned out that, in jax2tf._dynamic_slice, tf.constant doesn't work with polymorphic shapes, so I replaced it with a tf.cast. PiperOrigin-RevId: 378392273
This commit is contained in:
parent
86d2da44c0
commit
b749e78d2c
@ -2224,12 +2224,14 @@ def _dynamic_slice(operand, *start_indices, slice_sizes,
|
||||
# returned.
|
||||
# The code below manually clips the start indices so that the behavior is
|
||||
# the same as `lax.dynamic_slice_p`.
|
||||
|
||||
# clip_by_value fails if `start_indices` and `max_start` aren't of the same
|
||||
# dtype. By explicitly casting to the right dtype here this doesn't happen.
|
||||
operand_shape = _eval_shape(_in_avals[0].shape)
|
||||
shape = tf.constant(operand_shape, dtype=start_indices.dtype)
|
||||
max_start = tf.subtract(shape, slice_sizes)
|
||||
max_start = tf.subtract(operand_shape, slice_sizes)
|
||||
# If `operand_shape` and `slice_sizes` are Python tuples of integers,
|
||||
# `tf.subtract` returns a Tensor of dtype tf.int32, which may conflict with
|
||||
# the dtype of `start_indices` if we run in x64 mode and throw an error when
|
||||
# calling `tf.clip_by_vaue`. Therefore we cast to the right dtype here
|
||||
# explicitly.
|
||||
max_start = tf.cast(max_start, dtype=start_indices.dtype)
|
||||
start_indices = tf.clip_by_value(start_indices, 0, max_start)
|
||||
return tf.slice(operand, start_indices, size=slice_sizes)
|
||||
|
||||
|
@ -904,12 +904,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
[RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)],
|
||||
poly_axes=[0, 0]),
|
||||
|
||||
_make_harness("dynamic_slice", "",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0]),
|
||||
|
||||
_make_harness("dynamic_update_slice", "",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
|
||||
@ -1072,6 +1066,16 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
poly_axes=[0]),
|
||||
]
|
||||
|
||||
for enable_xla in [False, True]:
|
||||
_POLY_SHAPE_TEST_HARNESSES.append(
|
||||
_make_harness(f"dynamic_slice_enablexla={enable_xla}", "",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_xla=enable_xla)
|
||||
)
|
||||
|
||||
for reduce_op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum]:
|
||||
_POLY_SHAPE_TEST_HARNESSES.append(
|
||||
_make_harness("reduce", reduce_op.__name__,
|
||||
@ -1104,6 +1108,11 @@ def _add_vmap_primitive_harnesses():
|
||||
# And the jax2tf limitations that are known to result in TF error.
|
||||
if any(l.expect_tf_error for l in _get_jax2tf_limitations(device, h)):
|
||||
continue
|
||||
# TODO(marcvanzee): We currently exclude tests with enable_xla=False because
|
||||
# this doesn't work with vmap due to a call to lax.gather. We should include
|
||||
# them once vmap works with enable_xla=False.
|
||||
if not h.params.get("enable_xla", True):
|
||||
continue
|
||||
harness_groups[h.group_name].append(h)
|
||||
|
||||
selected_harnesses = []
|
||||
@ -1226,11 +1235,13 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
input_signature.append(arg_tensorspec)
|
||||
|
||||
res_jax = harness.dyn_fun(*args)
|
||||
enable_xla = harness.params.get("enable_xla", True)
|
||||
f_tf = self.CheckShapePolymorphism(
|
||||
harness.dyn_fun,
|
||||
input_signature=input_signature,
|
||||
polymorphic_shapes=polymorphic_shapes,
|
||||
expected_output_signature=None)
|
||||
expected_output_signature=None,
|
||||
enable_xla=enable_xla)
|
||||
|
||||
if harness.params["check_result"]:
|
||||
tol = harness.params["tol"]
|
||||
|
@ -293,24 +293,26 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
return self.ConvertAndCompare(grad_func, t_arg)
|
||||
assert False, transform
|
||||
|
||||
# TODO(marcvanzee): Add flag enable_xla here so we can also test shape
|
||||
# polymorphism for enable_xla=False.
|
||||
|
||||
def CheckShapePolymorphism(self, f_jax: Callable, *,
|
||||
input_signature: Sequence[tf.TensorSpec],
|
||||
polymorphic_shapes: Optional[Sequence[Any]],
|
||||
expected_output_signature: tf.TensorSpec):
|
||||
"""Convert a function using polymorphic shapes.
|
||||
expected_output_signature: Optional[tf.TensorSpec] = None,
|
||||
enable_xla: bool = True):
|
||||
"""Converts a function using polymorphic shapes.
|
||||
|
||||
Args:
|
||||
f_jax: a JAX function of `n` arguments
|
||||
input_signature: used as the input signature for the tf.function.
|
||||
in_shapes: if given, it must be a sequence of `n` shape specifications and
|
||||
must match the `input_signature`. (see jax2tf.convert).
|
||||
polymorphic_shapes: Specifies input shapes to be treated polymorphically
|
||||
during conversion.
|
||||
expected_output_signature: if given, this function tests whether the
|
||||
actual output signature is equal to this one.
|
||||
enable_xla: Whether to enable XLA conversion for jax2tf.convert.
|
||||
"""
|
||||
f_tf = tf.function(
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes),
|
||||
autograph=False,
|
||||
input_signature=input_signature)
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes,
|
||||
enable_xla=enable_xla)
|
||||
f_tf = tf.function(f_tf, autograph=False, input_signature=input_signature)
|
||||
concrete_f_tf = f_tf.get_concrete_function(*input_signature)
|
||||
if expected_output_signature:
|
||||
# Strangely, output_shapes can be a single shape for a function with a
|
||||
|
Loading…
x
Reference in New Issue
Block a user