Merge pull request #11540 from gnecula:ds_check_flag

PiperOrigin-RevId: 462061356
This commit is contained in:
jax authors 2022-07-19 23:07:14 -07:00
commit 9ee6cacdc8

View File

@ -334,6 +334,8 @@ def jit(
>>> g(jnp.arange(4), 3)
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
"""
if abstracted_axes and not config.jax_dynamic_shapes:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
if FLAGS.experimental_cpp_jit and not config.jax_dynamic_shapes:
return _jit(True, fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline, keep_unused)
@ -642,6 +644,8 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
flat_fun = lu.annotate(flat_fun, in_type)
in_avals = [aval for aval, explicit in in_type if explicit]
else:
if abstracted_axes:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
in_avals, _ = unzip2(arg_specs_and_devices)
computation = dispatch.lower_xla_callable(
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,