diff --git a/CHANGELOG.md b/CHANGELOG.md index ceb7b4912..ecad73d27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,11 @@ Remember to align the itemized text with the first line of an item within a list supporting a monolithic CUDA jaxlib. Future releases will use the CUDA plugin jaxlib (e.g. `pip install jax[cuda12]`). * JAX now requires ml_dtypes version 0.4.0 or newer. + * Removed backwards-compatibility support for old usage of the + `jax.experimental.export` API. It is not possible anymore to use + `from jax.experimental.export import export`, and instead you should use + `from jax.experimental import export`. + The removed functionality has been deprecated since 0.4.24. * Deprecations * `jax.sharding.XLACompatibleSharding` is deprecated. Please use diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 7235c0ce2..e7d5d1931 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -345,4 +345,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( _get_vjp=_get_vjp) # We use pjit in case there are shardings in the exported module. - return pjit.pjit(export.call_exported(exported))(*data.inputs) + return pjit.pjit(export.call(exported))(*data.inputs) diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py index e46a7ae4c..e6e298305 100644 --- a/jax/experimental/export/__init__.py +++ b/jax/experimental/export/__init__.py @@ -22,8 +22,6 @@ from jax.experimental.export._export import ( call, DisabledSafetyCheck, default_lowering_platform, - - args_specs, # TODO: deprecate ) from jax._src.export.shape_poly import ( is_symbolic_dim, diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index c0e325382..9e3a24b1d 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -1334,30 +1334,3 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext, return x return mlir.wrap_with_sharding_op( ctx, x, x_aval, x_sharding.to_proto()) - -# TODO(necula): Previously, we had `from jax.experimental.export import export` -# Now we want to simplify the usage, and export the public APIs directly -# from `jax.experimental.export` and now `jax.experimental.export.export` -# refers to the `export` function. Since there may still be users of the -# old API in other packages, we add the old public API as attributes of the -# exported function. We will clean this up after a deprecation period. -def wrap_with_deprecation_warning(f): - msg = (f"You are using function `{f.__name__}` from " - "`jax.experimental.export.export`. You should instead use it directly " - "from `jax.experimental.export`. Instead of " - "`from jax.experimental.export import export` you should use " - "`from jax.experimental import export`.") - def wrapped_f(*args, **kwargs): - warnings.warn(msg, DeprecationWarning, stacklevel=2) - return f(*args, **kwargs) - return wrapped_f - -export.export = wrap_with_deprecation_warning(export) -export.Exported = Exported -export.call_exported = wrap_with_deprecation_warning(call_exported) -export.DisabledSafetyCheck = DisabledSafetyCheck -export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform) -export.symbolic_shape = wrap_with_deprecation_warning(shape_poly.symbolic_shape) -export.args_specs = wrap_with_deprecation_warning(args_specs) -export.minimum_supported_serialization_version = minimum_supported_serialization_version -export.maximum_supported_serialization_version = maximum_supported_serialization_version diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 415d4a7c9..e7c997710 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -787,7 +787,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): logging.info("Running harness natively on %s", jax_device) native_res = f_jax(x_device) logging.info("Running exported harness on %s", jax_device) - exported_res = export.call_exported(exp)(x_device) + exported_res = export.call(exp)(x_device) self.assertAllClose(native_res, exported_res) def test_multi_platform_call_tf_graph(self): diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 44a3070f8..89ecd16fe 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -164,7 +164,7 @@ class PrimitiveTest(jtu.JaxTestCase): logging.info("Running harness natively on %s", device) native_res = func_jax(*device_args) logging.info("Running exported harness on %s", device) - exported_res = export.call_exported(exp)(*device_args) + exported_res = export.call(exp)(*device_args) if tol is not None: logging.info(f"Using non-standard tolerance {tol}") self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol) diff --git a/tests/export_test.py b/tests/export_test.py index 73dee5a20..556f9d363 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -196,7 +196,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) - f1 = export.call_exported(exp_f) + f1 = export.call(exp_f) self.assertAllClose(f(x), f1(x)) def test_jit_static_arg(self): @@ -210,7 +210,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x, c=0.1) - f1 = export.call_exported(exp_f) + f1 = export.call(exp_f) self.assertAllClose(f(x, c=0.1), f1(x)) with self.subTest("static_argnums"): @@ -222,7 +222,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.arange(4, dtype=np.float32) exp_g = get_exported(g)(x, 0.1) - g1 = export.call_exported(exp_g) + g1 = export.call(exp_g) self.assertAllClose(g(x, 0.1), g1(x)) def test_call_exported_lambda(self): @@ -230,7 +230,7 @@ class JaxExportTest(jtu.JaxTestCase): f = lambda x: jnp.sin(x) x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) - f1 = export.call_exported(exp_f) + f1 = export.call(exp_f) self.assertAllClose(f(x), f1(x)) def test_call_name_conflict(self): @@ -258,7 +258,7 @@ class JaxExportTest(jtu.JaxTestCase): @jax.jit def f1(x): exp_f = get_exported(f)(x) - return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x) + return export.call(exp_f)(x) + export.call(exp_f)(x) self.assertAllClose(2. * f(x), f1(x)) @@ -268,7 +268,7 @@ class JaxExportTest(jtu.JaxTestCase): y = np.arange(6, dtype=np.float32) exp_f = get_exported(f)(x, y) - f1 = export.call_exported(exp_f) + f1 = export.call(exp_f) self.assertAllClose(f(x, y), f1(x, y)) def test_pytree(self): @@ -278,7 +278,7 @@ class JaxExportTest(jtu.JaxTestCase): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) exp_f = get_exported(f)((a, b), a=a, b=b) - f1 = export.call_exported(exp_f) + f1 = export.call(exp_f) self.assertAllClose(f((a, b), a=a, b=b), f1((a, b), a=a, b=b)) @@ -291,7 +291,7 @@ class JaxExportTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "The invocation args and kwargs must have the same pytree structure"): - export.call_exported(exp_f)(a, b, c=(a, b)) + export.call(exp_f)(a, b, c=(a, b)) def test_error_wrong_avals(self): def f(a, *, b): # a: f32[4] and b: f32[4] @@ -301,19 +301,19 @@ class JaxExportTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, r"Shape mismatch for args\[0\].shape\[0\]"): - export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) + export.call(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for kwargs\['b'\].shape\[0\]"): - export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) + export.call(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) with self.assertRaisesRegex(ValueError, r"Rank mismatch for args\[0\]"): - export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4) + export.call(exp_f)(f32_4.reshape((1, 4)), b=f32_4) with self.assertRaisesRegex(ValueError, r"Dtype mismatch for args\[0\]"): - export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4) + export.call(exp_f)(f32_4.astype(np.float16), b=f32_4) @jtu.parameterized_filterable( testcase_name=lambda kw: kw["platform"], @@ -328,13 +328,13 @@ class JaxExportTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "The exported function .* was lowered for platform"): - export.call_exported(exp_f)(a) + export.call(exp_f)(a) # Now try with the platform check disabled exp_f_no_platform_check = get_exported( jnp.sin, lowering_platforms=(platform,), disabled_checks=[export.DisabledSafetyCheck.platform()])(a) - res = export.call_exported(exp_f_no_platform_check)(a) + res = export.call(exp_f_no_platform_check)(a) self.assertAllClose(res, jnp.sin(a)) @jtu.parameterized_filterable( @@ -394,7 +394,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.arange(4, dtype=np.float32) exp_f = get_exported(f, vjp_order=1)(x) - f1 = export.call_exported(exp_f) + f1 = export.call(exp_f) self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) def test_higher_order_grad(self): @@ -402,7 +402,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.float32(4.) exp_f = get_exported(f, vjp_order=3)(x) - f1 = export.call_exported(exp_f) + f1 = export.call(exp_f) self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) self.assertAllClose(jax.grad(jax.grad(f))(x), @@ -429,7 +429,7 @@ class JaxExportTest(jtu.JaxTestCase): (f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct)) exp = get_exported(f, vjp_order=2)(xi, xf) - fr = export.call_exported(exp) + fr = export.call(exp) res = fr(xi, xf) self.assertAllClose(res, (f_outi, f_outf)) @@ -463,7 +463,7 @@ class JaxExportTest(jtu.JaxTestCase): res = f((a, b), a=a, b=b) return res def f1_exp(a, b): # For VJP, make a function without kwargs - res = export.call_exported(exp_f)((a, b), a=a, b=b) + res = export.call(exp_f)((a, b), a=a, b=b) return res jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct) exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct) @@ -475,13 +475,13 @@ class JaxExportTest(jtu.JaxTestCase): a = np.arange(4, dtype=np.float32) exp_f1 = get_exported(f1)(a) def f2(x): - res1 = export.call_exported(exp_f1)(x) - res2 = export.call_exported(exp_f1)(res1) + res1 = export.call(exp_f1)(x) + res2 = export.call(exp_f1)(res1) return jnp.cos(res2) exp_f2 = get_exported(f2)(a) self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))), - export.call_exported(exp_f2)(a)) + export.call(exp_f2)(a)) def test_poly_export_only(self): a = np.arange(12, dtype=np.float32).reshape((3, 4)) @@ -590,11 +590,11 @@ class JaxExportTest(jtu.JaxTestCase): exp = get_exported(jnp.sin)( jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32)) x = np.arange(30, dtype=np.float32).reshape((5, 6)) - res = export.call_exported(exp)(x) + res = export.call(exp)(x) self.assertAllClose(res, np.sin(x)) # A function is exported with f32[poly_spec] and is called with different arg - # shapes. We use export.call_exported and we also run the shape check + # shapes. We use export.call and we also run the shape check # module. @jtu.parameterized_filterable( testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore @@ -642,7 +642,7 @@ class JaxExportTest(jtu.JaxTestCase): stack.push(self.assertRaisesRegex(Exception, expect_error)) assert core.is_constant_shape(arg.shape) - res = export.call_exported(exp_f)(arg) + res = export.call(exp_f)(arg) if not expect_error: self.assertAllClose(res, f(arg)) @@ -741,7 +741,7 @@ class JaxExportTest(jtu.JaxTestCase): def outer(x): # x: outer_poly_spec # Use an addition to test that the shapes are refined properly for the # result of the call_exported. - return export.call_exported(inner_exp)(x) + inner(x) + return export.call(inner_exp)(x) + inner(x) with contextlib.ExitStack() as stack: if expect_error_outer_exp is not None: @@ -761,7 +761,7 @@ class JaxExportTest(jtu.JaxTestCase): if expect_error_run is not None: stack.push(self.assertRaisesRegex(Exception, expect_error_run)) - res = export.call_exported(outer_exp)(arg) + res = export.call(outer_exp)(arg) if expect_error_run is not None: return @@ -825,7 +825,7 @@ class JaxExportTest(jtu.JaxTestCase): stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error))) exp = get_exported(f_jax)( jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype)) - export.call_exported(exp)(x) + export.call(exp)(x) def test_poly_booleans(self): # For booleans we use a special case ConvertOp to cast to and from @@ -836,7 +836,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.array([True, False, True, False], dtype=np.bool_) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) - res = export.call_exported(exp)(x) + res = export.call(exp)(x) self.assertAllClose(f_jax(x), res) @jtu.parameterized_filterable( @@ -857,7 +857,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.arange(6, dtype=dtype) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) - res = export.call_exported(exp)(x) + res = export.call(exp)(x) self.assertAllClose(f_jax(x), res) def test_poly_expressions(self): @@ -874,12 +874,12 @@ class JaxExportTest(jtu.JaxTestCase): exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) # Call with static shapes - res = export.call_exported(exp)(x) + res = export.call(exp)(x) self.assertAllClose(res, f(x)) # Now re-export with shape polymorphism x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype) - exp2 = get_exported(export.call_exported(exp))(x_spec) + exp2 = get_exported(export.call(exp))(x_spec) a = exp2.in_avals[0].shape[0] self.assertEqual(exp2.out_avals[0].shape, output_shape(a)) @@ -892,7 +892,7 @@ class JaxExportTest(jtu.JaxTestCase): a, = export.symbolic_shape("a") exp = export.export(f)( jax.ShapeDtypeStruct((a, 4), np.float32)) - f_exp = export.call_exported(exp) + f_exp = export.call(exp) x_jit = np.arange(12, dtype=np.float32).reshape((3, 4)) res_jit = jax.jit(f_exp)(x_jit) self.assertAllClose(res_jit, f(x_jit)) @@ -931,24 +931,24 @@ class JaxExportTest(jtu.JaxTestCase): # We apply the out_shardings for f_jax r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*", re.DOTALL) - hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text() + hlo = jax.jit(export.call(exp)).lower(a_device).as_text() self.assertRegex(hlo, expected_re) - res_exported = export.call_exported(exp)(a_device) + res_exported = export.call(exp)(a_device) self.assertAllClose(res_native, res_exported) # Test error reporting with self.assertRaisesRegex( NotImplementedError, "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): - _ = export.call_exported(exp)(a) + _ = export.call(exp)(a) with self.assertRaisesRegex( NotImplementedError, "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",)) _ = jax.jit( - export.call_exported(exp), + export.call(exp), in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),) )(a) @@ -996,7 +996,7 @@ class JaxExportTest(jtu.JaxTestCase): run_mesh = Mesh(run_devices, "i") b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) - res_exported = export.call_exported(exp)(b) + res_exported = export.call(exp)(b) self.assertAllClose(res_native, res_exported) def test_call_with_different_no_of_devices_error_has_in_shardings(self): @@ -1024,7 +1024,7 @@ class JaxExportTest(jtu.JaxTestCase): "Exported module .* was lowered for 1 devices and is called in a " f"context with {jax.local_device_count()} devices.* module contains " "non-replicated sharding annotations"): - export.call_exported(exp)(b) + export.call(exp)(b) def test_call_with_different_no_of_devices_pmap(self): if len(jax.devices()) < 2: @@ -1042,7 +1042,7 @@ class JaxExportTest(jtu.JaxTestCase): b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape( (-1, 1, 100) ) - res_exported = jax.pmap(export.call_exported(exp))(b) + res_exported = jax.pmap(export.call(exp))(b) self.assertAllClose(res_native, res_exported[0]) def test_call_with_different_no_of_devices_error_has_sharding_constraint(self): @@ -1070,7 +1070,7 @@ class JaxExportTest(jtu.JaxTestCase): "Exported module .* was lowered for 1 devices and is called in a " f"context with {jax.local_device_count()} devices.* module contains " "non-replicated sharding annotations"): - export.call_exported(exp)(b) + export.call(exp)(b) @jtu.parameterized_filterable( kwargs=[ @@ -1108,7 +1108,7 @@ class JaxExportTest(jtu.JaxTestCase): self.assertLen(res_jax.addressable_shards, len(devices)) # Test reloaded execution. - f_r = export.call_exported(exp) + f_r = export.call(exp) with self.assertRaisesRegex( Exception, "Exported module .* was lowered for 2 devices and is " @@ -1241,14 +1241,14 @@ class JaxExportTest(jtu.JaxTestCase): self.assertEqual(exp_vjp2.nr_devices, 2) call_mesh = Mesh(jax.devices()[:2], "e") - g1 = pjit.pjit(export.call_exported(exp_vjp), + g1 = pjit.pjit(export.call(exp_vjp), in_shardings=(NamedSharding(call_mesh, None), NamedSharding(call_mesh, None)))(x, x.T) _, f_jax_vjp = jax.vjp(f_jax, x) xbar = f_jax_vjp(x.T) self.assertAllClose(xbar, g1) - g2 = pjit.pjit(export.call_exported(exp_vjp2), + g2 = pjit.pjit(export.call(exp_vjp2), in_shardings=(NamedSharding(call_mesh, None), NamedSharding(call_mesh, None), NamedSharding(call_mesh, None)))(x, x.T, x) @@ -1299,7 +1299,7 @@ class JaxExportTest(jtu.JaxTestCase): # Call with argument placed on different plaforms for platform in self.__class__.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) - res_exp = export.call_exported(exp)(x_device) + res_exp = export.call(exp)(x_device) self.assertAllClose( res_exp, _testing_multi_platform_fun_expected(x, platform=platform)) @@ -1313,7 +1313,7 @@ class JaxExportTest(jtu.JaxTestCase): # Now serialize the call to the exported using a different sequence of # lowering platforms, but included in the lowering platforms for the # nested exported. - exp2 = get_exported(export.call_exported(exp), + exp2 = get_exported(export.call(exp), lowering_platforms=("cpu", "cuda","rocm"))(x) # Ensure that we do not have multiple lowerings of the exported function @@ -1325,7 +1325,7 @@ class JaxExportTest(jtu.JaxTestCase): for platform in self.__class__.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) - res_exp = export.call_exported(exp2)(x_device) + res_exp = export.call(exp2)(x_device) self.assertAllClose( res_exp, _testing_multi_platform_fun_expected(np.sin(x), platform=platform)) @@ -1337,11 +1337,11 @@ class JaxExportTest(jtu.JaxTestCase): self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call for the current platform. - exp2 = get_exported(export.call_exported(exp))(x) + exp2 = get_exported(export.call(exp))(x) module_str = str(exp2.mlir_module()) self.assertIn("jax.uses_shape_polymorphism = true", module_str) - res2 = export.call_exported(exp2)(x) + res2 = export.call(exp2)(x) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x)) def test_multi_platform_and_poly(self): @@ -1353,11 +1353,11 @@ class JaxExportTest(jtu.JaxTestCase): jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32) ) x = np.arange(12, dtype=np.float32).reshape((3, 4)) - res = export.call_exported(exp)(x) + res = export.call(exp)(x) self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,))) # Now serialize the call to the exported - exp2 = get_exported(export.call_exported(exp))(x) - res2 = export.call_exported(exp2)(x) + exp2 = get_exported(export.call(exp))(x) + res2 = export.call(exp2)(x) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,))) def test_multi_platform_and_sharding(self): @@ -1382,7 +1382,7 @@ class JaxExportTest(jtu.JaxTestCase): continue run_mesh = Mesh(run_devices, ("x",)) a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None)) - res_exp = export.call_exported(exp)(a_device) + res_exp = export.call(exp)(a_device) self.assertArraysAllClose(res_native, res_exp) @jtu.parameterized_filterable( @@ -1449,7 +1449,7 @@ class JaxExportTest(jtu.JaxTestCase): x, effect_class_name="ForTestingOrderedEffect2") + testing_primitive_with_effect_p.bind( x, effect_class_name="ForTestingUnorderedEffect1") + - export.call_exported(exp)(x)) + export.call(exp)(x)) lowered_outer = jax.jit(f_outer).lower(x) self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], @@ -1497,7 +1497,7 @@ class JaxExportTest(jtu.JaxTestCase): r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") self.assertRegex(mlir_module_str, main_expected_re) - res = export.call_exported(exp)(x) + res = export.call(exp)(x) self.assertAllClose(10. + 2. * x, res) @jtu.parameterized_filterable( @@ -1541,7 +1541,7 @@ class JaxExportTest(jtu.JaxTestCase): # Results r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") self.assertRegex(mlir_module_str, main_expected_re) - res = export.call_exported(exp)(x) + res = export.call(exp)(x) self.assertAllClose(10. + _testing_multi_platform_fun_expected(x), res) @@ -1608,7 +1608,7 @@ class JaxExportTest(jtu.JaxTestCase): jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype), jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype), ) - res_exported = export.call_exported(exp_f)(lhs, rhs, group_sizes) + res_exported = export.call(exp_f)(lhs, rhs, group_sizes) self.assertAllClose(res_native, res_exported) if __name__ == "__main__": diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 3fa5bfc26..8e4917774 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -1283,7 +1283,7 @@ class PolyHarness(Harness): return None # Run the JAX natively and then the exported function and compare res_jax_native = f_jax(*args) - res_jax_exported = export.call_exported(exp)(*args) + res_jax_exported = export.call(exp)(*args) custom_assert_lims = [ l for l in self.limitations if l.custom_assert is not None] assert len(custom_assert_lims) <= 1, custom_assert_lims @@ -1408,7 +1408,7 @@ class ShapePolyTest(jtu.JaxTestCase): def f_jax(x, *, y): return x + jnp.sin(y) - f_exported = export.call_exported( + f_exported = export.call( export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype), y=jax.ShapeDtypeStruct(y.shape, y.dtype))) @@ -1633,22 +1633,22 @@ class ShapePolyTest(jtu.JaxTestCase): exp = export.export(f)(x_spec) x_2 = np.arange(2, dtype=np.int32) - res_2 = export.call_exported(exp)(x_2) + res_2 = export.call(exp)(x_2) self.assertAllClose(x_2[0:2], res_2) x_4 = np.arange(4, dtype=np.int32) - res_4 = export.call_exported(exp)(x_4) + res_4 = export.call(exp)(x_4) self.assertAllClose(x_4[1:3], res_4) with self.assertRaisesRegex( ValueError, re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")): - export.call_exported(exp)(np.arange(1, dtype=np.int32)) + export.call(exp)(np.arange(1, dtype=np.int32)) with self.assertRaisesRegex( ValueError, re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")): - export.call_exported(exp)(np.arange(5, dtype=np.int32)) + export.call(exp)(np.arange(5, dtype=np.int32)) def test_caching_with_scopes(self): f_tracing_count = 0