[shape_poly] Refactor shape_poly_test in preparation for moving out of jax2tf.

The shape polymotphism is now independent of jax2tf and the code is actually
out of jax2tf. Here we refactor shape_poly_test to prepare for moving most of
out of jax2tf.

The main change is that we replace `jax2tf.convert(f_jax)(*args)` with
a call to `check_shape_poly` which now still uses `jax2tf` but in the
future will use JAX native mechanisms.
This commit is contained in:
George Necula 2023-11-12 18:17:07 +01:00
parent c8f3e2370d
commit f9474b221c

View File

@ -560,7 +560,7 @@ class PolyHarness(Harness):
self.name = f"{self.name}_enable_xla_True"
return (self, other)
def run_test(self, tst: tf_test_util.JaxToTfTestCase):
def run_test(self, tst: tf_test_util.JaxToTfTestCase) -> Optional[jax.Array]:
def log_message(extra: str):
return f"[{tst._testMethodName}]: {extra}"
@ -609,7 +609,7 @@ class PolyHarness(Harness):
concrete_f_tf = f_tf_func.get_concrete_function(*input_signature)
if expect_error_type is not None:
return
return None
if self.expected_output_signature:
# Strangely, output_shapes can be a single shape for a function with a
@ -649,6 +649,11 @@ class PolyHarness(Harness):
f"to {custom_assert_lims[0]}"))
custom_assert_lims[0].custom_assert(tst, res_jax, res_tf, args=args, # type: ignore
tol=tol, err_msg=None)
return res_tf
else:
return None
else:
return None
def check_shape_poly(tst, f_jax: Callable, *,
@ -657,7 +662,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
polymorphic_shapes: Sequence[Optional[str]] = (),
input_signature: Optional[Sequence[tf.TensorSpec]] = None,
expected_output_signature: Optional[tf.TensorSpec] = None,
expect_error=(None, None)):
expect_error=(None, None)) -> Optional[jax.Array]:
# Makes and tests a harness. See PolyHarness documentation.
h = PolyHarness("", "", f_jax,
arg_descriptors=arg_descriptors,
@ -666,7 +671,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
input_signature=input_signature,
expected_output_signature=expected_output_signature,
expect_error=expect_error)
h.run_test(tst)
return h.run_test(tst)
class ShapePolyTest(tf_test_util.JaxToTfTestCase):
@ -730,43 +735,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=["h, h", "h, h"],
expected_output_signature=tf.TensorSpec([None, None]))
@jtu.parameterized_filterable(
# make_args invoked with op.shape[0]: start, stop, step, dtype
# b == 6
kwargs=[
# Positive step
dict(testcase_name="b", make_args=lambda b: (b, None, None, None)),
dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)),
dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)),
dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)),
dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)),
dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)),
dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)),
dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)),
dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)),
dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)), # Cannot tell if size >= 0
# Negative step
dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)),
dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)),
dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)),
dict(testcase_name="5b+1_0_-2", make_args=lambda b: (5 * b + 1, 0, -2, None)),
dict(testcase_name="5b+2_0_-2", make_args=lambda b: (5 * b + 2, 0, -2, None)),
dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)), # Cannot tell if size >= 0
# Symbolic step
dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)),
dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)),
dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)),
dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)),
# Float return type
dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32))
])
def test_arange(self, make_args):
def f_jax(x): # x: i32[b]
return x[0] + jnp.arange(*(make_args(x.shape[0])))
x = np.ones((6,), dtype=np.int32)
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="b")(x),
f_jax(x))
@jtu.parameterized_filterable(
# make_args invoked with op.shape[0]: start, stop, step, dtype
kwargs=[
@ -792,14 +760,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
return x[0] + jnp.arange(*(make_args(x.shape[0])))
x = np.ones((3,), dtype=np.int32)
with self.assertRaisesRegex(expect_error, expect_msg):
jax2tf.convert(f_jax, polymorphic_shapes="b")(x)
check_shape_poly(self, f_jax, arg_descriptors=[x],
polymorphic_shapes=["b"])
def test_argmax(self):
def f_jax(x): # x: f32[b, 4, 5]
return lax.argmax(x, axis=1, index_dtype=np.int32)
x = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="(b, _, _)")(x),
f_jax(x))
@jtu.parameterized_filterable(
kwargs=[
@ -996,11 +959,13 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
expected_shapeenv=dict(a=2, b=3, c=4))
def test_arg_avals_errors(self):
"""Test error reporting for shape polymorpish."""
"""Test error reporting for shape polymorphism."""
def conv_and_run(*, arg_shape: core.Shape,
polymorphic_shape: str):
arg = np.arange(math.prod(arg_shape), dtype=np.float32).reshape(arg_shape)
jax2tf.convert(lambda x: x, polymorphic_shapes=[polymorphic_shape])(arg)
check_shape_poly(self, lambda x: x,
arg_descriptors=[arg],
polymorphic_shapes=[polymorphic_shape])
with self.assertRaisesRegex(ValueError,
re.escape("polymorphic shape spec should be")):
@ -1094,7 +1059,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
_ = jax2tf.convert(f_jax, polymorphic_shapes=[poly_spec])(x)
_ = check_shape_poly(self, f_jax,
arg_descriptors=[x],
polymorphic_shapes=[poly_spec])
def test_pytree(self):
"""Arguments and polymorphic_shapes are pytrees."""
@ -1372,7 +1339,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
res_jax = f(x, y)
self.assertAllClose(
res_jax,
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
check_shape_poly(self, f, arg_descriptors=[x, y],
polymorphic_shapes=["(b, h)", "h"]))
def test_while(self):
def f(x):
@ -1382,7 +1350,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
(x, 0))
x = np.ones((3,), dtype=np.float32)
res_tf = jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x)
res_tf = check_shape_poly(self, f, arg_descriptors=[x],
polymorphic_shapes=["(b,)"])
self.assertAllClose(f(x), res_tf)
@jtu.parameterized_filterable(
@ -1671,22 +1640,26 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
return jnp.sum(x, axis=0) * x.shape[0]
x = np.arange(3.)
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x))
self.assertAllClose(9.,
check_shape_poly(self, f,
arg_descriptors=[x],
polymorphic_shapes=["(b,)"]))
self.assertAllClose(
9.,
jax2tf.convert(jax.jit(f), polymorphic_shapes=["(b,)"])(x))
self.assertAllClose(
9.,
tf.function(jax2tf.convert(f, polymorphic_shapes=["(b,)"]))(x))
check_shape_poly(self, jax.jit(f),
arg_descriptors=[x], polymorphic_shapes=["(b,)"]))
res_primal, res_tangent = jax2tf.convert(
res_primal, res_tangent = check_shape_poly(self,
lambda x, xt: jax.jvp(f, (x,), (xt,)),
polymorphic_shapes=["b", "b"])(x, np.array([0.1, 0.2, 0.3]))
arg_descriptors=[x, np.array([0.1, 0.2, 0.3])],
polymorphic_shapes=["b", "b"])
self.assertAllClose((9., 1.8), (res_primal, res_tangent))
self.assertAllClose(
np.array([3., 3., 3.]),
jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x))
check_shape_poly(self, jax.grad(f),
arg_descriptors=[x],
polymorphic_shapes=["b"]))
xv = np.arange(24.).reshape((2, 3, 4))
res_vmap = jax.vmap(f, in_axes=1)(xv)
@ -1694,9 +1667,10 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])])
self.assertAllClose(res_iter, res_vmap)
res_vmap_tf = jax2tf.convert(jax.vmap(f, in_axes=1),
polymorphic_shapes=["b1, b2, ..."])(xv)
self.assertAllClose(res_iter, res_vmap_tf.numpy())
res_vmap_tf = check_shape_poly(self, jax.vmap(f, in_axes=1),
arg_descriptors=[xv],
polymorphic_shapes=["b1, b2, ..."])
self.assertAllClose(res_iter, res_vmap_tf)
def test_with_hash_collision_vmap(self):
# Batching caches based on Jaxpr, and Jaxpr include _DimExpr. If we have
@ -1948,33 +1922,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
res = jax2tf.convert(f2, polymorphic_shapes=zw_polymorphic_shapes)(z, w)
self.assertAllClose(f2(* f1(x, y)), res)
def test_gather_1d(self):
operand = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], np.float32)
rand_idxs = np.random.randint(0, high=max(operand.shape), size=(3, 1), dtype=np.int32)
slice_x = np.zeros((10,), dtype=jnp.float32)
dnums = lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)
)
@jax.jit
def f_jax(operand, start_indices, x):
return lax.gather(
operand,
start_indices,
dimension_numbers=dnums,
slice_sizes=x.shape,
mode="promise_in_bounds",
)
res = f_jax(operand, rand_idxs, slice_x)
f_tf = jax2tf.convert(
f_jax,
native_serialization=True,
polymorphic_shapes=["(t, )", "(3, 1)", "(t)"],
)
res_tf = f_tf(operand, rand_idxs, slice_x)
self.assertAllClose(res, res_tf)
# List containing either harnesses, or lists of harnesses
_POLY_SHAPE_TEST_HARNESSES = [
@ -1986,6 +1933,45 @@ _POLY_SHAPE_TEST_HARNESSES = [
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + jnp.sin(x))),
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]),
[
# make_args invoked with op.shape[0] and produces the arange args:
# start, stop, step, dtype
PolyHarness("arange", kwargs["testcase_name"], # type: ignore
lambda x: jnp.arange(*(kwargs["make_args"](x.shape[0]))), # type: ignore
arg_descriptors=[RandArg((6,), np.float32)],
polymorphic_shapes=["b"])
for kwargs in [
# Positive step
dict(testcase_name="b", make_args=lambda b: (b, None, None, None)),
dict(testcase_name="0_b+1", make_args=lambda b: (0, b + 1, None, None)),
dict(testcase_name="0_5b_2", make_args=lambda b: (0, 5 * b, 2, None)),
dict(testcase_name="0_5b+1_2", make_args=lambda b: (0, 5 * b + 1, 2, None)),
dict(testcase_name="b_5b+2_2", make_args=lambda b: (b, 5 * b + 2, 2, None)),
dict(testcase_name="0_b-1_2", make_args=lambda b: (0, b - 1, 2, None)),
dict(testcase_name="0_b-2_2", make_args=lambda b: (0, b - 2, 2, None)),
dict(testcase_name="0_-b_2", make_args=lambda b: (0, -b, 2, None)),
dict(testcase_name="0_1-b_2", make_args=lambda b: (0, 1 - b, 2, None)),
dict(testcase_name="0_b-3_2", make_args=lambda b: (0, b - 3, 2, None)),
# Cannot tell if size >= 0
# Negative step
dict(testcase_name="b_0_-1", make_args=lambda b: (b, 0, -1, None)),
dict(testcase_name="b_1_-2", make_args=lambda b: (b, 1, -2, None)),
dict(testcase_name="b_-1_-1", make_args=lambda b: (b, -1, -1, None)),
dict(testcase_name="5b+1_0_-2",
make_args=lambda b: (5 * b + 1, 0, -2, None)),
dict(testcase_name="5b+2_0_-2",
make_args=lambda b: (5 * b + 2, 0, -2, None)),
dict(testcase_name="b-3_0_-2", make_args=lambda b: (b - 3, 0, -2, None)),
# Cannot tell if size >= 0
# Symbolic step
dict(testcase_name="0_10_b", make_args=lambda b: (0, 10, b)),
dict(testcase_name="0_0_b", make_args=lambda b: (0, 0, b)),
dict(testcase_name="10_0_-b", make_args=lambda b: (10, 0, -b)),
dict(testcase_name="b_1_-b", make_args=lambda b: (b, 1, -b)),
# Float return type
dict(testcase_name="0_b_1_f32", make_args=lambda b: (0, b, 1, np.float32))
]
],
# Reduce the poly dimension
PolyHarness("argmax", "0",
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
@ -2328,6 +2314,23 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x: lax.full((x.shape[0], 2), 3.) + x,
arg_descriptors=[RandArg((3, 1), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("gather", "1d",
lambda operand, start_indices, x: lax.gather(
operand,
start_indices,
dimension_numbers=lax.GatherDimensionNumbers(
offset_dims=(1,),
collapsed_slice_dims=(),
start_index_map=(0,)),
slice_sizes=x.shape,
mode="promise_in_bounds"),
arg_descriptors=[
RandArg((10,), np.float32),
np.random.randint(0, high=10, size=(3, 1),
dtype=np.int32),
np.zeros((10,), dtype=jnp.int32),
],
polymorphic_shapes=["(t, )", "(3, 1)", "(t)"]),
# operand is non-poly, index is poly
PolyHarness("getitem", "op=static_idx=poly",
lambda a, i: a[i],