mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
support both static- and dynamic-shaped arguments to djit
ed functions
Also clean up the signature of `interp` along the way. Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
55d0f5ef8f
commit
9db8c6a0be
@ -39,6 +39,21 @@ def make_djaxpr(f, abstracted_axes, **make_jaxpr_kwargs):
|
||||
return jaxpr_maker(*args, **kwargs)
|
||||
return djaxpr_maker
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def interp(djaxpr, slab, sizes, args):
|
||||
views = []
|
||||
in_types = [x.aval for x in djaxpr.invars]
|
||||
_, arg_types = util.split_list(in_types, [len(djaxpr.invars) - len(args)])
|
||||
for ty, x in zip(arg_types, args):
|
||||
if isinstance(ty, core.DShapedArray):
|
||||
resolved_shape = tuple(sizes.get(d, d) for d in ty.shape)
|
||||
# TODO(frostig,mattjj): reconstructing slab views seems off?
|
||||
views.append(sl.SlabView(x, resolved_shape, ty.dtype))
|
||||
else:
|
||||
views.append(x)
|
||||
slab, outs = eval_djaxpr(djaxpr, slab, *sizes.values(), *views)
|
||||
return slab, outs
|
||||
|
||||
def _check_axis_size_conflicts(all_axes, sizes):
|
||||
if len(all_axes) != len(set(all_axes)):
|
||||
d = collections.defaultdict(list)
|
||||
@ -48,32 +63,35 @@ def _check_axis_size_conflicts(all_axes, sizes):
|
||||
for name, sizes in d.items() if len(sizes) > 1])
|
||||
raise ValueError(f'abstracted axes resolve to conflicting sizes. {msg}')
|
||||
|
||||
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
|
||||
def interp(djaxpr, abstracted_axes, dim_index, dtypes, slab, dims, addrs):
|
||||
# TODO(frostig,mattjj): reconstructing slab views seems less than ideal
|
||||
dim_index = dict(dim_index)
|
||||
views = []
|
||||
for addr, axes, dtype in zip(addrs, abstracted_axes, dtypes):
|
||||
resolved_shape = tuple(dims[dim_index[name]] for name in axes)
|
||||
views.append(sl.SlabView(addr, resolved_shape, dtype))
|
||||
slab, outs = eval_djaxpr(djaxpr, slab, *dims, *views)
|
||||
return slab, outs
|
||||
|
||||
def djit(f, abstracted_axes, **djit_kwargs):
|
||||
# TODO(frostig,mattjj): un/flatten f
|
||||
def f_wrapped(slab, *args): # TODO(frostig,mattjj): kw support
|
||||
djaxpr = make_djaxpr(f, abstracted_axes, **djit_kwargs)(*args).jaxpr
|
||||
slab, views = sl.chain(slab, sl.slab_upload, *args, unary=True)
|
||||
shapes = [x.shape for x in args]
|
||||
all_axes, sizes = util.unzip2(
|
||||
{(name, sz): None for axes, shape in zip(abstracted_axes, shapes)
|
||||
for name, sz in zip(axes, shape)})
|
||||
_check_axis_size_conflicts(all_axes, sizes)
|
||||
dim_index = {n: i for i, n in enumerate(all_axes)}
|
||||
in_types = [x.aval for x in djaxpr.invars]
|
||||
_, arg_types = util.split_list(in_types, [len(djaxpr.invars) - len(args)])
|
||||
|
||||
def upload(slab, ty, x):
|
||||
if isinstance(ty, core.DShapedArray):
|
||||
return sl.slab_upload(slab, x)
|
||||
elif isinstance(ty, core.ShapedArray):
|
||||
return slab, x
|
||||
else:
|
||||
assert False
|
||||
|
||||
slab, views = sl.chain(slab, upload, *zip(arg_types, args))
|
||||
|
||||
sizes: dict[core.Var, int] = {}
|
||||
for ty, x in zip(arg_types, args):
|
||||
for v, d in zip(ty.shape, x.shape):
|
||||
if isinstance(v, core.Var):
|
||||
d_ = sizes.setdefault(v, d)
|
||||
if d_ != d:
|
||||
raise ValueError(
|
||||
f'abstract dimension bound to unequal sizes: {d_} != {d}')
|
||||
|
||||
slab, out_views = interp(
|
||||
djaxpr, abstracted_axes, tuple(dim_index.items()),
|
||||
tuple(v.dtype for v in views), slab, sizes, [v.addr for v in views])
|
||||
djaxpr, slab, sizes,
|
||||
[v.addr if isinstance(v, sl.SlabView) else v for v in views])
|
||||
return slab, tuple(sl.slab_download(slab, v) for v in out_views)
|
||||
|
||||
return f_wrapped
|
||||
|
Loading…
x
Reference in New Issue
Block a user