mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
367 lines
16 KiB
Python
367 lines
16 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Module for calling pallas functions from JAX."""
|
|
from __future__ import annotations
|
|
|
|
from functools import partial
|
|
import itertools as it
|
|
|
|
from typing import Any, Callable, Dict, Sequence, Tuple
|
|
|
|
import jax
|
|
from jax import api_util
|
|
from jax import tree_util
|
|
from jax import lax
|
|
from jax.interpreters import ad
|
|
from jax.interpreters import batching
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax.interpreters import xla
|
|
from jax._src import ad_util
|
|
from jax._src import core as jax_core
|
|
from jax._src import linear_util as lu
|
|
from jax._src.state import discharge as state_discharge
|
|
from jax._src.util import (
|
|
split_list, safe_map, safe_zip, weakref_lru_cache,
|
|
tuple_insert, partition_list)
|
|
from jax._src.lax.control_flow import for_loop
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
from jax._src.pallas import core as pallas_core
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
Grid = pallas_core.Grid
|
|
BlockSpec = pallas_core.BlockSpec
|
|
GridSpec = pallas_core.GridSpec
|
|
BlockMapping = pallas_core.BlockMapping
|
|
GridMapping = pallas_core.GridMapping
|
|
|
|
pallas_call_p = jax_core.Primitive('pallas_call')
|
|
pallas_call_p.multiple_results = True
|
|
|
|
def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
|
|
if start_idx is None:
|
|
assert is_indexing is None
|
|
return value
|
|
assert is_indexing is not None
|
|
output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape)
|
|
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing,
|
|
dtype=np.bool_)])
|
|
return lax.squeeze(output, squeeze_dims)
|
|
|
|
def _maybe_dynamic_update_slice(start_idx, block_shape, value, update,
|
|
is_indexing):
|
|
if start_idx is None:
|
|
assert is_indexing is None
|
|
return update
|
|
assert is_indexing is not None
|
|
broadcast_dims = tuple(i for i, b in enumerate(is_indexing)
|
|
if not b)
|
|
update = lax.broadcast_in_dim(update, block_shape, broadcast_dims)
|
|
assert update.shape == block_shape
|
|
return lax.dynamic_update_slice(value, update, start_idx)
|
|
|
|
def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
|
interpret, debug: bool,
|
|
in_shapes,
|
|
input_output_aliases: Tuple[Tuple[int, int], ...],
|
|
grid_mapping: GridMapping,
|
|
**compiler_params: Any):
|
|
if interpret:
|
|
# If we're in interpreter mode, we *scan* over the grid and eval the
|
|
# discharged jaxpr. This should reproduce exactly what compiling to Triton
|
|
# will do.
|
|
grid = grid_mapping.grid
|
|
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
|
|
if debug:
|
|
print(discharged_jaxpr)
|
|
loop_indices = jnp.array(list(it.product(*(range(g) for g in grid))))
|
|
oi_map = {v: k for k, v in input_output_aliases}
|
|
out = []
|
|
for i, out_shape in enumerate(out_shapes):
|
|
if i in oi_map:
|
|
out.append(args[oi_map[i]])
|
|
else:
|
|
out.append(jnp.zeros(out_shape.shape, out_shape.dtype))
|
|
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
|
|
carry = [*args, *out]
|
|
def cond(carry):
|
|
return carry[0] < loop_indices.shape[0]
|
|
def body(carry):
|
|
i, *carry = carry
|
|
loop_idx = loop_indices[i]
|
|
start_indices = [
|
|
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
|
|
for bm in grid_mapping.block_mappings]
|
|
block_shapes_without_mapped_dims = [
|
|
None if block_mapping is None else block_mapping.block_shape
|
|
for block_mapping in grid_mapping.block_mappings
|
|
]
|
|
is_indexing_dim = [
|
|
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
|
|
for bm in block_shapes_without_mapped_dims
|
|
]
|
|
block_shapes = [
|
|
None if bm is None else tuple(1 if i else b for i, b in zip(iid, bm))
|
|
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
|
|
]
|
|
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
|
|
is_indexing_dim)
|
|
is_mapped_grid_dim = [
|
|
i in grid_mapping.mapped_dims for i in range(len(grid_mapping.grid))]
|
|
local_grid_env, _ = partition_list(is_mapped_grid_dim,
|
|
zip(loop_idx, grid_mapping.grid))
|
|
with pallas_core.grid_env(tuple(local_grid_env)):
|
|
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
|
|
*blocks)
|
|
blocks = blocks[grid_mapping.num_index_operands:]
|
|
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
|
|
carry, blocks, is_indexing_dim)
|
|
return (i + 1, *carry)
|
|
(_, *carry) = lax.while_loop(cond, body, (0, *carry))
|
|
_, out = split_list(carry, [len(args)])
|
|
return out
|
|
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,
|
|
in_shapes=in_shapes,
|
|
out_shapes=out_shapes, which_linear=which_linear,
|
|
grid_mapping=grid_mapping, interpret=interpret,
|
|
debug=debug,
|
|
input_output_aliases=input_output_aliases,
|
|
**compiler_params)
|
|
pallas_call_p.def_impl(_pallas_call_impl)
|
|
|
|
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
|
|
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
|
|
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
|
|
|
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
|
input_output_aliases: Tuple[Tuple[int, int], ...],
|
|
in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any):
|
|
if grid_mapping.num_index_operands:
|
|
raise NotImplementedError
|
|
if input_output_aliases:
|
|
raise NotImplementedError("JVP with aliasing not supported.")
|
|
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
|
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
|
nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes)
|
|
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
|
|
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
|
|
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
|
|
# `pallas_call` takes in inputs and returns outputs but its jaxpr *does not*.
|
|
# `pallas_call` takes in a stateful jaxpr, meaning the jaxpr accepts input
|
|
# `Ref`s that are read from followed by output `Ref`s that are written to.
|
|
# This means that when we do `jvp_jaxpr` on the `jaxpr`, we get out a new
|
|
# jaxpr that has tangents following primals. In order for this jaxpr to be
|
|
# compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around
|
|
# the jaxpr's invars.
|
|
primal_refs, primal_out_refs, tangent_refs, tangent_out_refs = split_list(
|
|
jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)]
|
|
)
|
|
invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs)
|
|
# TODO(sharadmv): Fix state effect tracking after invar switch.
|
|
jvp_jaxpr = jvp_jaxpr.replace(invars=invars)
|
|
if debug:
|
|
print(jvp_jaxpr)
|
|
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
|
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
|
|
out_flat = pallas_call_p.bind(
|
|
*primals,
|
|
*tangents,
|
|
jaxpr=jvp_jaxpr,
|
|
name=f"{name}_jvp",
|
|
in_shapes=(*in_shapes, *in_shapes),
|
|
out_shapes=(*out_shapes, *out_shapes),
|
|
grid_mapping=grid_mapping.replace(block_mappings=jvp_bms),
|
|
which_linear=which_linear + (True,) * len(tangents),
|
|
interpret=interpret,
|
|
debug=debug,
|
|
input_output_aliases=(),
|
|
**compiler_params,
|
|
)
|
|
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
|
|
return out_primals, out_tangents
|
|
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
|
|
|
def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
|
|
dim: int | batching.NotMapped,
|
|
block_mapping: BlockMapping | None) -> BlockMapping:
|
|
def _block_map_function(new_idx, *args):
|
|
if block_mapping is None:
|
|
indices = [0] * len(aval.shape)
|
|
else:
|
|
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
|
|
block_mapping.index_map_jaxpr.consts,
|
|
*args)
|
|
if dim is not batching.not_mapped:
|
|
indices.insert(dim, new_idx)
|
|
return tuple(indices)
|
|
i32_aval = jax_core.ShapedArray((), jnp.int32)
|
|
if block_mapping is None:
|
|
idx_avals = [i32_aval] * (len(grid) + 1)
|
|
else:
|
|
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
|
|
block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(_block_map_function), idx_avals)
|
|
shape = aval.shape if block_mapping is None else block_mapping.block_shape
|
|
if dim is batching.not_mapped:
|
|
new_block_shape = shape
|
|
else:
|
|
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
|
|
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
|
|
if block_mapping is None:
|
|
return BlockMapping(block_shape=new_block_shape, index_map_jaxpr=jaxpr)
|
|
return block_mapping.replace(block_shape=new_block_shape,
|
|
index_map_jaxpr=jaxpr)
|
|
|
|
def _pallas_call_batching_rule(args, dims, *,
|
|
jaxpr: jax_core.Jaxpr,
|
|
name: str,
|
|
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
|
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
|
|
grid_mapping: GridMapping,
|
|
input_output_aliases: Tuple[Tuple[int, int], ...],
|
|
debug: bool,
|
|
interpret: bool,
|
|
which_linear: Tuple[bool, ...],
|
|
**compiler_params: Any):
|
|
if grid_mapping.num_index_operands:
|
|
scalar_batch_dims = dims[:grid_mapping.num_index_operands]
|
|
if any(bdim is not batching.not_mapped for bdim in scalar_batch_dims):
|
|
# TODO(sharadmv,apaszke): enable batching over prefetched scalar args
|
|
raise NotImplementedError
|
|
axis_size, = {x.shape[d] for x, d in zip(args, dims)
|
|
if d is not batching.not_mapped}
|
|
block_mappings = grid_mapping.block_mappings
|
|
avals = [v.aval for v in jaxpr.invars]
|
|
# How should we pick output dimensions? This actually matters because XLA
|
|
# can't optimize our pallas kernels, and this layout impacts performance. For
|
|
# now, because `vmap` doesn't really offer a way of inferring good output
|
|
# dimensions. For now, we just use 0.
|
|
# TODO(sharadmv): explore inferring better output dimensions via a heuristic
|
|
# TODO(sharadmv): explore a long term solution to output dim inference
|
|
|
|
# When we have input/output aliasing, since the output will be mapped, we need
|
|
# to make sure to broadcast the input across that dimension if it is not
|
|
# mapped.
|
|
dims_ = list(dims)
|
|
args_ = list(args)
|
|
for input_index, _ in input_output_aliases:
|
|
dim = dims_[input_index]
|
|
if dim is batching.not_mapped:
|
|
dims_[input_index] = 0
|
|
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
|
|
args = tuple(args_)
|
|
dims = tuple(dims_)
|
|
|
|
all_dims = list(dims) + [0] * len(out_shapes)
|
|
|
|
num_index_operands = grid_mapping.num_index_operands
|
|
batched_block_mappings = map(
|
|
partial(_batch_block_mapping, grid_mapping.grid),
|
|
avals[num_index_operands:], all_dims[num_index_operands:], block_mappings)
|
|
|
|
batched_in_shapes = tuple(
|
|
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
|
|
tuple_insert(x.shape, dim, axis_size),
|
|
x.dtype)
|
|
for x, dim in zip(in_shapes, dims))
|
|
batched_out_shapes = tuple(
|
|
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
|
|
for x in out_shapes)
|
|
|
|
batched_grid_mapping = grid_mapping.replace(
|
|
grid=(axis_size, *grid_mapping.grid),
|
|
block_mappings=tuple(batched_block_mappings),
|
|
mapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.mapped_dims))
|
|
out = pallas_call_p.bind(*args, jaxpr=jaxpr, name=f"batched_{name}",
|
|
in_shapes=batched_in_shapes,
|
|
out_shapes=batched_out_shapes,
|
|
which_linear=which_linear,
|
|
grid_mapping=batched_grid_mapping,
|
|
input_output_aliases=input_output_aliases,
|
|
debug=debug,
|
|
interpret=interpret,
|
|
**compiler_params)
|
|
return out, (0,) * len(out)
|
|
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
|
|
|
|
@weakref_lru_cache
|
|
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
|
primitive_name: str | None = None):
|
|
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
|
lu.wrap_init(fun), in_tree)
|
|
debug = pe.debug_info(fun, in_tree, out_tree_thunk, False,
|
|
primitive_name or "<unknown>")
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
|
|
jaxpr = for_loop._hoist_consts_to_refs(jaxpr)
|
|
return jaxpr, consts, out_tree_thunk()
|
|
|
|
def _extract_function_name(f: Callable, name: str | None) -> str:
|
|
if name is None:
|
|
name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func"
|
|
return name
|
|
|
|
def pallas_call(
|
|
f: Callable[..., None], out_shape: Any, *,
|
|
grid_spec: GridSpec | None = None,
|
|
debug: bool = False,
|
|
grid: Grid | None = None,
|
|
in_specs: Sequence[BlockSpec | None] | None = None,
|
|
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
|
|
input_output_aliases: Dict[int, int] = {},
|
|
interpret: bool = False,
|
|
name: str | None = None,
|
|
**compiler_params: Any):
|
|
if grid_spec is None:
|
|
grid_spec = GridSpec(grid, in_specs, out_specs)
|
|
name = _extract_function_name(f, name)
|
|
singleton = False
|
|
if not isinstance(out_shape, (tuple, list)):
|
|
out_shape = (out_shape,)
|
|
singleton = True
|
|
if not isinstance(out_shape, tuple):
|
|
out_shape = tuple(out_shape)
|
|
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
|
|
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
|
|
for x in flat_out_shapes]
|
|
@jax.jit
|
|
def wrapped(*args):
|
|
flat_args, in_tree = tree_util.tree_flatten(args)
|
|
flat_avals = [jax_core.raise_to_shaped(jax_core.get_aval(a))
|
|
for a in flat_args]
|
|
avals, grid_mapping = grid_spec.get_grid_mapping(flat_avals, in_tree,
|
|
flat_out_shapes, out_tree)
|
|
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
|
|
jaxpr, consts, _ = _initial_style_open_jaxpr(f, jaxpr_in_tree,
|
|
tuple(jaxpr_flat_avals),
|
|
primitive_name="pallas_call")
|
|
which_linear = (False,) * len(flat_args)
|
|
out_flat = pallas_call_p.bind(
|
|
*consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
|
|
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
|
|
for a in flat_args),
|
|
out_shapes=tuple(flat_out_shapes), debug=debug,
|
|
interpret=interpret,
|
|
grid_mapping=grid_mapping,
|
|
input_output_aliases=tuple(input_output_aliases.items()),
|
|
**compiler_params)
|
|
out = tree_util.tree_unflatten(out_tree, out_flat)
|
|
if singleton:
|
|
return out[0]
|
|
return out
|
|
return wrapped
|