[export] Remove old deprecated APIs for jax.experimental.export.

See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
This commit is contained in:
George Necula 2024-06-07 06:51:26 -07:00 committed by jax authors
parent 5d6413cecc
commit 3914cb415d
8 changed files with 70 additions and 94 deletions

View File

@ -13,6 +13,11 @@ Remember to align the itemized text with the first line of an item within a list
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g. `pip install jax[cuda12]`).
* JAX now requires ml_dtypes version 0.4.0 or newer.
* Removed backwards-compatibility support for old usage of the
`jax.experimental.export` API. It is not possible anymore to use
`from jax.experimental.export import export`, and instead you should use
`from jax.experimental import export`.
The removed functionality has been deprecated since 0.4.24.
* Deprecations
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use

View File

@ -345,4 +345,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
_get_vjp=_get_vjp)
# We use pjit in case there are shardings in the exported module.
return pjit.pjit(export.call_exported(exported))(*data.inputs)
return pjit.pjit(export.call(exported))(*data.inputs)

View File

@ -22,8 +22,6 @@ from jax.experimental.export._export import (
call,
DisabledSafetyCheck,
default_lowering_platform,
args_specs, # TODO: deprecate
)
from jax._src.export.shape_poly import (
is_symbolic_dim,

View File

@ -1334,30 +1334,3 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
return x
return mlir.wrap_with_sharding_op(
ctx, x, x_aval, x_sharding.to_proto())
# TODO(necula): Previously, we had `from jax.experimental.export import export`
# Now we want to simplify the usage, and export the public APIs directly
# from `jax.experimental.export` and now `jax.experimental.export.export`
# refers to the `export` function. Since there may still be users of the
# old API in other packages, we add the old public API as attributes of the
# exported function. We will clean this up after a deprecation period.
def wrap_with_deprecation_warning(f):
msg = (f"You are using function `{f.__name__}` from "
"`jax.experimental.export.export`. You should instead use it directly "
"from `jax.experimental.export`. Instead of "
"`from jax.experimental.export import export` you should use "
"`from jax.experimental import export`.")
def wrapped_f(*args, **kwargs):
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return f(*args, **kwargs)
return wrapped_f
export.export = wrap_with_deprecation_warning(export)
export.Exported = Exported
export.call_exported = wrap_with_deprecation_warning(call_exported)
export.DisabledSafetyCheck = DisabledSafetyCheck
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
export.symbolic_shape = wrap_with_deprecation_warning(shape_poly.symbolic_shape)
export.args_specs = wrap_with_deprecation_warning(args_specs)
export.minimum_supported_serialization_version = minimum_supported_serialization_version
export.maximum_supported_serialization_version = maximum_supported_serialization_version

View File

@ -787,7 +787,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
logging.info("Running harness natively on %s", jax_device)
native_res = f_jax(x_device)
logging.info("Running exported harness on %s", jax_device)
exported_res = export.call_exported(exp)(x_device)
exported_res = export.call(exp)(x_device)
self.assertAllClose(native_res, exported_res)
def test_multi_platform_call_tf_graph(self):

View File

@ -164,7 +164,7 @@ class PrimitiveTest(jtu.JaxTestCase):
logging.info("Running harness natively on %s", device)
native_res = func_jax(*device_args)
logging.info("Running exported harness on %s", device)
exported_res = export.call_exported(exp)(*device_args)
exported_res = export.call(exp)(*device_args)
if tol is not None:
logging.info(f"Using non-standard tolerance {tol}")
self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol)

View File

@ -196,7 +196,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x)
f1 = export.call_exported(exp_f)
f1 = export.call(exp_f)
self.assertAllClose(f(x), f1(x))
def test_jit_static_arg(self):
@ -210,7 +210,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x, c=0.1)
f1 = export.call_exported(exp_f)
f1 = export.call(exp_f)
self.assertAllClose(f(x, c=0.1), f1(x))
with self.subTest("static_argnums"):
@ -222,7 +222,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32)
exp_g = get_exported(g)(x, 0.1)
g1 = export.call_exported(exp_g)
g1 = export.call(exp_g)
self.assertAllClose(g(x, 0.1), g1(x))
def test_call_exported_lambda(self):
@ -230,7 +230,7 @@ class JaxExportTest(jtu.JaxTestCase):
f = lambda x: jnp.sin(x)
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x)
f1 = export.call_exported(exp_f)
f1 = export.call(exp_f)
self.assertAllClose(f(x), f1(x))
def test_call_name_conflict(self):
@ -258,7 +258,7 @@ class JaxExportTest(jtu.JaxTestCase):
@jax.jit
def f1(x):
exp_f = get_exported(f)(x)
return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x)
return export.call(exp_f)(x) + export.call(exp_f)(x)
self.assertAllClose(2. * f(x), f1(x))
@ -268,7 +268,7 @@ class JaxExportTest(jtu.JaxTestCase):
y = np.arange(6, dtype=np.float32)
exp_f = get_exported(f)(x, y)
f1 = export.call_exported(exp_f)
f1 = export.call(exp_f)
self.assertAllClose(f(x, y), f1(x, y))
def test_pytree(self):
@ -278,7 +278,7 @@ class JaxExportTest(jtu.JaxTestCase):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp_f = get_exported(f)((a, b), a=a, b=b)
f1 = export.call_exported(exp_f)
f1 = export.call(exp_f)
self.assertAllClose(f((a, b), a=a, b=b),
f1((a, b), a=a, b=b))
@ -291,7 +291,7 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"The invocation args and kwargs must have the same pytree structure"):
export.call_exported(exp_f)(a, b, c=(a, b))
export.call(exp_f)(a, b, c=(a, b))
def test_error_wrong_avals(self):
def f(a, *, b): # a: f32[4] and b: f32[4]
@ -301,19 +301,19 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for args\[0\].shape\[0\]"):
export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
export.call(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
export.call(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
with self.assertRaisesRegex(ValueError,
r"Rank mismatch for args\[0\]"):
export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
export.call(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Dtype mismatch for args\[0\]"):
export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4)
export.call(exp_f)(f32_4.astype(np.float16), b=f32_4)
@jtu.parameterized_filterable(
testcase_name=lambda kw: kw["platform"],
@ -328,13 +328,13 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError, "The exported function .* was lowered for platform"):
export.call_exported(exp_f)(a)
export.call(exp_f)(a)
# Now try with the platform check disabled
exp_f_no_platform_check = get_exported(
jnp.sin, lowering_platforms=(platform,),
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
res = export.call_exported(exp_f_no_platform_check)(a)
res = export.call(exp_f_no_platform_check)(a)
self.assertAllClose(res, jnp.sin(a))
@jtu.parameterized_filterable(
@ -394,7 +394,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f, vjp_order=1)(x)
f1 = export.call_exported(exp_f)
f1 = export.call(exp_f)
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))
def test_higher_order_grad(self):
@ -402,7 +402,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.float32(4.)
exp_f = get_exported(f, vjp_order=3)(x)
f1 = export.call_exported(exp_f)
f1 = export.call(exp_f)
self.assertAllClose(jax.grad(f)(x),
jax.grad(f1)(x))
self.assertAllClose(jax.grad(jax.grad(f))(x),
@ -429,7 +429,7 @@ class JaxExportTest(jtu.JaxTestCase):
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))
exp = get_exported(f, vjp_order=2)(xi, xf)
fr = export.call_exported(exp)
fr = export.call(exp)
res = fr(xi, xf)
self.assertAllClose(res, (f_outi, f_outf))
@ -463,7 +463,7 @@ class JaxExportTest(jtu.JaxTestCase):
res = f((a, b), a=a, b=b)
return res
def f1_exp(a, b): # For VJP, make a function without kwargs
res = export.call_exported(exp_f)((a, b), a=a, b=b)
res = export.call(exp_f)((a, b), a=a, b=b)
return res
jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct)
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
@ -475,13 +475,13 @@ class JaxExportTest(jtu.JaxTestCase):
a = np.arange(4, dtype=np.float32)
exp_f1 = get_exported(f1)(a)
def f2(x):
res1 = export.call_exported(exp_f1)(x)
res2 = export.call_exported(exp_f1)(res1)
res1 = export.call(exp_f1)(x)
res2 = export.call(exp_f1)(res1)
return jnp.cos(res2)
exp_f2 = get_exported(f2)(a)
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
export.call_exported(exp_f2)(a))
export.call(exp_f2)(a))
def test_poly_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
@ -590,11 +590,11 @@ class JaxExportTest(jtu.JaxTestCase):
exp = get_exported(jnp.sin)(
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
x = np.arange(30, dtype=np.float32).reshape((5, 6))
res = export.call_exported(exp)(x)
res = export.call(exp)(x)
self.assertAllClose(res, np.sin(x))
# A function is exported with f32[poly_spec] and is called with different arg
# shapes. We use export.call_exported and we also run the shape check
# shapes. We use export.call and we also run the shape check
# module.
@jtu.parameterized_filterable(
testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore
@ -642,7 +642,7 @@ class JaxExportTest(jtu.JaxTestCase):
stack.push(self.assertRaisesRegex(Exception, expect_error))
assert core.is_constant_shape(arg.shape)
res = export.call_exported(exp_f)(arg)
res = export.call(exp_f)(arg)
if not expect_error:
self.assertAllClose(res, f(arg))
@ -741,7 +741,7 @@ class JaxExportTest(jtu.JaxTestCase):
def outer(x): # x: outer_poly_spec
# Use an addition to test that the shapes are refined properly for the
# result of the call_exported.
return export.call_exported(inner_exp)(x) + inner(x)
return export.call(inner_exp)(x) + inner(x)
with contextlib.ExitStack() as stack:
if expect_error_outer_exp is not None:
@ -761,7 +761,7 @@ class JaxExportTest(jtu.JaxTestCase):
if expect_error_run is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
res = export.call_exported(outer_exp)(arg)
res = export.call(outer_exp)(arg)
if expect_error_run is not None:
return
@ -825,7 +825,7 @@ class JaxExportTest(jtu.JaxTestCase):
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
exp = get_exported(f_jax)(
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype))
export.call_exported(exp)(x)
export.call(exp)(x)
def test_poly_booleans(self):
# For booleans we use a special case ConvertOp to cast to and from
@ -836,7 +836,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.array([True, False, True, False], dtype=np.bool_)
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype))
res = export.call_exported(exp)(x)
res = export.call(exp)(x)
self.assertAllClose(f_jax(x), res)
@jtu.parameterized_filterable(
@ -857,7 +857,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(6, dtype=dtype)
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype))
res = export.call_exported(exp)(x)
res = export.call(exp)(x)
self.assertAllClose(f_jax(x), res)
def test_poly_expressions(self):
@ -874,12 +874,12 @@ class JaxExportTest(jtu.JaxTestCase):
exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype))
# Call with static shapes
res = export.call_exported(exp)(x)
res = export.call(exp)(x)
self.assertAllClose(res, f(x))
# Now re-export with shape polymorphism
x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype)
exp2 = get_exported(export.call_exported(exp))(x_spec)
exp2 = get_exported(export.call(exp))(x_spec)
a = exp2.in_avals[0].shape[0]
self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
@ -892,7 +892,7 @@ class JaxExportTest(jtu.JaxTestCase):
a, = export.symbolic_shape("a")
exp = export.export(f)(
jax.ShapeDtypeStruct((a, 4), np.float32))
f_exp = export.call_exported(exp)
f_exp = export.call(exp)
x_jit = np.arange(12, dtype=np.float32).reshape((3, 4))
res_jit = jax.jit(f_exp)(x_jit)
self.assertAllClose(res_jit, f(x_jit))
@ -931,24 +931,24 @@ class JaxExportTest(jtu.JaxTestCase):
# We apply the out_shardings for f_jax
r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
re.DOTALL)
hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text()
hlo = jax.jit(export.call(exp)).lower(a_device).as_text()
self.assertRegex(hlo, expected_re)
res_exported = export.call_exported(exp)(a_device)
res_exported = export.call(exp)(a_device)
self.assertAllClose(res_native, res_exported)
# Test error reporting
with self.assertRaisesRegex(
NotImplementedError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
_ = export.call_exported(exp)(a)
_ = export.call(exp)(a)
with self.assertRaisesRegex(
NotImplementedError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
_ = jax.jit(
export.call_exported(exp),
export.call(exp),
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
)(a)
@ -996,7 +996,7 @@ class JaxExportTest(jtu.JaxTestCase):
run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
res_exported = export.call_exported(exp)(b)
res_exported = export.call(exp)(b)
self.assertAllClose(res_native, res_exported)
def test_call_with_different_no_of_devices_error_has_in_shardings(self):
@ -1024,7 +1024,7 @@ class JaxExportTest(jtu.JaxTestCase):
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"):
export.call_exported(exp)(b)
export.call(exp)(b)
def test_call_with_different_no_of_devices_pmap(self):
if len(jax.devices()) < 2:
@ -1042,7 +1042,7 @@ class JaxExportTest(jtu.JaxTestCase):
b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape(
(-1, 1, 100)
)
res_exported = jax.pmap(export.call_exported(exp))(b)
res_exported = jax.pmap(export.call(exp))(b)
self.assertAllClose(res_native, res_exported[0])
def test_call_with_different_no_of_devices_error_has_sharding_constraint(self):
@ -1070,7 +1070,7 @@ class JaxExportTest(jtu.JaxTestCase):
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"):
export.call_exported(exp)(b)
export.call(exp)(b)
@jtu.parameterized_filterable(
kwargs=[
@ -1108,7 +1108,7 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertLen(res_jax.addressable_shards, len(devices))
# Test reloaded execution.
f_r = export.call_exported(exp)
f_r = export.call(exp)
with self.assertRaisesRegex(
Exception,
"Exported module .* was lowered for 2 devices and is "
@ -1241,14 +1241,14 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual(exp_vjp2.nr_devices, 2)
call_mesh = Mesh(jax.devices()[:2], "e")
g1 = pjit.pjit(export.call_exported(exp_vjp),
g1 = pjit.pjit(export.call(exp_vjp),
in_shardings=(NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None)))(x, x.T)
_, f_jax_vjp = jax.vjp(f_jax, x)
xbar = f_jax_vjp(x.T)
self.assertAllClose(xbar, g1)
g2 = pjit.pjit(export.call_exported(exp_vjp2),
g2 = pjit.pjit(export.call(exp_vjp2),
in_shardings=(NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None)))(x, x.T, x)
@ -1299,7 +1299,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Call with argument placed on different plaforms
for platform in self.__class__.platforms:
x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = export.call_exported(exp)(x_device)
res_exp = export.call(exp)(x_device)
self.assertAllClose(
res_exp,
_testing_multi_platform_fun_expected(x, platform=platform))
@ -1313,7 +1313,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Now serialize the call to the exported using a different sequence of
# lowering platforms, but included in the lowering platforms for the
# nested exported.
exp2 = get_exported(export.call_exported(exp),
exp2 = get_exported(export.call(exp),
lowering_platforms=("cpu", "cuda","rocm"))(x)
# Ensure that we do not have multiple lowerings of the exported function
@ -1325,7 +1325,7 @@ class JaxExportTest(jtu.JaxTestCase):
for platform in self.__class__.platforms:
if platform == "tpu": continue
x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = export.call_exported(exp2)(x_device)
res_exp = export.call(exp2)(x_device)
self.assertAllClose(
res_exp,
_testing_multi_platform_fun_expected(np.sin(x), platform=platform))
@ -1337,11 +1337,11 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm"))
# Now serialize the call for the current platform.
exp2 = get_exported(export.call_exported(exp))(x)
exp2 = get_exported(export.call(exp))(x)
module_str = str(exp2.mlir_module())
self.assertIn("jax.uses_shape_polymorphism = true",
module_str)
res2 = export.call_exported(exp2)(x)
res2 = export.call(exp2)(x)
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
def test_multi_platform_and_poly(self):
@ -1353,11 +1353,11 @@ class JaxExportTest(jtu.JaxTestCase):
jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32)
)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
res = export.call_exported(exp)(x)
res = export.call(exp)(x)
self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,)))
# Now serialize the call to the exported
exp2 = get_exported(export.call_exported(exp))(x)
res2 = export.call_exported(exp2)(x)
exp2 = get_exported(export.call(exp))(x)
res2 = export.call(exp2)(x)
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
def test_multi_platform_and_sharding(self):
@ -1382,7 +1382,7 @@ class JaxExportTest(jtu.JaxTestCase):
continue
run_mesh = Mesh(run_devices, ("x",))
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None))
res_exp = export.call_exported(exp)(a_device)
res_exp = export.call(exp)(a_device)
self.assertArraysAllClose(res_native, res_exp)
@jtu.parameterized_filterable(
@ -1449,7 +1449,7 @@ class JaxExportTest(jtu.JaxTestCase):
x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(
x, effect_class_name="ForTestingUnorderedEffect1") +
export.call_exported(exp)(x))
export.call(exp)(x))
lowered_outer = jax.jit(f_outer).lower(x)
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
@ -1497,7 +1497,7 @@ class JaxExportTest(jtu.JaxTestCase):
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
res = export.call_exported(exp)(x)
res = export.call(exp)(x)
self.assertAllClose(10. + 2. * x, res)
@jtu.parameterized_filterable(
@ -1541,7 +1541,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
res = export.call_exported(exp)(x)
res = export.call(exp)(x)
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
res)
@ -1608,7 +1608,7 @@ class JaxExportTest(jtu.JaxTestCase):
jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype),
jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype),
)
res_exported = export.call_exported(exp_f)(lhs, rhs, group_sizes)
res_exported = export.call(exp_f)(lhs, rhs, group_sizes)
self.assertAllClose(res_native, res_exported)
if __name__ == "__main__":

View File

@ -1283,7 +1283,7 @@ class PolyHarness(Harness):
return None
# Run the JAX natively and then the exported function and compare
res_jax_native = f_jax(*args)
res_jax_exported = export.call_exported(exp)(*args)
res_jax_exported = export.call(exp)(*args)
custom_assert_lims = [
l for l in self.limitations if l.custom_assert is not None]
assert len(custom_assert_lims) <= 1, custom_assert_lims
@ -1408,7 +1408,7 @@ class ShapePolyTest(jtu.JaxTestCase):
def f_jax(x, *, y):
return x + jnp.sin(y)
f_exported = export.call_exported(
f_exported = export.call(
export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype),
y=jax.ShapeDtypeStruct(y.shape, y.dtype)))
@ -1633,22 +1633,22 @@ class ShapePolyTest(jtu.JaxTestCase):
exp = export.export(f)(x_spec)
x_2 = np.arange(2, dtype=np.int32)
res_2 = export.call_exported(exp)(x_2)
res_2 = export.call(exp)(x_2)
self.assertAllClose(x_2[0:2], res_2)
x_4 = np.arange(4, dtype=np.int32)
res_4 = export.call_exported(exp)(x_4)
res_4 = export.call(exp)(x_4)
self.assertAllClose(x_4[1:3], res_4)
with self.assertRaisesRegex(
ValueError,
re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")):
export.call_exported(exp)(np.arange(1, dtype=np.int32))
export.call(exp)(np.arange(1, dtype=np.int32))
with self.assertRaisesRegex(
ValueError,
re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")):
export.call_exported(exp)(np.arange(5, dtype=np.int32))
export.call(exp)(np.arange(5, dtype=np.int32))
def test_caching_with_scopes(self):
f_tracing_count = 0