mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add systematic pdot vjp tests
This commit is contained in:
parent
1fd1faa06c
commit
537c3d5c84
@ -650,6 +650,54 @@ class PDotTests(jtu.JaxTestCase):
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(result, expected, check_dtypes=False,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{next(test_counter)}",
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
|
||||
"axis_resources": axis_resources, "mesh_data": mesh_data}
|
||||
for test_counter in [it.count()]
|
||||
for lhs_shape, rhs_shape in product(
|
||||
[(2,), (2, 4, 2, 1)],
|
||||
repeat=2)
|
||||
for pdot_spec in all_pdot_specs(lhs_shape, rhs_shape)
|
||||
for axis_resources, mesh_data in schedules_from_pdot_spec(
|
||||
pdot_spec, lhs_shape, rhs_shape)))
|
||||
@ignore_xmap_warning()
|
||||
def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec,
|
||||
axis_resources, mesh_data):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs = rng(lhs_shape, np.float32)
|
||||
rhs = rng(rhs_shape, np.float32)
|
||||
|
||||
expected_out, ref_vjp = jax.vjp(
|
||||
lambda x, y: lax.dot_general(x, y, pdot_spec.dot_general_dim_nums),
|
||||
lhs, rhs)
|
||||
out_bar = rng(expected_out.shape, np.float32)
|
||||
expected_lhs, expected_rhs = ref_vjp(out_bar)
|
||||
|
||||
def pdot_fun(x, y, out_bar):
|
||||
pdot = partial(jax.lax.pdot,
|
||||
axis_name=pdot_spec.contract_names,
|
||||
pos_batch=pdot_spec.pos_batch_after_mapping,
|
||||
pos_contract=pdot_spec.pos_contract_after_mapping)
|
||||
_, pdot_vjp = jax.vjp(pdot, x, y)
|
||||
return pdot_vjp(out_bar)
|
||||
|
||||
fun = xmap(pdot_fun,
|
||||
in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes,
|
||||
[*pdot_spec.batch_names, ...]],
|
||||
out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
|
||||
axis_resources=axis_resources)
|
||||
|
||||
with with_mesh(mesh_data):
|
||||
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
|
||||
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def test_xeinsum_vector_dot(self):
|
||||
rng = np.random.RandomState(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user