[export] Create the jax.export module APIs.

The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
This commit is contained in:
George Necula 2024-06-10 09:45:09 +02:00
parent 14d87d3bf7
commit b33aca6b08
15 changed files with 342 additions and 159 deletions

View File

@ -18,7 +18,7 @@ import google_benchmark as benchmark
import jax
from jax import core
from jax._src.numpy import lax_numpy
from jax.experimental import export
from jax import export
jax.config.parse_flags_with_absl()

View File

@ -333,7 +333,7 @@ class Exported:
`jax.device_put`.
Example usage:
>>> from jax.experimental import export
>>> from jax import export
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
@ -347,7 +347,7 @@ class Exported:
# Put the args and kwargs on the appropriate devices
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
... exp.in_shardings_jax(run_mesh)[0])
>>> res = export.call(exp)(run_arg)
>>> res = exp.call(run_arg)
>>> res.addressable_shards
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
@ -372,19 +372,53 @@ class Exported:
for s in self.out_shardings_hlo)
def has_vjp(self) -> bool:
"""Returns if this Exported supports VJP."""
return self._get_vjp is not None
def vjp(self) -> Exported:
"""Gets the exported VJP.
Returns None if not available, which can happen if the Exported has been
loaded from an external format, without a VJP."""
loaded from an external format without a VJP.
"""
if self._get_vjp is None:
raise ValueError("No VJP is available")
return self._get_vjp(self)
def serialize(self,
vjp_order: int = 0) -> bytearray:
"""Serializes an Exported.
Args:
vjp_order: The maximum vjp order to include. E.g., the value 2 means that we
serialize the primal functions and two orders of the `vjp` function. This
should allow 2nd order reverse mode differentiation of the deserialized
function. i.e., `jax.grad(jax.grad(f)).`
"""
# Lazy load the serialization module, since flatbuffers is an optional
# dependency.
from jax._src.export.serialization import serialize
return serialize(self, vjp_order=vjp_order)
def call(self, *args, **kwargs):
return call_exported(self)(*args, **kwargs)
def deserialize(blob: bytearray) -> Exported:
"""Deserializes an Exported.
Args:
blob: a bytearray obtained from `Exported.serialize`.
"""
# Lazy load the serialization module, since flatbuffers is an optional
# dependency.
from jax._src.export.serialization import deserialize
return deserialize(blob)
def default_lowering_platform() -> str:
"""Retrieves the default lowering platform for the exporting machine.
"""
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
return xb.canonicalize_platform(jax.default_backend())
@ -411,7 +445,9 @@ def args_specs(
return shape_poly.symbolic_args_specs(args, polymorphic_shapes)
def export(fun_jax: Callable,
# TODO(necula): remove this once we remove jax.experimental.export.
def export_back_compat(
fun_jax: Callable,
*,
lowering_platforms: Sequence[str] | None = None,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
@ -419,6 +455,10 @@ def export(fun_jax: Callable,
) -> Callable[..., Exported]:
"""Exports native serialization for a JAX function.
Note: this function exists only for internal usage by jax2tf and for
backwards compatibility with jax.experimental.export. Use
`jax.export` instead.
Args:
fun_jax: the function to lower and serialize.
lowering_platforms:
@ -498,6 +538,85 @@ def export(fun_jax: Callable,
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
return do_export
def export(
fun_jit: stages.Wrapped,
*,
lowering_platforms: Sequence[str] | None = None,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
) -> Callable[..., Exported]:
"""Exports a JAX function for persistent serialization.
Args:
fun_jit: the function to export. Should be the result of `jit`.
lowering_platforms:
Optional sequence containing a subset of 'tpu', 'cpu',
'cuda', 'rocm'. If more than one platform is specified, then
the lowered code takes an argument specifying the platform.
If None, then use the default JAX backend.
The calling convention for multiple platforms is explained in the
`jax_export.Exported` docstring.
disabled_checks: the safety checks to disable. See docstring
of `DisabledSafetyCheck`.
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
or values with `.shape` and `.dtype` attributes, and returns an
`Exported`.
Usage:
>>> from jax import export
>>> exported: export.Exported = export.export(jnp.sin)(
... np.arange(4, dtype=np.float32))
# You can inspect the Exported object
>>> exported.in_avals
(ShapedArray(float32[4]),)
>>> blob: bytearray = exported.serialize()
# The serialized bytes are safe to use in a separate process
>>> rehydrated: export.Exported = export.deserialize(blob)
>>> rehydrated.fun_name
'sin'
>>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32))
Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)
"""
if not isinstance(fun_jit, stages.Wrapped):
raise ValueError(
f"Function to be exported must be the result of `jit` but is: {fun_jit}")
if lowering_platforms is not None:
actual_lowering_platforms = tuple(lowering_platforms)
else:
actual_lowering_platforms = (default_lowering_platform(),)
def do_export(*args_specs, **kwargs_specs) -> Exported:
# TODO: move to `lower`
symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore]
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
# Static args may have no `shape` attribute.
if not hasattr(aval, "shape"):
continue
for d in aval.shape:
if shape_poly.is_symbolic_dim(d):
if symbolic_scope is None:
symbolic_scope = (d.scope, k_path)
continue
symbolic_scope[0]._check_same_scope(
d, when=f"when exporting {util.fun_name(fun_jit)}",
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
traced = fun_jit.trace( # type: ignore
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
for_export=True,
))
jaxpr, fun_name = traced.jaxpr, traced.fun_name
lowered = traced.lower()
return _export_lowered(
lowered, jaxpr, fun_name,
disabled_checks=disabled_checks)
return do_export
def _export_lowered(
lowered: stages.Lowered,
jaxpr: core.ClosedJaxpr, fun_name: str,
@ -599,7 +718,7 @@ def _export_lowered(
device_assignment=device_assignment,
apply_jit=True,
flat_primal_fun=True)
return export(fun_vjp_jax,
return export(fun_vjp_jax, # type: ignore[arg-type]
lowering_platforms=exp_primal.lowering_platforms,
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
@ -816,7 +935,7 @@ def _wrap_main_func(
def _check_lowering(lowering) -> None:
if not isinstance(lowering, pxla.MeshComputation):
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
raise NotImplementedError(f"serialization is supported only for jit. {lowering}")
if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]:
raise NotImplementedError("serialization of host_callbacks is not yet implemented")

View File

@ -48,7 +48,7 @@ SerT = TypeVar("SerT")
_SERIALIZATION_VERSION = 2
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
"""Serialize an Exported.
"""Serializes an Exported.
Args:
exp: the Exported to serialize.
@ -64,7 +64,7 @@ def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
def deserialize(ser: bytearray) -> _export.Exported:
"""Deserialize an Exported."""
"""Deserializes an Exported."""
exp = ser_flatbuf.Exported.GetRootAsExported(ser)
return _deserialize_exported(exp)

View File

@ -86,7 +86,7 @@ from numpy import array, float32
import jax
from jax import tree_util
from jax.experimental import export
from jax import export
from jax.experimental import pjit
@ -345,4 +345,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
_get_vjp=_get_vjp)
# We use pjit in case there are shardings in the exported module.
return pjit.pjit(export.call(exported))(*data.inputs)
return pjit.pjit(exported.call)(*data.inputs)

View File

@ -32,7 +32,7 @@ from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, Union
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
import warnings
import jax
@ -756,8 +756,9 @@ class Lowered(Stage):
return None
@runtime_checkable
class Wrapped(Protocol):
"""A function ready to be specialized, lowered, and compiled.
"""A function ready to be traced, lowered, and compiled.
This protocol reflects the output of functions such as
``jax.jit``. Calling it results in JIT (just-in-time) lowering,

View File

@ -17,12 +17,13 @@ from jax._src.export._export import (
minimum_supported_serialization_version,
maximum_supported_serialization_version,
Exported,
export,
call_exported, # TODO: deprecate
call,
DisabledSafetyCheck,
default_lowering_platform,
default_lowering_platform, # TODO: deprecate
)
from jax._src.export._export import export_back_compat as export
from jax._src.export.shape_poly import (
is_symbolic_dim,
symbolic_shape,
@ -33,4 +34,6 @@ from jax._src.export.serialization import (
serialize,
deserialize,
)
# Import only to set the shape poly decision procedure
from jax._src.export import shape_poly_decision
del shape_poly_decision

View File

@ -36,7 +36,7 @@ from jax import random
from jax import numpy as jnp
from jax import tree_util
from jax import sharding
from jax.experimental import export
from jax import export
from jax.experimental.jax2tf import impl_no_xla
from jax.interpreters import xla
@ -515,7 +515,7 @@ class NativeSerializationImpl(SerializationImpl):
self._restore_context = _restore_context
_exported_device_assignment = [None]
self.exported = export.export(
self.exported = _export.export_back_compat(
self.fun_jax,
lowering_platforms=self.native_serialization_platforms,
disabled_checks=self.native_serialization_disabled_checks,

View File

@ -25,13 +25,13 @@ from absl.testing import parameterized
import jax
from jax import dlpack
from jax import dtypes
from jax import export
from jax import lax
from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax.experimental import export
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
import numpy as np
@ -778,7 +778,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
lowering_platforms = ("tpu", "cpu", "cuda")
exp = export.export(f_jax,
exp = export.export(jax.jit(f_jax),
lowering_platforms=lowering_platforms)(x)
for jax_platform in jax_and_tf_platforms:
with self.subTest(jax_platform):
@ -787,7 +787,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
logging.info("Running harness natively on %s", jax_device)
native_res = f_jax(x_device)
logging.info("Running exported harness on %s", jax_device)
exported_res = export.call(exp)(x_device)
exported_res = exp.call(x_device)
self.assertAllClose(native_res, exported_res)
def test_multi_platform_call_tf_graph(self):

View File

@ -27,6 +27,7 @@ from absl.testing import absltest, parameterized
import jax
from jax import ad_checkpoint
from jax import dtypes
from jax import export
from jax import lax
from jax import numpy as jnp
from jax import sharding
@ -37,7 +38,6 @@ from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax.experimental import jax2tf
from jax.experimental import export
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.shard_map import shard_map
from jax.experimental import pjit
@ -1559,7 +1559,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
# Run the JAX native version, to check it works, and to fill caches.
_ = func_to_convert(*args)
exported = export.export(
func_to_convert,
(jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert),
lowering_platforms=("tpu",)
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))

View File

@ -32,8 +32,8 @@ import re
import jax
from jax.experimental import jax2tf
from jax.experimental import export
from jax.experimental import pjit
from jax import export
from jax import lax
import jax.numpy as jnp
from jax import random

View File

@ -31,7 +31,7 @@ from jax._src import test_util as jtu
from jax import tree_util
from jax.experimental import jax2tf
from jax.experimental import export
from jax import export
from jax._src import config
from jax._src import xla_bridge
import numpy as np

34
jax/export.py Normal file
View File

@ -0,0 +1,34 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
"maximum_supported_serialization_version",
"minimum_supported_serialization_version",
"default_lowering_platform",
"SymbolicScope", "is_symbolic_dim",
"symbolic_shape", "symbolic_args_specs"]
from jax._src.export._export import DisabledSafetyCheck as DisabledSafetyCheck
from jax._src.export._export import Exported as Exported
from jax._src.export._export import export as export
from jax._src.export._export import deserialize as deserialize
from jax._src.export._export import maximum_supported_serialization_version as maximum_supported_serialization_version
from jax._src.export._export import minimum_supported_serialization_version as minimum_supported_serialization_version
from jax._src.export._export import default_lowering_platform as default_lowering_platform
from jax._src.export import shape_poly_decision # Import only to set the decision procedure
del shape_poly_decision
from jax._src.export.shape_poly import SymbolicScope as SymbolicScope
from jax._src.export.shape_poly import is_symbolic_dim as is_symbolic_dim
from jax._src.export.shape_poly import symbolic_shape as symbolic_shape
from jax._src.export.shape_poly import symbolic_args_specs as symbolic_args_specs

View File

@ -31,9 +31,9 @@ from absl.testing import absltest
import numpy as np
import jax
from jax import export
from jax import lax
from jax._src import test_util as jtu
from jax.experimental import export
from jax._src.internal_test_util import test_harnesses
@ -152,7 +152,8 @@ class PrimitiveTest(jtu.JaxTestCase):
)
logging.info("Exporting harness for %s", lowering_platforms)
exp = export.export(func_jax, lowering_platforms=lowering_platforms)(*args)
exp = export.export(jax.jit(func_jax),
lowering_platforms=lowering_platforms)(*args)
for device in devices:
if device.platform in skip_run_on_platforms:
@ -164,7 +165,7 @@ class PrimitiveTest(jtu.JaxTestCase):
logging.info("Running harness natively on %s", device)
native_res = func_jax(*device_args)
logging.info("Running exported harness on %s", device)
exported_res = export.call(exp)(*device_args)
exported_res = exp.call(*device_args)
if tol is not None:
logging.info(f"Using non-standard tolerance {tol}")
self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol)

View File

@ -20,12 +20,13 @@ import logging
import math
import re
import unittest
from typing import Callable
from absl.testing import absltest
import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental import export
from jax import export
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding
@ -137,12 +138,12 @@ def _testing_multi_platform_fun_expected(x,
]
def get_exported(fun, vjp_order=0,
def get_exported(fun: Callable, vjp_order=0,
**export_kwargs):
"""Like export.export but with serialization + deserialization."""
def serde_exported(*fun_args, **fun_kwargs):
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
serialized = export.serialize(exp, vjp_order=vjp_order)
serialized = exp.serialize(vjp_order=vjp_order)
return export.deserialize(serialized)
return serde_exported
@ -164,11 +165,13 @@ class JaxExportTest(jtu.JaxTestCase):
super().setUpClass()
def test_basic_export_only(self):
@jax.jit
def my_fun(x):
return jnp.sin(x)
exp = get_exported(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
self.assertEqual("my_fun", exp.fun_name)
self.assertEqual((export.default_lowering_platform(),),
expected_lowering_platform = xb.canonicalize_platform(jax.default_backend())
self.assertEqual((expected_lowering_platform,),
exp.lowering_platforms)
self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
@ -180,7 +183,7 @@ class JaxExportTest(jtu.JaxTestCase):
def f(a_b_pair, *, a, b):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp = get_exported(f, lowering_platforms=("cpu",))((a, b), a=a, b=b)
exp = get_exported(jax.jit(f), lowering_platforms=("cpu",))((a, b), a=a, b=b)
a_aval = core.ShapedArray(a.shape, a.dtype)
b_aval = core.ShapedArray(b.shape, b.dtype)
self.assertEqual(exp.lowering_platforms, ("cpu",))
@ -196,8 +199,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x)
f1 = export.call(exp_f)
self.assertAllClose(f(x), f1(x))
self.assertAllClose(f(x), exp_f.call(x))
def test_jit_static_arg(self):
@ -210,8 +212,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x, c=0.1)
f1 = export.call(exp_f)
self.assertAllClose(f(x, c=0.1), f1(x))
self.assertAllClose(f(x, c=0.1), exp_f.call(x))
with self.subTest("static_argnums"):
@ -222,16 +223,32 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32)
exp_g = get_exported(g)(x, 0.1)
g1 = export.call(exp_g)
self.assertAllClose(g(x, 0.1), g1(x))
self.assertAllClose(g(x, 0.1), exp_g.call(x))
def test_export_error_no_jit(self):
# Can export a lambda, without jit
with self.assertRaisesRegex(ValueError,
"Function to be exported must be the result of `jit`"):
_ = export.export(lambda x: jnp.sin(x))
def test_export_experimental_back_compat(self):
from jax.experimental import export
# Can export a lambda, without jit
exp = export.export(lambda x: jnp.sin(x))(.1)
self.assertAllClose(exp.call(1.), np.sin(1.))
blob = export.serialize(exp, vjp_order=1)
rehydrated = export.deserialize(blob)
self.assertAllClose(export.call(exp)(1.), np.sin(1.))
self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.))
def test_call_exported_lambda(self):
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
f = lambda x: jnp.sin(x)
f = jax.jit(lambda x: jnp.sin(x))
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x)
f1 = export.call(exp_f)
self.assertAllClose(f(x), f1(x))
self.assertAllClose(f(x), exp_f.call(x))
def test_call_name_conflict(self):
@jax.jit
@ -246,7 +263,7 @@ class JaxExportTest(jtu.JaxTestCase):
@jax.jit
def outer(x):
# There should be no conflict on _where
x = export.call(exp_inner)(x)
x = exp_inner.call(x)
return inner(x)
export.export(outer)(x)
@ -257,19 +274,18 @@ class JaxExportTest(jtu.JaxTestCase):
@jax.jit
def f1(x):
exp_f = get_exported(f)(x)
return export.call(exp_f)(x) + export.call(exp_f)(x)
exp_f = get_exported(jax.jit(f))(x)
return exp_f.call(x) + exp_f.call(x)
self.assertAllClose(2. * f(x), f1(x))
def test_unused_args(self):
f = lambda x, y: jnp.sin(x)
f = jax.jit(lambda x, y: jnp.sin(x))
x = np.arange(4, dtype=np.float32)
y = np.arange(6, dtype=np.float32)
exp_f = get_exported(f)(x, y)
f1 = export.call(exp_f)
self.assertAllClose(f(x, y), f1(x, y))
self.assertAllClose(f(x, y), exp_f.call(x, y))
def test_pytree(self):
a = np.arange(4, dtype=np.float32)
@ -277,43 +293,49 @@ class JaxExportTest(jtu.JaxTestCase):
def f(a_b_pair, a, b):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp_f = get_exported(f)((a, b), a=a, b=b)
f1 = export.call(exp_f)
exp_f = get_exported(jax.jit(f))((a, b), a=a, b=b)
self.assertAllClose(f((a, b), a=a, b=b),
f1((a, b), a=a, b=b))
exp_f.call((a, b), a=a, b=b))
def test_error_wrong_intree(self):
def f(a_b_pair, *, c):
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c
a = b = c = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)((a, b), c=c)
exp_f = get_exported(jax.jit(f))((a, b), c=c)
with self.assertRaisesRegex(
ValueError,
"The invocation args and kwargs must have the same pytree structure"):
export.call(exp_f)(a, b, c=(a, b))
exp_f.call(a, b, c=(a, b))
def test_error_wrong_avals(self):
def f(a, *, b): # a: f32[4] and b: f32[4]
return jnp.sin(a) + jnp.cos(b)
f32_4 = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(f32_4, b=f32_4)
exp_f = get_exported(jax.jit(f))(f32_4, b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for args\[0\].shape\[0\]"):
export.call(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
exp_f.call(np.arange(6, dtype=np.float32), b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
export.call(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
exp_f.call(f32_4, b=np.arange(6, dtype=np.float32))
with self.assertRaisesRegex(ValueError,
r"Rank mismatch for args\[0\]"):
export.call(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
exp_f.call(f32_4.reshape((1, 4)), b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Dtype mismatch for args\[0\]"):
export.call(exp_f)(f32_4.astype(np.float16), b=f32_4)
exp_f.call(f32_4.astype(np.float16), b=f32_4)
def test_default_lowering_platform(self):
test_platform = jtu.device_under_test()
if test_platform == "gpu": test_platform = "cuda"
self.assertEqual(export.default_lowering_platform(), test_platform)
exp = export.export(jnp.sin)(1.)
self.assertEqual(exp.lowering_platforms, (export.default_lowering_platform(),))
@jtu.parameterized_filterable(
testcase_name=lambda kw: kw["platform"],
@ -328,13 +350,13 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError, "The exported function .* was lowered for platform"):
export.call(exp_f)(a)
exp_f.call(a)
# Now try with the platform check disabled
exp_f_no_platform_check = get_exported(
jnp.sin, lowering_platforms=(platform,),
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
res = export.call(exp_f_no_platform_check)(a)
res = exp_f_no_platform_check.call(a)
self.assertAllClose(res, jnp.sin(a))
@jtu.parameterized_filterable(
@ -357,12 +379,12 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError,
"Cannot serialize code with custom calls whose targets .*"):
get_exported(
lambda a: a + test_primitive.bind(a)
jax.jit(lambda a: a + test_primitive.bind(a))
)(a)
# Now try again with the safety check disabled
exp = get_exported(
lambda a: a + test_primitive.bind(a),
jax.jit(lambda a: a + test_primitive.bind(a)),
disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")]
)(a)
self.assertIn("disallowed_call_target", exp.mlir_module())
@ -387,22 +409,22 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError,
"Lowering for export not supported"):
export.export(f)(a)
export.export(jax.jit(f))(a)
def test_grad(self):
f = lambda x: jnp.sum(jnp.sin(x))
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f, vjp_order=1)(x)
exp_f = get_exported(jax.jit(f), vjp_order=1)(x)
f1 = export.call(exp_f)
f1 = exp_f.call
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))
def test_higher_order_grad(self):
f = lambda x: x ** 3
x = np.float32(4.)
exp_f = get_exported(f, vjp_order=3)(x)
exp_f = get_exported(jax.jit(f), vjp_order=3)(x)
f1 = export.call(exp_f)
f1 = exp_f.call
self.assertAllClose(jax.grad(f)(x),
jax.grad(f1)(x))
self.assertAllClose(jax.grad(jax.grad(f))(x),
@ -428,8 +450,8 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertAllClose(res, (xi_ct, xf_ct))
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))
exp = get_exported(f, vjp_order=2)(xi, xf)
fr = export.call(exp)
exp = get_exported(jax.jit(f), vjp_order=2)(xi, xf)
fr = exp.call
res = fr(xi, xf)
self.assertAllClose(res, (f_outi, f_outf))
@ -456,14 +478,14 @@ class JaxExportTest(jtu.JaxTestCase):
a = np.arange(4, dtype=np.float32)
b = np.arange(6, dtype=np.float32)
exp_f = get_exported(f, vjp_order=1)((a, b), a=a, b=b)
exp_f = get_exported(jax.jit(f), vjp_order=1)((a, b), a=a, b=b)
out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent
def f1_jax(a, b): # For VJP, make a function without kwargs
res = f((a, b), a=a, b=b)
return res
def f1_exp(a, b): # For VJP, make a function without kwargs
res = export.call(exp_f)((a, b), a=a, b=b)
res = exp_f.call((a, b), a=a, b=b)
return res
jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct)
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
@ -473,15 +495,15 @@ class JaxExportTest(jtu.JaxTestCase):
def f1(x):
return jnp.sin(x)
a = np.arange(4, dtype=np.float32)
exp_f1 = get_exported(f1)(a)
exp_f1 = get_exported(jax.jit(f1))(a)
def f2(x):
res1 = export.call(exp_f1)(x)
res2 = export.call(exp_f1)(res1)
res1 = exp_f1.call(x)
res2 = exp_f1.call(res1)
return jnp.cos(res2)
exp_f2 = get_exported(f2)(a)
exp_f2 = get_exported(jax.jit(f2))(a)
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
export.call(exp_f2)(a))
exp_f2.call(a))
def test_poly_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
@ -489,7 +511,7 @@ class JaxExportTest(jtu.JaxTestCase):
return jnp.concatenate([a, b], axis=0)
scope = export.SymbolicScope()
exp = get_exported(f)(
exp = get_exported(jax.jit(f))(
jax.ShapeDtypeStruct(export.symbolic_shape("(2*w, h)", scope=scope), a.dtype),
jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)", scope=scope), a.dtype))
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
@ -528,7 +550,7 @@ class JaxExportTest(jtu.JaxTestCase):
return jnp.concatenate([a0, a1, ak], axis=0)
a_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)"), a.dtype)
exp = get_exported(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
exp = get_exported(jax.jit(f))(a_poly_spec, a_poly_spec, ak=a_poly_spec)
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
@ -545,7 +567,7 @@ class JaxExportTest(jtu.JaxTestCase):
"Invalid mixing of symbolic scopes when exporting f.*"
r"Expected current \(from args\[0\]\) scope .*"
r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)):
get_exported(f)(x_poly_spec, y_poly_spec)
get_exported(jax.jit(f))(x_poly_spec, y_poly_spec)
def test_poly_export_callable_with_no_name(self):
# This was reported by a user
@ -566,7 +588,7 @@ class JaxExportTest(jtu.JaxTestCase):
a, = export.symbolic_shape("a,")
# No error
_ = get_exported(MyCallable())(
_ = get_exported(jax.jit(MyCallable()))(
jax.ShapeDtypeStruct((a, a), dtype=np.float32)
)
@ -590,7 +612,7 @@ class JaxExportTest(jtu.JaxTestCase):
exp = get_exported(jnp.sin)(
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
x = np.arange(30, dtype=np.float32).reshape((5, 6))
res = export.call(exp)(x)
res = exp.call(x)
self.assertAllClose(res, np.sin(x))
# A function is exported with f32[poly_spec] and is called with different arg
@ -631,7 +653,7 @@ class JaxExportTest(jtu.JaxTestCase):
return jnp.reshape(x, (-1, x.shape[1]))
disabled_checks = ()
exp_f = get_exported(f, disabled_checks=disabled_checks)(
exp_f = get_exported(jax.jit(f), disabled_checks=disabled_checks)(
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), np.float32))
self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12")
arg = np.arange(np.prod(arg_shape),
@ -642,7 +664,7 @@ class JaxExportTest(jtu.JaxTestCase):
stack.push(self.assertRaisesRegex(Exception, expect_error))
assert core.is_constant_shape(arg.shape)
res = export.call(exp_f)(arg)
res = exp_f.call(arg)
if not expect_error:
self.assertAllClose(res, f(arg))
@ -733,7 +755,7 @@ class JaxExportTest(jtu.JaxTestCase):
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
inner_exp = get_exported(inner)(
inner_exp = get_exported(jax.jit(inner))(
jax.ShapeDtypeStruct(export.symbolic_shape(inner_poly_spec), np.float32))
self.assertEqual(inner_exp.uses_shape_polymorphism,
@ -741,14 +763,14 @@ class JaxExportTest(jtu.JaxTestCase):
def outer(x): # x: outer_poly_spec
# Use an addition to test that the shapes are refined properly for the
# result of the call_exported.
return export.call(inner_exp)(x) + inner(x)
return inner_exp.call(x) + inner(x)
with contextlib.ExitStack() as stack:
if expect_error_outer_exp is not None:
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
# Call it after exporting again, with polymorphic shapes
outer_exp = get_exported(outer)(
outer_exp = get_exported(jax.jit(outer))(
jax.ShapeDtypeStruct(export.symbolic_shape(outer_poly_spec), arg.dtype))
if expect_error_outer_exp is not None:
@ -761,7 +783,7 @@ class JaxExportTest(jtu.JaxTestCase):
if expect_error_run is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
res = export.call(outer_exp)(arg)
res = outer_exp.call(arg)
if expect_error_run is not None:
return
@ -823,20 +845,21 @@ class JaxExportTest(jtu.JaxTestCase):
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
exp = get_exported(f_jax)(
exp = get_exported(jax.jit(f_jax))(
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype))
export.call(exp)(x)
exp.call(x)
def test_poly_booleans(self):
# For booleans we use a special case ConvertOp to cast to and from
# dynamic shapes arguments.
@jax.jit
def f_jax(x): # x: bool[b]
return jnp.logical_not(x)
x = np.array([True, False, True, False], dtype=np.bool_)
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype))
res = export.call(exp)(x)
res = exp.call(x)
self.assertAllClose(f_jax(x), res)
@jtu.parameterized_filterable(
@ -851,13 +874,14 @@ class JaxExportTest(jtu.JaxTestCase):
"int4",
"uint4"}:
self.skipTest(f"TODO: serialization not supported for {str(dtype)}")
@jax.jit
def f_jax(x):
return x + x
x = np.arange(6, dtype=dtype)
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype))
res = export.call(exp)(x)
res = exp.call(x)
self.assertAllClose(f_jax(x), res)
def test_poly_expressions(self):
@ -867,6 +891,7 @@ class JaxExportTest(jtu.JaxTestCase):
return (b + b, b - b, b * b,
(b + 13) // b, (b + 13) % b,
core.max_dim(b - 5, 0))
@jax.jit
def f(x): # x: f32[b]
b = x.shape[0]
return jnp.ones(output_shape(b), dtype=x.dtype)
@ -874,12 +899,12 @@ class JaxExportTest(jtu.JaxTestCase):
exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype))
# Call with static shapes
res = export.call(exp)(x)
res = exp.call(x)
self.assertAllClose(res, f(x))
# Now re-export with shape polymorphism
x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype)
exp2 = get_exported(export.call(exp))(x_spec)
exp2 = get_exported(jax.jit(exp.call))(x_spec)
a = exp2.in_avals[0].shape[0]
self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
@ -890,9 +915,9 @@ class JaxExportTest(jtu.JaxTestCase):
return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1))
a, = export.symbolic_shape("a")
exp = export.export(f)(
exp = export.export(jax.jit(f))(
jax.ShapeDtypeStruct((a, 4), np.float32))
f_exp = export.call(exp)
f_exp = exp.call
x_jit = np.arange(12, dtype=np.float32).reshape((3, 4))
res_jit = jax.jit(f_exp)(x_jit)
self.assertAllClose(res_jit, f(x_jit))
@ -931,24 +956,24 @@ class JaxExportTest(jtu.JaxTestCase):
# We apply the out_shardings for f_jax
r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
re.DOTALL)
hlo = jax.jit(export.call(exp)).lower(a_device).as_text()
hlo = jax.jit(exp.call).lower(a_device).as_text()
self.assertRegex(hlo, expected_re)
res_exported = export.call(exp)(a_device)
res_exported = exp.call(a_device)
self.assertAllClose(res_native, res_exported)
# Test error reporting
with self.assertRaisesRegex(
NotImplementedError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
_ = export.call(exp)(a)
_ = exp.call(a)
with self.assertRaisesRegex(
NotImplementedError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
_ = jax.jit(
export.call(exp),
exp.call,
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
)(a)
@ -973,7 +998,7 @@ class JaxExportTest(jtu.JaxTestCase):
run_input_shardings = exp.in_shardings_jax(run_mesh)
a_run = jax.device_put(a, run_input_shardings[0])
b_run = jax.device_put(a, run_input_shardings[1])
res = export.call(exp)(a_run, b_run)
res = exp.call(a_run, b_run)
self.assertEqual(res.addressable_shards[0].device, run_devices[0])
self.assertEqual(res.addressable_shards[1].device, run_devices[1])
@ -996,7 +1021,7 @@ class JaxExportTest(jtu.JaxTestCase):
run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
res_exported = export.call(exp)(b)
res_exported = exp.call(b)
self.assertAllClose(res_native, res_exported)
def test_call_with_different_no_of_devices_error_has_in_shardings(self):
@ -1024,7 +1049,7 @@ class JaxExportTest(jtu.JaxTestCase):
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"):
export.call(exp)(b)
exp.call(b)
def test_call_with_different_no_of_devices_pmap(self):
if len(jax.devices()) < 2:
@ -1042,7 +1067,7 @@ class JaxExportTest(jtu.JaxTestCase):
b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape(
(-1, 1, 100)
)
res_exported = jax.pmap(export.call(exp))(b)
res_exported = jax.pmap(exp.call)(b)
self.assertAllClose(res_native, res_exported[0])
def test_call_with_different_no_of_devices_error_has_sharding_constraint(self):
@ -1070,7 +1095,7 @@ class JaxExportTest(jtu.JaxTestCase):
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"):
export.call(exp)(b)
exp.call(b)
@jtu.parameterized_filterable(
kwargs=[
@ -1108,7 +1133,7 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertLen(res_jax.addressable_shards, len(devices))
# Test reloaded execution.
f_r = export.call(exp)
f_r = exp.call
with self.assertRaisesRegex(
Exception,
"Exported module .* was lowered for 2 devices and is "
@ -1241,14 +1266,14 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual(exp_vjp2.nr_devices, 2)
call_mesh = Mesh(jax.devices()[:2], "e")
g1 = pjit.pjit(export.call(exp_vjp),
g1 = pjit.pjit(exp_vjp.call,
in_shardings=(NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None)))(x, x.T)
_, f_jax_vjp = jax.vjp(f_jax, x)
xbar = f_jax_vjp(x.T)
self.assertAllClose(xbar, g1)
g2 = pjit.pjit(export.call(exp_vjp2),
g2 = pjit.pjit(exp_vjp2.call,
in_shardings=(NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None)))(x, x.T, x)
@ -1274,16 +1299,16 @@ class JaxExportTest(jtu.JaxTestCase):
exp = export.export(pjit.pjit(f, in_shardings=shardings))(input)
exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards)
_ = export.serialize(exp, vjp_order=1)
_ = export.serialize(exp_rev, vjp_order=1)
_ = exp.serialize(vjp_order=1)
_ = exp_rev.serialize(vjp_order=1)
g = jax.grad(export.call(exp_rev))(input_rev)
g_rev = jax.grad(export.call(exp))(input)
g = jax.grad(exp_rev.call)(input_rev)
g_rev = jax.grad(exp.call)(input)
self.assertAllClose(g, g_rev)
def test_multi_platform(self):
x = np.arange(8, dtype=np.float32)
exp = get_exported(_testing_multi_platform_func,
exp = get_exported(jax.jit(_testing_multi_platform_func),
lowering_platforms=("tpu", "cpu", "cuda","rocm"))(x)
self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda", "rocm"))
module_str = str(exp.mlir_module())
@ -1299,21 +1324,21 @@ class JaxExportTest(jtu.JaxTestCase):
# Call with argument placed on different plaforms
for platform in self.__class__.platforms:
x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = export.call(exp)(x_device)
res_exp = exp.call(x_device)
self.assertAllClose(
res_exp,
_testing_multi_platform_fun_expected(x, platform=platform))
def test_multi_platform_nested(self):
x = np.arange(5, dtype=np.float32)
exp = get_exported(lambda x: _testing_multi_platform_func(jnp.sin(x)),
exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))),
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda","rocm"))
# Now serialize the call to the exported using a different sequence of
# lowering platforms, but included in the lowering platforms for the
# nested exported.
exp2 = get_exported(export.call(exp),
exp2 = get_exported(jax.jit(exp.call),
lowering_platforms=("cpu", "cuda","rocm"))(x)
# Ensure that we do not have multiple lowerings of the exported function
@ -1325,39 +1350,39 @@ class JaxExportTest(jtu.JaxTestCase):
for platform in self.__class__.platforms:
if platform == "tpu": continue
x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = export.call(exp2)(x_device)
res_exp = exp2.call(x_device)
self.assertAllClose(
res_exp,
_testing_multi_platform_fun_expected(np.sin(x), platform=platform))
def test_multi_platform_nested_inside_single_platform_export(self):
x = np.arange(5, dtype=np.float32)
exp = get_exported(_testing_multi_platform_func,
exp = get_exported(jax.jit(_testing_multi_platform_func),
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm"))
# Now serialize the call for the current platform.
exp2 = get_exported(export.call(exp))(x)
exp2 = get_exported(jax.jit(exp.call))(x)
module_str = str(exp2.mlir_module())
self.assertIn("jax.uses_shape_polymorphism = true",
module_str)
res2 = export.call(exp2)(x)
res2 = exp2.call(x)
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
def test_multi_platform_and_poly(self):
if jtu.test_device_matches(["gpu"]):
# The export is not applicable to GPU
raise unittest.SkipTest("Not intended for running on GPU")
exp = get_exported(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,)),
exp = get_exported(jax.jit(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,))),
lowering_platforms=("cpu", "tpu"))(
jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32)
)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
res = export.call(exp)(x)
res = exp.call(x)
self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,)))
# Now serialize the call to the exported
exp2 = get_exported(export.call(exp))(x)
res2 = export.call(exp2)(x)
exp2 = get_exported(jax.jit(exp.call))(x)
res2 = exp2.call(x)
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
def test_multi_platform_and_sharding(self):
@ -1382,7 +1407,7 @@ class JaxExportTest(jtu.JaxTestCase):
continue
run_mesh = Mesh(run_devices, ("x",))
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None))
res_exp = export.call(exp)(a_device)
res_exp = exp.call(a_device)
self.assertArraysAllClose(res_native, res_exp)
@jtu.parameterized_filterable(
@ -1409,7 +1434,7 @@ class JaxExportTest(jtu.JaxTestCase):
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
)
exp = get_exported(f_jax)(x)
exp = get_exported(jax.jit(f_jax))(x)
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in exp.ordered_effects))
self.assertEqual(["ForTestingUnorderedEffect1()"],
@ -1449,7 +1474,7 @@ class JaxExportTest(jtu.JaxTestCase):
x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(
x, effect_class_name="ForTestingUnorderedEffect1") +
export.call(exp)(x))
exp.call(x))
lowered_outer = jax.jit(f_outer).lower(x)
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
@ -1476,7 +1501,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(12, dtype=np.float32).reshape((3, 4))
def f_jax(x): # x: f32[b1, b2]
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(
exp = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct(
export.symbolic_shape("b2, b1"), x.dtype))
mlir_module_str = str(exp.mlir_module())
wrapped_main_expected_re = (
@ -1497,7 +1522,7 @@ class JaxExportTest(jtu.JaxTestCase):
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
res = export.call(exp)(x)
res = exp.call(x)
self.assertAllClose(10. + 2. * x, res)
@jtu.parameterized_filterable(
@ -1518,7 +1543,7 @@ class JaxExportTest(jtu.JaxTestCase):
return 10. + _testing_multi_platform_func(x,
effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(
f_jax,
jax.jit(f_jax),
lowering_platforms=("cpu", "tpu")
)(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype))
mlir_module_str = str(exp.mlir_module())
@ -1541,7 +1566,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
res = export.call(exp)(x)
res = exp.call(x)
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
res)
@ -1586,7 +1611,7 @@ class JaxExportTest(jtu.JaxTestCase):
x,
effect_class_name="ForTestingOrderedEffect" + name)
with self.assertRaisesRegex(Exception, expect_error):
_ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype))
_ = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct((3, 4), x.dtype))
@jtu.parameterized_filterable(
kwargs=[
@ -1603,12 +1628,12 @@ class JaxExportTest(jtu.JaxTestCase):
rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n))
res_native = f_jax(lhs, rhs, group_sizes)
exp_f = get_exported(f_jax)(
exp_f = get_exported(jax.jit(f_jax))(
jax.ShapeDtypeStruct(lhs.shape, dtype=lhs.dtype),
jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype),
jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype),
)
res_exported = export.call(exp_f)(lhs, rhs, group_sizes)
res_exported = exp_f.call(lhs, rhs, group_sizes)
self.assertAllClose(res_native, res_exported)
if __name__ == "__main__":

View File

@ -35,7 +35,7 @@ import operator as op
import re
import jax
from jax.experimental import export
from jax import export
from jax.experimental import pjit
from jax import lax
import jax.numpy as jnp
@ -1267,7 +1267,7 @@ class PolyHarness(Harness):
tst.assertEqual(getattr(jax.config, fname), fvalue, (
f"Flag {fname} current value {getattr(jax.config, fname)} != {fvalue}"))
f_jax = self.dyn_fun
f_jax = jax.jit(self.dyn_fun)
args = self.dyn_args_maker(tst.rng())
args = jax.tree.map(jnp.array, args)
args_specs = export.symbolic_args_specs(args, self.polymorphic_shapes,
@ -1283,7 +1283,7 @@ class PolyHarness(Harness):
return None
# Run the JAX natively and then the exported function and compare
res_jax_native = f_jax(*args)
res_jax_exported = export.call(exp)(*args)
res_jax_exported = exp.call(*args)
custom_assert_lims = [
l for l in self.limitations if l.custom_assert is not None]
assert len(custom_assert_lims) <= 1, custom_assert_lims
@ -1315,7 +1315,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
symbolic_constraints: Sequence[str] = (),
expect_error=None) -> jax.Array | None:
# Builds a PolyHarness and runs the test. See PolyHarness documentation.
h = PolyHarness("", "", f_jax,
h = PolyHarness("", "", jax.jit(f_jax),
arg_descriptors=arg_descriptors,
polymorphic_shapes=polymorphic_shapes,
symbolic_constraints=symbolic_constraints,
@ -1408,11 +1408,10 @@ class ShapePolyTest(jtu.JaxTestCase):
def f_jax(x, *, y):
return x + jnp.sin(y)
f_exported = export.call(
export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype),
y=jax.ShapeDtypeStruct(y.shape, y.dtype)))
self.assertAllClose(f_jax(x, y=y), f_exported(x, y=y))
exp = export.export(jax.jit(f_jax))(
jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype),
y=jax.ShapeDtypeStruct(y.shape, y.dtype))
self.assertAllClose(f_jax(x, y=y), exp.call(x, y=y))
def test_arg_avals_errors(self):
"""Test error reporting for shape polymorphism."""
@ -1617,8 +1616,8 @@ class ShapePolyTest(jtu.JaxTestCase):
acc += jnp.sum(slice, axis=0)
return acc
_ = export.export(f)(jax.ShapeDtypeStruct(export.symbolic_shape("a, b"),
np.int32))
_ = export.export(jax.jit(f))(
jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), np.int32))
def test_constraints_compile_time_check(self):
@ -1630,29 +1629,30 @@ class ShapePolyTest(jtu.JaxTestCase):
x_spec = jax.ShapeDtypeStruct(
export.symbolic_shape("a",
constraints=["a >= 2", "a <= 4"]), np.int32)
exp = export.export(f)(x_spec)
exp = export.export(jax.jit(f))(x_spec)
x_2 = np.arange(2, dtype=np.int32)
res_2 = export.call(exp)(x_2)
res_2 = exp.call(x_2)
self.assertAllClose(x_2[0:2], res_2)
x_4 = np.arange(4, dtype=np.int32)
res_4 = export.call(exp)(x_4)
res_4 = exp.call(x_4)
self.assertAllClose(x_4[1:3], res_4)
with self.assertRaisesRegex(
ValueError,
re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")):
export.call(exp)(np.arange(1, dtype=np.int32))
exp.call(np.arange(1, dtype=np.int32))
with self.assertRaisesRegex(
ValueError,
re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")):
export.call(exp)(np.arange(5, dtype=np.int32))
exp.call(np.arange(5, dtype=np.int32))
def test_caching_with_scopes(self):
f_tracing_count = 0
expected_a_bounds = (1, np.inf)
@jax.jit
def f(x): # x: i32[a]
nonlocal f_tracing_count
f_tracing_count += 1