mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[export] Add and fix a test for exporting higher-order gradients with sharding
There was a test for export with gradients, we changed the test to (a) export 2nd order gradient also, and (b) to export both with a mesh context and without a mesh context (using NamedSharding). This test currently fails, only in the case when we do NOT have a mesh context, as explained below: When exporting gradient functions, we first export the primal functions and we use the in/out-shardings to construct shardings of the gradient function. Since Exported shardings now contain only HloSharding objects, and to lower the gradient function we must use `pjit(vjp(f)).lower()`, we construct GSPMDSharding objects using the current devices and the HloSharding object from the Exported primal. However, these objects do not have the `_original_sharding` attribute. Later in `pjit._resource_typing_pjit` we attempt to `parse_flatten_op_sharding` using the mesh context (which is empty). This fails. This PR contains one workaround, to skip `parse_flatten_op_sharding` if the physical mesh of the `resource_env` is empty. Another, probably better solution, is to ensure that `resource_env` is `None` when then is no mesh context. That seemed reasonable, but currently the code returns an empty mesh from the resource_env if there is no mesh context. Changing this would have effects in more parts of the code, so I have not done it here, but it may be worth doing.
This commit is contained in:
parent
61e79cd4f4
commit
8a2d4a01f5
@ -1816,7 +1816,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
|
||||
s._original_sharding, '_parsed_pspec'):
|
||||
parsed_pspec = s._original_sharding._parsed_pspec
|
||||
else:
|
||||
if resource_env is not None:
|
||||
if resource_env is not None and not resource_env.physical_mesh.empty:
|
||||
parsed_pspec = parse_flatten_op_sharding(
|
||||
s._hlo_sharding, resource_env.physical_mesh)[0]
|
||||
else:
|
||||
@ -1838,7 +1838,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
|
||||
s._original_sharding, '_parsed_pspec'):
|
||||
parsed_pspec = s._original_sharding._parsed_pspec
|
||||
else:
|
||||
if resource_env is not None:
|
||||
if resource_env is not None and not resource_env.physical_mesh.empty:
|
||||
parsed_pspec = parse_flatten_op_sharding(
|
||||
s._hlo_sharding, resource_env.physical_mesh)[0]
|
||||
else:
|
||||
|
@ -22,6 +22,7 @@ from collections.abc import Sequence
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
@ -40,6 +41,7 @@ from jax.experimental import jax2tf
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import jax.numpy as jnp
|
||||
@ -382,16 +384,25 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
for in_shardings in ("missing", None, "P")
|
||||
for out_shardings in ("missing", None, "P")
|
||||
])
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_grad_pjit(self, in_shardings="P", out_shardings=None):
|
||||
if not config.jax2tf_default_native_serialization.value:
|
||||
self.skipTest("TODO: failure in non-native serialization")
|
||||
local_devices = list(jax.local_devices())
|
||||
size = 2
|
||||
if len(local_devices) < size:
|
||||
raise unittest.SkipTest(f"Test requires {size} local devices")
|
||||
mesh_devices = np.array(local_devices[:size]).reshape((2,))
|
||||
mesh = jax.sharding.Mesh(mesh_devices, ("x",))
|
||||
def f_jax(x): # x: f32[10,20] -> f32[20,10]
|
||||
return jnp.sin(x.T)
|
||||
|
||||
pjit_kwargs = {}
|
||||
if in_shardings != "missing":
|
||||
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
|
||||
pjit_kwargs["in_shardings"] = (
|
||||
NamedSharding(mesh, P(None, "x")) if in_shardings == "P" else None)
|
||||
if out_shardings != "missing":
|
||||
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
|
||||
pjit_kwargs["out_shardings"] = (
|
||||
NamedSharding(mesh, P("x", None)) if out_shardings == "P" else None)
|
||||
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
|
||||
x_shape = (10, 20)
|
||||
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
||||
@ -399,8 +410,12 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
def f_grad_tf(x_v, res_ct):
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
tape.watch(x_v)
|
||||
res_tf = jax2tf.convert(f_jax)(x_v)
|
||||
return tape.gradient(res_tf, x_v, output_gradients=res_ct)
|
||||
with tf.GradientTape() as tape2:
|
||||
tape2.watch(x_v)
|
||||
res_tf = jax2tf.convert(f_jax)(x_v)
|
||||
dy_dx = tape.gradient(res_tf, x_v, output_gradients=res_ct)
|
||||
d2y_dx2 = tape.gradient(dy_dx, x_v)
|
||||
return d2y_dx2
|
||||
|
||||
# Annotation count for the primal input and the grad output
|
||||
count_in_P = self.GEQ(2) if in_shardings == "P" else 0
|
||||
|
@ -28,6 +28,7 @@ from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental import pjit
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
@ -755,13 +756,16 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
)(a)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
one_containing="in_shardings_None_out_shardings_P_with_mesh_False",
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
||||
in_shardings=in_shardings, out_shardings=out_shardings)
|
||||
dict(in_shardings=in_shardings, out_shardings=out_shardings,
|
||||
with_mesh=with_mesh)
|
||||
for in_shardings in ("missing", None, "P")
|
||||
for out_shardings in ("missing", None, "P")
|
||||
for with_mesh in (True, False)
|
||||
])
|
||||
def test_grad_with_sharding(self, in_shardings="P", out_shardings=None):
|
||||
def test_grad_with_sharding(self, in_shardings="P", out_shardings=None,
|
||||
with_mesh=False):
|
||||
if len(jax.devices()) < 2:
|
||||
self.skipTest("Test requires at least 2 devices")
|
||||
x_shape = (10, 20)
|
||||
@ -769,16 +773,33 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f_jax(x): # x: f32[10,20] -> f32[20,10]
|
||||
return jnp.sin(x.T)
|
||||
|
||||
mesh = Mesh(jax.devices()[:2], "d")
|
||||
pjit_kwargs = {}
|
||||
if in_shardings != "missing":
|
||||
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
|
||||
if out_shardings != "missing":
|
||||
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
|
||||
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
|
||||
# Use NamedShardings if we don't have a mesh_context
|
||||
if with_mesh:
|
||||
sharding_None_d = P(None, "d")
|
||||
sharding_d_None = P("d", None)
|
||||
else:
|
||||
sharding_None_d = NamedSharding(mesh, P(None, "d"))
|
||||
sharding_d_None = NamedSharding(mesh, P("d", None))
|
||||
|
||||
if in_shardings != "missing":
|
||||
pjit_kwargs["in_shardings"] = (
|
||||
sharding_None_d if in_shardings == "P" else None)
|
||||
if out_shardings != "missing":
|
||||
pjit_kwargs["out_shardings"] = (
|
||||
sharding_d_None if out_shardings == "P" else None)
|
||||
f_jax_pjit = pjit.pjit(f_jax, **pjit_kwargs)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if with_mesh:
|
||||
stack.enter_context(mesh)
|
||||
# Serialize higher-order gradiends
|
||||
exp = export.export(f_jax_pjit)(x)
|
||||
|
||||
with Mesh(jax.devices()[:2], "x"):
|
||||
exp = export.export(f_jax)(x)
|
||||
exp_vjp = exp.vjp()
|
||||
# Try 2nd order grad as well
|
||||
exp_vjp2 = exp_vjp.vjp()
|
||||
|
||||
vjp_module_str = str(exp_vjp.mlir_module())
|
||||
|
||||
@ -812,13 +833,41 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
# Custom calls for the primal output shape all match primal_out_sharding
|
||||
primal_out_calls = re.findall(
|
||||
r"custom_call @Sharding.* {mhlo.sharding = \"(.+)\".*:.*tensor<20x10xf32>",
|
||||
r"custom_call @Sharding.*mhlo.sharding = \"(.+)\".*:.*tensor<20x10xf32>",
|
||||
vjp_module_str)
|
||||
self.assertTrue(
|
||||
all(s == primal_out_sharding for s in primal_out_calls),
|
||||
primal_in_calls
|
||||
)
|
||||
|
||||
# Call the exported gradient functions. In order to set the device context
|
||||
# we replicate the inputs. If we don't use a mesh context and there are
|
||||
# no shardings on inputs or outputs, then we have serialized for one
|
||||
# device.
|
||||
if in_shardings != "P" and out_shardings != "P" and not with_mesh:
|
||||
self.assertEqual(exp_vjp.nr_devices, 1)
|
||||
self.assertEqual(exp_vjp2.nr_devices, 1)
|
||||
call_mesh = Mesh(jax.devices()[:1], "e")
|
||||
else:
|
||||
self.assertEqual(exp_vjp.nr_devices, 2)
|
||||
self.assertEqual(exp_vjp2.nr_devices, 2)
|
||||
call_mesh = Mesh(jax.devices()[:2], "e")
|
||||
|
||||
g1 = pjit.pjit(export.call_exported(exp_vjp),
|
||||
in_shardings=(NamedSharding(call_mesh, None),
|
||||
NamedSharding(call_mesh, None)))(x, x.T)
|
||||
_, f_jax_vjp = jax.vjp(f_jax, x)
|
||||
xbar = f_jax_vjp(x.T)
|
||||
self.assertAllClose(xbar, g1)
|
||||
|
||||
g2 = pjit.pjit(export.call_exported(exp_vjp2),
|
||||
in_shardings=(NamedSharding(call_mesh, None),
|
||||
NamedSharding(call_mesh, None),
|
||||
NamedSharding(call_mesh, None)))(x, x.T, x)
|
||||
_, f_jax_vjp2 = jax.vjp(f_jax_vjp, x.T)
|
||||
xbar2, = f_jax_vjp2((x,))
|
||||
self.assertAllClose(xbar2, g2[1])
|
||||
|
||||
def test_multi_platform(self):
|
||||
x = np.arange(8, dtype=np.float32)
|
||||
exp = export.export(_testing_multi_platform_func,
|
||||
|
Loading…
x
Reference in New Issue
Block a user