mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
[jax2tf] First step to enable multi-platform native lowering
Enable experiments with jax2tf native serialization for multiple platforms. This feature is not yet fully functional but we need this change to enable further testing. Cleanup some of the places that are specific to single-platform serialization, e.g., `lowering_platform`, and generalize them to multiple platforms (`lowering_platforms`).
This commit is contained in:
parent
110b8d7484
commit
b65c1b293b
@ -229,7 +229,6 @@ class Exported:
|
||||
|
||||
in_shardings: tuple[Sharding, ...]
|
||||
out_shardings: tuple[Sharding, ...]
|
||||
lowering_platform: str # For backwards compatibility
|
||||
lowering_platforms: tuple[str, ...]
|
||||
disabled_checks: Sequence[DisabledSafetyCheck]
|
||||
|
||||
@ -357,6 +356,7 @@ def poly_specs(
|
||||
|
||||
def export(fun_jax: Callable,
|
||||
*,
|
||||
# TODO(necula): remove this kwarg
|
||||
lowering_platform: Optional[str] = None,
|
||||
lowering_platforms: Optional[Sequence[str]] = None,
|
||||
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
@ -365,9 +365,9 @@ def export(fun_jax: Callable,
|
||||
|
||||
Args:
|
||||
fun_jax: the function to lower and serialize.
|
||||
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'. If None, then use
|
||||
the default JAX backend.
|
||||
lowering_platforms: DO NOT USE (NOT YET FUNCTIONAL).
|
||||
lowering_platform: DO NOT USE, FOR BACKWARDS COMPATIBILITY ONLY. Use
|
||||
`lowering_platforms`.
|
||||
lowering_platforms:
|
||||
Optional sequence containing a subset of 'tpu', 'cpu',
|
||||
'cuda', 'rocm'. If more than one platform is specified, then
|
||||
the lowered code takes an argument specifying the platform.
|
||||
@ -408,11 +408,10 @@ def export(fun_jax: Callable,
|
||||
wrapped_fun_jax = fun_jax # type: ignore
|
||||
allow_non_replicated_sharding = True
|
||||
|
||||
nonlocal lowering_platforms
|
||||
if lowering_platforms is not None:
|
||||
lowering_platforms = tuple(lowering_platforms)
|
||||
actual_lowering_platforms = tuple(lowering_platforms)
|
||||
else:
|
||||
lowering_platforms = (lowering_platform or default_lowering_platform(),)
|
||||
actual_lowering_platforms = (lowering_platform or default_lowering_platform(),)
|
||||
|
||||
# Do not include shape assertions if the version is < 7.
|
||||
enable_shape_assertions = (
|
||||
@ -424,7 +423,7 @@ def export(fun_jax: Callable,
|
||||
lowered = wrapped_fun_jax.lower(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=lowering_platforms,
|
||||
platforms=actual_lowering_platforms,
|
||||
))
|
||||
|
||||
lowering = lowered._lowering # type: ignore
|
||||
@ -467,7 +466,7 @@ def export(fun_jax: Callable,
|
||||
if logging.vlog_is_on(3):
|
||||
mlir_module_text = mlir.module_to_string(mlir_module)
|
||||
logmsg = (f"version={version} "
|
||||
f"lowering_platforms={lowering_platforms} "
|
||||
f"lowering_platforms={actual_lowering_platforms} "
|
||||
f"disabled_checks={disabled_checks}")
|
||||
logging.info("Lowered JAX module: %s\n", logmsg)
|
||||
for l in mlir_module_text.splitlines():
|
||||
@ -485,8 +484,7 @@ def export(fun_jax: Callable,
|
||||
out_avals=tuple(out_avals_flat),
|
||||
in_shardings=lowering.compile_args["in_shardings"],
|
||||
out_shardings=lowering.compile_args["out_shardings"],
|
||||
lowering_platform=lowering_platforms[0], # TODO: remove
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_platforms=actual_lowering_platforms,
|
||||
disabled_checks=tuple(disabled_checks),
|
||||
mlir_module_serialized=mlir_module_serialized,
|
||||
module_kept_var_idx=module_kept_var_idx,
|
||||
@ -932,7 +930,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
|
||||
out_shardings=primal.out_shardings,
|
||||
apply_jit=True)
|
||||
return export(fun_vjp_jax,
|
||||
lowering_platform=primal.lowering_platform,
|
||||
lowering_platforms=primal.lowering_platforms,
|
||||
disabled_checks=primal.disabled_checks)(*vjp_in_avals)
|
||||
|
||||
### Importing
|
||||
|
@ -234,7 +234,7 @@ def convert(fun_jax: Callable,
|
||||
with_gradient: bool = True,
|
||||
enable_xla: bool = True,
|
||||
native_serialization: Union[bool, _DefaultNativeSerialization] = DEFAULT_NATIVE_SERIALIZATION,
|
||||
native_serialization_platforms: Sequence[str] = (),
|
||||
native_serialization_platforms: Optional[Sequence[str]] = None,
|
||||
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
) -> Callable:
|
||||
"""Allows calling a JAX function from a TensorFlow program.
|
||||
@ -307,7 +307,7 @@ def convert(fun_jax: Callable,
|
||||
`native_serialization`, specify the platform(s)
|
||||
for which to lower the code. Must be a tuple of
|
||||
strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'.
|
||||
The default (empty tuple), specifies the JAX default
|
||||
The default (`None``), specifies the JAX default
|
||||
backend on the machine where the lowering is done.
|
||||
native_serialization_disabled_checks: In conjunction with
|
||||
`native_serialization`, disable the specified safety checks.
|
||||
@ -342,9 +342,6 @@ def convert(fun_jax: Callable,
|
||||
"containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. "
|
||||
f"Got: {native_serialization_platforms}")
|
||||
native_serialization_platforms = tuple(native_serialization_platforms)
|
||||
if len(native_serialization_platforms) > 1:
|
||||
raise NotImplementedError(
|
||||
"native_serialization_platforms is not yet implemented for multiple platforms")
|
||||
|
||||
api.check_callable(fun_jax)
|
||||
|
||||
@ -478,7 +475,7 @@ class SerializationImpl:
|
||||
class NativeSerializationImpl(SerializationImpl):
|
||||
def __init__(self, fun_jax, *,
|
||||
args_specs, kwargs_specs,
|
||||
native_serialization_platforms: Sequence[str],
|
||||
native_serialization_platforms: Optional[Sequence[str]],
|
||||
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]):
|
||||
self.convert_kwargs = dict(native_serialization=True,
|
||||
native_serialization_platforms=native_serialization_platforms,
|
||||
@ -487,10 +484,7 @@ class NativeSerializationImpl(SerializationImpl):
|
||||
self.args_specs = args_specs
|
||||
self.kwargs_specs = kwargs_specs
|
||||
self.native_serialization_disabled_checks = native_serialization_disabled_checks
|
||||
if native_serialization_platforms:
|
||||
self.lowering_platform: Optional[str] = native_serialization_platforms[0]
|
||||
else:
|
||||
self.lowering_platform = None
|
||||
self.native_serialization_platforms = native_serialization_platforms
|
||||
|
||||
def before_conversion(self):
|
||||
_prev_func_list = _thread_local_state.call_tf_concrete_function_list
|
||||
@ -502,7 +496,7 @@ class NativeSerializationImpl(SerializationImpl):
|
||||
self._restore_context = _restore_context
|
||||
self.exported = export.export(
|
||||
self.fun_jax,
|
||||
lowering_platform=self.lowering_platform,
|
||||
lowering_platforms=self.native_serialization_platforms,
|
||||
disabled_checks=self.native_serialization_disabled_checks
|
||||
)(*self.args_specs, **self.kwargs_specs)
|
||||
|
||||
@ -850,7 +844,7 @@ def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
|
||||
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
|
||||
)
|
||||
|
||||
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
|
||||
call_module_attrs["platforms"] = tuple(p.upper() for p in exported.lowering_platforms)
|
||||
if version >= 6:
|
||||
call_module_attrs["disabled_checks"] = tuple(
|
||||
str(dc)
|
||||
|
@ -284,7 +284,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
|
||||
exported = export.export(
|
||||
jax.jit(func),
|
||||
lowering_platform=self.default_jax_backend(),
|
||||
lowering_platforms=(self.default_jax_backend(),),
|
||||
disabled_checks=tuple(
|
||||
export.DisabledSafetyCheck.custom_call(target)
|
||||
for target in allow_unstable_custom_call_targets)
|
||||
@ -320,7 +320,6 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
out_avals=tuple(out_avals),
|
||||
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
|
||||
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
|
||||
lowering_platform=data.platform,
|
||||
lowering_platforms=(data.platform,),
|
||||
disabled_checks=(),
|
||||
mlir_module_serialized=data.mlir_module_serialized,
|
||||
|
@ -15,10 +15,13 @@
|
||||
|
||||
Specific JAX primitive conversion tests are in primitives_test."""
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
import unittest
|
||||
|
||||
from absl import logging
|
||||
@ -35,6 +38,7 @@ from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
@ -54,6 +58,30 @@ config.parse_flags_with_absl()
|
||||
|
||||
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Pick one device from each available platform
|
||||
cls.jax_platforms = []
|
||||
for backend in ["cpu", "gpu", "tpu"]:
|
||||
try:
|
||||
devices = jax.devices(backend)
|
||||
except RuntimeError:
|
||||
devices = []
|
||||
if devices:
|
||||
cls.jax_platforms.append(devices[0].platform)
|
||||
|
||||
# One TF device of each device_type
|
||||
cls.tf_devices = []
|
||||
for tf_device in (tf.config.list_logical_devices("TPU") +
|
||||
tf.config.list_logical_devices("GPU") +
|
||||
tf.config.list_logical_devices()):
|
||||
if tf_device.device_type == "TPU_SYSTEM":
|
||||
continue # A virtual device
|
||||
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
|
||||
cls.tf_devices.append(tf_device)
|
||||
|
||||
super(Jax2TfTest, cls).setUpClass()
|
||||
|
||||
def test_empty(self):
|
||||
f_jax = lambda x, y: x
|
||||
self.ConvertAndCompare(f_jax, 0.7, 1)
|
||||
@ -1537,7 +1565,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
_ = func_to_convert(*args)
|
||||
exported = export.export(
|
||||
func_to_convert,
|
||||
lowering_platform='tpu'
|
||||
lowering_platforms=("tpu",)
|
||||
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
|
||||
|
||||
if transform1 == "shard_map":
|
||||
@ -1661,6 +1689,68 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf_nested = jax2tf.convert(f_jax_nested, native_serialization=True)
|
||||
self.assertAllClose(res, f_tf_nested(inputs))
|
||||
|
||||
def test_multi_platform(self):
|
||||
if config.enable_x64.value:
|
||||
self.skipTest("TODO: enable when we can handle i64 platform_index_argument")
|
||||
# Checks that we dispatch from TF to the proper JAX platform lowering.
|
||||
|
||||
# A primitive for testing multi-platform lowering. Takes one argument and
|
||||
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
|
||||
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
|
||||
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
|
||||
|
||||
@_testing_multi_platform_p.def_abstract_eval
|
||||
def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue):
|
||||
assert xaval.dtype == np.float32 # type: ignore
|
||||
return xaval
|
||||
|
||||
@_testing_multi_platform_p.def_impl
|
||||
def _testing_multi_platform_impl(x: jax.Array) -> jax.Array:
|
||||
to_add = _testing_multi_platform_to_add[platform]
|
||||
return x + to_add
|
||||
|
||||
def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext,
|
||||
x: mlir.Value,
|
||||
*,
|
||||
platform: str) -> Sequence[mlir.Value]:
|
||||
to_add = _testing_multi_platform_to_add[platform]
|
||||
to_add_value = mlir.broadcast_in_dim(ctx,
|
||||
mlir.ir_constant(
|
||||
np.float32(to_add)),
|
||||
ctx.avals_in[0],
|
||||
broadcast_dimensions=())
|
||||
return mlir.hlo.AddOp(x, to_add_value).results
|
||||
|
||||
# Register a default rule for cuda, to test the default-platform rule selection.
|
||||
mlir.register_lowering(_testing_multi_platform_p,
|
||||
functools.partial(_testing_multi_platform_lowering,
|
||||
platform="cuda"))
|
||||
for platform in ["cpu", "tpu", "rocm"]:
|
||||
mlir.register_lowering(_testing_multi_platform_p,
|
||||
functools.partial(
|
||||
_testing_multi_platform_lowering,
|
||||
platform=platform),
|
||||
platform=platform)
|
||||
|
||||
def f_jax(x):
|
||||
return _testing_multi_platform_p.bind(x)
|
||||
|
||||
x = np.float32(.42)
|
||||
f_tf = jax2tf.convert(
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
native_serialization_platforms=("cpu", "cuda", "tpu"))
|
||||
for tf_device in self.__class__.tf_devices:
|
||||
with tf.device(tf_device):
|
||||
res = f_tf(x)
|
||||
logging.info(f"tf_device = {tf_device} and device_type = {tf_device.device_type}")
|
||||
tf_device_jax_platform = dict(
|
||||
CPU="cpu", GPU="cuda", TPU="tpu"
|
||||
)[tf_device.device_type]
|
||||
self.assertAllClose(
|
||||
res,
|
||||
x + _testing_multi_platform_to_add[tf_device_jax_platform])
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_custom_prng=True)
|
||||
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
||||
|
@ -127,7 +127,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return jnp.sin(x)
|
||||
exp = export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
|
||||
self.assertEqual("my_fun", exp.fun_name)
|
||||
self.assertEqual(export.default_lowering_platform(), exp.lowering_platform)
|
||||
self.assertEqual((export.default_lowering_platform(),),
|
||||
exp.lowering_platforms)
|
||||
self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree)
|
||||
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
|
||||
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
|
||||
@ -138,10 +139,10 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f(a_b_pair, *, a, b):
|
||||
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
|
||||
|
||||
exp = export.export(f, lowering_platform="cpu")((a, b), a=a, b=b)
|
||||
exp = export.export(f, lowering_platforms=("cpu",))((a, b), a=a, b=b)
|
||||
a_aval = core.ShapedArray(a.shape, a.dtype)
|
||||
b_aval = core.ShapedArray(b.shape, b.dtype)
|
||||
self.assertEqual(exp.lowering_platform, "cpu")
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu",))
|
||||
args = ((a, b),)
|
||||
kwargs = dict(a=a, b=b)
|
||||
self.assertEqual(exp.in_tree, tree_util.tree_flatten((args, kwargs))[1])
|
||||
@ -258,7 +259,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def test_error_wrong_platform(self, platform):
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
|
||||
exp_f = export.export(jnp.sin, lowering_platform=platform)(a)
|
||||
exp_f = export.export(jnp.sin, lowering_platforms=(platform,))(a)
|
||||
if xb.canonicalize_platform(jtu.device_under_test()) == platform:
|
||||
raise unittest.SkipTest("Uninteresting scenario")
|
||||
|
||||
@ -268,7 +269,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
# Now try with the platform check disabled
|
||||
exp_f_no_platform_check = export.export(
|
||||
jnp.sin, lowering_platform=platform,
|
||||
jnp.sin, lowering_platforms=(platform,),
|
||||
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
|
||||
res = export.call_exported(exp_f_no_platform_check)(a)
|
||||
self.assertAllClose(res, jnp.sin(a))
|
||||
|
Loading…
x
Reference in New Issue
Block a user