mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #18104 from gnecula:multi_jax2tf
PiperOrigin-RevId: 573951693
This commit is contained in:
commit
5919c1f33c
@ -229,7 +229,6 @@ class Exported:
|
||||
|
||||
in_shardings: tuple[Sharding, ...]
|
||||
out_shardings: tuple[Sharding, ...]
|
||||
lowering_platform: str # For backwards compatibility
|
||||
lowering_platforms: tuple[str, ...]
|
||||
disabled_checks: Sequence[DisabledSafetyCheck]
|
||||
|
||||
@ -357,6 +356,7 @@ def poly_specs(
|
||||
|
||||
def export(fun_jax: Callable,
|
||||
*,
|
||||
# TODO(necula): remove this kwarg
|
||||
lowering_platform: Optional[str] = None,
|
||||
lowering_platforms: Optional[Sequence[str]] = None,
|
||||
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
@ -365,9 +365,9 @@ def export(fun_jax: Callable,
|
||||
|
||||
Args:
|
||||
fun_jax: the function to lower and serialize.
|
||||
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'. If None, then use
|
||||
the default JAX backend.
|
||||
lowering_platforms: DO NOT USE (NOT YET FUNCTIONAL).
|
||||
lowering_platform: DO NOT USE, FOR BACKWARDS COMPATIBILITY ONLY. Use
|
||||
`lowering_platforms`.
|
||||
lowering_platforms:
|
||||
Optional sequence containing a subset of 'tpu', 'cpu',
|
||||
'cuda', 'rocm'. If more than one platform is specified, then
|
||||
the lowered code takes an argument specifying the platform.
|
||||
@ -408,11 +408,10 @@ def export(fun_jax: Callable,
|
||||
wrapped_fun_jax = fun_jax # type: ignore
|
||||
allow_non_replicated_sharding = True
|
||||
|
||||
nonlocal lowering_platforms
|
||||
if lowering_platforms is not None:
|
||||
lowering_platforms = tuple(lowering_platforms)
|
||||
actual_lowering_platforms = tuple(lowering_platforms)
|
||||
else:
|
||||
lowering_platforms = (lowering_platform or default_lowering_platform(),)
|
||||
actual_lowering_platforms = (lowering_platform or default_lowering_platform(),)
|
||||
|
||||
# Do not include shape assertions if the version is < 7.
|
||||
enable_shape_assertions = (
|
||||
@ -424,7 +423,7 @@ def export(fun_jax: Callable,
|
||||
lowered = wrapped_fun_jax.lower(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=lowering_platforms,
|
||||
platforms=actual_lowering_platforms,
|
||||
))
|
||||
|
||||
lowering = lowered._lowering # type: ignore
|
||||
@ -467,7 +466,7 @@ def export(fun_jax: Callable,
|
||||
if logging.vlog_is_on(3):
|
||||
mlir_module_text = mlir.module_to_string(mlir_module)
|
||||
logmsg = (f"version={version} "
|
||||
f"lowering_platforms={lowering_platforms} "
|
||||
f"lowering_platforms={actual_lowering_platforms} "
|
||||
f"disabled_checks={disabled_checks}")
|
||||
logging.info("Lowered JAX module: %s\n", logmsg)
|
||||
for l in mlir_module_text.splitlines():
|
||||
@ -485,8 +484,7 @@ def export(fun_jax: Callable,
|
||||
out_avals=tuple(out_avals_flat),
|
||||
in_shardings=lowering.compile_args["in_shardings"],
|
||||
out_shardings=lowering.compile_args["out_shardings"],
|
||||
lowering_platform=lowering_platforms[0], # TODO: remove
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_platforms=actual_lowering_platforms,
|
||||
disabled_checks=tuple(disabled_checks),
|
||||
mlir_module_serialized=mlir_module_serialized,
|
||||
module_kept_var_idx=module_kept_var_idx,
|
||||
@ -932,7 +930,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
|
||||
out_shardings=primal.out_shardings,
|
||||
apply_jit=True)
|
||||
return export(fun_vjp_jax,
|
||||
lowering_platform=primal.lowering_platform,
|
||||
lowering_platforms=primal.lowering_platforms,
|
||||
disabled_checks=primal.disabled_checks)(*vjp_in_avals)
|
||||
|
||||
### Importing
|
||||
|
@ -234,7 +234,7 @@ def convert(fun_jax: Callable,
|
||||
with_gradient: bool = True,
|
||||
enable_xla: bool = True,
|
||||
native_serialization: Union[bool, _DefaultNativeSerialization] = DEFAULT_NATIVE_SERIALIZATION,
|
||||
native_serialization_platforms: Sequence[str] = (),
|
||||
native_serialization_platforms: Optional[Sequence[str]] = None,
|
||||
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
) -> Callable:
|
||||
"""Allows calling a JAX function from a TensorFlow program.
|
||||
@ -307,7 +307,7 @@ def convert(fun_jax: Callable,
|
||||
`native_serialization`, specify the platform(s)
|
||||
for which to lower the code. Must be a tuple of
|
||||
strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'.
|
||||
The default (empty tuple), specifies the JAX default
|
||||
The default (`None``), specifies the JAX default
|
||||
backend on the machine where the lowering is done.
|
||||
native_serialization_disabled_checks: In conjunction with
|
||||
`native_serialization`, disable the specified safety checks.
|
||||
@ -342,9 +342,6 @@ def convert(fun_jax: Callable,
|
||||
"containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. "
|
||||
f"Got: {native_serialization_platforms}")
|
||||
native_serialization_platforms = tuple(native_serialization_platforms)
|
||||
if len(native_serialization_platforms) > 1:
|
||||
raise NotImplementedError(
|
||||
"native_serialization_platforms is not yet implemented for multiple platforms")
|
||||
|
||||
api.check_callable(fun_jax)
|
||||
|
||||
@ -478,7 +475,7 @@ class SerializationImpl:
|
||||
class NativeSerializationImpl(SerializationImpl):
|
||||
def __init__(self, fun_jax, *,
|
||||
args_specs, kwargs_specs,
|
||||
native_serialization_platforms: Sequence[str],
|
||||
native_serialization_platforms: Optional[Sequence[str]],
|
||||
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]):
|
||||
self.convert_kwargs = dict(native_serialization=True,
|
||||
native_serialization_platforms=native_serialization_platforms,
|
||||
@ -487,10 +484,7 @@ class NativeSerializationImpl(SerializationImpl):
|
||||
self.args_specs = args_specs
|
||||
self.kwargs_specs = kwargs_specs
|
||||
self.native_serialization_disabled_checks = native_serialization_disabled_checks
|
||||
if native_serialization_platforms:
|
||||
self.lowering_platform: Optional[str] = native_serialization_platforms[0]
|
||||
else:
|
||||
self.lowering_platform = None
|
||||
self.native_serialization_platforms = native_serialization_platforms
|
||||
|
||||
def before_conversion(self):
|
||||
_prev_func_list = _thread_local_state.call_tf_concrete_function_list
|
||||
@ -502,7 +496,7 @@ class NativeSerializationImpl(SerializationImpl):
|
||||
self._restore_context = _restore_context
|
||||
self.exported = export.export(
|
||||
self.fun_jax,
|
||||
lowering_platform=self.lowering_platform,
|
||||
lowering_platforms=self.native_serialization_platforms,
|
||||
disabled_checks=self.native_serialization_disabled_checks
|
||||
)(*self.args_specs, **self.kwargs_specs)
|
||||
|
||||
@ -850,7 +844,7 @@ def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
|
||||
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
|
||||
)
|
||||
|
||||
call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
|
||||
call_module_attrs["platforms"] = tuple(p.upper() for p in exported.lowering_platforms)
|
||||
if version >= 6:
|
||||
call_module_attrs["disabled_checks"] = tuple(
|
||||
str(dc)
|
||||
|
@ -284,7 +284,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
|
||||
exported = export.export(
|
||||
jax.jit(func),
|
||||
lowering_platform=self.default_jax_backend(),
|
||||
lowering_platforms=(self.default_jax_backend(),),
|
||||
disabled_checks=tuple(
|
||||
export.DisabledSafetyCheck.custom_call(target)
|
||||
for target in allow_unstable_custom_call_targets)
|
||||
@ -320,7 +320,6 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
out_avals=tuple(out_avals),
|
||||
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
|
||||
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
|
||||
lowering_platform=data.platform,
|
||||
lowering_platforms=(data.platform,),
|
||||
disabled_checks=(),
|
||||
mlir_module_serialized=data.mlir_module_serialized,
|
||||
|
@ -15,10 +15,13 @@
|
||||
|
||||
Specific JAX primitive conversion tests are in primitives_test."""
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
import unittest
|
||||
|
||||
from absl import logging
|
||||
@ -35,6 +38,7 @@ from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
@ -54,6 +58,30 @@ config.parse_flags_with_absl()
|
||||
|
||||
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Pick one device from each available platform
|
||||
cls.jax_platforms = []
|
||||
for backend in ["cpu", "gpu", "tpu"]:
|
||||
try:
|
||||
devices = jax.devices(backend)
|
||||
except RuntimeError:
|
||||
devices = []
|
||||
if devices:
|
||||
cls.jax_platforms.append(devices[0].platform)
|
||||
|
||||
# One TF device of each device_type
|
||||
cls.tf_devices = []
|
||||
for tf_device in (tf.config.list_logical_devices("TPU") +
|
||||
tf.config.list_logical_devices("GPU") +
|
||||
tf.config.list_logical_devices()):
|
||||
if tf_device.device_type == "TPU_SYSTEM":
|
||||
continue # A virtual device
|
||||
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
|
||||
cls.tf_devices.append(tf_device)
|
||||
|
||||
super(Jax2TfTest, cls).setUpClass()
|
||||
|
||||
def test_empty(self):
|
||||
f_jax = lambda x, y: x
|
||||
self.ConvertAndCompare(f_jax, 0.7, 1)
|
||||
@ -1537,7 +1565,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
_ = func_to_convert(*args)
|
||||
exported = export.export(
|
||||
func_to_convert,
|
||||
lowering_platform='tpu'
|
||||
lowering_platforms=("tpu",)
|
||||
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
|
||||
|
||||
if transform1 == "shard_map":
|
||||
@ -1661,6 +1689,68 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf_nested = jax2tf.convert(f_jax_nested, native_serialization=True)
|
||||
self.assertAllClose(res, f_tf_nested(inputs))
|
||||
|
||||
def test_multi_platform(self):
|
||||
if config.enable_x64.value:
|
||||
self.skipTest("TODO: enable when we can handle i64 platform_index_argument")
|
||||
# Checks that we dispatch from TF to the proper JAX platform lowering.
|
||||
|
||||
# A primitive for testing multi-platform lowering. Takes one argument and
|
||||
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
|
||||
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
|
||||
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
|
||||
|
||||
@_testing_multi_platform_p.def_abstract_eval
|
||||
def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue):
|
||||
assert xaval.dtype == np.float32 # type: ignore
|
||||
return xaval
|
||||
|
||||
@_testing_multi_platform_p.def_impl
|
||||
def _testing_multi_platform_impl(x: jax.Array) -> jax.Array:
|
||||
to_add = _testing_multi_platform_to_add[platform]
|
||||
return x + to_add
|
||||
|
||||
def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext,
|
||||
x: mlir.Value,
|
||||
*,
|
||||
platform: str) -> Sequence[mlir.Value]:
|
||||
to_add = _testing_multi_platform_to_add[platform]
|
||||
to_add_value = mlir.broadcast_in_dim(ctx,
|
||||
mlir.ir_constant(
|
||||
np.float32(to_add)),
|
||||
ctx.avals_in[0],
|
||||
broadcast_dimensions=())
|
||||
return mlir.hlo.AddOp(x, to_add_value).results
|
||||
|
||||
# Register a default rule for cuda, to test the default-platform rule selection.
|
||||
mlir.register_lowering(_testing_multi_platform_p,
|
||||
functools.partial(_testing_multi_platform_lowering,
|
||||
platform="cuda"))
|
||||
for platform in ["cpu", "tpu", "rocm"]:
|
||||
mlir.register_lowering(_testing_multi_platform_p,
|
||||
functools.partial(
|
||||
_testing_multi_platform_lowering,
|
||||
platform=platform),
|
||||
platform=platform)
|
||||
|
||||
def f_jax(x):
|
||||
return _testing_multi_platform_p.bind(x)
|
||||
|
||||
x = np.float32(.42)
|
||||
f_tf = jax2tf.convert(
|
||||
f_jax,
|
||||
native_serialization=True,
|
||||
native_serialization_platforms=("cpu", "cuda", "tpu"))
|
||||
for tf_device in self.__class__.tf_devices:
|
||||
with tf.device(tf_device):
|
||||
res = f_tf(x)
|
||||
logging.info(f"tf_device = {tf_device} and device_type = {tf_device.device_type}")
|
||||
tf_device_jax_platform = dict(
|
||||
CPU="cpu", GPU="cuda", TPU="tpu"
|
||||
)[tf_device.device_type]
|
||||
self.assertAllClose(
|
||||
res,
|
||||
x + _testing_multi_platform_to_add[tf_device_jax_platform])
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_custom_prng=True)
|
||||
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
||||
|
@ -127,7 +127,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return jnp.sin(x)
|
||||
exp = export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
|
||||
self.assertEqual("my_fun", exp.fun_name)
|
||||
self.assertEqual(export.default_lowering_platform(), exp.lowering_platform)
|
||||
self.assertEqual((export.default_lowering_platform(),),
|
||||
exp.lowering_platforms)
|
||||
self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree)
|
||||
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
|
||||
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
|
||||
@ -138,10 +139,10 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f(a_b_pair, *, a, b):
|
||||
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
|
||||
|
||||
exp = export.export(f, lowering_platform="cpu")((a, b), a=a, b=b)
|
||||
exp = export.export(f, lowering_platforms=("cpu",))((a, b), a=a, b=b)
|
||||
a_aval = core.ShapedArray(a.shape, a.dtype)
|
||||
b_aval = core.ShapedArray(b.shape, b.dtype)
|
||||
self.assertEqual(exp.lowering_platform, "cpu")
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu",))
|
||||
args = ((a, b),)
|
||||
kwargs = dict(a=a, b=b)
|
||||
self.assertEqual(exp.in_tree, tree_util.tree_flatten((args, kwargs))[1])
|
||||
@ -258,7 +259,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def test_error_wrong_platform(self, platform):
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
|
||||
exp_f = export.export(jnp.sin, lowering_platform=platform)(a)
|
||||
exp_f = export.export(jnp.sin, lowering_platforms=(platform,))(a)
|
||||
if xb.canonicalize_platform(jtu.device_under_test()) == platform:
|
||||
raise unittest.SkipTest("Uninteresting scenario")
|
||||
|
||||
@ -268,7 +269,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
# Now try with the platform check disabled
|
||||
exp_f_no_platform_check = export.export(
|
||||
jnp.sin, lowering_platform=platform,
|
||||
jnp.sin, lowering_platforms=(platform,),
|
||||
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
|
||||
res = export.call_exported(exp_f_no_platform_check)(a)
|
||||
self.assertAllClose(res, jnp.sin(a))
|
||||
|
Loading…
x
Reference in New Issue
Block a user