2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import numpy as onp
|
2020-03-21 13:54:30 +01:00
|
|
|
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from .. import core
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
from .. import dtypes
|
2019-07-26 16:48:17 -04:00
|
|
|
from ..core import Trace, Tracer, new_master
|
2020-03-09 20:42:08 +01:00
|
|
|
from ..abstract_arrays import ShapedArray, raise_to_shaped
|
2019-07-27 15:46:14 -07:00
|
|
|
from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
|
2020-01-05 04:35:34 +01:00
|
|
|
from .. import linear_util as lu
|
2020-03-28 14:15:46 -07:00
|
|
|
from ..util import unzip2, partial, safe_map, wrap_name, split_list
|
2019-04-24 21:31:15 -07:00
|
|
|
from . import xla
|
2019-05-15 07:25:03 -07:00
|
|
|
from . import partial_eval as pe
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
map = safe_map
|
|
|
|
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
def batch(fun : lu.WrappedFun, in_vals, in_dims, out_dim_dests):
|
|
|
|
# executes a batched version of `fun` following out_dim_dests
|
|
|
|
batched_fun = batch_fun(fun, in_dims, out_dim_dests)
|
|
|
|
return batched_fun.call_wrapped(*in_vals)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2020-01-15 15:00:38 -08:00
|
|
|
def batch_subtrace(master, in_dims, *in_vals, **params):
|
2018-11-17 18:03:33 -08:00
|
|
|
trace = BatchTrace(master, core.cur_sublevel())
|
2019-10-28 15:20:49 -07:00
|
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
|
|
for val, dim in zip(in_vals, in_dims)]
|
2020-01-15 15:00:38 -08:00
|
|
|
outs = yield in_tracers, params
|
2019-07-27 15:46:14 -07:00
|
|
|
out_tracers = map(trace.full_raise, outs)
|
|
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
|
|
yield out_vals, out_dims
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
|
2020-03-29 20:51:51 -07:00
|
|
|
def batch_fun(fun : lu.WrappedFun, in_dims, out_dim_dests, sum_match=False):
|
2020-01-15 15:00:38 -08:00
|
|
|
# transformation version of batch, which doesn't call the function
|
|
|
|
fun, out_dims = batch_subtrace(fun)
|
2020-03-29 20:51:51 -07:00
|
|
|
return _batch_fun(fun, sum_match, in_dims, out_dims, out_dim_dests)
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
@lu.transformation
|
2020-03-28 16:50:31 +01:00
|
|
|
def _batch_fun(sum_match, in_dims, out_dims_thunk, out_dim_dests, *in_vals, **params):
|
2020-01-15 15:00:38 -08:00
|
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
|
|
|
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
|
|
|
with new_master(BatchTrace) as master:
|
|
|
|
out_vals = yield (master, in_dims,) + in_vals, params
|
|
|
|
del master
|
|
|
|
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
2020-03-28 16:50:31 +01:00
|
|
|
out_dims = out_dims_thunk()
|
|
|
|
for od, od_dest in zip(out_dims, out_dim_dests):
|
2020-04-02 14:44:36 +02:00
|
|
|
if od is not None and not isinstance(od_dest, int) and not od_dest is last and not sum_match:
|
2020-03-28 16:50:31 +01:00
|
|
|
msg = f"vmap has mapped output but out_axes is {od_dest}"
|
|
|
|
raise ValueError(msg)
|
|
|
|
out_vals = map(partial(matchaxis, size, sum_match=sum_match), out_dims, out_dim_dests, out_vals)
|
2020-01-15 15:00:38 -08:00
|
|
|
yield out_vals
|
|
|
|
|
|
|
|
def batch_fun2(fun : lu.WrappedFun, in_dims):
|
|
|
|
# like `batch_fun` but returns output batch dims (so no out_dim_dests)
|
|
|
|
fun, out_dims = batch_subtrace(fun)
|
|
|
|
return _batch_fun2(fun, in_dims), out_dims
|
|
|
|
|
|
|
|
@lu.transformation
|
|
|
|
def _batch_fun2(in_dims, *in_vals, **params):
|
|
|
|
with new_master(BatchTrace) as master:
|
|
|
|
out_vals = yield (master, in_dims,) + in_vals, params
|
|
|
|
del master
|
|
|
|
yield out_vals
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
### tracer
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
# TODO(mattjj): use a special sentinel type rather than None
|
|
|
|
NotMapped = type(None)
|
|
|
|
not_mapped = None
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
class BatchTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['val', 'batch_dim']
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
def __init__(self, trace, val, batch_dim: Optional[int]):
|
|
|
|
assert core.skip_checks or type(batch_dim) in (int, NotMapped) # type: ignore
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.val = val
|
|
|
|
self.batch_dim = batch_dim
|
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
2019-07-27 15:46:14 -07:00
|
|
|
aval = raise_to_shaped(core.get_aval(self.val))
|
|
|
|
if self.batch_dim is not_mapped:
|
|
|
|
return aval
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return aval
|
|
|
|
elif type(aval) is ShapedArray:
|
|
|
|
assert 0 <= self.batch_dim < aval.ndim
|
|
|
|
new_shape = tuple(onp.delete(aval.shape, self.batch_dim))
|
|
|
|
return ShapedArray(new_shape, aval.dtype)
|
2019-06-23 15:31:13 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise TypeError(aval)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def full_lower(self):
|
2019-07-27 15:46:14 -07:00
|
|
|
if self.batch_dim is not_mapped:
|
2018-11-17 18:03:33 -08:00
|
|
|
return core.full_lower(self.val)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
|
|
|
class BatchTrace(Trace):
|
|
|
|
def pure(self, val):
|
2019-07-27 15:46:14 -07:00
|
|
|
return BatchTracer(self, val, not_mapped)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def lift(self, val):
|
2019-07-27 15:46:14 -07:00
|
|
|
return BatchTracer(self, val, not_mapped)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def sublift(self, val):
|
|
|
|
return BatchTracer(self, val.val, val.batch_dim)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
|
2019-07-27 15:46:14 -07:00
|
|
|
if all(bdim is not_mapped for bdim in dims_in):
|
2018-11-17 18:03:33 -08:00
|
|
|
return primitive.bind(*vals_in, **params)
|
|
|
|
else:
|
2019-02-05 08:39:03 -08:00
|
|
|
# TODO(mattjj,phawkins): if no rule implemented, could vmap-via-map here
|
2018-11-17 18:03:33 -08:00
|
|
|
batched_primitive = get_primitive_batcher(primitive)
|
|
|
|
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
|
2019-07-27 15:46:14 -07:00
|
|
|
if primitive.multiple_results:
|
|
|
|
return map(partial(BatchTracer, self), val_out, dim_out)
|
|
|
|
else:
|
|
|
|
return BatchTracer(self, val_out, dim_out)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert call_primitive.multiple_results
|
2020-01-15 15:00:38 -08:00
|
|
|
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
|
2020-03-28 14:15:46 -07:00
|
|
|
if call_primitive in pe.map_primitives:
|
2019-06-04 18:33:52 -07:00
|
|
|
return self.process_map(call_primitive, f, tracers, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2020-01-15 15:00:38 -08:00
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
|
|
if all(bdim is not_mapped for bdim in dims):
|
|
|
|
return call_primitive.bind(f, *vals, **params)
|
|
|
|
else:
|
|
|
|
f, dims_out = batch_subtrace(f, self.master, dims)
|
|
|
|
vals_out = call_primitive.bind(f, *vals, **params)
|
|
|
|
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
2019-06-04 18:33:52 -07:00
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
2019-07-27 15:46:14 -07:00
|
|
|
if all(dim is not_mapped for dim in dims):
|
2019-06-04 18:33:52 -07:00
|
|
|
return map_primitive.bind(f, *vals, **params)
|
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
|
|
|
|
is_batched = tuple(d is not not_mapped for d in dims)
|
|
|
|
vals = [moveaxis(x, d, 1) if d is not not_mapped and d != 1 else x
|
|
|
|
for x, d in zip(vals, dims)]
|
2019-09-11 06:01:32 -07:00
|
|
|
dims = tuple(not_mapped if d is not_mapped else 0 for d in dims)
|
2019-07-27 15:46:14 -07:00
|
|
|
f, dims_out = batch_subtrace(f, self.master, dims)
|
|
|
|
vals_out = map_primitive.bind(f, *vals, **params)
|
2019-09-11 06:01:32 -07:00
|
|
|
dims_out = tuple(d + 1 if d is not not_mapped else d for d in dims_out())
|
|
|
|
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
2019-01-28 10:23:00 -08:00
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
|
|
|
trace = BatchTrace(master, core.cur_sublevel())
|
2019-07-27 15:46:14 -07:00
|
|
|
return map(partial(BatchTracer, trace), x, dims)
|
|
|
|
return vals, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
|
|
fun, out_dims1 = batch_subtrace(fun, self.master, in_dims)
|
|
|
|
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.master, in_dims)
|
|
|
|
out_vals = prim.bind(fun, jvp, *in_vals)
|
|
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
|
|
if not fst:
|
|
|
|
assert out_dims == out_dims[:len(out_dims) // 2] * 2
|
|
|
|
out_dims = out_dims[:len(out_dims) // 2]
|
|
|
|
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
|
|
|
|
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
|
|
fun, out_dims1 = batch_subtrace(fun, self.master, in_dims)
|
|
|
|
fwd, out_dims2 = batch_subtrace(fwd, self.master, in_dims)
|
2020-03-29 20:51:51 -07:00
|
|
|
bwd = batch_fun(bwd, out_dims2, in_dims, sum_match=True)
|
2020-03-28 14:15:46 -07:00
|
|
|
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
|
|
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
|
|
if not fst:
|
|
|
|
out_dims = out_dims[-len(out_vals) % len(out_dims):]
|
|
|
|
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
### primitives
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
|
2020-01-15 15:00:38 -08:00
|
|
|
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def get_primitive_batcher(p):
|
|
|
|
try:
|
|
|
|
return primitive_batchers[p]
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2019-10-30 17:31:37 -07:00
|
|
|
msg = "Batching rule for '{}' not implemented"
|
2020-03-09 22:06:12 +02:00
|
|
|
raise NotImplementedError(msg.format(p)) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defvectorized(prim):
|
|
|
|
primitive_batchers[prim] = partial(vectorized_batcher, prim)
|
|
|
|
|
|
|
|
def vectorized_batcher(prim, batched_args, batch_dims, **params):
|
2019-07-02 12:18:47 -04:00
|
|
|
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
|
2018-11-17 18:03:33 -08:00
|
|
|
return prim.bind(*batched_args, **params), batch_dims[0]
|
|
|
|
|
|
|
|
def defbroadcasting(prim):
|
|
|
|
primitive_batchers[prim] = partial(broadcast_batcher, prim)
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def broadcast_batcher(prim, args, dims, **params):
|
2020-02-10 11:40:05 +01:00
|
|
|
"""Process a primitive with built-in broadcasting.
|
|
|
|
|
|
|
|
Args:
|
2020-02-13 09:28:01 +01:00
|
|
|
args: the possibly-batched arguments
|
2020-02-14 07:23:09 +01:00
|
|
|
dims: list or tuple of the same length as `args`, where each
|
|
|
|
entry indicates the batching state of the corresponding entry to `args`:
|
|
|
|
either an int indicating the batch dimension, or else `not_mapped`
|
|
|
|
indicating no batching.
|
2020-02-10 11:40:05 +01:00
|
|
|
"""
|
2019-07-27 15:46:14 -07:00
|
|
|
shapes = {(x.shape, d) for x, d in zip(args, dims) if onp.ndim(x)}
|
|
|
|
if len(shapes) == 1:
|
|
|
|
# if there's only agreeing batch dims and scalars, just call the primitive
|
|
|
|
d = next(d for d in dims if d is not not_mapped)
|
2019-11-24 13:06:23 -05:00
|
|
|
out = prim.bind(*args, **params)
|
|
|
|
return (out, (d,) * len(out)) if prim.multiple_results else (out, d)
|
2019-07-27 15:46:14 -07:00
|
|
|
else:
|
|
|
|
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
|
|
|
|
args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
|
|
|
|
ndim = max(onp.ndim(x) for x in args) # special-case scalar broadcasting
|
|
|
|
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
|
2019-11-24 13:06:23 -05:00
|
|
|
out = prim.bind(*args, **params)
|
|
|
|
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
def _handle_scalar_broadcasting(nd, x, d):
|
|
|
|
if d is not_mapped or nd == onp.ndim(x):
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return x.reshape(x.shape + (1,) * (nd - onp.ndim(x)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defreducer(prim):
|
|
|
|
primitive_batchers[prim] = partial(reducer_batcher, prim)
|
|
|
|
|
2019-01-10 15:35:15 -08:00
|
|
|
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
|
2018-11-17 18:03:33 -08:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
axes = tuple(onp.where(onp.less(axes, bdim), axes, onp.add(axes, 1)))
|
2019-07-27 15:46:14 -07:00
|
|
|
bdim_out = int(list(onp.delete(onp.arange(operand.ndim), axes)).index(bdim))
|
2019-01-10 15:35:15 -08:00
|
|
|
if 'input_shape' in params:
|
|
|
|
params = dict(params, input_shape=operand.shape)
|
|
|
|
return prim.bind(operand, axes=axes, **params), bdim_out
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-02 12:18:47 -04:00
|
|
|
# sets up primitive batchers for ad_util and xla primitives
|
2019-01-07 08:34:48 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def add_batched(batched_args, batch_dims):
|
|
|
|
bdx, bdy = batch_dims
|
2019-05-22 15:59:30 -07:00
|
|
|
x, y = batched_args
|
2019-07-27 15:46:14 -07:00
|
|
|
if bdx == bdy or core.get_aval(x) == core.abstract_unit:
|
|
|
|
return add_jaxvals(x, y), bdx
|
|
|
|
elif bdx is not_mapped:
|
|
|
|
x = broadcast(x, y.shape[bdy], bdy)
|
|
|
|
return add_jaxvals(x, y), bdy
|
|
|
|
elif bdy is not_mapped:
|
|
|
|
y = broadcast(y, x.shape[bdx], bdx)
|
|
|
|
return add_jaxvals(x, y), bdx
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
x = moveaxis(x, bdx, bdy)
|
|
|
|
return add_jaxvals(x, y), bdy
|
2018-11-17 18:03:33 -08:00
|
|
|
primitive_batchers[add_jaxvals_p] = add_batched
|
|
|
|
|
2019-01-07 08:34:48 -08:00
|
|
|
def zeros_like_batched(batched_args, batch_dims):
|
|
|
|
val, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
return zeros_like_jaxval(val), bdim
|
|
|
|
primitive_batchers[zeros_like_p] = zeros_like_batched
|
|
|
|
|
2019-07-02 12:18:47 -04:00
|
|
|
defvectorized(xla.device_put_p)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
### util
|
|
|
|
|
2018-12-21 08:11:36 -08:00
|
|
|
# These utilities depend on primitives for things like broadcasting, reshaping,
|
|
|
|
# and transposition on arrays. To avoid a circular import from depending on
|
|
|
|
# lax.py, these functions use method dispatch on their arguments, which could be
|
|
|
|
# DeviceArrays, numpy.ndarrays, or traced versions of those. This strategy
|
|
|
|
# almost works, except for broadcast, for which raw numpy.ndarrays don't have a
|
|
|
|
# method. To handle that case, the `broadcast` function uses a try/except.
|
|
|
|
|
2019-09-23 13:35:52 -07:00
|
|
|
class _Last(object): pass
|
|
|
|
last = _Last()
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def broadcast(x, sz, axis):
|
|
|
|
if core.get_aval(x) is core.abstract_unit:
|
|
|
|
return core.unit
|
2019-09-23 13:35:52 -07:00
|
|
|
if axis is last:
|
|
|
|
axis = onp.ndim(x)
|
2019-07-27 15:46:14 -07:00
|
|
|
shape = list(onp.shape(x))
|
|
|
|
shape.insert(axis, sz)
|
|
|
|
if isinstance(x, onp.ndarray) or onp.isscalar(x):
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return onp.broadcast_to(dtypes.coerce_to_array(x), shape)
|
2019-05-15 07:25:03 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
broadcast_dims = tuple(onp.delete(onp.arange(len(shape)), axis))
|
|
|
|
return x.broadcast_in_dim(shape, broadcast_dims)
|
|
|
|
|
|
|
|
def moveaxis(x, src, dst):
|
|
|
|
if core.get_aval(x) is core.abstract_unit:
|
|
|
|
return core.unit
|
2019-11-21 11:52:58 -05:00
|
|
|
if src == dst:
|
|
|
|
return x
|
2019-07-27 15:46:14 -07:00
|
|
|
src, dst = src % x.ndim, dst % x.ndim
|
|
|
|
perm = [i for i in range(onp.ndim(x)) if i != src]
|
|
|
|
perm.insert(dst, src)
|
|
|
|
return x.transpose(perm)
|
|
|
|
|
2020-03-29 20:51:51 -07:00
|
|
|
def matchaxis(sz, src, dst, x, sum_match=False):
|
2019-07-27 15:46:14 -07:00
|
|
|
if core.get_aval(x) is core.abstract_unit:
|
|
|
|
return core.unit
|
2019-05-15 07:25:03 -07:00
|
|
|
if src == dst:
|
|
|
|
return x
|
2019-07-27 15:46:14 -07:00
|
|
|
elif type(src) == type(dst) == int:
|
|
|
|
return moveaxis(x, src, dst)
|
2019-09-23 13:35:52 -07:00
|
|
|
elif type(src) == int and dst is last:
|
|
|
|
return moveaxis(x, src, -1)
|
2019-07-27 15:46:14 -07:00
|
|
|
elif src is not_mapped and dst is not not_mapped:
|
|
|
|
return broadcast(x, sz, dst)
|
2020-03-29 20:51:51 -07:00
|
|
|
elif dst is None and sum_match:
|
|
|
|
return x.sum(src)
|
2019-05-15 07:25:03 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise ValueError((src, dst))
|
2019-05-15 07:25:03 -07:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def bdim_at_front(x, bdim, size):
|
|
|
|
if core.get_aval(x) is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
if bdim is not_mapped:
|
|
|
|
return broadcast(x, size, 0)
|
2019-05-15 07:25:03 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return moveaxis(x, bdim, 0)
|
2019-05-15 07:25:03 -07:00
|
|
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def _promote_aval_rank(sz, aval):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.abstract_unit
|
2019-05-15 07:25:03 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return ShapedArray((sz,) + aval.shape, aval.dtype)
|
2019-05-15 07:25:03 -07:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def batch_jaxpr(jaxpr, size, batched, instantiate):
|
2020-01-05 04:35:34 +01:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-07-27 15:46:14 -07:00
|
|
|
f, batched_out = batched_traceable(f, size, batched, instantiate)
|
|
|
|
avals_in = [_promote_aval_rank(size, a) if b else a
|
|
|
|
for a, b in zip(jaxpr.in_avals, batched)]
|
|
|
|
in_pvals = [pe.PartialVal((aval, core.unit)) for aval in avals_in]
|
|
|
|
jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
|
|
|
|
avals_out, _ = unzip2(pvals_out)
|
|
|
|
jaxpr_out = core.TypedJaxpr(jaxpr_out, consts_out, avals_in, avals_out)
|
|
|
|
return jaxpr_out, batched_out()
|
2019-05-15 07:25:03 -07:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def batched_traceable(size, batched, instantiate, *vals):
|
|
|
|
in_dims = [0 if b else None for b in batched]
|
2019-05-15 07:25:03 -07:00
|
|
|
with new_master(BatchTrace) as master:
|
|
|
|
trace = BatchTrace(master, core.cur_sublevel())
|
2019-07-27 15:46:14 -07:00
|
|
|
ans = yield map(partial(BatchTracer, trace), vals, in_dims), {}
|
|
|
|
out_tracers = map(trace.full_raise, ans)
|
|
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
|
|
del master, out_tracers
|
|
|
|
if type(instantiate) is bool:
|
|
|
|
instantiate = [instantiate] * len(out_vals)
|
|
|
|
out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
|
|
|
else broadcast(x, size, 0) if d is not_mapped and inst else x
|
|
|
|
for x, d, inst in zip(out_vals, out_dims, instantiate)]
|
|
|
|
out_batched = [d is not not_mapped or inst
|
|
|
|
for d, inst in zip(out_dims, instantiate)]
|
|
|
|
yield out_vals, out_batched
|
2020-03-28 14:15:46 -07:00
|
|
|
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
|
|
|
def batch_custom_jvp_subtrace(master, in_dims, *in_vals):
|
|
|
|
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
|
|
|
trace = BatchTrace(master, core.cur_sublevel())
|
|
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
|
|
for val, dim in zip(in_vals, in_dims * 2)]
|
|
|
|
outs = yield in_tracers, {}
|
|
|
|
out_tracers = map(trace.full_raise, outs)
|
|
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
|
|
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
|
|
|
|
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
|
|
|
|
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
|
|
|
|
out_primals = map(partial(matchaxis, size), out_primal_bds, out_dims, out_primals)
|
|
|
|
out_tangents = map(partial(matchaxis, size), out_tangent_bds, out_dims, out_tangents)
|
|
|
|
yield out_primals + out_tangents, out_dims * 2
|
|
|
|
|
|
|
|
def _merge_bdims(x, y):
|
|
|
|
if x == y:
|
|
|
|
return x
|
|
|
|
elif x is not_mapped:
|
|
|
|
return y
|
|
|
|
elif y is not_mapped:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return x # arbitrary
|