mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove the helper jit functions from api.py
PiperOrigin-RevId: 517152277
This commit is contained in:
parent
56267f08dd
commit
f9468d3879
383
jax/_src/api.py
383
jax/_src/api.py
@ -318,379 +318,6 @@ def jit(
|
||||
abstracted_axes, has_explicit_sharding)
|
||||
|
||||
|
||||
def _jit(
|
||||
use_cpp_jit: bool,
|
||||
fun: Callable,
|
||||
static_argnums: Union[int, Iterable[int], None] = None,
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
keep_unused: bool = False,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> stages.Wrapped:
|
||||
# Implemements common logic between CPP and Python backends
|
||||
check_callable(fun)
|
||||
|
||||
donate_argnums, static_argnums, static_argnames = resolve_argnums(
|
||||
fun, donate_argnums, static_argnums, static_argnames)
|
||||
|
||||
if use_cpp_jit:
|
||||
return _cpp_jit(
|
||||
fun, static_argnums=static_argnums, static_argnames=static_argnames,
|
||||
device=device, backend=backend, donate_argnums=donate_argnums,
|
||||
inline=inline, keep_unused=keep_unused)
|
||||
|
||||
return _python_jit(
|
||||
fun, static_argnums=static_argnums, static_argnames=static_argnames,
|
||||
device=device, backend=backend, donate_argnums=donate_argnums,
|
||||
inline=inline, keep_unused=keep_unused, abstracted_axes=abstracted_axes)
|
||||
|
||||
def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
|
||||
args, kwargs):
|
||||
# Validate donate_argnums
|
||||
if max(donate_argnums, default=-1) >= len(args):
|
||||
raise ValueError(
|
||||
f"jitted function has {donate_argnums=} but "
|
||||
f"was called with only {len(args)} positional arguments.")
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
|
||||
f, kwargs = argnames_partial_except(f, static_argnames, kwargs)
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
# Argument donation is incompatible with jax_debug_nans because it re-uses
|
||||
# donated buffers when rerunning the user's function.
|
||||
if donate_argnums and not config.jax_debug_nans:
|
||||
donated_invars = donation_vector(donate_argnums, args, kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
return f, in_tree, args_flat, donated_invars
|
||||
|
||||
|
||||
PytreeOfAbstractedAxesSpec = Any
|
||||
|
||||
def _python_jit(
|
||||
fun: Callable,
|
||||
*,
|
||||
static_argnums: Tuple[int, ...],
|
||||
static_argnames: Tuple[str, ...],
|
||||
device: Optional[xc.Device],
|
||||
backend: Optional[str],
|
||||
donate_argnums: Tuple[int, ...],
|
||||
inline: bool,
|
||||
keep_unused: bool,
|
||||
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec],
|
||||
) -> stages.Wrapped:
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def f_jitted(*args, **kwargs):
|
||||
if config.jax_disable_jit:
|
||||
return fun(*args, **kwargs)
|
||||
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
|
||||
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
for arg in args_flat:
|
||||
dispatch.check_arg(arg)
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
axes_specs = (None if abstracted_axes is None else
|
||||
_flat_axes_specs(abstracted_axes, *args, **kwargs))
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, args_flat)
|
||||
flat_fun = lu.annotate(flat_fun, in_type)
|
||||
out_flat = xla.xla_call(
|
||||
flat_fun, *args_flat,
|
||||
device=device, backend=backend, name=flat_fun.__name__,
|
||||
donated_invars=donated_invars, inline=inline,
|
||||
keep_unused=keep_unused)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device,
|
||||
backend, donate_argnums, inline, keep_unused,
|
||||
abstracted_axes)
|
||||
|
||||
def clear_cache():
|
||||
dispatch.xla_callable.evict_function(fun)
|
||||
f_jitted.clear_cache = clear_cache
|
||||
|
||||
return cast(stages.Wrapped, f_jitted)
|
||||
|
||||
def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
||||
) -> List[pe.AbstractedAxesSpec]:
|
||||
if kwargs: raise NotImplementedError
|
||||
def ax_leaf(l):
|
||||
return (isinstance(l, dict) and all_leaves(l.values()) or
|
||||
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
|
||||
return broadcast_prefix(abstracted_axes, args, ax_leaf)
|
||||
|
||||
|
||||
class _BackendAndDeviceInfo(NamedTuple):
|
||||
default_device: xc.Device
|
||||
committed_to_device: bool
|
||||
|
||||
class _FastpathData(NamedTuple):
|
||||
xla_executable: xc.LoadedExecutable
|
||||
out_pytree_def: Any
|
||||
sticky_device: Optional[xc.Device]
|
||||
avals: Iterable[Any]
|
||||
lazy_exprs: Iterable[Any]
|
||||
kept_var_bitvec: Iterable[bool]
|
||||
shardings: Iterable[Any]
|
||||
committed: Iterable[bool]
|
||||
|
||||
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
|
||||
|
||||
|
||||
def _cpp_jit_clear_cache(self):
|
||||
self._clear_cache()
|
||||
dispatch.xla_callable.evict_function(self._fun)
|
||||
|
||||
def _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
|
||||
use_fastpath = (
|
||||
# This is if we have already executed this code-path (most-recent entry
|
||||
# has been reset to None). Thus, we do not support the fast-path.
|
||||
execute is not None and
|
||||
type(execute) is pxla.ExecuteReplicated and
|
||||
len(execute._local_devices) == 1 and
|
||||
# No effects in computation
|
||||
not execute.ordered_effects and
|
||||
not execute.has_unordered_effects and
|
||||
not execute.has_host_callbacks and
|
||||
all(isinstance(x, xc.ArrayImpl) for x in out_flat) and
|
||||
# Not supported: dynamic shapes
|
||||
not jax.config.jax_dynamic_shapes
|
||||
# TODO(chky): Check sharding is SingleDeviceSharding
|
||||
)
|
||||
|
||||
if use_fastpath:
|
||||
sticky_device = None
|
||||
lazy_exprs = [None] * len(out_flat)
|
||||
kept_var_bitvec = [i in execute.kept_var_idx for i in range(len(args_flat))]
|
||||
avals = [out.aval for out in out_flat]
|
||||
shardings = [out.sharding for out in out_flat]
|
||||
committed = [out._committed for out in out_flat]
|
||||
|
||||
return _FastpathData(execute.xla_executable, out_pytree_def, sticky_device,
|
||||
avals, lazy_exprs, kept_var_bitvec, shardings,
|
||||
committed)
|
||||
|
||||
return None
|
||||
|
||||
def _device_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
|
||||
# TODO(sharadmv): Clean up usage of `execute.args`
|
||||
use_fastpath = (
|
||||
# This is if we have already executed this code-path (most-recent entry
|
||||
# has been reset to None). Thus, we do not support the fast-path.
|
||||
execute is not None and
|
||||
execute.func is dispatch._execute_compiled and # not trivial, not pmap
|
||||
# No effects in computation
|
||||
not execute.args[5] and not execute.args[6] and
|
||||
# Has no host callbacks
|
||||
not execute.args[8] and
|
||||
# impl rule must have been called, i.e. top trace is an EvalTrace
|
||||
isinstance(core.find_top_trace(args_flat), core.EvalTrace) and
|
||||
# Not supported: ShardedDeviceArray
|
||||
all(device_array.type_is_device_array(x) for x in out_flat) and
|
||||
# Not supported: dynamic shapes
|
||||
not jax.config.jax_dynamic_shapes
|
||||
and type(execute.args[4]) is dispatch.SimpleResultHandler)
|
||||
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
(_, xla_executable, _, _, result_handlers, _, _, kept_var_idx,
|
||||
_) = execute.args # pytype: disable=attribute-error
|
||||
sticky_device = None
|
||||
avals = []
|
||||
lazy_exprs = [None] * len(result_handlers)
|
||||
for result_handler in result_handlers:
|
||||
aval, sticky_device = result_handler.args
|
||||
avals.append(aval)
|
||||
assert len(avals) == len(out_flat)
|
||||
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
|
||||
shardings = []
|
||||
committed = []
|
||||
|
||||
return _FastpathData(xla_executable, out_pytree_def, sticky_device, avals,
|
||||
lazy_exprs, kept_var_bitvec, shardings, committed)
|
||||
|
||||
return None
|
||||
|
||||
def _cpp_jit(
|
||||
fun: Callable,
|
||||
*,
|
||||
static_argnums: Tuple[int, ...],
|
||||
static_argnames: Tuple[str, ...],
|
||||
device: Optional[xc.Device],
|
||||
backend: Optional[str],
|
||||
donate_argnums: Tuple[int, ...],
|
||||
inline: bool,
|
||||
keep_unused: bool,
|
||||
) -> stages.Wrapped:
|
||||
# An implementation of `jit` that tries to do as much as possible in C++.
|
||||
# The goal of this function is to speed up the time it takes to process the
|
||||
# arguments, find the correct C++ executable, start the transfer of arguments
|
||||
# and schedule the computation.
|
||||
# As long as it does not support all features of the Python implementation
|
||||
# the C++ code will fallback to `_python_jit` when it faces some unsupported
|
||||
# feature.
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
f"got {device=} and {backend=}.")
|
||||
|
||||
@api_boundary
|
||||
def cache_miss(*args, **kwargs):
|
||||
### This first part is basically the same code as in _python_jit.
|
||||
# An alternative would be for cache_miss to accept from C++ the arguments
|
||||
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
|
||||
# work/code that is redundant between C++ and Python. We can try that later.
|
||||
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
|
||||
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
|
||||
for arg in args_flat:
|
||||
dispatch.check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
in_type = pe.infer_lambda_input_type(None, args_flat)
|
||||
flat_fun = lu.annotate(flat_fun, in_type)
|
||||
|
||||
primitive = xla.xla_call_p
|
||||
call_bind_continuation, top_trace, fun_, tracers, params = (
|
||||
core.call_bind_with_continuation(primitive, flat_fun, *args_flat,
|
||||
device=device,
|
||||
backend=backend,
|
||||
name=flat_fun.__name__,
|
||||
donated_invars=donated_invars,
|
||||
inline=inline,
|
||||
keep_unused=keep_unused))
|
||||
execute = None
|
||||
try:
|
||||
if isinstance(top_trace, core.EvalTrace) and not (
|
||||
jax.config.jax_debug_nans or jax.config.jax_debug_infs):
|
||||
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
|
||||
out_flat = call_bind_continuation(execute(*args_flat))
|
||||
else:
|
||||
out_flat = call_bind_continuation(
|
||||
top_trace.process_call(primitive, fun_, tracers, params))
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
msg = pjit._device_assignment_mismatch_error(
|
||||
fun, fails, in_tree, args_flat, 'jit')
|
||||
raise ValueError(msg) from None
|
||||
out_pytree_def = out_tree()
|
||||
out = tree_unflatten(out_pytree_def, out_flat)
|
||||
|
||||
### Decide whether we can support the C++ fast path
|
||||
# High level note: The Python tracing mechanism is complex; in particular
|
||||
# to know whether `jax.jit(f)(x)` will execute or trace, it's not enough to
|
||||
# inspect the argument x, we actually do need to execute it and look at the
|
||||
# outputs that could be tracers (if f is capturing `Tracer` by closure).
|
||||
|
||||
fastpath_data = None
|
||||
|
||||
# TODO(sharadmv): Enable fast path for effectful jaxprs
|
||||
fastpath_data = _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat)
|
||||
return out, fastpath_data
|
||||
|
||||
def get_device_info():
|
||||
"""Backends do not exist before __main__ is being executed."""
|
||||
committed_to_device = device is not None or backend is not None
|
||||
|
||||
if device is not None:
|
||||
default_device = device
|
||||
else:
|
||||
backend_ = xb.get_backend(backend)
|
||||
default_device = backend_.get_default_device_assignment(1)[0]
|
||||
|
||||
return _BackendAndDeviceInfo(default_device, committed_to_device)
|
||||
|
||||
jitted_f_kwargs = {}
|
||||
jitted_f_kwargs["has_explicit_device"] = (
|
||||
device is not None or backend is not None)
|
||||
cpp_jitted_f = jax_jit.jit(
|
||||
fun,
|
||||
cache_miss,
|
||||
get_device_info,
|
||||
static_argnums=static_argnums,
|
||||
static_argnames=static_argnames,
|
||||
donate_argnums=donate_argnums,
|
||||
cache=_cpp_jit_cache,
|
||||
**jitted_f_kwargs) # type: ignore
|
||||
f_jitted = wraps(fun)(cpp_jitted_f)
|
||||
|
||||
f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device,
|
||||
backend, donate_argnums, inline, keep_unused,
|
||||
None)
|
||||
f_jitted._fun = fun
|
||||
type(f_jitted).clear_cache = _cpp_jit_clear_cache
|
||||
|
||||
return cast(stages.Wrapped, f_jitted)
|
||||
|
||||
|
||||
def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
donate_argnums, inline, keep_unused: bool,
|
||||
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec]):
|
||||
"""Make a ``lower`` method for jitted functions."""
|
||||
# If the function we returned from ``jit`` were a class instance,
|
||||
# this might naturally be a method, with ``fun`` as a ``self`` and
|
||||
# all the other arguments stored as attributes.
|
||||
|
||||
def arg_spec(x):
|
||||
# like xla.arg_spec but duck-types on x.shape and x.dtype
|
||||
aval = None if jax.config.jax_dynamic_shapes else shaped_abstractify(x)
|
||||
if hasattr(x, 'sharding'):
|
||||
if isinstance(x.sharding, PmapSharding):
|
||||
return aval, None
|
||||
# If `x` has a sharding attribute but not `_committed` attribute,
|
||||
# assume that `x` is committed. This might happen when the input is
|
||||
# a `ShapedDtypeStruct` or `types.SimpleNamespace`, etc that might
|
||||
# only have a `sharding` attribute on them.
|
||||
return aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim)
|
||||
if getattr(x, '_committed', True) else None)
|
||||
else:
|
||||
return aval, None
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, _experimental_lowering_platform: Optional[str] = None,
|
||||
**kwargs) -> stages.Lowered:
|
||||
"""Lower this function for the given arguments.
|
||||
|
||||
A lowered function is staged out of Python and translated to a
|
||||
compiler's input language, possibly in a backend-dependent
|
||||
manner. It is ready for compilation but not yet compiled.
|
||||
|
||||
Returns:
|
||||
A ``Lowered`` instance representing the lowering.
|
||||
"""
|
||||
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
|
||||
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
arg_specs_and_devices = map(arg_spec, args_flat)
|
||||
in_avals: Sequence[core.AbstractValue]
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
axes_specs = (None if abstracted_axes is None else
|
||||
_flat_axes_specs(abstracted_axes, *args, **kwargs))
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, args_flat)
|
||||
flat_fun = lu.annotate(flat_fun, in_type)
|
||||
in_avals = [aval for aval, explicit in in_type if explicit]
|
||||
else:
|
||||
if abstracted_axes:
|
||||
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
||||
in_avals, _ = unzip2(arg_specs_and_devices)
|
||||
if any(not core.is_constant_shape(a.shape) for a in in_avals):
|
||||
# TODO(b/262808613): Do not drop unused inputs when we have
|
||||
# shape polymorphism, to ensure that we can always derive
|
||||
# the dimension variables from the kept inputs.
|
||||
nonlocal keep_unused
|
||||
keep_unused = True
|
||||
computation = dispatch.sharded_lowering(
|
||||
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
|
||||
keep_unused, lowering_platform=_experimental_lowering_platform,
|
||||
*arg_specs_and_devices)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, in_tree, in_avals, donate_argnums, out_tree())
|
||||
|
||||
return lower
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_jit(disable: bool = True):
|
||||
"""Context manager that disables :py:func:`jit` behavior under its dynamic context.
|
||||
@ -2749,6 +2376,15 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
return Partial(transposed_fun, const)
|
||||
|
||||
|
||||
def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
||||
) -> List[pe.AbstractedAxesSpec]:
|
||||
if kwargs: raise NotImplementedError
|
||||
def ax_leaf(l):
|
||||
return (isinstance(l, dict) and all_leaves(l.values()) or
|
||||
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
|
||||
return broadcast_prefix(abstracted_axes, args, ax_leaf)
|
||||
|
||||
|
||||
def make_jaxpr(fun: Callable,
|
||||
static_argnums: Union[int, Iterable[int]] = (),
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
@ -3420,7 +3056,6 @@ def clear_backends():
|
||||
jax.lib.xla_bridge._backends = {}
|
||||
dispatch.xla_callable.cache_clear() # type: ignore
|
||||
dispatch.xla_primitive_callable.cache_clear()
|
||||
_cpp_jit_cache.clear()
|
||||
jax_jit.CompiledFunctionCache.clear_all()
|
||||
pjit._pjit_lower_cached.cache_clear()
|
||||
pjit._create_pjit_jaxpr.cache_clear()
|
||||
|
@ -61,7 +61,7 @@ from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding, GSPMDSharding, NamedSharding,
|
||||
PmapSharding, SingleDeviceSharding, NamedSharding,
|
||||
PartitionSpec, XLACompatibleSharding)
|
||||
from jax._src.util import flatten, unflatten
|
||||
|
||||
@ -190,16 +190,14 @@ def wait_for_tokens():
|
||||
|
||||
@util.cache()
|
||||
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
|
||||
_, arg_devices = util.unzip2(arg_specs)
|
||||
donated_invars = (False,) * len(arg_specs)
|
||||
device = None # This will be resolved in sharded_lowering.
|
||||
def prim_fun(*args):
|
||||
out = prim.bind(*args, **params)
|
||||
if prim.multiple_results:
|
||||
return out
|
||||
else:
|
||||
return out,
|
||||
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
|
||||
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), None, None,
|
||||
prim.name, donated_invars, False, *arg_specs)
|
||||
if not prim.multiple_results:
|
||||
return lambda *args, **kw: compiled(*args, **kw)[0]
|
||||
@ -299,42 +297,10 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
xla.xla_call_p.def_impl(_xla_call_impl)
|
||||
|
||||
|
||||
# TODO(yashkatariya,mattjj): Try to handle this in api.py via a device_put and
|
||||
# don't pass the device and backend argument to `_xla_callable_uncached`.
|
||||
def not_none_device_or_backend_on_jit(backend, device, num_ins):
|
||||
"""This is to support the backend and device argument on jit. It's a feature
|
||||
that's deprecated but needs to be supported for feature parity and so that we
|
||||
can delete the non-Array paths when Array is switched on.
|
||||
"""
|
||||
# TODO(yashkatariya): Remove this entire function when backend and device are
|
||||
# removed as arguments on jit.
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
"got device={} and backend={}".format(device, backend))
|
||||
|
||||
if backend is not None:
|
||||
da = [xb.get_backend(backend).get_default_device_assignment(1)[0]]
|
||||
else:
|
||||
assert device is not None
|
||||
da = [device]
|
||||
|
||||
assert len(da) == 1
|
||||
# in_shardings will be marked as replicated regardless of whatever the input
|
||||
# had. Given that only a single device is allowed above, this is correct.
|
||||
in_shardings = [GSPMDSharding.get_replicated(da)] * num_ins
|
||||
return da, in_shardings
|
||||
|
||||
|
||||
def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
|
||||
keep_unused, *arg_specs,
|
||||
lowering_platform: Optional[str]):
|
||||
in_avals, in_shardings = util.unzip2(arg_specs)
|
||||
|
||||
da = None
|
||||
if backend is not None or device is not None:
|
||||
da, in_shardings = not_none_device_or_backend_on_jit(
|
||||
backend, device, len(in_shardings))
|
||||
|
||||
in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore
|
||||
|
||||
# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
|
||||
@ -343,7 +309,7 @@ def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars,
|
||||
in_avals, in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
|
||||
always_lower=always_lower, devices_from_context=da,
|
||||
always_lower=always_lower, devices_from_context=None,
|
||||
lowering_platform=lowering_platform)
|
||||
|
||||
|
||||
|
@ -1022,18 +1022,13 @@ class JaxTestCase(parameterized.TestCase):
|
||||
atol=atol or tol, rtol=rtol or tol,
|
||||
canonicalize_dtypes=canonicalize_dtypes)
|
||||
|
||||
_CPP_JIT_IMPLEMENTATION = functools.partial(api._jit, True)
|
||||
_CPP_JIT_IMPLEMENTATION._name = "cpp"
|
||||
_PYTHON_JIT_IMPLEMENTATION = functools.partial(api._jit, False)
|
||||
_PYTHON_JIT_IMPLEMENTATION._name = "python"
|
||||
_PJIT_IMPLEMENTATION = jax.jit
|
||||
_PJIT_IMPLEMENTATION._name = "pjit"
|
||||
_PJIT_IMPLEMENTATION._name = "jit"
|
||||
_NOOP_JIT_IMPLEMENTATION = lambda x, *args, **kwargs: x
|
||||
_NOOP_JIT_IMPLEMENTATION._name = "noop"
|
||||
|
||||
JIT_IMPLEMENTATION = (
|
||||
_CPP_JIT_IMPLEMENTATION,
|
||||
_PYTHON_JIT_IMPLEMENTATION,
|
||||
_PJIT_IMPLEMENTATION,
|
||||
_NOOP_JIT_IMPLEMENTATION,
|
||||
)
|
||||
|
||||
|
@ -645,15 +645,14 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f(3)
|
||||
|
||||
def test_jit_raises_on_first_invocation_on_non_hashable_static_argnum(self):
|
||||
if self.jit != api._python_jit:
|
||||
raise unittest.SkipTest("this test only applies to _python_jit")
|
||||
f = lambda x, y: x + 3
|
||||
jitted_f = self.jit(f, static_argnums=(1,))
|
||||
|
||||
msg = ("Non-hashable static arguments are not supported, as this can lead "
|
||||
"to unexpected cache-misses. Static argument (index 1) of type "
|
||||
"<class 'numpy.ndarray'> for function <lambda> is non-hashable.")
|
||||
with self.assertRaisesRegex(ValueError, re.escape(msg)):
|
||||
msg = ("Non-hashable static arguments are not supported. An error occurred "
|
||||
".*while trying to hash an object of type "
|
||||
"<class 'numpy\\.ndarray'>, 1. The error was:\nTypeError: "
|
||||
"unhashable type: 'numpy\\.ndarray'")
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jitted_f(1, np.asarray(1))
|
||||
|
||||
def test_cpp_jit_raises_on_non_hashable_static_argnum(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user