mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Fixed the conversion of a function that contains an inner pjit
In experimental_native_lowering when we convert a function that is not a jit or pjit, we wrap it with an implicit jit. We used to specify the backend when doing this conversion, which bypassed some logic in jit to handle the merging of jit and pjit code paths. We now drop the backend parameter to the implicit jit. We also moved some pjit tests from jax2tf to sharding_test and dropped the old disabled test for teh GDAs, since GDAs are going away.
This commit is contained in:
parent
7788f0cf6b
commit
21ebf9042d
@ -624,12 +624,11 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
for aval in args_avals
|
||||
]
|
||||
# TODO: specify the backend for experimental_native_lowering
|
||||
backend = jax.default_backend()
|
||||
if not hasattr(fun_jax, "lower") or abstracted_axes:
|
||||
# We support convert(pjit(f_jax, ...)) and convert(jit(f_jax)) but also
|
||||
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
|
||||
# we need to pass the abstracted axes.
|
||||
fun_jax_lower = jax.jit(fun_jax, backend=backend,
|
||||
fun_jax_lower = jax.jit(fun_jax,
|
||||
abstracted_axes=abstracted_axes).lower
|
||||
else:
|
||||
fun_jax_lower = fun_jax.lower
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
Specific JAX primitive conversion tests are in primitives_test."""
|
||||
import collections
|
||||
from functools import partial
|
||||
import os
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
import unittest
|
||||
@ -33,12 +32,11 @@ from jax._src import test_util as jtu
|
||||
import jax._src.lib.xla_bridge
|
||||
from jax.config import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.experimental.pjit import FROM_GDA
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
@ -1231,59 +1229,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
include_xla_op_metadata=False
|
||||
)
|
||||
|
||||
@jtu.with_mesh([("x", 1)])
|
||||
def test_pjit_simple(self):
|
||||
@partial(pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
|
||||
def func_jax(x, y):
|
||||
return x + y
|
||||
|
||||
self.ConvertAndCompare(func_jax, jnp.ones((3, 4), dtype=np.float32),
|
||||
jnp.ones((1, 1), dtype=np.float32))
|
||||
|
||||
@jtu.with_mesh([("x", 1)])
|
||||
def test_pjit_closed_over_const(self):
|
||||
const = jnp.full((3, 4), 7, dtype=np.float32)
|
||||
@partial(pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
|
||||
def func_jax(x, y):
|
||||
return x + y * const
|
||||
|
||||
self.ConvertAndCompare(func_jax, jnp.ones((3, 4), dtype=np.float32),
|
||||
jnp.ones((1, 1), dtype=np.float32))
|
||||
|
||||
# TODO(necula): figure out this failure
|
||||
@jtu.skip_on_flag("jax2tf_default_experimental_native_lowering", True)
|
||||
def test_global_device_array(self):
|
||||
|
||||
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
|
||||
if global_data is None:
|
||||
global_data = np.arange(np.prod(global_shape)).reshape(global_shape)
|
||||
return GlobalDeviceArray.from_callback(
|
||||
global_shape, global_mesh, mesh_axes,
|
||||
lambda idx: global_data[idx]), global_data
|
||||
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
|
||||
mesh_axes = P(("x", "y"))
|
||||
params, _ = create_gda((8, 2), global_mesh, mesh_axes)
|
||||
input_data = np.arange(16).reshape(2, 8)
|
||||
|
||||
# Test 1: use GDA as constants
|
||||
def jax_func(input_data):
|
||||
handle = pjit(
|
||||
jnp.matmul,
|
||||
in_axis_resources=(P("y", "x"), FROM_GDA),
|
||||
out_axis_resources=None)
|
||||
return handle(input_data, params)
|
||||
|
||||
with global_mesh:
|
||||
tf_func = tf.function(
|
||||
jax2tf.convert(jax_func, enable_xla=True),
|
||||
jit_compile=True, autograph=False
|
||||
)
|
||||
jax_out = jax_func(input_data=input_data)
|
||||
tf_out = tf_func(input_data=input_data)
|
||||
# TODO(b/243146552) We can switch to ConvertAndCompare after this bug fix.
|
||||
np.array_equal(jax_out._value, np.array(tf_out))
|
||||
|
||||
def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str):
|
||||
"""Assert all operations name start with ```scope_name```.
|
||||
|
||||
@ -1400,7 +1345,7 @@ def get_serialized_computation(
|
||||
out_axis_resources = None) -> str:
|
||||
if use_pjit:
|
||||
assert not abstracted_axes
|
||||
lowered = pjit(f_jax,
|
||||
lowered = pjit.pjit(f_jax,
|
||||
in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources).lower(*args)
|
||||
else:
|
||||
@ -1503,7 +1448,7 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
in_axis_resources = (P("x"), P("x"))
|
||||
out_axis_resources = None
|
||||
res_jax = pjit(
|
||||
res_jax = pjit.pjit(
|
||||
func_jax,
|
||||
in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources)(x, x)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Tests for the jax2tf conversion of pjit."""
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@ -28,7 +28,6 @@ from jax.config import config
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lib import xla_bridge
|
||||
@ -39,6 +38,14 @@ import tensorflow as tf # type: ignore[import]
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# Must come after initializing the flags
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation
|
||||
|
||||
skip_eager_for_partitioning = Jax2TfLimitation(
|
||||
"pjit functions with partitioning must be under tf.function",
|
||||
modes="eager", skip_tf_run=True)
|
||||
|
||||
prev_xla_flags = None
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
@ -169,7 +176,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_basic1D(self):
|
||||
|
||||
@functools.partial(pjit.pjit,
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=(P("x"), P("x")),
|
||||
out_axis_resources=None)
|
||||
def jax_func(x, y):
|
||||
@ -195,7 +202,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_basic1D_variable(self):
|
||||
# The first argument is a tf.Variable
|
||||
@functools.partial(pjit.pjit,
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=(P("x"), P("x")),
|
||||
out_axis_resources=None)
|
||||
def jax_func(x, y):
|
||||
@ -221,7 +228,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@jtu.with_mesh([("x", 2), ("y", 2)])
|
||||
def test_pjit_basic2D(self):
|
||||
@functools.partial(pjit.pjit,
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=(P(None, "x", "y"), P("y")),
|
||||
out_axis_resources=P("x"))
|
||||
def jax_func(x, y):
|
||||
@ -250,7 +257,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@jtu.with_mesh([("x", 2), ("y", 2)])
|
||||
def test_pjit_TwoMeshAxisSharding(self):
|
||||
@functools.partial(pjit.pjit,
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=P(("x", "y"),),
|
||||
out_axis_resources=P(("x", "y"),))
|
||||
def jax_func(x, y):
|
||||
@ -279,7 +286,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@jtu.with_mesh([("x", 2), ("y", 1)])
|
||||
def test_pjit_ShardingConstraint(self):
|
||||
@functools.partial(pjit.pjit, in_axis_resources=None,
|
||||
@partial(pjit.pjit, in_axis_resources=None,
|
||||
out_axis_resources=None)
|
||||
def jax_func(x): # x: f32[12, 8]
|
||||
y = jnp.tile(x, (2, 1)) # y: f32[24, 8]
|
||||
@ -307,6 +314,13 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
class PjitTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def create_test_mesh(self, *axis_names):
|
||||
"""Creates a mesh with 2 axes"""
|
||||
assert len(axis_names) == 2, axis_names
|
||||
nr_devices = len(jax.devices())
|
||||
mesh_shape = (2, 1) if nr_devices >= 2 else (1, 1)
|
||||
return jtu.create_global_mesh(mesh_shape, axis_names)
|
||||
|
||||
@jtu.with_mesh([("axis", 2)])
|
||||
def test_pjit_basic1D(self):
|
||||
def func_jax(x):
|
||||
@ -341,6 +355,51 @@ class PjitTest(tf_test_util.JaxToTfTestCase):
|
||||
out_axis_resources=None))(x)
|
||||
self.assertAllClose(res_tf.numpy(), res_jax)
|
||||
|
||||
@jtu.with_mesh([("x", 1)])
|
||||
def test_pjit_closed_over_const(self):
|
||||
const = jnp.full((4, 3), 7, dtype=np.float32)
|
||||
@partial(pjit.pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
|
||||
def func_jax(x, y):
|
||||
return x + y * const
|
||||
|
||||
with self.create_test_mesh("x", "y"):
|
||||
self.ConvertAndCompare(func_jax, jnp.ones((4, 3), dtype=np.float32),
|
||||
jnp.ones((1, 1), dtype=np.float32),
|
||||
limitations=[skip_eager_for_partitioning])
|
||||
|
||||
def test_pjit_closed_over_global_device_array(self):
|
||||
global_mesh = self.create_test_mesh("x", "y")
|
||||
|
||||
input1 = np.arange(16).reshape(2, 8)
|
||||
input2_raw = np.arange(16).reshape(8, 2)
|
||||
input2_array = jax.make_array_from_callback(input2_raw.shape,
|
||||
jax.sharding.NamedSharding(global_mesh, P("x", "y")),
|
||||
lambda idx: input2_raw[idx])
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=(P("y", "x"),),
|
||||
out_axis_resources=None)
|
||||
def jax_func(input_data):
|
||||
return jnp.matmul(input_data, input2_array)
|
||||
|
||||
with global_mesh:
|
||||
self.ConvertAndCompare(jax_func, input1,
|
||||
limitations=[skip_eager_for_partitioning])
|
||||
|
||||
def test_nested_pjit(self):
|
||||
global_mesh = self.create_test_mesh("x", "y")
|
||||
x = np.arange(16).reshape(2, 8)
|
||||
|
||||
def func_jax(x):
|
||||
# We have a pjit nested inside the function to be converted
|
||||
inner_func = pjit.pjit(
|
||||
jnp.sin,
|
||||
in_axis_resources=(P("y", "x"),),
|
||||
out_axis_resources=None)
|
||||
return inner_func(x)
|
||||
|
||||
with global_mesh:
|
||||
self.ConvertAndCompare(func_jax, x,
|
||||
limitations=[skip_eager_for_partitioning])
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user