Merge pull request #18104 from gnecula:multi_jax2tf

PiperOrigin-RevId: 573951693
This commit is contained in:
jax authors 2023-10-16 15:46:23 -07:00
commit 5919c1f33c
5 changed files with 114 additions and 32 deletions

View File

@ -229,7 +229,6 @@ class Exported:
in_shardings: tuple[Sharding, ...]
out_shardings: tuple[Sharding, ...]
lowering_platform: str # For backwards compatibility
lowering_platforms: tuple[str, ...]
disabled_checks: Sequence[DisabledSafetyCheck]
@ -357,6 +356,7 @@ def poly_specs(
def export(fun_jax: Callable,
*,
# TODO(necula): remove this kwarg
lowering_platform: Optional[str] = None,
lowering_platforms: Optional[Sequence[str]] = None,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
@ -365,9 +365,9 @@ def export(fun_jax: Callable,
Args:
fun_jax: the function to lower and serialize.
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'. If None, then use
the default JAX backend.
lowering_platforms: DO NOT USE (NOT YET FUNCTIONAL).
lowering_platform: DO NOT USE, FOR BACKWARDS COMPATIBILITY ONLY. Use
`lowering_platforms`.
lowering_platforms:
Optional sequence containing a subset of 'tpu', 'cpu',
'cuda', 'rocm'. If more than one platform is specified, then
the lowered code takes an argument specifying the platform.
@ -408,11 +408,10 @@ def export(fun_jax: Callable,
wrapped_fun_jax = fun_jax # type: ignore
allow_non_replicated_sharding = True
nonlocal lowering_platforms
if lowering_platforms is not None:
lowering_platforms = tuple(lowering_platforms)
actual_lowering_platforms = tuple(lowering_platforms)
else:
lowering_platforms = (lowering_platform or default_lowering_platform(),)
actual_lowering_platforms = (lowering_platform or default_lowering_platform(),)
# Do not include shape assertions if the version is < 7.
enable_shape_assertions = (
@ -424,7 +423,7 @@ def export(fun_jax: Callable,
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=lowering_platforms,
platforms=actual_lowering_platforms,
))
lowering = lowered._lowering # type: ignore
@ -467,7 +466,7 @@ def export(fun_jax: Callable,
if logging.vlog_is_on(3):
mlir_module_text = mlir.module_to_string(mlir_module)
logmsg = (f"version={version} "
f"lowering_platforms={lowering_platforms} "
f"lowering_platforms={actual_lowering_platforms} "
f"disabled_checks={disabled_checks}")
logging.info("Lowered JAX module: %s\n", logmsg)
for l in mlir_module_text.splitlines():
@ -485,8 +484,7 @@ def export(fun_jax: Callable,
out_avals=tuple(out_avals_flat),
in_shardings=lowering.compile_args["in_shardings"],
out_shardings=lowering.compile_args["out_shardings"],
lowering_platform=lowering_platforms[0], # TODO: remove
lowering_platforms=lowering_platforms,
lowering_platforms=actual_lowering_platforms,
disabled_checks=tuple(disabled_checks),
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
@ -932,7 +930,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
out_shardings=primal.out_shardings,
apply_jit=True)
return export(fun_vjp_jax,
lowering_platform=primal.lowering_platform,
lowering_platforms=primal.lowering_platforms,
disabled_checks=primal.disabled_checks)(*vjp_in_avals)
### Importing

View File

@ -234,7 +234,7 @@ def convert(fun_jax: Callable,
with_gradient: bool = True,
enable_xla: bool = True,
native_serialization: Union[bool, _DefaultNativeSerialization] = DEFAULT_NATIVE_SERIALIZATION,
native_serialization_platforms: Sequence[str] = (),
native_serialization_platforms: Optional[Sequence[str]] = None,
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (),
) -> Callable:
"""Allows calling a JAX function from a TensorFlow program.
@ -307,7 +307,7 @@ def convert(fun_jax: Callable,
`native_serialization`, specify the platform(s)
for which to lower the code. Must be a tuple of
strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'.
The default (empty tuple), specifies the JAX default
The default (`None``), specifies the JAX default
backend on the machine where the lowering is done.
native_serialization_disabled_checks: In conjunction with
`native_serialization`, disable the specified safety checks.
@ -342,9 +342,6 @@ def convert(fun_jax: Callable,
"containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. "
f"Got: {native_serialization_platforms}")
native_serialization_platforms = tuple(native_serialization_platforms)
if len(native_serialization_platforms) > 1:
raise NotImplementedError(
"native_serialization_platforms is not yet implemented for multiple platforms")
api.check_callable(fun_jax)
@ -478,7 +475,7 @@ class SerializationImpl:
class NativeSerializationImpl(SerializationImpl):
def __init__(self, fun_jax, *,
args_specs, kwargs_specs,
native_serialization_platforms: Sequence[str],
native_serialization_platforms: Optional[Sequence[str]],
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]):
self.convert_kwargs = dict(native_serialization=True,
native_serialization_platforms=native_serialization_platforms,
@ -487,10 +484,7 @@ class NativeSerializationImpl(SerializationImpl):
self.args_specs = args_specs
self.kwargs_specs = kwargs_specs
self.native_serialization_disabled_checks = native_serialization_disabled_checks
if native_serialization_platforms:
self.lowering_platform: Optional[str] = native_serialization_platforms[0]
else:
self.lowering_platform = None
self.native_serialization_platforms = native_serialization_platforms
def before_conversion(self):
_prev_func_list = _thread_local_state.call_tf_concrete_function_list
@ -502,7 +496,7 @@ class NativeSerializationImpl(SerializationImpl):
self._restore_context = _restore_context
self.exported = export.export(
self.fun_jax,
lowering_platform=self.lowering_platform,
lowering_platforms=self.native_serialization_platforms,
disabled_checks=self.native_serialization_disabled_checks
)(*self.args_specs, **self.kwargs_specs)
@ -850,7 +844,7 @@ def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
)
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
call_module_attrs["platforms"] = tuple(p.upper() for p in exported.lowering_platforms)
if version >= 6:
call_module_attrs["disabled_checks"] = tuple(
str(dc)

View File

@ -284,7 +284,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
exported = export.export(
jax.jit(func),
lowering_platform=self.default_jax_backend(),
lowering_platforms=(self.default_jax_backend(),),
disabled_checks=tuple(
export.DisabledSafetyCheck.custom_call(target)
for target in allow_unstable_custom_call_targets)
@ -320,7 +320,6 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
out_avals=tuple(out_avals),
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
lowering_platform=data.platform,
lowering_platforms=(data.platform,),
disabled_checks=(),
mlir_module_serialized=data.mlir_module_serialized,

View File

@ -15,10 +15,13 @@
Specific JAX primitive conversion tests are in primitives_test."""
import collections
from collections.abc import Sequence
import contextlib
import functools
import math
import os
import re
from typing import Optional
import unittest
from absl import logging
@ -35,6 +38,7 @@ from jax._src import core
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import tf_test_util
@ -54,6 +58,30 @@ config.parse_flags_with_absl()
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@classmethod
def setUpClass(cls):
# Pick one device from each available platform
cls.jax_platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
devices = jax.devices(backend)
except RuntimeError:
devices = []
if devices:
cls.jax_platforms.append(devices[0].platform)
# One TF device of each device_type
cls.tf_devices = []
for tf_device in (tf.config.list_logical_devices("TPU") +
tf.config.list_logical_devices("GPU") +
tf.config.list_logical_devices()):
if tf_device.device_type == "TPU_SYSTEM":
continue # A virtual device
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
cls.tf_devices.append(tf_device)
super(Jax2TfTest, cls).setUpClass()
def test_empty(self):
f_jax = lambda x, y: x
self.ConvertAndCompare(f_jax, 0.7, 1)
@ -1537,7 +1565,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
_ = func_to_convert(*args)
exported = export.export(
func_to_convert,
lowering_platform='tpu'
lowering_platforms=("tpu",)
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
if transform1 == "shard_map":
@ -1661,6 +1689,68 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f_tf_nested = jax2tf.convert(f_jax_nested, native_serialization=True)
self.assertAllClose(res, f_tf_nested(inputs))
def test_multi_platform(self):
if config.enable_x64.value:
self.skipTest("TODO: enable when we can handle i64 platform_index_argument")
# Checks that we dispatch from TF to the proper JAX platform lowering.
# A primitive for testing multi-platform lowering. Takes one argument and
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
@_testing_multi_platform_p.def_abstract_eval
def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue):
assert xaval.dtype == np.float32 # type: ignore
return xaval
@_testing_multi_platform_p.def_impl
def _testing_multi_platform_impl(x: jax.Array) -> jax.Array:
to_add = _testing_multi_platform_to_add[platform]
return x + to_add
def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext,
x: mlir.Value,
*,
platform: str) -> Sequence[mlir.Value]:
to_add = _testing_multi_platform_to_add[platform]
to_add_value = mlir.broadcast_in_dim(ctx,
mlir.ir_constant(
np.float32(to_add)),
ctx.avals_in[0],
broadcast_dimensions=())
return mlir.hlo.AddOp(x, to_add_value).results
# Register a default rule for cuda, to test the default-platform rule selection.
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(_testing_multi_platform_lowering,
platform="cuda"))
for platform in ["cpu", "tpu", "rocm"]:
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(
_testing_multi_platform_lowering,
platform=platform),
platform=platform)
def f_jax(x):
return _testing_multi_platform_p.bind(x)
x = np.float32(.42)
f_tf = jax2tf.convert(
f_jax,
native_serialization=True,
native_serialization_platforms=("cpu", "cuda", "tpu"))
for tf_device in self.__class__.tf_devices:
with tf.device(tf_device):
res = f_tf(x)
logging.info(f"tf_device = {tf_device} and device_type = {tf_device.device_type}")
tf_device_jax_platform = dict(
CPU="cpu", GPU="cuda", TPU="tpu"
)[tf_device.device_type]
self.assertAllClose(
res,
x + _testing_multi_platform_to_add[tf_device_jax_platform])
@jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):

View File

@ -127,7 +127,8 @@ class JaxExportTest(jtu.JaxTestCase):
return jnp.sin(x)
exp = export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
self.assertEqual("my_fun", exp.fun_name)
self.assertEqual(export.default_lowering_platform(), exp.lowering_platform)
self.assertEqual((export.default_lowering_platform(),),
exp.lowering_platforms)
self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
@ -138,10 +139,10 @@ class JaxExportTest(jtu.JaxTestCase):
def f(a_b_pair, *, a, b):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp = export.export(f, lowering_platform="cpu")((a, b), a=a, b=b)
exp = export.export(f, lowering_platforms=("cpu",))((a, b), a=a, b=b)
a_aval = core.ShapedArray(a.shape, a.dtype)
b_aval = core.ShapedArray(b.shape, b.dtype)
self.assertEqual(exp.lowering_platform, "cpu")
self.assertEqual(exp.lowering_platforms, ("cpu",))
args = ((a, b),)
kwargs = dict(a=a, b=b)
self.assertEqual(exp.in_tree, tree_util.tree_flatten((args, kwargs))[1])
@ -258,7 +259,7 @@ class JaxExportTest(jtu.JaxTestCase):
def test_error_wrong_platform(self, platform):
a = np.arange(4, dtype=np.float32)
exp_f = export.export(jnp.sin, lowering_platform=platform)(a)
exp_f = export.export(jnp.sin, lowering_platforms=(platform,))(a)
if xb.canonicalize_platform(jtu.device_under_test()) == platform:
raise unittest.SkipTest("Uninteresting scenario")
@ -268,7 +269,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Now try with the platform check disabled
exp_f_no_platform_check = export.export(
jnp.sin, lowering_platform=platform,
jnp.sin, lowering_platforms=(platform,),
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
res = export.call_exported(exp_f_no_platform_check)(a)
self.assertAllClose(res, jnp.sin(a))