[export] Remove old deprecated APIs for jax.experimental.export.

See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
This commit is contained in:
George Necula 2024-06-07 06:51:26 -07:00 committed by jax authors
parent 5d6413cecc
commit 3914cb415d
8 changed files with 70 additions and 94 deletions

View File

@ -13,6 +13,11 @@ Remember to align the itemized text with the first line of an item within a list
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g. `pip install jax[cuda12]`). plugin jaxlib (e.g. `pip install jax[cuda12]`).
* JAX now requires ml_dtypes version 0.4.0 or newer. * JAX now requires ml_dtypes version 0.4.0 or newer.
* Removed backwards-compatibility support for old usage of the
`jax.experimental.export` API. It is not possible anymore to use
`from jax.experimental.export import export`, and instead you should use
`from jax.experimental import export`.
The removed functionality has been deprecated since 0.4.24.
* Deprecations * Deprecations
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use * `jax.sharding.XLACompatibleSharding` is deprecated. Please use

View File

@ -345,4 +345,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
_get_vjp=_get_vjp) _get_vjp=_get_vjp)
# We use pjit in case there are shardings in the exported module. # We use pjit in case there are shardings in the exported module.
return pjit.pjit(export.call_exported(exported))(*data.inputs) return pjit.pjit(export.call(exported))(*data.inputs)

View File

@ -22,8 +22,6 @@ from jax.experimental.export._export import (
call, call,
DisabledSafetyCheck, DisabledSafetyCheck,
default_lowering_platform, default_lowering_platform,
args_specs, # TODO: deprecate
) )
from jax._src.export.shape_poly import ( from jax._src.export.shape_poly import (
is_symbolic_dim, is_symbolic_dim,

View File

@ -1334,30 +1334,3 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
return x return x
return mlir.wrap_with_sharding_op( return mlir.wrap_with_sharding_op(
ctx, x, x_aval, x_sharding.to_proto()) ctx, x, x_aval, x_sharding.to_proto())
# TODO(necula): Previously, we had `from jax.experimental.export import export`
# Now we want to simplify the usage, and export the public APIs directly
# from `jax.experimental.export` and now `jax.experimental.export.export`
# refers to the `export` function. Since there may still be users of the
# old API in other packages, we add the old public API as attributes of the
# exported function. We will clean this up after a deprecation period.
def wrap_with_deprecation_warning(f):
msg = (f"You are using function `{f.__name__}` from "
"`jax.experimental.export.export`. You should instead use it directly "
"from `jax.experimental.export`. Instead of "
"`from jax.experimental.export import export` you should use "
"`from jax.experimental import export`.")
def wrapped_f(*args, **kwargs):
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return f(*args, **kwargs)
return wrapped_f
export.export = wrap_with_deprecation_warning(export)
export.Exported = Exported
export.call_exported = wrap_with_deprecation_warning(call_exported)
export.DisabledSafetyCheck = DisabledSafetyCheck
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
export.symbolic_shape = wrap_with_deprecation_warning(shape_poly.symbolic_shape)
export.args_specs = wrap_with_deprecation_warning(args_specs)
export.minimum_supported_serialization_version = minimum_supported_serialization_version
export.maximum_supported_serialization_version = maximum_supported_serialization_version

View File

@ -787,7 +787,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
logging.info("Running harness natively on %s", jax_device) logging.info("Running harness natively on %s", jax_device)
native_res = f_jax(x_device) native_res = f_jax(x_device)
logging.info("Running exported harness on %s", jax_device) logging.info("Running exported harness on %s", jax_device)
exported_res = export.call_exported(exp)(x_device) exported_res = export.call(exp)(x_device)
self.assertAllClose(native_res, exported_res) self.assertAllClose(native_res, exported_res)
def test_multi_platform_call_tf_graph(self): def test_multi_platform_call_tf_graph(self):

View File

@ -164,7 +164,7 @@ class PrimitiveTest(jtu.JaxTestCase):
logging.info("Running harness natively on %s", device) logging.info("Running harness natively on %s", device)
native_res = func_jax(*device_args) native_res = func_jax(*device_args)
logging.info("Running exported harness on %s", device) logging.info("Running exported harness on %s", device)
exported_res = export.call_exported(exp)(*device_args) exported_res = export.call(exp)(*device_args)
if tol is not None: if tol is not None:
logging.info(f"Using non-standard tolerance {tol}") logging.info(f"Using non-standard tolerance {tol}")
self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol) self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol)

View File

@ -196,7 +196,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32) x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x) exp_f = get_exported(f)(x)
f1 = export.call_exported(exp_f) f1 = export.call(exp_f)
self.assertAllClose(f(x), f1(x)) self.assertAllClose(f(x), f1(x))
def test_jit_static_arg(self): def test_jit_static_arg(self):
@ -210,7 +210,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32) x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x, c=0.1) exp_f = get_exported(f)(x, c=0.1)
f1 = export.call_exported(exp_f) f1 = export.call(exp_f)
self.assertAllClose(f(x, c=0.1), f1(x)) self.assertAllClose(f(x, c=0.1), f1(x))
with self.subTest("static_argnums"): with self.subTest("static_argnums"):
@ -222,7 +222,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32) x = np.arange(4, dtype=np.float32)
exp_g = get_exported(g)(x, 0.1) exp_g = get_exported(g)(x, 0.1)
g1 = export.call_exported(exp_g) g1 = export.call(exp_g)
self.assertAllClose(g(x, 0.1), g1(x)) self.assertAllClose(g(x, 0.1), g1(x))
def test_call_exported_lambda(self): def test_call_exported_lambda(self):
@ -230,7 +230,7 @@ class JaxExportTest(jtu.JaxTestCase):
f = lambda x: jnp.sin(x) f = lambda x: jnp.sin(x)
x = np.arange(4, dtype=np.float32) x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x) exp_f = get_exported(f)(x)
f1 = export.call_exported(exp_f) f1 = export.call(exp_f)
self.assertAllClose(f(x), f1(x)) self.assertAllClose(f(x), f1(x))
def test_call_name_conflict(self): def test_call_name_conflict(self):
@ -258,7 +258,7 @@ class JaxExportTest(jtu.JaxTestCase):
@jax.jit @jax.jit
def f1(x): def f1(x):
exp_f = get_exported(f)(x) exp_f = get_exported(f)(x)
return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x) return export.call(exp_f)(x) + export.call(exp_f)(x)
self.assertAllClose(2. * f(x), f1(x)) self.assertAllClose(2. * f(x), f1(x))
@ -268,7 +268,7 @@ class JaxExportTest(jtu.JaxTestCase):
y = np.arange(6, dtype=np.float32) y = np.arange(6, dtype=np.float32)
exp_f = get_exported(f)(x, y) exp_f = get_exported(f)(x, y)
f1 = export.call_exported(exp_f) f1 = export.call(exp_f)
self.assertAllClose(f(x, y), f1(x, y)) self.assertAllClose(f(x, y), f1(x, y))
def test_pytree(self): def test_pytree(self):
@ -278,7 +278,7 @@ class JaxExportTest(jtu.JaxTestCase):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp_f = get_exported(f)((a, b), a=a, b=b) exp_f = get_exported(f)((a, b), a=a, b=b)
f1 = export.call_exported(exp_f) f1 = export.call(exp_f)
self.assertAllClose(f((a, b), a=a, b=b), self.assertAllClose(f((a, b), a=a, b=b),
f1((a, b), a=a, b=b)) f1((a, b), a=a, b=b))
@ -291,7 +291,7 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
"The invocation args and kwargs must have the same pytree structure"): "The invocation args and kwargs must have the same pytree structure"):
export.call_exported(exp_f)(a, b, c=(a, b)) export.call(exp_f)(a, b, c=(a, b))
def test_error_wrong_avals(self): def test_error_wrong_avals(self):
def f(a, *, b): # a: f32[4] and b: f32[4] def f(a, *, b): # a: f32[4] and b: f32[4]
@ -301,19 +301,19 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
r"Shape mismatch for args\[0\].shape\[0\]"): r"Shape mismatch for args\[0\].shape\[0\]"):
export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) export.call(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
r"Shape mismatch for kwargs\['b'\].shape\[0\]"): r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) export.call(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
r"Rank mismatch for args\[0\]"): r"Rank mismatch for args\[0\]"):
export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4) export.call(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
r"Dtype mismatch for args\[0\]"): r"Dtype mismatch for args\[0\]"):
export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4) export.call(exp_f)(f32_4.astype(np.float16), b=f32_4)
@jtu.parameterized_filterable( @jtu.parameterized_filterable(
testcase_name=lambda kw: kw["platform"], testcase_name=lambda kw: kw["platform"],
@ -328,13 +328,13 @@ class JaxExportTest(jtu.JaxTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "The exported function .* was lowered for platform"): ValueError, "The exported function .* was lowered for platform"):
export.call_exported(exp_f)(a) export.call(exp_f)(a)
# Now try with the platform check disabled # Now try with the platform check disabled
exp_f_no_platform_check = get_exported( exp_f_no_platform_check = get_exported(
jnp.sin, lowering_platforms=(platform,), jnp.sin, lowering_platforms=(platform,),
disabled_checks=[export.DisabledSafetyCheck.platform()])(a) disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
res = export.call_exported(exp_f_no_platform_check)(a) res = export.call(exp_f_no_platform_check)(a)
self.assertAllClose(res, jnp.sin(a)) self.assertAllClose(res, jnp.sin(a))
@jtu.parameterized_filterable( @jtu.parameterized_filterable(
@ -394,7 +394,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(4, dtype=np.float32) x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f, vjp_order=1)(x) exp_f = get_exported(f, vjp_order=1)(x)
f1 = export.call_exported(exp_f) f1 = export.call(exp_f)
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))
def test_higher_order_grad(self): def test_higher_order_grad(self):
@ -402,7 +402,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.float32(4.) x = np.float32(4.)
exp_f = get_exported(f, vjp_order=3)(x) exp_f = get_exported(f, vjp_order=3)(x)
f1 = export.call_exported(exp_f) f1 = export.call(exp_f)
self.assertAllClose(jax.grad(f)(x), self.assertAllClose(jax.grad(f)(x),
jax.grad(f1)(x)) jax.grad(f1)(x))
self.assertAllClose(jax.grad(jax.grad(f))(x), self.assertAllClose(jax.grad(jax.grad(f))(x),
@ -429,7 +429,7 @@ class JaxExportTest(jtu.JaxTestCase):
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct)) (f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))
exp = get_exported(f, vjp_order=2)(xi, xf) exp = get_exported(f, vjp_order=2)(xi, xf)
fr = export.call_exported(exp) fr = export.call(exp)
res = fr(xi, xf) res = fr(xi, xf)
self.assertAllClose(res, (f_outi, f_outf)) self.assertAllClose(res, (f_outi, f_outf))
@ -463,7 +463,7 @@ class JaxExportTest(jtu.JaxTestCase):
res = f((a, b), a=a, b=b) res = f((a, b), a=a, b=b)
return res return res
def f1_exp(a, b): # For VJP, make a function without kwargs def f1_exp(a, b): # For VJP, make a function without kwargs
res = export.call_exported(exp_f)((a, b), a=a, b=b) res = export.call(exp_f)((a, b), a=a, b=b)
return res return res
jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct) jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct)
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct) exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
@ -475,13 +475,13 @@ class JaxExportTest(jtu.JaxTestCase):
a = np.arange(4, dtype=np.float32) a = np.arange(4, dtype=np.float32)
exp_f1 = get_exported(f1)(a) exp_f1 = get_exported(f1)(a)
def f2(x): def f2(x):
res1 = export.call_exported(exp_f1)(x) res1 = export.call(exp_f1)(x)
res2 = export.call_exported(exp_f1)(res1) res2 = export.call(exp_f1)(res1)
return jnp.cos(res2) return jnp.cos(res2)
exp_f2 = get_exported(f2)(a) exp_f2 = get_exported(f2)(a)
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))), self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
export.call_exported(exp_f2)(a)) export.call(exp_f2)(a))
def test_poly_export_only(self): def test_poly_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4)) a = np.arange(12, dtype=np.float32).reshape((3, 4))
@ -590,11 +590,11 @@ class JaxExportTest(jtu.JaxTestCase):
exp = get_exported(jnp.sin)( exp = get_exported(jnp.sin)(
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32)) jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
x = np.arange(30, dtype=np.float32).reshape((5, 6)) x = np.arange(30, dtype=np.float32).reshape((5, 6))
res = export.call_exported(exp)(x) res = export.call(exp)(x)
self.assertAllClose(res, np.sin(x)) self.assertAllClose(res, np.sin(x))
# A function is exported with f32[poly_spec] and is called with different arg # A function is exported with f32[poly_spec] and is called with different arg
# shapes. We use export.call_exported and we also run the shape check # shapes. We use export.call and we also run the shape check
# module. # module.
@jtu.parameterized_filterable( @jtu.parameterized_filterable(
testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore
@ -642,7 +642,7 @@ class JaxExportTest(jtu.JaxTestCase):
stack.push(self.assertRaisesRegex(Exception, expect_error)) stack.push(self.assertRaisesRegex(Exception, expect_error))
assert core.is_constant_shape(arg.shape) assert core.is_constant_shape(arg.shape)
res = export.call_exported(exp_f)(arg) res = export.call(exp_f)(arg)
if not expect_error: if not expect_error:
self.assertAllClose(res, f(arg)) self.assertAllClose(res, f(arg))
@ -741,7 +741,7 @@ class JaxExportTest(jtu.JaxTestCase):
def outer(x): # x: outer_poly_spec def outer(x): # x: outer_poly_spec
# Use an addition to test that the shapes are refined properly for the # Use an addition to test that the shapes are refined properly for the
# result of the call_exported. # result of the call_exported.
return export.call_exported(inner_exp)(x) + inner(x) return export.call(inner_exp)(x) + inner(x)
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
if expect_error_outer_exp is not None: if expect_error_outer_exp is not None:
@ -761,7 +761,7 @@ class JaxExportTest(jtu.JaxTestCase):
if expect_error_run is not None: if expect_error_run is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error_run)) stack.push(self.assertRaisesRegex(Exception, expect_error_run))
res = export.call_exported(outer_exp)(arg) res = export.call(outer_exp)(arg)
if expect_error_run is not None: if expect_error_run is not None:
return return
@ -825,7 +825,7 @@ class JaxExportTest(jtu.JaxTestCase):
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error))) stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
exp = get_exported(f_jax)( exp = get_exported(f_jax)(
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype)) jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype))
export.call_exported(exp)(x) export.call(exp)(x)
def test_poly_booleans(self): def test_poly_booleans(self):
# For booleans we use a special case ConvertOp to cast to and from # For booleans we use a special case ConvertOp to cast to and from
@ -836,7 +836,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.array([True, False, True, False], dtype=np.bool_) x = np.array([True, False, True, False], dtype=np.bool_)
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype)) x.dtype))
res = export.call_exported(exp)(x) res = export.call(exp)(x)
self.assertAllClose(f_jax(x), res) self.assertAllClose(f_jax(x), res)
@jtu.parameterized_filterable( @jtu.parameterized_filterable(
@ -857,7 +857,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.arange(6, dtype=dtype) x = np.arange(6, dtype=dtype)
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype)) x.dtype))
res = export.call_exported(exp)(x) res = export.call(exp)(x)
self.assertAllClose(f_jax(x), res) self.assertAllClose(f_jax(x), res)
def test_poly_expressions(self): def test_poly_expressions(self):
@ -874,12 +874,12 @@ class JaxExportTest(jtu.JaxTestCase):
exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype)) x.dtype))
# Call with static shapes # Call with static shapes
res = export.call_exported(exp)(x) res = export.call(exp)(x)
self.assertAllClose(res, f(x)) self.assertAllClose(res, f(x))
# Now re-export with shape polymorphism # Now re-export with shape polymorphism
x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype) x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype)
exp2 = get_exported(export.call_exported(exp))(x_spec) exp2 = get_exported(export.call(exp))(x_spec)
a = exp2.in_avals[0].shape[0] a = exp2.in_avals[0].shape[0]
self.assertEqual(exp2.out_avals[0].shape, output_shape(a)) self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
@ -892,7 +892,7 @@ class JaxExportTest(jtu.JaxTestCase):
a, = export.symbolic_shape("a") a, = export.symbolic_shape("a")
exp = export.export(f)( exp = export.export(f)(
jax.ShapeDtypeStruct((a, 4), np.float32)) jax.ShapeDtypeStruct((a, 4), np.float32))
f_exp = export.call_exported(exp) f_exp = export.call(exp)
x_jit = np.arange(12, dtype=np.float32).reshape((3, 4)) x_jit = np.arange(12, dtype=np.float32).reshape((3, 4))
res_jit = jax.jit(f_exp)(x_jit) res_jit = jax.jit(f_exp)(x_jit)
self.assertAllClose(res_jit, f(x_jit)) self.assertAllClose(res_jit, f(x_jit))
@ -931,24 +931,24 @@ class JaxExportTest(jtu.JaxTestCase):
# We apply the out_shardings for f_jax # We apply the out_shardings for f_jax
r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*", r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
re.DOTALL) re.DOTALL)
hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text() hlo = jax.jit(export.call(exp)).lower(a_device).as_text()
self.assertRegex(hlo, expected_re) self.assertRegex(hlo, expected_re)
res_exported = export.call_exported(exp)(a_device) res_exported = export.call(exp)(a_device)
self.assertAllClose(res_native, res_exported) self.assertAllClose(res_native, res_exported)
# Test error reporting # Test error reporting
with self.assertRaisesRegex( with self.assertRaisesRegex(
NotImplementedError, NotImplementedError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"): "Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
_ = export.call_exported(exp)(a) _ = export.call(exp)(a)
with self.assertRaisesRegex( with self.assertRaisesRegex(
NotImplementedError, NotImplementedError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"): "Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",)) mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
_ = jax.jit( _ = jax.jit(
export.call_exported(exp), export.call(exp),
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),) in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
)(a) )(a)
@ -996,7 +996,7 @@ class JaxExportTest(jtu.JaxTestCase):
run_mesh = Mesh(run_devices, "i") run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
res_exported = export.call_exported(exp)(b) res_exported = export.call(exp)(b)
self.assertAllClose(res_native, res_exported) self.assertAllClose(res_native, res_exported)
def test_call_with_different_no_of_devices_error_has_in_shardings(self): def test_call_with_different_no_of_devices_error_has_in_shardings(self):
@ -1024,7 +1024,7 @@ class JaxExportTest(jtu.JaxTestCase):
"Exported module .* was lowered for 1 devices and is called in a " "Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains " f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"): "non-replicated sharding annotations"):
export.call_exported(exp)(b) export.call(exp)(b)
def test_call_with_different_no_of_devices_pmap(self): def test_call_with_different_no_of_devices_pmap(self):
if len(jax.devices()) < 2: if len(jax.devices()) < 2:
@ -1042,7 +1042,7 @@ class JaxExportTest(jtu.JaxTestCase):
b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape( b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape(
(-1, 1, 100) (-1, 1, 100)
) )
res_exported = jax.pmap(export.call_exported(exp))(b) res_exported = jax.pmap(export.call(exp))(b)
self.assertAllClose(res_native, res_exported[0]) self.assertAllClose(res_native, res_exported[0])
def test_call_with_different_no_of_devices_error_has_sharding_constraint(self): def test_call_with_different_no_of_devices_error_has_sharding_constraint(self):
@ -1070,7 +1070,7 @@ class JaxExportTest(jtu.JaxTestCase):
"Exported module .* was lowered for 1 devices and is called in a " "Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains " f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"): "non-replicated sharding annotations"):
export.call_exported(exp)(b) export.call(exp)(b)
@jtu.parameterized_filterable( @jtu.parameterized_filterable(
kwargs=[ kwargs=[
@ -1108,7 +1108,7 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertLen(res_jax.addressable_shards, len(devices)) self.assertLen(res_jax.addressable_shards, len(devices))
# Test reloaded execution. # Test reloaded execution.
f_r = export.call_exported(exp) f_r = export.call(exp)
with self.assertRaisesRegex( with self.assertRaisesRegex(
Exception, Exception,
"Exported module .* was lowered for 2 devices and is " "Exported module .* was lowered for 2 devices and is "
@ -1241,14 +1241,14 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual(exp_vjp2.nr_devices, 2) self.assertEqual(exp_vjp2.nr_devices, 2)
call_mesh = Mesh(jax.devices()[:2], "e") call_mesh = Mesh(jax.devices()[:2], "e")
g1 = pjit.pjit(export.call_exported(exp_vjp), g1 = pjit.pjit(export.call(exp_vjp),
in_shardings=(NamedSharding(call_mesh, None), in_shardings=(NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None)))(x, x.T) NamedSharding(call_mesh, None)))(x, x.T)
_, f_jax_vjp = jax.vjp(f_jax, x) _, f_jax_vjp = jax.vjp(f_jax, x)
xbar = f_jax_vjp(x.T) xbar = f_jax_vjp(x.T)
self.assertAllClose(xbar, g1) self.assertAllClose(xbar, g1)
g2 = pjit.pjit(export.call_exported(exp_vjp2), g2 = pjit.pjit(export.call(exp_vjp2),
in_shardings=(NamedSharding(call_mesh, None), in_shardings=(NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None), NamedSharding(call_mesh, None),
NamedSharding(call_mesh, None)))(x, x.T, x) NamedSharding(call_mesh, None)))(x, x.T, x)
@ -1299,7 +1299,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Call with argument placed on different plaforms # Call with argument placed on different plaforms
for platform in self.__class__.platforms: for platform in self.__class__.platforms:
x_device = jax.device_put(x, jax.devices(platform)[0]) x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = export.call_exported(exp)(x_device) res_exp = export.call(exp)(x_device)
self.assertAllClose( self.assertAllClose(
res_exp, res_exp,
_testing_multi_platform_fun_expected(x, platform=platform)) _testing_multi_platform_fun_expected(x, platform=platform))
@ -1313,7 +1313,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Now serialize the call to the exported using a different sequence of # Now serialize the call to the exported using a different sequence of
# lowering platforms, but included in the lowering platforms for the # lowering platforms, but included in the lowering platforms for the
# nested exported. # nested exported.
exp2 = get_exported(export.call_exported(exp), exp2 = get_exported(export.call(exp),
lowering_platforms=("cpu", "cuda","rocm"))(x) lowering_platforms=("cpu", "cuda","rocm"))(x)
# Ensure that we do not have multiple lowerings of the exported function # Ensure that we do not have multiple lowerings of the exported function
@ -1325,7 +1325,7 @@ class JaxExportTest(jtu.JaxTestCase):
for platform in self.__class__.platforms: for platform in self.__class__.platforms:
if platform == "tpu": continue if platform == "tpu": continue
x_device = jax.device_put(x, jax.devices(platform)[0]) x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = export.call_exported(exp2)(x_device) res_exp = export.call(exp2)(x_device)
self.assertAllClose( self.assertAllClose(
res_exp, res_exp,
_testing_multi_platform_fun_expected(np.sin(x), platform=platform)) _testing_multi_platform_fun_expected(np.sin(x), platform=platform))
@ -1337,11 +1337,11 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm")) self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm"))
# Now serialize the call for the current platform. # Now serialize the call for the current platform.
exp2 = get_exported(export.call_exported(exp))(x) exp2 = get_exported(export.call(exp))(x)
module_str = str(exp2.mlir_module()) module_str = str(exp2.mlir_module())
self.assertIn("jax.uses_shape_polymorphism = true", self.assertIn("jax.uses_shape_polymorphism = true",
module_str) module_str)
res2 = export.call_exported(exp2)(x) res2 = export.call(exp2)(x)
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x)) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
def test_multi_platform_and_poly(self): def test_multi_platform_and_poly(self):
@ -1353,11 +1353,11 @@ class JaxExportTest(jtu.JaxTestCase):
jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32) jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32)
) )
x = np.arange(12, dtype=np.float32).reshape((3, 4)) x = np.arange(12, dtype=np.float32).reshape((3, 4))
res = export.call_exported(exp)(x) res = export.call(exp)(x)
self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,))) self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,)))
# Now serialize the call to the exported # Now serialize the call to the exported
exp2 = get_exported(export.call_exported(exp))(x) exp2 = get_exported(export.call(exp))(x)
res2 = export.call_exported(exp2)(x) res2 = export.call(exp2)(x)
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,))) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
def test_multi_platform_and_sharding(self): def test_multi_platform_and_sharding(self):
@ -1382,7 +1382,7 @@ class JaxExportTest(jtu.JaxTestCase):
continue continue
run_mesh = Mesh(run_devices, ("x",)) run_mesh = Mesh(run_devices, ("x",))
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None)) a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None))
res_exp = export.call_exported(exp)(a_device) res_exp = export.call(exp)(a_device)
self.assertArraysAllClose(res_native, res_exp) self.assertArraysAllClose(res_native, res_exp)
@jtu.parameterized_filterable( @jtu.parameterized_filterable(
@ -1449,7 +1449,7 @@ class JaxExportTest(jtu.JaxTestCase):
x, effect_class_name="ForTestingOrderedEffect2") + x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind( testing_primitive_with_effect_p.bind(
x, effect_class_name="ForTestingUnorderedEffect1") + x, effect_class_name="ForTestingUnorderedEffect1") +
export.call_exported(exp)(x)) export.call(exp)(x))
lowered_outer = jax.jit(f_outer).lower(x) lowered_outer = jax.jit(f_outer).lower(x)
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
@ -1497,7 +1497,7 @@ class JaxExportTest(jtu.JaxTestCase):
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re) self.assertRegex(mlir_module_str, main_expected_re)
res = export.call_exported(exp)(x) res = export.call(exp)(x)
self.assertAllClose(10. + 2. * x, res) self.assertAllClose(10. + 2. * x, res)
@jtu.parameterized_filterable( @jtu.parameterized_filterable(
@ -1541,7 +1541,7 @@ class JaxExportTest(jtu.JaxTestCase):
# Results # Results
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re) self.assertRegex(mlir_module_str, main_expected_re)
res = export.call_exported(exp)(x) res = export.call(exp)(x)
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x), self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
res) res)
@ -1608,7 +1608,7 @@ class JaxExportTest(jtu.JaxTestCase):
jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype), jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype),
jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype), jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype),
) )
res_exported = export.call_exported(exp_f)(lhs, rhs, group_sizes) res_exported = export.call(exp_f)(lhs, rhs, group_sizes)
self.assertAllClose(res_native, res_exported) self.assertAllClose(res_native, res_exported)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1283,7 +1283,7 @@ class PolyHarness(Harness):
return None return None
# Run the JAX natively and then the exported function and compare # Run the JAX natively and then the exported function and compare
res_jax_native = f_jax(*args) res_jax_native = f_jax(*args)
res_jax_exported = export.call_exported(exp)(*args) res_jax_exported = export.call(exp)(*args)
custom_assert_lims = [ custom_assert_lims = [
l for l in self.limitations if l.custom_assert is not None] l for l in self.limitations if l.custom_assert is not None]
assert len(custom_assert_lims) <= 1, custom_assert_lims assert len(custom_assert_lims) <= 1, custom_assert_lims
@ -1408,7 +1408,7 @@ class ShapePolyTest(jtu.JaxTestCase):
def f_jax(x, *, y): def f_jax(x, *, y):
return x + jnp.sin(y) return x + jnp.sin(y)
f_exported = export.call_exported( f_exported = export.call(
export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype), x.dtype),
y=jax.ShapeDtypeStruct(y.shape, y.dtype))) y=jax.ShapeDtypeStruct(y.shape, y.dtype)))
@ -1633,22 +1633,22 @@ class ShapePolyTest(jtu.JaxTestCase):
exp = export.export(f)(x_spec) exp = export.export(f)(x_spec)
x_2 = np.arange(2, dtype=np.int32) x_2 = np.arange(2, dtype=np.int32)
res_2 = export.call_exported(exp)(x_2) res_2 = export.call(exp)(x_2)
self.assertAllClose(x_2[0:2], res_2) self.assertAllClose(x_2[0:2], res_2)
x_4 = np.arange(4, dtype=np.int32) x_4 = np.arange(4, dtype=np.int32)
res_4 = export.call_exported(exp)(x_4) res_4 = export.call(exp)(x_4)
self.assertAllClose(x_4[1:3], res_4) self.assertAllClose(x_4[1:3], res_4)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")): re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")):
export.call_exported(exp)(np.arange(1, dtype=np.int32)) export.call(exp)(np.arange(1, dtype=np.int32))
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")): re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")):
export.call_exported(exp)(np.arange(5, dtype=np.int32)) export.call(exp)(np.arange(5, dtype=np.int32))
def test_caching_with_scopes(self): def test_caching_with_scopes(self):
f_tracing_count = 0 f_tracing_count = 0