mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11540 from gnecula:ds_check_flag
PiperOrigin-RevId: 462061356
This commit is contained in:
commit
9ee6cacdc8
@ -334,6 +334,8 @@ def jit(
|
||||
>>> g(jnp.arange(4), 3)
|
||||
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
|
||||
"""
|
||||
if abstracted_axes and not config.jax_dynamic_shapes:
|
||||
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
||||
if FLAGS.experimental_cpp_jit and not config.jax_dynamic_shapes:
|
||||
return _jit(True, fun, static_argnums, static_argnames, device, backend,
|
||||
donate_argnums, inline, keep_unused)
|
||||
@ -642,6 +644,8 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
flat_fun = lu.annotate(flat_fun, in_type)
|
||||
in_avals = [aval for aval, explicit in in_type if explicit]
|
||||
else:
|
||||
if abstracted_axes:
|
||||
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
||||
in_avals, _ = unzip2(arg_specs_and_devices)
|
||||
computation = dispatch.lower_xla_callable(
|
||||
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user