add systematic pdot vjp tests

This commit is contained in:
Matthew Johnson 2021-01-26 17:03:58 -08:00
parent 1fd1faa06c
commit 537c3d5c84

View File

@ -650,6 +650,54 @@ class PDotTests(jtu.JaxTestCase):
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(result, expected, check_dtypes=False,
atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{next(test_counter)}",
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
"axis_resources": axis_resources, "mesh_data": mesh_data}
for test_counter in [it.count()]
for lhs_shape, rhs_shape in product(
[(2,), (2, 4, 2, 1)],
repeat=2)
for pdot_spec in all_pdot_specs(lhs_shape, rhs_shape)
for axis_resources, mesh_data in schedules_from_pdot_spec(
pdot_spec, lhs_shape, rhs_shape)))
@ignore_xmap_warning()
def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec,
axis_resources, mesh_data):
rng = jtu.rand_default(self.rng())
lhs = rng(lhs_shape, np.float32)
rhs = rng(rhs_shape, np.float32)
expected_out, ref_vjp = jax.vjp(
lambda x, y: lax.dot_general(x, y, pdot_spec.dot_general_dim_nums),
lhs, rhs)
out_bar = rng(expected_out.shape, np.float32)
expected_lhs, expected_rhs = ref_vjp(out_bar)
def pdot_fun(x, y, out_bar):
pdot = partial(jax.lax.pdot,
axis_name=pdot_spec.contract_names,
pos_batch=pdot_spec.pos_batch_after_mapping,
pos_contract=pdot_spec.pos_contract_after_mapping)
_, pdot_vjp = jax.vjp(pdot, x, y)
return pdot_vjp(out_bar)
fun = xmap(pdot_fun,
in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes,
[*pdot_spec.batch_names, ...]],
out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
axis_resources=axis_resources)
with with_mesh(mesh_data):
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False,
atol=tol, rtol=tol)
self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False,
atol=tol, rtol=tol)
@ignore_xmap_warning()
def test_xeinsum_vector_dot(self):
rng = np.random.RandomState(0)