[jax2tf] First step to enable multi-platform native lowering

Enable experiments with jax2tf native serialization for
multiple platforms. This feature is not yet fully functional
but we need this change to enable further testing.

Cleanup some of the places that are specific to single-platform
serialization, e.g., `lowering_platform`, and generalize
them to multiple platforms (`lowering_platforms`).
This commit is contained in:
George Necula 2023-10-13 10:30:11 -07:00
parent 110b8d7484
commit b65c1b293b
5 changed files with 114 additions and 32 deletions

View File

@ -229,7 +229,6 @@ class Exported:
in_shardings: tuple[Sharding, ...]
out_shardings: tuple[Sharding, ...]
lowering_platform: str # For backwards compatibility
lowering_platforms: tuple[str, ...]
disabled_checks: Sequence[DisabledSafetyCheck]
@ -357,6 +356,7 @@ def poly_specs(
def export(fun_jax: Callable,
*,
# TODO(necula): remove this kwarg
lowering_platform: Optional[str] = None,
lowering_platforms: Optional[Sequence[str]] = None,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
@ -365,9 +365,9 @@ def export(fun_jax: Callable,
Args:
fun_jax: the function to lower and serialize.
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'. If None, then use
the default JAX backend.
lowering_platforms: DO NOT USE (NOT YET FUNCTIONAL).
lowering_platform: DO NOT USE, FOR BACKWARDS COMPATIBILITY ONLY. Use
`lowering_platforms`.
lowering_platforms:
Optional sequence containing a subset of 'tpu', 'cpu',
'cuda', 'rocm'. If more than one platform is specified, then
the lowered code takes an argument specifying the platform.
@ -408,11 +408,10 @@ def export(fun_jax: Callable,
wrapped_fun_jax = fun_jax # type: ignore
allow_non_replicated_sharding = True
nonlocal lowering_platforms
if lowering_platforms is not None:
lowering_platforms = tuple(lowering_platforms)
actual_lowering_platforms = tuple(lowering_platforms)
else:
lowering_platforms = (lowering_platform or default_lowering_platform(),)
actual_lowering_platforms = (lowering_platform or default_lowering_platform(),)
# Do not include shape assertions if the version is < 7.
enable_shape_assertions = (
@ -424,7 +423,7 @@ def export(fun_jax: Callable,
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=lowering_platforms,
platforms=actual_lowering_platforms,
))
lowering = lowered._lowering # type: ignore
@ -467,7 +466,7 @@ def export(fun_jax: Callable,
if logging.vlog_is_on(3):
mlir_module_text = mlir.module_to_string(mlir_module)
logmsg = (f"version={version} "
f"lowering_platforms={lowering_platforms} "
f"lowering_platforms={actual_lowering_platforms} "
f"disabled_checks={disabled_checks}")
logging.info("Lowered JAX module: %s\n", logmsg)
for l in mlir_module_text.splitlines():
@ -485,8 +484,7 @@ def export(fun_jax: Callable,
out_avals=tuple(out_avals_flat),
in_shardings=lowering.compile_args["in_shardings"],
out_shardings=lowering.compile_args["out_shardings"],
lowering_platform=lowering_platforms[0], # TODO: remove
lowering_platforms=lowering_platforms,
lowering_platforms=actual_lowering_platforms,
disabled_checks=tuple(disabled_checks),
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
@ -932,7 +930,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
out_shardings=primal.out_shardings,
apply_jit=True)
return export(fun_vjp_jax,
lowering_platform=primal.lowering_platform,
lowering_platforms=primal.lowering_platforms,
disabled_checks=primal.disabled_checks)(*vjp_in_avals)
### Importing

View File

@ -234,7 +234,7 @@ def convert(fun_jax: Callable,
with_gradient: bool = True,
enable_xla: bool = True,
native_serialization: Union[bool, _DefaultNativeSerialization] = DEFAULT_NATIVE_SERIALIZATION,
native_serialization_platforms: Sequence[str] = (),
native_serialization_platforms: Optional[Sequence[str]] = None,
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (),
) -> Callable:
"""Allows calling a JAX function from a TensorFlow program.
@ -307,7 +307,7 @@ def convert(fun_jax: Callable,
`native_serialization`, specify the platform(s)
for which to lower the code. Must be a tuple of
strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'.
The default (empty tuple), specifies the JAX default
The default (`None``), specifies the JAX default
backend on the machine where the lowering is done.
native_serialization_disabled_checks: In conjunction with
`native_serialization`, disable the specified safety checks.
@ -342,9 +342,6 @@ def convert(fun_jax: Callable,
"containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. "
f"Got: {native_serialization_platforms}")
native_serialization_platforms = tuple(native_serialization_platforms)
if len(native_serialization_platforms) > 1:
raise NotImplementedError(
"native_serialization_platforms is not yet implemented for multiple platforms")
api.check_callable(fun_jax)
@ -478,7 +475,7 @@ class SerializationImpl:
class NativeSerializationImpl(SerializationImpl):
def __init__(self, fun_jax, *,
args_specs, kwargs_specs,
native_serialization_platforms: Sequence[str],
native_serialization_platforms: Optional[Sequence[str]],
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]):
self.convert_kwargs = dict(native_serialization=True,
native_serialization_platforms=native_serialization_platforms,
@ -487,10 +484,7 @@ class NativeSerializationImpl(SerializationImpl):
self.args_specs = args_specs
self.kwargs_specs = kwargs_specs
self.native_serialization_disabled_checks = native_serialization_disabled_checks
if native_serialization_platforms:
self.lowering_platform: Optional[str] = native_serialization_platforms[0]
else:
self.lowering_platform = None
self.native_serialization_platforms = native_serialization_platforms
def before_conversion(self):
_prev_func_list = _thread_local_state.call_tf_concrete_function_list
@ -502,7 +496,7 @@ class NativeSerializationImpl(SerializationImpl):
self._restore_context = _restore_context
self.exported = export.export(
self.fun_jax,
lowering_platform=self.lowering_platform,
lowering_platforms=self.native_serialization_platforms,
disabled_checks=self.native_serialization_disabled_checks
)(*self.args_specs, **self.kwargs_specs)
@ -850,7 +844,7 @@ def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
)
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
call_module_attrs["platforms"] = tuple(p.upper() for p in exported.lowering_platforms)
if version >= 6:
call_module_attrs["disabled_checks"] = tuple(
str(dc)

View File

@ -284,7 +284,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
exported = export.export(
jax.jit(func),
lowering_platform=self.default_jax_backend(),
lowering_platforms=(self.default_jax_backend(),),
disabled_checks=tuple(
export.DisabledSafetyCheck.custom_call(target)
for target in allow_unstable_custom_call_targets)
@ -320,7 +320,6 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
out_avals=tuple(out_avals),
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
lowering_platform=data.platform,
lowering_platforms=(data.platform,),
disabled_checks=(),
mlir_module_serialized=data.mlir_module_serialized,

View File

@ -15,10 +15,13 @@
Specific JAX primitive conversion tests are in primitives_test."""
import collections
from collections.abc import Sequence
import contextlib
import functools
import math
import os
import re
from typing import Optional
import unittest
from absl import logging
@ -35,6 +38,7 @@ from jax._src import core
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import tf_test_util
@ -54,6 +58,30 @@ config.parse_flags_with_absl()
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@classmethod
def setUpClass(cls):
# Pick one device from each available platform
cls.jax_platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
devices = jax.devices(backend)
except RuntimeError:
devices = []
if devices:
cls.jax_platforms.append(devices[0].platform)
# One TF device of each device_type
cls.tf_devices = []
for tf_device in (tf.config.list_logical_devices("TPU") +
tf.config.list_logical_devices("GPU") +
tf.config.list_logical_devices()):
if tf_device.device_type == "TPU_SYSTEM":
continue # A virtual device
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
cls.tf_devices.append(tf_device)
super(Jax2TfTest, cls).setUpClass()
def test_empty(self):
f_jax = lambda x, y: x
self.ConvertAndCompare(f_jax, 0.7, 1)
@ -1537,7 +1565,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
_ = func_to_convert(*args)
exported = export.export(
func_to_convert,
lowering_platform='tpu'
lowering_platforms=("tpu",)
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
if transform1 == "shard_map":
@ -1661,6 +1689,68 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f_tf_nested = jax2tf.convert(f_jax_nested, native_serialization=True)
self.assertAllClose(res, f_tf_nested(inputs))
def test_multi_platform(self):
if config.enable_x64.value:
self.skipTest("TODO: enable when we can handle i64 platform_index_argument")
# Checks that we dispatch from TF to the proper JAX platform lowering.
# A primitive for testing multi-platform lowering. Takes one argument and
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
@_testing_multi_platform_p.def_abstract_eval
def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue):
assert xaval.dtype == np.float32 # type: ignore
return xaval
@_testing_multi_platform_p.def_impl
def _testing_multi_platform_impl(x: jax.Array) -> jax.Array:
to_add = _testing_multi_platform_to_add[platform]
return x + to_add
def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext,
x: mlir.Value,
*,
platform: str) -> Sequence[mlir.Value]:
to_add = _testing_multi_platform_to_add[platform]
to_add_value = mlir.broadcast_in_dim(ctx,
mlir.ir_constant(
np.float32(to_add)),
ctx.avals_in[0],
broadcast_dimensions=())
return mlir.hlo.AddOp(x, to_add_value).results
# Register a default rule for cuda, to test the default-platform rule selection.
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(_testing_multi_platform_lowering,
platform="cuda"))
for platform in ["cpu", "tpu", "rocm"]:
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(
_testing_multi_platform_lowering,
platform=platform),
platform=platform)
def f_jax(x):
return _testing_multi_platform_p.bind(x)
x = np.float32(.42)
f_tf = jax2tf.convert(
f_jax,
native_serialization=True,
native_serialization_platforms=("cpu", "cuda", "tpu"))
for tf_device in self.__class__.tf_devices:
with tf.device(tf_device):
res = f_tf(x)
logging.info(f"tf_device = {tf_device} and device_type = {tf_device.device_type}")
tf_device_jax_platform = dict(
CPU="cpu", GPU="cuda", TPU="tpu"
)[tf_device.device_type]
self.assertAllClose(
res,
x + _testing_multi_platform_to_add[tf_device_jax_platform])
@jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):

View File

@ -127,7 +127,8 @@ class JaxExportTest(jtu.JaxTestCase):
return jnp.sin(x)
exp = export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
self.assertEqual("my_fun", exp.fun_name)
self.assertEqual(export.default_lowering_platform(), exp.lowering_platform)
self.assertEqual((export.default_lowering_platform(),),
exp.lowering_platforms)
self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
@ -138,10 +139,10 @@ class JaxExportTest(jtu.JaxTestCase):
def f(a_b_pair, *, a, b):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp = export.export(f, lowering_platform="cpu")((a, b), a=a, b=b)
exp = export.export(f, lowering_platforms=("cpu",))((a, b), a=a, b=b)
a_aval = core.ShapedArray(a.shape, a.dtype)
b_aval = core.ShapedArray(b.shape, b.dtype)
self.assertEqual(exp.lowering_platform, "cpu")
self.assertEqual(exp.lowering_platforms, ("cpu",))
args = ((a, b),)
kwargs = dict(a=a, b=b)
self.assertEqual(exp.in_tree, tree_util.tree_flatten((args, kwargs))[1])
@ -258,7 +259,7 @@ class JaxExportTest(jtu.JaxTestCase):
def test_error_wrong_platform(self, platform):
a = np.arange(4, dtype=np.float32)
exp_f = export.export(jnp.sin, lowering_platform=platform)(a)
exp_f = export.export(jnp.sin, lowering_platforms=(platform,))(a)
if xb.canonicalize_platform(jtu.device_under_test()) == platform:
raise unittest.SkipTest("Uninteresting scenario")
@ -268,7 +269,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Now try with the platform check disabled
exp_f_no_platform_check = export.export(
jnp.sin, lowering_platform=platform,
jnp.sin, lowering_platforms=(platform,),
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
res = export.call_exported(exp_f_no_platform_check)(a)
self.assertAllClose(res, jnp.sin(a))