diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py index b1b6b625c..d26801d8d 100644 --- a/benchmarks/shape_poly_benchmark.py +++ b/benchmarks/shape_poly_benchmark.py @@ -18,7 +18,7 @@ import google_benchmark as benchmark import jax from jax import core from jax._src.numpy import lax_numpy -from jax.experimental import export +from jax import export jax.config.parse_flags_with_absl() diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index d0b45e20f..90cd81a1e 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -333,10 +333,10 @@ class Exported: `jax.device_put`. Example usage: - >>> from jax.experimental import export + >>> from jax import export >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), - ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) + ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) ... )(np.arange(jax.device_count())) >>> exp.in_shardings_hlo ({devices=[8]<=[8]},) @@ -347,7 +347,7 @@ class Exported: # Put the args and kwargs on the appropriate devices >>> run_arg = jax.device_put(np.arange(jax.device_count()), ... exp.in_shardings_jax(run_mesh)[0]) - >>> res = export.call(exp)(run_arg) + >>> res = exp.call(run_arg) >>> res.addressable_shards [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), @@ -372,19 +372,53 @@ class Exported: for s in self.out_shardings_hlo) def has_vjp(self) -> bool: + """Returns if this Exported supports VJP.""" return self._get_vjp is not None def vjp(self) -> Exported: """Gets the exported VJP. Returns None if not available, which can happen if the Exported has been - loaded from an external format, without a VJP.""" + loaded from an external format without a VJP. + """ if self._get_vjp is None: raise ValueError("No VJP is available") return self._get_vjp(self) + def serialize(self, + vjp_order: int = 0) -> bytearray: + """Serializes an Exported. + + Args: + vjp_order: The maximum vjp order to include. E.g., the value 2 means that we + serialize the primal functions and two orders of the `vjp` function. This + should allow 2nd order reverse mode differentiation of the deserialized + function. i.e., `jax.grad(jax.grad(f)).` + """ + # Lazy load the serialization module, since flatbuffers is an optional + # dependency. + from jax._src.export.serialization import serialize + return serialize(self, vjp_order=vjp_order) + + def call(self, *args, **kwargs): + return call_exported(self)(*args, **kwargs) + + +def deserialize(blob: bytearray) -> Exported: + """Deserializes an Exported. + + Args: + blob: a bytearray obtained from `Exported.serialize`. + """ + # Lazy load the serialization module, since flatbuffers is an optional + # dependency. + from jax._src.export.serialization import deserialize + return deserialize(blob) + def default_lowering_platform() -> str: + """Retrieves the default lowering platform for the exporting machine. + """ # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' return xb.canonicalize_platform(jax.default_backend()) @@ -411,14 +445,20 @@ def args_specs( return shape_poly.symbolic_args_specs(args, polymorphic_shapes) -def export(fun_jax: Callable, - *, - lowering_platforms: Sequence[str] | None = None, - disabled_checks: Sequence[DisabledSafetyCheck] = (), - _device_assignment_for_internal_jax2tf_use_only = None, - ) -> Callable[..., Exported]: +# TODO(necula): remove this once we remove jax.experimental.export. +def export_back_compat( + fun_jax: Callable, + *, + lowering_platforms: Sequence[str] | None = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + _device_assignment_for_internal_jax2tf_use_only = None, + ) -> Callable[..., Exported]: """Exports native serialization for a JAX function. + Note: this function exists only for internal usage by jax2tf and for + backwards compatibility with jax.experimental.export. Use + `jax.export` instead. + Args: fun_jax: the function to lower and serialize. lowering_platforms: @@ -498,6 +538,85 @@ def export(fun_jax: Callable, _device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only) return do_export +def export( + fun_jit: stages.Wrapped, + *, + lowering_platforms: Sequence[str] | None = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + ) -> Callable[..., Exported]: + """Exports a JAX function for persistent serialization. + + Args: + fun_jit: the function to export. Should be the result of `jit`. + lowering_platforms: + Optional sequence containing a subset of 'tpu', 'cpu', + 'cuda', 'rocm'. If more than one platform is specified, then + the lowered code takes an argument specifying the platform. + If None, then use the default JAX backend. + The calling convention for multiple platforms is explained in the + `jax_export.Exported` docstring. + disabled_checks: the safety checks to disable. See docstring + of `DisabledSafetyCheck`. + + Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, + or values with `.shape` and `.dtype` attributes, and returns an + `Exported`. + + Usage: + >>> from jax import export + >>> exported: export.Exported = export.export(jnp.sin)( + ... np.arange(4, dtype=np.float32)) + + # You can inspect the Exported object + >>> exported.in_avals + (ShapedArray(float32[4]),) + >>> blob: bytearray = exported.serialize() + + # The serialized bytes are safe to use in a separate process + >>> rehydrated: export.Exported = export.deserialize(blob) + >>> rehydrated.fun_name + 'sin' + >>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32)) + Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) + """ + if not isinstance(fun_jit, stages.Wrapped): + raise ValueError( + f"Function to be exported must be the result of `jit` but is: {fun_jit}") + if lowering_platforms is not None: + actual_lowering_platforms = tuple(lowering_platforms) + else: + actual_lowering_platforms = (default_lowering_platform(),) + + def do_export(*args_specs, **kwargs_specs) -> Exported: + # TODO: move to `lower` + symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] + for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: + # Static args may have no `shape` attribute. + if not hasattr(aval, "shape"): + continue + for d in aval.shape: + if shape_poly.is_symbolic_dim(d): + if symbolic_scope is None: + symbolic_scope = (d.scope, k_path) + continue + symbolic_scope[0]._check_same_scope( + d, when=f"when exporting {util.fun_name(fun_jit)}", + self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", + other_descr=shape_poly.args_kwargs_path_to_str(k_path)) + + traced = fun_jit.trace( # type: ignore + *args_specs, **kwargs_specs, + _experimental_lowering_parameters=mlir.LoweringParameters( + platforms=actual_lowering_platforms, + for_export=True, + )) + jaxpr, fun_name = traced.jaxpr, traced.fun_name + lowered = traced.lower() + return _export_lowered( + lowered, jaxpr, fun_name, + disabled_checks=disabled_checks) + return do_export + def _export_lowered( lowered: stages.Lowered, jaxpr: core.ClosedJaxpr, fun_name: str, @@ -599,7 +718,7 @@ def _export_lowered( device_assignment=device_assignment, apply_jit=True, flat_primal_fun=True) - return export(fun_vjp_jax, + return export(fun_vjp_jax, # type: ignore[arg-type] lowering_platforms=exp_primal.lowering_platforms, disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) @@ -816,7 +935,7 @@ def _wrap_main_func( def _check_lowering(lowering) -> None: if not isinstance(lowering, pxla.MeshComputation): - raise NotImplementedError(f"serialization is supported only for pjit. {lowering}") + raise NotImplementedError(f"serialization is supported only for jit. {lowering}") if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]: raise NotImplementedError("serialization of host_callbacks is not yet implemented") diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 3da4a32a8..dd181e4bd 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -48,7 +48,7 @@ SerT = TypeVar("SerT") _SERIALIZATION_VERSION = 2 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: - """Serialize an Exported. + """Serializes an Exported. Args: exp: the Exported to serialize. @@ -64,7 +64,7 @@ def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: def deserialize(ser: bytearray) -> _export.Exported: - """Deserialize an Exported.""" + """Deserializes an Exported.""" exp = ser_flatbuf.Exported.GetRootAsExported(ser) return _deserialize_exported(exp) diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index e7d5d1931..098245019 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -86,7 +86,7 @@ from numpy import array, float32 import jax from jax import tree_util -from jax.experimental import export +from jax import export from jax.experimental import pjit @@ -345,4 +345,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( _get_vjp=_get_vjp) # We use pjit in case there are shardings in the exported module. - return pjit.pjit(export.call(exported))(*data.inputs) + return pjit.pjit(exported.call)(*data.inputs) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 028c612d4..3ed095138 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -32,7 +32,7 @@ from __future__ import annotations from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, NamedTuple, Protocol, Union +from typing import Any, NamedTuple, Protocol, Union, runtime_checkable import warnings import jax @@ -756,8 +756,9 @@ class Lowered(Stage): return None +@runtime_checkable class Wrapped(Protocol): - """A function ready to be specialized, lowered, and compiled. + """A function ready to be traced, lowered, and compiled. This protocol reflects the output of functions such as ``jax.jit``. Calling it results in JIT (just-in-time) lowering, diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py index 1cdbf1673..77e7652db 100644 --- a/jax/experimental/export/__init__.py +++ b/jax/experimental/export/__init__.py @@ -17,12 +17,13 @@ from jax._src.export._export import ( minimum_supported_serialization_version, maximum_supported_serialization_version, Exported, - export, call_exported, # TODO: deprecate call, DisabledSafetyCheck, - default_lowering_platform, + default_lowering_platform, # TODO: deprecate ) +from jax._src.export._export import export_back_compat as export + from jax._src.export.shape_poly import ( is_symbolic_dim, symbolic_shape, @@ -33,4 +34,6 @@ from jax._src.export.serialization import ( serialize, deserialize, ) +# Import only to set the shape poly decision procedure from jax._src.export import shape_poly_decision +del shape_poly_decision diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 39b1bf484..39b6ffab9 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -36,7 +36,7 @@ from jax import random from jax import numpy as jnp from jax import tree_util from jax import sharding -from jax.experimental import export +from jax import export from jax.experimental.jax2tf import impl_no_xla from jax.interpreters import xla @@ -515,7 +515,7 @@ class NativeSerializationImpl(SerializationImpl): self._restore_context = _restore_context _exported_device_assignment = [None] - self.exported = export.export( + self.exported = _export.export_back_compat( self.fun_jax, lowering_platforms=self.native_serialization_platforms, disabled_checks=self.native_serialization_disabled_checks, diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index e7c997710..f22af3e64 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -25,13 +25,13 @@ from absl.testing import parameterized import jax from jax import dlpack from jax import dtypes +from jax import export from jax import lax from jax import numpy as jnp from jax._src import config from jax._src import test_util as jtu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax.experimental import export from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util import numpy as np @@ -778,7 +778,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): lowering_platforms = ("tpu", "cpu", "cuda") - exp = export.export(f_jax, + exp = export.export(jax.jit(f_jax), lowering_platforms=lowering_platforms)(x) for jax_platform in jax_and_tf_platforms: with self.subTest(jax_platform): @@ -787,7 +787,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): logging.info("Running harness natively on %s", jax_device) native_res = f_jax(x_device) logging.info("Running exported harness on %s", jax_device) - exported_res = export.call(exp)(x_device) + exported_res = exp.call(x_device) self.assertAllClose(native_res, exported_res) def test_multi_platform_call_tf_graph(self): diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 9d8abc8b3..cffe31abe 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -27,6 +27,7 @@ from absl.testing import absltest, parameterized import jax from jax import ad_checkpoint from jax import dtypes +from jax import export from jax import lax from jax import numpy as jnp from jax import sharding @@ -37,7 +38,6 @@ from jax._src import source_info_util from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax.experimental import jax2tf -from jax.experimental import export from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.shard_map import shard_map from jax.experimental import pjit @@ -1559,7 +1559,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): # Run the JAX native version, to check it works, and to fill caches. _ = func_to_convert(*args) exported = export.export( - func_to_convert, + (jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert), lowering_platforms=("tpu",) )(*(core.ShapedArray(a.shape, a.dtype) for a in args)) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 7f4ad1e45..66a3b7b62 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -32,8 +32,8 @@ import re import jax from jax.experimental import jax2tf -from jax.experimental import export from jax.experimental import pjit +from jax import export from jax import lax import jax.numpy as jnp from jax import random diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 7ad8a90da..87c44c7d6 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -31,7 +31,7 @@ from jax._src import test_util as jtu from jax import tree_util from jax.experimental import jax2tf -from jax.experimental import export +from jax import export from jax._src import config from jax._src import xla_bridge import numpy as np diff --git a/jax/export.py b/jax/export.py new file mode 100644 index 000000000..90d9f7e86 --- /dev/null +++ b/jax/export.py @@ -0,0 +1,34 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize", + "maximum_supported_serialization_version", + "minimum_supported_serialization_version", + "default_lowering_platform", + "SymbolicScope", "is_symbolic_dim", + "symbolic_shape", "symbolic_args_specs"] + +from jax._src.export._export import DisabledSafetyCheck as DisabledSafetyCheck +from jax._src.export._export import Exported as Exported +from jax._src.export._export import export as export +from jax._src.export._export import deserialize as deserialize +from jax._src.export._export import maximum_supported_serialization_version as maximum_supported_serialization_version +from jax._src.export._export import minimum_supported_serialization_version as minimum_supported_serialization_version +from jax._src.export._export import default_lowering_platform as default_lowering_platform + +from jax._src.export import shape_poly_decision # Import only to set the decision procedure +del shape_poly_decision +from jax._src.export.shape_poly import SymbolicScope as SymbolicScope +from jax._src.export.shape_poly import is_symbolic_dim as is_symbolic_dim +from jax._src.export.shape_poly import symbolic_shape as symbolic_shape +from jax._src.export.shape_poly import symbolic_args_specs as symbolic_args_specs diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 89ecd16fe..5d737632a 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -31,9 +31,9 @@ from absl.testing import absltest import numpy as np import jax +from jax import export from jax import lax from jax._src import test_util as jtu -from jax.experimental import export from jax._src.internal_test_util import test_harnesses @@ -152,7 +152,8 @@ class PrimitiveTest(jtu.JaxTestCase): ) logging.info("Exporting harness for %s", lowering_platforms) - exp = export.export(func_jax, lowering_platforms=lowering_platforms)(*args) + exp = export.export(jax.jit(func_jax), + lowering_platforms=lowering_platforms)(*args) for device in devices: if device.platform in skip_run_on_platforms: @@ -164,7 +165,7 @@ class PrimitiveTest(jtu.JaxTestCase): logging.info("Running harness natively on %s", device) native_res = func_jax(*device_args) logging.info("Running exported harness on %s", device) - exported_res = export.call(exp)(*device_args) + exported_res = exp.call(*device_args) if tol is not None: logging.info(f"Using non-standard tolerance {tol}") self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol) diff --git a/tests/export_test.py b/tests/export_test.py index 556f9d363..f36bf1bd0 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -20,12 +20,13 @@ import logging import math import re import unittest +from typing import Callable from absl.testing import absltest import jax from jax import lax from jax import numpy as jnp -from jax.experimental import export +from jax import export from jax.experimental import pjit from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding @@ -137,12 +138,12 @@ def _testing_multi_platform_fun_expected(x, ] -def get_exported(fun, vjp_order=0, +def get_exported(fun: Callable, vjp_order=0, **export_kwargs): """Like export.export but with serialization + deserialization.""" def serde_exported(*fun_args, **fun_kwargs): exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs) - serialized = export.serialize(exp, vjp_order=vjp_order) + serialized = exp.serialize(vjp_order=vjp_order) return export.deserialize(serialized) return serde_exported @@ -164,11 +165,13 @@ class JaxExportTest(jtu.JaxTestCase): super().setUpClass() def test_basic_export_only(self): + @jax.jit def my_fun(x): return jnp.sin(x) exp = get_exported(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32)) self.assertEqual("my_fun", exp.fun_name) - self.assertEqual((export.default_lowering_platform(),), + expected_lowering_platform = xb.canonicalize_platform(jax.default_backend()) + self.assertEqual((expected_lowering_platform,), exp.lowering_platforms) self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree) self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals) @@ -180,7 +183,7 @@ class JaxExportTest(jtu.JaxTestCase): def f(a_b_pair, *, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp = get_exported(f, lowering_platforms=("cpu",))((a, b), a=a, b=b) + exp = get_exported(jax.jit(f), lowering_platforms=("cpu",))((a, b), a=a, b=b) a_aval = core.ShapedArray(a.shape, a.dtype) b_aval = core.ShapedArray(b.shape, b.dtype) self.assertEqual(exp.lowering_platforms, ("cpu",)) @@ -196,8 +199,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) - f1 = export.call(exp_f) - self.assertAllClose(f(x), f1(x)) + self.assertAllClose(f(x), exp_f.call(x)) def test_jit_static_arg(self): @@ -210,8 +212,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x, c=0.1) - f1 = export.call(exp_f) - self.assertAllClose(f(x, c=0.1), f1(x)) + self.assertAllClose(f(x, c=0.1), exp_f.call(x)) with self.subTest("static_argnums"): @@ -222,16 +223,32 @@ class JaxExportTest(jtu.JaxTestCase): x = np.arange(4, dtype=np.float32) exp_g = get_exported(g)(x, 0.1) - g1 = export.call(exp_g) - self.assertAllClose(g(x, 0.1), g1(x)) + self.assertAllClose(g(x, 0.1), exp_g.call(x)) + + def test_export_error_no_jit(self): + # Can export a lambda, without jit + with self.assertRaisesRegex(ValueError, + "Function to be exported must be the result of `jit`"): + _ = export.export(lambda x: jnp.sin(x)) + + def test_export_experimental_back_compat(self): + from jax.experimental import export + # Can export a lambda, without jit + exp = export.export(lambda x: jnp.sin(x))(.1) + self.assertAllClose(exp.call(1.), np.sin(1.)) + + blob = export.serialize(exp, vjp_order=1) + rehydrated = export.deserialize(blob) + + self.assertAllClose(export.call(exp)(1.), np.sin(1.)) + self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.)) def test_call_exported_lambda(self): # When we export a lambda, the exported.fun_name is not a valid MLIR function name - f = lambda x: jnp.sin(x) + f = jax.jit(lambda x: jnp.sin(x)) x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) - f1 = export.call(exp_f) - self.assertAllClose(f(x), f1(x)) + self.assertAllClose(f(x), exp_f.call(x)) def test_call_name_conflict(self): @jax.jit @@ -246,7 +263,7 @@ class JaxExportTest(jtu.JaxTestCase): @jax.jit def outer(x): # There should be no conflict on _where - x = export.call(exp_inner)(x) + x = exp_inner.call(x) return inner(x) export.export(outer)(x) @@ -257,19 +274,18 @@ class JaxExportTest(jtu.JaxTestCase): @jax.jit def f1(x): - exp_f = get_exported(f)(x) - return export.call(exp_f)(x) + export.call(exp_f)(x) + exp_f = get_exported(jax.jit(f))(x) + return exp_f.call(x) + exp_f.call(x) self.assertAllClose(2. * f(x), f1(x)) def test_unused_args(self): - f = lambda x, y: jnp.sin(x) + f = jax.jit(lambda x, y: jnp.sin(x)) x = np.arange(4, dtype=np.float32) y = np.arange(6, dtype=np.float32) exp_f = get_exported(f)(x, y) - f1 = export.call(exp_f) - self.assertAllClose(f(x, y), f1(x, y)) + self.assertAllClose(f(x, y), exp_f.call(x, y)) def test_pytree(self): a = np.arange(4, dtype=np.float32) @@ -277,43 +293,49 @@ class JaxExportTest(jtu.JaxTestCase): def f(a_b_pair, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp_f = get_exported(f)((a, b), a=a, b=b) - f1 = export.call(exp_f) + exp_f = get_exported(jax.jit(f))((a, b), a=a, b=b) self.assertAllClose(f((a, b), a=a, b=b), - f1((a, b), a=a, b=b)) + exp_f.call((a, b), a=a, b=b)) def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c a = b = c = np.arange(4, dtype=np.float32) - exp_f = get_exported(f)((a, b), c=c) + exp_f = get_exported(jax.jit(f))((a, b), c=c) with self.assertRaisesRegex( ValueError, "The invocation args and kwargs must have the same pytree structure"): - export.call(exp_f)(a, b, c=(a, b)) + exp_f.call(a, b, c=(a, b)) def test_error_wrong_avals(self): def f(a, *, b): # a: f32[4] and b: f32[4] return jnp.sin(a) + jnp.cos(b) f32_4 = np.arange(4, dtype=np.float32) - exp_f = get_exported(f)(f32_4, b=f32_4) + exp_f = get_exported(jax.jit(f))(f32_4, b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for args\[0\].shape\[0\]"): - export.call(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) + exp_f.call(np.arange(6, dtype=np.float32), b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for kwargs\['b'\].shape\[0\]"): - export.call(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) + exp_f.call(f32_4, b=np.arange(6, dtype=np.float32)) with self.assertRaisesRegex(ValueError, r"Rank mismatch for args\[0\]"): - export.call(exp_f)(f32_4.reshape((1, 4)), b=f32_4) + exp_f.call(f32_4.reshape((1, 4)), b=f32_4) with self.assertRaisesRegex(ValueError, r"Dtype mismatch for args\[0\]"): - export.call(exp_f)(f32_4.astype(np.float16), b=f32_4) + exp_f.call(f32_4.astype(np.float16), b=f32_4) + + def test_default_lowering_platform(self): + test_platform = jtu.device_under_test() + if test_platform == "gpu": test_platform = "cuda" + self.assertEqual(export.default_lowering_platform(), test_platform) + exp = export.export(jnp.sin)(1.) + self.assertEqual(exp.lowering_platforms, (export.default_lowering_platform(),)) @jtu.parameterized_filterable( testcase_name=lambda kw: kw["platform"], @@ -328,13 +350,13 @@ class JaxExportTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "The exported function .* was lowered for platform"): - export.call(exp_f)(a) + exp_f.call(a) # Now try with the platform check disabled exp_f_no_platform_check = get_exported( jnp.sin, lowering_platforms=(platform,), disabled_checks=[export.DisabledSafetyCheck.platform()])(a) - res = export.call(exp_f_no_platform_check)(a) + res = exp_f_no_platform_check.call(a) self.assertAllClose(res, jnp.sin(a)) @jtu.parameterized_filterable( @@ -357,12 +379,12 @@ class JaxExportTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, "Cannot serialize code with custom calls whose targets .*"): get_exported( - lambda a: a + test_primitive.bind(a) + jax.jit(lambda a: a + test_primitive.bind(a)) )(a) # Now try again with the safety check disabled exp = get_exported( - lambda a: a + test_primitive.bind(a), + jax.jit(lambda a: a + test_primitive.bind(a)), disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")] )(a) self.assertIn("disallowed_call_target", exp.mlir_module()) @@ -387,22 +409,22 @@ class JaxExportTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, "Lowering for export not supported"): - export.export(f)(a) + export.export(jax.jit(f))(a) def test_grad(self): f = lambda x: jnp.sum(jnp.sin(x)) x = np.arange(4, dtype=np.float32) - exp_f = get_exported(f, vjp_order=1)(x) + exp_f = get_exported(jax.jit(f), vjp_order=1)(x) - f1 = export.call(exp_f) + f1 = exp_f.call self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) def test_higher_order_grad(self): f = lambda x: x ** 3 x = np.float32(4.) - exp_f = get_exported(f, vjp_order=3)(x) + exp_f = get_exported(jax.jit(f), vjp_order=3)(x) - f1 = export.call(exp_f) + f1 = exp_f.call self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) self.assertAllClose(jax.grad(jax.grad(f))(x), @@ -428,8 +450,8 @@ class JaxExportTest(jtu.JaxTestCase): self.assertAllClose(res, (xi_ct, xf_ct)) (f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct)) - exp = get_exported(f, vjp_order=2)(xi, xf) - fr = export.call(exp) + exp = get_exported(jax.jit(f), vjp_order=2)(xi, xf) + fr = exp.call res = fr(xi, xf) self.assertAllClose(res, (f_outi, f_outf)) @@ -456,14 +478,14 @@ class JaxExportTest(jtu.JaxTestCase): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) - exp_f = get_exported(f, vjp_order=1)((a, b), a=a, b=b) + exp_f = get_exported(jax.jit(f), vjp_order=1)((a, b), a=a, b=b) out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent def f1_jax(a, b): # For VJP, make a function without kwargs res = f((a, b), a=a, b=b) return res def f1_exp(a, b): # For VJP, make a function without kwargs - res = export.call(exp_f)((a, b), a=a, b=b) + res = exp_f.call((a, b), a=a, b=b) return res jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct) exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct) @@ -473,15 +495,15 @@ class JaxExportTest(jtu.JaxTestCase): def f1(x): return jnp.sin(x) a = np.arange(4, dtype=np.float32) - exp_f1 = get_exported(f1)(a) + exp_f1 = get_exported(jax.jit(f1))(a) def f2(x): - res1 = export.call(exp_f1)(x) - res2 = export.call(exp_f1)(res1) + res1 = exp_f1.call(x) + res2 = exp_f1.call(res1) return jnp.cos(res2) - exp_f2 = get_exported(f2)(a) + exp_f2 = get_exported(jax.jit(f2))(a) self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))), - export.call(exp_f2)(a)) + exp_f2.call(a)) def test_poly_export_only(self): a = np.arange(12, dtype=np.float32).reshape((3, 4)) @@ -489,7 +511,7 @@ class JaxExportTest(jtu.JaxTestCase): return jnp.concatenate([a, b], axis=0) scope = export.SymbolicScope() - exp = get_exported(f)( + exp = get_exported(jax.jit(f))( jax.ShapeDtypeStruct(export.symbolic_shape("(2*w, h)", scope=scope), a.dtype), jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)", scope=scope), a.dtype)) self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape)) @@ -528,7 +550,7 @@ class JaxExportTest(jtu.JaxTestCase): return jnp.concatenate([a0, a1, ak], axis=0) a_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)"), a.dtype) - exp = get_exported(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec) + exp = get_exported(jax.jit(f))(a_poly_spec, a_poly_spec, ak=a_poly_spec) self.assertEqual("(w, h)", str(exp.in_avals[0].shape)) self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape)) @@ -545,7 +567,7 @@ class JaxExportTest(jtu.JaxTestCase): "Invalid mixing of symbolic scopes when exporting f.*" r"Expected current \(from args\[0\]\) scope .*" r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)): - get_exported(f)(x_poly_spec, y_poly_spec) + get_exported(jax.jit(f))(x_poly_spec, y_poly_spec) def test_poly_export_callable_with_no_name(self): # This was reported by a user @@ -566,7 +588,7 @@ class JaxExportTest(jtu.JaxTestCase): a, = export.symbolic_shape("a,") # No error - _ = get_exported(MyCallable())( + _ = get_exported(jax.jit(MyCallable()))( jax.ShapeDtypeStruct((a, a), dtype=np.float32) ) @@ -590,7 +612,7 @@ class JaxExportTest(jtu.JaxTestCase): exp = get_exported(jnp.sin)( jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32)) x = np.arange(30, dtype=np.float32).reshape((5, 6)) - res = export.call(exp)(x) + res = exp.call(x) self.assertAllClose(res, np.sin(x)) # A function is exported with f32[poly_spec] and is called with different arg @@ -631,7 +653,7 @@ class JaxExportTest(jtu.JaxTestCase): return jnp.reshape(x, (-1, x.shape[1])) disabled_checks = () - exp_f = get_exported(f, disabled_checks=disabled_checks)( + exp_f = get_exported(jax.jit(f), disabled_checks=disabled_checks)( jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), np.float32)) self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12") arg = np.arange(np.prod(arg_shape), @@ -642,7 +664,7 @@ class JaxExportTest(jtu.JaxTestCase): stack.push(self.assertRaisesRegex(Exception, expect_error)) assert core.is_constant_shape(arg.shape) - res = export.call(exp_f)(arg) + res = exp_f.call(arg) if not expect_error: self.assertAllClose(res, f(arg)) @@ -733,7 +755,7 @@ class JaxExportTest(jtu.JaxTestCase): arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12] - inner_exp = get_exported(inner)( + inner_exp = get_exported(jax.jit(inner))( jax.ShapeDtypeStruct(export.symbolic_shape(inner_poly_spec), np.float32)) self.assertEqual(inner_exp.uses_shape_polymorphism, @@ -741,14 +763,14 @@ class JaxExportTest(jtu.JaxTestCase): def outer(x): # x: outer_poly_spec # Use an addition to test that the shapes are refined properly for the # result of the call_exported. - return export.call(inner_exp)(x) + inner(x) + return inner_exp.call(x) + inner(x) with contextlib.ExitStack() as stack: if expect_error_outer_exp is not None: stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp)) # Call it after exporting again, with polymorphic shapes - outer_exp = get_exported(outer)( + outer_exp = get_exported(jax.jit(outer))( jax.ShapeDtypeStruct(export.symbolic_shape(outer_poly_spec), arg.dtype)) if expect_error_outer_exp is not None: @@ -761,7 +783,7 @@ class JaxExportTest(jtu.JaxTestCase): if expect_error_run is not None: stack.push(self.assertRaisesRegex(Exception, expect_error_run)) - res = export.call(outer_exp)(arg) + res = outer_exp.call(arg) if expect_error_run is not None: return @@ -823,20 +845,21 @@ class JaxExportTest(jtu.JaxTestCase): with contextlib.ExitStack() as stack: if expect_error is not None: stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error))) - exp = get_exported(f_jax)( + exp = get_exported(jax.jit(f_jax))( jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype)) - export.call(exp)(x) + exp.call(x) def test_poly_booleans(self): # For booleans we use a special case ConvertOp to cast to and from # dynamic shapes arguments. + @jax.jit def f_jax(x): # x: bool[b] return jnp.logical_not(x) x = np.array([True, False, True, False], dtype=np.bool_) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) - res = export.call(exp)(x) + res = exp.call(x) self.assertAllClose(f_jax(x), res) @jtu.parameterized_filterable( @@ -851,13 +874,14 @@ class JaxExportTest(jtu.JaxTestCase): "int4", "uint4"}: self.skipTest(f"TODO: serialization not supported for {str(dtype)}") + @jax.jit def f_jax(x): return x + x x = np.arange(6, dtype=dtype) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) - res = export.call(exp)(x) + res = exp.call(x) self.assertAllClose(f_jax(x), res) def test_poly_expressions(self): @@ -867,6 +891,7 @@ class JaxExportTest(jtu.JaxTestCase): return (b + b, b - b, b * b, (b + 13) // b, (b + 13) % b, core.max_dim(b - 5, 0)) + @jax.jit def f(x): # x: f32[b] b = x.shape[0] return jnp.ones(output_shape(b), dtype=x.dtype) @@ -874,12 +899,12 @@ class JaxExportTest(jtu.JaxTestCase): exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) # Call with static shapes - res = export.call(exp)(x) + res = exp.call(x) self.assertAllClose(res, f(x)) # Now re-export with shape polymorphism x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype) - exp2 = get_exported(export.call(exp))(x_spec) + exp2 = get_exported(jax.jit(exp.call))(x_spec) a = exp2.in_avals[0].shape[0] self.assertEqual(exp2.out_avals[0].shape, output_shape(a)) @@ -890,9 +915,9 @@ class JaxExportTest(jtu.JaxTestCase): return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1)) a, = export.symbolic_shape("a") - exp = export.export(f)( + exp = export.export(jax.jit(f))( jax.ShapeDtypeStruct((a, 4), np.float32)) - f_exp = export.call(exp) + f_exp = exp.call x_jit = np.arange(12, dtype=np.float32).reshape((3, 4)) res_jit = jax.jit(f_exp)(x_jit) self.assertAllClose(res_jit, f(x_jit)) @@ -931,24 +956,24 @@ class JaxExportTest(jtu.JaxTestCase): # We apply the out_shardings for f_jax r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*", re.DOTALL) - hlo = jax.jit(export.call(exp)).lower(a_device).as_text() + hlo = jax.jit(exp.call).lower(a_device).as_text() self.assertRegex(hlo, expected_re) - res_exported = export.call(exp)(a_device) + res_exported = exp.call(a_device) self.assertAllClose(res_native, res_exported) # Test error reporting with self.assertRaisesRegex( NotImplementedError, "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): - _ = export.call(exp)(a) + _ = exp.call(a) with self.assertRaisesRegex( NotImplementedError, "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",)) _ = jax.jit( - export.call(exp), + exp.call, in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),) )(a) @@ -973,7 +998,7 @@ class JaxExportTest(jtu.JaxTestCase): run_input_shardings = exp.in_shardings_jax(run_mesh) a_run = jax.device_put(a, run_input_shardings[0]) b_run = jax.device_put(a, run_input_shardings[1]) - res = export.call(exp)(a_run, b_run) + res = exp.call(a_run, b_run) self.assertEqual(res.addressable_shards[0].device, run_devices[0]) self.assertEqual(res.addressable_shards[1].device, run_devices[1]) @@ -996,7 +1021,7 @@ class JaxExportTest(jtu.JaxTestCase): run_mesh = Mesh(run_devices, "i") b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) - res_exported = export.call(exp)(b) + res_exported = exp.call(b) self.assertAllClose(res_native, res_exported) def test_call_with_different_no_of_devices_error_has_in_shardings(self): @@ -1024,7 +1049,7 @@ class JaxExportTest(jtu.JaxTestCase): "Exported module .* was lowered for 1 devices and is called in a " f"context with {jax.local_device_count()} devices.* module contains " "non-replicated sharding annotations"): - export.call(exp)(b) + exp.call(b) def test_call_with_different_no_of_devices_pmap(self): if len(jax.devices()) < 2: @@ -1042,7 +1067,7 @@ class JaxExportTest(jtu.JaxTestCase): b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape( (-1, 1, 100) ) - res_exported = jax.pmap(export.call(exp))(b) + res_exported = jax.pmap(exp.call)(b) self.assertAllClose(res_native, res_exported[0]) def test_call_with_different_no_of_devices_error_has_sharding_constraint(self): @@ -1070,7 +1095,7 @@ class JaxExportTest(jtu.JaxTestCase): "Exported module .* was lowered for 1 devices and is called in a " f"context with {jax.local_device_count()} devices.* module contains " "non-replicated sharding annotations"): - export.call(exp)(b) + exp.call(b) @jtu.parameterized_filterable( kwargs=[ @@ -1108,7 +1133,7 @@ class JaxExportTest(jtu.JaxTestCase): self.assertLen(res_jax.addressable_shards, len(devices)) # Test reloaded execution. - f_r = export.call(exp) + f_r = exp.call with self.assertRaisesRegex( Exception, "Exported module .* was lowered for 2 devices and is " @@ -1241,14 +1266,14 @@ class JaxExportTest(jtu.JaxTestCase): self.assertEqual(exp_vjp2.nr_devices, 2) call_mesh = Mesh(jax.devices()[:2], "e") - g1 = pjit.pjit(export.call(exp_vjp), + g1 = pjit.pjit(exp_vjp.call, in_shardings=(NamedSharding(call_mesh, None), NamedSharding(call_mesh, None)))(x, x.T) _, f_jax_vjp = jax.vjp(f_jax, x) xbar = f_jax_vjp(x.T) self.assertAllClose(xbar, g1) - g2 = pjit.pjit(export.call(exp_vjp2), + g2 = pjit.pjit(exp_vjp2.call, in_shardings=(NamedSharding(call_mesh, None), NamedSharding(call_mesh, None), NamedSharding(call_mesh, None)))(x, x.T, x) @@ -1274,17 +1299,17 @@ class JaxExportTest(jtu.JaxTestCase): exp = export.export(pjit.pjit(f, in_shardings=shardings))(input) exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards) - _ = export.serialize(exp, vjp_order=1) - _ = export.serialize(exp_rev, vjp_order=1) + _ = exp.serialize(vjp_order=1) + _ = exp_rev.serialize(vjp_order=1) - g = jax.grad(export.call(exp_rev))(input_rev) - g_rev = jax.grad(export.call(exp))(input) + g = jax.grad(exp_rev.call)(input_rev) + g_rev = jax.grad(exp.call)(input) self.assertAllClose(g, g_rev) def test_multi_platform(self): x = np.arange(8, dtype=np.float32) - exp = get_exported(_testing_multi_platform_func, - lowering_platforms=("tpu", "cpu", "cuda","rocm"))(x) + exp = get_exported(jax.jit(_testing_multi_platform_func), + lowering_platforms=("tpu", "cpu", "cuda","rocm"))(x) self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda", "rocm")) module_str = str(exp.mlir_module()) expected_main_re = ( @@ -1299,22 +1324,22 @@ class JaxExportTest(jtu.JaxTestCase): # Call with argument placed on different plaforms for platform in self.__class__.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) - res_exp = export.call(exp)(x_device) + res_exp = exp.call(x_device) self.assertAllClose( res_exp, _testing_multi_platform_fun_expected(x, platform=platform)) def test_multi_platform_nested(self): x = np.arange(5, dtype=np.float32) - exp = get_exported(lambda x: _testing_multi_platform_func(jnp.sin(x)), - lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x) + exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))), + lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x) self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda","rocm")) # Now serialize the call to the exported using a different sequence of # lowering platforms, but included in the lowering platforms for the # nested exported. - exp2 = get_exported(export.call(exp), - lowering_platforms=("cpu", "cuda","rocm"))(x) + exp2 = get_exported(jax.jit(exp.call), + lowering_platforms=("cpu", "cuda","rocm"))(x) # Ensure that we do not have multiple lowerings of the exported function exp2_module_str = str(exp2.mlir_module()) @@ -1325,39 +1350,39 @@ class JaxExportTest(jtu.JaxTestCase): for platform in self.__class__.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) - res_exp = export.call(exp2)(x_device) + res_exp = exp2.call(x_device) self.assertAllClose( res_exp, _testing_multi_platform_fun_expected(np.sin(x), platform=platform)) def test_multi_platform_nested_inside_single_platform_export(self): x = np.arange(5, dtype=np.float32) - exp = get_exported(_testing_multi_platform_func, - lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x) + exp = get_exported(jax.jit(_testing_multi_platform_func), + lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x) self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call for the current platform. - exp2 = get_exported(export.call(exp))(x) + exp2 = get_exported(jax.jit(exp.call))(x) module_str = str(exp2.mlir_module()) self.assertIn("jax.uses_shape_polymorphism = true", module_str) - res2 = export.call(exp2)(x) + res2 = exp2.call(x) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x)) def test_multi_platform_and_poly(self): if jtu.test_device_matches(["gpu"]): # The export is not applicable to GPU raise unittest.SkipTest("Not intended for running on GPU") - exp = get_exported(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,)), - lowering_platforms=("cpu", "tpu"))( + exp = get_exported(jax.jit(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,))), + lowering_platforms=("cpu", "tpu"))( jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32) ) x = np.arange(12, dtype=np.float32).reshape((3, 4)) - res = export.call(exp)(x) + res = exp.call(x) self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,))) # Now serialize the call to the exported - exp2 = get_exported(export.call(exp))(x) - res2 = export.call(exp2)(x) + exp2 = get_exported(jax.jit(exp.call))(x) + res2 = exp2.call(x) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,))) def test_multi_platform_and_sharding(self): @@ -1382,7 +1407,7 @@ class JaxExportTest(jtu.JaxTestCase): continue run_mesh = Mesh(run_devices, ("x",)) a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None)) - res_exp = export.call(exp)(a_device) + res_exp = exp.call(a_device) self.assertArraysAllClose(res_native, res_exp) @jtu.parameterized_filterable( @@ -1409,7 +1434,7 @@ class JaxExportTest(jtu.JaxTestCase): testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") ) - exp = get_exported(f_jax)(x) + exp = get_exported(jax.jit(f_jax))(x) self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], sorted(str(e) for e in exp.ordered_effects)) self.assertEqual(["ForTestingUnorderedEffect1()"], @@ -1449,7 +1474,7 @@ class JaxExportTest(jtu.JaxTestCase): x, effect_class_name="ForTestingOrderedEffect2") + testing_primitive_with_effect_p.bind( x, effect_class_name="ForTestingUnorderedEffect1") + - export.call(exp)(x)) + exp.call(x)) lowered_outer = jax.jit(f_outer).lower(x) self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], @@ -1476,7 +1501,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.arange(12, dtype=np.float32).reshape((3, 4)) def f_jax(x): # x: f32[b1, b2] return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") - exp = get_exported(f_jax)(jax.ShapeDtypeStruct( + exp = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct( export.symbolic_shape("b2, b1"), x.dtype)) mlir_module_str = str(exp.mlir_module()) wrapped_main_expected_re = ( @@ -1497,7 +1522,7 @@ class JaxExportTest(jtu.JaxTestCase): r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") self.assertRegex(mlir_module_str, main_expected_re) - res = export.call(exp)(x) + res = exp.call(x) self.assertAllClose(10. + 2. * x, res) @jtu.parameterized_filterable( @@ -1518,7 +1543,7 @@ class JaxExportTest(jtu.JaxTestCase): return 10. + _testing_multi_platform_func(x, effect_class_name="ForTestingOrderedEffect1") exp = get_exported( - f_jax, + jax.jit(f_jax), lowering_platforms=("cpu", "tpu") )(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype)) mlir_module_str = str(exp.mlir_module()) @@ -1541,7 +1566,7 @@ class JaxExportTest(jtu.JaxTestCase): # Results r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") self.assertRegex(mlir_module_str, main_expected_re) - res = export.call(exp)(x) + res = exp.call(x) self.assertAllClose(10. + _testing_multi_platform_fun_expected(x), res) @@ -1586,7 +1611,7 @@ class JaxExportTest(jtu.JaxTestCase): x, effect_class_name="ForTestingOrderedEffect" + name) with self.assertRaisesRegex(Exception, expect_error): - _ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype)) + _ = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct((3, 4), x.dtype)) @jtu.parameterized_filterable( kwargs=[ @@ -1603,12 +1628,12 @@ class JaxExportTest(jtu.JaxTestCase): rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n)) res_native = f_jax(lhs, rhs, group_sizes) - exp_f = get_exported(f_jax)( + exp_f = get_exported(jax.jit(f_jax))( jax.ShapeDtypeStruct(lhs.shape, dtype=lhs.dtype), jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype), jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype), ) - res_exported = export.call(exp_f)(lhs, rhs, group_sizes) + res_exported = exp_f.call(lhs, rhs, group_sizes) self.assertAllClose(res_native, res_exported) if __name__ == "__main__": diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 8e4917774..74064a783 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -35,7 +35,7 @@ import operator as op import re import jax -from jax.experimental import export +from jax import export from jax.experimental import pjit from jax import lax import jax.numpy as jnp @@ -1267,7 +1267,7 @@ class PolyHarness(Harness): tst.assertEqual(getattr(jax.config, fname), fvalue, ( f"Flag {fname} current value {getattr(jax.config, fname)} != {fvalue}")) - f_jax = self.dyn_fun + f_jax = jax.jit(self.dyn_fun) args = self.dyn_args_maker(tst.rng()) args = jax.tree.map(jnp.array, args) args_specs = export.symbolic_args_specs(args, self.polymorphic_shapes, @@ -1283,7 +1283,7 @@ class PolyHarness(Harness): return None # Run the JAX natively and then the exported function and compare res_jax_native = f_jax(*args) - res_jax_exported = export.call(exp)(*args) + res_jax_exported = exp.call(*args) custom_assert_lims = [ l for l in self.limitations if l.custom_assert is not None] assert len(custom_assert_lims) <= 1, custom_assert_lims @@ -1315,7 +1315,7 @@ def check_shape_poly(tst, f_jax: Callable, *, symbolic_constraints: Sequence[str] = (), expect_error=None) -> jax.Array | None: # Builds a PolyHarness and runs the test. See PolyHarness documentation. - h = PolyHarness("", "", f_jax, + h = PolyHarness("", "", jax.jit(f_jax), arg_descriptors=arg_descriptors, polymorphic_shapes=polymorphic_shapes, symbolic_constraints=symbolic_constraints, @@ -1408,11 +1408,10 @@ class ShapePolyTest(jtu.JaxTestCase): def f_jax(x, *, y): return x + jnp.sin(y) - f_exported = export.call( - export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), - x.dtype), - y=jax.ShapeDtypeStruct(y.shape, y.dtype))) - self.assertAllClose(f_jax(x, y=y), f_exported(x, y=y)) + exp = export.export(jax.jit(f_jax))( + jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype), + y=jax.ShapeDtypeStruct(y.shape, y.dtype)) + self.assertAllClose(f_jax(x, y=y), exp.call(x, y=y)) def test_arg_avals_errors(self): """Test error reporting for shape polymorphism.""" @@ -1617,8 +1616,8 @@ class ShapePolyTest(jtu.JaxTestCase): acc += jnp.sum(slice, axis=0) return acc - _ = export.export(f)(jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), - np.int32)) + _ = export.export(jax.jit(f))( + jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), np.int32)) def test_constraints_compile_time_check(self): @@ -1630,29 +1629,30 @@ class ShapePolyTest(jtu.JaxTestCase): x_spec = jax.ShapeDtypeStruct( export.symbolic_shape("a", constraints=["a >= 2", "a <= 4"]), np.int32) - exp = export.export(f)(x_spec) + exp = export.export(jax.jit(f))(x_spec) x_2 = np.arange(2, dtype=np.int32) - res_2 = export.call(exp)(x_2) + res_2 = exp.call(x_2) self.assertAllClose(x_2[0:2], res_2) x_4 = np.arange(4, dtype=np.int32) - res_4 = export.call(exp)(x_4) + res_4 = exp.call(x_4) self.assertAllClose(x_4[1:3], res_4) with self.assertRaisesRegex( ValueError, re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")): - export.call(exp)(np.arange(1, dtype=np.int32)) + exp.call(np.arange(1, dtype=np.int32)) with self.assertRaisesRegex( ValueError, re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")): - export.call(exp)(np.arange(5, dtype=np.int32)) + exp.call(np.arange(5, dtype=np.int32)) def test_caching_with_scopes(self): f_tracing_count = 0 expected_a_bounds = (1, np.inf) + @jax.jit def f(x): # x: i32[a] nonlocal f_tracing_count f_tracing_count += 1