mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 08:26:07 +00:00

Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla. PiperOrigin-RevId: 413244450
1565 lines
52 KiB
Python
1565 lines
52 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial, reduce
|
|
import itertools as it
|
|
import operator as op
|
|
from typing import (Tuple, List, Sequence, Set, Dict, Any, Callable, Union,
|
|
Optional)
|
|
|
|
from jax import core
|
|
from jax._src import dispatch
|
|
from jax._src import source_info_util
|
|
from jax.core import Var, Literal, Atom, Tracer
|
|
from jax._src import util
|
|
from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list,
|
|
tuple_delete)
|
|
import jax._src.pretty_printer as pp
|
|
|
|
map = safe_map
|
|
zip = safe_zip
|
|
def identity(x): return x
|
|
|
|
DType = Any
|
|
NDArray = Any
|
|
|
|
|
|
# Dynamic shape jaxprs
|
|
|
|
## Element types
|
|
|
|
class EltTy: pass
|
|
|
|
class BaseType(EltTy):
|
|
def __init__(self, dtype: DType):
|
|
self._dtype = np.dtype(dtype)
|
|
|
|
def __repr__(self):
|
|
return f'BaseType({self._dtype.name})'
|
|
|
|
def __hash__(self):
|
|
return hash(self._dtype)
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, BaseType) and self._dtype == other._dtype
|
|
|
|
class BoundedIntTy(EltTy):
|
|
def __init__(self, bound: int):
|
|
assert isinstance(bound, int)
|
|
self._bound = bound
|
|
|
|
def __repr__(self):
|
|
return f'BIntTy{{≤{self._bound}}}'
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, BoundedIntTy) and self._bound == other._bound
|
|
|
|
|
|
## Array types
|
|
|
|
class AbsArray(core.AbstractValue):
|
|
def __init__(self, shape, eltTy):
|
|
assert isinstance(shape, tuple)
|
|
assert isinstance(eltTy, EltTy)
|
|
self.shape = shape
|
|
self._eltTy = eltTy
|
|
|
|
def str_short(self, short_dtypes=False):
|
|
del short_dtypes # ignored
|
|
shape = f'[{",".join(str(d) for d in self.shape)}]' if self.shape else ''
|
|
if isinstance(self._eltTy, BoundedIntTy):
|
|
return f'BInt{{≤{self._eltTy._bound}}}{shape}'
|
|
elif isinstance(self._eltTy, BaseType):
|
|
dtype = self._eltTy._dtype.name
|
|
return f'{dtype}{shape}'
|
|
else:
|
|
return repr(self)
|
|
|
|
def __eq__(self, other):
|
|
if (isinstance(other, AbsArray) and self._eltTy == other._eltTy and
|
|
len(self.shape) == len(other.shape)):
|
|
for a, b in zip(self.shape, other.shape):
|
|
if type(a) is type(b) is int:
|
|
if a != b: return False
|
|
elif type(a) is type(b) is BoundedInt:
|
|
if a is not b: return False
|
|
elif type(a) is type(b) is Var:
|
|
if a is not b: return False
|
|
elif type(a) is type(b) is AbsArray:
|
|
if a != b: return False
|
|
elif type(a) is type(b) is DimIndexingExpr:
|
|
if a.name is not b.name or a.indices != b.indices: return False
|
|
else:
|
|
return False
|
|
else:
|
|
return True
|
|
return False
|
|
|
|
# this duck-typing is needed by eg ad.py using dtypes.py
|
|
@property
|
|
def dtype(self):
|
|
if isinstance(self._eltTy, BaseType):
|
|
return self._eltTy._dtype
|
|
else:
|
|
raise Exception
|
|
|
|
def at_least_vspace(self):
|
|
return AbsArray(self.shape, self._eltTy)
|
|
|
|
def join(self, other):
|
|
if self == other:
|
|
return self
|
|
raise NotImplementedError # TODO
|
|
|
|
class DimIndexingExpr:
|
|
def __init__(self, name, indices):
|
|
assert isinstance(name, (Var, Tracer))
|
|
assert (isinstance(indices, tuple) and
|
|
all(isinstance(i, int) for i in indices))
|
|
self.name = name
|
|
self.indices = indices
|
|
|
|
def __repr__(self):
|
|
indices = '.'.join(map(str, self.indices))
|
|
return f'{self.name}.{indices}'
|
|
|
|
|
|
## DJaxprs
|
|
|
|
class DJaxprTy:
|
|
in_dim_binders: List[Var]
|
|
in_types: List[core.AbstractValue]
|
|
out_dim_binders: List[Var]
|
|
out_types: List[core.AbstractValue]
|
|
|
|
def __init__(self, in_dim_binders, in_types, out_dim_binders, out_types):
|
|
self.in_dim_binders = in_dim_binders
|
|
self.in_types = in_types
|
|
self.out_dim_binders = out_dim_binders
|
|
self.out_types = out_types
|
|
|
|
def __repr__(self):
|
|
in_dim_binders = pp_vars(self.in_dim_binders)
|
|
in_types = ', '.join(aval.str_short() for aval in self.in_types)
|
|
out_dim_binders = pp_vars(self.out_dim_binders)
|
|
out_types = ', '.join(aval.str_short() for aval in self.out_types)
|
|
return f'[{in_dim_binders}] [{in_types}] -> [{out_dim_binders}] [{out_types}]'
|
|
|
|
class DJaxpr:
|
|
in_dim_binders: List[Var]
|
|
in_binders: List[Var]
|
|
out_dims: List[Atom]
|
|
outs: List[Atom]
|
|
eqns: List[core.JaxprEqn] # reusing existing eqns, helps reuse some tracing
|
|
|
|
def __init__(self, in_dim_binders, in_binders, out_dims, outs, eqns):
|
|
assert all(isinstance(v, Var) and isinstance(v.aval, AbsArray) and
|
|
isinstance(v.aval._eltTy, BoundedIntTy) for v in in_dim_binders)
|
|
assert all(isinstance(v, Var) for v in in_binders)
|
|
assert all(isinstance(x, (Var, Literal)) and isinstance(x.aval, AbsArray) and
|
|
isinstance(x.aval._eltTy, BoundedIntTy) for x in out_dims)
|
|
assert all(isinstance(x, (Var, Literal)) for x in outs)
|
|
assert all(isinstance(e, core.JaxprEqn) for e in eqns)
|
|
self.in_dim_binders = in_dim_binders
|
|
self.in_binders = in_binders
|
|
self.out_dims = out_dims
|
|
self.outs = outs
|
|
self.eqns = eqns
|
|
|
|
def __repr__(self):
|
|
return str(pp_djaxpr(self))
|
|
|
|
def pp_djaxpr(jaxpr: DJaxpr) -> pp.Doc:
|
|
eqns = map(pp_eqn, jaxpr.eqns)
|
|
in_dim_binders = pp_vars(jaxpr.in_dim_binders)
|
|
in_binders = pp_vars(jaxpr.in_binders)
|
|
out_dims = ', '.join(map(str, jaxpr.out_dims))
|
|
outs = ', '.join(map(str, jaxpr.outs))
|
|
out_dim_types = pp_vars(jaxpr.out_dims)
|
|
outs_type = ', '.join(v.aval.str_short() for v in jaxpr.outs)
|
|
return (pp.text(f'{{ lambda {in_dim_binders} ; {in_binders} .')
|
|
+ (pp.text('let ') + pp.nest(2, pp.brk() + pp.join(pp.brk(), eqns)) +
|
|
pp.text(f'in ( {out_dims} ; {outs} ) '
|
|
f': ( {out_dim_types} ; {outs_type} ) }}')))
|
|
|
|
def pp_vars(vs: Sequence[Atom]) -> str:
|
|
return ', '.join(f'{v}:{v.aval.str_short()}' for v in vs)
|
|
|
|
def pp_eqn(eqn: core.JaxprEqn) -> pp.Doc:
|
|
lhs = pp_vars(eqn.outvars)
|
|
pp_lhs = pp.text(f'{lhs} =')
|
|
pp_rhs = (pp.text(eqn.primitive.name) +
|
|
core.pp_kv_pairs(sorted(eqn.params.items()), core.JaxprPpContext())
|
|
+ pp.text(' ') + pp.text(' '.join(map(str, eqn.invars))))
|
|
return pp_lhs + pp.text(' ') + pp_rhs
|
|
|
|
# Typechecking DJaxprs
|
|
|
|
def typecheck_jaxpr(jaxpr: DJaxpr):
|
|
env: Set[Var] = set() # bound variables
|
|
|
|
for v in jaxpr.in_dim_binders:
|
|
if not (isinstance(v.aval, AbsArray) and
|
|
isinstance(v.aval._eltTy, BoundedIntTy)): raise TypeError
|
|
typecheck_type(env, v.aval)
|
|
env.add(v)
|
|
|
|
for v in jaxpr.in_binders:
|
|
typecheck_type(env, v.aval)
|
|
for v in jaxpr.in_binders:
|
|
env.add(v)
|
|
|
|
for eqn in jaxpr.eqns:
|
|
for x in eqn.invars:
|
|
typecheck_atom(env, x)
|
|
rule = typecheck_rules[eqn.primitive]
|
|
out_types = rule(*eqn.invars, **eqn.params)
|
|
subst: Dict[Var, Var] = {}
|
|
for v, t in zip(eqn.outvars, out_types):
|
|
if isinstance(t, Var):
|
|
aval = substitute(subst, t.aval)
|
|
if v.aval != aval: raise TypeError(f'{v.aval} != {aval}')
|
|
subst[t] = v
|
|
elif isinstance(t, core.AbstractValue):
|
|
aval = substitute(subst, t)
|
|
if v.aval.strip_weak_type() != aval:
|
|
raise TypeError(f'{v.aval} != {aval}')
|
|
else:
|
|
assert False # typecheck rule produced unexpected type
|
|
typecheck_type(env, v.aval)
|
|
env.add(v)
|
|
|
|
in_types = [v.aval for v in jaxpr.in_binders]
|
|
out_types = []
|
|
for x in jaxpr.outs:
|
|
aval = typecheck_atom(env, x)
|
|
out_types.append(aval)
|
|
|
|
return DJaxprTy(jaxpr.in_dim_binders, in_types, jaxpr.out_dims, out_types)
|
|
|
|
def typecheck_type(env, aval):
|
|
if isinstance(aval, (core.AbstractUnit, core.ShapedArray)):
|
|
return aval # all syntactic forms are valid
|
|
elif isinstance(aval, AbsArray):
|
|
for i, d in enumerate(aval.shape):
|
|
if isinstance(d, int):
|
|
continue
|
|
elif isinstance(d, Var):
|
|
if d not in env: raise TypeError('unbound dim size')
|
|
if not (isinstance(d.aval, AbsArray) and not d.aval.shape and
|
|
isinstance(d.aval._eltTy, BoundedIntTy)):
|
|
raise TypeError(f'dim var of unexpected type: {d.aval}')
|
|
elif isinstance(d, DimIndexingExpr):
|
|
if d.name not in env: raise TypeError('unbound dim size')
|
|
if not (isinstance(d.name.aval, AbsArray) and
|
|
isinstance(d.name.aval._eltTy, BoundedIntTy)):
|
|
raise TypeError(f'dim var of unexpected type: {d.name.aval}')
|
|
d_indices_set = set(d.indices)
|
|
if i in d_indices_set:
|
|
raise TypeError(f"circular dim indexing expression: {d}")
|
|
for j in d.indices:
|
|
d_j = aval.shape[j]
|
|
if (isinstance(d_j, DimIndexingExpr) and
|
|
not d_indices_set.issuperset(d_j.indices)):
|
|
raise TypeError(f"dim indexing not transitively closed: {d}")
|
|
expected_idx_array_shape = tuple(aval.shape[j] for j in d.indices)
|
|
if d.name.aval.shape != expected_idx_array_shape:
|
|
raise TypeError(f'incompatible shapes in dim indexing: {aval}')
|
|
else:
|
|
raise TypeError(f'unexpected type in shape: {type(d)}')
|
|
return aval
|
|
else:
|
|
raise TypeError(f'unknown type: {aval}')
|
|
|
|
def typecheck_atom(env, x):
|
|
if isinstance(x, Literal):
|
|
return core.raise_to_shaped(core.get_aval(x.val))
|
|
elif isinstance(x, Var):
|
|
return typecheck_type(env, x.aval)
|
|
else:
|
|
raise TypeError(f'atom of unexpected type {x}')
|
|
|
|
def substitute(subst, aval):
|
|
if isinstance(aval, AbsArray):
|
|
new_shape = []
|
|
for d in aval.shape:
|
|
if isinstance(d, Var):
|
|
new_d = subst.get(d, d)
|
|
elif isinstance(d, DimIndexingExpr):
|
|
new_d = DimIndexingExpr(subst.get(d.name, d.name), d.indices)
|
|
else:
|
|
new_d = d
|
|
new_shape.append(new_d)
|
|
return AbsArray(tuple(new_shape), aval._eltTy)
|
|
else:
|
|
return aval
|
|
|
|
typecheck_rules: Dict[core.Primitive, Callable] = {}
|
|
|
|
|
|
# Interpreting DJaxprs
|
|
|
|
def eval_jaxpr(jaxpr, dim_args, args):
|
|
env: Dict[Var, Any] = {}
|
|
|
|
def read(v):
|
|
if type(v) is core.Literal:
|
|
return v.val
|
|
else:
|
|
return env[v]
|
|
|
|
def write(v, val):
|
|
env[v] = val
|
|
|
|
write(core.unitvar, core.unit)
|
|
map(write, jaxpr.in_dim_binders, dim_args)
|
|
map(write, jaxpr.in_binders, args)
|
|
for eqn in jaxpr.eqns:
|
|
in_vals = map(read, eqn.invars)
|
|
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
|
if eqn.primitive.multiple_results:
|
|
map(write, eqn.outvars, ans)
|
|
elif len(eqn.outvars) > 1:
|
|
# TODO a jaxpr unpacks dependent tuples, while Python packages them up
|
|
map(write, eqn.outvars, eqn.primitive.unpack_result(ans))
|
|
else:
|
|
write(eqn.outvars[0], ans)
|
|
return map(read, jaxpr.out_dims), map(read, jaxpr.outs)
|
|
|
|
@curry
|
|
def jaxpr_as_fun(jaxpr, *args):
|
|
shapevars_to_vals: Dict[Var, Any] = dict(
|
|
(d, t) for v, x in zip(jaxpr.in_binders, args) if isinstance(v.aval, AbsArray)
|
|
for d, t in zip(v.aval.shape, x.shape) if isinstance(d, Var)
|
|
and x is not core.unit) # TODO partial eval assumes we can plug in units?
|
|
dim_args = [shapevars_to_vals[v] for v in jaxpr.in_dim_binders]
|
|
_, out = eval_jaxpr(jaxpr, dim_args, args)
|
|
return out
|
|
|
|
|
|
# Data representations
|
|
|
|
class BoundedInt:
|
|
val: Union[int, Tracer]
|
|
bound: int
|
|
|
|
def __init__(self, val: Union[int, Tracer], bound: int):
|
|
self._val = val
|
|
self._bound = bound
|
|
|
|
def __repr__(self):
|
|
return f'{self._val}{{≤{self._bound}}}'
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, BoundedInt) and self._bound == other._bound:
|
|
return self._val is other._val or self._val == other._val
|
|
elif isinstance(other, int):
|
|
return self._val == other
|
|
else:
|
|
raise Exception
|
|
|
|
class DimIndexer:
|
|
data: NDArray
|
|
indices: Tuple[int, ...]
|
|
|
|
def __init__(self, data, indices):
|
|
self._data = data
|
|
self._indices = indices
|
|
|
|
def __repr__(self):
|
|
indices = '.'.join(map(str, self._indices))
|
|
data = f'{self._data._data}'
|
|
return f'{data}.{indices}'
|
|
|
|
# We want these to duck-type ndarrays when the element type is BaseType.
|
|
class Array:
|
|
def __init__(self,
|
|
shape: Tuple[Union[int, BoundedInt, DimIndexer], ...],
|
|
eltTy: EltTy,
|
|
data: NDArray):
|
|
self.shape = shape
|
|
self._eltTy = eltTy
|
|
self._data = data
|
|
|
|
@property
|
|
def dtype(self):
|
|
if isinstance(self._eltTy, BaseType):
|
|
return self._eltTy._dtype
|
|
else:
|
|
raise Exception
|
|
|
|
def __repr__(self):
|
|
dtypestr = (self._eltTy._dtype.name if isinstance(self._eltTy, BaseType)
|
|
else f'BInt{{≤{self._eltTy._bound}}}') # type: ignore
|
|
shapestr = ','.join(map(str, self.shape))
|
|
if any(isinstance(d, DimIndexer) for d in self.shape):
|
|
# find the last DimIndexer, as we'll treat chunks below that as
|
|
# rectangular
|
|
last = next(i for i, d in reversed(list(enumerate(self.shape)))
|
|
if isinstance(d, DimIndexer))
|
|
shape_prefix = tuple(d._val if type(d) is BoundedInt else d
|
|
for d in self.shape[:last])
|
|
outs = []
|
|
for idx in it.product(*map(range, shape_prefix)):
|
|
slices = [slice(d._data._data[tuple(idx[i] for i in d._indices)])
|
|
if isinstance(d, DimIndexer) else
|
|
slice(d._val) if isinstance(d, BoundedInt) else
|
|
slice(None) for d in self.shape[last:]]
|
|
full_index = (*idx, *slices)
|
|
data = self._data[full_index]
|
|
outs.append(f'{idx}:\n{data}')
|
|
return f'{dtypestr}[{shapestr}] with values:\n' + '\n\n'.join(outs)
|
|
else:
|
|
slices = tuple(slice(d._val) if type(d) is BoundedInt else slice(None)
|
|
for d in self.shape)
|
|
data = self._data[slices]
|
|
return f'{dtypestr}[{shapestr}] with value:\n{data}'
|
|
|
|
def __array__(self):
|
|
if any(isinstance(d, DimIndexer) for d in self.shape):
|
|
raise NotImplementedError # ragged ndarray
|
|
else:
|
|
slices = tuple(slice(d._val) if type(d) is BoundedInt else slice(None)
|
|
for d in self.shape)
|
|
return np.array(self._data[slices])
|
|
|
|
|
|
# Tracing to embed DJaxprs in Python
|
|
|
|
from jax import linear_util as lu
|
|
from jax.interpreters import partial_eval as pe
|
|
|
|
from jax._src.api_util import flatten_fun
|
|
from jax.tree_util import tree_flatten, tree_unflatten
|
|
|
|
def make_djaxpr(fun, *args, **kwargs):
|
|
args, in_tree = tree_flatten((args, kwargs))
|
|
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
|
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
|
|
return trace_to_jaxpr_dynamic(f, in_avals)
|
|
|
|
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
|
with core.new_main(DJaxprTrace, dynamic=True) as main:
|
|
main.jaxpr_stack = () # type: ignore
|
|
outs = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
|
del main
|
|
return outs
|
|
|
|
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
|
|
in_avals: Sequence[core.AbstractValue]):
|
|
frame = DJaxprStackFrame()
|
|
with pe.extend_jaxpr_stack(main, frame):
|
|
trace = DJaxprTrace(main, core.cur_sublevel())
|
|
in_dim_tracers, in_avals = _place_in_dim_tracers_in_shapes(trace, in_avals)
|
|
in_tracers = map(trace.new_arg, in_avals)
|
|
ans = fun.call_wrapped(*in_tracers)
|
|
out_tracers = map(trace.full_raise, ans)
|
|
out_dim_tracers = _extract_out_dim_tracers_from_shapes(main, in_dim_tracers, out_tracers)
|
|
return frame.to_jaxpr(in_dim_tracers, in_tracers, out_dim_tracers, out_tracers)
|
|
|
|
def _place_in_dim_tracers_in_shapes(trace, in_avals):
|
|
dim_tracers = {}
|
|
new_in_avals = []
|
|
for aval in in_avals:
|
|
if not isinstance(aval, AbsArray):
|
|
new_in_avals.append(aval)
|
|
else:
|
|
new_shape = []
|
|
for d in aval.shape:
|
|
if isinstance(d, AbsArray):
|
|
assert d.shape == () and isinstance(d._eltTy, BoundedIntTy)
|
|
dim_tracer = dim_tracers.get(id(d))
|
|
if dim_tracer is None:
|
|
dim_tracer = dim_tracers[id(d)] = trace.new_arg(d)
|
|
new_shape.append(dim_tracer)
|
|
elif isinstance(d, (int, BoundedInt)):
|
|
new_shape.append(d)
|
|
else:
|
|
raise NotImplementedError(d) # TODO
|
|
new_aval = AbsArray(tuple(new_shape), aval._eltTy)
|
|
new_in_avals.append(new_aval)
|
|
return list(dim_tracers.values()), new_in_avals
|
|
|
|
def _extract_out_dim_tracers_from_shapes(main, in_dim_tracers, out_tracers):
|
|
seen = {id(d) for d in in_dim_tracers}
|
|
def take(d):
|
|
if isinstance(d, Tracer):
|
|
return d._trace.main is main and id(d) not in seen and not seen.add(id(d))
|
|
elif isinstance(d, DimIndexingExpr):
|
|
return take(d.name)
|
|
else:
|
|
return False
|
|
return [d.name if isinstance(d, DimIndexingExpr) else d
|
|
for t in out_tracers if isinstance(t.aval, AbsArray)
|
|
for d in t.aval.shape if take(d)]
|
|
|
|
class DJaxprTrace(pe.DynamicJaxprTrace):
|
|
def process_primitive(self, primitive, tracers, params):
|
|
rule = custom_staging_rules.get(primitive)
|
|
if rule:
|
|
return rule(self, tracers, params)
|
|
else:
|
|
# If there's no special staging rule, by default do regular Jaxpr staging
|
|
return super().process_primitive(primitive, tracers, params)
|
|
|
|
def get_const(self, tracer):
|
|
assert isinstance(tracer, Tracer)
|
|
return self.frame.constvar_to_val.get(self.frame.tracer_to_var.get(id(tracer)))
|
|
|
|
def new_const(self, val):
|
|
if isinstance(val, BoundedInt):
|
|
raise NotImplementedError # TODO
|
|
elif isinstance(val, Array) and val.shape:
|
|
raise NotImplementedError # TODO
|
|
else:
|
|
return super().new_const(val)
|
|
|
|
custom_staging_rules: Dict[core.Primitive, Callable] = {}
|
|
|
|
class DJaxprStackFrame(pe.JaxprStackFrame):
|
|
def to_jaxpr(self, in_dim_tracers, in_tracers, out_dim_tracers, out_tracers):
|
|
t2v = lambda t: self.tracer_to_var[id(t)]
|
|
in_dim_binders, in_binders = map(t2v, in_dim_tracers), map(t2v, in_tracers)
|
|
out_dims, outs = map(t2v, out_dim_tracers), map(t2v, out_tracers)
|
|
|
|
# only include constants that are used
|
|
used_vars = ({a for eqn in self.eqns for a in eqn.invars if isinstance(a, Var)} |
|
|
{a for grp in [out_dims, outs] for a in grp if isinstance(a, Var)})
|
|
constvars, constvals = unzip2(
|
|
(v, c) for v, c in self.constvar_to_val.items() if v in used_vars)
|
|
in_binders = [*constvars, *in_binders]
|
|
|
|
# promote some lambda binders to pi binders
|
|
used_shape_vars = ({d for eqn in self.eqns for v in eqn.outvars
|
|
if isinstance(v.aval, AbsArray)
|
|
for d in v.aval.shape if isinstance(d, Var)} |
|
|
{d.name for eqn in self.eqns for v in eqn.outvars
|
|
if isinstance(v.aval, AbsArray)
|
|
for d in v.aval.shape if isinstance(d, DimIndexingExpr)})
|
|
lambda_binders = [v not in used_shape_vars for v in in_binders]
|
|
converted_binders, in_binders = partition_list(lambda_binders, in_binders)
|
|
in_dim_binders = in_dim_binders + converted_binders
|
|
out_dims = [v for v in out_dims if v not in in_dim_binders] # TODO
|
|
|
|
jaxpr = DJaxpr(in_dim_binders, in_binders, out_dims, outs, self.eqns)
|
|
typecheck_jaxpr(jaxpr)
|
|
return jaxpr, constvals, lambda_binders
|
|
|
|
def newvar(self, aval):
|
|
if isinstance(aval, AbsArray) and aval.shape:
|
|
# replace any tracers in the shape with their corresponding variables
|
|
shape = []
|
|
for d in aval.shape:
|
|
if isinstance(d, Tracer):
|
|
shape.append(self.tracer_to_var[id(d)])
|
|
elif isinstance(d, DimIndexingExpr):
|
|
assert isinstance(d.name, Tracer)
|
|
shape.append(DimIndexingExpr(self.tracer_to_var[id(d.name)], d.indices))
|
|
else:
|
|
shape.append(d)
|
|
aval = AbsArray(tuple(shape), aval._eltTy)
|
|
return self.gensym(aval)
|
|
|
|
def partition_list(bs, lst):
|
|
lists = lst1, lst2 = [], []
|
|
for b, x in zip(bs, lst):
|
|
lists[b].append(x)
|
|
return lst1, lst2
|
|
|
|
def _raise_absarray_to_type_level(aval: AbsArray, weak_type: bool):
|
|
assert isinstance(aval, AbsArray)
|
|
unique_avals: Dict[int, AbsArray] = {}
|
|
shape = []
|
|
for d in aval.shape:
|
|
if isinstance(d, BoundedInt):
|
|
shape.append(unique_avals.setdefault(id(d), AbsArray((), BoundedIntTy(d._bound))))
|
|
elif isinstance(d, DimIndexer):
|
|
raise NotImplementedError # TODO
|
|
else:
|
|
shape.append(d)
|
|
return AbsArray(tuple(shape), aval._eltTy)
|
|
core.raise_to_shaped_mappings[AbsArray] = _raise_absarray_to_type_level
|
|
|
|
def _abstractify_array_for_ad(x: Array): # TODO misleading name, used in djit
|
|
return AbsArray(x.shape, x._eltTy)
|
|
core.pytype_aval_mappings[Array] = _abstractify_array_for_ad
|
|
|
|
def _abstractify_bdint(x: BoundedInt):
|
|
return AbsArray((), BoundedIntTy(x._bound))
|
|
core.pytype_aval_mappings[BoundedInt] = _abstractify_bdint
|
|
|
|
|
|
# XLA lowering
|
|
|
|
from jax.interpreters import xla
|
|
from jax._src.lib import xla_bridge as xb
|
|
from jax._src.lib import xla_client as xc
|
|
xe = xc._xla
|
|
xops = xc._xla.ops
|
|
|
|
def _abstractify_array_to_type_level(x: Array):
|
|
return core.raise_to_shaped(core.get_aval(x))
|
|
xla.pytype_aval_mappings[Array] = _abstractify_array_to_type_level
|
|
|
|
def _array_xla_shape(aval: AbsArray):
|
|
if isinstance(aval._eltTy, BaseType):
|
|
dtype = aval._eltTy._dtype
|
|
shape = [d._eltTy._bound if isinstance(d, AbsArray) and not d.shape
|
|
else d for d in aval.shape]
|
|
return (xla.xc.Shape.array_shape(dtype, shape),)
|
|
elif isinstance(aval._eltTy, BoundedIntTy):
|
|
shape = [d._bound if isinstance(d, BoundedInt) else d for d in aval.shape]
|
|
return (xla.xc.Shape.array_shape(np.dtype('int32'), shape),)
|
|
else:
|
|
raise NotImplementedError
|
|
xla.xla_shape_handlers[AbsArray] = _array_xla_shape
|
|
xla.canonicalize_dtype_handlers[Array] = identity
|
|
|
|
def _array_device_put(x, device):
|
|
return dispatch._device_put_array(x._data, device)
|
|
dispatch.device_put_handlers[Array] = _array_device_put
|
|
|
|
def _bdint_device_put(x, device):
|
|
return dispatch._device_put_scalar(x._val, device)
|
|
dispatch.device_put_handlers[BoundedInt] = _bdint_device_put
|
|
|
|
def _bdint_canoncalize_dtype(x):
|
|
return BoundedInt(xla.canonicalize_dtype(x._val), x._bound)
|
|
xla.canonicalize_dtype_handlers[BoundedInt] = _bdint_canoncalize_dtype
|
|
|
|
def _make_params(c, dim_in_avals, in_avals):
|
|
n = it.count()
|
|
make = lambda a: [xla.parameter(c, next(n), s) for s in xla.aval_to_xla_shapes(a)]
|
|
return map(make, dim_in_avals), map(make, in_avals)
|
|
|
|
def _xla_consts(c, consts):
|
|
unique_consts = {id(const): const for const in consts}
|
|
xla_consts = {
|
|
id_: [xla.pyval_to_ir_constant(c, const)]
|
|
for id_, const in unique_consts.items()}
|
|
return [xla_consts[id(const)] for const in consts]
|
|
|
|
def djaxpr_subcomp(c, jaxpr, dim_args, args):
|
|
env: Dict[Var, Sequence[xe.XlaOp]] = {}
|
|
|
|
def aval(v):
|
|
return xla.abstractify(v.val) if type(v) is core.Literal else v.aval
|
|
|
|
def read(v):
|
|
if type(v) is core.Literal:
|
|
return [xla.pyval_to_ir_constant(c, xla.canonicalize_dtype(v.val))]
|
|
else:
|
|
return env[v]
|
|
|
|
def write(v, nodes):
|
|
env[v] = nodes
|
|
|
|
write(core.unitvar, xla._make_unit_constant(c))
|
|
map(write, jaxpr.in_dim_binders, dim_args)
|
|
map(write, jaxpr.in_binders, args)
|
|
for eqn in jaxpr.eqns:
|
|
in_vals, in_avals = map(read, eqn.invars), map(aval, eqn.invars)
|
|
in_dims = {v:read(v) for a in in_avals if isinstance(a, AbsArray)
|
|
for v in a.shape if isinstance(v, Var)}
|
|
rule = translations[eqn.primitive]
|
|
out_vals = rule(c, in_dims, in_avals, in_vals, **eqn.params)
|
|
map(write, eqn.outvars, out_vals)
|
|
return map(read, jaxpr.out_dims), map(read, jaxpr.outs)
|
|
|
|
def execute_compiled(compiled, partitioner, handlers, dim_vals, args):
|
|
input_bufs = list(it.chain(
|
|
(buf for x in dim_vals for buf in dispatch.device_put(x, None)),
|
|
(buf for x in args for buf in dispatch.device_put(x, None))))
|
|
out_bufs = compiled.execute(input_bufs)
|
|
dims_dict, grouped_bufs = partitioner(out_bufs)
|
|
return [handler(dims_dict, bs) for handler, bs in zip(handlers, grouped_bufs)]
|
|
|
|
def result_partitioner(in_dim_binders, in_dim_vals, out_dims, out_bufcounts):
|
|
out_dimvars = [v for v in out_dims if isinstance(v, Var)]
|
|
split_sizes = [len(out_dimvars)] + out_bufcounts[:-1]
|
|
|
|
def dim_handler(v, buf):
|
|
if not v.aval.shape:
|
|
return BoundedInt(int(buf.to_py()), v.aval._eltTy._bound)
|
|
else:
|
|
return Array(v.aval.shape, v.aval._eltTy, buf.to_py())
|
|
|
|
def partitioner(bufs):
|
|
dim_bufs, *grouped_bufs = split_list(bufs, split_sizes)
|
|
dims_dict = dict(it.chain(
|
|
zip(in_dim_binders, in_dim_vals),
|
|
zip(out_dimvars, map(dim_handler, out_dimvars, dim_bufs))))
|
|
return dims_dict, grouped_bufs
|
|
return partitioner
|
|
|
|
def result_handler(aval):
|
|
if isinstance(aval, AbsArray):
|
|
return array_result_handler(aval)
|
|
else:
|
|
handler = dispatch.aval_to_result_handler(None, aval)
|
|
return lambda _, bufs: handler(*bufs)
|
|
|
|
def array_result_handler(aval):
|
|
if not isinstance(aval._eltTy, BaseType): raise NotImplementedError
|
|
padded_shape = []
|
|
for d in aval.shape:
|
|
if isinstance(d, int):
|
|
padded_shape.append(d)
|
|
elif isinstance(d, Var):
|
|
padded_shape.append(d.aval._eltTy._bound)
|
|
elif isinstance(d, DimIndexingExpr):
|
|
padded_shape.append(d.name.aval._eltTy._bound)
|
|
else:
|
|
raise NotImplementedError # TODO
|
|
padded_aval = core.ShapedArray(tuple(padded_shape), aval._eltTy._dtype)
|
|
array_handler = dispatch.array_result_handler(None, padded_aval)
|
|
def handler(dims_dict, bufs):
|
|
shape = tuple(dims_dict[d] if isinstance(d, Var) else
|
|
DimIndexer(dims_dict[d.name], d.indices) if isinstance(d, DimIndexingExpr) else
|
|
d for d in aval.shape)
|
|
return Array(shape, aval._eltTy, array_handler(*bufs))
|
|
return handler
|
|
|
|
def aval_to_num_buffers(aval):
|
|
if isinstance(aval, AbsArray):
|
|
return 1
|
|
else:
|
|
return len(xla.aval_to_xla_shapes(aval))
|
|
|
|
translations: Dict[core.Primitive, Callable] = {}
|
|
|
|
|
|
dynamic_xla_call_p = core.Primitive('dxla_call')
|
|
dynamic_xla_call_p.multiple_results = True
|
|
|
|
@dynamic_xla_call_p.def_impl
|
|
def _dynamic_xla_call_impl(*args, jaxpr, num_consts):
|
|
in_dim_vals, consts, args = split_list(args, [len(jaxpr.in_dim_binders), num_consts])
|
|
dim_in_avals = [v.aval for v in jaxpr.in_dim_binders]
|
|
c = xc.XlaBuilder("dxla_call")
|
|
dim_params, params = _make_params(c, dim_in_avals, map(xla.abstractify, args))
|
|
const_params = _xla_consts(c, consts)
|
|
dim_outs, outs = djaxpr_subcomp(c, jaxpr, dim_params, const_params + params)
|
|
out = xops.Tuple(c, [o for ops in dim_outs + outs for o in ops])
|
|
compiled = xb.get_backend(None).compile(c.build(out))
|
|
result_handlers = map(result_handler, [v.aval for v in jaxpr.outs])
|
|
out_bufcounts = [aval_to_num_buffers(v.aval) for v in jaxpr.outs]
|
|
partitioner = result_partitioner(jaxpr.in_dim_binders, in_dim_vals,
|
|
jaxpr.out_dims, out_bufcounts)
|
|
return execute_compiled(compiled, partitioner, result_handlers,
|
|
in_dim_vals, args)
|
|
|
|
def djit(fun):
|
|
def f_jitted(*args, **kwargs):
|
|
args, in_tree = tree_flatten((args, kwargs))
|
|
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
|
# TODO we shouldn't dedup avals one array at a time; need to do it for the
|
|
# full argument list!
|
|
# unique_avals: Dict[int, core.AbstractValue] = {}
|
|
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
|
|
jaxpr, consts, unconverted_binders = trace_to_jaxpr_dynamic(f, in_avals)
|
|
num_consts = len(consts)
|
|
args = [*consts, *args]
|
|
dim_vals, args = _extract_dim_vals(jaxpr.in_dim_binders, jaxpr.in_binders,
|
|
unconverted_binders, args)
|
|
out_flat = dynamic_xla_call_p.bind(*dim_vals, *args, jaxpr=jaxpr,
|
|
num_consts=num_consts)
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
return f_jitted
|
|
|
|
def _extract_dim_vals(in_dim_binders, in_binders, unconverted_binders, args):
|
|
converted_in_dim_vals, args = partition_list(unconverted_binders, args)
|
|
sizes = {var: size for binder, arg in zip(in_binders, args)
|
|
for var, size in zip(binder.aval.shape, np.shape(arg))
|
|
if isinstance(var, Var)}
|
|
num_binders = len(in_dim_binders) - len(converted_in_dim_vals)
|
|
in_dim_vals = [sizes[v] for v in in_dim_binders[:num_binders]] + converted_in_dim_vals
|
|
return in_dim_vals, args
|
|
|
|
|
|
def traceable_to_padded_translation(traceable):
|
|
def translation(c, dims, avals, operands, **params):
|
|
dim_avals = [core.ShapedArray((), np.int32) for _ in dims]
|
|
padded_avals = map(_replace_vars_with_bounds, avals)
|
|
|
|
@lu.wrap_init
|
|
def fun(*args):
|
|
dim_sizes, args = split_list(args, [len(dims)])
|
|
logical_sizes = dict(zip(dims, dim_sizes))
|
|
logical_shapes = [tuple([logical_sizes.get(d, d) for d in aval.shape])
|
|
for aval in avals] # TODO more cases
|
|
return traceable(logical_shapes, *args, **params)
|
|
|
|
in_avals = [*dim_avals, *padded_avals]
|
|
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
|
|
|
operands_ = it.chain.from_iterable([*dims.values(), *operands])
|
|
platform = "cpu" # TODO: don't hardwire in the CPU translation.
|
|
ctx = xla.TranslationContext(c, platform, xla.AxisEnv(1, (), ()), '')
|
|
outs = xla.jaxpr_subcomp(ctx, jaxpr, xla._xla_consts(c, consts), *operands_)
|
|
return util.unflatten(outs,
|
|
[aval_to_num_buffers(aval) for aval in out_avals])
|
|
return translation
|
|
|
|
def _replace_vars_with_bounds(aval):
|
|
if not isinstance(aval, AbsArray):
|
|
return aval
|
|
else:
|
|
new_shape = []
|
|
for d in aval.shape:
|
|
if isinstance(d, Var):
|
|
assert d.aval.shape == () and isinstance(d.aval._eltTy, BoundedIntTy)
|
|
new_shape.append(d.aval._eltTy._bound)
|
|
elif isinstance(d, int):
|
|
new_shape.append(d)
|
|
elif isinstance(d, BoundedInt):
|
|
new_shape.append(d._bound)
|
|
else:
|
|
raise NotImplementedError(d)
|
|
return core.ShapedArray(tuple(new_shape), aval._eltTy._dtype)
|
|
|
|
# AD
|
|
|
|
from jax.interpreters import ad
|
|
|
|
def _dynamic_xla_call_jvp(primals, tangents, *, jaxpr, num_consts):
|
|
del num_consts
|
|
in_dim_vals, primals = split_list(primals, [len(jaxpr.in_dim_binders)])
|
|
_, tangents = split_list(tangents, [len(jaxpr.in_dim_binders)])
|
|
new_jaxpr, consts = jvp_jaxpr(jaxpr)
|
|
outs = dynamic_xla_call_p.bind(*in_dim_vals, *consts, *primals, *tangents,
|
|
jaxpr=new_jaxpr, num_consts=len(consts))
|
|
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
|
|
return primals_out, tangents_out
|
|
ad.primitive_jvps[dynamic_xla_call_p] = _dynamic_xla_call_jvp
|
|
|
|
def _dynamic_xla_call_transpose(cts_in, *args, jaxpr, num_consts):
|
|
# TODO make this a dynamic_xla_call_p bind
|
|
del num_consts
|
|
vars_to_vals = dict(
|
|
(d, t) for v, x in zip(jaxpr.in_binders, args)
|
|
if isinstance(v.aval, AbsArray) and not ad.is_undefined_primal(x)
|
|
for d, t in zip(v.aval.shape, x.shape) if isinstance(d, Var))
|
|
dim_args = [vars_to_vals[v] for v in jaxpr.in_dim_binders]
|
|
consts_bar, args_bar = backward_pass(jaxpr, dim_args, args, cts_in) # type: ignore
|
|
return [*consts_bar, *args_bar]
|
|
ad.primitive_transposes[dynamic_xla_call_p] = _dynamic_xla_call_transpose
|
|
|
|
def backward_pass(jaxpr, dim_args, args, cts_in):
|
|
primal_env = {}
|
|
ct_env = {}
|
|
|
|
def write_cotangent(v, ct):
|
|
ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct
|
|
|
|
def read_cotangent(v):
|
|
return ct_env.get(v, ad.Zero(v.aval))
|
|
|
|
def read_primal(v):
|
|
if type(v) is core.Literal:
|
|
raise NotImplementedError # TODO
|
|
else:
|
|
return primal_env.get(v, ad.UndefinedPrimal(v.aval))
|
|
|
|
def write_primal(v, val):
|
|
if not ad.is_undefined_primal(val):
|
|
primal_env[v] = val
|
|
|
|
write_primal(core.unitvar, core.unit)
|
|
map(write_primal, jaxpr.in_dim_binders, dim_args)
|
|
map(write_primal, jaxpr.in_binders, args)
|
|
|
|
map(write_cotangent, jaxpr.outs, cts_in)
|
|
raise NotImplementedError # TODO finish this
|
|
|
|
def jvp_jaxpr(jaxpr):
|
|
f = lu.wrap_init(jaxpr_as_fun(jaxpr))
|
|
dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders)
|
|
in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders]
|
|
jaxpr, consts, _ = trace_to_jaxpr_dynamic(jvp_traceable(ad.jvp(f)), in_avals * 2)
|
|
return jaxpr, consts
|
|
|
|
def _replace_vars_with_avals(dimvars, aval):
|
|
if isinstance(aval, AbsArray):
|
|
shape = [dimvars.get(d, d) for d in aval.shape]
|
|
return AbsArray(tuple(shape), aval._eltTy)
|
|
else:
|
|
return aval
|
|
|
|
@lu.transformation
|
|
def jvp_traceable(*primals_and_tangents):
|
|
n = len(primals_and_tangents)
|
|
primals, tangents = split_list(primals_and_tangents, [n // 2])
|
|
primals_out, tangents_out = yield (primals, tangents), {}
|
|
yield (*primals_out, *tangents_out)
|
|
|
|
def _dynamic_xla_call_pe(trace, *tracers, jaxpr, num_consts):
|
|
in_dim_tracers, tracers = split_list(tracers, [len(jaxpr.in_dim_binders)])
|
|
if any(not t.pval.is_known() for t in in_dim_tracers):
|
|
raise NotImplementedError
|
|
in_unknowns = [not t.pval.is_known() for t in tracers]
|
|
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
|
|
|
|
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
|
|
known_vals = [t.pval.get_known() for t in known_tracers]
|
|
in_dim_vals = [t.pval.get_known() for t in in_dim_tracers]
|
|
outs1_res = dynamic_xla_call_p.bind(*in_dim_vals, *known_vals, jaxpr=jaxpr1,
|
|
num_consts=num_consts)
|
|
outs1, res = split_list(outs1_res, [len(jaxpr1.outs) - num_res])
|
|
|
|
in_dim_tracers = map(trace.new_instantiated_const, in_dim_tracers)
|
|
res_tracers = map(trace.new_instantiated_const, res)
|
|
outs2 = [pe.JaxprTracer(trace, pe.PartialVal.unknown(v.aval), None)
|
|
for v in jaxpr2.outs]
|
|
eqn = pe.new_eqn_recipe(in_dim_tracers + res_tracers + unknown_tracers, outs2,
|
|
dynamic_xla_call_p, dict(jaxpr=jaxpr2, num_consts=0),
|
|
source_info_util.new_source_info())
|
|
for t in outs2: t.recipe = eqn
|
|
outs1, outs2 = iter(outs1), iter(outs2)
|
|
return [next(outs2) if uk else next(outs1) for uk in out_unknowns]
|
|
pe.custom_partial_eval_rules[dynamic_xla_call_p] = _dynamic_xla_call_pe
|
|
|
|
def partial_eval_jaxpr(jaxpr, in_unknowns):
|
|
env: Dict[Var, bool] = {}
|
|
res = []
|
|
|
|
def read(v):
|
|
if type(v) is core.Literal:
|
|
raise NotImplementedError # TODO
|
|
else:
|
|
return env[v]
|
|
|
|
def write(unk, v):
|
|
env[v] = unk
|
|
|
|
def new_res(v):
|
|
res.append(v)
|
|
return v
|
|
|
|
eqns1, eqns2 = [], []
|
|
map(write, in_unknowns, jaxpr.in_binders)
|
|
for eqn in jaxpr.eqns:
|
|
unks = map(read, eqn.invars)
|
|
if any(unks):
|
|
invars = [v if unk else new_res(v) for unk, v in zip(unks, eqn.invars)]
|
|
eqns2.append(pe.new_jaxpr_eqn(invars, eqn.outvars, eqn.primitive,
|
|
eqn.params,
|
|
source_info_util.new_source_info()))
|
|
map(partial(write, True), eqn.outvars)
|
|
else:
|
|
eqns1.append(eqn)
|
|
map(partial(write, False), eqn.outvars)
|
|
out_unknowns = map(read, jaxpr.outs)
|
|
out_dim_unknowns = map(read, jaxpr.out_dims) # when linearizing, all known
|
|
|
|
invars1, invars2 = partition_list(in_unknowns, jaxpr.in_binders)
|
|
outvars1, outvars2 = partition_list(out_unknowns, jaxpr.outs)
|
|
out_dims1, out_dims2 = partition_list(out_dim_unknowns, jaxpr.out_dims)
|
|
|
|
outvars1 = outvars1 + res
|
|
invars2 = res + invars2
|
|
|
|
# TODO forward the correct residuals here (all dimvars used in types)
|
|
in_dimvars2 = out_dims1 + jaxpr.in_dim_binders
|
|
|
|
jaxpr1 = DJaxpr(jaxpr.in_dim_binders, invars1, out_dims1, outvars1, eqns1)
|
|
jaxpr2 = DJaxpr(in_dimvars2, invars2, out_dims2, outvars2, eqns2)
|
|
|
|
return jaxpr1, jaxpr2, out_unknowns, len(res)
|
|
|
|
|
|
# batching
|
|
|
|
from jax.interpreters import batching
|
|
|
|
def _dynamic_xla_call_vmap(args, in_dims, *, jaxpr, num_consts):
|
|
del num_consts
|
|
in_dim_vals, args = split_list(args, [len(jaxpr.in_dim_binders)])
|
|
in_dim_bdims, arg_bdims = split_list(in_dims, [len(jaxpr.in_dim_binders)])
|
|
assert all(d is batching.not_mapped for d in in_dim_bdims)
|
|
axis_size, = {x.shape[d] for x, d in zip(args, arg_bdims)
|
|
if d is not batching.not_mapped}
|
|
new_jaxpr, consts, out_dims = batch_jaxpr(jaxpr, axis_size, arg_bdims)
|
|
outs = dynamic_xla_call_p.bind(*in_dim_vals, *consts, *args,
|
|
jaxpr=new_jaxpr, num_consts=len(consts))
|
|
return outs, out_dims
|
|
batching.primitive_batchers[dynamic_xla_call_p] = _dynamic_xla_call_vmap
|
|
|
|
def batch_jaxpr(jaxpr, axis_size, in_dims):
|
|
dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders)
|
|
in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders]
|
|
|
|
in_avals = [core.unmapped_aval(axis_size, core.no_axis_name, d, aval)
|
|
if d is not batching.not_mapped else aval
|
|
for d, aval in zip(in_dims, in_avals)]
|
|
|
|
fun, out_dims = batching.batch_subtrace(lu.wrap_init(jaxpr_as_fun(jaxpr)))
|
|
f = _batch_fun(fun, in_dims)
|
|
jaxpr, consts, _ = trace_to_jaxpr_dynamic(f, in_avals)
|
|
return jaxpr, consts, out_dims()
|
|
|
|
@lu.transformation
|
|
def _batch_fun(in_dims, *in_vals, **params):
|
|
with core.new_main(batching.BatchTrace, axis_name=core.no_axis_name) as main:
|
|
out_vals = yield (main, in_dims, *in_vals), params
|
|
del main
|
|
yield out_vals
|
|
|
|
def _map_array(size: int, axis: int, aval: AbsArray) -> AbsArray:
|
|
return AbsArray(tuple_delete(aval.shape, axis), aval._eltTy)
|
|
|
|
def _unmap_array(size: int, axis: int, aval: AbsArray) -> AbsArray:
|
|
raise NotImplementedError
|
|
|
|
core.aval_mapping_handlers[AbsArray] = _map_array, _unmap_array
|
|
|
|
|
|
# Primitives
|
|
|
|
import numpy as np
|
|
from jax._src.lax import lax
|
|
|
|
|
|
## sin
|
|
|
|
def sin(x: Any) -> Any:
|
|
return sin_p.bind(x)
|
|
sin_p = core.Primitive('sin_p')
|
|
|
|
@sin_p.def_abstract_eval
|
|
def _sin_abstract_eval(x):
|
|
if isinstance(x, AbsArray):
|
|
return AbsArray(x.shape, x._eltTy)
|
|
else:
|
|
return lax.sin_p.abstract_eval(x)
|
|
|
|
def _sin_typecheck_rule(invar):
|
|
return [invar.aval]
|
|
typecheck_rules[sin_p] = _sin_typecheck_rule
|
|
|
|
def _sin_translation_rule(c, dims, avals, operands):
|
|
(x,), = operands
|
|
return [[xops.Sin(x)]]
|
|
translations[sin_p] = _sin_translation_rule
|
|
|
|
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
|
|
|
|
|
## cos
|
|
|
|
def cos(x: Any) -> Any:
|
|
return cos_p.bind(x)
|
|
cos_p = core.Primitive('cos_p')
|
|
|
|
@cos_p.def_abstract_eval
|
|
def _cos_abstract_eval(x):
|
|
if isinstance(x, AbsArray):
|
|
return AbsArray(x.shape, x._eltTy)
|
|
else:
|
|
return lax.cos_p.abstract_eval(x)
|
|
|
|
def _cos_typecheck_rule(invar):
|
|
return [invar.aval]
|
|
typecheck_rules[cos_p] = _cos_typecheck_rule
|
|
|
|
def _cos_translation_rule(c, dims, avals, operands):
|
|
(x,), = operands
|
|
return [[xops.Cos(x)]]
|
|
translations[cos_p] = _cos_translation_rule
|
|
|
|
|
|
## reduce-sum
|
|
|
|
def reduce_sum(x: Any, axes: Optional[Sequence[int]] = None) -> Any:
|
|
if axes is None:
|
|
axes = tuple(range(len(x.shape)))
|
|
return reduce_sum_p.bind(x, axes=axes)
|
|
reduce_sum_p = core.Primitive('reduce_sum')
|
|
|
|
@reduce_sum_p.def_abstract_eval
|
|
def _sum_abstract_eval(operand, *, axes):
|
|
if isinstance(operand, AbsArray):
|
|
axes = set(axes)
|
|
new_shape = [d for i, d in enumerate(operand.shape) if i not in axes]
|
|
if (all(isinstance(d, int) for d in new_shape) and
|
|
isinstance(operand._eltTy, BaseType)):
|
|
return core.ShapedArray(tuple(new_shape), operand._eltTy._dtype)
|
|
else:
|
|
return AbsArray(tuple(new_shape), operand._eltTy)
|
|
else:
|
|
return lax.reduce_sum_p.reduce_sum_abstract_eval(operand, axes=axes)
|
|
|
|
def _reduce_sum_typecheck_rule(x, *, axes):
|
|
return [reduce_sum_p.abstract_eval(x.aval, axes=axes)]
|
|
typecheck_rules[reduce_sum_p] = _reduce_sum_typecheck_rule
|
|
|
|
def _reduce_sum_translation_traceable(logical_shapes, x, *, axes):
|
|
shape, = logical_shapes
|
|
x = _replace_masked_values(shape, x, 0, axes=axes)
|
|
return [lax._reduce_sum(x, axes=axes)]
|
|
translations[reduce_sum_p] = traceable_to_padded_translation(
|
|
_reduce_sum_translation_traceable)
|
|
|
|
def _replace_masked_values(logical_shape, x, val, axes=None):
|
|
axes = axes or set(range(len(logical_shape)))
|
|
masks = [lax.broadcasted_iota(np.int32, x.shape, i) < d
|
|
for i, d in enumerate(logical_shape) if d is not None and i in axes]
|
|
if masks:
|
|
x = lax.select(reduce(op.and_, masks), x, lax.full_like(x, val))
|
|
return x
|
|
|
|
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
|
raise NotImplementedError # TODO
|
|
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
|
|
|
|
|
### lt
|
|
|
|
def lt(x, y):
|
|
return lt_p.bind(x, y)
|
|
lt_p = core.Primitive('lt')
|
|
|
|
@lt_p.def_abstract_eval
|
|
def _lt_abstract_eval(x, y):
|
|
if isinstance(x, AbsArray) or isinstance(y, AbsArray):
|
|
# TODO check dtypes match
|
|
if not x.shape:
|
|
return AbsArray(y.shape, BaseType(np.dtype('bool')))
|
|
if not y.shape:
|
|
return AbsArray(x.shape, BaseType(np.dtype('bool')))
|
|
map(_dims_must_equal, x.shape, y.shape)
|
|
return AbsArray(x.shape, BaseType(np.dtype('bool')))
|
|
else:
|
|
return lax.lt_p.abstract_eval(x, y)
|
|
|
|
def _lt_typecheck_rule(x, y):
|
|
return [lt_p.abstract_eval(x.aval, y.aval)]
|
|
|
|
def _lt_translation_rule(c, dims, avals, operands):
|
|
(x,), (y,) = operands
|
|
return [[xops.Lt(x, y)]]
|
|
|
|
|
|
### dot
|
|
|
|
def dot(x, y):
|
|
assert len(x.shape) == len(y.shape) == 2
|
|
return dot_general(x, y, ([1], [0]), ([], []))
|
|
|
|
Dims = Tuple[Sequence[int], Sequence[int]]
|
|
|
|
def dot_general(x: Any, y: Any, contract: Dims, batch: Dims) -> Any:
|
|
return dot_general_p.bind(x, y, contract=contract, batch=batch)
|
|
dot_general_p = core.Primitive('dot_general')
|
|
|
|
@dot_general_p.def_abstract_eval
|
|
def _dot_general_abstract_eval(x, y, *, contract, batch):
|
|
for i, j in zip(*contract): _dims_must_equal(x.shape[i], y.shape[j])
|
|
for i, j in zip(*batch): _dims_must_equal(x.shape[i], y.shape[j])
|
|
shape = lax._dot_general_shape_computation(x.shape, y.shape, (contract, batch))
|
|
return AbsArray(shape, x._eltTy)
|
|
|
|
def _dot_general_typecheck_rule(x, y, *, contract, batch):
|
|
return [_dot_general_abstract_eval(x.aval, y.aval,
|
|
contract=contract, batch=batch)]
|
|
typecheck_rules[dot_general_p] = _dot_general_typecheck_rule
|
|
|
|
def _dot_general_trans(logical_shapes, x, y, *, contract, batch):
|
|
x_shape, _ = logical_shapes
|
|
lhs_contract, _ = contract
|
|
x = _replace_masked_values(x_shape, x, 0, axes=lhs_contract)
|
|
return [lax.dot_general(x, y, dimension_numbers=(contract, batch))]
|
|
translations[dot_general_p] = traceable_to_padded_translation(_dot_general_trans)
|
|
|
|
def _dot_general_transpose_rule(cotangent, x, y, *, contract, batch):
|
|
assert False # TODO
|
|
ad.primitive_transposes[dot_general_p] = _dot_general_transpose_rule
|
|
|
|
|
|
## add
|
|
|
|
def add(x: Any, y: Any) -> Any:
|
|
return add_p.bind(x, y)
|
|
add_p = core.Primitive('add')
|
|
|
|
@add_p.def_abstract_eval
|
|
def _add_abstract_eval(x, y):
|
|
if isinstance(x, AbsArray) and isinstance(y, AbsArray):
|
|
map(_dims_must_equal, x.shape, y.shape) # TODO broadcasting?
|
|
return AbsArray(x.shape, x._eltTy)
|
|
else:
|
|
return lax.add_p.abstract_eval(x, y)
|
|
|
|
def _dims_must_equal(d1, d2):
|
|
if isinstance(d1, (Tracer, Var)) and isinstance(d2, (Tracer, Var)):
|
|
if d1.aval is d2.aval: return True
|
|
elif isinstance(d1, int) and isinstance(d2, int):
|
|
return d1 == d2
|
|
raise Exception("can't prove shapes equal (or unequal)!")
|
|
|
|
def _add_typecheck_rule(x, y):
|
|
return [add_p.abstract_eval(x.aval, y.aval)]
|
|
typecheck_rules[add_p] = _add_typecheck_rule
|
|
|
|
def _add_translation_rule(c, dims, avals, operands):
|
|
(x,), (y,) = operands
|
|
return [[xops.Add(x, y)]]
|
|
translations[add_p] = _add_translation_rule
|
|
|
|
|
|
## mul
|
|
|
|
def mul(x: Any, y: Any) -> Any:
|
|
return mul_p.bind(x, y)
|
|
mul_p = core.Primitive('mul')
|
|
|
|
@mul_p.def_abstract_eval
|
|
def _mul_abstract_eval(x, y):
|
|
if isinstance(x, AbsArray) and isinstance(y, AbsArray):
|
|
map(_dims_must_equal, x.shape, y.shape) # TODO broadcasting?
|
|
return AbsArray(x.shape, x._eltTy)
|
|
else:
|
|
return lax.mul_p.abstract_eval(x, y)
|
|
|
|
def _mul_typecheck_rule(x, y):
|
|
return [mul_p.abstract_eval(x.aval, y.aval)]
|
|
typecheck_rules[mul_p] = _mul_typecheck_rule
|
|
|
|
def _mul_translation_rule(c, dims, avals, operands):
|
|
(x,), (y,) = operands
|
|
return [[xops.Mul(x, y)]]
|
|
translations[mul_p] = _mul_translation_rule
|
|
|
|
|
|
## nonzero
|
|
|
|
def nonzero(x):
|
|
return nonzero_p.bind(x)
|
|
nonzero_p = core.Primitive('nonzero')
|
|
|
|
def _nonzero_unpack_result(x):
|
|
return [x.shape[-1], x]
|
|
nonzero_p.unpack_result = _nonzero_unpack_result # type: ignore
|
|
|
|
def _nonzero_staging_rule(trace, tracers, params):
|
|
aval = tracers[0].aval
|
|
if isinstance(aval, AbsArray) and not isinstance(aval._eltTy, BaseType):
|
|
raise NotImplementedError
|
|
bound = aval.shape[-1]
|
|
bound = bound if isinstance(bound, int) else bound._bound
|
|
out_dim_aval = AbsArray(aval.shape[:-1], BoundedIntTy(bound))
|
|
out_dim_tracer = pe.DynamicJaxprTracer(trace, out_dim_aval, None)
|
|
if len(aval.shape) == 1:
|
|
out_val_aval = AbsArray((out_dim_tracer,), BaseType(np.dtype('int32')))
|
|
else:
|
|
indices = tuple(range(len(aval.shape[:-1])))
|
|
expr = DimIndexingExpr(out_dim_tracer, indices)
|
|
out_val_aval = AbsArray((*aval.shape[:-1], expr),
|
|
BaseType(np.dtype('int32')))
|
|
out_val_tracer = pe.DynamicJaxprTracer(trace, out_val_aval, None)
|
|
invars = map(trace.getvar, tracers)
|
|
outvars = map(trace.makevar, [out_dim_tracer, out_val_tracer])
|
|
eqn = pe.new_jaxpr_eqn(invars, outvars, nonzero_p, {},
|
|
source_info_util.new_source_info())
|
|
trace.frame.eqns.append(eqn)
|
|
return out_val_tracer
|
|
custom_staging_rules[nonzero_p] = _nonzero_staging_rule
|
|
|
|
def _nonzero_typecheck_rule(invar):
|
|
bound = invar.aval.shape[-1]
|
|
bound = bound if isinstance(bound, int) else bound._bound
|
|
newvar = core.gensym()
|
|
out_dim_var = newvar(AbsArray(invar.aval.shape[:-1], BoundedIntTy(bound)))
|
|
if len(invar.aval.shape) == 1:
|
|
out_val_aval = AbsArray((out_dim_var,), BaseType(np.dtype('int32')))
|
|
else:
|
|
indices = tuple(range(len(out_dim_var.aval.shape))) # pytype: disable=attribute-error
|
|
expr = DimIndexingExpr(out_dim_var, indices)
|
|
out_val_aval = AbsArray((*invar.aval.shape[:-1], expr),
|
|
BaseType(np.dtype('int32')))
|
|
return out_dim_var, out_val_aval
|
|
typecheck_rules[nonzero_p] = _nonzero_typecheck_rule
|
|
|
|
def _nonzero_translation_traceable(logical_shapes, x):
|
|
shape, = logical_shapes
|
|
assert shape
|
|
x = _replace_masked_values(shape, x, 0)
|
|
nonzero_indicators = x != 0
|
|
last_axis = len(shape) - 1
|
|
out_sizes = lax._reduce_sum(nonzero_indicators.astype(np.int32), [last_axis])
|
|
iota = lax.broadcasted_iota(np.int32, x.shape, dimension=last_axis)
|
|
_, idx = lax.sort_key_val(~nonzero_indicators, iota, dimension=last_axis)
|
|
return out_sizes, idx
|
|
translations[nonzero_p] = traceable_to_padded_translation(
|
|
_nonzero_translation_traceable)
|
|
|
|
def _nonzero_vmap_rule(args, in_dims):
|
|
(x,), (d,) = args, in_dims
|
|
if d != 0: raise NotImplementedError
|
|
return nonzero_p.bind(x), 0
|
|
batching.primitive_batchers[nonzero_p] = _nonzero_vmap_rule
|
|
|
|
|
|
## iota
|
|
|
|
def iota(n):
|
|
return iota_p.bind(n)
|
|
iota_p = core.Primitive('iota')
|
|
|
|
def _iota_staging_rule(trace, tracers, params):
|
|
tracer, = tracers
|
|
n = trace.get_const(tracer)
|
|
if n is not None:
|
|
if type(n) is not int: raise NotImplementedError # TODO batched version?
|
|
out_aval = core.ShapedArray((n,), np.dtype('int32'))
|
|
out_tracer = pe.DynamicJaxprTracer(trace, out_aval, None)
|
|
outvar = trace.makevar(out_tracer)
|
|
eqn = pe.new_jaxpr_eqn([], [outvar], iota_p, dict(size=n),
|
|
source_info_util.new_source_info())
|
|
else:
|
|
aval = tracer.aval
|
|
if not isinstance(aval, AbsArray): raise TypeError
|
|
if aval.shape:
|
|
indices = tuple(range(len(aval.shape)))
|
|
out_aval = AbsArray((*aval.shape, DimIndexingExpr(tracer, indices)),
|
|
BaseType(np.dtype('int32')))
|
|
else:
|
|
out_aval = AbsArray((tracer,), BaseType(np.dtype('int32')))
|
|
out_tracer = pe.DynamicJaxprTracer(trace, out_aval, None)
|
|
outvar = trace.makevar(out_tracer)
|
|
invar = trace.getvar(tracer)
|
|
eqn = pe.new_jaxpr_eqn([invar], [outvar], iota_p, {},
|
|
source_info_util.new_source_info())
|
|
trace.frame.eqns.append(eqn)
|
|
return out_tracer
|
|
custom_staging_rules[iota_p] = _iota_staging_rule
|
|
|
|
def _iota_typecheck_rule(*invars, size=None):
|
|
if size is not None:
|
|
if invars: raise TypeError
|
|
return [core.ShapedArray((size,), np.dtype('int32'))]
|
|
else:
|
|
invar, = invars
|
|
if not invar.aval.shape:
|
|
return [AbsArray((invar,), BaseType(np.dtype('int32')))]
|
|
else:
|
|
indices = tuple(range(len(invar.aval.shape)))
|
|
return [AbsArray((*invar.aval.shape, DimIndexingExpr(invar, indices)),
|
|
BaseType(np.dtype('int32')))]
|
|
typecheck_rules[iota_p] = _iota_typecheck_rule
|
|
|
|
def _iota_translation_rule(c, dims, avals, operands, *, size=None):
|
|
if size is None:
|
|
aval, = avals
|
|
size = aval._eltTy._bound
|
|
shape = aval.shape
|
|
else:
|
|
shape = ()
|
|
etype = xla.dtype_to_primitive_type(np.dtype('int32'))
|
|
xla_shape = xc.Shape.array_shape(etype, (*shape, size))
|
|
return [[xops.Iota(c, xla_shape, len(shape))]]
|
|
translations[iota_p] = _iota_translation_rule
|
|
|
|
|
|
## broadcast
|
|
|
|
def broadcast(x, d):
|
|
return broadcast_p.bind(x, d)
|
|
broadcast_p = core.Primitive('broadcast')
|
|
|
|
def _broadcast_staging_rule(trace, tracers, params):
|
|
x, d = tracers
|
|
d_const = trace.get_const(d)
|
|
if d_const is not None:
|
|
raise NotImplementedError # TODO
|
|
else:
|
|
aval = x.aval
|
|
dtype = aval._eltTy._dtype if isinstance(aval, AbsArray) else aval.dtype
|
|
out_aval = AbsArray((d, *x.shape), BaseType(dtype))
|
|
out_tracer = pe.DynamicJaxprTracer(trace, out_aval, None)
|
|
eqn = pe.new_jaxpr_eqn([trace.getvar(x), trace.getvar(d)],
|
|
[trace.makevar(out_tracer)], broadcast_p, {},
|
|
source_info_util.new_source_info())
|
|
trace.frame.eqns.append(eqn)
|
|
return out_tracer
|
|
custom_staging_rules[broadcast_p] = _broadcast_staging_rule
|
|
|
|
def _broadcast_typecheck_rule(x, d):
|
|
aval = x.aval
|
|
dtype = aval._eltTy._dtype if isinstance(aval, AbsArray) else aval.dtype
|
|
return [AbsArray((d, *x.aval.shape), BaseType(dtype))]
|
|
typecheck_rules[broadcast_p] = _broadcast_typecheck_rule
|
|
|
|
def _broadcast_translation_rule(c, dims, avals, operands, *, size=None):
|
|
(x,), (_,) = operands
|
|
if size is None:
|
|
_, aval = avals
|
|
assert not aval.shape
|
|
size = aval._eltTy._bound
|
|
return [[xops.Broadcast(x, (size,))]]
|
|
translations[broadcast_p] = _broadcast_translation_rule
|
|
|
|
|
|
# Examples
|
|
|
|
import jax.numpy as jnp
|
|
|
|
def bbarray(bound_shape: Tuple[int, ...], x: NDArray):
|
|
sizes: Dict[int, BoundedInt] = {}
|
|
shape = tuple(sizes.setdefault(d, BoundedInt(d, bound))
|
|
for d, bound in zip(x.shape, bound_shape))
|
|
slices = tuple(slice(d) for d in x.shape)
|
|
padded_x = jnp.ones(bound_shape, x.dtype).at[slices].set(x)
|
|
return Array(shape, BaseType(x.dtype), padded_x)
|
|
|
|
def ones_like(x):
|
|
if isinstance(x, Array): # doesn't work with tracers
|
|
return Array(x.shape, x._eltTy, jnp.ones_like(x._data))
|
|
else:
|
|
return jnp.ones_like(x)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import jax
|
|
jax.config.update('jax_platform_name', 'cpu')
|
|
def p(s): print('\n--- ' + str(s))
|
|
|
|
## Staging and typechecking
|
|
|
|
p('typecheck identity')
|
|
def f(x):
|
|
return x
|
|
x = jnp.array([0, 1])
|
|
jaxpr, _, _ = make_djaxpr(f, x)
|
|
print(jaxpr)
|
|
print(typecheck_jaxpr(jaxpr))
|
|
|
|
p('typecheck sin')
|
|
def f(x):
|
|
return sin(x)
|
|
x = bbarray((5,), jnp.arange(3.))
|
|
jaxpr, _, _ = make_djaxpr(f, x)
|
|
print(jaxpr)
|
|
print(typecheck_jaxpr(jaxpr))
|
|
|
|
p('typecheck sin-and-add')
|
|
def f(x):
|
|
y = sin(x)
|
|
z = sin(y)
|
|
return add(y, z)
|
|
x = bbarray((5,), jnp.arange(3.))
|
|
jaxpr, _, _ = make_djaxpr(f, x)
|
|
print(jaxpr)
|
|
print(typecheck_jaxpr(jaxpr))
|
|
|
|
p('typecheck iota(3)')
|
|
def f(): # type: ignore
|
|
return iota(3)
|
|
jaxpr, _, _ = make_djaxpr(f)
|
|
print(jaxpr)
|
|
print(typecheck_jaxpr(jaxpr))
|
|
|
|
p('typecheck nonzero')
|
|
def f(x):
|
|
return nonzero(x)
|
|
x = jnp.array([1, 0, -2, 0, 3, 0])
|
|
jaxpr, _, _ = make_djaxpr(f, x)
|
|
print(jaxpr)
|
|
print(typecheck_jaxpr(jaxpr))
|
|
|
|
p('typecheck sum-of-nonzero')
|
|
def f(x):
|
|
return reduce_sum(nonzero(x), tuple(range(len(x.shape))))
|
|
x = jnp.array([1, 0, -2, 0, 3, 0])
|
|
jaxpr, _, _ = make_djaxpr(f, x)
|
|
print(jaxpr)
|
|
print(typecheck_jaxpr(jaxpr))
|
|
|
|
|
|
## XLA lowering and execution
|
|
|
|
@djit
|
|
def f(x):
|
|
nonzero_idx = nonzero(x)
|
|
return reduce_sum(nonzero_idx)
|
|
p('execute sum of nonzero indices')
|
|
x = jnp.array([0, 1, 0, 1, 0, 1])
|
|
print(f(x))
|
|
print('should be', np.sum(np.nonzero(x)[0]))
|
|
|
|
@djit
|
|
def f(x):
|
|
return nonzero(x)
|
|
p('execute nonzero')
|
|
x = jnp.array([0, 1, 0, 1, 0, 1])
|
|
print(f(x))
|
|
print('should be', np.nonzero(x)[0])
|
|
|
|
@djit
|
|
def f(i):
|
|
return iota(i)
|
|
p('execute iota')
|
|
print(f(BoundedInt(3, 5)))
|
|
print('should be', np.arange(3))
|
|
|
|
@djit
|
|
def f(x, n):
|
|
y = nonzero(x)
|
|
return broadcast(y, n)
|
|
p('execute broadcast')
|
|
x = np.arange(3)
|
|
n = BoundedInt(4, 5)
|
|
print(f(x, n)) # type: ignore
|
|
print(f'should be\n{np.broadcast_to(np.nonzero(x)[0], (4, 2))}')
|
|
|
|
|
|
## ad
|
|
|
|
@djit
|
|
def f(x):
|
|
y = sin(x)
|
|
return reduce_sum(y, axes=(0,))
|
|
x = bbarray((5,), jnp.arange(2.))
|
|
p('basic jvp')
|
|
z, z_dot = jax.jvp(f, (x,), (ones_like(x),))
|
|
print(z, z_dot)
|
|
|
|
|
|
p('basic linearize')
|
|
_, f_lin = jax.linearize(f, x)
|
|
print(f_lin(ones_like(x)))
|
|
|
|
|
|
## vmap
|
|
|
|
@djit
|
|
def f(x):
|
|
return nonzero(x)
|
|
p('vmap of nonzero')
|
|
xs = jnp.array([[0, 1, 0, 1, 0, 1],
|
|
[1, 1, 1, 1, 0, 1]])
|
|
print(jax.vmap(f)(xs))
|
|
|
|
|
|
## dot
|
|
|
|
@djit
|
|
def f(x):
|
|
return dot(x, x)
|
|
p('dot(x, x)')
|
|
x = bbarray((4, 4), np.arange(9., dtype=np.float32).reshape(3, 3))
|
|
print(f(x))
|
|
y = np.arange(9.).reshape(3, 3)
|
|
print(f'should be\n{np.dot(y, y)}')
|