[jax2tf] Fixed the conversion of a function that contains an inner pjit

In experimental_native_lowering when we convert a function that is not
a jit or pjit, we wrap it with an implicit jit. We used to specify the
backend when doing this conversion, which bypassed some logic in jit
to handle the merging of jit and pjit code paths. We now drop the
backend parameter to the implicit jit.

We also moved some pjit tests from jax2tf to sharding_test and dropped
the old disabled test for teh GDAs, since GDAs are going away.
This commit is contained in:
George Necula 2023-01-08 10:09:48 +02:00
parent 7788f0cf6b
commit 21ebf9042d
3 changed files with 71 additions and 68 deletions

View File

@ -624,12 +624,11 @@ def _lower_native_and_run(fun_jax: Callable,
for aval in args_avals
]
# TODO: specify the backend for experimental_native_lowering
backend = jax.default_backend()
if not hasattr(fun_jax, "lower") or abstracted_axes:
# We support convert(pjit(f_jax, ...)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
# we need to pass the abstracted axes.
fun_jax_lower = jax.jit(fun_jax, backend=backend,
fun_jax_lower = jax.jit(fun_jax,
abstracted_axes=abstracted_axes).lower
else:
fun_jax_lower = fun_jax.lower

View File

@ -15,7 +15,6 @@
Specific JAX primitive conversion tests are in primitives_test."""
import collections
from functools import partial
import os
from typing import Callable, Dict, Optional, Tuple
import unittest
@ -33,12 +32,11 @@ from jax._src import test_util as jtu
import jax._src.lib.xla_bridge
from jax.config import config
from jax.experimental import jax2tf
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.pjit import FROM_GDA
from jax.experimental.pjit import pjit
from jax.experimental import pjit
from jax.interpreters import mlir
from jax.interpreters.pxla import PartitionSpec as P
import numpy as np
import tensorflow as tf # type: ignore[import]
# pylint: disable=g-direct-tensorflow-import
@ -1231,59 +1229,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
include_xla_op_metadata=False
)
@jtu.with_mesh([("x", 1)])
def test_pjit_simple(self):
@partial(pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
def func_jax(x, y):
return x + y
self.ConvertAndCompare(func_jax, jnp.ones((3, 4), dtype=np.float32),
jnp.ones((1, 1), dtype=np.float32))
@jtu.with_mesh([("x", 1)])
def test_pjit_closed_over_const(self):
const = jnp.full((3, 4), 7, dtype=np.float32)
@partial(pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
def func_jax(x, y):
return x + y * const
self.ConvertAndCompare(func_jax, jnp.ones((3, 4), dtype=np.float32),
jnp.ones((1, 1), dtype=np.float32))
# TODO(necula): figure out this failure
@jtu.skip_on_flag("jax2tf_default_experimental_native_lowering", True)
def test_global_device_array(self):
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
if global_data is None:
global_data = np.arange(np.prod(global_shape)).reshape(global_shape)
return GlobalDeviceArray.from_callback(
global_shape, global_mesh, mesh_axes,
lambda idx: global_data[idx]), global_data
global_mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh_axes = P(("x", "y"))
params, _ = create_gda((8, 2), global_mesh, mesh_axes)
input_data = np.arange(16).reshape(2, 8)
# Test 1: use GDA as constants
def jax_func(input_data):
handle = pjit(
jnp.matmul,
in_axis_resources=(P("y", "x"), FROM_GDA),
out_axis_resources=None)
return handle(input_data, params)
with global_mesh:
tf_func = tf.function(
jax2tf.convert(jax_func, enable_xla=True),
jit_compile=True, autograph=False
)
jax_out = jax_func(input_data=input_data)
tf_out = tf_func(input_data=input_data)
# TODO(b/243146552) We can switch to ConvertAndCompare after this bug fix.
np.array_equal(jax_out._value, np.array(tf_out))
def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str):
"""Assert all operations name start with ```scope_name```.
@ -1400,7 +1345,7 @@ def get_serialized_computation(
out_axis_resources = None) -> str:
if use_pjit:
assert not abstracted_axes
lowered = pjit(f_jax,
lowered = pjit.pjit(f_jax,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources).lower(*args)
else:
@ -1503,7 +1448,7 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
in_axis_resources = (P("x"), P("x"))
out_axis_resources = None
res_jax = pjit(
res_jax = pjit.pjit(
func_jax,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources)(x, x)

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""Tests for the jax2tf conversion of pjit."""
import functools
from functools import partial
import logging
import os
import re
@ -28,7 +28,6 @@ from jax.config import config
from jax.experimental import jax2tf
from jax.experimental import pjit
from jax.experimental.jax2tf.tests import tf_test_util
from jax.interpreters.pxla import PartitionSpec as P
import jax.numpy as jnp
from jax._src.lib import xla_bridge
@ -39,6 +38,14 @@ import tensorflow as tf # type: ignore[import]
config.parse_flags_with_absl()
# Must come after initializing the flags
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation
skip_eager_for_partitioning = Jax2TfLimitation(
"pjit functions with partitioning must be under tf.function",
modes="eager", skip_tf_run=True)
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
@ -169,7 +176,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2)])
def test_pjit_basic1D(self):
@functools.partial(pjit.pjit,
@partial(pjit.pjit,
in_axis_resources=(P("x"), P("x")),
out_axis_resources=None)
def jax_func(x, y):
@ -195,7 +202,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2)])
def test_pjit_basic1D_variable(self):
# The first argument is a tf.Variable
@functools.partial(pjit.pjit,
@partial(pjit.pjit,
in_axis_resources=(P("x"), P("x")),
out_axis_resources=None)
def jax_func(x, y):
@ -221,7 +228,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2), ("y", 2)])
def test_pjit_basic2D(self):
@functools.partial(pjit.pjit,
@partial(pjit.pjit,
in_axis_resources=(P(None, "x", "y"), P("y")),
out_axis_resources=P("x"))
def jax_func(x, y):
@ -250,7 +257,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2), ("y", 2)])
def test_pjit_TwoMeshAxisSharding(self):
@functools.partial(pjit.pjit,
@partial(pjit.pjit,
in_axis_resources=P(("x", "y"),),
out_axis_resources=P(("x", "y"),))
def jax_func(x, y):
@ -279,7 +286,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2), ("y", 1)])
def test_pjit_ShardingConstraint(self):
@functools.partial(pjit.pjit, in_axis_resources=None,
@partial(pjit.pjit, in_axis_resources=None,
out_axis_resources=None)
def jax_func(x): # x: f32[12, 8]
y = jnp.tile(x, (2, 1)) # y: f32[24, 8]
@ -307,6 +314,13 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
class PjitTest(tf_test_util.JaxToTfTestCase):
def create_test_mesh(self, *axis_names):
"""Creates a mesh with 2 axes"""
assert len(axis_names) == 2, axis_names
nr_devices = len(jax.devices())
mesh_shape = (2, 1) if nr_devices >= 2 else (1, 1)
return jtu.create_global_mesh(mesh_shape, axis_names)
@jtu.with_mesh([("axis", 2)])
def test_pjit_basic1D(self):
def func_jax(x):
@ -341,6 +355,51 @@ class PjitTest(tf_test_util.JaxToTfTestCase):
out_axis_resources=None))(x)
self.assertAllClose(res_tf.numpy(), res_jax)
@jtu.with_mesh([("x", 1)])
def test_pjit_closed_over_const(self):
const = jnp.full((4, 3), 7, dtype=np.float32)
@partial(pjit.pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
def func_jax(x, y):
return x + y * const
with self.create_test_mesh("x", "y"):
self.ConvertAndCompare(func_jax, jnp.ones((4, 3), dtype=np.float32),
jnp.ones((1, 1), dtype=np.float32),
limitations=[skip_eager_for_partitioning])
def test_pjit_closed_over_global_device_array(self):
global_mesh = self.create_test_mesh("x", "y")
input1 = np.arange(16).reshape(2, 8)
input2_raw = np.arange(16).reshape(8, 2)
input2_array = jax.make_array_from_callback(input2_raw.shape,
jax.sharding.NamedSharding(global_mesh, P("x", "y")),
lambda idx: input2_raw[idx])
@partial(pjit.pjit,
in_axis_resources=(P("y", "x"),),
out_axis_resources=None)
def jax_func(input_data):
return jnp.matmul(input_data, input2_array)
with global_mesh:
self.ConvertAndCompare(jax_func, input1,
limitations=[skip_eager_for_partitioning])
def test_nested_pjit(self):
global_mesh = self.create_test_mesh("x", "y")
x = np.arange(16).reshape(2, 8)
def func_jax(x):
# We have a pjit nested inside the function to be converted
inner_func = pjit.pjit(
jnp.sin,
in_axis_resources=(P("y", "x"),),
out_axis_resources=None)
return inner_func(x)
with global_mesh:
self.ConvertAndCompare(func_jax, x,
limitations=[skip_eager_for_partitioning])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())