mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
1967 lines
79 KiB
Python
1967 lines
79 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
from collections.abc import Callable, Sequence
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import logging
|
|
import json
|
|
import math
|
|
import re
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
import jax
|
|
from jax import lax
|
|
from jax import numpy as jnp
|
|
from jax import export
|
|
from jax.experimental import pjit
|
|
from jax.experimental.shard_map import shard_map
|
|
from jax.sharding import NamedSharding
|
|
from jax.sharding import Mesh
|
|
from jax.sharding import PartitionSpec as P
|
|
from jax import tree_util
|
|
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src import test_util as jtu
|
|
from jax._src import xla_bridge as xb
|
|
from jax._src.interpreters import mlir
|
|
|
|
from jax._src.lib.mlir.dialects import hlo
|
|
|
|
import numpy as np
|
|
|
|
# ruff: noqa: F401
|
|
try:
|
|
import flatbuffers
|
|
CAN_SERIALIZE = True
|
|
except (ModuleNotFoundError, ImportError):
|
|
CAN_SERIALIZE = False
|
|
|
|
config.parse_flags_with_absl()
|
|
jtu.request_cpu_devices(8)
|
|
|
|
|
|
### Setup for testing lowering with effects
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ForTestingOrderedEffect1(effects.Effect):
|
|
pass
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ForTestingOrderedEffect2(effects.Effect):
|
|
pass
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ForTestingUnorderedEffect1(effects.Effect):
|
|
pass
|
|
|
|
|
|
class ForTestingOrderedEffect4NoNullary(effects.Effect):
|
|
def __init__(self, _):
|
|
pass
|
|
|
|
@dataclasses.dataclass(eq=False)
|
|
class ForTestingOrderedEffect5NoEq(effects.Effect):
|
|
pass
|
|
|
|
|
|
_testing_effects = dict(
|
|
ForTestingOrderedEffect1=ForTestingOrderedEffect1(),
|
|
ForTestingOrderedEffect2=ForTestingOrderedEffect2(),
|
|
ForTestingUnorderedEffect1=ForTestingUnorderedEffect1(),
|
|
ForTestingOrderedEffect4NoNullary=ForTestingOrderedEffect4NoNullary(42),
|
|
ForTestingOrderedEffect5NoEq=ForTestingOrderedEffect5NoEq(),
|
|
)
|
|
# Register the effects
|
|
for effect in _testing_effects.values():
|
|
effect_class = effect.__class__
|
|
effects.lowerable_effects.add_type(effect_class)
|
|
effects.control_flow_allowed_effects.add_type(effect_class)
|
|
effects.remat_allowed_effects.add_type(effect_class)
|
|
effects.custom_derivatives_allowed_effects.add_type(effect_class)
|
|
if "Ordered" in str(effect_class):
|
|
effects.ordered_effects.add_type(effect_class)
|
|
|
|
# A primitive that takes a effect_class_name kwarg with the name of the effect class
|
|
# and just doubles its argument.
|
|
testing_primitive_with_effect_p = core.Primitive("testing_primitive_with_effect")
|
|
testing_primitive_with_effect_p.def_effectful_abstract_eval(
|
|
lambda aval, *x, effect_class_name: (aval, {_testing_effects[effect_class_name]}))
|
|
|
|
def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str):
|
|
if "Ordered" in effect_class_name:
|
|
token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])
|
|
ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: token_in}))
|
|
return [mlir.hlo.add(a, a)]
|
|
|
|
mlir.register_lowering(testing_primitive_with_effect_p,
|
|
lowering_testing_primitive_with_effect)
|
|
|
|
## Setup for multi-platform lowering
|
|
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
|
|
|
|
def _testing_multi_platform_func(x, *,
|
|
effect_class_name: str | None = None):
|
|
# Behaves like x + 2 * _testing_multi_platform_to_add[platform]
|
|
def for_platform(platform: str):
|
|
if effect_class_name is None:
|
|
return 2. * _testing_multi_platform_to_add[platform]
|
|
else:
|
|
return testing_primitive_with_effect_p.bind(
|
|
_testing_multi_platform_to_add[platform],
|
|
effect_class_name=effect_class_name)
|
|
|
|
return x + lax.platform_dependent(
|
|
tpu=lambda: for_platform("tpu"),
|
|
cuda=lambda: for_platform("cuda"),
|
|
rocm=lambda: for_platform("rocm"),
|
|
default=lambda: for_platform("cpu"),
|
|
)
|
|
|
|
def _testing_multi_platform_fun_expected(x,
|
|
platform: str | None = None):
|
|
return x + 2. * _testing_multi_platform_to_add[
|
|
xb.canonicalize_platform(platform or jtu.device_under_test())
|
|
]
|
|
|
|
|
|
def get_exported(fun: Callable, vjp_order=0,
|
|
**export_kwargs) -> Callable[[...], export.Exported]:
|
|
"""Like export.export but with serialization + deserialization."""
|
|
def serde_exported(*fun_args, **fun_kwargs):
|
|
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
|
|
if CAN_SERIALIZE:
|
|
serialized = exp.serialize(vjp_order=vjp_order)
|
|
return export.deserialize(serialized)
|
|
else:
|
|
return exp
|
|
return serde_exported
|
|
|
|
|
|
# Run tests with the maximum supported version by default
|
|
@jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version)
|
|
class JaxExportTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
# Find the available platforms
|
|
self.platforms = []
|
|
for backend in ["cpu", "gpu", "tpu"]:
|
|
try:
|
|
jax.devices(backend)
|
|
except RuntimeError:
|
|
continue
|
|
self.platforms.append(backend)
|
|
|
|
def test_basic_export_only(self):
|
|
@jax.jit
|
|
def my_fun(x):
|
|
return jnp.sin(x)
|
|
exp = get_exported(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
|
|
self.assertEqual("my_fun", exp.fun_name)
|
|
expected_lowering_platform = xb.canonicalize_platform(jax.default_backend())
|
|
self.assertEqual((expected_lowering_platform,),
|
|
exp.platforms)
|
|
self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree)
|
|
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
|
|
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
|
|
|
|
def test_pytree_export_only(self):
|
|
a = np.arange(4, dtype=np.float32)
|
|
b = np.arange(6, dtype=np.float32)
|
|
def f(a_b_pair, *, a, b):
|
|
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
|
|
|
|
exp = get_exported(jax.jit(f), platforms=("cpu",))((a, b), a=a, b=b)
|
|
a_aval = core.ShapedArray(a.shape, a.dtype)
|
|
b_aval = core.ShapedArray(b.shape, b.dtype)
|
|
self.assertEqual(exp.platforms, ("cpu",))
|
|
args = ((a, b),)
|
|
kwargs = dict(a=a, b=b)
|
|
self.assertEqual(exp.in_tree, jax.tree.flatten((args, kwargs))[1])
|
|
self.assertEqual(exp.in_avals, (a_aval, b_aval, a_aval, b_aval))
|
|
self.assertEqual(exp.out_tree, jax.tree.flatten(f(*args, **kwargs))[1])
|
|
self.assertEqual(exp.out_avals, (a_aval, b_aval, a_aval, b_aval, a_aval, b_aval))
|
|
|
|
def test_basic(self):
|
|
f = jnp.sin
|
|
x = np.arange(4, dtype=np.float32)
|
|
exp_f = get_exported(f)(x)
|
|
|
|
self.assertAllClose(f(x), exp_f.call(x))
|
|
|
|
def test_jit_static_arg(self):
|
|
|
|
with self.subTest("static_argnames"):
|
|
|
|
@functools.partial(jax.jit, static_argnames=["c"])
|
|
def f(x, *, c):
|
|
return c * jnp.sin(x)
|
|
|
|
x = np.arange(4, dtype=np.float32)
|
|
exp_f = get_exported(f)(x, c=0.1)
|
|
|
|
self.assertAllClose(f(x, c=0.1), exp_f.call(x))
|
|
|
|
with self.subTest("static_argnums"):
|
|
|
|
@functools.partial(jax.jit, static_argnums=[1])
|
|
def g(x, c):
|
|
return c * jnp.sin(x)
|
|
|
|
x = np.arange(4, dtype=np.float32)
|
|
exp_g = get_exported(g)(x, 0.1)
|
|
|
|
self.assertAllClose(g(x, 0.1), exp_g.call(x))
|
|
|
|
def test_export_error_no_jit(self):
|
|
# Can export a lambda, without jit
|
|
with self.assertRaisesRegex(ValueError,
|
|
"Function to be exported must be the result of `jit`"):
|
|
_ = export.export(lambda x: jnp.sin(x))
|
|
|
|
def test_call_exported_lambda(self):
|
|
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
|
|
f = jax.jit(lambda x: jnp.sin(x))
|
|
x = np.arange(4, dtype=np.float32)
|
|
exp_f = get_exported(f)(x)
|
|
self.assertAllClose(f(x), exp_f.call(x))
|
|
|
|
def test_call_name_conflict(self):
|
|
@jax.jit
|
|
def inner(x):
|
|
# The lowering will contain a _where private function
|
|
return jnp.where(x > 0, jnp.ones_like(x), jnp.zeros_like(x))
|
|
|
|
x = jnp.arange(-20, 20, dtype=np.int32)
|
|
exp_inner = export.export(inner)(x)
|
|
self.assertIn("@_where(", str(exp_inner.mlir_module()))
|
|
|
|
@jax.jit
|
|
def outer(x):
|
|
# There should be no conflict on _where
|
|
x = exp_inner.call(x)
|
|
return inner(x)
|
|
|
|
export.export(outer)(x)
|
|
|
|
def test_call_twice_exported(self):
|
|
def f(x): return jnp.sin(x)
|
|
x = np.arange(4, dtype=np.float32)
|
|
|
|
@jax.jit
|
|
def f1(x):
|
|
exp_f = get_exported(jax.jit(f))(x)
|
|
return exp_f.call(x) + exp_f.call(x)
|
|
|
|
self.assertAllClose(2. * f(x), f1(x))
|
|
|
|
def test_unused_args(self):
|
|
f = jax.jit(lambda x, y: jnp.sin(x))
|
|
x = np.arange(4, dtype=np.float32)
|
|
y = np.arange(6, dtype=np.float32)
|
|
exp_f = get_exported(f)(x, y)
|
|
|
|
self.assertAllClose(f(x, y), exp_f.call(x, y))
|
|
|
|
def test_pytree(self):
|
|
a = np.arange(4, dtype=np.float32)
|
|
b = np.arange(6, dtype=np.float32)
|
|
def f(a_b_pair, a, b):
|
|
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
|
|
|
|
exp_f = get_exported(jax.jit(f))((a, b), a=a, b=b)
|
|
self.assertAllClose(f((a, b), a=a, b=b),
|
|
exp_f.call((a, b), a=a, b=b))
|
|
|
|
def test_pytree_namedtuple(self):
|
|
if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers")
|
|
T = collections.namedtuple("SomeType", ("a", "b", "c"))
|
|
export.register_namedtuple_serialization(
|
|
T,
|
|
serialized_name="test_pytree_namedtuple.SomeType",
|
|
)
|
|
x = T(a=1, b=2, c=3)
|
|
|
|
def f(x):
|
|
return (x, x) # return 2 copies, to check that types are shared
|
|
|
|
exp = export.export(jax.jit(f))(x)
|
|
res = exp.call(x)
|
|
self.assertEqual(tree_util.tree_structure(res),
|
|
tree_util.tree_structure((x, x)))
|
|
self.assertEqual(type(res[0]), type(x))
|
|
self.assertEqual(type(res[1]), type(x))
|
|
ser = exp.serialize()
|
|
exp2 = export.deserialize(ser)
|
|
self.assertEqual(exp2.in_tree, exp.in_tree)
|
|
self.assertEqual(exp2.out_tree, exp.out_tree)
|
|
res2 = exp2.call(x)
|
|
self.assertEqual(tree_util.tree_structure(res2),
|
|
tree_util.tree_structure(res))
|
|
|
|
def test_pytree_namedtuple_error(self):
|
|
if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers")
|
|
T = collections.namedtuple("SomeType", ("a", "b"))
|
|
x = T(a=1, b=2)
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Cannot serialize .* unregistered type .*SomeType"):
|
|
export.export(jax.jit(lambda x: x))(x).serialize()
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"If `from_children` is not present.* must call.*register_pytree_node"
|
|
):
|
|
export.register_pytree_node_serialization(
|
|
T,
|
|
serialized_name="test_pytree_namedtuple.SomeType_V2",
|
|
serialize_auxdata=lambda x: b"",
|
|
deserialize_auxdata=lambda b: None
|
|
)
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
"Use .*register_pytree_node_serialization"):
|
|
export.register_namedtuple_serialization(str, serialized_name="n/a")
|
|
|
|
export.register_namedtuple_serialization(
|
|
T,
|
|
serialized_name="test_pytree_namedtuple_error.SomeType",
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Duplicate serialization registration .*test_pytree_namedtuple_error.SomeType"
|
|
):
|
|
export.register_namedtuple_serialization(
|
|
T,
|
|
serialized_name="test_pytree_namedtuple_error.OtherType",
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Duplicate serialization registration for serialized_name.*test_pytree_namedtuple_error.SomeType"
|
|
):
|
|
export.register_namedtuple_serialization(
|
|
collections.namedtuple("SomeOtherType", ("a", "b")),
|
|
serialized_name="test_pytree_namedtuple_error.SomeType",
|
|
)
|
|
|
|
def test_pytree_custom_types(self):
|
|
if not CAN_SERIALIZE: raise unittest.SkipTest("test requires flatbuffers")
|
|
x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)])
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class CustomType:
|
|
def __init__(self, a: int, b: CustomType | None, string: str):
|
|
self.a = a
|
|
self.b = b
|
|
self.string = string
|
|
|
|
def tree_flatten(self):
|
|
return ((self.a, self.b), self.string)
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children):
|
|
string = aux_data
|
|
return cls(*children, string)
|
|
|
|
export.register_pytree_node_serialization(
|
|
CustomType,
|
|
serialized_name="test_pytree_custom_types.CustomType",
|
|
serialize_auxdata=lambda aux: aux.encode("utf-8"),
|
|
deserialize_auxdata=lambda b: b.decode("utf-8")
|
|
)
|
|
x2 = CustomType(4, 5, "foo")
|
|
|
|
def f(x1, x2):
|
|
return (x1, x2, x1, x2) # return 2 copies, to check that types are shared
|
|
|
|
exp = export.export(jax.jit(f))(x1, x2)
|
|
res = exp.call(x1, x2)
|
|
self.assertEqual(tree_util.tree_structure(res),
|
|
tree_util.tree_structure(((x1, x2, x1, x2))))
|
|
self.assertEqual(type(res[0]), type(x1))
|
|
self.assertEqual(type(res[1]), type(x2))
|
|
self.assertEqual(type(res[2]), type(x1))
|
|
self.assertEqual(type(res[3]), type(x2))
|
|
ser = exp.serialize()
|
|
exp2 = export.deserialize(ser)
|
|
self.assertEqual(exp2.in_tree, exp.in_tree)
|
|
self.assertEqual(exp2.out_tree, exp.out_tree)
|
|
res2 = exp2.call(x1, x2)
|
|
self.assertEqual(tree_util.tree_structure(res2),
|
|
tree_util.tree_structure(res))
|
|
|
|
def test_error_wrong_intree(self):
|
|
def f(a_b_pair, *, c):
|
|
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c
|
|
a = b = c = np.arange(4, dtype=np.float32)
|
|
exp_f = get_exported(jax.jit(f))((a, b), c=c)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"The invocation args and kwargs must have the same pytree structure"):
|
|
exp_f.call(a, b, c=(a, b))
|
|
|
|
def test_error_wrong_avals(self):
|
|
def f(a, *, b): # a: f32[4] and b: f32[4]
|
|
return jnp.sin(a) + jnp.cos(b)
|
|
f32_4 = np.arange(4, dtype=np.float32)
|
|
exp_f = get_exported(jax.jit(f))(f32_4, b=f32_4)
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
r"Shape mismatch for args\[0\].shape\[0\]"):
|
|
exp_f.call(np.arange(6, dtype=np.float32), b=f32_4)
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
|
|
exp_f.call(f32_4, b=np.arange(6, dtype=np.float32))
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
r"Rank mismatch for args\[0\]"):
|
|
exp_f.call(f32_4.reshape((1, 4)), b=f32_4)
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
r"Dtype mismatch for args\[0\]"):
|
|
exp_f.call(f32_4.astype(np.float16), b=f32_4)
|
|
|
|
def test_default_export_platform(self):
|
|
test_platform = jtu.device_under_test()
|
|
if test_platform == "gpu":
|
|
test_platform = "rocm" if jtu.is_device_rocm() else "cuda"
|
|
self.assertEqual(export.default_export_platform(), test_platform)
|
|
exp = export.export(jnp.sin)(1.)
|
|
self.assertEqual(exp.platforms, (export.default_export_platform(),))
|
|
|
|
@jtu.parameterized_filterable(
|
|
testcase_name=lambda kw: kw["platform"],
|
|
kwargs=[dict(platform=p)
|
|
for p in ("cpu", "cuda", "rocm", "tpu")])
|
|
def test_error_wrong_platform(self, platform):
|
|
a = np.arange(4, dtype=np.float32)
|
|
|
|
exp_f = get_exported(jnp.sin, platforms=(platform,))(a)
|
|
if xb.canonicalize_platform(jtu.device_under_test()) == platform:
|
|
raise unittest.SkipTest("Uninteresting scenario")
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Function .* was exported for platform"):
|
|
exp_f.call(a)
|
|
|
|
# Now try with the platform check disabled
|
|
exp_f_no_platform_check = get_exported(
|
|
jnp.sin, platforms=(platform,),
|
|
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
|
|
res = exp_f_no_platform_check.call(a)
|
|
self.assertAllClose(res, jnp.sin(a))
|
|
|
|
@jtu.parameterized_filterable(
|
|
testcase_name=lambda kw: kw["dialect"],
|
|
kwargs=[dict(dialect=dialect)
|
|
for dialect in ("stablehlo",)]
|
|
)
|
|
def test_error_disallowed_custom_call(self, dialect):
|
|
# If we use hlo.custom_call we detect invalid custom call targets.
|
|
# Set up a primitive with custom lowering rules
|
|
test_primitive = core.Primitive("_test_primitive_disallowed_custom_call")
|
|
test_primitive.def_abstract_eval(lambda in_aval: in_aval)
|
|
def test_primitive_lowering(ctx, arg):
|
|
op = dict(stablehlo=hlo.CustomCallOp)[dialect]
|
|
return op([arg.type], [arg], "disallowed_call_target").results
|
|
mlir.register_lowering(test_primitive, test_primitive_lowering)
|
|
self.addCleanup(lambda: mlir.register_lowering(test_primitive, None))
|
|
|
|
a = np.arange(3, dtype=np.float32)
|
|
with self.assertRaisesRegex(ValueError,
|
|
"Cannot serialize code with custom calls whose targets .*"):
|
|
get_exported(
|
|
jax.jit(lambda a: a + test_primitive.bind(a))
|
|
)(a)
|
|
|
|
# Now try again with the safety check disabled
|
|
exp = get_exported(
|
|
jax.jit(lambda a: a + test_primitive.bind(a)),
|
|
disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")]
|
|
)(a)
|
|
self.assertIn("disallowed_call_target", exp.mlir_module())
|
|
|
|
def test_lowering_parameters_for_export(self):
|
|
# Test that we propagate properly the LoweringParameters.for_export
|
|
test_primitive = core.Primitive("_test_primitive_for_export")
|
|
test_primitive.def_abstract_eval(lambda in_aval: in_aval)
|
|
# Store here the context for lowering
|
|
context = {}
|
|
def test_primitive_lowering(ctx, arg):
|
|
context["for_export"] = ctx.module_context.lowering_parameters.for_export
|
|
context["export_ignore_forward_compatibility"] = ctx.module_context.lowering_parameters.export_ignore_forward_compatibility
|
|
return mlir.hlo.AddOp(arg, arg).results
|
|
|
|
mlir.register_lowering(test_primitive, test_primitive_lowering)
|
|
self.addCleanup(lambda: mlir.register_lowering(test_primitive, None))
|
|
|
|
f = jax.jit(test_primitive.bind)
|
|
a = np.arange(3, dtype=np.float32)
|
|
context.clear()
|
|
res = f(a) # Works with JIT
|
|
self.assertAllClose(res, a + a)
|
|
self.assertEqual(context,
|
|
dict(for_export=False,
|
|
export_ignore_forward_compatibility=False))
|
|
context.clear()
|
|
f.lower(a) # Works with most AOT
|
|
# The above was cached
|
|
self.assertEqual(context, {})
|
|
_ = export.export(f)(a)
|
|
self.assertEqual(context,
|
|
dict(for_export=True,
|
|
export_ignore_forward_compatibility=False))
|
|
context.clear()
|
|
with config.export_ignore_forward_compatibility(True):
|
|
_ = export.export(f)(a)
|
|
self.assertEqual(context,
|
|
dict(for_export=True,
|
|
export_ignore_forward_compatibility=True))
|
|
|
|
def test_grad(self):
|
|
f = lambda x: jnp.sum(jnp.sin(x))
|
|
x = np.arange(4, dtype=np.float32)
|
|
exp_f = get_exported(jax.jit(f), vjp_order=1)(x)
|
|
|
|
f1 = exp_f.call
|
|
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))
|
|
|
|
def test_higher_order_grad(self):
|
|
f = lambda x: x ** 3
|
|
x = np.float32(4.)
|
|
exp_f = get_exported(jax.jit(f), vjp_order=3)(x)
|
|
|
|
f1 = exp_f.call
|
|
self.assertAllClose(jax.grad(f)(x),
|
|
jax.grad(f1)(x))
|
|
self.assertAllClose(jax.grad(jax.grad(f))(x),
|
|
jax.grad(jax.grad(f1))(x))
|
|
self.assertAllClose(jax.grad(jax.grad(jax.grad(f)))(x),
|
|
jax.grad(jax.grad(jax.grad(f1)))(x))
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[dict(poly_shape=True), dict(poly_shape=False)])
|
|
def test_grad_int(self, poly_shape):
|
|
def f(xi, xf):
|
|
return (2 * xi.T, xf.T * xf.T)
|
|
|
|
xi = np.arange(6, dtype=np.int32).reshape((2, 3))
|
|
xf = np.arange(12, dtype=np.float32).reshape((3, 4))
|
|
|
|
# Native JAX 1st order vjp
|
|
(f_outi, f_outf), f_vjp = jax.vjp(f, xi, xf)
|
|
f_outi_ct = np.ones(f_outi.shape,
|
|
dtype=core.primal_dtype_to_tangent_dtype(f_outi.dtype))
|
|
f_outf_ct = np.ones(f_outf.shape, dtype=f_outf.dtype)
|
|
xi_ct, xf_ct = f_vjp((f_outi_ct, f_outf_ct))
|
|
|
|
# Native JAX 2nd order vjp
|
|
res, f_vjp2 = jax.vjp(f_vjp, (f_outi_ct, f_outf_ct))
|
|
self.assertAllClose(res, (xi_ct, xf_ct))
|
|
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))
|
|
|
|
if poly_shape:
|
|
args = export.symbolic_args_specs([xi, xf], shapes_specs=["2, a", "a, 4"])
|
|
else:
|
|
args = (xi, xf)
|
|
exp = get_exported(jax.jit(f), vjp_order=2)(*args)
|
|
fr = exp.call
|
|
|
|
res = fr(xi, xf)
|
|
self.assertAllClose(res, (f_outi, f_outf))
|
|
|
|
# Reloaded 1st order vjp
|
|
(fr_outi, fr_outf), fr_vjp = jax.vjp(fr, xi, xf)
|
|
self.assertAllClose(fr_outi, f_outi)
|
|
self.assertAllClose(fr_outf, f_outf)
|
|
xri_ct, xrf_ct = fr_vjp((f_outi_ct, f_outf_ct))
|
|
self.assertAllClose(xri_ct, xi_ct)
|
|
self.assertAllClose(xrf_ct, xf_ct)
|
|
|
|
# Reloaded 2nd order vjp
|
|
res, f_vjp2 = jax.vjp(fr_vjp, (f_outi_ct, f_outf_ct))
|
|
self.assertAllClose(res, (xi_ct, xf_ct))
|
|
(fr_outi_ct2, fr_outf_ct2), = f_vjp2((xi_ct, xf_ct))
|
|
self.assertAllClose(fr_outi_ct2, f_outi_ct2)
|
|
self.assertAllClose(fr_outf_ct2, f_outf_ct2)
|
|
|
|
def test_pytree_vjp(self):
|
|
def f(a_b_pair, *, a, b):
|
|
return (dict(res=a_b_pair, a=2. * a, b=3. * b),
|
|
jnp.sin(4. * a))
|
|
|
|
a = np.arange(4, dtype=np.float32)
|
|
b = np.arange(6, dtype=np.float32)
|
|
exp_f = get_exported(jax.jit(f), vjp_order=1)((a, b), a=a, b=b)
|
|
|
|
out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent
|
|
def f1_jax(a, b): # For VJP, make a function without kwargs
|
|
res = f((a, b), a=a, b=b)
|
|
return res
|
|
def f1_exp(a, b): # For VJP, make a function without kwargs
|
|
res = exp_f.call((a, b), a=a, b=b)
|
|
return res
|
|
jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct)
|
|
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
|
|
self.assertAllClose(jax_vjp, exp_vjp)
|
|
|
|
def test_roundtrip(self):
|
|
def f1(x):
|
|
return jnp.sin(x)
|
|
a = np.arange(4, dtype=np.float32)
|
|
exp_f1 = get_exported(jax.jit(f1))(a)
|
|
def f2(x):
|
|
res1 = exp_f1.call(x)
|
|
res2 = exp_f1.call(res1)
|
|
return jnp.cos(res2)
|
|
exp_f2 = get_exported(jax.jit(f2))(a)
|
|
|
|
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
|
|
exp_f2.call(a))
|
|
|
|
def test_poly_export_only(self):
|
|
a = np.arange(12, dtype=np.float32).reshape((3, 4))
|
|
def f(a, b): # a: f32[2w,h] b: f32[w,h]
|
|
return jnp.concatenate([a, b], axis=0)
|
|
|
|
scope = export.SymbolicScope()
|
|
exp = get_exported(jax.jit(f))(
|
|
jax.ShapeDtypeStruct(export.symbolic_shape("(2*w, h)", scope=scope), a.dtype),
|
|
jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)", scope=scope), a.dtype))
|
|
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
|
|
self.assertEqual("(w, h)", str(exp.in_avals[1].shape))
|
|
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
|
|
|
|
# Peek at the module
|
|
module_str = exp.mlir_module()
|
|
self.assertEqual(config.jax_export_calling_convention_version.value >= 7,
|
|
"shape_assertion" in module_str)
|
|
self.assertIn("jax.uses_shape_polymorphism = true", module_str)
|
|
wrapped_main_expected_re = (
|
|
r"@_wrapped_jax_export_main\("
|
|
r"%arg0: tensor<i..> {jax.global_constant = \"h\".*"
|
|
r"%arg1: tensor<i..> {jax.global_constant = \"w\".*"
|
|
r"%arg2: tensor<\?x\?xf32>"
|
|
)
|
|
self.assertRegex(module_str, wrapped_main_expected_re)
|
|
|
|
# Look for private inner functions that are generated to compute the
|
|
# dimension variables and shape assertions. All those functions must
|
|
# have jax.global_constant attributes on all the arguments.
|
|
for func_name, func_args in re.findall(
|
|
r"func.func private @([\w]+)\((.+)\) ->",
|
|
module_str):
|
|
if func_name == "_wrapped_jax_export_main":
|
|
continue
|
|
func_args_count = len(re.findall(r"%arg\d+", func_args))
|
|
func_args_constant_attrs = len(re.findall(r"jax.global_constant = ",
|
|
func_args))
|
|
self.assertEqual(func_args_count, func_args_constant_attrs)
|
|
|
|
def test_poly_pytree_export_only(self):
|
|
a = np.arange(12, dtype=np.float32).reshape((3, 4))
|
|
def f(a0, a1, *, ak):
|
|
return jnp.concatenate([a0, a1, ak], axis=0)
|
|
|
|
a_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)"), a.dtype)
|
|
exp = get_exported(jax.jit(f))(a_poly_spec, a_poly_spec, ak=a_poly_spec)
|
|
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
|
|
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
|
|
|
|
def test_poly_export_error_symbolic_scope(self):
|
|
a = np.arange(12, dtype=np.float32).reshape((3, 4))
|
|
def f(x, y):
|
|
return jnp.concatenate([x, y], axis=1)
|
|
|
|
x_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h1)"), a.dtype)
|
|
y_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h2)"), a.dtype)
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
re.compile(
|
|
"Invalid mixing of symbolic scopes when exporting f.*"
|
|
r"Expected current \(from args\[0\]\) scope .*"
|
|
r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)):
|
|
get_exported(jax.jit(f))(x_poly_spec, y_poly_spec)
|
|
|
|
def test_poly_export_callable_with_no_name(self):
|
|
# This was reported by a user
|
|
class MyCallable:
|
|
def __call__(self, x):
|
|
return jnp.sin(x)
|
|
|
|
# This makes it look like a jitted-function
|
|
def lower(self, x, _experimental_lowering_parameters=None):
|
|
return jax.jit(self.__call__).lower(
|
|
x,
|
|
_experimental_lowering_parameters=_experimental_lowering_parameters)
|
|
|
|
def trace(self, x, _experimental_lowering_parameters=None):
|
|
return jax.jit(self.__call__).trace(
|
|
x,
|
|
_experimental_lowering_parameters=_experimental_lowering_parameters)
|
|
|
|
a, = export.symbolic_shape("a,")
|
|
# No error
|
|
_ = get_exported(jax.jit(MyCallable()))(
|
|
jax.ShapeDtypeStruct((a, a), dtype=np.float32)
|
|
)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(v=v)
|
|
for v in range(export.minimum_supported_calling_convention_version - 1,
|
|
export.maximum_supported_calling_convention_version + 2)])
|
|
def test_poly_basic_versions(self, v: int):
|
|
with config.jax_export_calling_convention_version(v):
|
|
logging.info(
|
|
"Using JAX calling convention version %s",
|
|
config.jax_export_calling_convention_version.value)
|
|
with contextlib.ExitStack() as e:
|
|
if not (export.minimum_supported_calling_convention_version <= v
|
|
<= export.maximum_supported_calling_convention_version):
|
|
e.enter_context(self.assertRaisesRegex(
|
|
ValueError,
|
|
f"The requested export calling convention version {v} is outside the range of supported versions"))
|
|
|
|
exp = get_exported(jnp.sin)(
|
|
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
|
|
x = np.arange(30, dtype=np.float32).reshape((5, 6))
|
|
res = exp.call(x)
|
|
self.assertAllClose(res, np.sin(x))
|
|
|
|
# A function is exported with f32[poly_spec] and is called with different arg
|
|
# shapes. We use export.call and we also run the shape check
|
|
# module.
|
|
@jtu.parameterized_filterable(
|
|
testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore
|
|
kwargs=[
|
|
dict(poly_spec="3,4,12", arg_shape=(3, 4, 12)),
|
|
dict(poly_spec="3,4,12", arg_shape=(3, 4, 13),
|
|
# The shape check module does not test constant dimensions
|
|
expect_error=re.escape(
|
|
r"Shape mismatch for args[0].shape[2] (expected same constant)")),
|
|
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 12)),
|
|
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 12)),
|
|
dict(poly_spec="3,4,a+1", arg_shape=(3, 4, 1),
|
|
expect_error=re.escape(
|
|
"Expected value >= 1 for dimension variable 'a'. "
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (3, 4, a + 1). "
|
|
"Obtained dimension variables: 'a' = 0"
|
|
)),
|
|
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 13),
|
|
expect_error=re.escape(
|
|
"Division had remainder 1 when computing the value of 'a'"
|
|
)),
|
|
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 13),
|
|
expect_error=re.escape(
|
|
"Found inconsistency between dimension size "
|
|
"args[0].shape[2] (= 13) and the specification 'a + 8' (= 12)"
|
|
)),
|
|
])
|
|
def test_poly_shape_checks(
|
|
self, poly_spec="3,a,a+8",
|
|
arg_shape=(3, 4, 12), arg_dtype=np.float32,
|
|
expect_error=None): # If given, error from running the exported module
|
|
|
|
def f(x): # x: f32[poly_spec]
|
|
return jnp.reshape(x, (-1, x.shape[1]))
|
|
|
|
disabled_checks = ()
|
|
exp_f = get_exported(jax.jit(f), disabled_checks=disabled_checks)(
|
|
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), np.float32))
|
|
self.assertEqual(exp_f.uses_global_constants, poly_spec != "3,4,12")
|
|
arg = np.arange(np.prod(arg_shape),
|
|
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
|
|
|
|
with contextlib.ExitStack() as stack:
|
|
if expect_error is not None:
|
|
stack.push(self.assertRaisesRegex(Exception, expect_error))
|
|
|
|
assert core.is_constant_shape(arg.shape)
|
|
res = exp_f.call(arg)
|
|
|
|
if not expect_error:
|
|
self.assertAllClose(res, f(arg))
|
|
|
|
# An inner function is exported with polymorphic shapes inner_poly_spec, and
|
|
# is called from an outer function, which is exported with outer_poly_spec.
|
|
@jtu.parameterized_filterable(
|
|
testcase_name=lambda kw:f"inner={kw['inner_poly_spec']}_outer={kw['outer_poly_spec']}", # type: ignore
|
|
#one_containing="",
|
|
# By default arg_shape = (3, 4, 12) for both the outer function and the inner
|
|
# The inner function is exported for f32.
|
|
kwargs=[
|
|
# Both inner and outer are static shapes
|
|
dict(inner_poly_spec="3,4,12", outer_poly_spec="3,4,12"),
|
|
# Inner has poly shapes but outer has static shapes. When we call inner
|
|
# we do the shape constraint checking
|
|
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
|
|
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
|
|
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,12",
|
|
expect_error_outer_exp=re.escape(
|
|
"Found inconsistency between dimension size "
|
|
"args[0].shape[2] (= 12) and the specification 'a' (= 4)")),
|
|
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
|
|
expect_error_outer_exp=re.escape(
|
|
"Division had remainder 2 when computing the value of 'a'")),
|
|
dict(inner_poly_spec="3,4,12+a", outer_poly_spec="3,4,12",
|
|
expect_error_outer_exp=re.escape(
|
|
"Expected value >= 1 for dimension variable 'a'. "
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (3, 4, a + 12). "
|
|
"Obtained dimension variables: 'a' = 0 from specification "
|
|
"'a + 12' for dimension args[0].shape[2] (= 12)")),
|
|
# Both inner and outer have poly shapes.
|
|
dict(inner_poly_spec="3,a,b", outer_poly_spec="3,4,c"),
|
|
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,6*c"),
|
|
dict(inner_poly_spec="3,a,a+8", outer_poly_spec="3,c+2,c+10"),
|
|
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c",
|
|
expect_error_outer_exp=re.escape(
|
|
"Expected value >= 1 for dimension variable 'b'. "
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, b + a). "
|
|
"Obtained dimension variables: 'a' = 4 from specification "
|
|
"'a' for dimension args[0].shape[1] (= 4), "
|
|
"'b' = c - 4 from specification 'b + a' for dimension args[0].shape[2] (= c),")),
|
|
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,c",
|
|
expect_error_outer_exp=re.escape(
|
|
"Found inconsistency between dimension size "
|
|
"args[0].shape[2] (= c) and the specification 'a' (= 4)")),
|
|
dict(inner_poly_spec="3,a,a", arg_shape=(3, 4),
|
|
outer_poly_spec="3,c",
|
|
expect_error_outer_exp=r"Rank mismatch for args\[0\]"),
|
|
dict(inner_poly_spec="3,a,a+b", arg_dtype=np.int32,
|
|
outer_poly_spec="3,c,d",
|
|
expect_error_outer_exp=r"Dtype mismatch for args\[0\]"),
|
|
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c",
|
|
expect_error_outer_exp=re.escape(
|
|
"Division had remainder mod(c, 5) when computing the value of 'a'")),
|
|
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
|
|
expect_error_outer_exp=re.escape(
|
|
"Expected value >= 1 for dimension variable 'b'. "
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, b + a). "
|
|
"Obtained dimension variables: 'a' = c from "
|
|
"specification 'a' for dimension args[0].shape[1] (= c), "
|
|
"'b' = 0 from specification 'b + a' for dimension args[0].shape[2] (= c)")),
|
|
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
|
|
expect_error_outer_exp=re.escape(
|
|
"Shape mismatch for args[0].shape[0] (expected same constant)")),
|
|
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,25*c",
|
|
expect_error_run=re.escape(
|
|
"Division had remainder 12 when computing the value of 'c'")),
|
|
dict(inner_poly_spec="3,a,b", outer_poly_spec="3,c+4,12",
|
|
expect_error_run=re.escape(
|
|
"Expected value >= 1 for dimension variable 'c'. "
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (3, c + 4, 12). "
|
|
"Obtained dimension variables: 'c' = 0")),
|
|
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a",
|
|
expect_error_run=re.escape(
|
|
"Found inconsistency between dimension size "
|
|
"args[0].shape[2] (= 12) and the specification 'a' (= 4)")),
|
|
])
|
|
def test_poly_shape_checks_nested(
|
|
self, inner_poly_spec="3,4,5*a",
|
|
arg_shape=(3, 4, 12), arg_dtype=np.float32,
|
|
outer_poly_spec="3,4,25*c",
|
|
expect_error_outer_exp=None,
|
|
expect_error_run=None):
|
|
# Polymorphic export called with static or polymorphic shapes
|
|
def inner(x): # x: inner_poly_spec
|
|
return jnp.reshape(x, (-1, x.shape[1]))
|
|
|
|
arg = np.arange(np.prod(arg_shape),
|
|
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
|
|
inner_exp = get_exported(jax.jit(inner))(
|
|
jax.ShapeDtypeStruct(export.symbolic_shape(inner_poly_spec), np.float32))
|
|
|
|
self.assertEqual(inner_exp.uses_global_constants,
|
|
(inner_poly_spec != "3,4,12"))
|
|
def outer(x): # x: outer_poly_spec
|
|
# Use an addition to test that the shapes are refined properly for the
|
|
# result of the call_exported.
|
|
return inner_exp.call(x) + inner(x)
|
|
|
|
with contextlib.ExitStack() as stack:
|
|
if expect_error_outer_exp is not None:
|
|
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
|
|
|
|
# Call it after exporting again, with polymorphic shapes
|
|
outer_exp = get_exported(jax.jit(outer))(
|
|
jax.ShapeDtypeStruct(export.symbolic_shape(outer_poly_spec), arg.dtype))
|
|
|
|
if expect_error_outer_exp is not None:
|
|
return
|
|
|
|
self.assertEqual(outer_exp.uses_global_constants,
|
|
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
|
|
|
|
with contextlib.ExitStack() as stack:
|
|
if expect_error_run is not None:
|
|
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
|
|
|
|
res = outer_exp.call(arg)
|
|
|
|
if expect_error_run is not None:
|
|
return
|
|
self.assertAllClose(2. * inner(arg), res)
|
|
|
|
# Tests details of the shape constraints errors
|
|
# This test exists also in shape_poly_test.py. Here we test the
|
|
# call_exported error reporting.
|
|
@jtu.parameterized_filterable(
|
|
testcase_name=lambda kw: kw["shape"], # assume "shape" is unique
|
|
kwargs=[
|
|
dict(shape=(8, 2, 9), # a = 2, b = 3, c = 4
|
|
poly_spec="(a + 2*b, a, a + b + c)"),
|
|
dict(shape=(2, 2, 6), # a = 2, b = 0, c = 4
|
|
poly_spec="(a + 2*b, a, a + b + c)",
|
|
expect_error=(
|
|
"Input shapes do not match the polymorphic shapes specification. "
|
|
"Expected value >= 1 for dimension variable 'b'. "
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
|
|
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
|
"'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . "
|
|
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
|
|
)),
|
|
dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer
|
|
poly_spec="(a + 2*b, a, a + b + c)",
|
|
expect_error=(
|
|
"Input shapes do not match the polymorphic shapes specification. "
|
|
"Division had remainder 1 when computing the value of 'b'. "
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
|
|
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . "
|
|
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
|
|
)),
|
|
dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency
|
|
poly_spec="(a + 2*b, a, a + b)",
|
|
expect_error=(
|
|
"Input shapes do not match the polymorphic shapes specification. "
|
|
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification '2*b + a' (= 10). "
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). "
|
|
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
|
"'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . "
|
|
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
|
|
)),
|
|
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
|
|
poly_spec="(2 * a + b, a, c * c)",
|
|
expect_error=(
|
|
"Cannot solve for values of dimension variables {'c'}. "
|
|
"We can only solve linear uni-variate constraints. "
|
|
"Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). "
|
|
"Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. "
|
|
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details."
|
|
)),
|
|
])
|
|
def test_shape_constraints_errors(self, *,
|
|
shape, poly_spec: str, expect_error: str | None = None):
|
|
def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
|
|
return 0.
|
|
|
|
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
with contextlib.ExitStack() as stack:
|
|
if expect_error is not None:
|
|
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
|
|
exp = get_exported(jax.jit(f_jax))(
|
|
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype))
|
|
exp.call(x)
|
|
|
|
def test_poly_booleans(self):
|
|
# For booleans we use a special case ConvertOp to cast to and from
|
|
# dynamic shapes arguments.
|
|
@jax.jit
|
|
def f_jax(x): # x: bool[b]
|
|
return jnp.logical_not(x)
|
|
|
|
x = np.array([True, False, True, False], dtype=np.bool_)
|
|
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
|
x.dtype))
|
|
res = exp.call(x)
|
|
self.assertAllClose(f_jax(x), res)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(dtype=dtype)
|
|
for dtype in dtypes._jax_types if dtype != np.dtype("bool")
|
|
])
|
|
def test_poly_numeric_dtypes(self, dtype=np.int32):
|
|
if str(dtype) in {"float8_e4m3b11fnuz",
|
|
"float8_e4m3fnuz",
|
|
"float8_e5m2fnuz",
|
|
"int2",
|
|
"int4",
|
|
"uint2",
|
|
"uint4"}:
|
|
self.skipTest(f"TODO: serialization not supported for {str(dtype)}")
|
|
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
|
|
self.skipTest("TPU does not support float8_e8m0fnu.")
|
|
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
|
|
self.skipTest("TPU does not support float4_e2m1fn.")
|
|
@jax.jit
|
|
def f_jax(x):
|
|
return x + x
|
|
|
|
x = np.arange(6, dtype=dtype)
|
|
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
|
x.dtype))
|
|
res = exp.call(x)
|
|
self.assertAllClose(f_jax(x), res)
|
|
|
|
def test_poly_expressions(self):
|
|
# Calling an Exported module whose output shape contains symbolic
|
|
# expressions
|
|
def output_shape(b):
|
|
return (b + b, b - b, b * b,
|
|
(b + 13) // b, (b + 13) % b,
|
|
core.max_dim(b - 5, 0))
|
|
@jax.jit
|
|
def f(x): # x: f32[b]
|
|
b = x.shape[0]
|
|
return jnp.ones(output_shape(b), dtype=x.dtype)
|
|
x = np.arange(5, dtype=np.float32)
|
|
exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
|
x.dtype))
|
|
# Call with static shapes
|
|
res = exp.call(x)
|
|
self.assertAllClose(res, f(x))
|
|
|
|
# Now re-export with shape polymorphism
|
|
x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype)
|
|
exp2 = get_exported(jax.jit(exp.call))(x_spec)
|
|
a = exp2.in_avals[0].shape[0]
|
|
self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
|
|
|
|
def test_with_donation(self):
|
|
f = jax.jit(jnp.sin, donate_argnums=(0,))
|
|
x = np.arange(3, dtype=np.float32)
|
|
exp = export.export(f)(x)
|
|
|
|
def caller(x):
|
|
y = exp.call(x)
|
|
return x + y
|
|
res = jax.jit(caller)(x)
|
|
self.assertAllClose(res, x + np.sin(x))
|
|
|
|
def test_poly_call_pmap(self):
|
|
if len(jax.devices()) < 2:
|
|
self.skipTest("Need at least 2 devices")
|
|
def f(x): # x: f32[a, 4]
|
|
return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1))
|
|
|
|
a, = export.symbolic_shape("a")
|
|
exp = export.export(jax.jit(f))(
|
|
jax.ShapeDtypeStruct((a, 4), np.float32))
|
|
f_exp = exp.call
|
|
x_jit = np.arange(12, dtype=np.float32).reshape((3, 4))
|
|
res_jit = jax.jit(f_exp)(x_jit)
|
|
self.assertAllClose(res_jit, f(x_jit))
|
|
x_pmap = np.arange(24, dtype=np.float32).reshape((2, 3, 4))
|
|
res_pmap = jax.pmap(f_exp)(x_pmap)
|
|
self.assertAllClose(res_pmap, jnp.stack([f(x) for x in x_pmap]))
|
|
|
|
def test_with_sharding(self):
|
|
nr_devices = 2
|
|
if len(jax.devices()) < nr_devices:
|
|
self.skipTest("Need at least 2 devices")
|
|
export_devices = jax.devices()[0:nr_devices]
|
|
export_mesh = Mesh(export_devices, axis_names=("x",))
|
|
a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4))
|
|
@functools.partial(
|
|
jax.jit,
|
|
in_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),),),
|
|
out_shardings=jax.sharding.NamedSharding(export_mesh, P(None, "x")))
|
|
def f_jax(b): # b: f32[16 // DEVICES, 4]
|
|
return b * 2.
|
|
|
|
res_native = f_jax(a)
|
|
exp = get_exported(f_jax)(a)
|
|
|
|
self.assertEqual(exp.nr_devices, len(export_devices))
|
|
run_devices = export_devices[::-1] # We can use other devices
|
|
run_mesh = Mesh(run_devices, "y")
|
|
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P()))
|
|
if config.use_shardy_partitioner.value:
|
|
expected_re = re.compile(
|
|
# The top-level input is replicated
|
|
r"func.func .* @main\(%arg0: tensor<16x4xf32>.*sdy.sharding ="
|
|
r" #sdy.sharding<@mesh_0, \[\{\}\, \{\}\]>}\).*"
|
|
# We apply the in_shardings for f_jax
|
|
r".*sdy.sharding_constraint %arg0 <@mesh, \[\{\"x\"\}\, \{\}\]>.*"
|
|
r"%1 = .*call @call_exported_f_jax.*"
|
|
# We apply the out_shardings for f_jax
|
|
r".*sdy.sharding_constraint %1 <@mesh, \[\{\}, \{\"x\"\}\]>.*",
|
|
re.DOTALL)
|
|
else:
|
|
expected_re = re.compile(
|
|
# The top-level input it replicated
|
|
r"func.func .* @main\(%arg0: tensor<16x4xf32>.*mhlo.sharding ="
|
|
r" \"{replicated}\"}\).*"
|
|
# We apply the in_shardings for f_jax
|
|
r".*custom_call @Sharding\(%arg0\).*mhlo.sharding ="
|
|
r" \"{devices=\[2,1\]<=\[2\]}\"}.*"
|
|
r"%1 = .*call @call_exported_f_jax.*"
|
|
# We apply the out_shardings for f_jax
|
|
r".*custom_call @Sharding\(%1\).*mhlo.sharding ="
|
|
r" \"{devices=\[1,2\]<=\[2\]}\"}.*",
|
|
re.DOTALL)
|
|
hlo = jax.jit(exp.call).lower(a_device).as_text()
|
|
self.assertRegex(hlo, expected_re)
|
|
|
|
res_exported = exp.call(a_device)
|
|
self.assertAllClose(res_native, res_exported)
|
|
|
|
# Test error reporting
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Function .* was exported for 2 devices and is called in a context with 1 device"):
|
|
_ = exp.call(a)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Function .* was exported for 2 devices and is called in a context with 1 device"):
|
|
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
|
|
_ = jax.jit(
|
|
exp.call,
|
|
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
|
|
)(a)
|
|
|
|
def test_input_shardings_unused_args(self):
|
|
nr_devices = 2
|
|
if len(jax.devices()) < nr_devices:
|
|
self.skipTest("Need at least 2 devices")
|
|
devices = jax.devices()[0:nr_devices]
|
|
export_mesh = Mesh(np.array(devices),
|
|
axis_names=("x",))
|
|
a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4))
|
|
|
|
f = jax.jit(lambda x, y: jnp.sin(x),
|
|
in_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),),
|
|
None),
|
|
out_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),)))
|
|
exp = get_exported(f)(a, a)
|
|
|
|
# We can use other devices and other meshes for running
|
|
run_devices = devices[::-1]
|
|
run_mesh = Mesh(run_devices, "a")
|
|
run_input_shardings = exp.in_shardings_jax(run_mesh)
|
|
a_run = jax.device_put(a, run_input_shardings[0])
|
|
b_run = jax.device_put(a, run_input_shardings[1])
|
|
res = exp.call(a_run, b_run)
|
|
self.assertEqual(res.addressable_shards[0].device, run_devices[0])
|
|
self.assertEqual(res.addressable_shards[1].device, run_devices[1])
|
|
|
|
def test_export_abstract_mesh(self):
|
|
if jax.local_device_count() < 2:
|
|
self.skipTest("Need at least 2 devices")
|
|
|
|
abs_mesh = jax.sharding.AbstractMesh((2,), 'x')
|
|
input_sharding = jax.sharding.NamedSharding(abs_mesh, P("x", None))
|
|
output_sharding = jax.sharding.NamedSharding(abs_mesh, P(None, "x"))
|
|
@jax.jit
|
|
def f(a):
|
|
b = a @ a.T
|
|
return jax.lax.with_sharding_constraint(b, output_sharding)
|
|
|
|
exp = get_exported(f)(
|
|
jax.ShapeDtypeStruct((16, 16), dtype=np.float32,
|
|
sharding=input_sharding))
|
|
# Call the Exported with a concrete Mesh
|
|
devices = jax.local_devices()[:2]
|
|
run_mesh = Mesh(devices, ("x",))
|
|
a_sharding = jax.sharding.NamedSharding(run_mesh, P("x", None))
|
|
a = jnp.arange(16 * 16, dtype=np.float32).reshape((16, 16))
|
|
a = jax.device_put(a, a_sharding)
|
|
|
|
res = exp.call(a)
|
|
self.assertAllClose(res, f(a))
|
|
self.assertLen(res.addressable_shards, 2)
|
|
self.assertEqual(res.addressable_shards[0].index, (slice(None), slice(0, 8)))
|
|
self.assertEqual(res.addressable_shards[1].index, (slice(None), slice(8, 16)))
|
|
|
|
def test_call_single_device_export_with_different_no_of_devices(self):
|
|
if jax.local_device_count() < 2:
|
|
self.skipTest("Need at least 2 devices")
|
|
|
|
@jax.jit
|
|
def f_without_shardings(x):
|
|
return jnp.sum(x ** 2, axis=0)
|
|
|
|
a = jnp.arange(jax.local_device_count() * 10, dtype=np.float32).reshape(
|
|
(jax.local_device_count(), 10)
|
|
)
|
|
res_native = f_without_shardings(a)
|
|
exp = get_exported(f_without_shardings)(a)
|
|
self.assertEqual(exp.nr_devices, 1)
|
|
|
|
run_devices = jax.local_devices()
|
|
run_mesh = Mesh(run_devices, "i")
|
|
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
|
|
|
|
res_exported = exp.call(b)
|
|
self.assertAllClose(res_native, res_exported)
|
|
|
|
def test_call_with_different_no_of_devices_error_has_in_shardings(self):
|
|
if jax.local_device_count() < 2:
|
|
self.skipTest("Need at least 2 devices")
|
|
|
|
mesh_1 = Mesh(jax.local_devices()[:1], "i")
|
|
@functools.partial(pjit.pjit,
|
|
in_shardings=NamedSharding(mesh_1, P("i")))
|
|
def f_with_sharding(x):
|
|
return jnp.sum(x ** 2, axis=0)
|
|
|
|
a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape(
|
|
(jax.device_count(), 10)
|
|
)
|
|
exp = get_exported(f_with_sharding)(a)
|
|
self.assertEqual(exp.nr_devices, 1)
|
|
|
|
run_devices = jax.local_devices()
|
|
run_mesh = Mesh(run_devices, "i")
|
|
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Function .* was exported for 1 devices and is called in a "
|
|
f"context with {jax.local_device_count()} devices.* function contains "
|
|
"non-replicated sharding annotations"):
|
|
exp.call(b)
|
|
|
|
def test_call_with_different_no_of_devices_pmap(self):
|
|
if len(jax.devices()) < 2:
|
|
self.skipTest("Need at least 2 devices")
|
|
|
|
@jax.jit
|
|
def f_jax(x):
|
|
return jnp.sum(x ** 2, axis=0)
|
|
|
|
a = jnp.arange(100, dtype=jnp.float32).reshape((1, 100))
|
|
res_native = f_jax(a)
|
|
exp = get_exported(f_jax)(a)
|
|
self.assertEqual(exp.nr_devices, 1)
|
|
|
|
b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape(
|
|
(-1, 1, 100)
|
|
)
|
|
res_exported = jax.pmap(exp.call)(b)
|
|
self.assertAllClose(res_native, res_exported[0])
|
|
|
|
def test_call_with_different_no_of_devices_error_has_sharding_constraint(self):
|
|
if jax.device_count() < 2:
|
|
self.skipTest("Need at least 2 devices")
|
|
|
|
mesh_1 = Mesh(jax.local_devices()[:1], "i")
|
|
@jax.jit
|
|
def f_with_sharding(x):
|
|
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh_1, P("i")))
|
|
return jnp.sum(x ** 2, axis=0)
|
|
|
|
a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape(
|
|
(jax.device_count(), 10)
|
|
)
|
|
exp = get_exported(f_with_sharding)(a)
|
|
self.assertEqual(exp.nr_devices, 1)
|
|
|
|
run_devices = jax.local_devices()
|
|
run_mesh = Mesh(run_devices, "i")
|
|
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Function .* was exported for 1 devices and is called in a "
|
|
f"context with {jax.local_device_count()} devices.* function contains "
|
|
"non-replicated sharding annotations"):
|
|
exp.call(b)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(testcase_name=f"_poly={poly}", poly=poly)
|
|
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
|
|
])
|
|
def test_shard_map_collective_permute(self, poly=None):
|
|
if len(jax.devices()) < 2:
|
|
self.skipTest("Test requires at least 2 local devices")
|
|
devices = np.array(jax.devices()[:2]) # use 2 devices
|
|
mesh = Mesh(devices, axis_names=("x",))
|
|
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
|
|
|
|
@functools.partial(
|
|
pjit.pjit,
|
|
in_shardings=NamedSharding(mesh, P("x", None),),
|
|
out_shardings=NamedSharding(mesh, P("x", None)))
|
|
@functools.partial(
|
|
shard_map, mesh=mesh,
|
|
in_specs=(P("x", None),), out_specs=P("x", None))
|
|
def f_jax(b): # b: f32[2, 4]
|
|
axis_size = lax.psum(1, "x")
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
return lax.ppermute(b, "x", perm=perm)
|
|
|
|
args_specs = export.symbolic_args_specs((a,), poly)
|
|
exp = get_exported(f_jax)(*args_specs)
|
|
|
|
# Test JAX native execution
|
|
res_jax = f_jax(a)
|
|
b0, b1 = np.split(a, 2, axis=0) # The shard_map splits on axis 0
|
|
b0, b1 = b1, b0
|
|
expected = np.concatenate([b0, b1], axis=0) # out_specs concatenates on axis 0
|
|
self.assertAllClose(res_jax, expected)
|
|
self.assertLen(res_jax.addressable_shards, len(devices))
|
|
|
|
# Test reloaded execution.
|
|
f_r = exp.call
|
|
with self.assertRaisesRegex(
|
|
Exception,
|
|
"Function .* was exported for 2 devices and is "
|
|
"called in a context with 1 devices"):
|
|
_ = f_r(a) # A is all on the default device
|
|
|
|
# Replicate the input so that the execution knows
|
|
# that we are using multiple devices
|
|
a_replicated = jax.device_put(a, NamedSharding(mesh, P()))
|
|
res_r = f_r(a_replicated)
|
|
self.assertAllClose(res_r, expected)
|
|
self.assertLen(res_r.addressable_shards, len(devices))
|
|
for i in range(len(devices)):
|
|
self.assertEqual(res_jax.addressable_shards[i].device,
|
|
res_r.addressable_shards[i].device)
|
|
self.assertEqual(res_jax.addressable_shards[i].index,
|
|
res_r.addressable_shards[i].index)
|
|
self.assertAllClose(res_jax.addressable_shards[i].data,
|
|
res_r.addressable_shards[i].data)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(in_shardings=in_shardings, out_shardings=out_shardings,
|
|
with_mesh_context=with_mesh_context)
|
|
for in_shardings in ("missing", None, "P")
|
|
for out_shardings in ("missing", None, "P")
|
|
for with_mesh_context in (True, False)
|
|
])
|
|
def test_grad_with_sharding(self, in_shardings="P", out_shardings=None,
|
|
with_mesh_context=False):
|
|
if len(jax.devices()) < 2:
|
|
self.skipTest("Test requires at least 2 devices")
|
|
x_shape = (10, 20)
|
|
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
|
# The input has shape f32[10,20] and output f32[20,10] in order to
|
|
# distinguish them in the HLO.
|
|
def f_jax(x): # x: f32[10,20] -> f32[20,10]
|
|
return jnp.sin(x.T)
|
|
|
|
mesh = Mesh(jax.devices()[:2], "d")
|
|
pjit_kwargs = {}
|
|
# Use NamedShardings if we don't have a mesh_context
|
|
if with_mesh_context:
|
|
sharding_None_d = P(None, "d")
|
|
sharding_d_None = P("d", None)
|
|
else:
|
|
sharding_None_d = NamedSharding(mesh, P(None, "d"))
|
|
sharding_d_None = NamedSharding(mesh, P("d", None))
|
|
|
|
if in_shardings != "missing":
|
|
pjit_kwargs["in_shardings"] = (
|
|
sharding_None_d if in_shardings == "P" else None)
|
|
if out_shardings != "missing":
|
|
pjit_kwargs["out_shardings"] = (
|
|
sharding_d_None if out_shardings == "P" else None)
|
|
f_jax_pjit = pjit.pjit(f_jax, **pjit_kwargs)
|
|
|
|
with contextlib.ExitStack() as stack:
|
|
if with_mesh_context:
|
|
stack.enter_context(mesh)
|
|
# Serialize higher-order gradiends
|
|
exp = get_exported(f_jax_pjit, vjp_order=2)(x)
|
|
exp_vjp = exp.vjp()
|
|
# Try 2nd order grad as well
|
|
exp_vjp2 = exp_vjp.vjp()
|
|
|
|
vjp_module_str = str(exp_vjp.mlir_module())
|
|
|
|
# The MHLO attributes of the args and the result of the main function
|
|
# Arg0 are the primal inputs, arg1 are the output cotangent, res is the input cotangent
|
|
arg0_attrs, arg1_attrs, res_attrs = re.search(
|
|
r"func.func public @main\(%arg0: tensor<10x20xf32> (.*)"
|
|
r", %arg1: tensor<20x10xf32> (.*)"
|
|
r"\) -> \(tensor<10x20xf32> (.*)", # the result
|
|
vjp_module_str).groups()
|
|
|
|
if in_shardings == "P":
|
|
self.assertRegex(arg0_attrs, re.escape("{devices=[1,2]<=[2]}"))
|
|
self.assertRegex(res_attrs, re.escape("{devices=[1,2]<=[2]}"))
|
|
primal_in_sharding = "{devices=[1,2]<=[2]}"
|
|
else:
|
|
primal_in_sharding = "{replicated}"
|
|
if with_mesh_context:
|
|
self.assertRegex(arg0_attrs, re.escape("replicated"))
|
|
self.assertRegex(res_attrs, re.escape("replicated"))
|
|
else:
|
|
# If there is no mesh context, we have used NamedSharding(None)
|
|
# and then the sharding is unspecified!
|
|
self.assertNotIn("mhlo.sharding", arg0_attrs)
|
|
self.assertNotIn("mhlo.sharding", res_attrs)
|
|
|
|
if out_shardings == "P":
|
|
self.assertRegex(arg1_attrs, re.escape("{devices=[2,1]<=[2]}"))
|
|
primal_out_sharding = "{devices=[2,1]<=[2]}"
|
|
else:
|
|
primal_out_sharding = "{replicated}"
|
|
if with_mesh_context:
|
|
self.assertRegex(arg1_attrs, re.escape("replicated"))
|
|
else:
|
|
self.assertNotIn("mhlo.sharding", arg1_attrs)
|
|
|
|
# Sharding custom calls for the primal input shape all match primal_in_sharding
|
|
primal_in_sharding_calls = re.findall(
|
|
r"custom_call @Sharding.*mhlo.sharding = \"(.+)\".*:.*tensor<10x20xf32>",
|
|
vjp_module_str)
|
|
self.assertTrue(
|
|
all(s == primal_in_sharding for s in primal_in_sharding_calls),
|
|
primal_in_sharding_calls
|
|
)
|
|
|
|
# Custom calls for the primal output shape all match primal_out_sharding
|
|
primal_out_sharding_calls = re.findall(
|
|
r"custom_call @Sharding.*mhlo.sharding = \"(.+)\".*:.*tensor<20x10xf32>",
|
|
vjp_module_str)
|
|
self.assertTrue(
|
|
all(s == primal_out_sharding for s in primal_out_sharding_calls),
|
|
primal_out_sharding_calls
|
|
)
|
|
|
|
# Call the exported gradient functions. In order to set the device context
|
|
# we replicate the inputs. If we don't use a mesh context and there are
|
|
# no shardings on inputs or outputs, then we have serialized for one
|
|
# device.
|
|
if in_shardings != "P" and out_shardings != "P" and not with_mesh_context:
|
|
self.assertEqual(exp_vjp.nr_devices, 1)
|
|
self.assertEqual(exp_vjp2.nr_devices, 1)
|
|
call_mesh = Mesh(jax.devices()[:1], "e")
|
|
else:
|
|
self.assertEqual(exp_vjp.nr_devices, 2)
|
|
self.assertEqual(exp_vjp2.nr_devices, 2)
|
|
call_mesh = Mesh(jax.devices()[:2], "e")
|
|
|
|
g1 = pjit.pjit(exp_vjp.call,
|
|
in_shardings=(NamedSharding(call_mesh, P()),
|
|
NamedSharding(call_mesh, P())))(x, x.T)
|
|
_, f_jax_vjp = jax.vjp(f_jax, x)
|
|
xbar = f_jax_vjp(x.T)
|
|
self.assertAllClose(xbar, g1)
|
|
|
|
g2 = pjit.pjit(exp_vjp2.call,
|
|
in_shardings=(NamedSharding(call_mesh, P()),
|
|
NamedSharding(call_mesh, P()),
|
|
NamedSharding(call_mesh, P())))(x, x.T, x)
|
|
_, f_jax_vjp2 = jax.vjp(f_jax_vjp, x.T)
|
|
xbar2, = f_jax_vjp2((x,))
|
|
self.assertAllClose(xbar2, g2[1])
|
|
|
|
def test_grad_sharding_different_mesh(self):
|
|
# Export and serialize with two similar meshes, the only difference being
|
|
# the order of the devices. grad and serialization should not fail.
|
|
# https://github.com/jax-ml/jax/issues/21314
|
|
def f(x):
|
|
return jnp.sum(x * 2.)
|
|
|
|
mesh = Mesh(jax.local_devices(), "i")
|
|
mesh_rev = Mesh(list(reversed(jax.local_devices())), "i")
|
|
shardings = NamedSharding(mesh, jax.sharding.PartitionSpec(("i",)))
|
|
shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",)))
|
|
input_no_shards = jnp.ones(shape=(jax.local_device_count(),))
|
|
input = jnp.ones(shape=(jax.local_device_count(),), device=shardings)
|
|
input_rev = jax.device_put(input_no_shards, device=shardings_rev)
|
|
|
|
exp = export.export(pjit.pjit(f, in_shardings=shardings))(input)
|
|
exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards)
|
|
|
|
if CAN_SERIALIZE:
|
|
_ = exp.serialize(vjp_order=1)
|
|
_ = exp_rev.serialize(vjp_order=1)
|
|
|
|
g = jax.grad(exp_rev.call)(input_rev)
|
|
g_rev = jax.grad(exp.call)(input)
|
|
self.assertAllClose(g, g_rev)
|
|
|
|
def test_multi_platform(self):
|
|
x = np.arange(8, dtype=np.float32)
|
|
exp = get_exported(jax.jit(_testing_multi_platform_func),
|
|
platforms=("tpu", "cpu", "cuda", "rocm"))(x)
|
|
self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "rocm"))
|
|
module_str = str(exp.mlir_module())
|
|
expected_main_re = (
|
|
r"@main\("
|
|
r"%arg0: tensor<i..>.*jax.global_constant = \"_platform_index\".*, "
|
|
r"%arg1: tensor<8xf32>.*->")
|
|
self.assertRegex(module_str, expected_main_re)
|
|
|
|
self.assertIn("jax.uses_shape_polymorphism = true",
|
|
module_str)
|
|
|
|
# Call with argument placed on different plaforms
|
|
for platform in self.platforms:
|
|
x_device = jax.device_put(x, jax.devices(platform)[0])
|
|
res_exp = exp.call(x_device)
|
|
self.assertAllClose(
|
|
res_exp,
|
|
_testing_multi_platform_fun_expected(x, platform=platform))
|
|
|
|
def test_multi_platform_nested(self):
|
|
x = np.arange(5, dtype=np.float32)
|
|
exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))),
|
|
platforms=("cpu", "tpu", "cuda", "rocm"))(x)
|
|
self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm"))
|
|
|
|
# Now serialize the call to the exported using a different sequence of
|
|
# lowering platforms, but included in the lowering platforms for the
|
|
# nested exported.
|
|
exp2 = get_exported(jax.jit(exp.call),
|
|
platforms=("cpu", "cuda", "rocm"))(x)
|
|
|
|
# Ensure that we do not have multiple lowerings of the exported function
|
|
exp2_module_str = str(exp2.mlir_module())
|
|
count_sine = len(re.findall("stablehlo.sine", exp2_module_str))
|
|
self.assertEqual(1, count_sine)
|
|
|
|
# Call with argument placed on different plaforms
|
|
for platform in self.platforms:
|
|
if platform == "tpu": continue
|
|
x_device = jax.device_put(x, jax.devices(platform)[0])
|
|
res_exp = exp2.call(x_device)
|
|
self.assertAllClose(
|
|
res_exp,
|
|
_testing_multi_platform_fun_expected(np.sin(x), platform=platform))
|
|
|
|
def test_multi_platform_nested_inside_single_platform_export(self):
|
|
x = np.arange(5, dtype=np.float32)
|
|
exp = get_exported(jax.jit(_testing_multi_platform_func),
|
|
platforms=("cpu", "tpu", "cuda", "rocm"))(x)
|
|
self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm"))
|
|
|
|
# Now serialize the call for the current platform.
|
|
exp2 = get_exported(jax.jit(exp.call))(x)
|
|
module_str = str(exp2.mlir_module())
|
|
self.assertIn("jax.uses_shape_polymorphism = true",
|
|
module_str)
|
|
res2 = exp2.call(x)
|
|
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
|
|
|
|
def test_multi_platform_mlir_lower_fun_with_platform_specific_primitives(self):
|
|
# A primitive with multiple lowering rules, which themselves involve
|
|
# tracing primitives with per-platform rules, using mlir.lower_fun.
|
|
# This situation arises for Pallas lowering.
|
|
def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext,
|
|
x: mlir.ir.Value) -> Sequence[mlir.ir.Value]:
|
|
# Lowering n * x
|
|
res = x
|
|
for i in range(n - 1):
|
|
res = mlir.hlo.AddOp(res, x)
|
|
return res.results
|
|
|
|
times_2 = core.Primitive("__testing_times_2") # x2 for cpu
|
|
times_2.def_abstract_eval(lambda x: x)
|
|
# Define lowering rules only for the relevant platforms, ensure there
|
|
# is no error about missing lowering rules
|
|
mlir.register_lowering(times_2, functools.partial(times_n_lowering, 2),
|
|
"cpu")
|
|
|
|
times_3 = core.Primitive("__testing_times_3") # x3 for cuda and rocm
|
|
times_3.def_abstract_eval(lambda x: x)
|
|
|
|
mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3),
|
|
"rocm")
|
|
mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3),
|
|
"cuda")
|
|
|
|
times_4 = core.Primitive("__testing_times_4") # x4 for tpu
|
|
times_4.def_abstract_eval(lambda x: x)
|
|
mlir.register_lowering(times_4, functools.partial(times_n_lowering, 4),
|
|
"tpu")
|
|
|
|
times_2_or_3 = core.Primitive("__testing_times_2_or_3") # x2 for cpu, x3 for cuda and rocm
|
|
times_2_or_3.def_abstract_eval(lambda x: x)
|
|
mlir.register_lowering(times_2_or_3,
|
|
mlir.lower_fun(times_2.bind,
|
|
multiple_results=False), "cpu")
|
|
|
|
mlir.register_lowering(times_2_or_3,
|
|
mlir.lower_fun(times_3.bind,
|
|
multiple_results=False), "rocm")
|
|
mlir.register_lowering(times_2_or_3,
|
|
mlir.lower_fun(times_3.bind,
|
|
multiple_results=False), "cuda")
|
|
|
|
times_2_or_3_or_4 = core.Primitive("__testing_times_2_or_3_or_4") # x2 for cpu, x3 for cuda and rocm, x4 for tpu
|
|
times_2_or_3_or_4.def_abstract_eval(lambda x: x)
|
|
times_2_or_3_or_4_lowering_cpu_gpu = mlir.lower_fun(times_2_or_3.bind,
|
|
multiple_results=False)
|
|
|
|
for platform in ["cpu", "cuda", "rocm"]:
|
|
mlir.register_lowering(times_2_or_3_or_4,
|
|
times_2_or_3_or_4_lowering_cpu_gpu,
|
|
platform)
|
|
mlir.register_lowering(times_2_or_3_or_4, mlir.lower_fun(times_4.bind,
|
|
multiple_results=False),
|
|
"tpu")
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return times_2_or_3_or_4.bind(x)
|
|
x = np.float32(42.)
|
|
exp = export.export(f, platforms=["cpu", "cuda", "rocm", "tpu"])(x)
|
|
expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()])
|
|
self.assertAllClose(exp.call(x), expected)
|
|
|
|
def test_multi_platform_unknown_platform(self):
|
|
x = np.arange(8, dtype=np.float32)
|
|
exp = get_exported(jax.jit(jnp.sin),
|
|
platforms=("tpu", "cpu", "cuda", "other"))(x)
|
|
self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "other"))
|
|
|
|
def test_multi_platform_with_donation(self):
|
|
f = jax.jit(jnp.sin, donate_argnums=(0,))
|
|
x = np.arange(3, dtype=np.float32)
|
|
exp = export.export(f, platforms=["cpu", "tpu"])(x)
|
|
if jtu.device_under_test() not in ["cpu", "tpu"]:
|
|
self.skipTest("other platform")
|
|
|
|
def caller(x):
|
|
y = exp.call(x)
|
|
return x + y
|
|
res = jax.jit(caller)(x)
|
|
self.assertAllClose(res, x + np.sin(x))
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"In multi-platform lowering either all or no lowering platforms should support donation"):
|
|
export.export(f, platforms=["cpu", "tpu", "other"])(x)
|
|
|
|
def test_multi_platform_and_poly(self):
|
|
if jtu.test_device_matches(["gpu"]):
|
|
# The export is not applicable to GPU
|
|
raise unittest.SkipTest("Not intended for running on GPU")
|
|
exp = get_exported(jax.jit(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,))),
|
|
platforms=("cpu", "tpu"))(
|
|
jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32)
|
|
)
|
|
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
|
res = exp.call(x)
|
|
self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,)))
|
|
# Now serialize the call to the exported
|
|
exp2 = get_exported(jax.jit(exp.call))(x)
|
|
res2 = exp2.call(x)
|
|
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
|
|
|
|
def test_multi_platform_and_sharding(self):
|
|
export_devices = jax.devices()[0:2]
|
|
export_mesh = Mesh(export_devices, axis_names=("x",))
|
|
a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4))
|
|
@functools.partial(
|
|
jax.jit,
|
|
in_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),),),
|
|
out_shardings=jax.sharding.NamedSharding(export_mesh, P(None, "x")))
|
|
def f_jax(b): # b: f32[16 // DEVICES, 4]
|
|
return b * 2.
|
|
|
|
res_native = f_jax(a)
|
|
exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a)
|
|
|
|
# Call with argument placed on different plaforms
|
|
for platform in self.platforms:
|
|
run_devices = jax.devices(platform)[0:len(export_devices)]
|
|
if len(run_devices) != len(export_devices):
|
|
continue
|
|
run_mesh = Mesh(run_devices, ("x",))
|
|
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P()))
|
|
res_exp = exp.call(a_device)
|
|
self.assertArraysAllClose(res_native, res_exp)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(v=v)
|
|
for v in range(export.minimum_supported_calling_convention_version,
|
|
export.maximum_supported_calling_convention_version + 1)])
|
|
def test_ordered_effects_basic(self, *, v: int):
|
|
with config.jax_export_calling_convention_version(v):
|
|
logging.info(
|
|
"Using JAX serialization version %s",
|
|
config.jax_export_calling_convention_version.value)
|
|
x = np.arange(3, dtype=np.float32)
|
|
def f_jax(x): # x: f32[3]
|
|
# Test also the calling convention for inner functions
|
|
def f_jax_inner(x):
|
|
return (
|
|
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") +
|
|
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1"))
|
|
return (
|
|
10. +
|
|
jax.jit(f_jax_inner)(x) +
|
|
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") +
|
|
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
|
|
)
|
|
|
|
exp = get_exported(jax.jit(f_jax))(x)
|
|
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
|
sorted(str(e) for e in exp.ordered_effects))
|
|
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
|
[str(e) for e in exp.unordered_effects])
|
|
mlir_module_str = str(exp.mlir_module())
|
|
|
|
# Inner functions use stablehlo.token for all versions
|
|
inner_fun_expected_re = (
|
|
r"func.func private @f_jax_inner\("
|
|
r"%arg0: !stablehlo.token .*jax.token = true.*"
|
|
r"%arg1: tensor<3xf32>.*->.*"
|
|
# Results
|
|
r"!stablehlo.token .*jax.token = true.*"
|
|
r"tensor<3xf32>"
|
|
)
|
|
self.assertRegex(mlir_module_str, inner_fun_expected_re)
|
|
|
|
# The wrapped_main function takens tokens after version 9, and takes
|
|
# i1[0] before version 9.
|
|
wrapped_main_expected_re = (
|
|
r"@_wrapped_jax_export_main\("
|
|
r"%arg0: !stablehlo.token .*jax.token = true.*"
|
|
r"%arg1: !stablehlo.token .*jax.token = true.*->.*"
|
|
# Results
|
|
r"!stablehlo.token .*jax.token = true.*"
|
|
r"!stablehlo.token .*jax.token = true.*")
|
|
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
|
|
|
# The main function takes tokens and has the same type as the wrapped main
|
|
main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main")
|
|
self.assertRegex(mlir_module_str, main_expected_re)
|
|
|
|
# Now call the exported from a function that uses its own effects
|
|
def f_outer(x):
|
|
return (
|
|
testing_primitive_with_effect_p.bind(
|
|
x, effect_class_name="ForTestingOrderedEffect2") +
|
|
testing_primitive_with_effect_p.bind(
|
|
x, effect_class_name="ForTestingUnorderedEffect1") +
|
|
exp.call(x))
|
|
|
|
lowered_outer = jax.jit(f_outer).lower(x)
|
|
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
|
sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]))
|
|
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
|
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
|
|
|
|
mlir_outer_module_str = str(lowered_outer.compiler_ir())
|
|
self.assertRegex(mlir_outer_module_str, main_expected_re)
|
|
|
|
res = jax.jit(f_outer)(x)
|
|
self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(v=v)
|
|
for v in range(export.minimum_supported_calling_convention_version,
|
|
export.maximum_supported_calling_convention_version + 1)])
|
|
def test_ordered_effects_poly(self, *, v: int):
|
|
with config.jax_export_calling_convention_version(v):
|
|
logging.info(
|
|
"Using JAX serialization version %s",
|
|
config.jax_export_calling_convention_version.value)
|
|
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
|
def f_jax(x): # x: f32[b1, b2]
|
|
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
|
|
exp = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct(
|
|
export.symbolic_shape("b2, b1"), x.dtype))
|
|
mlir_module_str = str(exp.mlir_module())
|
|
wrapped_main_expected_re = (
|
|
r"@_wrapped_jax_export_main\("
|
|
r"%arg0: tensor<i..> {jax.global_constant = \"b1\".* "
|
|
r"%arg1: tensor<i..> {jax.global_constant = \"b2\".* "
|
|
r"%arg2: !stablehlo.token {jax.token = true.* "
|
|
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
|
|
# Results
|
|
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
|
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
|
|
|
main_expected_re = (
|
|
r"@main\("
|
|
r"%arg0: !stablehlo.token {jax.token = true.*, "
|
|
r"%arg1: tensor<\?x\?xf32>.*\) -> \("
|
|
# Results
|
|
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
|
self.assertRegex(mlir_module_str, main_expected_re)
|
|
|
|
res = exp.call(x)
|
|
self.assertAllClose(10. + 2. * x, res)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(v=v)
|
|
for v in range(export.minimum_supported_calling_convention_version,
|
|
export.maximum_supported_calling_convention_version + 1)])
|
|
def test_ordered_effects_multi_platform_and_poly(self, *, v: int):
|
|
with config.jax_export_calling_convention_version(v):
|
|
logging.info(
|
|
"Using JAX serialization version %s",
|
|
config.jax_export_calling_convention_version.value)
|
|
if jtu.device_under_test() == "gpu":
|
|
# The export is not applicable to GPU
|
|
raise unittest.SkipTest("Not intended for running on GPU")
|
|
x = np.ones((3, 4), dtype=np.float32)
|
|
def f_jax(x): # x: f32[b1, b2]
|
|
return 10. + _testing_multi_platform_func(x,
|
|
effect_class_name="ForTestingOrderedEffect1")
|
|
exp = get_exported(
|
|
jax.jit(f_jax),
|
|
platforms=("cpu", "tpu")
|
|
)(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype))
|
|
mlir_module_str = str(exp.mlir_module())
|
|
wrapped_main_expected_re = (
|
|
r"@_wrapped_jax_export_main\("
|
|
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
|
|
r"%arg1: tensor<i..> {jax.global_constant = \"b1\".*, "
|
|
r"%arg2: tensor<i..> {jax.global_constant = \"b2\".*, "
|
|
r"%arg3: !stablehlo.token {jax.token = true.*, "
|
|
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
|
|
# Results
|
|
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
|
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
|
|
|
main_expected_re = (
|
|
r"@main\("
|
|
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
|
|
r"%arg1: !stablehlo.token {jax.token = true.*, "
|
|
r"%arg2: tensor<\?x\?xf32>.*\) -> \("
|
|
# Results
|
|
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
|
self.assertRegex(mlir_module_str, main_expected_re)
|
|
res = exp.call(x)
|
|
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
|
|
res)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(v=v)
|
|
for v in range(export.minimum_supported_calling_convention_version,
|
|
export.maximum_supported_calling_convention_version + 1)])
|
|
def test_ordered_effects_with_donation(self, *, v: int):
|
|
with config.jax_export_calling_convention_version(v):
|
|
logging.info(
|
|
"Using JAX serialization version %s",
|
|
config.jax_export_calling_convention_version.value)
|
|
|
|
x = np.arange(3, dtype=np.float32)
|
|
|
|
def f_jax(x):
|
|
return testing_primitive_with_effect_p.bind(
|
|
x, effect_class_name="ForTestingOrderedEffect1"
|
|
)
|
|
|
|
f_jax = jax.jit(f_jax, donate_argnums=(0,))
|
|
exp = export.export(f_jax)(x)
|
|
mlir_module_str = str(exp.mlir_module())
|
|
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1")
|
|
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(name=name, expect_error=expect_error)
|
|
# name is the suffix for event name: ForTestingOrderedEffectxxx
|
|
for name, expect_error in (
|
|
("4NoNullary", "must have a nullary constructor"),
|
|
("5NoEq", "must have a nullary class constructor that produces an "
|
|
"equal effect object"),
|
|
)
|
|
])
|
|
def test_ordered_effects_error(self, *, name: str, expect_error: str):
|
|
if not CAN_SERIALIZE:
|
|
# These errors arise during serialization
|
|
self.skipTest("serialization is disabled")
|
|
x = np.ones((3, 4), dtype=np.float32)
|
|
def f_jax(x):
|
|
return 10. + _testing_multi_platform_func(
|
|
x,
|
|
effect_class_name="ForTestingOrderedEffect" + name)
|
|
with self.assertRaisesRegex(Exception, expect_error):
|
|
_ = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct((3, 4), x.dtype))
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
{"m": 5, "k": 4, "n": 3, "group_sizes": [5]},
|
|
{"m": 10, "k": 9, "n": 8, "group_sizes": [3, 7]},
|
|
])
|
|
def test_ragged_dot(self, m, k, n, group_sizes):
|
|
def f_jax(x, y, gs):
|
|
return jax.lax.ragged_dot(x, y, gs)
|
|
dtype = np.float32
|
|
group_sizes = np.array(group_sizes, dtype=np.int32)
|
|
lhs = np.arange(m * k, dtype=dtype).reshape((m, k))
|
|
num_groups = group_sizes.shape[0]
|
|
rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n))
|
|
res_native = f_jax(lhs, rhs, group_sizes)
|
|
|
|
exp_f = get_exported(jax.jit(f_jax))(
|
|
jax.ShapeDtypeStruct(lhs.shape, dtype=lhs.dtype),
|
|
jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype),
|
|
jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype),
|
|
)
|
|
res_exported = exp_f.call(lhs, rhs, group_sizes)
|
|
self.assertAllClose(res_native, res_exported)
|
|
|
|
def test_different_meshes(self):
|
|
if config.use_shardy_partitioner.value:
|
|
self.skipTest("TODO(b/394265659) can't create NameShardings for Shardy "
|
|
"since there are multiple different meshes")
|
|
# Make sure we can save with various meshes, and load with various meshes,
|
|
# all different from one another.
|
|
if jax.local_device_count() < 8:
|
|
self.skipTest("Need at least 8 devices")
|
|
|
|
old_mesh_0 = jtu.create_mesh((4, 2), ("old_a", "old_b"))
|
|
old_mesh_1 = jtu.create_mesh((2, 4), ("old_x", "old_y"))
|
|
@jax.jit
|
|
def f(x, y):
|
|
z = x + y
|
|
return jax.lax.with_sharding_constraint(
|
|
z, NamedSharding(old_mesh_0, P("old_b")))
|
|
|
|
exp = get_exported(f)(
|
|
jax.ShapeDtypeStruct(
|
|
(32, 32), dtype=np.float32,
|
|
sharding=NamedSharding(old_mesh_0, P(None, "old_a"))),
|
|
jax.ShapeDtypeStruct(
|
|
(32, 32), dtype=np.float32,
|
|
sharding=NamedSharding(old_mesh_1, P("old_y"))))
|
|
|
|
# Call the Exported with a concrete Mesh
|
|
new_mesh_0 = jtu.create_mesh((2, 2, 2), ("new_1", "new_2", "new_3"))
|
|
new_mesh_1 = jtu.create_mesh((1, 8), ("newA", "newB"))
|
|
|
|
a = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32))
|
|
a = jax.device_put(a, NamedSharding(new_mesh_0, P(None, "new_1")))
|
|
b = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32))
|
|
b = jax.device_put(b, NamedSharding(new_mesh_1, P("newB")))
|
|
|
|
r = jax.jit(exp.call, out_shardings=NamedSharding(old_mesh_0, P("old_b")))(a, b)
|
|
self.assertAllClose(a + b, r)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|