mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15129 from gnecula:tf_grad_pjit
PiperOrigin-RevId: 518519977
This commit is contained in:
commit
95d5e153b6
@ -784,7 +784,7 @@ def flatten_axis_resources(what, tree, shardings, tupled_args):
|
||||
|
||||
axis_tree = shardings
|
||||
|
||||
# Because ecause we only have the `tree` treedef and not the full pytree here,
|
||||
# Because we only have the `tree` treedef and not the full pytree here,
|
||||
# we construct a dummy tree to compare against. Revise this in callers?
|
||||
dummy_tree = tree_unflatten(tree, [PytreeLeaf()] * tree.num_leaves)
|
||||
errors = prefix_errors(axis_tree, dummy_tree)
|
||||
|
@ -430,16 +430,15 @@ def convert(fun_jax: Callable,
|
||||
lowering_platform = native_serialization_platforms[0]
|
||||
else:
|
||||
lowering_platform = None
|
||||
exported: Exported = serialize_native(
|
||||
exported: Optional[Exported] = serialize_native(
|
||||
fun_flat_jax, args_avals_flat,
|
||||
lowering_platform=lowering_platform,
|
||||
strict_checks=native_serialization_strict_checks)
|
||||
|
||||
def run_fun_flat_as_tf(
|
||||
args_flat_tf: Sequence[TfVal]
|
||||
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
|
||||
outs_tf, out_avals = run_exported_as_tf(
|
||||
args_avals_flat, args_flat_tf, exported,
|
||||
args_avals_flat, args_flat_tf, exported, # type: ignore
|
||||
native_serialization_strict_checks)
|
||||
return outs_tf, out_avals
|
||||
else:
|
||||
@ -448,6 +447,7 @@ def convert(fun_jax: Callable,
|
||||
dim_values, _ = _interpret_fun_jax(get_dim_values_jax, args_flat_tf,
|
||||
args_avals_flat, name_stack)
|
||||
shape_env = zip(dim_vars, dim_values) # type: ignore
|
||||
exported = None
|
||||
def run_fun_flat_as_tf(
|
||||
args_flat_tf: Sequence[TfVal]
|
||||
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
|
||||
@ -477,7 +477,7 @@ def convert(fun_jax: Callable,
|
||||
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
|
||||
outs_tf, out_avals = run_fun_flat_as_tf(args_flat_tf)
|
||||
return (tuple(outs_tf),
|
||||
make_custom_gradient_fn_tf(
|
||||
_make_custom_gradient_fn_tf(
|
||||
fun_flat_jax=fun_flat_jax,
|
||||
args_flat_tf=args_flat_tf,
|
||||
args_avals_flat=args_avals_flat,
|
||||
@ -485,7 +485,8 @@ def convert(fun_jax: Callable,
|
||||
out_avals=out_avals,
|
||||
native_serialization=native_serialization,
|
||||
native_serialization_platforms=native_serialization_platforms,
|
||||
native_serialization_strict_checks=native_serialization_strict_checks))
|
||||
native_serialization_strict_checks=native_serialization_strict_checks,
|
||||
exported_primal=exported))
|
||||
|
||||
out_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
|
||||
else:
|
||||
@ -599,17 +600,19 @@ def preprocess_arg_tf(arg_idx: int,
|
||||
return arg_tf, arg_aval
|
||||
|
||||
|
||||
# Prepare the grad_fn for tf.custom_gradient.
|
||||
def make_custom_gradient_fn_tf(*,
|
||||
fun_flat_jax: Callable,
|
||||
args_flat_tf: Sequence[TfVal],
|
||||
polymorphic_shapes_flat: Sequence[str],
|
||||
args_avals_flat: Sequence[core.ShapedArray],
|
||||
out_avals: Sequence[core.ShapedArray],
|
||||
native_serialization: Union[str, bool],
|
||||
native_serialization_platforms: Sequence[str],
|
||||
native_serialization_strict_checks: bool
|
||||
):
|
||||
def _make_custom_gradient_fn_tf(*,
|
||||
fun_flat_jax: Callable,
|
||||
args_flat_tf: Sequence[TfVal],
|
||||
polymorphic_shapes_flat: Sequence[str],
|
||||
args_avals_flat: Sequence[core.ShapedArray],
|
||||
out_avals: Sequence[core.ShapedArray],
|
||||
native_serialization: Union[str, bool],
|
||||
native_serialization_platforms: Sequence[str],
|
||||
native_serialization_strict_checks: bool,
|
||||
exported_primal: Optional["Exported"]):
|
||||
"""Prepares the TF function to be used with tf.custom_gradient.
|
||||
|
||||
"""
|
||||
|
||||
def grad_fn_tf(*out_cts_flat_tf: TfVal,
|
||||
variables=None):
|
||||
@ -659,6 +662,45 @@ def make_custom_gradient_fn_tf(*,
|
||||
in_cts_fixed_flat_jax = tuple(map(fix_in_ct, in_cts_flat_jax, args_avals_flat))
|
||||
return in_cts_fixed_flat_jax
|
||||
|
||||
if exported_primal is not None:
|
||||
# Native lowering
|
||||
all_in_shardings = [pxla._UNSPECIFIED] * len(exported_primal.in_avals)
|
||||
for idx, in_s in zip(sorted(exported_primal.module_kept_var_idx),
|
||||
exported_primal.in_shardings):
|
||||
all_in_shardings[idx] = in_s # type: ignore
|
||||
all_shardings = all_in_shardings + list(exported_primal.out_shardings)
|
||||
# We cannot mix unspecified and specified shardings. Make the unspecified
|
||||
# ones replicated
|
||||
specified_shardings = [
|
||||
s for s in all_shardings if not pxla._is_unspecified(s)]
|
||||
if 0 < len(specified_shardings) < len(all_shardings):
|
||||
# There are some specified, but not all
|
||||
in_s = specified_shardings[0] # pjit will enforce that all have same devices
|
||||
assert isinstance(in_s, sharding.XLACompatibleSharding)
|
||||
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
|
||||
all_shardings = [
|
||||
s if not pxla._is_unspecified(s) else replicated_s
|
||||
for s in all_shardings]
|
||||
# Since fun_vjp_jax takes two tuples of arguments we must split the in_shardings
|
||||
vjp_in_args_shardings, vjp_in_out_ct_shardings = util.split_list(all_shardings,
|
||||
[len(exported_primal.in_avals)])
|
||||
# pjit front-end does not like all-unspecified
|
||||
if all(pxla._is_unspecified(s) for s in vjp_in_args_shardings):
|
||||
vjp_in_args_shardings = pxla._UNSPECIFIED
|
||||
else:
|
||||
vjp_in_args_shardings = tuple(vjp_in_args_shardings)
|
||||
if all(pxla._is_unspecified(s) for s in vjp_in_out_ct_shardings):
|
||||
vjp_in_out_ct_shardings = pxla._UNSPECIFIED
|
||||
else:
|
||||
vjp_in_out_ct_shardings = tuple(vjp_in_out_ct_shardings)
|
||||
|
||||
if pxla._is_unspecified(vjp_in_args_shardings) and pxla._is_unspecified(vjp_in_args_shardings):
|
||||
vjp_in_shardings = pxla._UNSPECIFIED
|
||||
else:
|
||||
vjp_in_shardings = (vjp_in_args_shardings, vjp_in_out_ct_shardings)
|
||||
fun_vjp_jax = pjit.pjit(fun_vjp_jax,
|
||||
in_shardings=vjp_in_shardings,
|
||||
out_shardings=vjp_in_args_shardings)
|
||||
# TODO: enable higher-order gradients
|
||||
with tf.name_scope("jax2tf_vjp"):
|
||||
in_cts_flat = convert(
|
||||
@ -707,15 +749,16 @@ class Exported:
|
||||
"""Represents a lowered and serialized module."""
|
||||
in_avals: Sequence[core.ShapedArray]
|
||||
out_avals: Sequence[core.ShapedArray]
|
||||
in_shardings: Optional[Sequence[Any]]
|
||||
out_shardings: Optional[Sequence[Any]]
|
||||
# The in_shardings reflect only the module_ket_var_idx
|
||||
in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
|
||||
out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
|
||||
lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm"
|
||||
|
||||
mlir_module: mlir.ir.Module
|
||||
mlir_module_serialized: bytes # VHLO bytecode format
|
||||
xla_call_module_version: int # Follows the versions of XlaCallModule
|
||||
module_kept_var_idx: Sequence[bool] # Specifies if an argument is kept in the
|
||||
# lowering. As long as `out_avals`.
|
||||
module_kept_var_idx: Sequence[int] # Specifies if an argument is kept in the
|
||||
# lowering. As long as `out_avals`.
|
||||
dim_args_spec: Sequence[str]
|
||||
|
||||
def serialize_native(fun_jax: Callable,
|
||||
@ -767,7 +810,7 @@ def serialize_native(fun_jax: Callable,
|
||||
raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering")
|
||||
|
||||
if "kept_var_idx" in lowered.compile_args:
|
||||
module_kept_var_idx = lowered.compile_args["kept_var_idx"]
|
||||
module_kept_var_idx = tuple(sorted(lowered.compile_args["kept_var_idx"]))
|
||||
else:
|
||||
# For pmap
|
||||
module_kept_var_idx = tuple(range(len(args_avals)))
|
||||
@ -837,8 +880,8 @@ def serialize_native(fun_jax: Callable,
|
||||
return Exported(
|
||||
in_avals=args_avals,
|
||||
out_avals=out_avals,
|
||||
in_shardings=lowered.compile_args.get("in_shardings"),
|
||||
out_shardings=lowered.compile_args.get("out_shardings"),
|
||||
in_shardings=lowered.compile_args["in_shardings"],
|
||||
out_shardings=lowered.compile_args["out_shardings"],
|
||||
lowering_platform=lowering_platform or default_jax_backend(),
|
||||
mlir_module=mlir_module,
|
||||
mlir_module_serialized=mlir_module_serialized,
|
||||
|
@ -1101,7 +1101,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
"Cannot serialize code with custom calls whose targets .*"):
|
||||
jax2tf.convert(
|
||||
lambda a, b: jax.lax.linalg.triangular_solve(a, b, left_side=True),
|
||||
experimental_native_lowering=True)(a, b)
|
||||
native_serialization=True)(a, b)
|
||||
|
||||
def test_op_metadata_simple(self):
|
||||
self.skipTest("include_xla_op_metadata not yet enabled")
|
||||
|
@ -358,6 +358,58 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
(r"custom_call_target.*Sharding", 2 + count_inner_sharding)
|
||||
])
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
||||
in_shardings=in_shardings, out_shardings=out_shardings)
|
||||
for in_shardings in ("missing", None, "P")
|
||||
for out_shardings in ("missing", None, "P")
|
||||
)
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_grad_pjit(self, in_shardings="missing", out_shardings="None"):
|
||||
def f_jax(x): # x: f32[10,20] -> f32[20,10]
|
||||
return jnp.sin(x.T)
|
||||
|
||||
pjit_kwargs = {}
|
||||
if in_shardings != "missing":
|
||||
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
|
||||
if out_shardings != "missing":
|
||||
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
|
||||
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
|
||||
x_shape = (10, 20)
|
||||
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
||||
|
||||
def f_grad_tf(x_v, res_ct):
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
tape.watch(x_v)
|
||||
res_tf = jax2tf.convert(f_jax)(x_v)
|
||||
return tape.gradient(res_tf, x_v, output_gradients=res_ct)
|
||||
|
||||
# Annotation count for the primal input and the grad output
|
||||
count_in_P = self.GEQ(2) if in_shardings == "P" else 0
|
||||
if config.jax2tf_default_native_serialization:
|
||||
# With native serialization even unspecified in_shardings turn into replicated
|
||||
count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0
|
||||
else:
|
||||
count_in_replicated = self.GEQ(2) if in_shardings is None else 0
|
||||
# Annotation count for the contangent input
|
||||
count_out_P = self.GEQ(1) if out_shardings == "P" else 0
|
||||
if config.jax2tf_default_native_serialization:
|
||||
# With native serialization even unspecified in_shardings turn into replicated
|
||||
count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0
|
||||
else:
|
||||
count_out_replicated = self.GEQ(1) if out_shardings is None else 0
|
||||
|
||||
self.check_sharding(f_grad_tf, [x, x.T],
|
||||
checks=[
|
||||
# The input primal argument, and the output grad
|
||||
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P),
|
||||
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated),
|
||||
# The primal result, and the input cotangent
|
||||
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
|
||||
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated),
|
||||
])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
||||
kind=kind, in_shardings=in_shardings, out_shardings=out_shardings)
|
||||
@ -460,8 +512,8 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
bshape = (2, 7)
|
||||
b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape)
|
||||
|
||||
# f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28]
|
||||
# f_jax: f32[5], f32[7] -> f32[10], f32[28]
|
||||
# f_jax: f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28]
|
||||
# lambda ...: f32[5], f32[7] -> f32[10], f32[28]
|
||||
f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2.,
|
||||
jnp.concatenate([b, b, b, b], axis=0) * 4.),
|
||||
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
||||
@ -535,6 +587,34 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
(r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1),
|
||||
])
|
||||
|
||||
def test_grad_xmap(self):
|
||||
devices = np.reshape(self.devices, (1, 2))
|
||||
ashape = (16, 8, 5)
|
||||
a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape)
|
||||
|
||||
# f_jax: f32[16,8,5]-> f32[16,8,10]
|
||||
# lambda ...: f32[5]-> f32[10]
|
||||
f_jax = xmap(lambda a: jnp.concatenate([a, a], axis=0) * 2.,
|
||||
in_axes=({0: 'a', 1: 'b'}),
|
||||
out_axes={0: 'a', 1: 'b'},
|
||||
axis_resources={'a': 'x', 'b': 'y'})
|
||||
|
||||
def f_grad_tf(a, res_ct):
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
tape.watch(a)
|
||||
res_tf = jax2tf.convert(f_jax, native_serialization=True)(a)
|
||||
return tape.gradient(res_tf, a, output_gradients=res_ct)
|
||||
|
||||
|
||||
with Mesh(devices, ('x', 'y')):
|
||||
self.check_sharding(f_grad_tf, [a, np.concatenate([a, a], axis=2)],
|
||||
checks=[
|
||||
# Primal input and grad output
|
||||
(r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(2)),
|
||||
# Input cotangent
|
||||
(r"f32\[16,8,10\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(1)),
|
||||
])
|
||||
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="all_to_all .* are only implemented properly for TPUs and GPUs .*")
|
||||
def test_shmap_all_to_all(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user