mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md. The deprecation period has passed. Also replace deprecated .call_exported with .call in tests. PiperOrigin-RevId: 641236222
This commit is contained in:
parent
5d6413cecc
commit
3914cb415d
@ -13,6 +13,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
|
||||
plugin jaxlib (e.g. `pip install jax[cuda12]`).
|
||||
* JAX now requires ml_dtypes version 0.4.0 or newer.
|
||||
* Removed backwards-compatibility support for old usage of the
|
||||
`jax.experimental.export` API. It is not possible anymore to use
|
||||
`from jax.experimental.export import export`, and instead you should use
|
||||
`from jax.experimental import export`.
|
||||
The removed functionality has been deprecated since 0.4.24.
|
||||
|
||||
* Deprecations
|
||||
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use
|
||||
|
@ -345,4 +345,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
_get_vjp=_get_vjp)
|
||||
|
||||
# We use pjit in case there are shardings in the exported module.
|
||||
return pjit.pjit(export.call_exported(exported))(*data.inputs)
|
||||
return pjit.pjit(export.call(exported))(*data.inputs)
|
||||
|
@ -22,8 +22,6 @@ from jax.experimental.export._export import (
|
||||
call,
|
||||
DisabledSafetyCheck,
|
||||
default_lowering_platform,
|
||||
|
||||
args_specs, # TODO: deprecate
|
||||
)
|
||||
from jax._src.export.shape_poly import (
|
||||
is_symbolic_dim,
|
||||
|
@ -1334,30 +1334,3 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
|
||||
return x
|
||||
return mlir.wrap_with_sharding_op(
|
||||
ctx, x, x_aval, x_sharding.to_proto())
|
||||
|
||||
# TODO(necula): Previously, we had `from jax.experimental.export import export`
|
||||
# Now we want to simplify the usage, and export the public APIs directly
|
||||
# from `jax.experimental.export` and now `jax.experimental.export.export`
|
||||
# refers to the `export` function. Since there may still be users of the
|
||||
# old API in other packages, we add the old public API as attributes of the
|
||||
# exported function. We will clean this up after a deprecation period.
|
||||
def wrap_with_deprecation_warning(f):
|
||||
msg = (f"You are using function `{f.__name__}` from "
|
||||
"`jax.experimental.export.export`. You should instead use it directly "
|
||||
"from `jax.experimental.export`. Instead of "
|
||||
"`from jax.experimental.export import export` you should use "
|
||||
"`from jax.experimental import export`.")
|
||||
def wrapped_f(*args, **kwargs):
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
return f(*args, **kwargs)
|
||||
return wrapped_f
|
||||
|
||||
export.export = wrap_with_deprecation_warning(export)
|
||||
export.Exported = Exported
|
||||
export.call_exported = wrap_with_deprecation_warning(call_exported)
|
||||
export.DisabledSafetyCheck = DisabledSafetyCheck
|
||||
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
|
||||
export.symbolic_shape = wrap_with_deprecation_warning(shape_poly.symbolic_shape)
|
||||
export.args_specs = wrap_with_deprecation_warning(args_specs)
|
||||
export.minimum_supported_serialization_version = minimum_supported_serialization_version
|
||||
export.maximum_supported_serialization_version = maximum_supported_serialization_version
|
||||
|
@ -787,7 +787,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
logging.info("Running harness natively on %s", jax_device)
|
||||
native_res = f_jax(x_device)
|
||||
logging.info("Running exported harness on %s", jax_device)
|
||||
exported_res = export.call_exported(exp)(x_device)
|
||||
exported_res = export.call(exp)(x_device)
|
||||
self.assertAllClose(native_res, exported_res)
|
||||
|
||||
def test_multi_platform_call_tf_graph(self):
|
||||
|
@ -164,7 +164,7 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
logging.info("Running harness natively on %s", device)
|
||||
native_res = func_jax(*device_args)
|
||||
logging.info("Running exported harness on %s", device)
|
||||
exported_res = export.call_exported(exp)(*device_args)
|
||||
exported_res = export.call(exp)(*device_args)
|
||||
if tol is not None:
|
||||
logging.info(f"Using non-standard tolerance {tol}")
|
||||
self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol)
|
||||
|
@ -196,7 +196,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x)
|
||||
|
||||
f1 = export.call_exported(exp_f)
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(f(x), f1(x))
|
||||
|
||||
def test_jit_static_arg(self):
|
||||
@ -210,7 +210,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x, c=0.1)
|
||||
|
||||
f1 = export.call_exported(exp_f)
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(f(x, c=0.1), f1(x))
|
||||
|
||||
with self.subTest("static_argnums"):
|
||||
@ -222,7 +222,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_g = get_exported(g)(x, 0.1)
|
||||
|
||||
g1 = export.call_exported(exp_g)
|
||||
g1 = export.call(exp_g)
|
||||
self.assertAllClose(g(x, 0.1), g1(x))
|
||||
|
||||
def test_call_exported_lambda(self):
|
||||
@ -230,7 +230,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
f = lambda x: jnp.sin(x)
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x)
|
||||
f1 = export.call_exported(exp_f)
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(f(x), f1(x))
|
||||
|
||||
def test_call_name_conflict(self):
|
||||
@ -258,7 +258,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f1(x):
|
||||
exp_f = get_exported(f)(x)
|
||||
return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x)
|
||||
return export.call(exp_f)(x) + export.call(exp_f)(x)
|
||||
|
||||
self.assertAllClose(2. * f(x), f1(x))
|
||||
|
||||
@ -268,7 +268,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
y = np.arange(6, dtype=np.float32)
|
||||
exp_f = get_exported(f)(x, y)
|
||||
|
||||
f1 = export.call_exported(exp_f)
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(f(x, y), f1(x, y))
|
||||
|
||||
def test_pytree(self):
|
||||
@ -278,7 +278,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
|
||||
|
||||
exp_f = get_exported(f)((a, b), a=a, b=b)
|
||||
f1 = export.call_exported(exp_f)
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(f((a, b), a=a, b=b),
|
||||
f1((a, b), a=a, b=b))
|
||||
|
||||
@ -291,7 +291,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The invocation args and kwargs must have the same pytree structure"):
|
||||
export.call_exported(exp_f)(a, b, c=(a, b))
|
||||
export.call(exp_f)(a, b, c=(a, b))
|
||||
|
||||
def test_error_wrong_avals(self):
|
||||
def f(a, *, b): # a: f32[4] and b: f32[4]
|
||||
@ -301,19 +301,19 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Shape mismatch for args\[0\].shape\[0\]"):
|
||||
export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
|
||||
export.call(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
|
||||
export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
|
||||
export.call(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Rank mismatch for args\[0\]"):
|
||||
export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
|
||||
export.call(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Dtype mismatch for args\[0\]"):
|
||||
export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4)
|
||||
export.call(exp_f)(f32_4.astype(np.float16), b=f32_4)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw: kw["platform"],
|
||||
@ -328,13 +328,13 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The exported function .* was lowered for platform"):
|
||||
export.call_exported(exp_f)(a)
|
||||
export.call(exp_f)(a)
|
||||
|
||||
# Now try with the platform check disabled
|
||||
exp_f_no_platform_check = get_exported(
|
||||
jnp.sin, lowering_platforms=(platform,),
|
||||
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
|
||||
res = export.call_exported(exp_f_no_platform_check)(a)
|
||||
res = export.call(exp_f_no_platform_check)(a)
|
||||
self.assertAllClose(res, jnp.sin(a))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -394,7 +394,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = get_exported(f, vjp_order=1)(x)
|
||||
|
||||
f1 = export.call_exported(exp_f)
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))
|
||||
|
||||
def test_higher_order_grad(self):
|
||||
@ -402,7 +402,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.float32(4.)
|
||||
exp_f = get_exported(f, vjp_order=3)(x)
|
||||
|
||||
f1 = export.call_exported(exp_f)
|
||||
f1 = export.call(exp_f)
|
||||
self.assertAllClose(jax.grad(f)(x),
|
||||
jax.grad(f1)(x))
|
||||
self.assertAllClose(jax.grad(jax.grad(f))(x),
|
||||
@ -429,7 +429,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))
|
||||
|
||||
exp = get_exported(f, vjp_order=2)(xi, xf)
|
||||
fr = export.call_exported(exp)
|
||||
fr = export.call(exp)
|
||||
|
||||
res = fr(xi, xf)
|
||||
self.assertAllClose(res, (f_outi, f_outf))
|
||||
@ -463,7 +463,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
res = f((a, b), a=a, b=b)
|
||||
return res
|
||||
def f1_exp(a, b): # For VJP, make a function without kwargs
|
||||
res = export.call_exported(exp_f)((a, b), a=a, b=b)
|
||||
res = export.call(exp_f)((a, b), a=a, b=b)
|
||||
return res
|
||||
jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct)
|
||||
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
|
||||
@ -475,13 +475,13 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
exp_f1 = get_exported(f1)(a)
|
||||
def f2(x):
|
||||
res1 = export.call_exported(exp_f1)(x)
|
||||
res2 = export.call_exported(exp_f1)(res1)
|
||||
res1 = export.call(exp_f1)(x)
|
||||
res2 = export.call(exp_f1)(res1)
|
||||
return jnp.cos(res2)
|
||||
exp_f2 = get_exported(f2)(a)
|
||||
|
||||
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
|
||||
export.call_exported(exp_f2)(a))
|
||||
export.call(exp_f2)(a))
|
||||
|
||||
def test_poly_export_only(self):
|
||||
a = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
@ -590,11 +590,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp = get_exported(jnp.sin)(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
|
||||
x = np.arange(30, dtype=np.float32).reshape((5, 6))
|
||||
res = export.call_exported(exp)(x)
|
||||
res = export.call(exp)(x)
|
||||
self.assertAllClose(res, np.sin(x))
|
||||
|
||||
# A function is exported with f32[poly_spec] and is called with different arg
|
||||
# shapes. We use export.call_exported and we also run the shape check
|
||||
# shapes. We use export.call and we also run the shape check
|
||||
# module.
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore
|
||||
@ -642,7 +642,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error))
|
||||
|
||||
assert core.is_constant_shape(arg.shape)
|
||||
res = export.call_exported(exp_f)(arg)
|
||||
res = export.call(exp_f)(arg)
|
||||
|
||||
if not expect_error:
|
||||
self.assertAllClose(res, f(arg))
|
||||
@ -741,7 +741,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def outer(x): # x: outer_poly_spec
|
||||
# Use an addition to test that the shapes are refined properly for the
|
||||
# result of the call_exported.
|
||||
return export.call_exported(inner_exp)(x) + inner(x)
|
||||
return export.call(inner_exp)(x) + inner(x)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error_outer_exp is not None:
|
||||
@ -761,7 +761,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
if expect_error_run is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
|
||||
|
||||
res = export.call_exported(outer_exp)(arg)
|
||||
res = export.call(outer_exp)(arg)
|
||||
|
||||
if expect_error_run is not None:
|
||||
return
|
||||
@ -825,7 +825,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
|
||||
exp = get_exported(f_jax)(
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype))
|
||||
export.call_exported(exp)(x)
|
||||
export.call(exp)(x)
|
||||
|
||||
def test_poly_booleans(self):
|
||||
# For booleans we use a special case ConvertOp to cast to and from
|
||||
@ -836,7 +836,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.array([True, False, True, False], dtype=np.bool_)
|
||||
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
||||
x.dtype))
|
||||
res = export.call_exported(exp)(x)
|
||||
res = export.call(exp)(x)
|
||||
self.assertAllClose(f_jax(x), res)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -857,7 +857,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x = np.arange(6, dtype=dtype)
|
||||
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
||||
x.dtype))
|
||||
res = export.call_exported(exp)(x)
|
||||
res = export.call(exp)(x)
|
||||
self.assertAllClose(f_jax(x), res)
|
||||
|
||||
def test_poly_expressions(self):
|
||||
@ -874,12 +874,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
||||
x.dtype))
|
||||
# Call with static shapes
|
||||
res = export.call_exported(exp)(x)
|
||||
res = export.call(exp)(x)
|
||||
self.assertAllClose(res, f(x))
|
||||
|
||||
# Now re-export with shape polymorphism
|
||||
x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype)
|
||||
exp2 = get_exported(export.call_exported(exp))(x_spec)
|
||||
exp2 = get_exported(export.call(exp))(x_spec)
|
||||
a = exp2.in_avals[0].shape[0]
|
||||
self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
|
||||
|
||||
@ -892,7 +892,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
a, = export.symbolic_shape("a")
|
||||
exp = export.export(f)(
|
||||
jax.ShapeDtypeStruct((a, 4), np.float32))
|
||||
f_exp = export.call_exported(exp)
|
||||
f_exp = export.call(exp)
|
||||
x_jit = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
res_jit = jax.jit(f_exp)(x_jit)
|
||||
self.assertAllClose(res_jit, f(x_jit))
|
||||
@ -931,24 +931,24 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# We apply the out_shardings for f_jax
|
||||
r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
|
||||
re.DOTALL)
|
||||
hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text()
|
||||
hlo = jax.jit(export.call(exp)).lower(a_device).as_text()
|
||||
self.assertRegex(hlo, expected_re)
|
||||
|
||||
res_exported = export.call_exported(exp)(a_device)
|
||||
res_exported = export.call(exp)(a_device)
|
||||
self.assertAllClose(res_native, res_exported)
|
||||
|
||||
# Test error reporting
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
_ = export.call_exported(exp)(a)
|
||||
_ = export.call(exp)(a)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
|
||||
_ = jax.jit(
|
||||
export.call_exported(exp),
|
||||
export.call(exp),
|
||||
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
|
||||
)(a)
|
||||
|
||||
@ -996,7 +996,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
run_mesh = Mesh(run_devices, "i")
|
||||
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
|
||||
|
||||
res_exported = export.call_exported(exp)(b)
|
||||
res_exported = export.call(exp)(b)
|
||||
self.assertAllClose(res_native, res_exported)
|
||||
|
||||
def test_call_with_different_no_of_devices_error_has_in_shardings(self):
|
||||
@ -1024,7 +1024,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
"Exported module .* was lowered for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* module contains "
|
||||
"non-replicated sharding annotations"):
|
||||
export.call_exported(exp)(b)
|
||||
export.call(exp)(b)
|
||||
|
||||
def test_call_with_different_no_of_devices_pmap(self):
|
||||
if len(jax.devices()) < 2:
|
||||
@ -1042,7 +1042,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape(
|
||||
(-1, 1, 100)
|
||||
)
|
||||
res_exported = jax.pmap(export.call_exported(exp))(b)
|
||||
res_exported = jax.pmap(export.call(exp))(b)
|
||||
self.assertAllClose(res_native, res_exported[0])
|
||||
|
||||
def test_call_with_different_no_of_devices_error_has_sharding_constraint(self):
|
||||
@ -1070,7 +1070,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
"Exported module .* was lowered for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* module contains "
|
||||
"non-replicated sharding annotations"):
|
||||
export.call_exported(exp)(b)
|
||||
export.call(exp)(b)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
@ -1108,7 +1108,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertLen(res_jax.addressable_shards, len(devices))
|
||||
|
||||
# Test reloaded execution.
|
||||
f_r = export.call_exported(exp)
|
||||
f_r = export.call(exp)
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"Exported module .* was lowered for 2 devices and is "
|
||||
@ -1241,14 +1241,14 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertEqual(exp_vjp2.nr_devices, 2)
|
||||
call_mesh = Mesh(jax.devices()[:2], "e")
|
||||
|
||||
g1 = pjit.pjit(export.call_exported(exp_vjp),
|
||||
g1 = pjit.pjit(export.call(exp_vjp),
|
||||
in_shardings=(NamedSharding(call_mesh, None),
|
||||
NamedSharding(call_mesh, None)))(x, x.T)
|
||||
_, f_jax_vjp = jax.vjp(f_jax, x)
|
||||
xbar = f_jax_vjp(x.T)
|
||||
self.assertAllClose(xbar, g1)
|
||||
|
||||
g2 = pjit.pjit(export.call_exported(exp_vjp2),
|
||||
g2 = pjit.pjit(export.call(exp_vjp2),
|
||||
in_shardings=(NamedSharding(call_mesh, None),
|
||||
NamedSharding(call_mesh, None),
|
||||
NamedSharding(call_mesh, None)))(x, x.T, x)
|
||||
@ -1299,7 +1299,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# Call with argument placed on different plaforms
|
||||
for platform in self.__class__.platforms:
|
||||
x_device = jax.device_put(x, jax.devices(platform)[0])
|
||||
res_exp = export.call_exported(exp)(x_device)
|
||||
res_exp = export.call(exp)(x_device)
|
||||
self.assertAllClose(
|
||||
res_exp,
|
||||
_testing_multi_platform_fun_expected(x, platform=platform))
|
||||
@ -1313,7 +1313,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# Now serialize the call to the exported using a different sequence of
|
||||
# lowering platforms, but included in the lowering platforms for the
|
||||
# nested exported.
|
||||
exp2 = get_exported(export.call_exported(exp),
|
||||
exp2 = get_exported(export.call(exp),
|
||||
lowering_platforms=("cpu", "cuda","rocm"))(x)
|
||||
|
||||
# Ensure that we do not have multiple lowerings of the exported function
|
||||
@ -1325,7 +1325,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
for platform in self.__class__.platforms:
|
||||
if platform == "tpu": continue
|
||||
x_device = jax.device_put(x, jax.devices(platform)[0])
|
||||
res_exp = export.call_exported(exp2)(x_device)
|
||||
res_exp = export.call(exp2)(x_device)
|
||||
self.assertAllClose(
|
||||
res_exp,
|
||||
_testing_multi_platform_fun_expected(np.sin(x), platform=platform))
|
||||
@ -1337,11 +1337,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm"))
|
||||
|
||||
# Now serialize the call for the current platform.
|
||||
exp2 = get_exported(export.call_exported(exp))(x)
|
||||
exp2 = get_exported(export.call(exp))(x)
|
||||
module_str = str(exp2.mlir_module())
|
||||
self.assertIn("jax.uses_shape_polymorphism = true",
|
||||
module_str)
|
||||
res2 = export.call_exported(exp2)(x)
|
||||
res2 = export.call(exp2)(x)
|
||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
|
||||
|
||||
def test_multi_platform_and_poly(self):
|
||||
@ -1353,11 +1353,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32)
|
||||
)
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
res = export.call_exported(exp)(x)
|
||||
res = export.call(exp)(x)
|
||||
self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,)))
|
||||
# Now serialize the call to the exported
|
||||
exp2 = get_exported(export.call_exported(exp))(x)
|
||||
res2 = export.call_exported(exp2)(x)
|
||||
exp2 = get_exported(export.call(exp))(x)
|
||||
res2 = export.call(exp2)(x)
|
||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
|
||||
|
||||
def test_multi_platform_and_sharding(self):
|
||||
@ -1382,7 +1382,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
continue
|
||||
run_mesh = Mesh(run_devices, ("x",))
|
||||
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None))
|
||||
res_exp = export.call_exported(exp)(a_device)
|
||||
res_exp = export.call(exp)(a_device)
|
||||
self.assertArraysAllClose(res_native, res_exp)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -1449,7 +1449,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
x, effect_class_name="ForTestingOrderedEffect2") +
|
||||
testing_primitive_with_effect_p.bind(
|
||||
x, effect_class_name="ForTestingUnorderedEffect1") +
|
||||
export.call_exported(exp)(x))
|
||||
export.call(exp)(x))
|
||||
|
||||
lowered_outer = jax.jit(f_outer).lower(x)
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
@ -1497,7 +1497,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
|
||||
res = export.call_exported(exp)(x)
|
||||
res = export.call(exp)(x)
|
||||
self.assertAllClose(10. + 2. * x, res)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -1541,7 +1541,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
res = export.call_exported(exp)(x)
|
||||
res = export.call(exp)(x)
|
||||
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
|
||||
res)
|
||||
|
||||
@ -1608,7 +1608,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype),
|
||||
jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype),
|
||||
)
|
||||
res_exported = export.call_exported(exp_f)(lhs, rhs, group_sizes)
|
||||
res_exported = export.call(exp_f)(lhs, rhs, group_sizes)
|
||||
self.assertAllClose(res_native, res_exported)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1283,7 +1283,7 @@ class PolyHarness(Harness):
|
||||
return None
|
||||
# Run the JAX natively and then the exported function and compare
|
||||
res_jax_native = f_jax(*args)
|
||||
res_jax_exported = export.call_exported(exp)(*args)
|
||||
res_jax_exported = export.call(exp)(*args)
|
||||
custom_assert_lims = [
|
||||
l for l in self.limitations if l.custom_assert is not None]
|
||||
assert len(custom_assert_lims) <= 1, custom_assert_lims
|
||||
@ -1408,7 +1408,7 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
def f_jax(x, *, y):
|
||||
return x + jnp.sin(y)
|
||||
|
||||
f_exported = export.call_exported(
|
||||
f_exported = export.call(
|
||||
export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
||||
x.dtype),
|
||||
y=jax.ShapeDtypeStruct(y.shape, y.dtype)))
|
||||
@ -1633,22 +1633,22 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
exp = export.export(f)(x_spec)
|
||||
|
||||
x_2 = np.arange(2, dtype=np.int32)
|
||||
res_2 = export.call_exported(exp)(x_2)
|
||||
res_2 = export.call(exp)(x_2)
|
||||
self.assertAllClose(x_2[0:2], res_2)
|
||||
|
||||
x_4 = np.arange(4, dtype=np.int32)
|
||||
res_4 = export.call_exported(exp)(x_4)
|
||||
res_4 = export.call(exp)(x_4)
|
||||
self.assertAllClose(x_4[1:3], res_4)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")):
|
||||
export.call_exported(exp)(np.arange(1, dtype=np.int32))
|
||||
export.call(exp)(np.arange(1, dtype=np.int32))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")):
|
||||
export.call_exported(exp)(np.arange(5, dtype=np.int32))
|
||||
export.call(exp)(np.arange(5, dtype=np.int32))
|
||||
|
||||
def test_caching_with_scopes(self):
|
||||
f_tracing_count = 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user