mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
made polymorphic jaxprs, reshape fail
This commit is contained in:
parent
a609ae7071
commit
fbc85af54f
@ -580,7 +580,8 @@ def reshape(operand, new_sizes, dimensions=None):
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
||||
operator.
|
||||
"""
|
||||
new_sizes = _canonicalize_shape(new_sizes)
|
||||
# new_sizes = _canonicalize_shape(new_sizes) # TODO
|
||||
new_sizes = tuple(new_sizes)
|
||||
same_shape = onp.shape(operand) == new_sizes
|
||||
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
|
||||
if onp.shape(operand) and same_shape and same_dims:
|
||||
|
239
mask.py
239
mask.py
@ -1,6 +1,6 @@
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import defaultdict, Counter
|
||||
from collections import defaultdict, Counter, namedtuple
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import operator as op
|
||||
@ -29,17 +29,17 @@ def prod(xs):
|
||||
|
||||
### main transformation functions
|
||||
|
||||
def mask_fun(fun, shape_env, in_vals, shape_exprs):
|
||||
def mask_fun(fun, shape_envs, in_vals, shape_exprs):
|
||||
with core.new_master(MaskTrace) as master:
|
||||
fun, out_shapes = mask_subtrace(fun, master, shape_env)
|
||||
fun, out_shapes = mask_subtrace(fun, master, shape_envs)
|
||||
out_vals = fun.call_wrapped(in_vals, shape_exprs)
|
||||
del master
|
||||
return out_vals, out_shapes()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def mask_subtrace(master, shape_env, in_vals, shape_exprs):
|
||||
def mask_subtrace(master, shape_envs, in_vals, shape_exprs):
|
||||
trace = MaskTrace(master, core.cur_sublevel())
|
||||
in_tracers = map(partial(MaskTracer, trace, shape_env),
|
||||
in_tracers = map(partial(MaskTracer, trace, shape_envs),
|
||||
in_vals, shape_exprs)
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, outs)
|
||||
@ -72,6 +72,9 @@ class Poly(Counter): # type Poly = Map Mon Int -- monomials to coeffs
|
||||
def __add__(p1, p2):
|
||||
return Poly(Counter.__add__(p1, p2))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(self.items()))
|
||||
|
||||
def __str__(self):
|
||||
return ' + '.join('{} {}'.format(v, k) if v != 1 else str(k)
|
||||
for k, v in sorted(self.items())).strip()
|
||||
@ -85,9 +88,9 @@ class Mon(Counter): # type Mon = Map Id Int -- ids to degrees
|
||||
for k, v in sorted(self.items()))
|
||||
|
||||
def eval_shape_expr(env, expr):
|
||||
return tuple(eval_poly(env, poly) for poly in expr)
|
||||
return tuple(eval_dim_expr(env, poly) for poly in expr)
|
||||
|
||||
def eval_poly(env, poly):
|
||||
def eval_dim_expr(env, poly):
|
||||
return sum(coeff * prod([env[id] ** deg for id, deg in mon.items()])
|
||||
for mon, coeff in poly.items())
|
||||
|
||||
@ -131,35 +134,35 @@ identifiers = frozenset(string.lowercase)
|
||||
def parse_id(name): return Poly({Mon({name: 1}): 1})
|
||||
def parse_lit(val_str): return Poly({Mon(): int(val_str)})
|
||||
|
||||
print(parse_spec('(m, n)')) # ShapeExpr(m, n)
|
||||
print(parse_spec('(m * n)')) # ShapeExpr(m n)
|
||||
print(parse_spec('(m * n,)')) # ShapeExpr(m n)
|
||||
print(parse_spec('(3, m)')) # ShapeExpr(3, m)
|
||||
print(parse_spec('(3 * m)')) # ShapeExpr(3 m)
|
||||
print(parse_spec('m')) # ShapeExpr(m)
|
||||
print(parse_spec('')) # ShapeExpr()
|
||||
Shape = parse_spec # convenience
|
||||
|
||||
# Tests:
|
||||
print(Shape('(m, n)')) # ShapeExpr(m, n)
|
||||
print(Shape('(m * n)')) # ShapeExpr(m n)
|
||||
print(Shape('m * n')) # ShapeExpr(m n)
|
||||
print(Shape('(m * n,)')) # ShapeExpr(m n)
|
||||
print(Shape('(3, m)')) # ShapeExpr(3, m)
|
||||
print(Shape('(3 * m)')) # ShapeExpr(3 m)
|
||||
print(Shape('m')) # ShapeExpr(m)
|
||||
print(Shape('')) # ShapeExpr()
|
||||
|
||||
Shape = parse_spec
|
||||
|
||||
### tracer machinery
|
||||
|
||||
class MaskTracer(Tracer):
|
||||
__slots__ = ["val", "shape_expr", "shape_env"]
|
||||
ShapeEnvs = namedtuple("ShapeEnvs", ["logical", "padded"])
|
||||
|
||||
def __init__(self, trace, shape_env, val, shape_expr):
|
||||
class MaskTracer(Tracer):
|
||||
__slots__ = ["val", "shape_expr", "shape_envs", "log_shape_env"]
|
||||
|
||||
def __init__(self, trace, shape_envs, val, shape_expr):
|
||||
self.trace = trace
|
||||
self.shape_env = shape_env
|
||||
self.shape_envs = shape_envs
|
||||
self.val = val
|
||||
self.shape_expr = shape_expr
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
# TODO can avoid some blowups, also improve error messages
|
||||
if self.shape_env is not None:
|
||||
shape = eval_shape_expr(self.shape_env, self.shape_expr)
|
||||
return ShapedArray(tuple(shape), self.val.dtype)
|
||||
else:
|
||||
return ShapedArray(self.val.shape, self.val.dtype)
|
||||
return ShapedArray(self.shape_expr, self.val.dtype)
|
||||
|
||||
def full_lower(self):
|
||||
if all(type(s) is int for s in self.shape_expr):
|
||||
@ -175,22 +178,27 @@ class MaskTrace(Trace):
|
||||
return MaskTracer(self, None, val, ShapeExpr(*onp.shape(val)))
|
||||
|
||||
def sublift(self, val):
|
||||
return MaskTracer(self, val.shape_env, val.val, val.shape_expr)
|
||||
return MaskTracer(self, val.shape_envs, val.val, val.shape_expr)
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
shape_env = next(t.shape_env for t in tracers if t.shape_env is not None)
|
||||
shape_envs = next(t.shape_envs for t in tracers if t.shape_envs is not None)
|
||||
vals, shape_exprs = unzip2((t.val, t.shape_expr) for t in tracers)
|
||||
out_shape = shape_rules[primitive](shape_exprs, **params)
|
||||
logical_shapes = map(partial(eval_shape_expr, shape_env), shape_exprs)
|
||||
out = masking_rules[primitive](vals, logical_shapes, **params)
|
||||
if not primitive.multiple_results:
|
||||
return MaskTracer(self, shape_env, out, out_shape)
|
||||
if primitive in shape_parameterized_primitive_rules:
|
||||
rule = shape_parameterized_primitive_rules[primitive]
|
||||
out, out_shape = rule(shape_envs, vals, shape_exprs, **params)
|
||||
else:
|
||||
return map(partial(MaskTracer, self, shape_env), out, out_shape)
|
||||
out_shape = shape_rules[primitive](shape_exprs, **params)
|
||||
logical_shapes = map(partial(eval_shape_expr, shape_envs.logical), shape_exprs)
|
||||
out = masking_rules[primitive](vals, logical_shapes, **params)
|
||||
if not primitive.multiple_results:
|
||||
return MaskTracer(self, shape_envs, out, out_shape)
|
||||
else:
|
||||
return map(partial(MaskTracer, self, shape_envs), out, out_shape)
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
shape_parameterized_primitive_rules = {}
|
||||
masking_rules = {}
|
||||
shape_rules = {}
|
||||
|
||||
@ -255,57 +263,6 @@ defvectorized(lax.log_p)
|
||||
defvectorized(lax.tanh_p)
|
||||
|
||||
|
||||
def scan_shape_rule(shape_exprs, forward, length, jaxpr, num_consts, num_carry,
|
||||
linear):
|
||||
const_shexprs, init_shexprs, xs_shexprs = split_list(shape_exprs, [num_consts, num_carry])
|
||||
if (any(any(type(d) is Id for d in shexpr) for shexpr in const_shexprs)
|
||||
or any(any(type(d) is Id for d in shexpr) for shexpr in init_shexprs)
|
||||
or any(any(type(d) is Id for d in shexpr[1:]) for shexpr in xs_shexprs)):
|
||||
raise NotImplementedError
|
||||
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
ys_shapes = [ShapeExpr(length, *y_aval.shape) for y_aval in y_avals]
|
||||
return init_shexprs + ys_shapes
|
||||
shape_rules[lax.scan_p] = scan_shape_rule
|
||||
|
||||
def scan_masking_rule(padded_vals, logical_shapes, forward, length, jaxpr,
|
||||
num_consts, num_carry, linear):
|
||||
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
|
||||
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
|
||||
max_length, = {x.shape[0] for x in xs}
|
||||
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
||||
out_vals = lax.scan_p.bind(
|
||||
*it.chain([length] + consts, [0], init, xs),
|
||||
forward=forward, length=max_length, jaxpr=masked_jaxpr,
|
||||
num_consts=1 + num_consts, num_carry=1 + num_carry,
|
||||
linear=[False] + const_linear + [False] + init_linear + xs_linear)
|
||||
return out_vals[1:]
|
||||
masking_rules[lax.scan_p] = scan_masking_rule
|
||||
|
||||
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
||||
fun = core.jaxpr_as_fun(jaxpr)
|
||||
|
||||
@lu.wrap_init
|
||||
def masked(*args):
|
||||
[dynamic_length], consts, [i], carry, xs = split_list(
|
||||
args, [1, num_consts, 1, num_carry])
|
||||
out = fun(*(consts + carry + xs))
|
||||
new_carry, ys = split_list(out, [num_carry])
|
||||
new_carry = [lax.select(i < dynamic_length, new_c, c)
|
||||
for new_c, c in zip(new_carry, carry)]
|
||||
return [i + 1] + new_carry + ys
|
||||
|
||||
aval = ShapedArray((), onp.int32)
|
||||
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
||||
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
|
||||
|
||||
def _make_typed_jaxpr(traceable, in_avals):
|
||||
pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
|
||||
jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
|
||||
assert not consts
|
||||
out_avals, _ = unzip2(pvals_out)
|
||||
return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
||||
|
||||
|
||||
def dot_shape_rule(shape_exprs, precision):
|
||||
del precision # Unused.
|
||||
lhs_shape, rhs_shape = shape_exprs
|
||||
@ -347,24 +304,77 @@ def dot_masking_rule(padded_vals, logical_shapes, precision):
|
||||
masking_rules[lax.dot_p] = dot_masking_rule
|
||||
|
||||
|
||||
def scan_shape_rule(shape_exprs, forward, length, jaxpr, num_consts, num_carry,
|
||||
linear):
|
||||
const_shexprs, init_shexprs, xs_shexprs = split_list(shape_exprs, [num_consts, num_carry])
|
||||
if (any(any(type(d) is Id for d in shexpr) for shexpr in const_shexprs)
|
||||
or any(any(type(d) is Id for d in shexpr) for shexpr in init_shexprs)
|
||||
or any(any(type(d) is Id for d in shexpr[1:]) for shexpr in xs_shexprs)):
|
||||
raise NotImplementedError
|
||||
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
ys_shapes = [ShapeExpr(length, *y_aval.shape) for y_aval in y_avals]
|
||||
return init_shexprs + ys_shapes
|
||||
|
||||
def scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
|
||||
jaxpr, num_consts, num_carry, linear):
|
||||
out_shape = scan_shape_rule(shape_exprs, forward, length, jaxpr, num_consts,
|
||||
num_carry, linear)
|
||||
|
||||
dynamic_length = eval_dim_expr(shape_envs.logical, length)
|
||||
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
|
||||
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
|
||||
max_length, = {x.shape[0] for x in xs}
|
||||
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
||||
out_vals = lax.scan_p.bind(
|
||||
*it.chain([dynamic_length] + consts, [0], init, xs),
|
||||
forward=forward, length=max_length, jaxpr=masked_jaxpr,
|
||||
num_consts=1 + num_consts, num_carry=1 + num_carry,
|
||||
linear=[False] + const_linear + [False] + init_linear + xs_linear)
|
||||
return out_vals[1:], out_shape
|
||||
shape_parameterized_primitive_rules[lax.scan_p] = scan_masking_rule
|
||||
|
||||
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
||||
fun = core.jaxpr_as_fun(jaxpr)
|
||||
|
||||
@lu.wrap_init
|
||||
def masked(*args):
|
||||
[dynamic_length], consts, [i], carry, xs = split_list(
|
||||
args, [1, num_consts, 1, num_carry])
|
||||
out = fun(*(consts + carry + xs))
|
||||
new_carry, ys = split_list(out, [num_carry])
|
||||
new_carry = [lax.select(i < dynamic_length, new_c, c)
|
||||
for new_c, c in zip(new_carry, carry)]
|
||||
return [i + 1] + new_carry + ys
|
||||
|
||||
aval = ShapedArray((), onp.int32)
|
||||
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
||||
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
|
||||
|
||||
def _make_typed_jaxpr(traceable, in_avals):
|
||||
pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
|
||||
jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
|
||||
assert not consts
|
||||
out_avals, _ = unzip2(pvals_out)
|
||||
return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
||||
|
||||
|
||||
# TODO remove this
|
||||
def reshape_shape_rule(shape_exprs, new_sizes, dimensions, old_sizes):
|
||||
import ipdb; ipdb.set_trace()
|
||||
del old_sizes # Unused.
|
||||
if dimensions is not None: raise NotImplementedError
|
||||
shape_expr, = shape_exprs
|
||||
if prod(shape_expr) != prod(new_sizes): raise Exception
|
||||
return ShapeExpr(new_sizes)
|
||||
shape_rules[lax.reshape_p] = reshape_shape_rule
|
||||
if prod(shape_expr) != prod(new_sizes): raise ShapeError
|
||||
return new_sizes
|
||||
|
||||
def reshape_masking_rule(padded_vals, logical_shapes, new_sizes, dimensions,
|
||||
old_sizes):
|
||||
import ipdb; ipdb.set_trace()
|
||||
del new_sizes, old_sizes # Unused.
|
||||
def reshape_masking_rule(shape_envs, padded_vals, shape_exprs, new_sizes,
|
||||
dimensions, old_sizes):
|
||||
if dimensions is not None: raise NotImplementedError
|
||||
new_sizes = ShapeExpr(new_sizes) # tuplified
|
||||
out_shape = reshape_shape_rule(shape_exprs, new_sizes, dimensions, old_sizes)
|
||||
padded_operand, = padded_vals
|
||||
new_shape, = logical_shapes
|
||||
return lax.reshape(padded_operand) # TODO
|
||||
masking_rules[lax.reshape_p] = reshape_masking_rule
|
||||
padded_new_sizes = eval_shape_expr(shape_envs.padded, new_sizes)
|
||||
out = lax.reshape(padded_operand, padded_new_sizes)
|
||||
return out, out_shape
|
||||
shape_parameterized_primitive_rules[lax.reshape_p] = reshape_masking_rule
|
||||
|
||||
|
||||
###
|
||||
@ -373,32 +383,27 @@ def mask(fun, in_shapes, out_shape):
|
||||
in_shapes_flat, in_shapes_tree = tree_flatten(in_shapes)
|
||||
out_shapes_flat, out_shapes_tree = tree_flatten(out_shape)
|
||||
|
||||
def wrapped_fun(args, shape_env):
|
||||
def wrapped_fun(args, logical_shape_env):
|
||||
f = lu.wrap_init(fun)
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
assert in_tree == in_shapes_tree
|
||||
# padded_sizes = _check_shape_agreement(args_flat, in_shapes_flat) # TODO
|
||||
padded_shape_env = _bind_shapes(in_shapes_flat, [x.shape for x in args_flat])
|
||||
shape_envs = ShapeEnvs(logical_shape_env, padded_shape_env)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
|
||||
outs, out_shapes_ = mask_fun(flat_fun, shape_env, args_flat, in_shapes_flat)
|
||||
outs, out_shapes_ = mask_fun(flat_fun, shape_envs, args_flat, in_shapes_flat)
|
||||
assert out_shapes_flat == list(out_shapes_)
|
||||
# _check_shape_agreement(outs, out_shapes_flat, padded_sizes)
|
||||
assert all(out.shape == eval_shape_expr(padded_shape_env, expr)
|
||||
for out, expr in zip(outs, out_shapes_flat))
|
||||
return tree_unflatten(out_tree(), outs)
|
||||
return wrapped_fun
|
||||
|
||||
def _check_shape_agreement(padded_args, shape_exprs, shape_values=None):
|
||||
shape_values = shape_values or defaultdict(set)
|
||||
for arg, shexpr in zip(padded_args, shape_exprs):
|
||||
for padded_size, size_expr in zip(arg.shape, shexpr):
|
||||
if type(size_expr) is Id:
|
||||
shape_values[size_expr].add(padded_size)
|
||||
elif type(size_expr) is int:
|
||||
if padded_size != size_expr: raise ShapeError
|
||||
else:
|
||||
raise TypeError(size_expr)
|
||||
for shape_var, sizes in shape_values.items():
|
||||
if len(sizes) != 1:
|
||||
raise ShapeError
|
||||
return shape_values
|
||||
def _bind_shapes(shape_exprs, shapes):
|
||||
env = {}
|
||||
for binders, shape in zip(shape_exprs, shapes):
|
||||
for poly, d in zip(binders, shape):
|
||||
(binder,), = poly
|
||||
if env.setdefault(binder, d) != d: raise ShapeError
|
||||
return env
|
||||
|
||||
###
|
||||
|
||||
@ -444,7 +449,8 @@ print(jit_cumsum([np.array([5, 2, 9, 1, 4])], dict(n=4)))
|
||||
|
||||
@partial(mask, in_shapes=[Shape('(m, n)')], out_shape=Shape('m * n'))
|
||||
def flatten(x):
|
||||
pass # TODO
|
||||
return lax.reshape(x, (x.shape[0] * x.shape[1],))
|
||||
print(flatten([np.arange(12).reshape(3, 4)], dict(m=2, n=3)))
|
||||
|
||||
|
||||
# @partial(mask, in_shapes=[Shape('m', 'k'), Shape('k', 'n')],
|
||||
@ -468,6 +474,9 @@ def flatten(x):
|
||||
# argument
|
||||
|
||||
|
||||
# TODO try to revert to the version without polymorphism in jaxprs, get rid of
|
||||
# reshape, make Tracer.aval return an aval with evaluated shapes not shexprs
|
||||
|
||||
# next steps:
|
||||
# 0. generic test setup
|
||||
# 1. clean up shape expression language (maybe handle reshape/conat)
|
||||
|
Loading…
x
Reference in New Issue
Block a user