Merge pull request #15129 from gnecula:tf_grad_pjit

PiperOrigin-RevId: 518519977
This commit is contained in:
jax authors 2023-03-22 03:12:11 -07:00
commit 95d5e153b6
4 changed files with 150 additions and 27 deletions

View File

@ -784,7 +784,7 @@ def flatten_axis_resources(what, tree, shardings, tupled_args):
axis_tree = shardings
# Because ecause we only have the `tree` treedef and not the full pytree here,
# Because we only have the `tree` treedef and not the full pytree here,
# we construct a dummy tree to compare against. Revise this in callers?
dummy_tree = tree_unflatten(tree, [PytreeLeaf()] * tree.num_leaves)
errors = prefix_errors(axis_tree, dummy_tree)

View File

@ -430,16 +430,15 @@ def convert(fun_jax: Callable,
lowering_platform = native_serialization_platforms[0]
else:
lowering_platform = None
exported: Exported = serialize_native(
exported: Optional[Exported] = serialize_native(
fun_flat_jax, args_avals_flat,
lowering_platform=lowering_platform,
strict_checks=native_serialization_strict_checks)
def run_fun_flat_as_tf(
args_flat_tf: Sequence[TfVal]
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
outs_tf, out_avals = run_exported_as_tf(
args_avals_flat, args_flat_tf, exported,
args_avals_flat, args_flat_tf, exported, # type: ignore
native_serialization_strict_checks)
return outs_tf, out_avals
else:
@ -448,6 +447,7 @@ def convert(fun_jax: Callable,
dim_values, _ = _interpret_fun_jax(get_dim_values_jax, args_flat_tf,
args_avals_flat, name_stack)
shape_env = zip(dim_vars, dim_values) # type: ignore
exported = None
def run_fun_flat_as_tf(
args_flat_tf: Sequence[TfVal]
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
@ -477,7 +477,7 @@ def convert(fun_jax: Callable,
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
outs_tf, out_avals = run_fun_flat_as_tf(args_flat_tf)
return (tuple(outs_tf),
make_custom_gradient_fn_tf(
_make_custom_gradient_fn_tf(
fun_flat_jax=fun_flat_jax,
args_flat_tf=args_flat_tf,
args_avals_flat=args_avals_flat,
@ -485,7 +485,8 @@ def convert(fun_jax: Callable,
out_avals=out_avals,
native_serialization=native_serialization,
native_serialization_platforms=native_serialization_platforms,
native_serialization_strict_checks=native_serialization_strict_checks))
native_serialization_strict_checks=native_serialization_strict_checks,
exported_primal=exported))
out_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
else:
@ -599,17 +600,19 @@ def preprocess_arg_tf(arg_idx: int,
return arg_tf, arg_aval
# Prepare the grad_fn for tf.custom_gradient.
def make_custom_gradient_fn_tf(*,
fun_flat_jax: Callable,
args_flat_tf: Sequence[TfVal],
polymorphic_shapes_flat: Sequence[str],
args_avals_flat: Sequence[core.ShapedArray],
out_avals: Sequence[core.ShapedArray],
native_serialization: Union[str, bool],
native_serialization_platforms: Sequence[str],
native_serialization_strict_checks: bool
):
def _make_custom_gradient_fn_tf(*,
fun_flat_jax: Callable,
args_flat_tf: Sequence[TfVal],
polymorphic_shapes_flat: Sequence[str],
args_avals_flat: Sequence[core.ShapedArray],
out_avals: Sequence[core.ShapedArray],
native_serialization: Union[str, bool],
native_serialization_platforms: Sequence[str],
native_serialization_strict_checks: bool,
exported_primal: Optional["Exported"]):
"""Prepares the TF function to be used with tf.custom_gradient.
"""
def grad_fn_tf(*out_cts_flat_tf: TfVal,
variables=None):
@ -659,6 +662,45 @@ def make_custom_gradient_fn_tf(*,
in_cts_fixed_flat_jax = tuple(map(fix_in_ct, in_cts_flat_jax, args_avals_flat))
return in_cts_fixed_flat_jax
if exported_primal is not None:
# Native lowering
all_in_shardings = [pxla._UNSPECIFIED] * len(exported_primal.in_avals)
for idx, in_s in zip(sorted(exported_primal.module_kept_var_idx),
exported_primal.in_shardings):
all_in_shardings[idx] = in_s # type: ignore
all_shardings = all_in_shardings + list(exported_primal.out_shardings)
# We cannot mix unspecified and specified shardings. Make the unspecified
# ones replicated
specified_shardings = [
s for s in all_shardings if not pxla._is_unspecified(s)]
if 0 < len(specified_shardings) < len(all_shardings):
# There are some specified, but not all
in_s = specified_shardings[0] # pjit will enforce that all have same devices
assert isinstance(in_s, sharding.XLACompatibleSharding)
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
all_shardings = [
s if not pxla._is_unspecified(s) else replicated_s
for s in all_shardings]
# Since fun_vjp_jax takes two tuples of arguments we must split the in_shardings
vjp_in_args_shardings, vjp_in_out_ct_shardings = util.split_list(all_shardings,
[len(exported_primal.in_avals)])
# pjit front-end does not like all-unspecified
if all(pxla._is_unspecified(s) for s in vjp_in_args_shardings):
vjp_in_args_shardings = pxla._UNSPECIFIED
else:
vjp_in_args_shardings = tuple(vjp_in_args_shardings)
if all(pxla._is_unspecified(s) for s in vjp_in_out_ct_shardings):
vjp_in_out_ct_shardings = pxla._UNSPECIFIED
else:
vjp_in_out_ct_shardings = tuple(vjp_in_out_ct_shardings)
if pxla._is_unspecified(vjp_in_args_shardings) and pxla._is_unspecified(vjp_in_args_shardings):
vjp_in_shardings = pxla._UNSPECIFIED
else:
vjp_in_shardings = (vjp_in_args_shardings, vjp_in_out_ct_shardings)
fun_vjp_jax = pjit.pjit(fun_vjp_jax,
in_shardings=vjp_in_shardings,
out_shardings=vjp_in_args_shardings)
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
in_cts_flat = convert(
@ -707,15 +749,16 @@ class Exported:
"""Represents a lowered and serialized module."""
in_avals: Sequence[core.ShapedArray]
out_avals: Sequence[core.ShapedArray]
in_shardings: Optional[Sequence[Any]]
out_shardings: Optional[Sequence[Any]]
# The in_shardings reflect only the module_ket_var_idx
in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]]
lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm"
mlir_module: mlir.ir.Module
mlir_module_serialized: bytes # VHLO bytecode format
xla_call_module_version: int # Follows the versions of XlaCallModule
module_kept_var_idx: Sequence[bool] # Specifies if an argument is kept in the
# lowering. As long as `out_avals`.
module_kept_var_idx: Sequence[int] # Specifies if an argument is kept in the
# lowering. As long as `out_avals`.
dim_args_spec: Sequence[str]
def serialize_native(fun_jax: Callable,
@ -767,7 +810,7 @@ def serialize_native(fun_jax: Callable,
raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering")
if "kept_var_idx" in lowered.compile_args:
module_kept_var_idx = lowered.compile_args["kept_var_idx"]
module_kept_var_idx = tuple(sorted(lowered.compile_args["kept_var_idx"]))
else:
# For pmap
module_kept_var_idx = tuple(range(len(args_avals)))
@ -837,8 +880,8 @@ def serialize_native(fun_jax: Callable,
return Exported(
in_avals=args_avals,
out_avals=out_avals,
in_shardings=lowered.compile_args.get("in_shardings"),
out_shardings=lowered.compile_args.get("out_shardings"),
in_shardings=lowered.compile_args["in_shardings"],
out_shardings=lowered.compile_args["out_shardings"],
lowering_platform=lowering_platform or default_jax_backend(),
mlir_module=mlir_module,
mlir_module_serialized=mlir_module_serialized,

View File

@ -1101,7 +1101,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
"Cannot serialize code with custom calls whose targets .*"):
jax2tf.convert(
lambda a, b: jax.lax.linalg.triangular_solve(a, b, left_side=True),
experimental_native_lowering=True)(a, b)
native_serialization=True)(a, b)
def test_op_metadata_simple(self):
self.skipTest("include_xla_op_metadata not yet enabled")

View File

@ -358,6 +358,58 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
(r"custom_call_target.*Sharding", 2 + count_inner_sharding)
])
@parameterized.named_parameters(
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
in_shardings=in_shardings, out_shardings=out_shardings)
for in_shardings in ("missing", None, "P")
for out_shardings in ("missing", None, "P")
)
@jtu.with_mesh([("x", 2)])
def test_grad_pjit(self, in_shardings="missing", out_shardings="None"):
def f_jax(x): # x: f32[10,20] -> f32[20,10]
return jnp.sin(x.T)
pjit_kwargs = {}
if in_shardings != "missing":
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
if out_shardings != "missing":
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
x_shape = (10, 20)
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
def f_grad_tf(x_v, res_ct):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x_v)
res_tf = jax2tf.convert(f_jax)(x_v)
return tape.gradient(res_tf, x_v, output_gradients=res_ct)
# Annotation count for the primal input and the grad output
count_in_P = self.GEQ(2) if in_shardings == "P" else 0
if config.jax2tf_default_native_serialization:
# With native serialization even unspecified in_shardings turn into replicated
count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0
else:
count_in_replicated = self.GEQ(2) if in_shardings is None else 0
# Annotation count for the contangent input
count_out_P = self.GEQ(1) if out_shardings == "P" else 0
if config.jax2tf_default_native_serialization:
# With native serialization even unspecified in_shardings turn into replicated
count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0
else:
count_out_replicated = self.GEQ(1) if out_shardings is None else 0
self.check_sharding(f_grad_tf, [x, x.T],
checks=[
# The input primal argument, and the output grad
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P),
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated),
# The primal result, and the input cotangent
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated),
])
@parameterized.named_parameters(
dict(testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}",
kind=kind, in_shardings=in_shardings, out_shardings=out_shardings)
@ -460,8 +512,8 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
bshape = (2, 7)
b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape)
# f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28]
# f_jax: f32[5], f32[7] -> f32[10], f32[28]
# f_jax: f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28]
# lambda ...: f32[5], f32[7] -> f32[10], f32[28]
f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2.,
jnp.concatenate([b, b, b, b], axis=0) * 4.),
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
@ -535,6 +587,34 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
(r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1),
])
def test_grad_xmap(self):
devices = np.reshape(self.devices, (1, 2))
ashape = (16, 8, 5)
a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape)
# f_jax: f32[16,8,5]-> f32[16,8,10]
# lambda ...: f32[5]-> f32[10]
f_jax = xmap(lambda a: jnp.concatenate([a, a], axis=0) * 2.,
in_axes=({0: 'a', 1: 'b'}),
out_axes={0: 'a', 1: 'b'},
axis_resources={'a': 'x', 'b': 'y'})
def f_grad_tf(a, res_ct):
with tf.GradientTape(persistent=True) as tape:
tape.watch(a)
res_tf = jax2tf.convert(f_jax, native_serialization=True)(a)
return tape.gradient(res_tf, a, output_gradients=res_ct)
with Mesh(devices, ('x', 'y')):
self.check_sharding(f_grad_tf, [a, np.concatenate([a, a], axis=2)],
checks=[
# Primal input and grad output
(r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(2)),
# Input cotangent
(r"f32\[16,8,10\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(1)),
])
@jtu.ignore_warning(category=UserWarning,
message="all_to_all .* are only implemented properly for TPUs and GPUs .*")
def test_shmap_all_to_all(self):