made polymorphic jaxprs, reshape fail

This commit is contained in:
Matthew Johnson 2019-08-30 16:06:43 -07:00
parent a609ae7071
commit fbc85af54f
2 changed files with 126 additions and 116 deletions

View File

@ -580,7 +580,8 @@ def reshape(operand, new_sizes, dimensions=None):
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
operator.
"""
new_sizes = _canonicalize_shape(new_sizes)
# new_sizes = _canonicalize_shape(new_sizes) # TODO
new_sizes = tuple(new_sizes)
same_shape = onp.shape(operand) == new_sizes
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
if onp.shape(operand) and same_shape and same_dims:

239
mask.py
View File

@ -1,6 +1,6 @@
from __future__ import print_function
from collections import defaultdict, Counter
from collections import defaultdict, Counter, namedtuple
from functools import partial
import itertools as it
import operator as op
@ -29,17 +29,17 @@ def prod(xs):
### main transformation functions
def mask_fun(fun, shape_env, in_vals, shape_exprs):
def mask_fun(fun, shape_envs, in_vals, shape_exprs):
with core.new_master(MaskTrace) as master:
fun, out_shapes = mask_subtrace(fun, master, shape_env)
fun, out_shapes = mask_subtrace(fun, master, shape_envs)
out_vals = fun.call_wrapped(in_vals, shape_exprs)
del master
return out_vals, out_shapes()
@lu.transformation_with_aux
def mask_subtrace(master, shape_env, in_vals, shape_exprs):
def mask_subtrace(master, shape_envs, in_vals, shape_exprs):
trace = MaskTrace(master, core.cur_sublevel())
in_tracers = map(partial(MaskTracer, trace, shape_env),
in_tracers = map(partial(MaskTracer, trace, shape_envs),
in_vals, shape_exprs)
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
@ -72,6 +72,9 @@ class Poly(Counter): # type Poly = Map Mon Int -- monomials to coeffs
def __add__(p1, p2):
return Poly(Counter.__add__(p1, p2))
def __hash__(self):
return hash(tuple(self.items()))
def __str__(self):
return ' + '.join('{} {}'.format(v, k) if v != 1 else str(k)
for k, v in sorted(self.items())).strip()
@ -85,9 +88,9 @@ class Mon(Counter): # type Mon = Map Id Int -- ids to degrees
for k, v in sorted(self.items()))
def eval_shape_expr(env, expr):
return tuple(eval_poly(env, poly) for poly in expr)
return tuple(eval_dim_expr(env, poly) for poly in expr)
def eval_poly(env, poly):
def eval_dim_expr(env, poly):
return sum(coeff * prod([env[id] ** deg for id, deg in mon.items()])
for mon, coeff in poly.items())
@ -131,35 +134,35 @@ identifiers = frozenset(string.lowercase)
def parse_id(name): return Poly({Mon({name: 1}): 1})
def parse_lit(val_str): return Poly({Mon(): int(val_str)})
print(parse_spec('(m, n)')) # ShapeExpr(m, n)
print(parse_spec('(m * n)')) # ShapeExpr(m n)
print(parse_spec('(m * n,)')) # ShapeExpr(m n)
print(parse_spec('(3, m)')) # ShapeExpr(3, m)
print(parse_spec('(3 * m)')) # ShapeExpr(3 m)
print(parse_spec('m')) # ShapeExpr(m)
print(parse_spec('')) # ShapeExpr()
Shape = parse_spec # convenience
# Tests:
print(Shape('(m, n)')) # ShapeExpr(m, n)
print(Shape('(m * n)')) # ShapeExpr(m n)
print(Shape('m * n')) # ShapeExpr(m n)
print(Shape('(m * n,)')) # ShapeExpr(m n)
print(Shape('(3, m)')) # ShapeExpr(3, m)
print(Shape('(3 * m)')) # ShapeExpr(3 m)
print(Shape('m')) # ShapeExpr(m)
print(Shape('')) # ShapeExpr()
Shape = parse_spec
### tracer machinery
class MaskTracer(Tracer):
__slots__ = ["val", "shape_expr", "shape_env"]
ShapeEnvs = namedtuple("ShapeEnvs", ["logical", "padded"])
def __init__(self, trace, shape_env, val, shape_expr):
class MaskTracer(Tracer):
__slots__ = ["val", "shape_expr", "shape_envs", "log_shape_env"]
def __init__(self, trace, shape_envs, val, shape_expr):
self.trace = trace
self.shape_env = shape_env
self.shape_envs = shape_envs
self.val = val
self.shape_expr = shape_expr
@property
def aval(self):
# TODO can avoid some blowups, also improve error messages
if self.shape_env is not None:
shape = eval_shape_expr(self.shape_env, self.shape_expr)
return ShapedArray(tuple(shape), self.val.dtype)
else:
return ShapedArray(self.val.shape, self.val.dtype)
return ShapedArray(self.shape_expr, self.val.dtype)
def full_lower(self):
if all(type(s) is int for s in self.shape_expr):
@ -175,22 +178,27 @@ class MaskTrace(Trace):
return MaskTracer(self, None, val, ShapeExpr(*onp.shape(val)))
def sublift(self, val):
return MaskTracer(self, val.shape_env, val.val, val.shape_expr)
return MaskTracer(self, val.shape_envs, val.val, val.shape_expr)
def process_primitive(self, primitive, tracers, params):
shape_env = next(t.shape_env for t in tracers if t.shape_env is not None)
shape_envs = next(t.shape_envs for t in tracers if t.shape_envs is not None)
vals, shape_exprs = unzip2((t.val, t.shape_expr) for t in tracers)
out_shape = shape_rules[primitive](shape_exprs, **params)
logical_shapes = map(partial(eval_shape_expr, shape_env), shape_exprs)
out = masking_rules[primitive](vals, logical_shapes, **params)
if not primitive.multiple_results:
return MaskTracer(self, shape_env, out, out_shape)
if primitive in shape_parameterized_primitive_rules:
rule = shape_parameterized_primitive_rules[primitive]
out, out_shape = rule(shape_envs, vals, shape_exprs, **params)
else:
return map(partial(MaskTracer, self, shape_env), out, out_shape)
out_shape = shape_rules[primitive](shape_exprs, **params)
logical_shapes = map(partial(eval_shape_expr, shape_envs.logical), shape_exprs)
out = masking_rules[primitive](vals, logical_shapes, **params)
if not primitive.multiple_results:
return MaskTracer(self, shape_envs, out, out_shape)
else:
return map(partial(MaskTracer, self, shape_envs), out, out_shape)
def process_call(self, call_primitive, f, tracers, params):
raise NotImplementedError # TODO
shape_parameterized_primitive_rules = {}
masking_rules = {}
shape_rules = {}
@ -255,57 +263,6 @@ defvectorized(lax.log_p)
defvectorized(lax.tanh_p)
def scan_shape_rule(shape_exprs, forward, length, jaxpr, num_consts, num_carry,
linear):
const_shexprs, init_shexprs, xs_shexprs = split_list(shape_exprs, [num_consts, num_carry])
if (any(any(type(d) is Id for d in shexpr) for shexpr in const_shexprs)
or any(any(type(d) is Id for d in shexpr) for shexpr in init_shexprs)
or any(any(type(d) is Id for d in shexpr[1:]) for shexpr in xs_shexprs)):
raise NotImplementedError
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_shapes = [ShapeExpr(length, *y_aval.shape) for y_aval in y_avals]
return init_shexprs + ys_shapes
shape_rules[lax.scan_p] = scan_shape_rule
def scan_masking_rule(padded_vals, logical_shapes, forward, length, jaxpr,
num_consts, num_carry, linear):
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
max_length, = {x.shape[0] for x in xs}
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
out_vals = lax.scan_p.bind(
*it.chain([length] + consts, [0], init, xs),
forward=forward, length=max_length, jaxpr=masked_jaxpr,
num_consts=1 + num_consts, num_carry=1 + num_carry,
linear=[False] + const_linear + [False] + init_linear + xs_linear)
return out_vals[1:]
masking_rules[lax.scan_p] = scan_masking_rule
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
fun = core.jaxpr_as_fun(jaxpr)
@lu.wrap_init
def masked(*args):
[dynamic_length], consts, [i], carry, xs = split_list(
args, [1, num_consts, 1, num_carry])
out = fun(*(consts + carry + xs))
new_carry, ys = split_list(out, [num_carry])
new_carry = [lax.select(i < dynamic_length, new_c, c)
for new_c, c in zip(new_carry, carry)]
return [i + 1] + new_carry + ys
aval = ShapedArray((), onp.int32)
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def _make_typed_jaxpr(traceable, in_avals):
pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
assert not consts
out_avals, _ = unzip2(pvals_out)
return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
def dot_shape_rule(shape_exprs, precision):
del precision # Unused.
lhs_shape, rhs_shape = shape_exprs
@ -347,24 +304,77 @@ def dot_masking_rule(padded_vals, logical_shapes, precision):
masking_rules[lax.dot_p] = dot_masking_rule
def scan_shape_rule(shape_exprs, forward, length, jaxpr, num_consts, num_carry,
linear):
const_shexprs, init_shexprs, xs_shexprs = split_list(shape_exprs, [num_consts, num_carry])
if (any(any(type(d) is Id for d in shexpr) for shexpr in const_shexprs)
or any(any(type(d) is Id for d in shexpr) for shexpr in init_shexprs)
or any(any(type(d) is Id for d in shexpr[1:]) for shexpr in xs_shexprs)):
raise NotImplementedError
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_shapes = [ShapeExpr(length, *y_aval.shape) for y_aval in y_avals]
return init_shexprs + ys_shapes
def scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
jaxpr, num_consts, num_carry, linear):
out_shape = scan_shape_rule(shape_exprs, forward, length, jaxpr, num_consts,
num_carry, linear)
dynamic_length = eval_dim_expr(shape_envs.logical, length)
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
max_length, = {x.shape[0] for x in xs}
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
out_vals = lax.scan_p.bind(
*it.chain([dynamic_length] + consts, [0], init, xs),
forward=forward, length=max_length, jaxpr=masked_jaxpr,
num_consts=1 + num_consts, num_carry=1 + num_carry,
linear=[False] + const_linear + [False] + init_linear + xs_linear)
return out_vals[1:], out_shape
shape_parameterized_primitive_rules[lax.scan_p] = scan_masking_rule
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
fun = core.jaxpr_as_fun(jaxpr)
@lu.wrap_init
def masked(*args):
[dynamic_length], consts, [i], carry, xs = split_list(
args, [1, num_consts, 1, num_carry])
out = fun(*(consts + carry + xs))
new_carry, ys = split_list(out, [num_carry])
new_carry = [lax.select(i < dynamic_length, new_c, c)
for new_c, c in zip(new_carry, carry)]
return [i + 1] + new_carry + ys
aval = ShapedArray((), onp.int32)
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def _make_typed_jaxpr(traceable, in_avals):
pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
assert not consts
out_avals, _ = unzip2(pvals_out)
return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
# TODO remove this
def reshape_shape_rule(shape_exprs, new_sizes, dimensions, old_sizes):
import ipdb; ipdb.set_trace()
del old_sizes # Unused.
if dimensions is not None: raise NotImplementedError
shape_expr, = shape_exprs
if prod(shape_expr) != prod(new_sizes): raise Exception
return ShapeExpr(new_sizes)
shape_rules[lax.reshape_p] = reshape_shape_rule
if prod(shape_expr) != prod(new_sizes): raise ShapeError
return new_sizes
def reshape_masking_rule(padded_vals, logical_shapes, new_sizes, dimensions,
old_sizes):
import ipdb; ipdb.set_trace()
del new_sizes, old_sizes # Unused.
def reshape_masking_rule(shape_envs, padded_vals, shape_exprs, new_sizes,
dimensions, old_sizes):
if dimensions is not None: raise NotImplementedError
new_sizes = ShapeExpr(new_sizes) # tuplified
out_shape = reshape_shape_rule(shape_exprs, new_sizes, dimensions, old_sizes)
padded_operand, = padded_vals
new_shape, = logical_shapes
return lax.reshape(padded_operand) # TODO
masking_rules[lax.reshape_p] = reshape_masking_rule
padded_new_sizes = eval_shape_expr(shape_envs.padded, new_sizes)
out = lax.reshape(padded_operand, padded_new_sizes)
return out, out_shape
shape_parameterized_primitive_rules[lax.reshape_p] = reshape_masking_rule
###
@ -373,32 +383,27 @@ def mask(fun, in_shapes, out_shape):
in_shapes_flat, in_shapes_tree = tree_flatten(in_shapes)
out_shapes_flat, out_shapes_tree = tree_flatten(out_shape)
def wrapped_fun(args, shape_env):
def wrapped_fun(args, logical_shape_env):
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten(args)
assert in_tree == in_shapes_tree
# padded_sizes = _check_shape_agreement(args_flat, in_shapes_flat) # TODO
padded_shape_env = _bind_shapes(in_shapes_flat, [x.shape for x in args_flat])
shape_envs = ShapeEnvs(logical_shape_env, padded_shape_env)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
outs, out_shapes_ = mask_fun(flat_fun, shape_env, args_flat, in_shapes_flat)
outs, out_shapes_ = mask_fun(flat_fun, shape_envs, args_flat, in_shapes_flat)
assert out_shapes_flat == list(out_shapes_)
# _check_shape_agreement(outs, out_shapes_flat, padded_sizes)
assert all(out.shape == eval_shape_expr(padded_shape_env, expr)
for out, expr in zip(outs, out_shapes_flat))
return tree_unflatten(out_tree(), outs)
return wrapped_fun
def _check_shape_agreement(padded_args, shape_exprs, shape_values=None):
shape_values = shape_values or defaultdict(set)
for arg, shexpr in zip(padded_args, shape_exprs):
for padded_size, size_expr in zip(arg.shape, shexpr):
if type(size_expr) is Id:
shape_values[size_expr].add(padded_size)
elif type(size_expr) is int:
if padded_size != size_expr: raise ShapeError
else:
raise TypeError(size_expr)
for shape_var, sizes in shape_values.items():
if len(sizes) != 1:
raise ShapeError
return shape_values
def _bind_shapes(shape_exprs, shapes):
env = {}
for binders, shape in zip(shape_exprs, shapes):
for poly, d in zip(binders, shape):
(binder,), = poly
if env.setdefault(binder, d) != d: raise ShapeError
return env
###
@ -444,7 +449,8 @@ print(jit_cumsum([np.array([5, 2, 9, 1, 4])], dict(n=4)))
@partial(mask, in_shapes=[Shape('(m, n)')], out_shape=Shape('m * n'))
def flatten(x):
pass # TODO
return lax.reshape(x, (x.shape[0] * x.shape[1],))
print(flatten([np.arange(12).reshape(3, 4)], dict(m=2, n=3)))
# @partial(mask, in_shapes=[Shape('m', 'k'), Shape('k', 'n')],
@ -468,6 +474,9 @@ def flatten(x):
# argument
# TODO try to revert to the version without polymorphism in jaxprs, get rid of
# reshape, make Tracer.aval return an aval with evaluated shapes not shexprs
# next steps:
# 0. generic test setup
# 1. clean up shape expression language (maybe handle reshape/conat)