mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Refactor shape_poly_test in preparation for moving out of jax2tf.
The shape polymotphism is now independent of jax2tf and the code is actually out of jax2tf. Here we refactor shape_poly_test to prepare for moving most of out of jax2tf. The main change is that we replace `jax2tf.convert(f_jax)(*args)` with a call to `check_shape_poly` which now still uses `jax2tf` but in the future will use JAX native mechanisms.
This commit is contained in:
parent
c8f3e2370d
commit
f9474b221c
@ -560,7 +560,7 @@ class PolyHarness(Harness):
|
||||
self.name = f"{self.name}_enable_xla_True"
|
||||
return (self, other)
|
||||
|
||||
def run_test(self, tst: tf_test_util.JaxToTfTestCase):
|
||||
def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> Optional[jax.Array]:
|
||||
def log_message(extra: str):
|
||||
return f"[{tst._testMethodName}]: {extra}"
|
||||
|
||||
@ -609,7 +609,7 @@ class PolyHarness(Harness):
|
||||
concrete_f_tf = f_tf_func.get_concrete_function(*input_signature)
|
||||
|
||||
if expect_error_type is not None:
|
||||
return
|
||||
return None
|
||||
|
||||
if self.expected_output_signature:
|
||||
# Strangely, output_shapes can be a single shape for a function with a
|
||||
@ -649,6 +649,11 @@ class PolyHarness(Harness):
|
||||
f"to {custom_assert_lims[0]}"))
|
||||
custom_assert_lims[0].custom_assert(tst, res_jax, res_tf, args=args, # type: ignore
|
||||
tol=tol, err_msg=None)
|
||||
return res_tf
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def check_shape_poly(tst, f_jax: Callable, *,
|
||||
@ -657,7 +662,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
|
||||
polymorphic_shapes: Sequence[Optional[str]] = (),
|
||||
input_signature: Optional[Sequence[tf.TensorSpec]] = None,
|
||||
expected_output_signature: Optional[tf.TensorSpec] = None,
|
||||
expect_error=(None, None)):
|
||||
expect_error=(None, None)) -> Optional[jax.Array]:
|
||||
# Makes and tests a harness. See PolyHarness documentation.
|
||||
h = PolyHarness("", "", f_jax,
|
||||
arg_descriptors=arg_descriptors,
|
||||
@ -666,7 +671,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
|
||||
input_signature=input_signature,
|
||||
expected_output_signature=expected_output_signature,
|
||||
expect_error=expect_error)
|
||||
h.run_test(tst)
|
||||
return h.run_test(tst)
|
||||
|
||||
|
||||
class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
@ -730,43 +735,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=["h, h", "h, h"],
|
||||
expected_output_signature=tf.TensorSpec([None, None]))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
# make_args invoked with op.shape[0]: start, stop, step, dtype
|
||||
# b == 6
|
||||
kwargs=[
|
||||
# Positive step
|
||||
dict(testcase_name="b", make_args=lambda b: (b, None, None, None)),
|
||||
dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)),
|
||||
dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)),
|
||||
dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)),
|
||||
dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)),
|
||||
dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)),
|
||||
dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)),
|
||||
dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)),
|
||||
dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)),
|
||||
dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)), # Cannot tell if size >= 0
|
||||
# Negative step
|
||||
dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)),
|
||||
dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)),
|
||||
dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)),
|
||||
dict(testcase_name="5b+1_0_-2", make_args=lambda b: (5 * b + 1, 0, -2, None)),
|
||||
dict(testcase_name="5b+2_0_-2", make_args=lambda b: (5 * b + 2, 0, -2, None)),
|
||||
dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)), # Cannot tell if size >= 0
|
||||
# Symbolic step
|
||||
dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)),
|
||||
dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)),
|
||||
dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)),
|
||||
dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)),
|
||||
# Float return type
|
||||
dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32))
|
||||
])
|
||||
def test_arange(self, make_args):
|
||||
def f_jax(x): # x: i32[b]
|
||||
return x[0] + jnp.arange(*(make_args(x.shape[0])))
|
||||
x = np.ones((6,), dtype=np.int32)
|
||||
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="b")(x),
|
||||
f_jax(x))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
# make_args invoked with op.shape[0]: start, stop, step, dtype
|
||||
kwargs=[
|
||||
@ -792,14 +760,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
return x[0] + jnp.arange(*(make_args(x.shape[0])))
|
||||
x = np.ones((3,), dtype=np.int32)
|
||||
with self.assertRaisesRegex(expect_error, expect_msg):
|
||||
jax2tf.convert(f_jax, polymorphic_shapes="b")(x)
|
||||
check_shape_poly(self, f_jax, arg_descriptors=[x],
|
||||
polymorphic_shapes=["b"])
|
||||
|
||||
def test_argmax(self):
|
||||
def f_jax(x): # x: f32[b, 4, 5]
|
||||
return lax.argmax(x, axis=1, index_dtype=np.int32)
|
||||
x = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
|
||||
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="(b, _, _)")(x),
|
||||
f_jax(x))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
@ -996,11 +959,13 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
expected_shapeenv=dict(a=2, b=3, c=4))
|
||||
|
||||
def test_arg_avals_errors(self):
|
||||
"""Test error reporting for shape polymorpish."""
|
||||
"""Test error reporting for shape polymorphism."""
|
||||
def conv_and_run(*, arg_shape: core.Shape,
|
||||
polymorphic_shape: str):
|
||||
arg = np.arange(math.prod(arg_shape), dtype=np.float32).reshape(arg_shape)
|
||||
jax2tf.convert(lambda x: x, polymorphic_shapes=[polymorphic_shape])(arg)
|
||||
check_shape_poly(self, lambda x: x,
|
||||
arg_descriptors=[arg],
|
||||
polymorphic_shapes=[polymorphic_shape])
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("polymorphic shape spec should be")):
|
||||
@ -1094,7 +1059,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
|
||||
_ = jax2tf.convert(f_jax, polymorphic_shapes=[poly_spec])(x)
|
||||
_ = check_shape_poly(self, f_jax,
|
||||
arg_descriptors=[x],
|
||||
polymorphic_shapes=[poly_spec])
|
||||
|
||||
def test_pytree(self):
|
||||
"""Arguments and polymorphic_shapes are pytrees."""
|
||||
@ -1372,7 +1339,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
res_jax = f(x, y)
|
||||
self.assertAllClose(
|
||||
res_jax,
|
||||
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
|
||||
check_shape_poly(self, f, arg_descriptors=[x, y],
|
||||
polymorphic_shapes=["(b, h)", "h"]))
|
||||
|
||||
def test_while(self):
|
||||
def f(x):
|
||||
@ -1382,7 +1350,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
(x, 0))
|
||||
|
||||
x = np.ones((3,), dtype=np.float32)
|
||||
res_tf = jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x)
|
||||
res_tf = check_shape_poly(self, f, arg_descriptors=[x],
|
||||
polymorphic_shapes=["(b,)"])
|
||||
self.assertAllClose(f(x), res_tf)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -1671,22 +1640,26 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
return jnp.sum(x, axis=0) * x.shape[0]
|
||||
|
||||
x = np.arange(3.)
|
||||
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x))
|
||||
self.assertAllClose(9.,
|
||||
check_shape_poly(self, f,
|
||||
arg_descriptors=[x],
|
||||
polymorphic_shapes=["(b,)"]))
|
||||
self.assertAllClose(
|
||||
9.,
|
||||
jax2tf.convert(jax.jit(f), polymorphic_shapes=["(b,)"])(x))
|
||||
self.assertAllClose(
|
||||
9.,
|
||||
tf.function(jax2tf.convert(f, polymorphic_shapes=["(b,)"]))(x))
|
||||
check_shape_poly(self, jax.jit(f),
|
||||
arg_descriptors=[x], polymorphic_shapes=["(b,)"]))
|
||||
|
||||
res_primal, res_tangent = jax2tf.convert(
|
||||
res_primal, res_tangent = check_shape_poly(self,
|
||||
lambda x, xt: jax.jvp(f, (x,), (xt,)),
|
||||
polymorphic_shapes=["b", "b"])(x, np.array([0.1, 0.2, 0.3]))
|
||||
arg_descriptors=[x, np.array([0.1, 0.2, 0.3])],
|
||||
polymorphic_shapes=["b", "b"])
|
||||
self.assertAllClose((9., 1.8), (res_primal, res_tangent))
|
||||
|
||||
self.assertAllClose(
|
||||
np.array([3., 3., 3.]),
|
||||
jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x))
|
||||
check_shape_poly(self, jax.grad(f),
|
||||
arg_descriptors=[x],
|
||||
polymorphic_shapes=["b"]))
|
||||
|
||||
xv = np.arange(24.).reshape((2, 3, 4))
|
||||
res_vmap = jax.vmap(f, in_axes=1)(xv)
|
||||
@ -1694,9 +1667,10 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])])
|
||||
self.assertAllClose(res_iter, res_vmap)
|
||||
|
||||
res_vmap_tf = jax2tf.convert(jax.vmap(f, in_axes=1),
|
||||
polymorphic_shapes=["b1, b2, ..."])(xv)
|
||||
self.assertAllClose(res_iter, res_vmap_tf.numpy())
|
||||
res_vmap_tf = check_shape_poly(self, jax.vmap(f, in_axes=1),
|
||||
arg_descriptors=[xv],
|
||||
polymorphic_shapes=["b1, b2, ..."])
|
||||
self.assertAllClose(res_iter, res_vmap_tf)
|
||||
|
||||
def test_with_hash_collision_vmap(self):
|
||||
# Batching caches based on Jaxpr, and Jaxpr include _DimExpr. If we have
|
||||
@ -1948,33 +1922,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
res = jax2tf.convert(f2, polymorphic_shapes=zw_polymorphic_shapes)(z, w)
|
||||
self.assertAllClose(f2(* f1(x, y)), res)
|
||||
|
||||
def test_gather_1d(self):
|
||||
operand = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], np.float32)
|
||||
rand_idxs = np.random.randint(0, high=max(operand.shape), size=(3, 1), dtype=np.int32)
|
||||
slice_x = np.zeros((10,), dtype=jnp.float32)
|
||||
dnums = lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)
|
||||
)
|
||||
|
||||
@jax.jit
|
||||
def f_jax(operand, start_indices, x):
|
||||
return lax.gather(
|
||||
operand,
|
||||
start_indices,
|
||||
dimension_numbers=dnums,
|
||||
slice_sizes=x.shape,
|
||||
mode="promise_in_bounds",
|
||||
)
|
||||
|
||||
res = f_jax(operand, rand_idxs, slice_x)
|
||||
f_tf = jax2tf.convert(
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
polymorphic_shapes=["(t, )", "(3, 1)", "(t)"],
|
||||
)
|
||||
res_tf = f_tf(operand, rand_idxs, slice_x)
|
||||
self.assertAllClose(res, res_tf)
|
||||
|
||||
|
||||
# List containing either harnesses, or lists of harnesses
|
||||
_POLY_SHAPE_TEST_HARNESSES = [
|
||||
@ -1986,6 +1933,45 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + jnp.sin(x))),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
[
|
||||
# make_args invoked with op.shape[0] and produces the arange args:
|
||||
# start, stop, step, dtype
|
||||
PolyHarness("arange", kwargs["testcase_name"], # type: ignore
|
||||
lambda x: jnp.arange(*(kwargs["make_args"](x.shape[0]))), # type: ignore
|
||||
arg_descriptors=[RandArg((6,), np.float32)],
|
||||
polymorphic_shapes=["b"])
|
||||
for kwargs in [
|
||||
# Positive step
|
||||
dict(testcase_name="b", make_args=lambda b: (b, None, None, None)),
|
||||
dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)),
|
||||
dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)),
|
||||
dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)),
|
||||
dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)),
|
||||
dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)),
|
||||
dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)),
|
||||
dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)),
|
||||
dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)),
|
||||
dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)),
|
||||
# Cannot tell if size >= 0
|
||||
# Negative step
|
||||
dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)),
|
||||
dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)),
|
||||
dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)),
|
||||
dict(testcase_name="5b+1_0_-2",
|
||||
make_args=lambda b: (5 * b + 1, 0, -2, None)),
|
||||
dict(testcase_name="5b+2_0_-2",
|
||||
make_args=lambda b: (5 * b + 2, 0, -2, None)),
|
||||
dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)),
|
||||
# Cannot tell if size >= 0
|
||||
# Symbolic step
|
||||
dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)),
|
||||
dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)),
|
||||
dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)),
|
||||
dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)),
|
||||
# Float return type
|
||||
dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32))
|
||||
]
|
||||
],
|
||||
# Reduce the poly dimension
|
||||
PolyHarness("argmax", "0",
|
||||
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
|
||||
@ -2328,6 +2314,23 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: lax.full((x.shape[0], 2), 3.) + x,
|
||||
arg_descriptors=[RandArg((3, 1), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("gather", "1d",
|
||||
lambda operand, start_indices, x: lax.gather(
|
||||
operand,
|
||||
start_indices,
|
||||
dimension_numbers=lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,),
|
||||
collapsed_slice_dims=(),
|
||||
start_index_map=(0,)),
|
||||
slice_sizes=x.shape,
|
||||
mode="promise_in_bounds"),
|
||||
arg_descriptors=[
|
||||
RandArg((10,), np.float32),
|
||||
np.random.randint(0, high=10, size=(3, 1),
|
||||
dtype=np.int32),
|
||||
np.zeros((10,), dtype=jnp.int32),
|
||||
],
|
||||
polymorphic_shapes=["(t, )", "(3, 1)", "(t)"]),
|
||||
# operand is non-poly, index is poly
|
||||
PolyHarness("getitem", "op=static_idx=poly",
|
||||
lambda a, i: a[i],
|
||||
|
Loading…
x
Reference in New Issue
Block a user