mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
fb6a143d4d
commit
4354f355a8
124
jax/_src/api.py
124
jax/_src/api.py
@ -32,7 +32,8 @@ import threading
|
||||
import weakref
|
||||
import types
|
||||
from typing import (Any, Callable, Iterable, NamedTuple, Mapping, Optional,
|
||||
Sequence, Tuple, TypeVar, Union, overload, Dict, Hashable)
|
||||
Sequence, Tuple, TypeVar, Union, overload, Dict, Hashable,
|
||||
List)
|
||||
from warnings import warn
|
||||
|
||||
import numpy as np
|
||||
@ -43,9 +44,9 @@ from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import stages
|
||||
from jax.core import eval_jaxpr
|
||||
from jax.tree_util import (tree_map, tree_multimap, tree_flatten, tree_unflatten,
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
tree_structure, tree_transpose, tree_leaves,
|
||||
tree_map, treedef_is_leaf, treedef_children,
|
||||
tree_multimap, treedef_is_leaf, treedef_children,
|
||||
Partial, PyTreeDef, all_leaves, treedef_tuple)
|
||||
|
||||
from jax._src import device_array
|
||||
@ -191,7 +192,7 @@ def _infer_argnums_and_argnames(
|
||||
fun: Callable,
|
||||
argnums: Union[int, Iterable[int], None],
|
||||
argnames: Union[str, Iterable[str], None],
|
||||
) -> Tuple[Tuple[int, ...], Tuple[str, ...]]:
|
||||
) -> Tuple[Tuple[int, ...], Tuple[str, ...]]:
|
||||
"""Infer missing argnums and argnames for a function with inspect."""
|
||||
if argnums is None and argnames is None:
|
||||
argnums = ()
|
||||
@ -226,15 +227,16 @@ def _infer_argnums_and_argnames(
|
||||
|
||||
|
||||
def jit(
|
||||
fun: Callable,
|
||||
*,
|
||||
static_argnums: Union[int, Iterable[int], None] = None,
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
) -> stages.Wrapped:
|
||||
fun: Callable,
|
||||
*,
|
||||
static_argnums: Union[int, Iterable[int], None] = None,
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> stages.Wrapped:
|
||||
"""Sets up ``fun`` for just-in-time compilation with XLA.
|
||||
|
||||
Args:
|
||||
@ -325,12 +327,12 @@ def jit(
|
||||
>>> g(jnp.arange(4), 3)
|
||||
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
|
||||
"""
|
||||
if FLAGS.experimental_cpp_jit:
|
||||
if FLAGS.experimental_cpp_jit and not config.jax_dynamic_shapes:
|
||||
return _cpp_jit(fun, static_argnums, static_argnames, device, backend,
|
||||
donate_argnums, inline)
|
||||
else:
|
||||
return _python_jit(fun, static_argnums, static_argnames, device, backend,
|
||||
donate_argnums, inline)
|
||||
donate_argnums, inline, abstracted_axes)
|
||||
|
||||
|
||||
def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
|
||||
@ -352,6 +354,8 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
|
||||
return f, in_tree, args_flat, donated_invars
|
||||
|
||||
|
||||
PytreeOfAbstractedAxesSpec = Any
|
||||
|
||||
def _python_jit(
|
||||
fun: Callable,
|
||||
static_argnums: Union[int, Iterable[int], None] = None,
|
||||
@ -360,8 +364,8 @@ def _python_jit(
|
||||
backend: Optional[str] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
) -> stages.Wrapped:
|
||||
# The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit.
|
||||
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec] = None,
|
||||
) -> stages.Wrapped:
|
||||
_check_callable(fun)
|
||||
static_argnums, static_argnames = _infer_argnums_and_argnames(
|
||||
fun, static_argnums, static_argnames)
|
||||
@ -376,9 +380,14 @@ def _python_jit(
|
||||
return fun(*args, **kwargs)
|
||||
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
|
||||
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
for arg in args_flat:
|
||||
_check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
if config.jax_dynamic_shapes:
|
||||
axes_specs = (None if abstracted_axes is None else
|
||||
_flat_axes_specs(abstracted_axes, *args, **kwargs))
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, args_flat)
|
||||
flat_fun = lu.annotate(flat_fun, in_type)
|
||||
out_flat = xla.xla_call(
|
||||
flat_fun, *args_flat,
|
||||
device=device, backend=backend, name=flat_fun.__name__,
|
||||
@ -389,6 +398,14 @@ def _python_jit(
|
||||
backend, donate_argnums, inline)
|
||||
return f_jitted
|
||||
|
||||
def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
||||
) -> List[pe.AbstractedAxesSpec]:
|
||||
if kwargs: raise NotImplementedError
|
||||
def ax_leaf(l):
|
||||
return (isinstance(l, dict) and all_leaves(l.values()) or
|
||||
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
|
||||
return broadcast_prefix(abstracted_axes, args, ax_leaf)
|
||||
|
||||
|
||||
class _BackendAndDeviceInfo(NamedTuple):
|
||||
default_device: xc.Device
|
||||
@ -412,7 +429,7 @@ def _cpp_jit(
|
||||
backend: Optional[str] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
) -> stages.Wrapped:
|
||||
) -> stages.Wrapped:
|
||||
# An implementation of `jit` that tries to do as much as possible in C++.
|
||||
# The goal of this function is to speed up the time it takes to process the
|
||||
# arguments, find the correct C++ executable, start the transfer of arguments
|
||||
@ -442,6 +459,9 @@ def _cpp_jit(
|
||||
for arg in args_flat:
|
||||
_check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
in_type = pe.infer_lambda_input_type(None, args_flat)
|
||||
flat_fun = lu.annotate(flat_fun, in_type)
|
||||
out_flat = xla.xla_call(
|
||||
flat_fun, *args_flat,
|
||||
device=device, backend=backend, name=flat_fun.__name__,
|
||||
@ -462,10 +482,12 @@ def _cpp_jit(
|
||||
execute is not None and
|
||||
execute.func is dispatch._execute_compiled and # not trivial, not pmap
|
||||
# Not supported: ShardedDeviceArray
|
||||
all(device_array.type_is_device_array(x) for x in out_flat))
|
||||
all(device_array.type_is_device_array(x) for x in out_flat) and
|
||||
# Not supported: dynamic shapes
|
||||
not jax.config.jax_dynamic_shapes)
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
_, xla_executable, _, result_handlers, kept_var_idx = execute.args
|
||||
_, xla_executable, _, _, result_handlers, kept_var_idx = execute.args
|
||||
sticky_device = None
|
||||
avals = []
|
||||
lazy_exprs = [None] * len(result_handlers)
|
||||
@ -900,7 +922,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
has_aux: bool = False, holomorphic: bool = False,
|
||||
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
|
||||
) -> Callable[..., Tuple[Any, Any]]:
|
||||
) -> Callable[..., Tuple[Any, Any]]:
|
||||
"""Create a function that evaluates both ``fun`` and the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
@ -1517,18 +1539,18 @@ def _mapped_axis_size(tree, vals, dims, name, *, kws=False):
|
||||
raise ValueError(msg.format(f"the tree of axis sizes is:\n{sizes}")) from None
|
||||
|
||||
def pmap(
|
||||
fun: Callable,
|
||||
axis_name: Optional[AxisName] = None,
|
||||
*,
|
||||
in_axes=0,
|
||||
out_axes=0,
|
||||
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
|
||||
devices: Optional[Sequence[xc.Device]] = None,
|
||||
backend: Optional[str] = None,
|
||||
axis_size: Optional[int] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
||||
) -> Any:
|
||||
fun: Callable,
|
||||
axis_name: Optional[AxisName] = None,
|
||||
*,
|
||||
in_axes=0,
|
||||
out_axes=0,
|
||||
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
|
||||
devices: Optional[Sequence[xc.Device]] = None,
|
||||
backend: Optional[str] = None,
|
||||
axis_size: Optional[int] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
||||
) -> Any:
|
||||
"""Parallel map with support for collective operations.
|
||||
|
||||
The purpose of :py:func:`pmap` is to express single-program multiple-data
|
||||
@ -1864,7 +1886,7 @@ def _get_f_mapped(
|
||||
axis_size: Optional[int],
|
||||
donate_tuple: Tuple[int],
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]],
|
||||
):
|
||||
):
|
||||
def pmap_f(*args, **kwargs):
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
@ -1914,7 +1936,7 @@ def _python_pmap(
|
||||
axis_size: Optional[int] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
||||
) -> stages.Wrapped:
|
||||
) -> stages.Wrapped:
|
||||
"""The Python only implementation."""
|
||||
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
|
||||
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
|
||||
@ -1973,7 +1995,7 @@ def _cpp_pmap(
|
||||
axis_size: Optional[int] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
||||
) -> Any:
|
||||
) -> Any:
|
||||
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
|
||||
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
|
||||
out_axes)
|
||||
@ -2140,7 +2162,7 @@ def shapecheck(in_shapes, out_shape, fun: Callable):
|
||||
|
||||
def jvp(
|
||||
fun: Callable, primals, tangents, has_aux: bool = False
|
||||
) -> Tuple[Any, ...]:
|
||||
) -> Tuple[Any, ...]:
|
||||
"""Computes a (forward-mode) Jacobian-vector product of ``fun``.
|
||||
|
||||
Args:
|
||||
@ -2364,7 +2386,7 @@ else:
|
||||
|
||||
def vjp( # type: ignore
|
||||
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
|
||||
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
|
||||
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
|
||||
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
|
||||
|
||||
:py:func:`grad` is implemented as a special case of :py:func:`vjp`.
|
||||
@ -2596,24 +2618,10 @@ def make_jaxpr(fun: Callable,
|
||||
if abstracted_axes is None:
|
||||
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
||||
else:
|
||||
if kwargs: raise NotImplementedError
|
||||
ax_leaf = lambda l: (isinstance(l, dict) and all_leaves(l.values()) or
|
||||
isinstance(l, tuple) and all_leaves(l))
|
||||
axes_specs = broadcast_prefix(abstracted_axes, args, ax_leaf)
|
||||
sizes: Dict[Hashable, int] = {}
|
||||
env: Dict[Hashable, core.AbstractValue] = {}
|
||||
def make_aval(arg, spec):
|
||||
if isinstance(spec, tuple):
|
||||
spec = dict(zip(range(len(arg.shape)), spec))
|
||||
if not spec: return shaped_abstractify(arg)
|
||||
assert all(arg.shape[i] == sizes.setdefault(name, arg.shape[i])
|
||||
for i, name in spec.items())
|
||||
shape = [env.setdefault(spec[i], ShapedArray((), dtypes.dtype('int32')))
|
||||
if i in spec else d for i, d in enumerate(arg.shape)]
|
||||
return core.DShapedArray(tuple(shape), arg.dtype, False)
|
||||
in_avals = map(make_aval, flat_args, axes_specs)
|
||||
keep_inputs = [False] * len(env) + [True] * len(flat_args)
|
||||
return [*env.values(), *in_avals], in_tree, keep_inputs
|
||||
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
||||
in_avals, keep_inputs = unzip2(in_type)
|
||||
return in_avals, in_tree, keep_inputs
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
@ -3078,7 +3086,7 @@ def named_call(
|
||||
fun: Callable[..., Any],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
) -> Callable[..., Any]:
|
||||
) -> Callable[..., Any]:
|
||||
"""Adds a user specified name to a function when staging out JAX computations.
|
||||
|
||||
When staging out computations for just-in-time compilation to XLA (or other
|
||||
|
@ -350,7 +350,7 @@ class Config:
|
||||
Values included in this set should also most likely be included in
|
||||
the C++ JIT state, which is handled separately."""
|
||||
return (self.x64_enabled, self.jax_numpy_rank_promotion,
|
||||
self.jax_default_matmul_precision)
|
||||
self.jax_default_matmul_precision, self.jax_dynamic_shapes)
|
||||
|
||||
class _StateContextManager:
|
||||
def __init__(self, name, help, update_thread_local_hook,
|
||||
@ -424,6 +424,7 @@ already_configured_with_absl = False
|
||||
class GlobalJitState(NamedTuple):
|
||||
numpy_rank_promotion: Optional[str] = None
|
||||
default_matmul_precision: Optional[Any] = None
|
||||
dynamic_shapes: bool = False
|
||||
|
||||
|
||||
def update_global_jit_state(**kw):
|
||||
@ -436,6 +437,7 @@ class ThreadLocalJitState(NamedTuple):
|
||||
dynamic_trace_state: Optional[Any] = None
|
||||
numpy_rank_promotion: Optional[str] = None
|
||||
default_matmul_precision: Optional[Any] = None
|
||||
dynamic_shapes: bool = False
|
||||
|
||||
|
||||
def update_thread_local_jit_state(**kw):
|
||||
@ -700,7 +702,11 @@ config.define_bool_state(
|
||||
name='jax_dynamic_shapes',
|
||||
default=False,
|
||||
help=('Enables experimental features for staging out computations with '
|
||||
'dynamic shapes.'))
|
||||
'dynamic shapes.'),
|
||||
update_global_hook=lambda val: \
|
||||
update_global_jit_state(dynamic_shapes=val),
|
||||
update_thread_local_hook=lambda val: \
|
||||
update_thread_local_jit_state(dynamic_shapes=val))
|
||||
|
||||
config.define_bool_state(
|
||||
name='jax_experimental_name_stack',
|
||||
|
@ -20,7 +20,7 @@ from functools import partial
|
||||
import itertools
|
||||
import time
|
||||
from typing import (
|
||||
Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union)
|
||||
Any, Callable, Dict, Optional, Sequence, Set, Tuple, List, Type, Union)
|
||||
from typing_extensions import Protocol
|
||||
import os
|
||||
import re
|
||||
@ -29,6 +29,7 @@ import warnings
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.errors import UnexpectedTracerError
|
||||
@ -142,8 +143,11 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
|
||||
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
donated_invars, inline):
|
||||
del inline # Only used at tracing time
|
||||
arg_specs = unsafe_map(arg_spec, args)
|
||||
if fun.in_type is not None:
|
||||
arg_specs = [(None, *xs) for _, *xs in arg_specs]
|
||||
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
|
||||
*unsafe_map(arg_spec, args))
|
||||
*arg_specs)
|
||||
try:
|
||||
out = compiled_fun(*args)
|
||||
except FloatingPointError:
|
||||
@ -159,7 +163,7 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
# intentional here, to avoid "Store occupied" errors we clone the WrappedFun
|
||||
# with empty stores.
|
||||
stores = [lu.Store() for _ in fun.stores]
|
||||
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params)
|
||||
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params, fun.in_type)
|
||||
with core.new_sublevel():
|
||||
_ = clone.call_wrapped(*args) # probably won't return
|
||||
return out
|
||||
@ -193,21 +197,27 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
"got device={} and backend={}".format(device, backend))
|
||||
|
||||
abstract_args, arg_devices = util.unzip2(arg_specs)
|
||||
if fun.in_type is not None:
|
||||
abstract_args, which_explicit = util.unzip2(fun.in_type)
|
||||
else:
|
||||
which_explicit = None
|
||||
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for jit in {elapsed_time} sec"):
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, abstract_args, pe.debug_info_final(fun, "jit"))
|
||||
fun, abstract_args, pe.debug_info_final(fun, "jit"), which_explicit)
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
|
||||
abstract_args, arg_devices = util.unzip2(pruned_arg_specs)
|
||||
donated_invars = [
|
||||
x for i, x in enumerate(donated_invars) if i in kept_var_idx
|
||||
]
|
||||
# TODO(mattjj): handle argument pruning w/ dynamic shapes
|
||||
if fun.in_type is None:
|
||||
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
abstract_args, arg_devices = util.unzip2(
|
||||
[a for i, a in enumerate(arg_specs) if i in kept_var_idx])
|
||||
donated_invars = [x for i, x in enumerate(donated_invars) if i in kept_var_idx]
|
||||
del kept_const_idx
|
||||
else:
|
||||
kept_var_idx = set(range(len(abstract_args)))
|
||||
map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
|
||||
jaxpr = apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
@ -215,12 +225,16 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
||||
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
|
||||
|
||||
if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and
|
||||
backend.platform != 'iree'):
|
||||
jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
# and don't need to evaluate their arguments.
|
||||
if not jaxpr.eqns:
|
||||
return XlaComputation(
|
||||
name, None, True, None, jaxpr=jaxpr, consts=consts, device=device,
|
||||
name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device,
|
||||
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
|
||||
|
||||
if not _on_exit:
|
||||
@ -260,9 +274,9 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
module_name, closed_jaxpr, backend.platform,
|
||||
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
|
||||
return XlaComputation(
|
||||
name, module, False, donated_invars, nreps=nreps, device=device,
|
||||
backend=backend, tuple_args=tuple_args, in_avals=abstract_args,
|
||||
out_avals=out_avals, kept_var_idx=kept_var_idx)
|
||||
name, module, False, donated_invars, which_explicit, nreps=nreps,
|
||||
device=device, backend=backend, tuple_args=tuple_args,
|
||||
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
|
||||
|
||||
|
||||
def prefetch(x):
|
||||
@ -291,6 +305,14 @@ def jaxpr_has_pmap(jaxpr):
|
||||
return True
|
||||
return False
|
||||
|
||||
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
|
||||
return (any(type(d) is core.Var for v in jaxpr.invars
|
||||
if type(v.aval) is core.DShapedArray for d in v.aval.shape) or
|
||||
any(type(d) is core.Var
|
||||
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
|
||||
for e in j.eqns for v in itertools.chain(e.invars, e.outvars)
|
||||
if type(v.aval) is core.DShapedArray for d in v.aval.shape))
|
||||
|
||||
def _prune_unused_inputs(
|
||||
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
|
||||
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
|
||||
@ -384,6 +406,59 @@ num_buffers_handlers[core.ShapedArray] = lambda _: 1
|
||||
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
|
||||
|
||||
|
||||
def _input_handler(which_explicit: Optional[Sequence[bool]],
|
||||
in_avals: Sequence[core.AbstractValue]
|
||||
) -> Optional[Callable]:
|
||||
# Extract implicit inputs, and pad bounded-size inputs to their max size.
|
||||
needs_implicit = which_explicit and not all(which_explicit)
|
||||
needs_padding = any(type(in_avals[d.val]) is core.AbstractBInt # type: ignore
|
||||
for a in in_avals if type(a) is core.DShapedArray
|
||||
for d in a.shape if type(d) is pe.DBIdx)
|
||||
|
||||
if not needs_implicit and not needs_padding:
|
||||
return None
|
||||
assert config.jax_dynamic_shapes
|
||||
|
||||
# Precompute how to grab implicit inputs from explicit inputs' axis sizes.
|
||||
which_explicit = which_explicit or [True] * len(in_avals)
|
||||
implicit_idxs = {i for i, ex in enumerate(which_explicit) if not ex}
|
||||
implicit_args_from_axes: List[Tuple[int, int, int]] = []
|
||||
for arg_idx, aval in enumerate(in_avals):
|
||||
if isinstance(aval, core.DShapedArray):
|
||||
for axis_idx, d in enumerate(aval.shape):
|
||||
if isinstance(d, pe.DBIdx) and d.val in implicit_idxs:
|
||||
implicit_args_from_axes.append((d.val, arg_idx, axis_idx))
|
||||
assert {i for i, _, _ in implicit_args_from_axes} == implicit_idxs
|
||||
|
||||
# Precompute how to pad bounded-size inputs to their max size.
|
||||
def needs_pad(a: core.AbstractValue) -> bool:
|
||||
return (type(a) is core.DShapedArray and
|
||||
any(type(d) is pe.DBIdx for d in aval.shape))
|
||||
|
||||
def padshape(a: core.DShapedArray) -> List[int]:
|
||||
return [in_avals[d.val].bound if type(d) is pe.DBIdx and # type: ignore
|
||||
type(in_avals[d.val]) is core.AbstractBInt else d for d in a.shape] # type: ignore
|
||||
|
||||
padders = [partial(jax.jit(_pad_arg, static_argnums=0), tuple(padshape(aval))) # type: ignore
|
||||
if needs_pad(aval) else None for aval in in_avals]
|
||||
|
||||
def elaborate_and_pad(explicit_args):
|
||||
explicit_args_ = iter(explicit_args)
|
||||
args = [next(explicit_args_) if ex else None for ex in which_explicit]
|
||||
assert next(explicit_args_, None) is None
|
||||
for i, j, k in implicit_args_from_axes:
|
||||
if args[i] is None:
|
||||
args[i] = args[j].shape[k] # type: ignore
|
||||
else:
|
||||
if args[i] != args[j].shape[k]:
|
||||
raise Exception("inconsistent argument axis sizes for type")
|
||||
return tuple([pad(x) if pad else x for x, pad in zip(args, padders)])
|
||||
return elaborate_and_pad
|
||||
|
||||
def _pad_arg(shape, x):
|
||||
zeros = jax.lax.full(shape, 0, x.dtype)
|
||||
return jax.lax.dynamic_update_slice(zeros, x, (0,) * len(shape))
|
||||
|
||||
if MYPY:
|
||||
ResultHandler = Any
|
||||
else:
|
||||
@ -405,6 +480,13 @@ def array_result_handler(sticky_device: Optional[Device],
|
||||
return partial(device_array.make_device_array, core.raise_to_shaped(aval),
|
||||
sticky_device)
|
||||
|
||||
def dynamic_array_result_handler(sticky_device: Optional[Device],
|
||||
aval: core.DShapedArray):
|
||||
if aval.dtype is dtypes.float0:
|
||||
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
result_handlers: Dict[
|
||||
Type[core.AbstractValue],
|
||||
@ -412,6 +494,7 @@ result_handlers: Dict[
|
||||
result_handlers[core.AbstractUnit] = lambda _, __: lambda _: core.unit
|
||||
result_handlers[core.AbstractToken] = lambda _, __: lambda _: core.token
|
||||
result_handlers[core.ShapedArray] = array_result_handler
|
||||
result_handlers[core.DShapedArray] = dynamic_array_result_handler
|
||||
result_handlers[core.ConcreteArray] = array_result_handler
|
||||
|
||||
|
||||
@ -433,9 +516,11 @@ def _check_special(name, xla_shape, buf):
|
||||
|
||||
|
||||
def _execute_compiled(name: str, compiled: XlaExecutable,
|
||||
input_handler: Optional[Callable],
|
||||
output_buffer_counts: Optional[Sequence[int]],
|
||||
result_handlers, kept_var_idx, *args):
|
||||
device, = compiled.local_devices()
|
||||
args = input_handler(args) if input_handler else args
|
||||
input_bufs_flat = flatten(device_put(x, device) for i, x in enumerate(args)
|
||||
if i in kept_var_idx)
|
||||
out_bufs_flat = compiled.execute(input_bufs_flat)
|
||||
@ -447,8 +532,10 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
|
||||
|
||||
|
||||
def _execute_replicated(name: str, compiled: XlaExecutable,
|
||||
input_handler: Optional[Callable],
|
||||
output_buffer_counts: Optional[Sequence[int]],
|
||||
result_handlers, kept_var_idx, *args):
|
||||
if input_handler: raise NotImplementedError # TODO(mattjj, dougalm)
|
||||
input_bufs = [flatten(device_put(x, device) for i, x in enumerate(args)
|
||||
if i in kept_var_idx)
|
||||
for device in compiled.local_devices()]
|
||||
@ -477,16 +564,18 @@ def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
|
||||
class XlaComputation(stages.Computation):
|
||||
name: str
|
||||
_is_trivial: bool
|
||||
_executable: Optional['XlaCompiledComputation']
|
||||
_executable: Optional[XlaCompiledComputation]
|
||||
_donated_invars: Optional[Sequence[bool]]
|
||||
|
||||
def __init__(self, name: str, hlo, is_trivial: bool,
|
||||
donated_invars: Optional[Sequence[bool]],
|
||||
explicit_args: Optional[Sequence[bool]],
|
||||
**compile_args):
|
||||
self.name = name
|
||||
self._hlo = hlo
|
||||
self._is_trivial = is_trivial
|
||||
self._donated_invars = donated_invars
|
||||
self._explicit_args = explicit_args
|
||||
self._executable = None
|
||||
self.compile_args = compile_args
|
||||
|
||||
@ -511,14 +600,14 @@ class XlaComputation(stages.Computation):
|
||||
return ir.Module.parse(module_str)
|
||||
return self._hlo
|
||||
|
||||
def compile(self) -> 'XlaCompiledComputation':
|
||||
def compile(self) -> XlaCompiledComputation:
|
||||
if self._executable is None:
|
||||
if self.is_trivial():
|
||||
self._executable = XlaCompiledComputation.from_trivial_jaxpr(
|
||||
**self.compile_args)
|
||||
else:
|
||||
self._executable = XlaCompiledComputation.from_xla_computation(
|
||||
self.name, self._hlo, **self.compile_args)
|
||||
self.name, self._hlo, self._explicit_args, **self.compile_args)
|
||||
|
||||
return self._executable
|
||||
|
||||
@ -586,6 +675,7 @@ class XlaCompiledComputation(stages.Executable):
|
||||
def from_xla_computation(
|
||||
name: str,
|
||||
xla_computation: Optional[ir.Module],
|
||||
explicit_args: Optional[Sequence[bool]],
|
||||
nreps: int,
|
||||
device: Optional[Device],
|
||||
backend: Backend,
|
||||
@ -594,6 +684,7 @@ class XlaCompiledComputation(stages.Executable):
|
||||
out_avals: Sequence[core.AbstractValue],
|
||||
kept_var_idx: Set[int]) -> XlaCompiledComputation:
|
||||
sticky_device = device
|
||||
input_handler = _input_handler(explicit_args, in_avals)
|
||||
result_handlers = map(partial(aval_to_result_handler, sticky_device),
|
||||
out_avals)
|
||||
options = xb.get_compile_options(
|
||||
@ -603,10 +694,10 @@ class XlaCompiledComputation(stages.Executable):
|
||||
with log_elapsed_time(f"Finished XLA compilation of {name} "
|
||||
"in {elapsed_time} sec"):
|
||||
compiled = compile_or_get_cached(backend, xla_computation, options)
|
||||
buffer_counts = (None if len(out_avals) == 1 else
|
||||
[aval_to_num_buffers(aval) for aval in out_avals])
|
||||
buffer_counts = (None if len(out_avals) == 1 and not config.jax_dynamic_shapes
|
||||
else [aval_to_num_buffers(aval) for aval in out_avals])
|
||||
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
||||
unsafe_call = partial(execute, name, compiled, buffer_counts,
|
||||
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts,
|
||||
result_handlers, kept_var_idx)
|
||||
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
|
||||
|
||||
@ -622,7 +713,7 @@ class XlaCompiledComputation(stages.Executable):
|
||||
|
||||
@staticmethod
|
||||
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
|
||||
kept_var_idx) -> 'XlaCompiledComputation':
|
||||
kept_var_idx) -> XlaCompiledComputation:
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
|
||||
out_avals, result_handlers, kept_var_idx)
|
||||
|
@ -1719,12 +1719,7 @@ def _combine_leading(sz0, sz1, aval, x):
|
||||
return lax.collapse(x, 0, 2)
|
||||
|
||||
def _prepend_dim_to_aval(sz, aval):
|
||||
if aval is core.abstract_unit:
|
||||
return aval
|
||||
elif isinstance(aval, ShapedArray):
|
||||
return aval.update(shape=(sz, *aval.shape), weak_type=False)
|
||||
else:
|
||||
raise TypeError(f'Prepending dim {sz} to aval {aval}')
|
||||
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
|
||||
|
||||
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear, unroll):
|
||||
@ -2063,6 +2058,10 @@ def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
||||
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
||||
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
|
||||
|
||||
def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
|
||||
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
|
||||
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
|
||||
|
||||
def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
|
||||
jaxpr, linear, unroll):
|
||||
tc = partial(_typecheck_param, 'scan')
|
||||
@ -2133,6 +2132,7 @@ masking.masking_rules[scan_p] = _scan_masking_rule
|
||||
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
||||
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
|
||||
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
|
||||
pe.padding_rules[scan_p] = _scan_padding_rule
|
||||
|
||||
|
||||
|
||||
|
@ -1970,6 +1970,7 @@ integer_pow_p = standard_primitive(
|
||||
batching.defvectorized(integer_pow_p)
|
||||
masking.defvectorized(integer_pow_p)
|
||||
ad.defjvp(integer_pow_p, _integer_pow_jvp)
|
||||
pe.padding_rules[integer_pow_p] = lambda _, __, x, y: [integer_pow_p.bind(x, y=y)]
|
||||
|
||||
def _integer_pow(x, *, y):
|
||||
# This should be kept in sync with the jax2tf translation rule.
|
||||
@ -2061,6 +2062,7 @@ add_p: Primitive = standard_naryop([_num, _num], 'add')
|
||||
ad.primitive_jvps[add_p] = _add_jvp
|
||||
ad.primitive_transposes[add_p] = _add_transpose
|
||||
mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp))
|
||||
pe.padding_rules[add_p] = lambda _, __, x, y: [add(x, y)]
|
||||
|
||||
def _sub_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
@ -2090,6 +2092,7 @@ sub_p = standard_naryop([_num, _num], 'sub')
|
||||
ad.primitive_jvps[sub_p] = _sub_jvp
|
||||
ad.primitive_transposes[sub_p] = _sub_transpose
|
||||
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubOp))
|
||||
pe.padding_rules[sub_p] = lambda _, __, x, y: [sub(x, y)]
|
||||
|
||||
|
||||
def _mul_transpose(ct, x, y):
|
||||
@ -2116,6 +2119,7 @@ ad.defjvp(mul_p,
|
||||
lambda ydot, x, y: mul(x, ydot))
|
||||
ad.primitive_transposes[mul_p] = _mul_transpose
|
||||
mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp))
|
||||
pe.padding_rules[mul_p] = lambda _, __, x, y: [mul(x, y)]
|
||||
|
||||
def _div_transpose_rule(cotangent, x, y):
|
||||
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
|
||||
@ -2166,6 +2170,7 @@ ad.defjvp2(max_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
|
||||
pe.padding_rules[max_p] = lambda _, __, x, y: [max(x, y)]
|
||||
|
||||
min_p: core.Primitive = standard_naryop(
|
||||
[_any, _any], 'min', translation_rule=partial(
|
||||
@ -2281,7 +2286,8 @@ def _convert_elt_type_fwd_rule(eqn):
|
||||
return [None], eqn
|
||||
|
||||
def _convert_elt_type_pp_rule(eqn, context, settings):
|
||||
# don't print new_dtype because the output binder shows it
|
||||
# don't print new_dtype because the output binder shows it, don't print
|
||||
# weak_type when false
|
||||
printed_params = {}
|
||||
if eqn.params['weak_type']:
|
||||
printed_params['weak_type'] = True
|
||||
@ -2595,6 +2601,27 @@ def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers,
|
||||
rhs, dimension_numbers, precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
|
||||
dimension_numbers, **params):
|
||||
lhs_aval, _ = in_avals
|
||||
(lhs_contract, _), _ = dimension_numbers
|
||||
padded_axes = [(i, lhs_aval.shape[i].val) for i in lhs_contract
|
||||
if isinstance(lhs_aval.shape[i], pe.BoundedAxisSize)]
|
||||
lhs_ = _replace_masked_values(lhs, 0, padded_axes)
|
||||
return [dot_general(lhs_, rhs, dimension_numbers=dimension_numbers, **params)]
|
||||
|
||||
def _dot_general_pp_rule(eqn, context, settings):
|
||||
# suppress printing precision or preferred_element_type when None.
|
||||
# print dimension_numbers as list-of-lists to be shorter.
|
||||
printed_params = {k: v for k, v in eqn.params.items() if v is not None}
|
||||
(lhs_cont, rhs_cont), (lhs_batch, rhs_batch) = eqn.params['dimension_numbers']
|
||||
printed_params['dimension_numbers'] = (
|
||||
(list(lhs_cont), list(rhs_cont)), (list(lhs_batch), list(rhs_batch)))
|
||||
return [pp.text(eqn.primitive.name),
|
||||
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
|
||||
pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
||||
|
||||
|
||||
dot_general_p = standard_primitive(_dot_general_shape_rule,
|
||||
_dot_general_dtype_rule, 'dot_general',
|
||||
_dot_general_translation_rule)
|
||||
@ -2604,6 +2631,9 @@ batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
||||
masking.masking_rules[dot_general_p] = _dot_general_masking_rule
|
||||
xla.register_translation(dot_general_p, _dot_general_cpu_translation_rule,
|
||||
platform="cpu")
|
||||
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
|
||||
# TODO(mattjj): un-comment the next line
|
||||
# core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
|
||||
|
||||
def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
|
||||
if precision is None:
|
||||
@ -2728,12 +2758,31 @@ def _broadcast_in_dim_staging_rule(
|
||||
|
||||
return out_tracer
|
||||
|
||||
def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
|
||||
shape, broadcast_dimensions):
|
||||
del in_avals, dyn_shape
|
||||
out_aval, = out_avals
|
||||
new_shape = []
|
||||
new_dyn_shape = []
|
||||
for d in out_aval.shape:
|
||||
if type(d) is pe.BoundedAxisSize:
|
||||
new_shape.append(d.bound)
|
||||
elif type(d) is int:
|
||||
new_shape.append(d)
|
||||
else:
|
||||
assert isinstance(d, core.Tracer)
|
||||
new_shape.append(None)
|
||||
new_dyn_shape.append(d)
|
||||
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=new_shape,
|
||||
broadcast_dimensions=broadcast_dimensions)]
|
||||
|
||||
broadcast_in_dim_p = standard_primitive(
|
||||
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
|
||||
ad.deflinear2(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
|
||||
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
||||
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
|
||||
pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
|
||||
pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule
|
||||
|
||||
def _broadcast_in_dim_lower(ctx, x, *, shape, broadcast_dimensions):
|
||||
del shape
|
||||
@ -2809,6 +2858,7 @@ ad.defjvp(clamp_p,
|
||||
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
|
||||
mlir.register_lowering(
|
||||
clamp_p, partial(_nary_lower_mhlo, mhlo.ClampOp, explicit_type=True))
|
||||
pe.padding_rules[clamp_p] = lambda _, __, a, x, b: [clamp(a, x, b)]
|
||||
|
||||
def _concatenate_shape_rule(*operands, **kwargs):
|
||||
dimension = kwargs.pop('dimension')
|
||||
@ -3559,6 +3609,21 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
||||
assert result.shape == input_shape
|
||||
return [result]
|
||||
|
||||
def _reduce_sum_padding_rule(in_avals, out_avals, operand, *, axes):
|
||||
del out_avals
|
||||
aval, = in_avals
|
||||
padded_axes = [(i, d.val) for i, d in enumerate(aval.shape)
|
||||
if isinstance(d, pe.BoundedAxisSize)]
|
||||
masked_operand = _replace_masked_values(operand, 0, padded_axes)
|
||||
return [_reduce_sum(masked_operand, axes)]
|
||||
|
||||
def _replace_masked_values(x, val, padded_axes):
|
||||
if not padded_axes: return x
|
||||
masks = [broadcasted_iota(np.dtype('int32'), x.shape, i) < d
|
||||
for i, d in padded_axes]
|
||||
return select(_reduce(operator.and_, masks), x, full_like(x, val))
|
||||
|
||||
|
||||
reduce_sum_p = standard_primitive(
|
||||
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
|
||||
'reduce_sum', _reduce_sum_translation_rule)
|
||||
@ -3566,6 +3631,7 @@ ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
||||
batching.defreducer(reduce_sum_p)
|
||||
_masking_defreducer(reduce_sum_p,
|
||||
lambda shape, dtype: np.broadcast_to(np.array(0, dtype), shape))
|
||||
pe.padding_rules[reduce_sum_p] = _reduce_sum_padding_rule
|
||||
|
||||
|
||||
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
|
||||
@ -4409,6 +4475,18 @@ def _iota_lower(ctx, *, dtype, shape, dimension):
|
||||
mlir.register_lowering(iota_p, _iota_lower)
|
||||
|
||||
|
||||
def make_bint(i, bd: int):
|
||||
return bint_p.bind(i, bd=bd)
|
||||
|
||||
bint_p = core.Primitive('bint')
|
||||
|
||||
@bint_p.def_abstract_eval
|
||||
def bint_abstract_eval(_, *, bd: int):
|
||||
return core.AbstractBInt(bound=bd)
|
||||
|
||||
pe.padding_rules[bint_p] = lambda _, __, i, bd: [i]
|
||||
|
||||
|
||||
### util
|
||||
|
||||
_ndim = np.ndim
|
||||
|
@ -22,10 +22,12 @@ import numpy as np
|
||||
from jax import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import masking
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src.lax.utils import (
|
||||
_argnum_weak_type,
|
||||
@ -2113,3 +2115,30 @@ def _dynamic_slice_indices(operand, start_indices: Any):
|
||||
lax.add(i, lax.convert_element_type(core.dimension_as_value(d), lax._dtype(i))),
|
||||
i)
|
||||
for i, d in zip(start_indices, operand.shape)]
|
||||
|
||||
|
||||
def _getslice(x, lo, hi):
|
||||
return getslice_p.bind(x, lo, hi)
|
||||
|
||||
getslice_p = core.Primitive('getslice')
|
||||
|
||||
@getslice_p.def_impl
|
||||
def getslice_impl(x, lo, hi):
|
||||
return x[lo:hi]
|
||||
|
||||
def _getslice_staging_rule(trace, x, lo, hi):
|
||||
size = lax.make_bint(lax.clamp(0, hi - lo, x.shape[0]), x.shape[0])
|
||||
aval = core.DShapedArray((size,), x.dtype, x.weak_type)
|
||||
source_info = source_info_util.current()
|
||||
out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info)
|
||||
invars = map(trace.getvar, [x, lo, hi])
|
||||
eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)],
|
||||
getslice_p, {}, source_info)
|
||||
trace.frame.eqns.append(eqn)
|
||||
return out_tracer
|
||||
pe.custom_staging_rules[getslice_p] = _getslice_staging_rule
|
||||
|
||||
def _getslice_padding_rule(in_avals, out_avals, x, lo, hi):
|
||||
xx = lax.concatenate([x, x], 0)
|
||||
return [dynamic_slice_in_dim(xx, lo, x.shape[0])]
|
||||
pe.padding_rules[getslice_p] = _getslice_padding_rule
|
||||
|
@ -67,8 +67,10 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
dtype_rule(*avals, **kwargs), weak_type=weak_type,
|
||||
named_shape=named_shape_rule(*avals, **kwargs))
|
||||
elif least_specialized is core.DShapedArray:
|
||||
return core.DShapedArray(shape_rule(*avals, **kwargs),
|
||||
dtype_rule(*avals, **kwargs), weak_type)
|
||||
shape = shape_rule(*avals, **kwargs)
|
||||
ty = (core.ShapedArray if all(type(d) is int for d in shape)
|
||||
else core.DShapedArray)
|
||||
return ty(shape, dtype_rule(*avals, **kwargs), weak_type)
|
||||
elif least_specialized is core.UnshapedArray:
|
||||
return core.UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type)
|
||||
else:
|
||||
|
@ -48,8 +48,10 @@ from jax.tree_util import tree_leaves, tree_flatten, tree_map
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src.api_util import _ensure_index_tuple
|
||||
from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator
|
||||
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
|
||||
_sort_le_comparator)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.slicing import _getslice
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.numpy.reductions import ( # noqa: F401
|
||||
_ensure_optional_axes, _reduction_dims,
|
||||
@ -3479,6 +3481,15 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
# All supported cases of indexing can be implemented as an XLA gather,
|
||||
# followed by an optional reverse and broadcast_in_dim.
|
||||
arr = asarray(arr)
|
||||
|
||||
# TODO(mattjj,dougalm): expand dynamic shape indexing support
|
||||
if (jax.config.jax_dynamic_shapes and type(idx) is slice and idx.step is None
|
||||
and (isinstance(idx.start, core.Tracer) or isinstance(idx.stop, core.Tracer))
|
||||
and arr.shape):
|
||||
start = 0 if idx.start is None else idx.start
|
||||
stop = arr.shape[0] if idx.stop is None else idx.stop
|
||||
return _getslice(arr, start, stop)
|
||||
|
||||
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
|
||||
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
|
||||
unique_indices, mode, fill_value)
|
||||
|
@ -89,7 +89,7 @@ def treedef_children(treedef):
|
||||
def treedef_is_leaf(treedef):
|
||||
return treedef.num_nodes == 1
|
||||
|
||||
def all_leaves(iterable):
|
||||
def all_leaves(iterable, is_leaf: Optional[Callable[[Any], bool]] = None):
|
||||
"""Tests whether all elements in the given iterable are all leaves.
|
||||
|
||||
>>> tree = {"a": [1, 2, 3]}
|
||||
@ -106,7 +106,11 @@ def all_leaves(iterable):
|
||||
Returns:
|
||||
A boolean indicating if all elements in the input are leaves.
|
||||
"""
|
||||
return pytree.all_leaves(iterable)
|
||||
if is_leaf is None:
|
||||
return pytree.all_leaves(iterable)
|
||||
else:
|
||||
lst = list(iterable)
|
||||
return lst == tree_leaves(lst, is_leaf)
|
||||
|
||||
|
||||
_Children = TypeVar("_Children", bound=Iterable[Any])
|
||||
|
52
jax/core.py
52
jax/core.py
@ -1034,6 +1034,18 @@ class AbstractUnit(AbstractValue):
|
||||
|
||||
abstract_unit = AbstractUnit()
|
||||
|
||||
class AbstractBInt(AbstractValue):
|
||||
__slots__ = ['bound']
|
||||
bound: int
|
||||
def __init__(self, bound):
|
||||
self.bound = bound
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
return f'bint{{≤{self.bound}}}[]'
|
||||
def __eq__(self, other):
|
||||
return type(other) is AbstractBInt and self.bound == other.bound
|
||||
def __hash__(self) -> int:
|
||||
return hash((type(self), self.bound))
|
||||
|
||||
def lattice_join(x: Optional[AbstractValue],
|
||||
y: Optional[AbstractValue]) -> AbstractValue:
|
||||
if x is None:
|
||||
@ -1212,7 +1224,7 @@ AxisSize = Union[AxisSizeForTracing, AxisSizeForJaxprType,
|
||||
class DShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape']
|
||||
shape: Tuple[AxisSize, ...] # noqa: F821
|
||||
array_abstraction_level = 2
|
||||
array_abstraction_level: int = 2
|
||||
|
||||
def __init__(self, shape, dtype, weak_type):
|
||||
self.shape = shape
|
||||
@ -1229,11 +1241,6 @@ class DShapedArray(UnshapedArray):
|
||||
return f'{dtype}[{shape}]'
|
||||
__str__ = __repr__ = str_short
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other) and
|
||||
self.dtype == other.dtype and self.shape == other.shape and
|
||||
self.weak_type == other.weak_type)
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None):
|
||||
if shape is None:
|
||||
shape = self.shape
|
||||
@ -1243,6 +1250,14 @@ class DShapedArray(UnshapedArray):
|
||||
weak_type = self.weak_type
|
||||
return DShapedArray(shape, dtype, weak_type)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape
|
||||
and self.weak_type == other.weak_type)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape, self.dtype, self.weak_type))
|
||||
|
||||
del AxisSize, AxisSizeForTracing, AxisSizeForJaxprType, \
|
||||
AxisSizeForJaxprTracingSpec
|
||||
|
||||
@ -1415,6 +1430,7 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
|
||||
|
||||
raise_to_shaped_mappings : Dict[type, Callable] = {
|
||||
AbstractUnit: lambda aval, _: aval,
|
||||
AbstractBInt: lambda aval, _: aval,
|
||||
AbstractToken: lambda aval, _: aval,
|
||||
Bot: lambda aval, _: aval,
|
||||
UnshapedArray: lambda aval, _: aval,
|
||||
@ -1758,14 +1774,17 @@ class CallPrimitive(Primitive):
|
||||
new_params = dict(params)
|
||||
jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
|
||||
if config.jax_dynamic_shapes:
|
||||
subfun = lu.annotate(subfun, tuple((v.aval, True) for v in jaxpr.invars))
|
||||
return [subfun], new_params
|
||||
|
||||
def call_bind(primitive: CallPrimitive, fun, *args, **params):
|
||||
top_trace = find_top_trace(args)
|
||||
fun, env_trace_todo = process_env_traces_call(
|
||||
fun_, env_trace_todo = process_env_traces_call(
|
||||
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_call(primitive, fun, tracers, params)
|
||||
fun_ = lu.annotate(fun_, fun.in_type)
|
||||
outs = top_trace.process_call(primitive, fun_, tracers, params)
|
||||
return map(full_lower, apply_todos(env_trace_todo(), outs))
|
||||
|
||||
@lu.transformation_with_aux
|
||||
@ -1931,13 +1950,28 @@ def _unmap_shaped_array(size: int, axis_name, axis: Optional[int],
|
||||
named_shape = dict(aval.named_shape)
|
||||
# TODO: Make this mandatory
|
||||
named_shape.pop(axis_name, None)
|
||||
if axis is None: return aval.replace(named_shape=named_shape)
|
||||
if axis is None: return aval.update(named_shape=named_shape)
|
||||
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
named_shape=named_shape, weak_type=aval.weak_type)
|
||||
|
||||
def _map_dshaped_array(size: Union[int, Tracer], axis: Optional[int],
|
||||
aval: ShapedArray) -> ShapedArray:
|
||||
assert False # TODO(mattjj, dougalm)
|
||||
|
||||
def _unmap_dshaped_array(
|
||||
size: Union[int, Tracer], axis_name, axis: Optional[int],
|
||||
aval: DShapedArray) -> DShapedArray:
|
||||
if isinstance(size, int):
|
||||
if axis is None: return aval
|
||||
return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
weak_type=aval.weak_type)
|
||||
else:
|
||||
assert False # TODO(mattjj, dougalm)
|
||||
|
||||
AvalMapHandlerPair = Tuple[Callable, Callable]
|
||||
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
|
||||
AbstractUnit: (_map_unit, _map_unit),
|
||||
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
|
||||
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
|
||||
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ import contextlib
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Callable, Dict, List, Tuple, Optional
|
||||
|
||||
import jax
|
||||
from jax.interpreters import partial_eval as pe
|
||||
@ -24,7 +24,7 @@ from jax.config import config
|
||||
from jax import core
|
||||
from jax._src.dtypes import dtype, float0
|
||||
from jax.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
||||
raise_to_shaped)
|
||||
raise_to_shaped)
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_aval, zeros_like_p, Zero)
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
|
||||
@ -39,6 +39,16 @@ zip = safe_zip
|
||||
map = safe_map
|
||||
def identity(x): return x
|
||||
|
||||
def _update_annotation(
|
||||
f: lu.WrappedFun,
|
||||
orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]],
|
||||
nonzeros: List[bool]
|
||||
) -> lu.WrappedFun:
|
||||
if orig_type is None:
|
||||
return f
|
||||
tan_types = [(aval.at_least_vspace(), keep)
|
||||
for nz, (aval, keep) in zip(nonzeros, orig_type) if nz]
|
||||
return lu.annotate(f, (*orig_type, *tan_types))
|
||||
|
||||
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
|
||||
transform_stack=True) -> Any:
|
||||
@ -217,7 +227,6 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
|
||||
with ctx:
|
||||
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
# FIXME: Some invars correspond to tangents
|
||||
invals = map(read_primal, eqn.invars)
|
||||
if eqn.primitive.multiple_results:
|
||||
cts_in = map(read_cotangent, eqn.outvars)
|
||||
@ -244,8 +253,10 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
|
||||
cotangents_out = map(read_cotangent, jaxpr.invars)
|
||||
return cotangents_out
|
||||
|
||||
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, transform_stack, primals_in, cotangents_in):
|
||||
return backward_pass(jaxpr.jaxpr, reduce_axes, transform_stack, jaxpr.consts, primals_in, cotangents_in)
|
||||
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, transform_stack,
|
||||
primals_in, cotangents_in):
|
||||
return backward_pass(jaxpr.jaxpr, reduce_axes, transform_stack, jaxpr.consts,
|
||||
primals_in, cotangents_in)
|
||||
|
||||
|
||||
class UndefinedPrimal:
|
||||
@ -301,7 +312,7 @@ class JVPTrace(Trace):
|
||||
else:
|
||||
return JVPTracer(self, primal_out, tangent_out)
|
||||
|
||||
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
assert call_primitive.multiple_results
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
|
||||
@ -329,6 +340,7 @@ class JVPTrace(Trace):
|
||||
f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def)
|
||||
update_params = call_param_updaters.get(call_primitive)
|
||||
new_params = update_params(params, nz_tangents) if update_params else params
|
||||
f_jvp = _update_annotation(f_jvp, f.in_type, nz_tangents)
|
||||
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
|
||||
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
|
||||
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
|
||||
@ -590,15 +602,16 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
|
||||
reduce_axes, False)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
if config.jax_experimental_name_stack:
|
||||
new_params = params
|
||||
else:
|
||||
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
||||
if not config.jax_experimental_name_stack:
|
||||
params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
||||
update_params = call_transpose_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, map(is_undefined_primal, args),
|
||||
[type(x) is not Zero for x in ct])
|
||||
out_flat = primitive.bind(fun, *all_args, **new_params)
|
||||
params = update_params(params, map(is_undefined_primal, args),
|
||||
[type(x) is not Zero for x in ct])
|
||||
if config.jax_dynamic_shapes:
|
||||
in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args]
|
||||
fun = lu.annotate(fun, tuple(in_type))
|
||||
out_flat = primitive.bind(fun, *all_args, **params)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Set, Optional, Tuple, Union, Iterable,
|
||||
Type)
|
||||
Type, Sequence)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -34,6 +34,17 @@ from jax.interpreters import partial_eval as pe
|
||||
|
||||
map = safe_map
|
||||
|
||||
def _update_annotation(
|
||||
f: lu.WrappedFun,
|
||||
orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]],
|
||||
axis_size: int, axis_name: core.AxisName, in_dims: Sequence[Optional[int]]
|
||||
) -> lu.WrappedFun:
|
||||
if orig_type is None:
|
||||
return f
|
||||
batched_in_type = [(core.unmapped_aval(axis_size, axis_name, dim, aval), keep)
|
||||
for dim, (aval, keep) in zip(in_dims, orig_type)]
|
||||
return lu.annotate(f, tuple(batched_in_type))
|
||||
|
||||
### vmappable typeclass
|
||||
|
||||
Vmappable = Any
|
||||
@ -175,11 +186,9 @@ class BatchTrace(Trace):
|
||||
|
||||
def get_frame(self, vals, dims) -> core.AxisEnvFrame:
|
||||
if self.axis_name is core.no_axis_name:
|
||||
# If axis name is `no_axis_name` we can't find it via `core.axis_name` so we
|
||||
# reconstruct it from the information we have available
|
||||
axis_sizes = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
|
||||
assert len(axis_sizes) == 1
|
||||
axis_size, = axis_sizes
|
||||
# If axis name is `no_axis_name` we can't find it via `core.axis_name` so
|
||||
# we reconstruct it from the information we have available
|
||||
axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
|
||||
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
|
||||
return core.axis_frame(self.axis_name)
|
||||
|
||||
@ -204,7 +213,7 @@ class BatchTrace(Trace):
|
||||
else:
|
||||
return BatchTracer(self, val_out, dim_out, src)
|
||||
|
||||
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
assert call_primitive.multiple_results
|
||||
if config.jax_experimental_name_stack:
|
||||
params = dict(params, name=params.get('name', f.__name__))
|
||||
@ -214,8 +223,10 @@ class BatchTrace(Trace):
|
||||
if all(bdim is not_mapped for bdim in dims):
|
||||
return call_primitive.bind(f, *vals, **params)
|
||||
else:
|
||||
f, dims_out = batch_subtrace(f, self.main, dims)
|
||||
vals_out = call_primitive.bind(f, *vals, **params)
|
||||
f_, dims_out = batch_subtrace(f, self.main, dims)
|
||||
ax_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
|
||||
f_ = _update_annotation(f_, f.in_type, ax_size, self.axis_name, dims)
|
||||
vals_out = call_primitive.bind(f_, *vals, **params)
|
||||
src = source_info_util.current()
|
||||
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())]
|
||||
|
||||
|
@ -118,6 +118,13 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
|
||||
def _array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
|
||||
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
|
||||
|
||||
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
|
||||
shape = [d if type(d) is int else -1 for d in aval.shape]
|
||||
return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)
|
||||
|
||||
def _bint_ir_types(aval: core.AbstractBInt) -> Sequence[ir.Type]:
|
||||
return (ir.RankedTensorType.get((), dtype_to_ir_type(dtypes.dtype('int32'))),)
|
||||
|
||||
ir_type_handlers: Dict[Type[core.AbstractValue],
|
||||
Callable[[Any], Sequence[ir.Type]]] = {}
|
||||
|
||||
@ -132,9 +139,11 @@ def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
|
||||
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err
|
||||
|
||||
ir_type_handlers[core.AbstractUnit] = lambda _: ()
|
||||
ir_type_handlers[core.AbstractBInt] = _bint_ir_types
|
||||
ir_type_handlers[core.ShapedArray] = _array_ir_types
|
||||
ir_type_handlers[core.ConcreteArray] = _array_ir_types
|
||||
ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()]
|
||||
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types
|
||||
|
||||
def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
|
||||
"""Convenience wrapper around aval_to_ir_types for single types.
|
||||
|
@ -14,13 +14,14 @@
|
||||
|
||||
from collections import namedtuple
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
import operator as op
|
||||
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
|
||||
List, Union, Set, cast)
|
||||
List, Union, Hashable, cast)
|
||||
from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
@ -38,15 +39,24 @@ from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
as_hashable_function, weakref_lru_cache)
|
||||
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
||||
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
|
||||
ConcreteArray, raise_to_shaped, Var, Atom, JaxprEqn,
|
||||
Primitive, DShapedArray, mapped_aval, unmapped_aval)
|
||||
ConcreteArray, raise_to_shaped, Var, DropVar, Atom,
|
||||
JaxprEqn, Primitive, ShapedArray, DShapedArray,
|
||||
AbstractBInt, mapped_aval, unmapped_aval)
|
||||
from jax._src import source_info_util
|
||||
from jax.config import config
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
def identity(x): return x
|
||||
|
||||
def _update_annotation(
|
||||
f: lu.WrappedFun,
|
||||
orig_type: Optional[Tuple[Tuple[AbstractValue, bool], ...]],
|
||||
in_knowns: List[bool]) -> lu.WrappedFun:
|
||||
if orig_type is None:
|
||||
return f
|
||||
return lu.annotate(f, tuple([ty for k, ty in zip(in_knowns, orig_type) if k]))
|
||||
|
||||
class PartialVal(tuple):
|
||||
"""Partial value: either a known value or an unknown (abstract) value.
|
||||
|
||||
@ -182,7 +192,7 @@ class JaxprTrace(Trace):
|
||||
params, effects, source)
|
||||
return out_tracer
|
||||
|
||||
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
def process_call(self, primitive, f, tracers, params):
|
||||
if primitive in call_partial_eval_rules:
|
||||
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
|
||||
|
||||
@ -197,13 +207,14 @@ class JaxprTrace(Trace):
|
||||
# which were unknown to the first call (corresponding to in_avals).
|
||||
|
||||
# Wrap f to perform the partial evaluation and plumb out aux data.
|
||||
f = trace_to_subjaxpr_nounits(f, self.main, False)
|
||||
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
|
||||
f_ = trace_to_subjaxpr_nounits(f, self.main, False)
|
||||
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), tuple(in_avals))
|
||||
# Adjust parameters (e.g. donated_invars) for the call to be evaluated now.
|
||||
const_params = update_params(params, in_knowns, 0)
|
||||
|
||||
# Run the call, getting known out vals and aux data used for staged-out call
|
||||
out = primitive.bind(f, *in_consts, **const_params)
|
||||
out = primitive.bind(_update_annotation(f_, f.in_type, in_knowns),
|
||||
*in_consts, **const_params)
|
||||
out_knowns, out_avals, jaxpr, env = aux()
|
||||
# Split apart known outputs from the original call and residuals.
|
||||
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
@ -679,7 +690,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
|
||||
_, in_tracers, out_tracer_refs, primitive, params, effects, source_info = recipe
|
||||
out_tracers = [t_ref() for t_ref in out_tracer_refs]
|
||||
invars = [getvar(t) for t in in_tracers]
|
||||
outvars = [core.DropVar(core.abstract_unit) if t is None
|
||||
outvars = [DropVar(core.abstract_unit) if t is None
|
||||
else cast(Var, getvar(t)) for t in out_tracers]
|
||||
return new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
|
||||
|
||||
@ -1391,18 +1402,14 @@ forwarding_rules: Dict[Primitive, ForwardingRule] = {}
|
||||
def _inline_literals(jaxpr, constvals):
|
||||
# This function also ensures variables are labeled in a canonical ordering,
|
||||
# prunes unused constants, and inserts `dropvar` symbols.
|
||||
consts = dict(zip(jaxpr.constvars, constvals))
|
||||
lits = {v: Literal(c, v.aval) for v, c in zip(jaxpr.constvars, constvals)
|
||||
if type(c) in core.literalable_types and not np.shape(c)}
|
||||
lit: Callable[[Var], Optional[Literal]] = lits.get
|
||||
newname: Callable[[AbstractValue], Var] = core.gensym()
|
||||
newvars: Dict[Var, Var] = {}
|
||||
newvar = lambda aval: newname(_substitute_vars_in_type(newvars, aval))
|
||||
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
|
||||
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
|
||||
|
||||
def lit(v: Var) -> Optional[Literal]:
|
||||
val = consts.get(v)
|
||||
if type(val) in core.literalable_types and not np.shape(val):
|
||||
return Literal(val, v.aval)
|
||||
else:
|
||||
return None
|
||||
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))
|
||||
|
||||
def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
|
||||
if isinstance(aval, DShapedArray):
|
||||
@ -1420,8 +1427,7 @@ def _inline_literals(jaxpr, constvals):
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
invars = [lit(v) or var(v) for v in eqn.invars]
|
||||
outvars = [var(v) if v in used else core.DropVar(v.aval)
|
||||
for v in eqn.outvars]
|
||||
outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars]
|
||||
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
|
||||
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
|
||||
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
|
||||
@ -1443,6 +1449,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
return tracer
|
||||
|
||||
def new_const(self, c):
|
||||
# TODO(mattjj): for ints, or hashable consts, don't rely on id
|
||||
tracer = self.frame.constid_to_tracer.get(id(c))
|
||||
if tracer is None:
|
||||
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
|
||||
@ -1515,30 +1522,33 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers if primitive.multiple_results else out_tracers.pop()
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
dim_tracers = _get_tracers_only_in_shapes(tracers)
|
||||
in_avals = _tracers_to_avals(dim_tracers + tracers)
|
||||
keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers)
|
||||
def process_call(self, call_primitive, f, explicit_tracers, params):
|
||||
if f.in_type is None:
|
||||
in_avals = [core.raise_to_shaped(get_aval(x)) for x in explicit_tracers]
|
||||
keep_inputs = [True] * len(explicit_tracers)
|
||||
im_tracers = []
|
||||
else:
|
||||
im_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
|
||||
in_avals, keep_inputs = unzip2(f.in_type)
|
||||
with core.new_sublevel():
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, in_avals, keep_inputs=keep_inputs)
|
||||
tracers = [*im_tracers, *explicit_tracers]
|
||||
if params.get('inline', False):
|
||||
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
|
||||
return core.eval_jaxpr(jaxpr, consts, *tracers)
|
||||
env = {v: t for v, t in zip(jaxpr.constvars, consts) if isinstance(t, Tracer)}
|
||||
env.update(zip(jaxpr.invars, tracers))
|
||||
out_avals_ = [_substitute_tracers_in_type(env, a) for a in out_avals]
|
||||
source_info = source_info_util.current()
|
||||
env = {v: t for v, t in zip((*jaxpr.constvars, *jaxpr.invars),
|
||||
(*consts, *dim_tracers, *tracers))
|
||||
if isinstance(t, Tracer)}
|
||||
subs = partial(_substitute_tracers_in_type, env)
|
||||
out_tracers = [DynamicJaxprTracer(self, subs(a), source_info)
|
||||
for a in out_avals]
|
||||
invars = map(self.getvar, dim_tracers + tracers)
|
||||
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals_]
|
||||
invars = map(self.getvar, tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
||||
update_params = call_param_updaters.get(call_primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, [True] * len(tracers),
|
||||
len(consts) + len(dim_tracers))
|
||||
new_params = update_params(new_params, [True] * len(explicit_tracers),
|
||||
len(consts) + len(im_tracers))
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
|
||||
call_primitive, new_params,
|
||||
new_params['call_jaxpr'].effects, source_info)
|
||||
@ -1765,7 +1775,7 @@ def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
|
||||
|
||||
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
|
||||
in_avals: Sequence[AbstractValue], *,
|
||||
keep_inputs: Optional[List[bool]] = None):
|
||||
keep_inputs: Optional[Sequence[bool]] = None):
|
||||
# In general, the Tracers passed to ther Python callable underlying `fun` may
|
||||
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
|
||||
# the jaxpr). For example:
|
||||
@ -1793,7 +1803,7 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
|
||||
frame = JaxprStackFrame()
|
||||
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
|
||||
trace = DynamicJaxprTrace(main, core.cur_sublevel())
|
||||
in_tracers = _avals_to_tracers(trace, in_avals)
|
||||
in_tracers = _input_type_to_tracers(trace, in_avals)
|
||||
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
|
||||
ans = fun.call_wrapped(*in_tracers_)
|
||||
out_tracers = map(trace.full_raise, ans)
|
||||
@ -1817,12 +1827,14 @@ def extend_jaxpr_stack(main, frame):
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_final(fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: Optional[DebugInfo] = None):
|
||||
debug_info: Optional[DebugInfo] = None,
|
||||
keep_inputs: Optional[Sequence[bool]] = None):
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.debug_info = debug_info # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
with core.new_sublevel():
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
fun, main, in_avals, keep_inputs=keep_inputs)
|
||||
del fun, main
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
@ -1835,65 +1847,179 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
|
||||
return trace_to_jaxpr(fun, in_pvals)
|
||||
|
||||
|
||||
def _avals_to_tracers(
|
||||
AbstractedAxisName = Hashable
|
||||
AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName], Tuple[AbstractedAxisName, ...]]
|
||||
|
||||
class DBIdx(NamedTuple):
|
||||
val: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Bound:
|
||||
name: AbstractedAxisName
|
||||
bound: int
|
||||
|
||||
InputType = Tuple[Tuple[AbstractValue, bool], ...]
|
||||
|
||||
def infer_lambda_input_type(
|
||||
axes_specs: Optional[Sequence[AbstractedAxesSpec]],
|
||||
args: Sequence[Any]
|
||||
) -> InputType:
|
||||
partial_specs = _canonicalize_specs(map(np.ndim, args), axes_specs)
|
||||
specs = _complete_specs(args, partial_specs)
|
||||
idxs, implicit_names = _collect_implicit(args, specs)
|
||||
implicit_inputs = [(_implicit_arg_type(n), False) for n in implicit_names]
|
||||
explicit_inputs = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)]
|
||||
return (*implicit_inputs, *explicit_inputs)
|
||||
|
||||
def _canonicalize_specs(
|
||||
ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]]
|
||||
) -> List[Dict[int, AbstractedAxisName]]:
|
||||
if specs is None:
|
||||
return [{}] * len(ndims)
|
||||
else:
|
||||
return [{i: d for i, d in enumerate(s) if d is not None} if type(s) is tuple
|
||||
else s for n, s in zip(ndims, specs)]
|
||||
|
||||
def _complete_specs(
|
||||
args: Sequence[Any], partial_specs: List[Dict[int, AbstractedAxisName]]
|
||||
) -> List[Dict[int, AbstractedAxisName]]:
|
||||
# Identify each user-supplied name in partial_specs with a size.
|
||||
sizes: Dict[AbstractedAxisName, Union[int, DynamicJaxprTracer]] = {}
|
||||
for x, spec in zip(args, partial_specs):
|
||||
for i, name in spec.items():
|
||||
d = sizes.setdefault(name, x.shape[i])
|
||||
if d is not x.shape[i] and d != x.shape[i]: raise TypeError
|
||||
# Introduce new names as needed for Tracers in shapes.
|
||||
named_tracers: Dict[TracerId, AbstractedAxisName] = {
|
||||
id(d): name for name, d in sizes.items() if isinstance(d, Tracer)}
|
||||
specs: List[Dict[int, AbstractedAxisName]] = []
|
||||
for x, spec in zip(args, partial_specs):
|
||||
if isinstance(get_aval(x), DShapedArray):
|
||||
spec = dict(spec)
|
||||
for i, d in enumerate(x.shape):
|
||||
if isinstance(d, Tracer):
|
||||
spec[i] = named_tracers.get(id(d), TracerAsName(d))
|
||||
specs.append(spec)
|
||||
assert all(not spec or not any(isinstance(d, Tracer) and i not in spec
|
||||
for i, d in enumerate(x.shape))
|
||||
for x, spec in zip(args, specs))
|
||||
return specs
|
||||
|
||||
def _collect_implicit(
|
||||
args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]]
|
||||
) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractedAxisName]]:
|
||||
idxs: Dict[AbstractedAxisName, DBIdx] = {}
|
||||
explicit_tracers: Dict[TracerId, int] = {}
|
||||
counter = (DBIdx(i) for i in it.count())
|
||||
# Add implicit arguments to idxs.
|
||||
for explicit_idx, (x, spec) in enumerate(zip(args, specs)):
|
||||
for i, name in spec.items():
|
||||
if name not in idxs and id(x.shape[i]) not in explicit_tracers:
|
||||
idxs[name] = next(counter)
|
||||
if isinstance(x, Tracer):
|
||||
explicit_tracers[id(x)] = explicit_idx
|
||||
implicit_names: List[AbstractedAxisName] = list(idxs)
|
||||
|
||||
# Now that we know the implicit args, add explicit args to idxs.
|
||||
offset = len(implicit_names)
|
||||
for x, spec in zip(args, specs):
|
||||
for i, name in spec.items():
|
||||
if id(x.shape[i]) in explicit_tracers:
|
||||
idxs[name] = DBIdx(offset + explicit_tracers[id(x.shape[i])])
|
||||
|
||||
return idxs, implicit_names
|
||||
|
||||
def _implicit_arg_type(name: AbstractedAxisName) -> AbstractValue:
|
||||
if type(name) is Bound:
|
||||
return AbstractBInt(name.bound)
|
||||
else:
|
||||
return ShapedArray((), dtypes.dtype('int32'))
|
||||
|
||||
def _arg_type(
|
||||
idxs: Dict[AbstractedAxisName, DBIdx], x: Any,
|
||||
spec: Dict[int, AbstractedAxisName]
|
||||
) -> AbstractValue:
|
||||
aval = get_aval(x) # aval.shape could contain Tracers
|
||||
if not spec: return core.raise_to_shaped(aval)
|
||||
shape: List[Union[int, DBIdx]] = [idxs[spec[i]] if i in spec else d
|
||||
for i, d in enumerate(aval.shape)]
|
||||
assert not any(isinstance(d, Tracer) for d in shape)
|
||||
return DShapedArray(tuple(shape), aval.dtype, False)
|
||||
|
||||
class TracerAsName:
|
||||
tracer: DynamicJaxprTracer
|
||||
def __init__(self, tracer):
|
||||
trace = core.thread_local_state.trace_state.trace_stack.dynamic
|
||||
self.tracer = trace.with_cur_sublevel().full_raise(tracer)
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TracerAsName) and self.tracer is other.tracer
|
||||
def __hash__(self):
|
||||
return id(self.tracer)
|
||||
|
||||
def _extract_implicit_args(
|
||||
trace: DynamicJaxprTrace, in_type: Sequence[Tuple[AbstractValue, bool]],
|
||||
explicit_tracers: Sequence[DynamicJaxprTracer]
|
||||
) -> Sequence[DynamicJaxprTracer]:
|
||||
# First, construct a list to represent the full argument list, leaving the
|
||||
# implicit arguments as Nones for now.
|
||||
explicit_tracers_ = iter(explicit_tracers)
|
||||
tracers = [next(explicit_tracers_) if expl else None for _, expl in in_type]
|
||||
assert next(explicit_tracers_, None) is None
|
||||
del explicit_tracers_
|
||||
|
||||
# Next, populate the implicit arguments using DBIdxs in in_type.
|
||||
for i, (aval, explicit) in enumerate(in_type):
|
||||
if not explicit or not isinstance(aval, DShapedArray):
|
||||
continue # can't populate an implicit argument
|
||||
tracer = tracers[i]
|
||||
assert tracer is not None
|
||||
for d1, d2 in zip(aval.shape, tracer.aval.shape):
|
||||
if isinstance(d1, DBIdx):
|
||||
if tracers[d1.val] is None:
|
||||
tracers[d1.val] = trace.instantiate_const(d2)
|
||||
assert tracers[d1.val] is trace.instantiate_const(d2)
|
||||
assert all(t is not None for t in tracers)
|
||||
return [t for t, (_, e) in zip(tracers, in_type) if not e]
|
||||
|
||||
def _in_avals_from_tracers(
|
||||
tracers: List[DynamicJaxprTracer]
|
||||
) -> List[AbstractValue]:
|
||||
# Returned AbstractValues contain DBIdx indices. Uses Tracer obj id as name.
|
||||
dbidxs: Dict[TracerId, DBIdx] = {id(t): DBIdx(i) for i, t in enumerate(tracers)}
|
||||
in_avals: List[AbstractValue] = []
|
||||
for t in tracers:
|
||||
a = t.aval
|
||||
if isinstance(a, DShapedArray) and any(isinstance(d, Tracer) for d in a.shape):
|
||||
shape = [dbidxs[id(d)] if isinstance(d, Tracer) else d for d in a.shape]
|
||||
a = a.update(shape=tuple(shape))
|
||||
in_avals.append(a)
|
||||
return in_avals
|
||||
|
||||
def _input_type_to_tracers(
|
||||
trace: DynamicJaxprTrace, in_avals: Sequence[AbstractValue]
|
||||
) -> Sequence[Tracer]:
|
||||
# Create input Tracers given input AbstractValues, each of which can contain
|
||||
# other AbstractValues. That is, each element `a` of `in_avals` can have
|
||||
# abstract values in its shape, which must occur to the left of `a`.
|
||||
env: Dict[AvalId, Tracer] = {}
|
||||
# DeBruijn indices which refer to positions in the input argument list. That
|
||||
# is, each element `a` of `in_avals` can have DBIdx instances in its shape,
|
||||
# which must refer to positions left of `a`'s.
|
||||
in_tracers: List[Tracer] = []
|
||||
for a in in_avals:
|
||||
t = env[id(a)] = trace.new_arg(_substitute_tracers_in_aval(env, a))
|
||||
in_tracers.append(t)
|
||||
return in_tracers
|
||||
|
||||
def _substitute_tracers_in_aval(
|
||||
env: Dict[AvalId, Tracer], a: AbstractValue
|
||||
) -> AbstractValue:
|
||||
# Substitute Tracers into a given AbstractValue using the given environment.
|
||||
# That is, the input is an AbstractValue possibly containing AbstractValues,
|
||||
# and the output is an AbstractValue possibly containing Tracers.
|
||||
if (isinstance(a, DShapedArray) and
|
||||
any(isinstance(d, AbstractValue) for d in a.shape)):
|
||||
shape = [env[id(d)] if isinstance(d, AbstractValue) else d for d in a.shape]
|
||||
return a.update(shape=tuple(shape))
|
||||
return a
|
||||
|
||||
def _tracers_to_avals(tracers: Sequence[Tracer]) -> List[AbstractValue]:
|
||||
# Replace Tracers with corresponding abstract values, handling Tracers in
|
||||
# shapes and ensuring each Tracer object is mapped to a single AbstractValue.
|
||||
env: Dict[TracerId, AbstractValue] = {}
|
||||
avals: List[AbstractValue] = []
|
||||
for t in tracers:
|
||||
aval = env.get(id(t))
|
||||
if aval is None:
|
||||
aval = env[id(t)] = _substitute_avals_in_aval(env, t.aval)
|
||||
avals.append(aval)
|
||||
return avals
|
||||
|
||||
def _substitute_avals_in_aval(
|
||||
env: Dict[TracerId, AbstractValue], a: AbstractValue
|
||||
) -> AbstractValue:
|
||||
# Substitute AbstractValues into given AbstractValue using given environment.
|
||||
# That is, the input is an AbstractValue possibly containing Tracers and the
|
||||
# output is an AbstractValue possibly containing AbstractValues.
|
||||
if (isinstance(a, DShapedArray) and
|
||||
any(isinstance(d, Tracer) for d in a.shape)):
|
||||
shape = [env.setdefault(id(d), d.aval) if isinstance(d, Tracer) else d
|
||||
for d in a.shape]
|
||||
return a.update(shape=tuple(shape))
|
||||
else:
|
||||
def _substitute_tracers_in_aval(a: AbstractValue) -> AbstractValue:
|
||||
if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape):
|
||||
shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape] # type: ignore
|
||||
return a.update(shape=tuple(shape))
|
||||
return a
|
||||
|
||||
for a in in_avals:
|
||||
in_tracers.append(trace.new_arg(_substitute_tracers_in_aval(a)))
|
||||
return in_tracers
|
||||
|
||||
def _substitute_vars_in_type(
|
||||
env: Dict[Var, Var], a: AbstractValue
|
||||
consts: Dict[Var, Literal], env: Dict[Var, Var], a: AbstractValue
|
||||
) -> AbstractValue:
|
||||
# Substitutes variables into a given AbstractValue using given environment.
|
||||
# That is, the input is an AbstractValue possibly containing Vars, and the
|
||||
# output is an aval possibly containing Vars.
|
||||
if isinstance(a, DShapedArray) and any(isinstance(d, Var) for d in a.shape):
|
||||
shape = [env[d] if isinstance(d, Var) else d for d in a.shape]
|
||||
shape = [consts[d].val if d in consts else env[d] # type: ignore
|
||||
if isinstance(d, Var) else d for d in a.shape]
|
||||
return a.update(shape=tuple(shape))
|
||||
else:
|
||||
return a
|
||||
@ -1910,24 +2036,74 @@ def _substitute_tracers_in_type(
|
||||
else:
|
||||
return a
|
||||
|
||||
def _get_tracers_only_in_shapes(in_tracers: Sequence[Tracer]) -> List[Tracer]:
|
||||
# In DynamicJaxprTrace.process_call (e.g. for handling jit-of-jit) we want to
|
||||
# extract Tracers from the shapes of arguments so as to elaborate them into
|
||||
# explicit inputs to the call that appears in the jaxpr. Some of these Tracers
|
||||
# may already appear as explicit inputs, so we only need to get those present
|
||||
# exclusively in shapes.
|
||||
return _get_tracers_in_shapes({id(t) for t in in_tracers}, in_tracers)
|
||||
Const = Any
|
||||
Val = Any
|
||||
|
||||
def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
|
||||
) -> Tuple[Jaxpr, List[Const]]:
|
||||
bounds = {v: v.aval.bound for v in jaxpr.invars
|
||||
if type(v.aval) is AbstractBInt}
|
||||
idxs = {v: DBIdx(i) for i, v in enumerate(jaxpr.invars)}
|
||||
|
||||
def substitute(aval: AbstractValue) -> AbstractValue:
|
||||
if isinstance(aval, AbstractBInt):
|
||||
return ShapedArray((), np.dtype('int32'))
|
||||
elif isinstance(aval, DShapedArray):
|
||||
shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore
|
||||
typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray
|
||||
return typ(tuple(shape), aval.dtype, aval.weak_type)
|
||||
else:
|
||||
return aval
|
||||
|
||||
in_avals = [substitute(v.aval) for v in jaxpr.invars]
|
||||
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
|
||||
padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals)
|
||||
return padded_jaxpr, padded_consts
|
||||
|
||||
class BoundedAxisSize(NamedTuple):
|
||||
val: Union[int, DynamicJaxprTracer]
|
||||
bound: int
|
||||
|
||||
def _eval_jaxpr_padded(
|
||||
jaxpr: Jaxpr, consts: List[Const], *args: DynamicJaxprTracer
|
||||
) -> List[Union[Const, DynamicJaxprTracer]]:
|
||||
env: Dict[Var, Val] = {}
|
||||
|
||||
def read(x):
|
||||
return x.val if type(x) is Literal else env[x]
|
||||
|
||||
def write(v, val) -> None:
|
||||
env[v] = val
|
||||
|
||||
write(unitvar, unit)
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
rule = padding_rules[eqn.primitive]
|
||||
in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
|
||||
out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
|
||||
outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
|
||||
map(write, eqn.outvars, outs)
|
||||
return map(read, jaxpr.outvars)
|
||||
|
||||
def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:
|
||||
if isinstance(aval, DShapedArray):
|
||||
shp = [BoundedAxisSize(env[d], d.aval.bound) if type(d) is Var and
|
||||
type(d.aval) is AbstractBInt else env.get(d, d) for d in aval.shape]
|
||||
return DShapedArray(tuple(shp), aval.dtype, aval.weak_type)
|
||||
else:
|
||||
return aval
|
||||
|
||||
padding_rules: Dict[Primitive, Callable] = {}
|
||||
|
||||
def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
|
||||
if call_jaxpr.constvars: raise NotImplementedError
|
||||
padded_jaxpr, padded_consts = pad_jaxpr(call_jaxpr, ())
|
||||
if padded_consts: raise NotImplementedError
|
||||
new_params = dict(params, call_jaxpr=padded_jaxpr)
|
||||
subfuns, bind_params = prim.get_bind_params(new_params)
|
||||
return prim.bind(*subfuns, *args, **bind_params)
|
||||
|
||||
def _get_tracers_in_shapes(seen: Set[TracerId], in_tracers: Sequence[Tracer]
|
||||
) -> List[Tracer]:
|
||||
dim_tracers: List[Tracer] = []
|
||||
for t in in_tracers:
|
||||
if isinstance(t.aval, core.DShapedArray):
|
||||
for d in t.aval.shape:
|
||||
if isinstance(d, Tracer) and id(d) not in seen:
|
||||
seen.add(id(d))
|
||||
dim_tracers.append(d)
|
||||
return dim_tracers
|
||||
|
||||
# TODO(mattjj): the following are deprecated; update callers to _nounits version
|
||||
# See https://github.com/google/jax/pull/9498
|
||||
|
@ -674,8 +674,7 @@ def jaxpr_collectives(jaxpr):
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive in _collective_primitives:
|
||||
yield eqn.primitive
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_collectives(subjaxpr)
|
||||
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)
|
||||
|
||||
|
||||
### xla_call underlying jit
|
||||
@ -833,6 +832,8 @@ pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
|
||||
_xla_call_partial_eval_custom_params_updater)
|
||||
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
|
||||
|
||||
pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p)
|
||||
|
||||
|
||||
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext,
|
||||
settings: core.JaxprPpSettings,
|
||||
|
@ -146,6 +146,7 @@ from jax._src.lax.lax import (
|
||||
log_p as log_p,
|
||||
lt as lt,
|
||||
lt_p as lt_p,
|
||||
make_bint as make_bint,
|
||||
max as max,
|
||||
max_p as max_p,
|
||||
min as min,
|
||||
|
@ -61,10 +61,11 @@ compare as equal only if they compute the same function. The static and the
|
||||
dynamic positional arguments for the generators, and also the auxiliary output
|
||||
data must be immutable, because it will be stored in function memoization tables.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from functools import partial
|
||||
from typing import Any, Tuple, Callable
|
||||
from typing import Any, Tuple, Callable, Optional
|
||||
import weakref
|
||||
|
||||
from jax import core
|
||||
@ -81,10 +82,10 @@ traceback_util.register_exclusion(__file__)
|
||||
class StoreException(Exception): pass
|
||||
|
||||
|
||||
class EmptyStoreValue(object): pass
|
||||
class EmptyStoreValue: pass
|
||||
_EMPTY_STORE_VALUE = EmptyStoreValue()
|
||||
|
||||
class Store(object):
|
||||
class Store:
|
||||
"""Storage for a value, with checks for overwriting or reading empty store."""
|
||||
__slots__ = ("_val",)
|
||||
|
||||
@ -112,27 +113,28 @@ class Store(object):
|
||||
__bool__ = __nonzero__
|
||||
|
||||
|
||||
class WrappedFun(object):
|
||||
class WrappedFun:
|
||||
"""Represents a function `f` to which `transforms` are to be applied.
|
||||
|
||||
Args:
|
||||
f: the function to be transformed.
|
||||
transforms: a list of `(gen, gen_static_args)` tuples representing
|
||||
transformations to apply to `f.` Here `gen` is a generator function
|
||||
and `gen_static_args` is a tuple of static arguments for the generator. See
|
||||
transformations to apply to `f.` Here `gen` is a generator function and
|
||||
`gen_static_args` is a tuple of static arguments for the generator. See
|
||||
description at the start of this module for the expected behavior of the
|
||||
generator.
|
||||
stores: a list of out_store for the auxiliary output of the `transforms`.
|
||||
params: extra parameters to pass as keyword arguments to `f`, along with the
|
||||
transformed keyword arguments.
|
||||
"""
|
||||
__slots__ = ("f", "transforms", "stores", "params")
|
||||
__slots__ = ("f", "transforms", "stores", "params", "in_type")
|
||||
|
||||
def __init__(self, f, transforms, stores, params):
|
||||
def __init__(self, f, transforms, stores, params, in_type):
|
||||
self.f = f
|
||||
self.transforms = transforms
|
||||
self.stores = stores
|
||||
self.params = params
|
||||
self.in_type = in_type
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
@ -141,7 +143,7 @@ class WrappedFun(object):
|
||||
def wrap(self, gen, gen_static_args, out_store) -> 'WrappedFun':
|
||||
"""Add another transform and its store."""
|
||||
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
|
||||
(out_store,) + self.stores, self.params)
|
||||
(out_store,) + self.stores, self.params, None)
|
||||
|
||||
def populate_stores(self, stores):
|
||||
"""Copy the values from the `stores` into `self.stores`."""
|
||||
@ -200,11 +202,11 @@ class WrappedFun(object):
|
||||
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.f, self.transforms, self.params))
|
||||
return hash((self.f, self.transforms, self.params, self.in_type))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.f == other.f and self.transforms == other.transforms and
|
||||
self.params == other.params)
|
||||
self.params == other.params and self.in_type == other.in_type)
|
||||
|
||||
@curry
|
||||
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
@ -231,8 +233,19 @@ def fun_name(f):
|
||||
|
||||
def wrap_init(f, params=None) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
return WrappedFun(f, (), (),
|
||||
() if params is None else tuple(sorted(params.items())))
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
return WrappedFun(f, (), (), params, None)
|
||||
|
||||
def annotate(f: WrappedFun,
|
||||
in_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]]
|
||||
) -> WrappedFun:
|
||||
assert f.in_type is None
|
||||
if in_type is None:
|
||||
return f
|
||||
assert (type(in_type) is tuple and all(type(e) is tuple for e in in_type) and
|
||||
all(isinstance(a, core.AbstractValue) and type(b) is bool
|
||||
for a, b in in_type))
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
|
||||
|
||||
|
||||
class _CacheLocalContext(threading.local):
|
||||
@ -259,11 +272,11 @@ def cache(call: Callable):
|
||||
def memoized_fun(fun: WrappedFun, *args):
|
||||
cache = fun_caches.setdefault(fun.f, {})
|
||||
if config.jax_check_tracer_leaks:
|
||||
key = (_copy_main_traces(fun.transforms), fun.params, args,
|
||||
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
|
||||
config.x64_enabled, config._trace_context())
|
||||
else:
|
||||
key = (fun.transforms, fun.params, args, config.x64_enabled,
|
||||
config._trace_context())
|
||||
key = (fun.transforms, fun.params, fun.in_type, args,
|
||||
config.x64_enabled, config._trace_context())
|
||||
result = cache.get(key, None)
|
||||
if result is not None:
|
||||
ans, stores = result
|
||||
|
@ -7972,6 +7972,177 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
self.assertIs(c, c_)
|
||||
self.assertIs(d, d_)
|
||||
|
||||
def test_jit_abstracted_axes_staging(self):
|
||||
# We just test make_jaxpr-of-jit because dynamic shape compilation/execution
|
||||
# may not be supported.
|
||||
@partial(jax.jit, abstracted_axes=('n',))
|
||||
def f(x):
|
||||
return jnp.sum(x)
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.ones(3))
|
||||
# { lambda ; a:f32[3]. let
|
||||
# b:f32[] = xla_call[
|
||||
# call_jaxpr={ lambda ; c:i32[] d:f32[c]. let
|
||||
# e:f32[] = reduce_sum[axes=(0,)] d
|
||||
# in (e,) }
|
||||
# name=f
|
||||
# ] 3 a
|
||||
# in (b,) }
|
||||
a, = jaxpr.jaxpr.invars
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
self.assertLen(e.invars, 2)
|
||||
self.assertIsInstance(e.invars[0], core.Literal)
|
||||
self.assertIs(e.invars[1], a)
|
||||
b, = e.outvars
|
||||
self.assertLen(b.aval.shape, 0)
|
||||
|
||||
subjaxpr = e.params['call_jaxpr']
|
||||
c, d = subjaxpr.invars
|
||||
self.assertLen(c.aval.shape, 0)
|
||||
self.assertLen(d.aval.shape, 1)
|
||||
self.assertIs(d.aval.shape[0], c)
|
||||
|
||||
def test_jit_abstracted_axes_staging2(self):
|
||||
@partial(jax.jit, abstracted_axes=('n',))
|
||||
def fun(x):
|
||||
return jnp.sum(x)
|
||||
jaxpr = jax.make_jaxpr(lambda n: fun(jnp.ones(n + n)))(3)
|
||||
# { lambda ; a:i32[]. let
|
||||
# b:i32[] = add a a
|
||||
# c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b
|
||||
# d:f32[] = xla_call[
|
||||
# call_jaxpr={ lambda ; e:i32[] f:f32[e]. let
|
||||
# g:f32[] = reduce_sum[axes=(0,)] f
|
||||
# in (g,) }
|
||||
# name=f
|
||||
# ] b c
|
||||
# in (d,) }
|
||||
a, = jaxpr.jaxpr.invars
|
||||
e1, e2, e3 = jaxpr.jaxpr.eqns
|
||||
b, = e1.outvars
|
||||
c, = e2.outvars
|
||||
b_, c_ = e3.invars
|
||||
self.assertIs(b, b_)
|
||||
self.assertIs(c, c_)
|
||||
|
||||
subjaxpr = e3.params['call_jaxpr']
|
||||
e, f = subjaxpr.invars
|
||||
self.assertLen(e.aval.shape, 0)
|
||||
self.assertLen(f.aval.shape, 1)
|
||||
self.assertIs(f.aval.shape[0], e)
|
||||
|
||||
def test_jit_abstracted_axes_staging3(self):
|
||||
f = jax.jit(jnp.sum, abstracted_axes=('n',))
|
||||
jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3.))
|
||||
# { lambda ; a:i32[] b:f32[a]. let
|
||||
# c:f32[] = xla_call[
|
||||
# call_jaxpr={ lambda ; d:i32[] e:f32[d]. let
|
||||
# f:f32[] = reduce_sum[axes=(0,)] e
|
||||
# in (f,) }
|
||||
# name=sum
|
||||
# ] a b
|
||||
# in (c,) }
|
||||
a, b = jaxpr.jaxpr.invars
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
self.assertIs(e.invars[0], a)
|
||||
self.assertIs(e.invars[1], b)
|
||||
c, = e.outvars
|
||||
self.assertLen(c.aval.shape, 0)
|
||||
|
||||
subjaxpr = e.params['call_jaxpr']
|
||||
d, e = subjaxpr.invars
|
||||
self.assertLen(d.aval.shape, 0)
|
||||
self.assertLen(e.aval.shape, 1)
|
||||
self.assertIs(e.aval.shape[0], d)
|
||||
|
||||
def test_jit_abstracted_axes_return_polymorphic_shape(self):
|
||||
f = jax.jit(lambda x: x, abstracted_axes=('n',))
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash
|
||||
# { lambda ; a:i32[3]. let
|
||||
# b:i32[3] = xla_call[
|
||||
# call_jaxpr={ lambda ; c:i32[] d:i32[c]. let in (d,) }
|
||||
# name=<lambda>
|
||||
# ] 3 a
|
||||
# in (b,) }
|
||||
a, = jaxpr.jaxpr.invars
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
three, a_ = e.invars
|
||||
b, = e.outvars
|
||||
self.assertIsInstance(three, core.Literal)
|
||||
self.assertEqual(three.val, 3)
|
||||
self.assertIs(a_, a)
|
||||
self.assertLen(b.aval.shape, 1)
|
||||
self.assertEqual(b.aval.shape[0], 3)
|
||||
|
||||
def test_jit_abstracted_axes_return_polymorphic_shape2(self):
|
||||
f = jax.jit(lambda n: jnp.ones(n))
|
||||
# TODO(mattjj,dougalm): support dynamic shapes in type checker
|
||||
with jax.enable_checks(False):
|
||||
jaxpr = jax.make_jaxpr(f)(3)
|
||||
# { lambda ; a:i32[]. let
|
||||
# b:f32[a] = xla_call[
|
||||
# call_jaxpr={ lambda ; c:i32[]. let
|
||||
# d:f32[c] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0
|
||||
# c
|
||||
# in (d,) }
|
||||
# name=<lambda>
|
||||
# ] a
|
||||
# in (b,) }
|
||||
a, = jaxpr.jaxpr.invars
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
a_, = e.invars
|
||||
self.assertIs(a, a_)
|
||||
b, = e.outvars
|
||||
a__, = b.aval.shape
|
||||
self.assertIs(a, a__)
|
||||
|
||||
with jax.enable_checks(False):
|
||||
jaxpr = jax.make_jaxpr(lambda: f(3))()
|
||||
# { lambda ; . let
|
||||
# a:f32[3] = xla_call[
|
||||
# call_jaxpr={ lambda ; b:i32[]. let
|
||||
# c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0
|
||||
# b
|
||||
# in (c,) }
|
||||
# name=<lambda>
|
||||
# ] 3
|
||||
# in (a,) }
|
||||
() = jaxpr.jaxpr.invars
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
three, = e.invars
|
||||
self.assertIsInstance(three, core.Literal)
|
||||
self.assertEqual(three.val, 3)
|
||||
b, = e.outvars
|
||||
three_, = b.aval.shape
|
||||
self.assertIsInstance(three_, int)
|
||||
self.assertEqual(three_, 3)
|
||||
|
||||
def test_jit_basic_iree(self):
|
||||
if not jtu.device_under_test() == 'iree':
|
||||
raise unittest.SkipTest("test only works on IREE")
|
||||
@jax.jit
|
||||
def f(i):
|
||||
return jnp.sum(jnp.ones(i, dtype='float32'))
|
||||
|
||||
self.assertAllClose(f(3), jnp.array(3., dtype='float32'), check_dtypes=True)
|
||||
|
||||
def test_slicing_basic(self):
|
||||
f = jax.jit(lambda x, n: jnp.sum(x[:n]))
|
||||
ans = f(jnp.arange(10), 3)
|
||||
expected = jnp.sum(jnp.arange(10)[:3])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def test_scan_basic(self):
|
||||
def cumsum(x):
|
||||
def body(i, _):
|
||||
return i + 1, jnp.sum(x[:i+1])
|
||||
_, ans = lax.scan(body, 0, None, length=len(x))
|
||||
return ans
|
||||
x = jnp.array([3, 1, 4, 1, 5, 9])
|
||||
with jax.enable_checks(False):
|
||||
ans = cumsum(x)
|
||||
expected = jnp.cumsum(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -30,7 +30,8 @@ from jax import numpy as jnp
|
||||
from jax import linear_util as lu
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.core import UnshapedArray, ShapedArray
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_reduce, tree_leaves
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
|
||||
tree_leaves)
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
@ -529,15 +530,15 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_staging_basic(self):
|
||||
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
|
||||
a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
||||
a = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
|
||||
@lu.wrap_init
|
||||
def f(x, y):
|
||||
return x, y
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
|
||||
keep_inputs=[False, True, True])
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
self.assertLen(jaxpr.invars, 3)
|
||||
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
|
||||
@ -549,8 +550,8 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_staging_nested(self):
|
||||
n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
|
||||
a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
||||
a = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
|
||||
@lu.wrap_init
|
||||
def f(x, y):
|
||||
@ -559,8 +560,8 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
return (x, w)
|
||||
return g(x, y, x, y)
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
|
||||
keep_inputs=[False, True, True])
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
|
||||
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
|
||||
@ -583,10 +584,9 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
|
||||
|
||||
def test_staging_nested_including_shape_arg(self):
|
||||
# This test covers the _get_tracers_only_in_shapes logic in partial_eval.py.
|
||||
n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
|
||||
a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
||||
a = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
|
||||
@lu.wrap_init
|
||||
def f(x, y):
|
||||
@ -595,8 +595,18 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
return (x, w)
|
||||
return g(x.shape[0], x, y, x, y)
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
|
||||
keep_inputs=[False, True, True])
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
print(jaxpr)
|
||||
|
||||
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
|
||||
# d:f32[a] e:f32[a] = xla_call[
|
||||
# call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let
|
||||
#
|
||||
# in (h, k) }
|
||||
# name=g
|
||||
# ] a a b c b c
|
||||
# in (d, e) }
|
||||
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
eqn = jaxpr.eqns[0]
|
||||
@ -612,8 +622,8 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_staging_primitive_applications(self):
|
||||
n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
|
||||
a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
||||
a = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
b = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
|
||||
|
||||
@lu.wrap_init
|
||||
def f(x, y):
|
||||
@ -622,8 +632,8 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
u = lax_internal._reduce_sum(w, [0])
|
||||
return (u,)
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
|
||||
keep_inputs=[False, True, True])
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
|
||||
self.assertLen(jaxpr.eqns, 3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user