Merge pull request #13917 from gnecula:tf_bug_gda

PiperOrigin-RevId: 500971319
This commit is contained in:
jax authors 2023-01-10 05:58:28 -08:00
commit 62f2b9680b
3 changed files with 71 additions and 68 deletions

View File

@ -624,12 +624,11 @@ def _lower_native_and_run(fun_jax: Callable,
for aval in args_avals
]
# TODO: specify the backend for experimental_native_lowering
backend = jax.default_backend()
if not hasattr(fun_jax, "lower") or abstracted_axes:
# We support convert(pjit(f_jax, ...)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
# we need to pass the abstracted axes.
fun_jax_lower = jax.jit(fun_jax, backend=backend,
fun_jax_lower = jax.jit(fun_jax,
abstracted_axes=abstracted_axes).lower
else:
fun_jax_lower = fun_jax.lower

View File

@ -15,7 +15,6 @@
Specific JAX primitive conversion tests are in primitives_test."""
import collections
from functools import partial
import os
from typing import Callable, Dict, Optional, Tuple
import unittest
@ -33,12 +32,11 @@ from jax._src import test_util as jtu
import jax._src.lib.xla_bridge
from jax.config import config
from jax.experimental import jax2tf
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.pjit import FROM_GDA
from jax.experimental.pjit import pjit
from jax.experimental import pjit
from jax.interpreters import mlir
from jax.interpreters.pxla import PartitionSpec as P
import numpy as np
import tensorflow as tf # type: ignore[import]
# pylint: disable=g-direct-tensorflow-import
@ -1231,59 +1229,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
include_xla_op_metadata=False
)
@jtu.with_mesh([("x", 1)])
def test_pjit_simple(self):
@partial(pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
def func_jax(x, y):
return x + y
self.ConvertAndCompare(func_jax, jnp.ones((3, 4), dtype=np.float32),
jnp.ones((1, 1), dtype=np.float32))
@jtu.with_mesh([("x", 1)])
def test_pjit_closed_over_const(self):
const = jnp.full((3, 4), 7, dtype=np.float32)
@partial(pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
def func_jax(x, y):
return x + y * const
self.ConvertAndCompare(func_jax, jnp.ones((3, 4), dtype=np.float32),
jnp.ones((1, 1), dtype=np.float32))
# TODO(necula): figure out this failure
@jtu.skip_on_flag("jax2tf_default_experimental_native_lowering", True)
def test_global_device_array(self):
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
if global_data is None:
global_data = np.arange(np.prod(global_shape)).reshape(global_shape)
return GlobalDeviceArray.from_callback(
global_shape, global_mesh, mesh_axes,
lambda idx: global_data[idx]), global_data
global_mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh_axes = P(("x", "y"))
params, _ = create_gda((8, 2), global_mesh, mesh_axes)
input_data = np.arange(16).reshape(2, 8)
# Test 1: use GDA as constants
def jax_func(input_data):
handle = pjit(
jnp.matmul,
in_axis_resources=(P("y", "x"), FROM_GDA),
out_axis_resources=None)
return handle(input_data, params)
with global_mesh:
tf_func = tf.function(
jax2tf.convert(jax_func, enable_xla=True),
jit_compile=True, autograph=False
)
jax_out = jax_func(input_data=input_data)
tf_out = tf_func(input_data=input_data)
# TODO(b/243146552) We can switch to ConvertAndCompare after this bug fix.
np.array_equal(jax_out._value, np.array(tf_out))
def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str):
"""Assert all operations name start with ```scope_name```.
@ -1400,7 +1345,7 @@ def get_serialized_computation(
out_axis_resources = None) -> str:
if use_pjit:
assert not abstracted_axes
lowered = pjit(f_jax,
lowered = pjit.pjit(f_jax,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources).lower(*args)
else:
@ -1503,7 +1448,7 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
in_axis_resources = (P("x"), P("x"))
out_axis_resources = None
res_jax = pjit(
res_jax = pjit.pjit(
func_jax,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources)(x, x)

View File

@ -13,7 +13,7 @@
# limitations under the License.
"""Tests for the jax2tf conversion of pjit."""
import functools
from functools import partial
import logging
import os
import re
@ -28,7 +28,6 @@ from jax.config import config
from jax.experimental import jax2tf
from jax.experimental import pjit
from jax.experimental.jax2tf.tests import tf_test_util
from jax.interpreters.pxla import PartitionSpec as P
import jax.numpy as jnp
from jax._src.lib import xla_bridge
@ -39,6 +38,14 @@ import tensorflow as tf # type: ignore[import]
config.parse_flags_with_absl()
# Must come after initializing the flags
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation
skip_eager_for_partitioning = Jax2TfLimitation(
"pjit functions with partitioning must be under tf.function",
modes="eager", skip_tf_run=True)
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
@ -169,7 +176,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2)])
def test_pjit_basic1D(self):
@functools.partial(pjit.pjit,
@partial(pjit.pjit,
in_axis_resources=(P("x"), P("x")),
out_axis_resources=None)
def jax_func(x, y):
@ -195,7 +202,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2)])
def test_pjit_basic1D_variable(self):
# The first argument is a tf.Variable
@functools.partial(pjit.pjit,
@partial(pjit.pjit,
in_axis_resources=(P("x"), P("x")),
out_axis_resources=None)
def jax_func(x, y):
@ -221,7 +228,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2), ("y", 2)])
def test_pjit_basic2D(self):
@functools.partial(pjit.pjit,
@partial(pjit.pjit,
in_axis_resources=(P(None, "x", "y"), P("y")),
out_axis_resources=P("x"))
def jax_func(x, y):
@ -250,7 +257,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2), ("y", 2)])
def test_pjit_TwoMeshAxisSharding(self):
@functools.partial(pjit.pjit,
@partial(pjit.pjit,
in_axis_resources=P(("x", "y"),),
out_axis_resources=P(("x", "y"),))
def jax_func(x, y):
@ -279,7 +286,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
@jtu.with_mesh([("x", 2), ("y", 1)])
def test_pjit_ShardingConstraint(self):
@functools.partial(pjit.pjit, in_axis_resources=None,
@partial(pjit.pjit, in_axis_resources=None,
out_axis_resources=None)
def jax_func(x): # x: f32[12, 8]
y = jnp.tile(x, (2, 1)) # y: f32[24, 8]
@ -307,6 +314,13 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
class PjitTest(tf_test_util.JaxToTfTestCase):
def create_test_mesh(self, *axis_names):
"""Creates a mesh with 2 axes"""
assert len(axis_names) == 2, axis_names
nr_devices = len(jax.devices())
mesh_shape = (2, 1) if nr_devices >= 2 else (1, 1)
return jtu.create_global_mesh(mesh_shape, axis_names)
@jtu.with_mesh([("axis", 2)])
def test_pjit_basic1D(self):
def func_jax(x):
@ -341,6 +355,51 @@ class PjitTest(tf_test_util.JaxToTfTestCase):
out_axis_resources=None))(x)
self.assertAllClose(res_tf.numpy(), res_jax)
@jtu.with_mesh([("x", 1)])
def test_pjit_closed_over_const(self):
const = jnp.full((4, 3), 7, dtype=np.float32)
@partial(pjit.pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
def func_jax(x, y):
return x + y * const
with self.create_test_mesh("x", "y"):
self.ConvertAndCompare(func_jax, jnp.ones((4, 3), dtype=np.float32),
jnp.ones((1, 1), dtype=np.float32),
limitations=[skip_eager_for_partitioning])
def test_pjit_closed_over_global_device_array(self):
global_mesh = self.create_test_mesh("x", "y")
input1 = np.arange(16).reshape(2, 8)
input2_raw = np.arange(16).reshape(8, 2)
input2_array = jax.make_array_from_callback(input2_raw.shape,
jax.sharding.NamedSharding(global_mesh, P("x", "y")),
lambda idx: input2_raw[idx])
@partial(pjit.pjit,
in_axis_resources=(P("y", "x"),),
out_axis_resources=None)
def jax_func(input_data):
return jnp.matmul(input_data, input2_array)
with global_mesh:
self.ConvertAndCompare(jax_func, input1,
limitations=[skip_eager_for_partitioning])
def test_nested_pjit(self):
global_mesh = self.create_test_mesh("x", "y")
x = np.arange(16).reshape(2, 8)
def func_jax(x):
# We have a pjit nested inside the function to be converted
inner_func = pjit.pjit(
jnp.sin,
in_axis_resources=(P("y", "x"),),
out_axis_resources=None)
return inner_func(x)
with global_mesh:
self.ConvertAndCompare(func_jax, x,
limitations=[skip_eager_for_partitioning])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())