mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Implement jax2tf(pjit) for experimental_native_lowering
This implementation is for the case jax2tf.convert(pjit(f_jax)), that is, the `pjit` appears at the top-level of the function to be lowered.
This commit is contained in:
parent
ff17d3d9fe
commit
9c879adb73
@ -422,6 +422,21 @@ def flatten_fun_jax(fun_jax: Callable, args_tf: Sequence[TfVal],
|
||||
out_tree_ref = out_tree
|
||||
return res_flat_jax
|
||||
|
||||
if hasattr(fun_jax, "lower"):
|
||||
# If the fun_jax is already a jit(f) or pjit(f), we must
|
||||
# preserve the lowering function. This will be used in the _lower_native_and_run.
|
||||
# We rely on the fact that the lowering is the same for the function
|
||||
# taking pytrees, and the one taking flat args.
|
||||
def fun_flat_jax_lower(*args_flat_jax):
|
||||
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
|
||||
lowered = fun_jax.lower(*tree_args, **tree_kwargs)
|
||||
out_tree = lowered.out_tree
|
||||
nonlocal out_tree_ref
|
||||
assert out_tree_ref is None or out_tree_ref == out_tree
|
||||
out_tree_ref = out_tree
|
||||
return lowered
|
||||
setattr(fun_flat_jax, "lower", fun_flat_jax_lower)
|
||||
|
||||
return fun_flat_jax, args_flat_tf, in_tree, lambda: out_tree_ref
|
||||
|
||||
def preprocess_arg_tf(arg_idx: int,
|
||||
@ -604,19 +619,24 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
abstracted_axes = None # type: ignore
|
||||
|
||||
arg_specs_jax = [
|
||||
jax.ShapeDtypeStruct(aval.shape, aval.dtype)
|
||||
jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape)
|
||||
for aval in args_avals
|
||||
]
|
||||
# TODO: specify the backend for experimental_native_lowering
|
||||
backend = jax.default_backend()
|
||||
lowered = jax.jit(fun_jax, backend=backend,
|
||||
keep_unused=True, # TODO: allow dropping unused
|
||||
abstracted_axes=abstracted_axes).lower(*arg_specs_jax)._lowering
|
||||
|
||||
if not hasattr(fun_jax, "lower") or abstracted_axes:
|
||||
# We support convert(pjit(f_jax, ...)) and convert(jit(f_jax)) but also
|
||||
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
|
||||
# we need to pass the abstracted axes.
|
||||
fun_jax_lower = jax.jit(fun_jax, backend=backend,
|
||||
keep_unused=True, # TODO: allow dropping unused
|
||||
abstracted_axes=abstracted_axes).lower
|
||||
else:
|
||||
fun_jax_lower = fun_jax.lower
|
||||
lowered = fun_jax_lower(*arg_specs_jax)._lowering
|
||||
mhlo_module = lowered.mhlo()
|
||||
mhlo_module_text = mlir.module_to_string(mhlo_module)
|
||||
if jaxlib.version <= (0, 3, 14):
|
||||
mhlo_module_text = _fixup_mhlo_module_text(mhlo_module_text)
|
||||
logging.vlog(2, f"XlaCallModule {mhlo_module_text}")
|
||||
# We do not support custom_call, try to give an error for now
|
||||
if "mhlo.custom_call" in mhlo_module_text:
|
||||
# Try to give a nice error message. We could just dump the module...
|
||||
@ -626,20 +646,25 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
"work on TPU.")
|
||||
custom_calls = re.findall(r'mhlo.custom_call.*call_target_name\s+=\s+"([^"]+)".*loc\(([^\)]+)\)',
|
||||
mhlo_module_text)
|
||||
for cc in custom_calls:
|
||||
msg += f"\n{cc[0]}"
|
||||
# Get the line number
|
||||
m = re.search('^' + cc[1] + ' =.*', mhlo_module_text, re.MULTILINE)
|
||||
if m:
|
||||
msg += f"\n from line {m.group(0)}"
|
||||
raise NotImplementedError(msg)
|
||||
logging.vlog(2, f"XlaCallModule {mhlo_module_text}")
|
||||
bad_custom_calls = tuple(filter(lambda cc: cc[0] != "Sharding", custom_calls))
|
||||
if bad_custom_calls:
|
||||
for cc in bad_custom_calls:
|
||||
msg += f"\n{cc[0]}"
|
||||
# Get the line number
|
||||
m = re.search('^' + cc[1] + ' =.*', mhlo_module_text, re.MULTILINE)
|
||||
if m:
|
||||
msg += f"\n from line {m.group(0)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# Figure out the result types and shapes
|
||||
if config.jax_array:
|
||||
if "global_out_avals" in lowered.compile_args:
|
||||
# This is currently the case for pjit
|
||||
out_avals = lowered.compile_args["global_out_avals"]
|
||||
else:
|
||||
out_avals = lowered.compile_args["out_avals"]
|
||||
if lowered.compile_args["host_callbacks"]:
|
||||
raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering")
|
||||
|
||||
# TODO(necula): handle d being InDBIdx
|
||||
out_shapes = tuple(
|
||||
tuple(d if type(d) is int else None
|
||||
@ -652,12 +677,22 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
return jax_type
|
||||
out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals)
|
||||
|
||||
# Apply the shardings on arguments and results for pjit. This is redundant
|
||||
# because the mhlo_module_text will already contain the shardings, but it
|
||||
# makes it easier for tools like the TPU inference converter to see the
|
||||
# sharding without digging into the `module` attribute of the `XlaCallModule`
|
||||
# op, in the same way as it is done for the legacy jax2tf conversion.
|
||||
if "in_shardings" in lowered.compile_args:
|
||||
args_tf = tuple(
|
||||
map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"]))
|
||||
res = tfxla.call_module(
|
||||
args_tf,
|
||||
module=mhlo_module_text,
|
||||
Tout=out_types,
|
||||
Sout=out_shapes,
|
||||
dim_args_spec=dim_args_spec)
|
||||
if "out_shardings" in lowered.compile_args:
|
||||
res = list(map(_shard_value, res, out_avals, lowered.compile_args["out_shardings"]))
|
||||
|
||||
# Convert the results to the needed TF types
|
||||
def _convert_res(res_val, res_jax_type):
|
||||
@ -672,15 +707,6 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
for res_val, out_aval in zip(res, out_avals))
|
||||
return res, out_avals
|
||||
|
||||
def _fixup_mhlo_module_text(mhlo_module_text: str) -> str:
|
||||
# A workaround for MHLO not (yet) having backwards compatibility. With
|
||||
# jaxlib 0.3.14 we have an old serialization method that puts "..." around
|
||||
# MHLO attributes. The parser is new and does not accept those attributes.
|
||||
# We try to fix it up here, temporarily.
|
||||
import re
|
||||
return re.sub(r'#mhlo<"([^"]+)">', "#mhlo<\\1>", mhlo_module_text)
|
||||
|
||||
|
||||
def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
|
||||
in_vals: Sequence[TfVal],
|
||||
fresh_constant_cache: bool = False
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the jax2tf conversion for control-flow primitives."""
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
@ -29,6 +30,12 @@ config.parse_flags_with_absl()
|
||||
|
||||
class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# TODO(b/252947617): re-enable these tests
|
||||
if config.jax_array and config.jax2tf_default_experimental_native_lowering:
|
||||
raise unittest.SkipTest("Test disabled for JAX_ARRAY")
|
||||
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Explicitly requested dtype .* requested in array is not available")
|
||||
def test_cond(self):
|
||||
|
@ -51,6 +51,12 @@ config.parse_flags_with_absl()
|
||||
|
||||
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# TODO(b/252943725): re-enable these tests
|
||||
if config.jax_array and config.jax2tf_default_experimental_native_lowering:
|
||||
raise unittest.SkipTest("Test disabled for JAX_ARRAY")
|
||||
|
||||
def test_empty(self):
|
||||
f_jax = lambda x, y: x
|
||||
self.ConvertAndCompare(f_jax, 0.7, 1)
|
||||
@ -117,8 +123,16 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_nested_jit(self):
|
||||
f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
|
||||
f_tf = jax2tf.convert(f_jax)
|
||||
np.testing.assert_allclose(f_jax(0.7), f_tf(0.7))
|
||||
x = 0.7
|
||||
self.ConvertAndCompare(f_jax, x)
|
||||
|
||||
def test_nested_jit_pytree(self):
|
||||
@jax.jit
|
||||
def f_jax(xy):
|
||||
x, y = xy
|
||||
return x + y
|
||||
xy = (0.7, 0.8)
|
||||
self.ConvertAndCompare(f_jax, xy)
|
||||
|
||||
def test_nested_jit_is_compiled(self):
|
||||
# Check that nested jax.jit are compiled with tf.function(jit_compile=True)
|
||||
@ -1241,8 +1255,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
def get_serialized_computation(
|
||||
f_jax: Callable,
|
||||
*args,
|
||||
abstracted_axes: Optional[Tuple[Dict[int, str]]] = None) -> str:
|
||||
lowered = jax.jit(f_jax, abstracted_axes=abstracted_axes).lower(*args)
|
||||
abstracted_axes: Optional[Tuple[Dict[int, str]]] = None,
|
||||
use_pjit: bool = False,
|
||||
in_axis_resources = None,
|
||||
out_axis_resources = None) -> str:
|
||||
if use_pjit:
|
||||
assert not abstracted_axes
|
||||
lowered = pjit(f_jax,
|
||||
in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources).lower(*args)
|
||||
else:
|
||||
lowered = jax.jit(f_jax, abstracted_axes=abstracted_axes).lower(*args)
|
||||
mhlo_module = lowered.compiler_ir(dialect='mhlo')
|
||||
mhlo_module_text = mlir.module_to_string(mhlo_module)
|
||||
if jaxlib.version <= (0, 3, 14):
|
||||
@ -1364,23 +1387,29 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
|
||||
# TODO(b/243146552) We can switch to ConvertAndCompare after this bug fix.
|
||||
np.array_equal(jax_out._value, np.array(tf_out))
|
||||
|
||||
# Test 2: use GDA as JAX function input
|
||||
def jax_func_2(input_data, params):
|
||||
handle = pjit(
|
||||
jnp.matmul,
|
||||
in_axis_resources=(P("y", "x"), P(("x", "y"),)),
|
||||
out_axis_resources=None)
|
||||
return handle(input_data, params)
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_basic1D(self):
|
||||
def func_jax(x, y):
|
||||
return x + y
|
||||
|
||||
with global_mesh:
|
||||
tf_func_2 = tf.function(
|
||||
jax2tf.convert(jax_func_2, enable_xla=True),
|
||||
jit_compile=True,
|
||||
)
|
||||
jax_out_2 = jax_func_2(input_data=input_data, params=params)
|
||||
tf_out_2 = tf_func_2(input_data=input_data, params=params)
|
||||
# TODO(b/243146552) We can switch to ConvertAndCompare after this bug fix.
|
||||
np.array_equal(jax_out_2._value, np.array(tf_out_2))
|
||||
shape = (8, 10)
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
in_axis_resources = (P("x"), P("x"))
|
||||
out_axis_resources = None
|
||||
res_jax = pjit(func_jax,
|
||||
in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources)(x, x)
|
||||
module = get_serialized_computation(func_jax, x, x,
|
||||
use_pjit=True,
|
||||
in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources)
|
||||
def f_tf(x_tf, y_tf):
|
||||
return tfxla.call_module([x_tf, y_tf],
|
||||
module=module,
|
||||
Tout=[x.dtype],
|
||||
Sout=[x.shape])
|
||||
res_tf = tf.function(f_tf, jit_compile=True, autograph=False)(x, x)[0]
|
||||
self.assertAllClose(res_tf.numpy(), res_jax)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -696,6 +696,16 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
dict(a="(_,)", b="(4,)")],
|
||||
expected_output_signature=tf.TensorSpec([4]))
|
||||
|
||||
def test_with_nested_jit(self):
|
||||
@jax.jit
|
||||
def f_jax(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
self.CheckShapePolymorphism(
|
||||
f_jax,
|
||||
input_signature=[tf.TensorSpec([1, None])],
|
||||
polymorphic_shapes=["1, b"])
|
||||
|
||||
def test_with_custom_vjp(self):
|
||||
"""Shape-polymorphic custom VJP."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user