prototyping dynamic shapes

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2022-03-30 17:52:55 -07:00
parent fb6a143d4d
commit 4354f355a8
19 changed files with 935 additions and 267 deletions

View File

@ -32,7 +32,8 @@ import threading
import weakref
import types
from typing import (Any, Callable, Iterable, NamedTuple, Mapping, Optional,
Sequence, Tuple, TypeVar, Union, overload, Dict, Hashable)
Sequence, Tuple, TypeVar, Union, overload, Dict, Hashable,
List)
from warnings import warn
import numpy as np
@ -43,9 +44,9 @@ from jax import core
from jax import linear_util as lu
from jax import stages
from jax.core import eval_jaxpr
from jax.tree_util import (tree_map, tree_multimap, tree_flatten, tree_unflatten,
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
tree_map, treedef_is_leaf, treedef_children,
tree_multimap, treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)
from jax._src import device_array
@ -191,7 +192,7 @@ def _infer_argnums_and_argnames(
fun: Callable,
argnums: Union[int, Iterable[int], None],
argnames: Union[str, Iterable[str], None],
) -> Tuple[Tuple[int, ...], Tuple[str, ...]]:
) -> Tuple[Tuple[int, ...], Tuple[str, ...]]:
"""Infer missing argnums and argnames for a function with inspect."""
if argnums is None and argnames is None:
argnums = ()
@ -226,15 +227,16 @@ def _infer_argnums_and_argnames(
def jit(
fun: Callable,
*,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> stages.Wrapped:
fun: Callable,
*,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
abstracted_axes: Optional[Any] = None,
) -> stages.Wrapped:
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
@ -325,12 +327,12 @@ def jit(
>>> g(jnp.arange(4), 3)
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
"""
if FLAGS.experimental_cpp_jit:
if FLAGS.experimental_cpp_jit and not config.jax_dynamic_shapes:
return _cpp_jit(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline)
else:
return _python_jit(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline)
donate_argnums, inline, abstracted_axes)
def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
@ -352,6 +354,8 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
return f, in_tree, args_flat, donated_invars
PytreeOfAbstractedAxesSpec = Any
def _python_jit(
fun: Callable,
static_argnums: Union[int, Iterable[int], None] = None,
@ -360,8 +364,8 @@ def _python_jit(
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> stages.Wrapped:
# The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit.
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec] = None,
) -> stages.Wrapped:
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
fun, static_argnums, static_argnames)
@ -376,9 +380,14 @@ def _python_jit(
return fun(*args, **kwargs)
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
if config.jax_dynamic_shapes:
axes_specs = (None if abstracted_axes is None else
_flat_axes_specs(abstracted_axes, *args, **kwargs))
in_type = pe.infer_lambda_input_type(axes_specs, args_flat)
flat_fun = lu.annotate(flat_fun, in_type)
out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
@ -389,6 +398,14 @@ def _python_jit(
backend, donate_argnums, inline)
return f_jitted
def _flat_axes_specs(abstracted_axes, *args, **kwargs
) -> List[pe.AbstractedAxesSpec]:
if kwargs: raise NotImplementedError
def ax_leaf(l):
return (isinstance(l, dict) and all_leaves(l.values()) or
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
return broadcast_prefix(abstracted_axes, args, ax_leaf)
class _BackendAndDeviceInfo(NamedTuple):
default_device: xc.Device
@ -412,7 +429,7 @@ def _cpp_jit(
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> stages.Wrapped:
) -> stages.Wrapped:
# An implementation of `jit` that tries to do as much as possible in C++.
# The goal of this function is to speed up the time it takes to process the
# arguments, find the correct C++ executable, start the transfer of arguments
@ -442,6 +459,9 @@ def _cpp_jit(
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
if jax.config.jax_dynamic_shapes:
in_type = pe.infer_lambda_input_type(None, args_flat)
flat_fun = lu.annotate(flat_fun, in_type)
out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
@ -462,10 +482,12 @@ def _cpp_jit(
execute is not None and
execute.func is dispatch._execute_compiled and # not trivial, not pmap
# Not supported: ShardedDeviceArray
all(device_array.type_is_device_array(x) for x in out_flat))
all(device_array.type_is_device_array(x) for x in out_flat) and
# Not supported: dynamic shapes
not jax.config.jax_dynamic_shapes)
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
_, xla_executable, _, result_handlers, kept_var_idx = execute.args
_, xla_executable, _, _, result_handlers, kept_var_idx = execute.args
sticky_device = None
avals = []
lazy_exprs = [None] * len(result_handlers)
@ -900,7 +922,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
) -> Callable[..., Tuple[Any, Any]]:
) -> Callable[..., Tuple[Any, Any]]:
"""Create a function that evaluates both ``fun`` and the gradient of ``fun``.
Args:
@ -1517,18 +1539,18 @@ def _mapped_axis_size(tree, vals, dims, name, *, kws=False):
raise ValueError(msg.format(f"the tree of axis sizes is:\n{sizes}")) from None
def pmap(
fun: Callable,
axis_name: Optional[AxisName] = None,
*,
in_axes=0,
out_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices: Optional[Sequence[xc.Device]] = None,
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any:
fun: Callable,
axis_name: Optional[AxisName] = None,
*,
in_axes=0,
out_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices: Optional[Sequence[xc.Device]] = None,
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any:
"""Parallel map with support for collective operations.
The purpose of :py:func:`pmap` is to express single-program multiple-data
@ -1864,7 +1886,7 @@ def _get_f_mapped(
axis_size: Optional[int],
donate_tuple: Tuple[int],
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]],
):
):
def pmap_f(*args, **kwargs):
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
@ -1914,7 +1936,7 @@ def _python_pmap(
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> stages.Wrapped:
) -> stages.Wrapped:
"""The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
@ -1973,7 +1995,7 @@ def _cpp_pmap(
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any:
) -> Any:
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
out_axes)
@ -2140,7 +2162,7 @@ def shapecheck(in_shapes, out_shape, fun: Callable):
def jvp(
fun: Callable, primals, tangents, has_aux: bool = False
) -> Tuple[Any, ...]:
) -> Tuple[Any, ...]:
"""Computes a (forward-mode) Jacobian-vector product of ``fun``.
Args:
@ -2364,7 +2386,7 @@ else:
def vjp( # type: ignore
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
:py:func:`grad` is implemented as a special case of :py:func:`vjp`.
@ -2596,24 +2618,10 @@ def make_jaxpr(fun: Callable,
if abstracted_axes is None:
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
else:
if kwargs: raise NotImplementedError
ax_leaf = lambda l: (isinstance(l, dict) and all_leaves(l.values()) or
isinstance(l, tuple) and all_leaves(l))
axes_specs = broadcast_prefix(abstracted_axes, args, ax_leaf)
sizes: Dict[Hashable, int] = {}
env: Dict[Hashable, core.AbstractValue] = {}
def make_aval(arg, spec):
if isinstance(spec, tuple):
spec = dict(zip(range(len(arg.shape)), spec))
if not spec: return shaped_abstractify(arg)
assert all(arg.shape[i] == sizes.setdefault(name, arg.shape[i])
for i, name in spec.items())
shape = [env.setdefault(spec[i], ShapedArray((), dtypes.dtype('int32')))
if i in spec else d for i, d in enumerate(arg.shape)]
return core.DShapedArray(tuple(shape), arg.dtype, False)
in_avals = map(make_aval, flat_args, axes_specs)
keep_inputs = [False] * len(env) + [True] * len(flat_args)
return [*env.values(), *in_avals], in_tree, keep_inputs
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
in_avals, keep_inputs = unzip2(in_type)
return in_avals, in_tree, keep_inputs
@wraps(fun)
@api_boundary
@ -3078,7 +3086,7 @@ def named_call(
fun: Callable[..., Any],
*,
name: Optional[str] = None,
) -> Callable[..., Any]:
) -> Callable[..., Any]:
"""Adds a user specified name to a function when staging out JAX computations.
When staging out computations for just-in-time compilation to XLA (or other

View File

@ -350,7 +350,7 @@ class Config:
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately."""
return (self.x64_enabled, self.jax_numpy_rank_promotion,
self.jax_default_matmul_precision)
self.jax_default_matmul_precision, self.jax_dynamic_shapes)
class _StateContextManager:
def __init__(self, name, help, update_thread_local_hook,
@ -424,6 +424,7 @@ already_configured_with_absl = False
class GlobalJitState(NamedTuple):
numpy_rank_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
dynamic_shapes: bool = False
def update_global_jit_state(**kw):
@ -436,6 +437,7 @@ class ThreadLocalJitState(NamedTuple):
dynamic_trace_state: Optional[Any] = None
numpy_rank_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
dynamic_shapes: bool = False
def update_thread_local_jit_state(**kw):
@ -700,7 +702,11 @@ config.define_bool_state(
name='jax_dynamic_shapes',
default=False,
help=('Enables experimental features for staging out computations with '
'dynamic shapes.'))
'dynamic shapes.'),
update_global_hook=lambda val: \
update_global_jit_state(dynamic_shapes=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(dynamic_shapes=val))
config.define_bool_state(
name='jax_experimental_name_stack',

View File

@ -20,7 +20,7 @@ from functools import partial
import itertools
import time
from typing import (
Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union)
Any, Callable, Dict, Optional, Sequence, Set, Tuple, List, Type, Union)
from typing_extensions import Protocol
import os
import re
@ -29,6 +29,7 @@ import warnings
from absl import logging
import numpy as np
import jax
from jax import core
from jax import linear_util as lu
from jax.errors import UnexpectedTracerError
@ -142,8 +143,11 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
arg_specs = unsafe_map(arg_spec, args)
if fun.in_type is not None:
arg_specs = [(None, *xs) for _, *xs in arg_specs]
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
*arg_specs)
try:
out = compiled_fun(*args)
except FloatingPointError:
@ -159,7 +163,7 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
# intentional here, to avoid "Store occupied" errors we clone the WrappedFun
# with empty stores.
stores = [lu.Store() for _ in fun.stores]
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params)
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params, fun.in_type)
with core.new_sublevel():
_ = clone.call_wrapped(*args) # probably won't return
return out
@ -193,21 +197,27 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
if fun.in_type is not None:
abstract_args, which_explicit = util.unzip2(fun.in_type)
else:
which_explicit = None
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
fun, abstract_args, pe.debug_info_final(fun, "jit"), which_explicit)
if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError("Encountered an unexpected tracer.")
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
abstract_args, arg_devices = util.unzip2(pruned_arg_specs)
donated_invars = [
x for i, x in enumerate(donated_invars) if i in kept_var_idx
]
# TODO(mattjj): handle argument pruning w/ dynamic shapes
if fun.in_type is None:
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
abstract_args, arg_devices = util.unzip2(
[a for i, a in enumerate(arg_specs) if i in kept_var_idx])
donated_invars = [x for i, x in enumerate(donated_invars) if i in kept_var_idx]
del kept_const_idx
else:
kept_var_idx = set(range(len(abstract_args)))
map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
jaxpr = apply_outfeed_rewriter(jaxpr)
@ -215,12 +225,16 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and
backend.platform != 'iree'):
jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if not jaxpr.eqns:
return XlaComputation(
name, None, True, None, jaxpr=jaxpr, consts=consts, device=device,
name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device,
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
if not _on_exit:
@ -260,9 +274,9 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
module_name, closed_jaxpr, backend.platform,
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
return XlaComputation(
name, module, False, donated_invars, nreps=nreps, device=device,
backend=backend, tuple_args=tuple_args, in_avals=abstract_args,
out_avals=out_avals, kept_var_idx=kept_var_idx)
name, module, False, donated_invars, which_explicit, nreps=nreps,
device=device, backend=backend, tuple_args=tuple_args,
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
def prefetch(x):
@ -291,6 +305,14 @@ def jaxpr_has_pmap(jaxpr):
return True
return False
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
return (any(type(d) is core.Var for v in jaxpr.invars
if type(v.aval) is core.DShapedArray for d in v.aval.shape) or
any(type(d) is core.Var
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
for e in j.eqns for v in itertools.chain(e.invars, e.outvars)
if type(v.aval) is core.DShapedArray for d in v.aval.shape))
def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
@ -384,6 +406,59 @@ num_buffers_handlers[core.ShapedArray] = lambda _: 1
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
def _input_handler(which_explicit: Optional[Sequence[bool]],
in_avals: Sequence[core.AbstractValue]
) -> Optional[Callable]:
# Extract implicit inputs, and pad bounded-size inputs to their max size.
needs_implicit = which_explicit and not all(which_explicit)
needs_padding = any(type(in_avals[d.val]) is core.AbstractBInt # type: ignore
for a in in_avals if type(a) is core.DShapedArray
for d in a.shape if type(d) is pe.DBIdx)
if not needs_implicit and not needs_padding:
return None
assert config.jax_dynamic_shapes
# Precompute how to grab implicit inputs from explicit inputs' axis sizes.
which_explicit = which_explicit or [True] * len(in_avals)
implicit_idxs = {i for i, ex in enumerate(which_explicit) if not ex}
implicit_args_from_axes: List[Tuple[int, int, int]] = []
for arg_idx, aval in enumerate(in_avals):
if isinstance(aval, core.DShapedArray):
for axis_idx, d in enumerate(aval.shape):
if isinstance(d, pe.DBIdx) and d.val in implicit_idxs:
implicit_args_from_axes.append((d.val, arg_idx, axis_idx))
assert {i for i, _, _ in implicit_args_from_axes} == implicit_idxs
# Precompute how to pad bounded-size inputs to their max size.
def needs_pad(a: core.AbstractValue) -> bool:
return (type(a) is core.DShapedArray and
any(type(d) is pe.DBIdx for d in aval.shape))
def padshape(a: core.DShapedArray) -> List[int]:
return [in_avals[d.val].bound if type(d) is pe.DBIdx and # type: ignore
type(in_avals[d.val]) is core.AbstractBInt else d for d in a.shape] # type: ignore
padders = [partial(jax.jit(_pad_arg, static_argnums=0), tuple(padshape(aval))) # type: ignore
if needs_pad(aval) else None for aval in in_avals]
def elaborate_and_pad(explicit_args):
explicit_args_ = iter(explicit_args)
args = [next(explicit_args_) if ex else None for ex in which_explicit]
assert next(explicit_args_, None) is None
for i, j, k in implicit_args_from_axes:
if args[i] is None:
args[i] = args[j].shape[k] # type: ignore
else:
if args[i] != args[j].shape[k]:
raise Exception("inconsistent argument axis sizes for type")
return tuple([pad(x) if pad else x for x, pad in zip(args, padders)])
return elaborate_and_pad
def _pad_arg(shape, x):
zeros = jax.lax.full(shape, 0, x.dtype)
return jax.lax.dynamic_update_slice(zeros, x, (0,) * len(shape))
if MYPY:
ResultHandler = Any
else:
@ -405,6 +480,13 @@ def array_result_handler(sticky_device: Optional[Device],
return partial(device_array.make_device_array, core.raise_to_shaped(aval),
sticky_device)
def dynamic_array_result_handler(sticky_device: Optional[Device],
aval: core.DShapedArray):
if aval.dtype is dtypes.float0:
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
else:
raise NotImplementedError
result_handlers: Dict[
Type[core.AbstractValue],
@ -412,6 +494,7 @@ result_handlers: Dict[
result_handlers[core.AbstractUnit] = lambda _, __: lambda _: core.unit
result_handlers[core.AbstractToken] = lambda _, __: lambda _: core.token
result_handlers[core.ShapedArray] = array_result_handler
result_handlers[core.DShapedArray] = dynamic_array_result_handler
result_handlers[core.ConcreteArray] = array_result_handler
@ -433,9 +516,11 @@ def _check_special(name, xla_shape, buf):
def _execute_compiled(name: str, compiled: XlaExecutable,
input_handler: Optional[Callable],
output_buffer_counts: Optional[Sequence[int]],
result_handlers, kept_var_idx, *args):
device, = compiled.local_devices()
args = input_handler(args) if input_handler else args
input_bufs_flat = flatten(device_put(x, device) for i, x in enumerate(args)
if i in kept_var_idx)
out_bufs_flat = compiled.execute(input_bufs_flat)
@ -447,8 +532,10 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
def _execute_replicated(name: str, compiled: XlaExecutable,
input_handler: Optional[Callable],
output_buffer_counts: Optional[Sequence[int]],
result_handlers, kept_var_idx, *args):
if input_handler: raise NotImplementedError # TODO(mattjj, dougalm)
input_bufs = [flatten(device_put(x, device) for i, x in enumerate(args)
if i in kept_var_idx)
for device in compiled.local_devices()]
@ -477,16 +564,18 @@ def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
class XlaComputation(stages.Computation):
name: str
_is_trivial: bool
_executable: Optional['XlaCompiledComputation']
_executable: Optional[XlaCompiledComputation]
_donated_invars: Optional[Sequence[bool]]
def __init__(self, name: str, hlo, is_trivial: bool,
donated_invars: Optional[Sequence[bool]],
explicit_args: Optional[Sequence[bool]],
**compile_args):
self.name = name
self._hlo = hlo
self._is_trivial = is_trivial
self._donated_invars = donated_invars
self._explicit_args = explicit_args
self._executable = None
self.compile_args = compile_args
@ -511,14 +600,14 @@ class XlaComputation(stages.Computation):
return ir.Module.parse(module_str)
return self._hlo
def compile(self) -> 'XlaCompiledComputation':
def compile(self) -> XlaCompiledComputation:
if self._executable is None:
if self.is_trivial():
self._executable = XlaCompiledComputation.from_trivial_jaxpr(
**self.compile_args)
else:
self._executable = XlaCompiledComputation.from_xla_computation(
self.name, self._hlo, **self.compile_args)
self.name, self._hlo, self._explicit_args, **self.compile_args)
return self._executable
@ -586,6 +675,7 @@ class XlaCompiledComputation(stages.Executable):
def from_xla_computation(
name: str,
xla_computation: Optional[ir.Module],
explicit_args: Optional[Sequence[bool]],
nreps: int,
device: Optional[Device],
backend: Backend,
@ -594,6 +684,7 @@ class XlaCompiledComputation(stages.Executable):
out_avals: Sequence[core.AbstractValue],
kept_var_idx: Set[int]) -> XlaCompiledComputation:
sticky_device = device
input_handler = _input_handler(explicit_args, in_avals)
result_handlers = map(partial(aval_to_result_handler, sticky_device),
out_avals)
options = xb.get_compile_options(
@ -603,10 +694,10 @@ class XlaCompiledComputation(stages.Executable):
with log_elapsed_time(f"Finished XLA compilation of {name} "
"in {elapsed_time} sec"):
compiled = compile_or_get_cached(backend, xla_computation, options)
buffer_counts = (None if len(out_avals) == 1 else
[aval_to_num_buffers(aval) for aval in out_avals])
buffer_counts = (None if len(out_avals) == 1 and not config.jax_dynamic_shapes
else [aval_to_num_buffers(aval) for aval in out_avals])
execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, buffer_counts,
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts,
result_handlers, kept_var_idx)
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
@ -622,7 +713,7 @@ class XlaCompiledComputation(stages.Executable):
@staticmethod
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
kept_var_idx) -> 'XlaCompiledComputation':
kept_var_idx) -> XlaCompiledComputation:
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
out_avals, result_handlers, kept_var_idx)

View File

@ -1719,12 +1719,7 @@ def _combine_leading(sz0, sz1, aval, x):
return lax.collapse(x, 0, 2)
def _prepend_dim_to_aval(sz, aval):
if aval is core.abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
return aval.update(shape=(sz, *aval.shape), weak_type=False)
else:
raise TypeError(f'Prepending dim {sz} to aval {aval}')
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
@ -2063,6 +2058,10 @@ def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
tc = partial(_typecheck_param, 'scan')
@ -2133,6 +2132,7 @@ masking.masking_rules[scan_p] = _scan_masking_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
pe.padding_rules[scan_p] = _scan_padding_rule

View File

@ -1970,6 +1970,7 @@ integer_pow_p = standard_primitive(
batching.defvectorized(integer_pow_p)
masking.defvectorized(integer_pow_p)
ad.defjvp(integer_pow_p, _integer_pow_jvp)
pe.padding_rules[integer_pow_p] = lambda _, __, x, y: [integer_pow_p.bind(x, y=y)]
def _integer_pow(x, *, y):
# This should be kept in sync with the jax2tf translation rule.
@ -2061,6 +2062,7 @@ add_p: Primitive = standard_naryop([_num, _num], 'add')
ad.primitive_jvps[add_p] = _add_jvp
ad.primitive_transposes[add_p] = _add_transpose
mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp))
pe.padding_rules[add_p] = lambda _, __, x, y: [add(x, y)]
def _sub_jvp(primals, tangents):
x, y = primals
@ -2090,6 +2092,7 @@ sub_p = standard_naryop([_num, _num], 'sub')
ad.primitive_jvps[sub_p] = _sub_jvp
ad.primitive_transposes[sub_p] = _sub_transpose
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubOp))
pe.padding_rules[sub_p] = lambda _, __, x, y: [sub(x, y)]
def _mul_transpose(ct, x, y):
@ -2116,6 +2119,7 @@ ad.defjvp(mul_p,
lambda ydot, x, y: mul(x, ydot))
ad.primitive_transposes[mul_p] = _mul_transpose
mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp))
pe.padding_rules[mul_p] = lambda _, __, x, y: [mul(x, y)]
def _div_transpose_rule(cotangent, x, y):
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
@ -2166,6 +2170,7 @@ ad.defjvp2(max_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
pe.padding_rules[max_p] = lambda _, __, x, y: [max(x, y)]
min_p: core.Primitive = standard_naryop(
[_any, _any], 'min', translation_rule=partial(
@ -2281,7 +2286,8 @@ def _convert_elt_type_fwd_rule(eqn):
return [None], eqn
def _convert_elt_type_pp_rule(eqn, context, settings):
# don't print new_dtype because the output binder shows it
# don't print new_dtype because the output binder shows it, don't print
# weak_type when false
printed_params = {}
if eqn.params['weak_type']:
printed_params['weak_type'] = True
@ -2595,6 +2601,27 @@ def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers,
rhs, dimension_numbers, precision=precision,
preferred_element_type=preferred_element_type)
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
dimension_numbers, **params):
lhs_aval, _ = in_avals
(lhs_contract, _), _ = dimension_numbers
padded_axes = [(i, lhs_aval.shape[i].val) for i in lhs_contract
if isinstance(lhs_aval.shape[i], pe.BoundedAxisSize)]
lhs_ = _replace_masked_values(lhs, 0, padded_axes)
return [dot_general(lhs_, rhs, dimension_numbers=dimension_numbers, **params)]
def _dot_general_pp_rule(eqn, context, settings):
# suppress printing precision or preferred_element_type when None.
# print dimension_numbers as list-of-lists to be shorter.
printed_params = {k: v for k, v in eqn.params.items() if v is not None}
(lhs_cont, rhs_cont), (lhs_batch, rhs_batch) = eqn.params['dimension_numbers']
printed_params['dimension_numbers'] = (
(list(lhs_cont), list(rhs_cont)), (list(lhs_batch), list(rhs_batch)))
return [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
dot_general_p = standard_primitive(_dot_general_shape_rule,
_dot_general_dtype_rule, 'dot_general',
_dot_general_translation_rule)
@ -2604,6 +2631,9 @@ batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
masking.masking_rules[dot_general_p] = _dot_general_masking_rule
xla.register_translation(dot_general_p, _dot_general_cpu_translation_rule,
platform="cpu")
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
# TODO(mattjj): un-comment the next line
# core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
if precision is None:
@ -2728,12 +2758,31 @@ def _broadcast_in_dim_staging_rule(
return out_tracer
def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
shape, broadcast_dimensions):
del in_avals, dyn_shape
out_aval, = out_avals
new_shape = []
new_dyn_shape = []
for d in out_aval.shape:
if type(d) is pe.BoundedAxisSize:
new_shape.append(d.bound)
elif type(d) is int:
new_shape.append(d)
else:
assert isinstance(d, core.Tracer)
new_shape.append(None)
new_dyn_shape.append(d)
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=new_shape,
broadcast_dimensions=broadcast_dimensions)]
broadcast_in_dim_p = standard_primitive(
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
ad.deflinear2(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule
def _broadcast_in_dim_lower(ctx, x, *, shape, broadcast_dimensions):
del shape
@ -2809,6 +2858,7 @@ ad.defjvp(clamp_p,
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
mlir.register_lowering(
clamp_p, partial(_nary_lower_mhlo, mhlo.ClampOp, explicit_type=True))
pe.padding_rules[clamp_p] = lambda _, __, a, x, b: [clamp(a, x, b)]
def _concatenate_shape_rule(*operands, **kwargs):
dimension = kwargs.pop('dimension')
@ -3559,6 +3609,21 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
assert result.shape == input_shape
return [result]
def _reduce_sum_padding_rule(in_avals, out_avals, operand, *, axes):
del out_avals
aval, = in_avals
padded_axes = [(i, d.val) for i, d in enumerate(aval.shape)
if isinstance(d, pe.BoundedAxisSize)]
masked_operand = _replace_masked_values(operand, 0, padded_axes)
return [_reduce_sum(masked_operand, axes)]
def _replace_masked_values(x, val, padded_axes):
if not padded_axes: return x
masks = [broadcasted_iota(np.dtype('int32'), x.shape, i) < d
for i, d in padded_axes]
return select(_reduce(operator.and_, masks), x, full_like(x, val))
reduce_sum_p = standard_primitive(
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
'reduce_sum', _reduce_sum_translation_rule)
@ -3566,6 +3631,7 @@ ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p)
_masking_defreducer(reduce_sum_p,
lambda shape, dtype: np.broadcast_to(np.array(0, dtype), shape))
pe.padding_rules[reduce_sum_p] = _reduce_sum_padding_rule
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
@ -4409,6 +4475,18 @@ def _iota_lower(ctx, *, dtype, shape, dimension):
mlir.register_lowering(iota_p, _iota_lower)
def make_bint(i, bd: int):
return bint_p.bind(i, bd=bd)
bint_p = core.Primitive('bint')
@bint_p.def_abstract_eval
def bint_abstract_eval(_, *, bd: int):
return core.AbstractBInt(bound=bd)
pe.padding_rules[bint_p] = lambda _, __, i, bd: [i]
### util
_ndim = np.ndim

View File

@ -22,10 +22,12 @@ import numpy as np
from jax import core
from jax._src import ad_util
from jax._src import dtypes
from jax._src import source_info_util
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import masking
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.lax.utils import (
_argnum_weak_type,
@ -2113,3 +2115,30 @@ def _dynamic_slice_indices(operand, start_indices: Any):
lax.add(i, lax.convert_element_type(core.dimension_as_value(d), lax._dtype(i))),
i)
for i, d in zip(start_indices, operand.shape)]
def _getslice(x, lo, hi):
return getslice_p.bind(x, lo, hi)
getslice_p = core.Primitive('getslice')
@getslice_p.def_impl
def getslice_impl(x, lo, hi):
return x[lo:hi]
def _getslice_staging_rule(trace, x, lo, hi):
size = lax.make_bint(lax.clamp(0, hi - lo, x.shape[0]), x.shape[0])
aval = core.DShapedArray((size,), x.dtype, x.weak_type)
source_info = source_info_util.current()
out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info)
invars = map(trace.getvar, [x, lo, hi])
eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)],
getslice_p, {}, source_info)
trace.frame.eqns.append(eqn)
return out_tracer
pe.custom_staging_rules[getslice_p] = _getslice_staging_rule
def _getslice_padding_rule(in_avals, out_avals, x, lo, hi):
xx = lax.concatenate([x, x], 0)
return [dynamic_slice_in_dim(xx, lo, x.shape[0])]
pe.padding_rules[getslice_p] = _getslice_padding_rule

View File

@ -67,8 +67,10 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
dtype_rule(*avals, **kwargs), weak_type=weak_type,
named_shape=named_shape_rule(*avals, **kwargs))
elif least_specialized is core.DShapedArray:
return core.DShapedArray(shape_rule(*avals, **kwargs),
dtype_rule(*avals, **kwargs), weak_type)
shape = shape_rule(*avals, **kwargs)
ty = (core.ShapedArray if all(type(d) is int for d in shape)
else core.DShapedArray)
return ty(shape, dtype_rule(*avals, **kwargs), weak_type)
elif least_specialized is core.UnshapedArray:
return core.UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type)
else:

View File

@ -48,8 +48,10 @@ from jax.tree_util import tree_leaves, tree_flatten, tree_map
from jax._src import device_array
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator)
from jax._src.lax import lax as lax_internal
from jax._src.lax.slicing import _getslice
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.reductions import ( # noqa: F401
_ensure_optional_axes, _reduction_dims,
@ -3479,6 +3481,15 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
arr = asarray(arr)
# TODO(mattjj,dougalm): expand dynamic shape indexing support
if (jax.config.jax_dynamic_shapes and type(idx) is slice and idx.step is None
and (isinstance(idx.start, core.Tracer) or isinstance(idx.stop, core.Tracer))
and arr.shape):
start = 0 if idx.start is None else idx.start
stop = arr.shape[0] if idx.stop is None else idx.stop
return _getslice(arr, start, stop)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices, mode, fill_value)

View File

@ -89,7 +89,7 @@ def treedef_children(treedef):
def treedef_is_leaf(treedef):
return treedef.num_nodes == 1
def all_leaves(iterable):
def all_leaves(iterable, is_leaf: Optional[Callable[[Any], bool]] = None):
"""Tests whether all elements in the given iterable are all leaves.
>>> tree = {"a": [1, 2, 3]}
@ -106,7 +106,11 @@ def all_leaves(iterable):
Returns:
A boolean indicating if all elements in the input are leaves.
"""
return pytree.all_leaves(iterable)
if is_leaf is None:
return pytree.all_leaves(iterable)
else:
lst = list(iterable)
return lst == tree_leaves(lst, is_leaf)
_Children = TypeVar("_Children", bound=Iterable[Any])

View File

@ -1034,6 +1034,18 @@ class AbstractUnit(AbstractValue):
abstract_unit = AbstractUnit()
class AbstractBInt(AbstractValue):
__slots__ = ['bound']
bound: int
def __init__(self, bound):
self.bound = bound
def str_short(self, short_dtypes=False) -> str:
return f'bint{{{self.bound}}}[]'
def __eq__(self, other):
return type(other) is AbstractBInt and self.bound == other.bound
def __hash__(self) -> int:
return hash((type(self), self.bound))
def lattice_join(x: Optional[AbstractValue],
y: Optional[AbstractValue]) -> AbstractValue:
if x is None:
@ -1212,7 +1224,7 @@ AxisSize = Union[AxisSizeForTracing, AxisSizeForJaxprType,
class DShapedArray(UnshapedArray):
__slots__ = ['shape']
shape: Tuple[AxisSize, ...] # noqa: F821
array_abstraction_level = 2
array_abstraction_level: int = 2
def __init__(self, shape, dtype, weak_type):
self.shape = shape
@ -1229,11 +1241,6 @@ class DShapedArray(UnshapedArray):
return f'{dtype}[{shape}]'
__str__ = __repr__ = str_short
def __eq__(self, other):
return (type(self) is type(other) and
self.dtype == other.dtype and self.shape == other.shape and
self.weak_type == other.weak_type)
def update(self, shape=None, dtype=None, weak_type=None):
if shape is None:
shape = self.shape
@ -1243,6 +1250,14 @@ class DShapedArray(UnshapedArray):
weak_type = self.weak_type
return DShapedArray(shape, dtype, weak_type)
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type)
def __hash__(self):
return hash((self.shape, self.dtype, self.weak_type))
del AxisSize, AxisSizeForTracing, AxisSizeForJaxprType, \
AxisSizeForJaxprTracingSpec
@ -1415,6 +1430,7 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
raise_to_shaped_mappings : Dict[type, Callable] = {
AbstractUnit: lambda aval, _: aval,
AbstractBInt: lambda aval, _: aval,
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
@ -1758,14 +1774,17 @@ class CallPrimitive(Primitive):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
if config.jax_dynamic_shapes:
subfun = lu.annotate(subfun, tuple((v.aval, True) for v in jaxpr.invars))
return [subfun], new_params
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun_, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_call(primitive, fun, tracers, params)
fun_ = lu.annotate(fun_, fun.in_type)
outs = top_trace.process_call(primitive, fun_, tracers, params)
return map(full_lower, apply_todos(env_trace_todo(), outs))
@lu.transformation_with_aux
@ -1931,13 +1950,28 @@ def _unmap_shaped_array(size: int, axis_name, axis: Optional[int],
named_shape = dict(aval.named_shape)
# TODO: Make this mandatory
named_shape.pop(axis_name, None)
if axis is None: return aval.replace(named_shape=named_shape)
if axis is None: return aval.update(named_shape=named_shape)
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
named_shape=named_shape, weak_type=aval.weak_type)
def _map_dshaped_array(size: Union[int, Tracer], axis: Optional[int],
aval: ShapedArray) -> ShapedArray:
assert False # TODO(mattjj, dougalm)
def _unmap_dshaped_array(
size: Union[int, Tracer], axis_name, axis: Optional[int],
aval: DShapedArray) -> DShapedArray:
if isinstance(size, int):
if axis is None: return aval
return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
weak_type=aval.weak_type)
else:
assert False # TODO(mattjj, dougalm)
AvalMapHandlerPair = Tuple[Callable, Callable]
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
AbstractUnit: (_map_unit, _map_unit),
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
}

View File

@ -16,7 +16,7 @@ import contextlib
import functools
from functools import partial
import itertools as it
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, List, Tuple, Optional
import jax
from jax.interpreters import partial_eval as pe
@ -24,7 +24,7 @@ from jax.config import config
from jax import core
from jax._src.dtypes import dtype, float0
from jax.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
raise_to_shaped)
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_aval, zeros_like_p, Zero)
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
@ -39,6 +39,16 @@ zip = safe_zip
map = safe_map
def identity(x): return x
def _update_annotation(
f: lu.WrappedFun,
orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]],
nonzeros: List[bool]
) -> lu.WrappedFun:
if orig_type is None:
return f
tan_types = [(aval.at_least_vspace(), keep)
for nz, (aval, keep) in zip(nonzeros, orig_type) if nz]
return lu.annotate(f, (*orig_type, *tan_types))
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
transform_stack=True) -> Any:
@ -217,7 +227,6 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
with ctx:
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
# FIXME: Some invars correspond to tangents
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
@ -244,8 +253,10 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
cotangents_out = map(read_cotangent, jaxpr.invars)
return cotangents_out
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, transform_stack, primals_in, cotangents_in):
return backward_pass(jaxpr.jaxpr, reduce_axes, transform_stack, jaxpr.consts, primals_in, cotangents_in)
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, transform_stack,
primals_in, cotangents_in):
return backward_pass(jaxpr.jaxpr, reduce_axes, transform_stack, jaxpr.consts,
primals_in, cotangents_in)
class UndefinedPrimal:
@ -301,7 +312,7 @@ class JVPTrace(Trace):
else:
return JVPTracer(self, primal_out, tangent_out)
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
@ -329,6 +340,7 @@ class JVPTrace(Trace):
f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def)
update_params = call_param_updaters.get(call_primitive)
new_params = update_params(params, nz_tangents) if update_params else params
f_jvp = _update_annotation(f_jvp, f.in_type, nz_tangents)
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
@ -590,15 +602,16 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
reduce_axes, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
if config.jax_experimental_name_stack:
new_params = params
else:
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
if not config.jax_experimental_name_stack:
params = dict(params, name=wrap_name(params['name'], 'transpose'))
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, map(is_undefined_primal, args),
[type(x) is not Zero for x in ct])
out_flat = primitive.bind(fun, *all_args, **new_params)
params = update_params(params, map(is_undefined_primal, args),
[type(x) is not Zero for x in ct])
if config.jax_dynamic_shapes:
in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args]
fun = lu.annotate(fun, tuple(in_type))
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)

View File

@ -14,7 +14,7 @@
from functools import partial
from typing import (Any, Callable, Dict, Set, Optional, Tuple, Union, Iterable,
Type)
Type, Sequence)
import numpy as np
@ -34,6 +34,17 @@ from jax.interpreters import partial_eval as pe
map = safe_map
def _update_annotation(
f: lu.WrappedFun,
orig_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]],
axis_size: int, axis_name: core.AxisName, in_dims: Sequence[Optional[int]]
) -> lu.WrappedFun:
if orig_type is None:
return f
batched_in_type = [(core.unmapped_aval(axis_size, axis_name, dim, aval), keep)
for dim, (aval, keep) in zip(in_dims, orig_type)]
return lu.annotate(f, tuple(batched_in_type))
### vmappable typeclass
Vmappable = Any
@ -175,11 +186,9 @@ class BatchTrace(Trace):
def get_frame(self, vals, dims) -> core.AxisEnvFrame:
if self.axis_name is core.no_axis_name:
# If axis name is `no_axis_name` we can't find it via `core.axis_name` so we
# reconstruct it from the information we have available
axis_sizes = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
assert len(axis_sizes) == 1
axis_size, = axis_sizes
# If axis name is `no_axis_name` we can't find it via `core.axis_name` so
# we reconstruct it from the information we have available
axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
return core.axis_frame(self.axis_name)
@ -204,7 +213,7 @@ class BatchTrace(Trace):
else:
return BatchTracer(self, val_out, dim_out, src)
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
if config.jax_experimental_name_stack:
params = dict(params, name=params.get('name', f.__name__))
@ -214,8 +223,10 @@ class BatchTrace(Trace):
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
else:
f, dims_out = batch_subtrace(f, self.main, dims)
vals_out = call_primitive.bind(f, *vals, **params)
f_, dims_out = batch_subtrace(f, self.main, dims)
ax_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
f_ = _update_annotation(f_, f.in_type, ax_size, self.axis_name, dims)
vals_out = call_primitive.bind(f_, *vals, **params)
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())]

View File

@ -118,6 +118,13 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
def _array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
shape = [d if type(d) is int else -1 for d in aval.shape]
return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)
def _bint_ir_types(aval: core.AbstractBInt) -> Sequence[ir.Type]:
return (ir.RankedTensorType.get((), dtype_to_ir_type(dtypes.dtype('int32'))),)
ir_type_handlers: Dict[Type[core.AbstractValue],
Callable[[Any], Sequence[ir.Type]]] = {}
@ -132,9 +139,11 @@ def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err
ir_type_handlers[core.AbstractUnit] = lambda _: ()
ir_type_handlers[core.AbstractBInt] = _bint_ir_types
ir_type_handlers[core.ShapedArray] = _array_ir_types
ir_type_handlers[core.ConcreteArray] = _array_ir_types
ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()]
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types
def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
"""Convenience wrapper around aval_to_ir_types for single types.

View File

@ -14,13 +14,14 @@
from collections import namedtuple
import contextlib
from dataclasses import dataclass
import functools
from functools import partial
import inspect
import itertools as it
import operator as op
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
List, Union, Set, cast)
List, Union, Hashable, cast)
from weakref import ref
import numpy as np
@ -38,15 +39,24 @@ from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
as_hashable_function, weakref_lru_cache)
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
ConcreteArray, raise_to_shaped, Var, Atom, JaxprEqn,
Primitive, DShapedArray, mapped_aval, unmapped_aval)
ConcreteArray, raise_to_shaped, Var, DropVar, Atom,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
AbstractBInt, mapped_aval, unmapped_aval)
from jax._src import source_info_util
from jax.config import config
map = safe_map
zip = safe_zip
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
def identity(x): return x
def _update_annotation(
f: lu.WrappedFun,
orig_type: Optional[Tuple[Tuple[AbstractValue, bool], ...]],
in_knowns: List[bool]) -> lu.WrappedFun:
if orig_type is None:
return f
return lu.annotate(f, tuple([ty for k, ty in zip(in_knowns, orig_type) if k]))
class PartialVal(tuple):
"""Partial value: either a known value or an unknown (abstract) value.
@ -182,7 +192,7 @@ class JaxprTrace(Trace):
params, effects, source)
return out_tracer
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
def process_call(self, primitive, f, tracers, params):
if primitive in call_partial_eval_rules:
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
@ -197,13 +207,14 @@ class JaxprTrace(Trace):
# which were unknown to the first call (corresponding to in_avals).
# Wrap f to perform the partial evaluation and plumb out aux data.
f = trace_to_subjaxpr_nounits(f, self.main, False)
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
f_ = trace_to_subjaxpr_nounits(f, self.main, False)
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), tuple(in_avals))
# Adjust parameters (e.g. donated_invars) for the call to be evaluated now.
const_params = update_params(params, in_knowns, 0)
# Run the call, getting known out vals and aux data used for staged-out call
out = primitive.bind(f, *in_consts, **const_params)
out = primitive.bind(_update_annotation(f_, f.in_type, in_knowns),
*in_consts, **const_params)
out_knowns, out_avals, jaxpr, env = aux()
# Split apart known outputs from the original call and residuals.
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
@ -679,7 +690,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
_, in_tracers, out_tracer_refs, primitive, params, effects, source_info = recipe
out_tracers = [t_ref() for t_ref in out_tracer_refs]
invars = [getvar(t) for t in in_tracers]
outvars = [core.DropVar(core.abstract_unit) if t is None
outvars = [DropVar(core.abstract_unit) if t is None
else cast(Var, getvar(t)) for t in out_tracers]
return new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
@ -1391,18 +1402,14 @@ forwarding_rules: Dict[Primitive, ForwardingRule] = {}
def _inline_literals(jaxpr, constvals):
# This function also ensures variables are labeled in a canonical ordering,
# prunes unused constants, and inserts `dropvar` symbols.
consts = dict(zip(jaxpr.constvars, constvals))
lits = {v: Literal(c, v.aval) for v, c in zip(jaxpr.constvars, constvals)
if type(c) in core.literalable_types and not np.shape(c)}
lit: Callable[[Var], Optional[Literal]] = lits.get
newname: Callable[[AbstractValue], Var] = core.gensym()
newvars: Dict[Var, Var] = {}
newvar = lambda aval: newname(_substitute_vars_in_type(newvars, aval))
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
def lit(v: Var) -> Optional[Literal]:
val = consts.get(v)
if type(val) in core.literalable_types and not np.shape(val):
return Literal(val, v.aval)
else:
return None
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))
def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
if isinstance(aval, DShapedArray):
@ -1420,8 +1427,7 @@ def _inline_literals(jaxpr, constvals):
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(v) or var(v) for v in eqn.invars]
outvars = [var(v) if v in used else core.DropVar(v.aval)
for v in eqn.outvars]
outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars]
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
@ -1443,6 +1449,7 @@ class DynamicJaxprTrace(core.Trace):
return tracer
def new_const(self, c):
# TODO(mattjj): for ints, or hashable consts, don't rely on id
tracer = self.frame.constid_to_tracer.get(id(c))
if tracer is None:
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
@ -1515,30 +1522,33 @@ class DynamicJaxprTrace(core.Trace):
self.frame.add_eqn(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()
def process_call(self, call_primitive, f, tracers, params):
dim_tracers = _get_tracers_only_in_shapes(tracers)
in_avals = _tracers_to_avals(dim_tracers + tracers)
keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers)
def process_call(self, call_primitive, f, explicit_tracers, params):
if f.in_type is None:
in_avals = [core.raise_to_shaped(get_aval(x)) for x in explicit_tracers]
keep_inputs = [True] * len(explicit_tracers)
im_tracers = []
else:
im_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
in_avals, keep_inputs = unzip2(f.in_type)
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, in_avals, keep_inputs=keep_inputs)
tracers = [*im_tracers, *explicit_tracers]
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
return core.eval_jaxpr(jaxpr, consts, *tracers)
env = {v: t for v, t in zip(jaxpr.constvars, consts) if isinstance(t, Tracer)}
env.update(zip(jaxpr.invars, tracers))
out_avals_ = [_substitute_tracers_in_type(env, a) for a in out_avals]
source_info = source_info_util.current()
env = {v: t for v, t in zip((*jaxpr.constvars, *jaxpr.invars),
(*consts, *dim_tracers, *tracers))
if isinstance(t, Tracer)}
subs = partial(_substitute_tracers_in_type, env)
out_tracers = [DynamicJaxprTracer(self, subs(a), source_info)
for a in out_avals]
invars = map(self.getvar, dim_tracers + tracers)
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals_]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
outvars = map(self.makevar, out_tracers)
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
update_params = call_param_updaters.get(call_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers),
len(consts) + len(dim_tracers))
new_params = update_params(new_params, [True] * len(explicit_tracers),
len(consts) + len(im_tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
call_primitive, new_params,
new_params['call_jaxpr'].effects, source_info)
@ -1765,7 +1775,7 @@ def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
keep_inputs: Optional[Sequence[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
@ -1793,7 +1803,7 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers = _input_type_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
ans = fun.call_wrapped(*in_tracers_)
out_tracers = map(trace.full_raise, ans)
@ -1817,12 +1827,14 @@ def extend_jaxpr_stack(main, frame):
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
debug_info: Optional[DebugInfo] = None,
keep_inputs: Optional[Sequence[bool]] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
fun, main, in_avals, keep_inputs=keep_inputs)
del fun, main
return jaxpr, out_avals, consts
@ -1835,65 +1847,179 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
return trace_to_jaxpr(fun, in_pvals)
def _avals_to_tracers(
AbstractedAxisName = Hashable
AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName], Tuple[AbstractedAxisName, ...]]
class DBIdx(NamedTuple):
val: int
@dataclass(frozen=True)
class Bound:
name: AbstractedAxisName
bound: int
InputType = Tuple[Tuple[AbstractValue, bool], ...]
def infer_lambda_input_type(
axes_specs: Optional[Sequence[AbstractedAxesSpec]],
args: Sequence[Any]
) -> InputType:
partial_specs = _canonicalize_specs(map(np.ndim, args), axes_specs)
specs = _complete_specs(args, partial_specs)
idxs, implicit_names = _collect_implicit(args, specs)
implicit_inputs = [(_implicit_arg_type(n), False) for n in implicit_names]
explicit_inputs = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)]
return (*implicit_inputs, *explicit_inputs)
def _canonicalize_specs(
ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]]
) -> List[Dict[int, AbstractedAxisName]]:
if specs is None:
return [{}] * len(ndims)
else:
return [{i: d for i, d in enumerate(s) if d is not None} if type(s) is tuple
else s for n, s in zip(ndims, specs)]
def _complete_specs(
args: Sequence[Any], partial_specs: List[Dict[int, AbstractedAxisName]]
) -> List[Dict[int, AbstractedAxisName]]:
# Identify each user-supplied name in partial_specs with a size.
sizes: Dict[AbstractedAxisName, Union[int, DynamicJaxprTracer]] = {}
for x, spec in zip(args, partial_specs):
for i, name in spec.items():
d = sizes.setdefault(name, x.shape[i])
if d is not x.shape[i] and d != x.shape[i]: raise TypeError
# Introduce new names as needed for Tracers in shapes.
named_tracers: Dict[TracerId, AbstractedAxisName] = {
id(d): name for name, d in sizes.items() if isinstance(d, Tracer)}
specs: List[Dict[int, AbstractedAxisName]] = []
for x, spec in zip(args, partial_specs):
if isinstance(get_aval(x), DShapedArray):
spec = dict(spec)
for i, d in enumerate(x.shape):
if isinstance(d, Tracer):
spec[i] = named_tracers.get(id(d), TracerAsName(d))
specs.append(spec)
assert all(not spec or not any(isinstance(d, Tracer) and i not in spec
for i, d in enumerate(x.shape))
for x, spec in zip(args, specs))
return specs
def _collect_implicit(
args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]]
) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractedAxisName]]:
idxs: Dict[AbstractedAxisName, DBIdx] = {}
explicit_tracers: Dict[TracerId, int] = {}
counter = (DBIdx(i) for i in it.count())
# Add implicit arguments to idxs.
for explicit_idx, (x, spec) in enumerate(zip(args, specs)):
for i, name in spec.items():
if name not in idxs and id(x.shape[i]) not in explicit_tracers:
idxs[name] = next(counter)
if isinstance(x, Tracer):
explicit_tracers[id(x)] = explicit_idx
implicit_names: List[AbstractedAxisName] = list(idxs)
# Now that we know the implicit args, add explicit args to idxs.
offset = len(implicit_names)
for x, spec in zip(args, specs):
for i, name in spec.items():
if id(x.shape[i]) in explicit_tracers:
idxs[name] = DBIdx(offset + explicit_tracers[id(x.shape[i])])
return idxs, implicit_names
def _implicit_arg_type(name: AbstractedAxisName) -> AbstractValue:
if type(name) is Bound:
return AbstractBInt(name.bound)
else:
return ShapedArray((), dtypes.dtype('int32'))
def _arg_type(
idxs: Dict[AbstractedAxisName, DBIdx], x: Any,
spec: Dict[int, AbstractedAxisName]
) -> AbstractValue:
aval = get_aval(x) # aval.shape could contain Tracers
if not spec: return core.raise_to_shaped(aval)
shape: List[Union[int, DBIdx]] = [idxs[spec[i]] if i in spec else d
for i, d in enumerate(aval.shape)]
assert not any(isinstance(d, Tracer) for d in shape)
return DShapedArray(tuple(shape), aval.dtype, False)
class TracerAsName:
tracer: DynamicJaxprTracer
def __init__(self, tracer):
trace = core.thread_local_state.trace_state.trace_stack.dynamic
self.tracer = trace.with_cur_sublevel().full_raise(tracer)
def __eq__(self, other):
return isinstance(other, TracerAsName) and self.tracer is other.tracer
def __hash__(self):
return id(self.tracer)
def _extract_implicit_args(
trace: DynamicJaxprTrace, in_type: Sequence[Tuple[AbstractValue, bool]],
explicit_tracers: Sequence[DynamicJaxprTracer]
) -> Sequence[DynamicJaxprTracer]:
# First, construct a list to represent the full argument list, leaving the
# implicit arguments as Nones for now.
explicit_tracers_ = iter(explicit_tracers)
tracers = [next(explicit_tracers_) if expl else None for _, expl in in_type]
assert next(explicit_tracers_, None) is None
del explicit_tracers_
# Next, populate the implicit arguments using DBIdxs in in_type.
for i, (aval, explicit) in enumerate(in_type):
if not explicit or not isinstance(aval, DShapedArray):
continue # can't populate an implicit argument
tracer = tracers[i]
assert tracer is not None
for d1, d2 in zip(aval.shape, tracer.aval.shape):
if isinstance(d1, DBIdx):
if tracers[d1.val] is None:
tracers[d1.val] = trace.instantiate_const(d2)
assert tracers[d1.val] is trace.instantiate_const(d2)
assert all(t is not None for t in tracers)
return [t for t, (_, e) in zip(tracers, in_type) if not e]
def _in_avals_from_tracers(
tracers: List[DynamicJaxprTracer]
) -> List[AbstractValue]:
# Returned AbstractValues contain DBIdx indices. Uses Tracer obj id as name.
dbidxs: Dict[TracerId, DBIdx] = {id(t): DBIdx(i) for i, t in enumerate(tracers)}
in_avals: List[AbstractValue] = []
for t in tracers:
a = t.aval
if isinstance(a, DShapedArray) and any(isinstance(d, Tracer) for d in a.shape):
shape = [dbidxs[id(d)] if isinstance(d, Tracer) else d for d in a.shape]
a = a.update(shape=tuple(shape))
in_avals.append(a)
return in_avals
def _input_type_to_tracers(
trace: DynamicJaxprTrace, in_avals: Sequence[AbstractValue]
) -> Sequence[Tracer]:
# Create input Tracers given input AbstractValues, each of which can contain
# other AbstractValues. That is, each element `a` of `in_avals` can have
# abstract values in its shape, which must occur to the left of `a`.
env: Dict[AvalId, Tracer] = {}
# DeBruijn indices which refer to positions in the input argument list. That
# is, each element `a` of `in_avals` can have DBIdx instances in its shape,
# which must refer to positions left of `a`'s.
in_tracers: List[Tracer] = []
for a in in_avals:
t = env[id(a)] = trace.new_arg(_substitute_tracers_in_aval(env, a))
in_tracers.append(t)
return in_tracers
def _substitute_tracers_in_aval(
env: Dict[AvalId, Tracer], a: AbstractValue
) -> AbstractValue:
# Substitute Tracers into a given AbstractValue using the given environment.
# That is, the input is an AbstractValue possibly containing AbstractValues,
# and the output is an AbstractValue possibly containing Tracers.
if (isinstance(a, DShapedArray) and
any(isinstance(d, AbstractValue) for d in a.shape)):
shape = [env[id(d)] if isinstance(d, AbstractValue) else d for d in a.shape]
return a.update(shape=tuple(shape))
return a
def _tracers_to_avals(tracers: Sequence[Tracer]) -> List[AbstractValue]:
# Replace Tracers with corresponding abstract values, handling Tracers in
# shapes and ensuring each Tracer object is mapped to a single AbstractValue.
env: Dict[TracerId, AbstractValue] = {}
avals: List[AbstractValue] = []
for t in tracers:
aval = env.get(id(t))
if aval is None:
aval = env[id(t)] = _substitute_avals_in_aval(env, t.aval)
avals.append(aval)
return avals
def _substitute_avals_in_aval(
env: Dict[TracerId, AbstractValue], a: AbstractValue
) -> AbstractValue:
# Substitute AbstractValues into given AbstractValue using given environment.
# That is, the input is an AbstractValue possibly containing Tracers and the
# output is an AbstractValue possibly containing AbstractValues.
if (isinstance(a, DShapedArray) and
any(isinstance(d, Tracer) for d in a.shape)):
shape = [env.setdefault(id(d), d.aval) if isinstance(d, Tracer) else d
for d in a.shape]
return a.update(shape=tuple(shape))
else:
def _substitute_tracers_in_aval(a: AbstractValue) -> AbstractValue:
if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape):
shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape] # type: ignore
return a.update(shape=tuple(shape))
return a
for a in in_avals:
in_tracers.append(trace.new_arg(_substitute_tracers_in_aval(a)))
return in_tracers
def _substitute_vars_in_type(
env: Dict[Var, Var], a: AbstractValue
consts: Dict[Var, Literal], env: Dict[Var, Var], a: AbstractValue
) -> AbstractValue:
# Substitutes variables into a given AbstractValue using given environment.
# That is, the input is an AbstractValue possibly containing Vars, and the
# output is an aval possibly containing Vars.
if isinstance(a, DShapedArray) and any(isinstance(d, Var) for d in a.shape):
shape = [env[d] if isinstance(d, Var) else d for d in a.shape]
shape = [consts[d].val if d in consts else env[d] # type: ignore
if isinstance(d, Var) else d for d in a.shape]
return a.update(shape=tuple(shape))
else:
return a
@ -1910,24 +2036,74 @@ def _substitute_tracers_in_type(
else:
return a
def _get_tracers_only_in_shapes(in_tracers: Sequence[Tracer]) -> List[Tracer]:
# In DynamicJaxprTrace.process_call (e.g. for handling jit-of-jit) we want to
# extract Tracers from the shapes of arguments so as to elaborate them into
# explicit inputs to the call that appears in the jaxpr. Some of these Tracers
# may already appear as explicit inputs, so we only need to get those present
# exclusively in shapes.
return _get_tracers_in_shapes({id(t) for t in in_tracers}, in_tracers)
Const = Any
Val = Any
def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
) -> Tuple[Jaxpr, List[Const]]:
bounds = {v: v.aval.bound for v in jaxpr.invars
if type(v.aval) is AbstractBInt}
idxs = {v: DBIdx(i) for i, v in enumerate(jaxpr.invars)}
def substitute(aval: AbstractValue) -> AbstractValue:
if isinstance(aval, AbstractBInt):
return ShapedArray((), np.dtype('int32'))
elif isinstance(aval, DShapedArray):
shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore
typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray
return typ(tuple(shape), aval.dtype, aval.weak_type)
else:
return aval
in_avals = [substitute(v.aval) for v in jaxpr.invars]
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals)
return padded_jaxpr, padded_consts
class BoundedAxisSize(NamedTuple):
val: Union[int, DynamicJaxprTracer]
bound: int
def _eval_jaxpr_padded(
jaxpr: Jaxpr, consts: List[Const], *args: DynamicJaxprTracer
) -> List[Union[Const, DynamicJaxprTracer]]:
env: Dict[Var, Val] = {}
def read(x):
return x.val if type(x) is Literal else env[x]
def write(v, val) -> None:
env[v] = val
write(unitvar, unit)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
rule = padding_rules[eqn.primitive]
in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
map(write, eqn.outvars, outs)
return map(read, jaxpr.outvars)
def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:
if isinstance(aval, DShapedArray):
shp = [BoundedAxisSize(env[d], d.aval.bound) if type(d) is Var and
type(d.aval) is AbstractBInt else env.get(d, d) for d in aval.shape]
return DShapedArray(tuple(shp), aval.dtype, aval.weak_type)
else:
return aval
padding_rules: Dict[Primitive, Callable] = {}
def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
if call_jaxpr.constvars: raise NotImplementedError
padded_jaxpr, padded_consts = pad_jaxpr(call_jaxpr, ())
if padded_consts: raise NotImplementedError
new_params = dict(params, call_jaxpr=padded_jaxpr)
subfuns, bind_params = prim.get_bind_params(new_params)
return prim.bind(*subfuns, *args, **bind_params)
def _get_tracers_in_shapes(seen: Set[TracerId], in_tracers: Sequence[Tracer]
) -> List[Tracer]:
dim_tracers: List[Tracer] = []
for t in in_tracers:
if isinstance(t.aval, core.DShapedArray):
for d in t.aval.shape:
if isinstance(d, Tracer) and id(d) not in seen:
seen.add(id(d))
dim_tracers.append(d)
return dim_tracers
# TODO(mattjj): the following are deprecated; update callers to _nounits version
# See https://github.com/google/jax/pull/9498

View File

@ -674,8 +674,7 @@ def jaxpr_collectives(jaxpr):
for eqn in jaxpr.eqns:
if eqn.primitive in _collective_primitives:
yield eqn.primitive
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_collectives(subjaxpr)
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)
### xla_call underlying jit
@ -833,6 +832,8 @@ pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
_xla_call_partial_eval_custom_params_updater)
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p)
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext,
settings: core.JaxprPpSettings,

View File

@ -146,6 +146,7 @@ from jax._src.lax.lax import (
log_p as log_p,
lt as lt,
lt_p as lt_p,
make_bint as make_bint,
max as max,
max_p as max_p,
min as min,

View File

@ -61,10 +61,11 @@ compare as equal only if they compute the same function. The static and the
dynamic positional arguments for the generators, and also the auxiliary output
data must be immutable, because it will be stored in function memoization tables.
"""
from __future__ import annotations
import threading
from functools import partial
from typing import Any, Tuple, Callable
from typing import Any, Tuple, Callable, Optional
import weakref
from jax import core
@ -81,10 +82,10 @@ traceback_util.register_exclusion(__file__)
class StoreException(Exception): pass
class EmptyStoreValue(object): pass
class EmptyStoreValue: pass
_EMPTY_STORE_VALUE = EmptyStoreValue()
class Store(object):
class Store:
"""Storage for a value, with checks for overwriting or reading empty store."""
__slots__ = ("_val",)
@ -112,27 +113,28 @@ class Store(object):
__bool__ = __nonzero__
class WrappedFun(object):
class WrappedFun:
"""Represents a function `f` to which `transforms` are to be applied.
Args:
f: the function to be transformed.
transforms: a list of `(gen, gen_static_args)` tuples representing
transformations to apply to `f.` Here `gen` is a generator function
and `gen_static_args` is a tuple of static arguments for the generator. See
transformations to apply to `f.` Here `gen` is a generator function and
`gen_static_args` is a tuple of static arguments for the generator. See
description at the start of this module for the expected behavior of the
generator.
stores: a list of out_store for the auxiliary output of the `transforms`.
params: extra parameters to pass as keyword arguments to `f`, along with the
transformed keyword arguments.
"""
__slots__ = ("f", "transforms", "stores", "params")
__slots__ = ("f", "transforms", "stores", "params", "in_type")
def __init__(self, f, transforms, stores, params):
def __init__(self, f, transforms, stores, params, in_type):
self.f = f
self.transforms = transforms
self.stores = stores
self.params = params
self.in_type = in_type
@property
def __name__(self):
@ -141,7 +143,7 @@ class WrappedFun(object):
def wrap(self, gen, gen_static_args, out_store) -> 'WrappedFun':
"""Add another transform and its store."""
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params)
(out_store,) + self.stores, self.params, None)
def populate_stores(self, stores):
"""Copy the values from the `stores` into `self.stores`."""
@ -200,11 +202,11 @@ class WrappedFun(object):
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
def __hash__(self):
return hash((self.f, self.transforms, self.params))
return hash((self.f, self.transforms, self.params, self.in_type))
def __eq__(self, other):
return (self.f == other.f and self.transforms == other.transforms and
self.params == other.params)
self.params == other.params and self.in_type == other.in_type)
@curry
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
@ -231,8 +233,19 @@ def fun_name(f):
def wrap_init(f, params=None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
return WrappedFun(f, (), (),
() if params is None else tuple(sorted(params.items())))
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, (), (), params, None)
def annotate(f: WrappedFun,
in_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]]
) -> WrappedFun:
assert f.in_type is None
if in_type is None:
return f
assert (type(in_type) is tuple and all(type(e) is tuple for e in in_type) and
all(isinstance(a, core.AbstractValue) and type(b) is bool
for a, b in in_type))
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
class _CacheLocalContext(threading.local):
@ -259,11 +272,11 @@ def cache(call: Callable):
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
key = (fun.transforms, fun.params, fun.in_type, args,
config.x64_enabled, config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result

View File

@ -7972,6 +7972,177 @@ class DynamicShapeTest(jtu.JaxTestCase):
self.assertIs(c, c_)
self.assertIs(d, d_)
def test_jit_abstracted_axes_staging(self):
# We just test make_jaxpr-of-jit because dynamic shape compilation/execution
# may not be supported.
@partial(jax.jit, abstracted_axes=('n',))
def f(x):
return jnp.sum(x)
jaxpr = jax.make_jaxpr(f)(jnp.ones(3))
# { lambda ; a:f32[3]. let
# b:f32[] = xla_call[
# call_jaxpr={ lambda ; c:i32[] d:f32[c]. let
# e:f32[] = reduce_sum[axes=(0,)] d
# in (e,) }
# name=f
# ] 3 a
# in (b,) }
a, = jaxpr.jaxpr.invars
e, = jaxpr.jaxpr.eqns
self.assertLen(e.invars, 2)
self.assertIsInstance(e.invars[0], core.Literal)
self.assertIs(e.invars[1], a)
b, = e.outvars
self.assertLen(b.aval.shape, 0)
subjaxpr = e.params['call_jaxpr']
c, d = subjaxpr.invars
self.assertLen(c.aval.shape, 0)
self.assertLen(d.aval.shape, 1)
self.assertIs(d.aval.shape[0], c)
def test_jit_abstracted_axes_staging2(self):
@partial(jax.jit, abstracted_axes=('n',))
def fun(x):
return jnp.sum(x)
jaxpr = jax.make_jaxpr(lambda n: fun(jnp.ones(n + n)))(3)
# { lambda ; a:i32[]. let
# b:i32[] = add a a
# c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b
# d:f32[] = xla_call[
# call_jaxpr={ lambda ; e:i32[] f:f32[e]. let
# g:f32[] = reduce_sum[axes=(0,)] f
# in (g,) }
# name=f
# ] b c
# in (d,) }
a, = jaxpr.jaxpr.invars
e1, e2, e3 = jaxpr.jaxpr.eqns
b, = e1.outvars
c, = e2.outvars
b_, c_ = e3.invars
self.assertIs(b, b_)
self.assertIs(c, c_)
subjaxpr = e3.params['call_jaxpr']
e, f = subjaxpr.invars
self.assertLen(e.aval.shape, 0)
self.assertLen(f.aval.shape, 1)
self.assertIs(f.aval.shape[0], e)
def test_jit_abstracted_axes_staging3(self):
f = jax.jit(jnp.sum, abstracted_axes=('n',))
jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3.))
# { lambda ; a:i32[] b:f32[a]. let
# c:f32[] = xla_call[
# call_jaxpr={ lambda ; d:i32[] e:f32[d]. let
# f:f32[] = reduce_sum[axes=(0,)] e
# in (f,) }
# name=sum
# ] a b
# in (c,) }
a, b = jaxpr.jaxpr.invars
e, = jaxpr.jaxpr.eqns
self.assertIs(e.invars[0], a)
self.assertIs(e.invars[1], b)
c, = e.outvars
self.assertLen(c.aval.shape, 0)
subjaxpr = e.params['call_jaxpr']
d, e = subjaxpr.invars
self.assertLen(d.aval.shape, 0)
self.assertLen(e.aval.shape, 1)
self.assertIs(e.aval.shape[0], d)
def test_jit_abstracted_axes_return_polymorphic_shape(self):
f = jax.jit(lambda x: x, abstracted_axes=('n',))
jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash
# { lambda ; a:i32[3]. let
# b:i32[3] = xla_call[
# call_jaxpr={ lambda ; c:i32[] d:i32[c]. let in (d,) }
# name=<lambda>
# ] 3 a
# in (b,) }
a, = jaxpr.jaxpr.invars
e, = jaxpr.jaxpr.eqns
three, a_ = e.invars
b, = e.outvars
self.assertIsInstance(three, core.Literal)
self.assertEqual(three.val, 3)
self.assertIs(a_, a)
self.assertLen(b.aval.shape, 1)
self.assertEqual(b.aval.shape[0], 3)
def test_jit_abstracted_axes_return_polymorphic_shape2(self):
f = jax.jit(lambda n: jnp.ones(n))
# TODO(mattjj,dougalm): support dynamic shapes in type checker
with jax.enable_checks(False):
jaxpr = jax.make_jaxpr(f)(3)
# { lambda ; a:i32[]. let
# b:f32[a] = xla_call[
# call_jaxpr={ lambda ; c:i32[]. let
# d:f32[c] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0
# c
# in (d,) }
# name=<lambda>
# ] a
# in (b,) }
a, = jaxpr.jaxpr.invars
e, = jaxpr.jaxpr.eqns
a_, = e.invars
self.assertIs(a, a_)
b, = e.outvars
a__, = b.aval.shape
self.assertIs(a, a__)
with jax.enable_checks(False):
jaxpr = jax.make_jaxpr(lambda: f(3))()
# { lambda ; . let
# a:f32[3] = xla_call[
# call_jaxpr={ lambda ; b:i32[]. let
# c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0
# b
# in (c,) }
# name=<lambda>
# ] 3
# in (a,) }
() = jaxpr.jaxpr.invars
e, = jaxpr.jaxpr.eqns
three, = e.invars
self.assertIsInstance(three, core.Literal)
self.assertEqual(three.val, 3)
b, = e.outvars
three_, = b.aval.shape
self.assertIsInstance(three_, int)
self.assertEqual(three_, 3)
def test_jit_basic_iree(self):
if not jtu.device_under_test() == 'iree':
raise unittest.SkipTest("test only works on IREE")
@jax.jit
def f(i):
return jnp.sum(jnp.ones(i, dtype='float32'))
self.assertAllClose(f(3), jnp.array(3., dtype='float32'), check_dtypes=True)
def test_slicing_basic(self):
f = jax.jit(lambda x, n: jnp.sum(x[:n]))
ans = f(jnp.arange(10), 3)
expected = jnp.sum(jnp.arange(10)[:3])
self.assertAllClose(ans, expected, check_dtypes=True)
def test_scan_basic(self):
def cumsum(x):
def body(i, _):
return i + 1, jnp.sum(x[:i+1])
_, ans = lax.scan(body, 0, None, length=len(x))
return ans
x = jnp.array([3, 1, 4, 1, 5, 9])
with jax.enable_checks(False):
ans = cumsum(x)
expected = jnp.cumsum(x)
self.assertAllClose(ans, expected, check_dtypes=False)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -30,7 +30,8 @@ from jax import numpy as jnp
from jax import linear_util as lu
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_reduce, tree_leaves
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
tree_leaves)
from jax.interpreters import partial_eval as pe
from jax._src import test_util as jtu
@ -529,15 +530,15 @@ class DynamicShapesTest(jtu.JaxTestCase):
def test_staging_basic(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
a = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
@lu.wrap_init
def f(x, y):
return x, y
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
keep_inputs=[False, True, True])
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 3)
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
@ -549,8 +550,8 @@ class DynamicShapesTest(jtu.JaxTestCase):
def test_staging_nested(self):
n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
a = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
@lu.wrap_init
def f(x, y):
@ -559,8 +560,8 @@ class DynamicShapesTest(jtu.JaxTestCase):
return (x, w)
return g(x, y, x, y)
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
keep_inputs=[False, True, True])
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
@ -583,10 +584,9 @@ class DynamicShapesTest(jtu.JaxTestCase):
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
def test_staging_nested_including_shape_arg(self):
# This test covers the _get_tracers_only_in_shapes logic in partial_eval.py.
n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
a = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
@lu.wrap_init
def f(x, y):
@ -595,8 +595,18 @@ class DynamicShapesTest(jtu.JaxTestCase):
return (x, w)
return g(x.shape[0], x, y, x, y)
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
keep_inputs=[False, True, True])
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
print(jaxpr)
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
# d:f32[a] e:f32[a] = xla_call[
# call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let
#
# in (h, k) }
# name=g
# ] a a b c b c
# in (d, e) }
self.assertLen(jaxpr.eqns, 1)
eqn = jaxpr.eqns[0]
@ -612,8 +622,8 @@ class DynamicShapesTest(jtu.JaxTestCase):
def test_staging_primitive_applications(self):
n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
a = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((pe.DBIdx(0),), jnp.dtype('float32'), weak_type=False)
@lu.wrap_init
def f(x, y):
@ -622,8 +632,8 @@ class DynamicShapesTest(jtu.JaxTestCase):
u = lax_internal._reduce_sum(w, [0])
return (u,)
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
keep_inputs=[False, True, True])
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
self.assertLen(jaxpr.eqns, 3)