mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13917 from gnecula:tf_bug_gda
PiperOrigin-RevId: 500971319
This commit is contained in:
commit
62f2b9680b
@ -624,12 +624,11 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
for aval in args_avals
|
||||
]
|
||||
# TODO: specify the backend for experimental_native_lowering
|
||||
backend = jax.default_backend()
|
||||
if not hasattr(fun_jax, "lower") or abstracted_axes:
|
||||
# We support convert(pjit(f_jax, ...)) and convert(jit(f_jax)) but also
|
||||
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
|
||||
# we need to pass the abstracted axes.
|
||||
fun_jax_lower = jax.jit(fun_jax, backend=backend,
|
||||
fun_jax_lower = jax.jit(fun_jax,
|
||||
abstracted_axes=abstracted_axes).lower
|
||||
else:
|
||||
fun_jax_lower = fun_jax.lower
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
Specific JAX primitive conversion tests are in primitives_test."""
|
||||
import collections
|
||||
from functools import partial
|
||||
import os
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
import unittest
|
||||
@ -33,12 +32,11 @@ from jax._src import test_util as jtu
|
||||
import jax._src.lib.xla_bridge
|
||||
from jax.config import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.experimental.pjit import FROM_GDA
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
@ -1231,59 +1229,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
include_xla_op_metadata=False
|
||||
)
|
||||
|
||||
@jtu.with_mesh([("x", 1)])
|
||||
def test_pjit_simple(self):
|
||||
@partial(pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
|
||||
def func_jax(x, y):
|
||||
return x + y
|
||||
|
||||
self.ConvertAndCompare(func_jax, jnp.ones((3, 4), dtype=np.float32),
|
||||
jnp.ones((1, 1), dtype=np.float32))
|
||||
|
||||
@jtu.with_mesh([("x", 1)])
|
||||
def test_pjit_closed_over_const(self):
|
||||
const = jnp.full((3, 4), 7, dtype=np.float32)
|
||||
@partial(pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
|
||||
def func_jax(x, y):
|
||||
return x + y * const
|
||||
|
||||
self.ConvertAndCompare(func_jax, jnp.ones((3, 4), dtype=np.float32),
|
||||
jnp.ones((1, 1), dtype=np.float32))
|
||||
|
||||
# TODO(necula): figure out this failure
|
||||
@jtu.skip_on_flag("jax2tf_default_experimental_native_lowering", True)
|
||||
def test_global_device_array(self):
|
||||
|
||||
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
|
||||
if global_data is None:
|
||||
global_data = np.arange(np.prod(global_shape)).reshape(global_shape)
|
||||
return GlobalDeviceArray.from_callback(
|
||||
global_shape, global_mesh, mesh_axes,
|
||||
lambda idx: global_data[idx]), global_data
|
||||
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
|
||||
mesh_axes = P(("x", "y"))
|
||||
params, _ = create_gda((8, 2), global_mesh, mesh_axes)
|
||||
input_data = np.arange(16).reshape(2, 8)
|
||||
|
||||
# Test 1: use GDA as constants
|
||||
def jax_func(input_data):
|
||||
handle = pjit(
|
||||
jnp.matmul,
|
||||
in_axis_resources=(P("y", "x"), FROM_GDA),
|
||||
out_axis_resources=None)
|
||||
return handle(input_data, params)
|
||||
|
||||
with global_mesh:
|
||||
tf_func = tf.function(
|
||||
jax2tf.convert(jax_func, enable_xla=True),
|
||||
jit_compile=True, autograph=False
|
||||
)
|
||||
jax_out = jax_func(input_data=input_data)
|
||||
tf_out = tf_func(input_data=input_data)
|
||||
# TODO(b/243146552) We can switch to ConvertAndCompare after this bug fix.
|
||||
np.array_equal(jax_out._value, np.array(tf_out))
|
||||
|
||||
def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str):
|
||||
"""Assert all operations name start with ```scope_name```.
|
||||
|
||||
@ -1400,7 +1345,7 @@ def get_serialized_computation(
|
||||
out_axis_resources = None) -> str:
|
||||
if use_pjit:
|
||||
assert not abstracted_axes
|
||||
lowered = pjit(f_jax,
|
||||
lowered = pjit.pjit(f_jax,
|
||||
in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources).lower(*args)
|
||||
else:
|
||||
@ -1503,7 +1448,7 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
in_axis_resources = (P("x"), P("x"))
|
||||
out_axis_resources = None
|
||||
res_jax = pjit(
|
||||
res_jax = pjit.pjit(
|
||||
func_jax,
|
||||
in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources)(x, x)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Tests for the jax2tf conversion of pjit."""
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@ -28,7 +28,6 @@ from jax.config import config
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lib import xla_bridge
|
||||
@ -39,6 +38,14 @@ import tensorflow as tf # type: ignore[import]
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# Must come after initializing the flags
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation
|
||||
|
||||
skip_eager_for_partitioning = Jax2TfLimitation(
|
||||
"pjit functions with partitioning must be under tf.function",
|
||||
modes="eager", skip_tf_run=True)
|
||||
|
||||
prev_xla_flags = None
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
@ -169,7 +176,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_basic1D(self):
|
||||
|
||||
@functools.partial(pjit.pjit,
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=(P("x"), P("x")),
|
||||
out_axis_resources=None)
|
||||
def jax_func(x, y):
|
||||
@ -195,7 +202,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_basic1D_variable(self):
|
||||
# The first argument is a tf.Variable
|
||||
@functools.partial(pjit.pjit,
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=(P("x"), P("x")),
|
||||
out_axis_resources=None)
|
||||
def jax_func(x, y):
|
||||
@ -221,7 +228,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@jtu.with_mesh([("x", 2), ("y", 2)])
|
||||
def test_pjit_basic2D(self):
|
||||
@functools.partial(pjit.pjit,
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=(P(None, "x", "y"), P("y")),
|
||||
out_axis_resources=P("x"))
|
||||
def jax_func(x, y):
|
||||
@ -250,7 +257,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@jtu.with_mesh([("x", 2), ("y", 2)])
|
||||
def test_pjit_TwoMeshAxisSharding(self):
|
||||
@functools.partial(pjit.pjit,
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=P(("x", "y"),),
|
||||
out_axis_resources=P(("x", "y"),))
|
||||
def jax_func(x, y):
|
||||
@ -279,7 +286,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@jtu.with_mesh([("x", 2), ("y", 1)])
|
||||
def test_pjit_ShardingConstraint(self):
|
||||
@functools.partial(pjit.pjit, in_axis_resources=None,
|
||||
@partial(pjit.pjit, in_axis_resources=None,
|
||||
out_axis_resources=None)
|
||||
def jax_func(x): # x: f32[12, 8]
|
||||
y = jnp.tile(x, (2, 1)) # y: f32[24, 8]
|
||||
@ -307,6 +314,13 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
class PjitTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def create_test_mesh(self, *axis_names):
|
||||
"""Creates a mesh with 2 axes"""
|
||||
assert len(axis_names) == 2, axis_names
|
||||
nr_devices = len(jax.devices())
|
||||
mesh_shape = (2, 1) if nr_devices >= 2 else (1, 1)
|
||||
return jtu.create_global_mesh(mesh_shape, axis_names)
|
||||
|
||||
@jtu.with_mesh([("axis", 2)])
|
||||
def test_pjit_basic1D(self):
|
||||
def func_jax(x):
|
||||
@ -341,6 +355,51 @@ class PjitTest(tf_test_util.JaxToTfTestCase):
|
||||
out_axis_resources=None))(x)
|
||||
self.assertAllClose(res_tf.numpy(), res_jax)
|
||||
|
||||
@jtu.with_mesh([("x", 1)])
|
||||
def test_pjit_closed_over_const(self):
|
||||
const = jnp.full((4, 3), 7, dtype=np.float32)
|
||||
@partial(pjit.pjit, in_axis_resources=(P("x"), None), out_axis_resources=None)
|
||||
def func_jax(x, y):
|
||||
return x + y * const
|
||||
|
||||
with self.create_test_mesh("x", "y"):
|
||||
self.ConvertAndCompare(func_jax, jnp.ones((4, 3), dtype=np.float32),
|
||||
jnp.ones((1, 1), dtype=np.float32),
|
||||
limitations=[skip_eager_for_partitioning])
|
||||
|
||||
def test_pjit_closed_over_global_device_array(self):
|
||||
global_mesh = self.create_test_mesh("x", "y")
|
||||
|
||||
input1 = np.arange(16).reshape(2, 8)
|
||||
input2_raw = np.arange(16).reshape(8, 2)
|
||||
input2_array = jax.make_array_from_callback(input2_raw.shape,
|
||||
jax.sharding.NamedSharding(global_mesh, P("x", "y")),
|
||||
lambda idx: input2_raw[idx])
|
||||
@partial(pjit.pjit,
|
||||
in_axis_resources=(P("y", "x"),),
|
||||
out_axis_resources=None)
|
||||
def jax_func(input_data):
|
||||
return jnp.matmul(input_data, input2_array)
|
||||
|
||||
with global_mesh:
|
||||
self.ConvertAndCompare(jax_func, input1,
|
||||
limitations=[skip_eager_for_partitioning])
|
||||
|
||||
def test_nested_pjit(self):
|
||||
global_mesh = self.create_test_mesh("x", "y")
|
||||
x = np.arange(16).reshape(2, 8)
|
||||
|
||||
def func_jax(x):
|
||||
# We have a pjit nested inside the function to be converted
|
||||
inner_func = pjit.pjit(
|
||||
jnp.sin,
|
||||
in_axis_resources=(P("y", "x"),),
|
||||
out_axis_resources=None)
|
||||
return inner_func(x)
|
||||
|
||||
with global_mesh:
|
||||
self.ConvertAndCompare(func_jax, x,
|
||||
limitations=[skip_eager_for_partitioning])
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user