mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export module, which will be deprecated. The following APIs are introduced: ``` from jax import export def f(...): ... ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs) blob: bytearray = ex.serialize() rehydrated: export.Export = export.deserialize(blob) def caller(...): ... rehydrated.call(*args, **kwargs) ``` Module documentation will follow shortly. There are no changes for now in the jax.experimental.export APIs. Most of the changes in this PR are in tests due to some differences in the new jax.export APIs compared to jax.experimental.export: * Instead of `jax.experimental.export.call(exp)` we now write `exp.call` * The `jax.experimental.export.export` allowed the function argument to be any Python callable and it would wrap it with a `jax.jit`. This is not supported anymore by export, and instead the user must use `jax.jit`.
This commit is contained in:
parent
14d87d3bf7
commit
b33aca6b08
@ -18,7 +18,7 @@ import google_benchmark as benchmark
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax.experimental import export
|
||||
from jax import export
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
@ -333,7 +333,7 @@ class Exported:
|
||||
`jax.device_put`.
|
||||
|
||||
Example usage:
|
||||
>>> from jax.experimental import export
|
||||
>>> from jax import export
|
||||
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
|
||||
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
|
||||
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
|
||||
@ -347,7 +347,7 @@ class Exported:
|
||||
# Put the args and kwargs on the appropriate devices
|
||||
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
|
||||
... exp.in_shardings_jax(run_mesh)[0])
|
||||
>>> res = export.call(exp)(run_arg)
|
||||
>>> res = exp.call(run_arg)
|
||||
>>> res.addressable_shards
|
||||
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
|
||||
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
|
||||
@ -372,19 +372,53 @@ class Exported:
|
||||
for s in self.out_shardings_hlo)
|
||||
|
||||
def has_vjp(self) -> bool:
|
||||
"""Returns if this Exported supports VJP."""
|
||||
return self._get_vjp is not None
|
||||
|
||||
def vjp(self) -> Exported:
|
||||
"""Gets the exported VJP.
|
||||
|
||||
Returns None if not available, which can happen if the Exported has been
|
||||
loaded from an external format, without a VJP."""
|
||||
loaded from an external format without a VJP.
|
||||
"""
|
||||
if self._get_vjp is None:
|
||||
raise ValueError("No VJP is available")
|
||||
return self._get_vjp(self)
|
||||
|
||||
def serialize(self,
|
||||
vjp_order: int = 0) -> bytearray:
|
||||
"""Serializes an Exported.
|
||||
|
||||
Args:
|
||||
vjp_order: The maximum vjp order to include. E.g., the value 2 means that we
|
||||
serialize the primal functions and two orders of the `vjp` function. This
|
||||
should allow 2nd order reverse mode differentiation of the deserialized
|
||||
function. i.e., `jax.grad(jax.grad(f)).`
|
||||
"""
|
||||
# Lazy load the serialization module, since flatbuffers is an optional
|
||||
# dependency.
|
||||
from jax._src.export.serialization import serialize
|
||||
return serialize(self, vjp_order=vjp_order)
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
return call_exported(self)(*args, **kwargs)
|
||||
|
||||
|
||||
def deserialize(blob: bytearray) -> Exported:
|
||||
"""Deserializes an Exported.
|
||||
|
||||
Args:
|
||||
blob: a bytearray obtained from `Exported.serialize`.
|
||||
"""
|
||||
# Lazy load the serialization module, since flatbuffers is an optional
|
||||
# dependency.
|
||||
from jax._src.export.serialization import deserialize
|
||||
return deserialize(blob)
|
||||
|
||||
|
||||
def default_lowering_platform() -> str:
|
||||
"""Retrieves the default lowering platform for the exporting machine.
|
||||
"""
|
||||
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
|
||||
return xb.canonicalize_platform(jax.default_backend())
|
||||
|
||||
@ -411,7 +445,9 @@ def args_specs(
|
||||
return shape_poly.symbolic_args_specs(args, polymorphic_shapes)
|
||||
|
||||
|
||||
def export(fun_jax: Callable,
|
||||
# TODO(necula): remove this once we remove jax.experimental.export.
|
||||
def export_back_compat(
|
||||
fun_jax: Callable,
|
||||
*,
|
||||
lowering_platforms: Sequence[str] | None = None,
|
||||
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
@ -419,6 +455,10 @@ def export(fun_jax: Callable,
|
||||
) -> Callable[..., Exported]:
|
||||
"""Exports native serialization for a JAX function.
|
||||
|
||||
Note: this function exists only for internal usage by jax2tf and for
|
||||
backwards compatibility with jax.experimental.export. Use
|
||||
`jax.export` instead.
|
||||
|
||||
Args:
|
||||
fun_jax: the function to lower and serialize.
|
||||
lowering_platforms:
|
||||
@ -498,6 +538,85 @@ def export(fun_jax: Callable,
|
||||
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
|
||||
return do_export
|
||||
|
||||
def export(
|
||||
fun_jit: stages.Wrapped,
|
||||
*,
|
||||
lowering_platforms: Sequence[str] | None = None,
|
||||
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
) -> Callable[..., Exported]:
|
||||
"""Exports a JAX function for persistent serialization.
|
||||
|
||||
Args:
|
||||
fun_jit: the function to export. Should be the result of `jit`.
|
||||
lowering_platforms:
|
||||
Optional sequence containing a subset of 'tpu', 'cpu',
|
||||
'cuda', 'rocm'. If more than one platform is specified, then
|
||||
the lowered code takes an argument specifying the platform.
|
||||
If None, then use the default JAX backend.
|
||||
The calling convention for multiple platforms is explained in the
|
||||
`jax_export.Exported` docstring.
|
||||
disabled_checks: the safety checks to disable. See docstring
|
||||
of `DisabledSafetyCheck`.
|
||||
|
||||
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
|
||||
or values with `.shape` and `.dtype` attributes, and returns an
|
||||
`Exported`.
|
||||
|
||||
Usage:
|
||||
>>> from jax import export
|
||||
>>> exported: export.Exported = export.export(jnp.sin)(
|
||||
... np.arange(4, dtype=np.float32))
|
||||
|
||||
# You can inspect the Exported object
|
||||
>>> exported.in_avals
|
||||
(ShapedArray(float32[4]),)
|
||||
>>> blob: bytearray = exported.serialize()
|
||||
|
||||
# The serialized bytes are safe to use in a separate process
|
||||
>>> rehydrated: export.Exported = export.deserialize(blob)
|
||||
>>> rehydrated.fun_name
|
||||
'sin'
|
||||
>>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32))
|
||||
Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)
|
||||
"""
|
||||
if not isinstance(fun_jit, stages.Wrapped):
|
||||
raise ValueError(
|
||||
f"Function to be exported must be the result of `jit` but is: {fun_jit}")
|
||||
if lowering_platforms is not None:
|
||||
actual_lowering_platforms = tuple(lowering_platforms)
|
||||
else:
|
||||
actual_lowering_platforms = (default_lowering_platform(),)
|
||||
|
||||
def do_export(*args_specs, **kwargs_specs) -> Exported:
|
||||
# TODO: move to `lower`
|
||||
symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore]
|
||||
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
|
||||
# Static args may have no `shape` attribute.
|
||||
if not hasattr(aval, "shape"):
|
||||
continue
|
||||
for d in aval.shape:
|
||||
if shape_poly.is_symbolic_dim(d):
|
||||
if symbolic_scope is None:
|
||||
symbolic_scope = (d.scope, k_path)
|
||||
continue
|
||||
symbolic_scope[0]._check_same_scope(
|
||||
d, when=f"when exporting {util.fun_name(fun_jit)}",
|
||||
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
|
||||
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
|
||||
|
||||
traced = fun_jit.trace( # type: ignore
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=actual_lowering_platforms,
|
||||
for_export=True,
|
||||
))
|
||||
jaxpr, fun_name = traced.jaxpr, traced.fun_name
|
||||
lowered = traced.lower()
|
||||
return _export_lowered(
|
||||
lowered, jaxpr, fun_name,
|
||||
disabled_checks=disabled_checks)
|
||||
return do_export
|
||||
|
||||
def _export_lowered(
|
||||
lowered: stages.Lowered,
|
||||
jaxpr: core.ClosedJaxpr, fun_name: str,
|
||||
@ -599,7 +718,7 @@ def _export_lowered(
|
||||
device_assignment=device_assignment,
|
||||
apply_jit=True,
|
||||
flat_primal_fun=True)
|
||||
return export(fun_vjp_jax,
|
||||
return export(fun_vjp_jax, # type: ignore[arg-type]
|
||||
lowering_platforms=exp_primal.lowering_platforms,
|
||||
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
|
||||
|
||||
@ -816,7 +935,7 @@ def _wrap_main_func(
|
||||
|
||||
def _check_lowering(lowering) -> None:
|
||||
if not isinstance(lowering, pxla.MeshComputation):
|
||||
raise NotImplementedError(f"serialization is supported only for pjit. {lowering}")
|
||||
raise NotImplementedError(f"serialization is supported only for jit. {lowering}")
|
||||
|
||||
if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]:
|
||||
raise NotImplementedError("serialization of host_callbacks is not yet implemented")
|
||||
|
@ -48,7 +48,7 @@ SerT = TypeVar("SerT")
|
||||
_SERIALIZATION_VERSION = 2
|
||||
|
||||
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
|
||||
"""Serialize an Exported.
|
||||
"""Serializes an Exported.
|
||||
|
||||
Args:
|
||||
exp: the Exported to serialize.
|
||||
@ -64,7 +64,7 @@ def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
|
||||
|
||||
|
||||
def deserialize(ser: bytearray) -> _export.Exported:
|
||||
"""Deserialize an Exported."""
|
||||
"""Deserializes an Exported."""
|
||||
exp = ser_flatbuf.Exported.GetRootAsExported(ser)
|
||||
return _deserialize_exported(exp)
|
||||
|
||||
|
@ -86,7 +86,7 @@ from numpy import array, float32
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.experimental import export
|
||||
from jax import export
|
||||
|
||||
from jax.experimental import pjit
|
||||
|
||||
@ -345,4 +345,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
_get_vjp=_get_vjp)
|
||||
|
||||
# We use pjit in case there are shardings in the exported module.
|
||||
return pjit.pjit(export.call(exported))(*data.inputs)
|
||||
return pjit.pjit(exported.call)(*data.inputs)
|
||||
|
@ -32,7 +32,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NamedTuple, Protocol, Union
|
||||
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
@ -756,8 +756,9 @@ class Lowered(Stage):
|
||||
return None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Wrapped(Protocol):
|
||||
"""A function ready to be specialized, lowered, and compiled.
|
||||
"""A function ready to be traced, lowered, and compiled.
|
||||
|
||||
This protocol reflects the output of functions such as
|
||||
``jax.jit``. Calling it results in JIT (just-in-time) lowering,
|
||||
|
@ -17,12 +17,13 @@ from jax._src.export._export import (
|
||||
minimum_supported_serialization_version,
|
||||
maximum_supported_serialization_version,
|
||||
Exported,
|
||||
export,
|
||||
call_exported, # TODO: deprecate
|
||||
call,
|
||||
DisabledSafetyCheck,
|
||||
default_lowering_platform,
|
||||
default_lowering_platform, # TODO: deprecate
|
||||
)
|
||||
from jax._src.export._export import export_back_compat as export
|
||||
|
||||
from jax._src.export.shape_poly import (
|
||||
is_symbolic_dim,
|
||||
symbolic_shape,
|
||||
@ -33,4 +34,6 @@ from jax._src.export.serialization import (
|
||||
serialize,
|
||||
deserialize,
|
||||
)
|
||||
# Import only to set the shape poly decision procedure
|
||||
from jax._src.export import shape_poly_decision
|
||||
del shape_poly_decision
|
||||
|
@ -36,7 +36,7 @@ from jax import random
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax import sharding
|
||||
from jax.experimental import export
|
||||
from jax import export
|
||||
from jax.experimental.jax2tf import impl_no_xla
|
||||
from jax.interpreters import xla
|
||||
|
||||
@ -515,7 +515,7 @@ class NativeSerializationImpl(SerializationImpl):
|
||||
|
||||
self._restore_context = _restore_context
|
||||
_exported_device_assignment = [None]
|
||||
self.exported = export.export(
|
||||
self.exported = _export.export_back_compat(
|
||||
self.fun_jax,
|
||||
lowering_platforms=self.native_serialization_platforms,
|
||||
disabled_checks=self.native_serialization_disabled_checks,
|
||||
|
@ -25,13 +25,13 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import dlpack
|
||||
from jax import dtypes
|
||||
from jax import export
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax.experimental import export
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
import numpy as np
|
||||
@ -778,7 +778,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
lowering_platforms = ("tpu", "cpu", "cuda")
|
||||
|
||||
exp = export.export(f_jax,
|
||||
exp = export.export(jax.jit(f_jax),
|
||||
lowering_platforms=lowering_platforms)(x)
|
||||
for jax_platform in jax_and_tf_platforms:
|
||||
with self.subTest(jax_platform):
|
||||
@ -787,7 +787,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
logging.info("Running harness natively on %s", jax_device)
|
||||
native_res = f_jax(x_device)
|
||||
logging.info("Running exported harness on %s", jax_device)
|
||||
exported_res = export.call(exp)(x_device)
|
||||
exported_res = exp.call(x_device)
|
||||
self.assertAllClose(native_res, exported_res)
|
||||
|
||||
def test_multi_platform_call_tf_graph(self):
|
||||
|
@ -27,6 +27,7 @@ from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
from jax import ad_checkpoint
|
||||
from jax import dtypes
|
||||
from jax import export
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import sharding
|
||||
@ -37,7 +38,6 @@ from jax._src import source_info_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental import export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.experimental import pjit
|
||||
@ -1559,7 +1559,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
# Run the JAX native version, to check it works, and to fill caches.
|
||||
_ = func_to_convert(*args)
|
||||
exported = export.export(
|
||||
func_to_convert,
|
||||
(jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert),
|
||||
lowering_platforms=("tpu",)
|
||||
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
|
||||
|
||||
|
@ -32,8 +32,8 @@ import re
|
||||
|
||||
import jax
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental import export
|
||||
from jax.experimental import pjit
|
||||
from jax import export
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
@ -31,7 +31,7 @@ from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental import export
|
||||
from jax import export
|
||||
from jax._src import config
|
||||
from jax._src import xla_bridge
|
||||
import numpy as np
|
||||
|
34
jax/export.py
Normal file
34
jax/export.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
|
||||
"maximum_supported_serialization_version",
|
||||
"minimum_supported_serialization_version",
|
||||
"default_lowering_platform",
|
||||
"SymbolicScope", "is_symbolic_dim",
|
||||
"symbolic_shape", "symbolic_args_specs"]
|
||||
|
||||
from jax._src.export._export import DisabledSafetyCheck as DisabledSafetyCheck
|
||||
from jax._src.export._export import Exported as Exported
|
||||
from jax._src.export._export import export as export
|
||||
from jax._src.export._export import deserialize as deserialize
|
||||
from jax._src.export._export import maximum_supported_serialization_version as maximum_supported_serialization_version
|
||||
from jax._src.export._export import minimum_supported_serialization_version as minimum_supported_serialization_version
|
||||
from jax._src.export._export import default_lowering_platform as default_lowering_platform
|
||||
|
||||
from jax._src.export import shape_poly_decision # Import only to set the decision procedure
|
||||
del shape_poly_decision
|
||||
from jax._src.export.shape_poly import SymbolicScope as SymbolicScope
|
||||
from jax._src.export.shape_poly import is_symbolic_dim as is_symbolic_dim
|
||||
from jax._src.export.shape_poly import symbolic_shape as symbolic_shape
|
||||
from jax._src.export.shape_poly import symbolic_args_specs as symbolic_args_specs
|
@ -31,9 +31,9 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import export
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import export
|
||||
from jax._src.internal_test_util import test_harnesses
|
||||
|
||||
|
||||
@ -152,7 +152,8 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
logging.info("Exporting harness for %s", lowering_platforms)
|
||||
exp = export.export(func_jax, lowering_platforms=lowering_platforms)(*args)
|
||||
exp = export.export(jax.jit(func_jax),
|
||||
lowering_platforms=lowering_platforms)(*args)
|
||||
|
||||
for device in devices:
|
||||
if device.platform in skip_run_on_platforms:
|
||||
@ -164,7 +165,7 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
logging.info("Running harness natively on %s", device)
|
||||
native_res = func_jax(*device_args)
|
||||
logging.info("Running exported harness on %s", device)
|
||||
exported_res = export.call(exp)(*device_args)
|
||||
exported_res = exp.call(*device_args)
|
||||
if tol is not None:
|
||||
logging.info(f"Using non-standard tolerance {tol}")
|
||||
self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol)
|
||||
|
@ -20,12 +20,13 @@ import logging
|
||||
import math
|
||||
import re
|
||||
import unittest
|
||||
from typing import Callable
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import export
|
||||
from jax import export
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import NamedSharding
|
||||
@ -137,12 +138,12 @@ def _testing_multi_platform_fun_expected(x,
|
||||
]
|
||||
|
||||
|
||||
def get_exported(fun, vjp_order=0,
|
||||
def get_exported(fun: Callable, vjp_order=0,
|
||||
**export_kwargs):
|
||||
"""Like export.export but with serialization + deserialization."""
|
||||
def serde_exported(*fun_args, **fun_kwargs):
|
||||
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
|
||||
serialized = export.serialize(exp, vjp_order=vjp_order)
|
||||
serialized = exp.serialize(vjp_order=vjp_order)
|
||||
return export.deserialize(serialized)
|
||||
return serde_exported
|
||||
|
||||
@ -164,11 +165,13 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
super().setUpClass()
|
||||
|
||||
def test_basic_export_only(self):
|
||||
@jax.jit
|
||||
def my_fun(x):
|
||||
return jnp.sin(x)
|
||||
exp = get_exported(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
|
||||
self.assertEqual("my_fun", exp.fun_name)
|
||||
self.assertEqual((export.default_lowering_platform(),),
|
||||
expected_lowering_platform = xb.canonicalize_platform(jax.default_backend())
|
||||
self.assertEqual((expected_lowering_platform,),
|
||||
exp.lowering_platforms)
|
||||
self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree)
|
||||
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
|
||||
@ -180,7 +183,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f(a_b_pair, *, a, b):
|
||||
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
|
||||
|
||||
exp = get_exported(f, lowering_platforms=("cpu",))((a, b), a=a, b=b)
|
||||
exp = get_exported(jax.jit(f), lowering_platforms=("cpu",))((a, b), a=a, b=b)
|
||||
a_aval = core.ShapedArray(a.shape, a.dtype)
|
||||
b_aval = core.ShapedArray(b.shape, b.dtype)
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu",))
|
||||
@ -196,8 +199,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x)
|
||||
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(f(x), f1(x))
|
||||
self.assertAllClose(f(x), exp_f.call(x))
|
||||
|
||||
def test_jit_static_arg(self):
|
||||
|
||||
@ -210,8 +212,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x, c=0.1)
|
||||
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(f(x, c=0.1), f1(x))
|
||||
self.assertAllClose(f(x, c=0.1), exp_f.call(x))
|
||||
|
||||
with self.subTest("static_argnums"):
|
||||
|
||||
@ -222,16 +223,32 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_g = get_exported(g)(x, 0.1)
|
||||
|
||||
g1 = export.call(exp_g)
|
||||
self.assertAllClose(g(x, 0.1), g1(x))
|
||||
self.assertAllClose(g(x, 0.1), exp_g.call(x))
|
||||
|
||||
def test_export_error_no_jit(self):
|
||||
# Can export a lambda, without jit
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Function to be exported must be the result of `jit`"):
|
||||
_ = export.export(lambda x: jnp.sin(x))
|
||||
|
||||
def test_export_experimental_back_compat(self):
|
||||
from jax.experimental import export
|
||||
# Can export a lambda, without jit
|
||||
exp = export.export(lambda x: jnp.sin(x))(.1)
|
||||
self.assertAllClose(exp.call(1.), np.sin(1.))
|
||||
|
||||
blob = export.serialize(exp, vjp_order=1)
|
||||
rehydrated = export.deserialize(blob)
|
||||
|
||||
self.assertAllClose(export.call(exp)(1.), np.sin(1.))
|
||||
self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.))
|
||||
|
||||
def test_call_exported_lambda(self):
|
||||
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
|
||||
f = lambda x: jnp.sin(x)
|
||||
f = jax.jit(lambda x: jnp.sin(x))
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x)
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(f(x), f1(x))
|
||||
self.assertAllClose(f(x), exp_f.call(x))
|
||||
|
||||
def test_call_name_conflict(self):
|
||||
@jax.jit
|
||||
@ -246,7 +263,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def outer(x):
|
||||
# There should be no conflict on _where
|
||||
x = export.call(exp_inner)(x)
|
||||
x = exp_inner.call(x)
|
||||
return inner(x)
|
||||
|
||||
export.export(outer)(x)
|
||||
@ -257,19 +274,18 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f1(x):
|
||||
exp_f = get_exported(f)(x)
|
||||
return export.call(exp_f)(x) + export.call(exp_f)(x)
|
||||
exp_f = get_exported(jax.jit(f))(x)
|
||||
return exp_f.call(x) + exp_f.call(x)
|
||||
|
||||
self.assertAllClose(2. * f(x), f1(x))
|
||||
|
||||
def test_unused_args(self):
|
||||
f = lambda x, y: jnp.sin(x)
|
||||
f = jax.jit(lambda x, y: jnp.sin(x))
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
y = np.arange(6, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x, y)
|
||||
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(f(x, y), f1(x, y))
|
||||
self.assertAllClose(f(x, y), exp_f.call(x, y))
|
||||
|
||||
def test_pytree(self):
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
@ -277,43 +293,49 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f(a_b_pair, a, b):
|
||||
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
|
||||
|
||||
exp_f = get_exported(f)((a, b), a=a, b=b)
|
||||
f1 = export.call(exp_f)
|
||||
exp_f = get_exported(jax.jit(f))((a, b), a=a, b=b)
|
||||
self.assertAllClose(f((a, b), a=a, b=b),
|
||||
f1((a, b), a=a, b=b))
|
||||
exp_f.call((a, b), a=a, b=b))
|
||||
|
||||
def test_error_wrong_intree(self):
|
||||
def f(a_b_pair, *, c):
|
||||
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c
|
||||
a = b = c = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)((a, b), c=c)
|
||||
exp_f = get_exported(jax.jit(f))((a, b), c=c)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The invocation args and kwargs must have the same pytree structure"):
|
||||
export.call(exp_f)(a, b, c=(a, b))
|
||||
exp_f.call(a, b, c=(a, b))
|
||||
|
||||
def test_error_wrong_avals(self):
|
||||
def f(a, *, b): # a: f32[4] and b: f32[4]
|
||||
return jnp.sin(a) + jnp.cos(b)
|
||||
f32_4 = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)(f32_4, b=f32_4)
|
||||
exp_f = get_exported(jax.jit(f))(f32_4, b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Shape mismatch for args\[0\].shape\[0\]"):
|
||||
export.call(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
|
||||
exp_f.call(np.arange(6, dtype=np.float32), b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
|
||||
export.call(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
|
||||
exp_f.call(f32_4, b=np.arange(6, dtype=np.float32))
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Rank mismatch for args\[0\]"):
|
||||
export.call(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
|
||||
exp_f.call(f32_4.reshape((1, 4)), b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Dtype mismatch for args\[0\]"):
|
||||
export.call(exp_f)(f32_4.astype(np.float16), b=f32_4)
|
||||
exp_f.call(f32_4.astype(np.float16), b=f32_4)
|
||||
|
||||
def test_default_lowering_platform(self):
|
||||
test_platform = jtu.device_under_test()
|
||||
if test_platform == "gpu": test_platform = "cuda"
|
||||
self.assertEqual(export.default_lowering_platform(), test_platform)
|
||||
exp = export.export(jnp.sin)(1.)
|
||||
self.assertEqual(exp.lowering_platforms, (export.default_lowering_platform(),))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw: kw["platform"],
|
||||
@ -328,13 +350,13 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The exported function .* was lowered for platform"):
|
||||
export.call(exp_f)(a)
|
||||
exp_f.call(a)
|
||||
|
||||
# Now try with the platform check disabled
|
||||
exp_f_no_platform_check = get_exported(
|
||||
jnp.sin, lowering_platforms=(platform,),
|
||||
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
|
||||
res = export.call(exp_f_no_platform_check)(a)
|
||||
res = exp_f_no_platform_check.call(a)
|
||||
self.assertAllClose(res, jnp.sin(a))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -357,12 +379,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Cannot serialize code with custom calls whose targets .*"):
|
||||
get_exported(
|
||||
lambda a: a + test_primitive.bind(a)
|
||||
jax.jit(lambda a: a + test_primitive.bind(a))
|
||||
)(a)
|
||||
|
||||
# Now try again with the safety check disabled
|
||||
exp = get_exported(
|
||||
lambda a: a + test_primitive.bind(a),
|
||||
jax.jit(lambda a: a + test_primitive.bind(a)),
|
||||
disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")]
|
||||
)(a)
|
||||
self.assertIn("disallowed_call_target", exp.mlir_module())
|
||||
@ -387,22 +409,22 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Lowering for export not supported"):
|
||||
export.export(f)(a)
|
||||
export.export(jax.jit(f))(a)
|
||||
|
||||
def test_grad(self):
|
||||
f = lambda x: jnp.sum(jnp.sin(x))
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f, vjp_order=1)(x)
|
||||
exp_f = get_exported(jax.jit(f), vjp_order=1)(x)
|
||||
|
||||
f1 = export.call(exp_f)
|
||||
f1 = exp_f.call
|
||||
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))
|
||||
|
||||
def test_higher_order_grad(self):
|
||||
f = lambda x: x ** 3
|
||||
x = np.float32(4.)
|
||||
exp_f = get_exported(f, vjp_order=3)(x)
|
||||
exp_f = get_exported(jax.jit(f), vjp_order=3)(x)
|
||||
|
||||
f1 = export.call(exp_f)
|
||||
f1 = exp_f.call
|
||||
self.assertAllClose(jax.grad(f)(x),
|
||||
jax.grad(f1)(x))
|
||||
self.assertAllClose(jax.grad(jax.grad(f))(x),
|
||||
@ -428,8 +450,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(res, (xi_ct, xf_ct))
|
||||
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))
|
||||
|
||||
exp = get_exported(f, vjp_order=2)(xi, xf)
|
||||
fr = export.call(exp)
|
||||
exp = get_exported(jax.jit(f), vjp_order=2)(xi, xf)
|
||||
fr = exp.call
|
||||
|
||||
res = fr(xi, xf)
|
||||
self.assertAllClose(res, (f_outi, f_outf))
|
||||
@ -456,14 +478,14 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
b = np.arange(6, dtype=np.float32)
|
||||
exp_f = get_exported(f, vjp_order=1)((a, b), a=a, b=b)
|
||||
exp_f = get_exported(jax.jit(f), vjp_order=1)((a, b), a=a, b=b)
|
||||
|
||||
out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent
|
||||
def f1_jax(a, b): # For VJP, make a function without kwargs
|
||||
res = f((a, b), a=a, b=b)
|
||||
return res
|
||||
def f1_exp(a, b): # For VJP, make a function without kwargs
|
||||
res = export.call(exp_f)((a, b), a=a, b=b)
|
||||
res = exp_f.call((a, b), a=a, b=b)
|
||||
return res
|
||||
jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct)
|
||||
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
|
||||
@ -473,15 +495,15 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f1(x):
|
||||
return jnp.sin(x)
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
exp_f1 = get_exported(f1)(a)
|
||||
exp_f1 = get_exported(jax.jit(f1))(a)
|
||||
def f2(x):
|
||||
res1 = export.call(exp_f1)(x)
|
||||
res2 = export.call(exp_f1)(res1)
|
||||
res1 = exp_f1.call(x)
|
||||
res2 = exp_f1.call(res1)
|
||||
return jnp.cos(res2)
|
||||
exp_f2 = get_exported(f2)(a)
|
||||
exp_f2 = get_exported(jax.jit(f2))(a)
|
||||
|
||||
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
|
||||
export.call(exp_f2)(a))
|
||||
exp_f2.call(a))
|
||||
|
||||
def test_poly_export_only(self):
|
||||
a = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
@ -489,7 +511,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return jnp.concatenate([a, b], axis=0)
|
||||
|
||||
scope = export.SymbolicScope()
|
||||
exp = get_exported(f)(
|
||||
exp = get_exported(jax.jit(f))(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("(2*w, h)", scope=scope), a.dtype),
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)", scope=scope), a.dtype))
|
||||
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
|
||||
@ -528,7 +550,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return jnp.concatenate([a0, a1, ak], axis=0)
|
||||
|
||||
a_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)"), a.dtype)
|
||||
exp = get_exported(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
|
||||
exp = get_exported(jax.jit(f))(a_poly_spec, a_poly_spec, ak=a_poly_spec)
|
||||
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
|
||||
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
|
||||
|
||||
@ -545,7 +567,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
"Invalid mixing of symbolic scopes when exporting f.*"
|
||||
r"Expected current \(from args\[0\]\) scope .*"
|
||||
r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)):
|
||||
get_exported(f)(x_poly_spec, y_poly_spec)
|
||||
get_exported(jax.jit(f))(x_poly_spec, y_poly_spec)
|
||||
|
||||
def test_poly_export_callable_with_no_name(self):
|
||||
# This was reported by a user
|
||||
@ -566,7 +588,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
a, = export.symbolic_shape("a,")
|
||||
# No error
|
||||
_ = get_exported(MyCallable())(
|
||||
_ = get_exported(jax.jit(MyCallable()))(
|
||||
jax.ShapeDtypeStruct((a, a), dtype=np.float32)
|
||||
)
|
||||
|
||||
@ -590,7 +612,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp = get_exported(jnp.sin)(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
|
||||
x = np.arange(30, dtype=np.float32).reshape((5, 6))
|
||||
res = export.call(exp)(x)
|
||||
res = exp.call(x)
|
||||
self.assertAllClose(res, np.sin(x))
|
||||
|
||||
# A function is exported with f32[poly_spec] and is called with different arg
|
||||
@ -631,7 +653,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return jnp.reshape(x, (-1, x.shape[1]))
|
||||
|
||||
disabled_checks = ()
|
||||
exp_f = get_exported(f, disabled_checks=disabled_checks)(
|
||||
exp_f = get_exported(jax.jit(f), disabled_checks=disabled_checks)(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), np.float32))
|
||||
self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12")
|
||||
arg = np.arange(np.prod(arg_shape),
|
||||
@ -642,7 +664,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error))
|
||||
|
||||
assert core.is_constant_shape(arg.shape)
|
||||
res = export.call(exp_f)(arg)
|
||||
res = exp_f.call(arg)
|
||||
|
||||
if not expect_error:
|
||||
self.assertAllClose(res, f(arg))
|
||||
@ -733,7 +755,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
arg = np.arange(np.prod(arg_shape),
|
||||
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
|
||||
inner_exp = get_exported(inner)(
|
||||
inner_exp = get_exported(jax.jit(inner))(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape(inner_poly_spec), np.float32))
|
||||
|
||||
self.assertEqual(inner_exp.uses_shape_polymorphism,
|
||||
@ -741,14 +763,14 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def outer(x): # x: outer_poly_spec
|
||||
# Use an addition to test that the shapes are refined properly for the
|
||||
# result of the call_exported.
|
||||
return export.call(inner_exp)(x) + inner(x)
|
||||
return inner_exp.call(x) + inner(x)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error_outer_exp is not None:
|
||||
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
|
||||
|
||||
# Call it after exporting again, with polymorphic shapes
|
||||
outer_exp = get_exported(outer)(
|
||||
outer_exp = get_exported(jax.jit(outer))(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape(outer_poly_spec), arg.dtype))
|
||||
|
||||
if expect_error_outer_exp is not None:
|
||||
@ -761,7 +783,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
if expect_error_run is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
|
||||
|
||||
res = export.call(outer_exp)(arg)
|
||||
res = outer_exp.call(arg)
|
||||
|
||||
if expect_error_run is not None:
|
||||
return
|
||||
@ -823,20 +845,21 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
|
||||
exp = get_exported(f_jax)(
|
||||
exp = get_exported(jax.jit(f_jax))(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype))
|
||||
export.call(exp)(x)
|
||||
exp.call(x)
|
||||
|
||||
def test_poly_booleans(self):
|
||||
# For booleans we use a special case ConvertOp to cast to and from
|
||||
# dynamic shapes arguments.
|
||||
@jax.jit
|
||||
def f_jax(x): # x: bool[b]
|
||||
return jnp.logical_not(x)
|
||||
|
||||
x = np.array([True, False, True, False], dtype=np.bool_)
|
||||
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
||||
x.dtype))
|
||||
res = export.call(exp)(x)
|
||||
res = exp.call(x)
|
||||
self.assertAllClose(f_jax(x), res)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -851,13 +874,14 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
"int4",
|
||||
"uint4"}:
|
||||
self.skipTest(f"TODO: serialization not supported for {str(dtype)}")
|
||||
@jax.jit
|
||||
def f_jax(x):
|
||||
return x + x
|
||||
|
||||
x = np.arange(6, dtype=dtype)
|
||||
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
||||
x.dtype))
|
||||
res = export.call(exp)(x)
|
||||
res = exp.call(x)
|
||||
self.assertAllClose(f_jax(x), res)
|
||||
|
||||
def test_poly_expressions(self):
|
||||
@ -867,6 +891,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return (b + b, b - b, b * b,
|
||||
(b + 13) // b, (b + 13) % b,
|
||||
core.max_dim(b - 5, 0))
|
||||
@jax.jit
|
||||
def f(x): # x: f32[b]
|
||||
b = x.shape[0]
|
||||
return jnp.ones(output_shape(b), dtype=x.dtype)
|
||||
@ -874,12 +899,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
||||
x.dtype))
|
||||
# Call with static shapes
|
||||
res = export.call(exp)(x)
|
||||
res = exp.call(x)
|
||||
self.assertAllClose(res, f(x))
|
||||
|
||||
# Now re-export with shape polymorphism
|
||||
x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype)
|
||||
exp2 = get_exported(export.call(exp))(x_spec)
|
||||
exp2 = get_exported(jax.jit(exp.call))(x_spec)
|
||||
a = exp2.in_avals[0].shape[0]
|
||||
self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
|
||||
|
||||
@ -890,9 +915,9 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1))
|
||||
|
||||
a, = export.symbolic_shape("a")
|
||||
exp = export.export(f)(
|
||||
exp = export.export(jax.jit(f))(
|
||||
jax.ShapeDtypeStruct((a, 4), np.float32))
|
||||
f_exp = export.call(exp)
|
||||
f_exp = exp.call
|
||||
x_jit = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
res_jit = jax.jit(f_exp)(x_jit)
|
||||
self.assertAllClose(res_jit, f(x_jit))
|
||||
@ -931,24 +956,24 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# We apply the out_shardings for f_jax
|
||||
r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
|
||||
re.DOTALL)
|
||||
hlo = jax.jit(export.call(exp)).lower(a_device).as_text()
|
||||
hlo = jax.jit(exp.call).lower(a_device).as_text()
|
||||
self.assertRegex(hlo, expected_re)
|
||||
|
||||
res_exported = export.call(exp)(a_device)
|
||||
res_exported = exp.call(a_device)
|
||||
self.assertAllClose(res_native, res_exported)
|
||||
|
||||
# Test error reporting
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
_ = export.call(exp)(a)
|
||||
_ = exp.call(a)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
|
||||
_ = jax.jit(
|
||||
export.call(exp),
|
||||
exp.call,
|
||||
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
|
||||
)(a)
|
||||
|
||||
@ -973,7 +998,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
run_input_shardings = exp.in_shardings_jax(run_mesh)
|
||||
a_run = jax.device_put(a, run_input_shardings[0])
|
||||
b_run = jax.device_put(a, run_input_shardings[1])
|
||||
res = export.call(exp)(a_run, b_run)
|
||||
res = exp.call(a_run, b_run)
|
||||
self.assertEqual(res.addressable_shards[0].device, run_devices[0])
|
||||
self.assertEqual(res.addressable_shards[1].device, run_devices[1])
|
||||
|
||||
@ -996,7 +1021,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
run_mesh = Mesh(run_devices, "i")
|
||||
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
|
||||
|
||||
res_exported = export.call(exp)(b)
|
||||
res_exported = exp.call(b)
|
||||
self.assertAllClose(res_native, res_exported)
|
||||
|
||||
def test_call_with_different_no_of_devices_error_has_in_shardings(self):
|
||||
@ -1024,7 +1049,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
"Exported module .* was lowered for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* module contains "
|
||||
"non-replicated sharding annotations"):
|
||||
export.call(exp)(b)
|
||||
exp.call(b)
|
||||
|
||||
def test_call_with_different_no_of_devices_pmap(self):
|
||||
if len(jax.devices()) < 2:
|
||||
@ -1042,7 +1067,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape(
|
||||
(-1, 1, 100)
|
||||
)
|
||||
res_exported = jax.pmap(export.call(exp))(b)
|
||||
res_exported = jax.pmap(exp.call)(b)
|
||||
self.assertAllClose(res_native, res_exported[0])
|
||||
|
||||
def test_call_with_different_no_of_devices_error_has_sharding_constraint(self):
|
||||
@ -1070,7 +1095,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
"Exported module .* was lowered for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* module contains "
|
||||
"non-replicated sharding annotations"):
|
||||
export.call(exp)(b)
|
||||
exp.call(b)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
@ -1108,7 +1133,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertLen(res_jax.addressable_shards, len(devices))
|
||||
|
||||
# Test reloaded execution.
|
||||
f_r = export.call(exp)
|
||||
f_r = exp.call
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"Exported module .* was lowered for 2 devices and is "
|
||||
@ -1241,14 +1266,14 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertEqual(exp_vjp2.nr_devices, 2)
|
||||
call_mesh = Mesh(jax.devices()[:2], "e")
|
||||
|
||||
g1 = pjit.pjit(export.call(exp_vjp),
|
||||
g1 = pjit.pjit(exp_vjp.call,
|
||||
in_shardings=(NamedSharding(call_mesh, None),
|
||||
NamedSharding(call_mesh, None)))(x, x.T)
|
||||
_, f_jax_vjp = jax.vjp(f_jax, x)
|
||||
xbar = f_jax_vjp(x.T)
|
||||
self.assertAllClose(xbar, g1)
|
||||
|
||||
g2 = pjit.pjit(export.call(exp_vjp2),
|
||||
g2 = pjit.pjit(exp_vjp2.call,
|
||||
in_shardings=(NamedSharding(call_mesh, None),
|
||||
NamedSharding(call_mesh, None),
|
||||
NamedSharding(call_mesh, None)))(x, x.T, x)
|
||||
@ -1274,16 +1299,16 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp = export.export(pjit.pjit(f, in_shardings=shardings))(input)
|
||||
exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards)
|
||||
|
||||
_ = export.serialize(exp, vjp_order=1)
|
||||
_ = export.serialize(exp_rev, vjp_order=1)
|
||||
_ = exp.serialize(vjp_order=1)
|
||||
_ = exp_rev.serialize(vjp_order=1)
|
||||
|
||||
g = jax.grad(export.call(exp_rev))(input_rev)
|
||||
g_rev = jax.grad(export.call(exp))(input)
|
||||
g = jax.grad(exp_rev.call)(input_rev)
|
||||
g_rev = jax.grad(exp.call)(input)
|
||||
self.assertAllClose(g, g_rev)
|
||||
|
||||
def test_multi_platform(self):
|
||||
x = np.arange(8, dtype=np.float32)
|
||||
exp = get_exported(_testing_multi_platform_func,
|
||||
exp = get_exported(jax.jit(_testing_multi_platform_func),
|
||||
lowering_platforms=("tpu", "cpu", "cuda","rocm"))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda", "rocm"))
|
||||
module_str = str(exp.mlir_module())
|
||||
@ -1299,21 +1324,21 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# Call with argument placed on different plaforms
|
||||
for platform in self.__class__.platforms:
|
||||
x_device = jax.device_put(x, jax.devices(platform)[0])
|
||||
res_exp = export.call(exp)(x_device)
|
||||
res_exp = exp.call(x_device)
|
||||
self.assertAllClose(
|
||||
res_exp,
|
||||
_testing_multi_platform_fun_expected(x, platform=platform))
|
||||
|
||||
def test_multi_platform_nested(self):
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
exp = get_exported(lambda x: _testing_multi_platform_func(jnp.sin(x)),
|
||||
exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))),
|
||||
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda","rocm"))
|
||||
|
||||
# Now serialize the call to the exported using a different sequence of
|
||||
# lowering platforms, but included in the lowering platforms for the
|
||||
# nested exported.
|
||||
exp2 = get_exported(export.call(exp),
|
||||
exp2 = get_exported(jax.jit(exp.call),
|
||||
lowering_platforms=("cpu", "cuda","rocm"))(x)
|
||||
|
||||
# Ensure that we do not have multiple lowerings of the exported function
|
||||
@ -1325,39 +1350,39 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
for platform in self.__class__.platforms:
|
||||
if platform == "tpu": continue
|
||||
x_device = jax.device_put(x, jax.devices(platform)[0])
|
||||
res_exp = export.call(exp2)(x_device)
|
||||
res_exp = exp2.call(x_device)
|
||||
self.assertAllClose(
|
||||
res_exp,
|
||||
_testing_multi_platform_fun_expected(np.sin(x), platform=platform))
|
||||
|
||||
def test_multi_platform_nested_inside_single_platform_export(self):
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
exp = get_exported(_testing_multi_platform_func,
|
||||
exp = get_exported(jax.jit(_testing_multi_platform_func),
|
||||
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm"))
|
||||
|
||||
# Now serialize the call for the current platform.
|
||||
exp2 = get_exported(export.call(exp))(x)
|
||||
exp2 = get_exported(jax.jit(exp.call))(x)
|
||||
module_str = str(exp2.mlir_module())
|
||||
self.assertIn("jax.uses_shape_polymorphism = true",
|
||||
module_str)
|
||||
res2 = export.call(exp2)(x)
|
||||
res2 = exp2.call(x)
|
||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
|
||||
|
||||
def test_multi_platform_and_poly(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
# The export is not applicable to GPU
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
exp = get_exported(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,)),
|
||||
exp = get_exported(jax.jit(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,))),
|
||||
lowering_platforms=("cpu", "tpu"))(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32)
|
||||
)
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
res = export.call(exp)(x)
|
||||
res = exp.call(x)
|
||||
self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,)))
|
||||
# Now serialize the call to the exported
|
||||
exp2 = get_exported(export.call(exp))(x)
|
||||
res2 = export.call(exp2)(x)
|
||||
exp2 = get_exported(jax.jit(exp.call))(x)
|
||||
res2 = exp2.call(x)
|
||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
|
||||
|
||||
def test_multi_platform_and_sharding(self):
|
||||
@ -1382,7 +1407,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
continue
|
||||
run_mesh = Mesh(run_devices, ("x",))
|
||||
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None))
|
||||
res_exp = export.call(exp)(a_device)
|
||||
res_exp = exp.call(a_device)
|
||||
self.assertArraysAllClose(res_native, res_exp)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -1409,7 +1434,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
|
||||
)
|
||||
|
||||
exp = get_exported(f_jax)(x)
|
||||
exp = get_exported(jax.jit(f_jax))(x)
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
sorted(str(e) for e in exp.ordered_effects))
|
||||
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
||||
@ -1449,7 +1474,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x, effect_class_name="ForTestingOrderedEffect2") +
|
||||
testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="ForTestingUnorderedEffect1") +
|
||||
export.call(exp)(x))
|
||||
exp.call(x))
|
||||
|
||||
lowered_outer = jax.jit(f_outer).lower(x)
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
@ -1476,7 +1501,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
def f_jax(x): # x: f32[b1, b2]
|
||||
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
|
||||
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(
|
||||
exp = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct(
|
||||
export.symbolic_shape("b2, b1"), x.dtype))
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
wrapped_main_expected_re = (
|
||||
@ -1497,7 +1522,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
|
||||
res = export.call(exp)(x)
|
||||
res = exp.call(x)
|
||||
self.assertAllClose(10. + 2. * x, res)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -1518,7 +1543,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return 10. + _testing_multi_platform_func(x,
|
||||
effect_class_name="ForTestingOrderedEffect1")
|
||||
exp = get_exported(
|
||||
f_jax,
|
||||
jax.jit(f_jax),
|
||||
lowering_platforms=("cpu", "tpu")
|
||||
)(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype))
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
@ -1541,7 +1566,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
res = export.call(exp)(x)
|
||||
res = exp.call(x)
|
||||
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
|
||||
res)
|
||||
|
||||
@ -1586,7 +1611,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x,
|
||||
effect_class_name="ForTestingOrderedEffect" + name)
|
||||
with self.assertRaisesRegex(Exception, expect_error):
|
||||
_ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype))
|
||||
_ = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct((3, 4), x.dtype))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
@ -1603,12 +1628,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n))
|
||||
res_native = f_jax(lhs, rhs, group_sizes)
|
||||
|
||||
exp_f = get_exported(f_jax)(
|
||||
exp_f = get_exported(jax.jit(f_jax))(
|
||||
jax.ShapeDtypeStruct(lhs.shape, dtype=lhs.dtype),
|
||||
jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype),
|
||||
jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype),
|
||||
)
|
||||
res_exported = export.call(exp_f)(lhs, rhs, group_sizes)
|
||||
res_exported = exp_f.call(lhs, rhs, group_sizes)
|
||||
self.assertAllClose(res_native, res_exported)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -35,7 +35,7 @@ import operator as op
|
||||
import re
|
||||
|
||||
import jax
|
||||
from jax.experimental import export
|
||||
from jax import export
|
||||
from jax.experimental import pjit
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
@ -1267,7 +1267,7 @@ class PolyHarness(Harness):
|
||||
tst.assertEqual(getattr(jax.config, fname), fvalue, (
|
||||
f"Flag {fname} current value {getattr(jax.config, fname)} != {fvalue}"))
|
||||
|
||||
f_jax = self.dyn_fun
|
||||
f_jax = jax.jit(self.dyn_fun)
|
||||
args = self.dyn_args_maker(tst.rng())
|
||||
args = jax.tree.map(jnp.array, args)
|
||||
args_specs = export.symbolic_args_specs(args, self.polymorphic_shapes,
|
||||
@ -1283,7 +1283,7 @@ class PolyHarness(Harness):
|
||||
return None
|
||||
# Run the JAX natively and then the exported function and compare
|
||||
res_jax_native = f_jax(*args)
|
||||
res_jax_exported = export.call(exp)(*args)
|
||||
res_jax_exported = exp.call(*args)
|
||||
custom_assert_lims = [
|
||||
l for l in self.limitations if l.custom_assert is not None]
|
||||
assert len(custom_assert_lims) <= 1, custom_assert_lims
|
||||
@ -1315,7 +1315,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
|
||||
symbolic_constraints: Sequence[str] = (),
|
||||
expect_error=None) -> jax.Array | None:
|
||||
# Builds a PolyHarness and runs the test. See PolyHarness documentation.
|
||||
h = PolyHarness("", "", f_jax,
|
||||
h = PolyHarness("", "", jax.jit(f_jax),
|
||||
arg_descriptors=arg_descriptors,
|
||||
polymorphic_shapes=polymorphic_shapes,
|
||||
symbolic_constraints=symbolic_constraints,
|
||||
@ -1408,11 +1408,10 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
def f_jax(x, *, y):
|
||||
return x + jnp.sin(y)
|
||||
|
||||
f_exported = export.call(
|
||||
export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
||||
x.dtype),
|
||||
y=jax.ShapeDtypeStruct(y.shape, y.dtype)))
|
||||
self.assertAllClose(f_jax(x, y=y), f_exported(x, y=y))
|
||||
exp = export.export(jax.jit(f_jax))(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype),
|
||||
y=jax.ShapeDtypeStruct(y.shape, y.dtype))
|
||||
self.assertAllClose(f_jax(x, y=y), exp.call(x, y=y))
|
||||
|
||||
def test_arg_avals_errors(self):
|
||||
"""Test error reporting for shape polymorphism."""
|
||||
@ -1617,8 +1616,8 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
acc += jnp.sum(slice, axis=0)
|
||||
return acc
|
||||
|
||||
_ = export.export(f)(jax.ShapeDtypeStruct(export.symbolic_shape("a, b"),
|
||||
np.int32))
|
||||
_ = export.export(jax.jit(f))(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), np.int32))
|
||||
|
||||
|
||||
def test_constraints_compile_time_check(self):
|
||||
@ -1630,29 +1629,30 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
x_spec = jax.ShapeDtypeStruct(
|
||||
export.symbolic_shape("a",
|
||||
constraints=["a >= 2", "a <= 4"]), np.int32)
|
||||
exp = export.export(f)(x_spec)
|
||||
exp = export.export(jax.jit(f))(x_spec)
|
||||
|
||||
x_2 = np.arange(2, dtype=np.int32)
|
||||
res_2 = export.call(exp)(x_2)
|
||||
res_2 = exp.call(x_2)
|
||||
self.assertAllClose(x_2[0:2], res_2)
|
||||
|
||||
x_4 = np.arange(4, dtype=np.int32)
|
||||
res_4 = export.call(exp)(x_4)
|
||||
res_4 = exp.call(x_4)
|
||||
self.assertAllClose(x_4[1:3], res_4)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")):
|
||||
export.call(exp)(np.arange(1, dtype=np.int32))
|
||||
exp.call(np.arange(1, dtype=np.int32))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")):
|
||||
export.call(exp)(np.arange(5, dtype=np.int32))
|
||||
exp.call(np.arange(5, dtype=np.int32))
|
||||
|
||||
def test_caching_with_scopes(self):
|
||||
f_tracing_count = 0
|
||||
expected_a_bounds = (1, np.inf)
|
||||
@jax.jit
|
||||
def f(x): # x: i32[a]
|
||||
nonlocal f_tracing_count
|
||||
f_tracing_count += 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user