Merge pull request #193 from ROCm/ci-upstream-sync-78_1

CI: 01/06/25 upstream sync
This commit is contained in:
charleshofer 2025-01-06 11:22:23 -06:00 committed by GitHub
commit 4b11080f18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 1146 additions and 476 deletions

View File

@ -85,7 +85,7 @@ jobs:
&& steps.status.outcome == 'failure'
&& github.event_name == 'schedule'
&& github.repository == 'jax-ml/jax'
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: output-${{ matrix.python-version }}-log.jsonl
path: output-${{ matrix.python-version }}-log.jsonl

View File

@ -45,7 +45,7 @@ jobs:
--bazel_options=--config=win_clang `
--verbose
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
path: ${{ github.workspace }}\dist\*.whl

View File

@ -54,7 +54,7 @@ jobs:
--bazel_options=--config=win_clang `
--verbose
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: wheels
path: ${{ github.workspace }}\jax\dist\*.whl

View File

@ -4,6 +4,10 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
For the changes specific to the experimental Pallas APIs,
see {ref}`pallas-changelog`.
JAX follows Effort-based versioning; for a discussion of this and JAX's API
compatibility policy, refer to {ref}`api-compatibility`. For the Python and
NumPy version support policy, refer to {ref}`version-support-policy`.
<!--
Remember to align the itemized text with the first line of an item within a list.
@ -12,6 +16,16 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## Unreleased
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
@ -20,21 +34,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
is on by default.
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
and `jax.errors.JaxRuntimeError` respectively.
## jax 0.4.38 (Dec 17, 2024)

View File

@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
### Simplified Build Script
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.

View File

@ -93,6 +93,12 @@ plugins" error described {ref}`below <multiple_installs>`. See
<https://www.tensorflow.org/guide/profiler> for more information on installing
TensorBoard.
Nightly version of TensorBoard profiler requires nightly tensorflow and
tensorboard
```shell
pip install tf-nightly tb-nightly tbp-nightly
```
### Programmatic capture
You can instrument your code to capture a profiler trace via the

View File

@ -490,7 +490,7 @@ This section covers some of the most common patterns with JAX pytrees.
### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose`
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func}`jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
**Option 1:** Use {func}`jax.tree.map`. Here's an example:

View File

@ -49,7 +49,6 @@ def masked_array_error(*args, **kwargs):
"Use arr.filled() to convert the value to a standard numpy array.")
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
core.shaped_abstractify_handlers[np.ma.MaskedArray] = masked_array_error
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
@ -58,7 +57,6 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
core.shaped_abstractify_handlers[np.ndarray] = _make_shaped_array_for_numpy_array
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
@ -68,7 +66,6 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.shaped_abstractify_handlers[t] = _make_shaped_array_for_numpy_scalar
core.literalable_types.update(array_types)
@ -81,6 +78,5 @@ def _make_abstract_python_scalar(typ, val):
for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
core.shaped_abstractify_handlers[t] = partial(_make_abstract_python_scalar, t)
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

View File

@ -2564,7 +2564,6 @@ def _sds_aval_mapping(x):
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=x.weak_type)
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
core.shaped_abstractify_handlers[ShapeDtypeStruct] = _sds_aval_mapping
@api_boundary

View File

@ -1035,7 +1035,6 @@ def _get_aval_array(self):
else:
return self.aval
core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
# TODO(jakevdp) replace this with true inheritance at the C++ level.

View File

@ -192,6 +192,14 @@ def get_compile_options(
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value
# This is a temporary workaround to simplify the AutoPGLE usage.
# TODO(b/376647494): Remove once the bug is fixed.
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
if env_options_overrides is None:
env_options_overrides = {}
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
overrides_on_build_options = [

View File

@ -656,6 +656,13 @@ def check_bool_conversion(arr: Array):
" is ambiguous. Use a.any() or a.all()")
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
pytype_aval_mappings[str] = _str_abstractify
def _aval_property(name):
return property(lambda self: getattr(self.aval, name))
@ -918,6 +925,8 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])
pytype_aval_mappings[Tracer] = lambda x: x.aval
def check_eval_args(args):
for arg in args:
if isinstance(arg, Tracer):
@ -1400,45 +1409,51 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")
def _shaped_abstractify_slow(x):
try:
return x if isinstance(x, AbstractValue) else get_aval(x)
except TypeError:
pass
# We have three flavors of abstractification APIs here which each used to have
# their own separate implementation. Now they're effectively the same, with the
# following differences:
#
# - abstractify returns avals for non-traced array-like objects.
# - get_aval is like abstractify, but also accepts tracers.
# - shaped_abstractify is like get_aval, but also accepts duck-typed arrays.
#
# TODO(jakevdp): can these be unified further?
weak_type = getattr(x, 'weak_type', False)
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return ShapedArray(np.shape(x), dtype, weak_type=weak_type)
# TODO(jakevdp): deduplicate this with abstractify
def shaped_abstractify(x):
# This was originally api_util.shaped_abstractify; temporarily moved
# here in order to facilitate combining it with abstractify.
handler = shaped_abstractify_handlers.get(type(x), None)
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
for t in typ.__mro__[1:]:
if (aval_fn := pytype_aval_mappings.get(t)):
return aval_fn(x)
if isinstance(x, AbstractValue):
return x
if hasattr(x, '__jax_array__'):
return shaped_abstractify(x.__jax_array__())
if hasattr(x, 'dtype'):
return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False))
raise TypeError(
f"Cannot interpret value of type {typ} as an abstract array; it "
"does not have a dtype attribute")
def abstractify(x):
for typ in type(x).__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
if isinstance(x, Tracer):
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
return get_aval(x)
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
return abstractify(x)
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
for t in typ.__mro__[1:]:
if (aval_fn := pytype_aval_mappings.get(t)):
return aval_fn(x)
if hasattr(x, '__jax_array__'):
return get_aval(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
get_type = get_aval
def is_concrete(x):
return to_concrete_value(x) is not None
@ -1831,13 +1846,6 @@ class DShapedArray(UnshapedArray):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
shaped_abstractify_handlers: dict[Any, Callable[[Any], AbstractValue]] = {}
def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
pytype_aval_mappings[str] = _str_abstractify
shaped_abstractify_handlers[str] = _str_abstractify
class DArray:
_aval: DShapedArray
@ -1894,7 +1902,6 @@ def _darray_aval(x):
return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
pytype_aval_mappings[DArray] = _darray_aval
shaped_abstractify_handlers[DArray] = _darray_aval
@dataclass(frozen=True)
@ -1924,11 +1931,10 @@ class MutableArray:
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __getitem__(self, idx): return self._aval._getitem(self, idx)
def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval
shaped_abstractify_handlers[MutableArray] = lambda x: x._aval
def mutable_array(init_val):
return mutable_array_p.bind(init_val)
@ -1984,7 +1990,6 @@ class Token:
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
shaped_abstractify_handlers[Token] = lambda _: abstract_token
# TODO(dougalm): Deprecate these. They're just here for backwards compat.

View File

@ -348,11 +348,20 @@ def check_is_flash_attention(
)
else:
# Regular attention conditions
if not ((H <= 128 and H % 8 == 0) and
(not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)):
raise NotImplementedError(
f"Unsupported sequence length Q {T}, KV {S} and head dim {H}."
)
# Check the head dim.
is_on_hopper = check_compute_capability("9.0")
H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128
if not (H <= H_max and H % 8 == 0):
raise NotImplementedError(
f"The head dim must be <= {H_max} and a mutiple of 8, "
f"but got {H}."
)
# Check patterns with bias, seqlen should be divisible by 2
if (is_training and has_bias and (T % 2 != 0 or S % 2 != 0)):
raise NotImplementedError(
f"Unsupported sequence length Q {T}, KV {S}."
)
def check_cudnn_version():
# check if cuDNN is installed

View File

@ -300,8 +300,9 @@ class _DebugPrintFormatChecker(string.Formatter):
formatter = _DebugPrintFormatChecker()
def _format_print_callback(fmt: str, *args, **kwargs):
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs):
with np.printoptions(**np_printoptions):
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
"""Prints values and works in staged out JAX functions.
@ -338,8 +339,8 @@ def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
# Check that we provide the correct arguments to be formatted.
formatter.format(fmt, *args, **kwargs)
debug_callback(functools.partial(_format_print_callback, fmt), *args,
**kwargs, ordered=ordered)
debug_callback(functools.partial(_format_print_callback, fmt, np.get_printoptions()),
*args, **kwargs, ordered=ordered)
# Sharding visualization

View File

@ -115,7 +115,6 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
core.shaped_abstractify_handlers[EArray] = lambda self: self.aval
core.pytype_aval_mappings[EArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
tree_util.dispatch_registry.register_node(

View File

@ -1205,7 +1205,6 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
core.shaped_abstractify_handlers[_DimExpr] = _DimExpr._get_aval
dtypes._weak_types.append(_DimExpr)
def _convertible_to_int(p: DimSize) -> bool:

View File

@ -1569,10 +1569,7 @@ class DynamicJaxprTracer(core.Tracer):
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
return self if val is None else get_referent(val)
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
return x.aval
core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
sentinel = object()

View File

@ -710,7 +710,6 @@ def one_hot(x: Any, num_classes: int, *,
'jax-nn-one-hot-float-input',
f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}",
stacklevel=1)
x_arr = x_arr.astype('int32')
return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis)

View File

@ -192,7 +192,6 @@ class _ScalarMeta(type):
def _abstractify_scalar_meta(x):
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
meta = _ScalarMeta(np_scalar_type.__name__, (object,),

View File

@ -924,7 +924,7 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None,
- :func:`jax.numpy.linalg.inv`: multiplicative inverse of a square matrix.
Notes:
:func:`jax.numpy.linalg.prng` differs from :func:`numpy.linalg.prng` in the
:func:`jax.numpy.linalg.pinv` differs from :func:`numpy.linalg.pinv` in the
default value of `rcond``: in NumPy, the default is `1e-15`. In JAX, the
default is ``10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps``.

View File

@ -2765,6 +2765,12 @@ def log2(x: ArrayLike, /) -> Array:
Array([-2., -1., 0., 1., 2., 3.], dtype=float32)
"""
x, = promote_args_inexact("log2", x)
if dtypes.issubdtype(x.dtype, np.complexfloating):
r = lax.log(x)
re = lax.real(r)
im = lax.imag(r)
ln2 = lax.log(_constant_like(re, 2))
return lax.complex(lax.div(re, ln2), lax.div(im, ln2))
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@ -2789,6 +2795,12 @@ def log10(x: ArrayLike, /) -> Array:
[-2. -1. 0. 1. 2. 3.]
"""
x, = promote_args_inexact("log10", x)
if dtypes.issubdtype(x.dtype, np.complexfloating):
r = lax.log(x)
re = lax.real(r)
im = lax.imag(r)
ln10 = lax.log(_constant_like(re, 10))
return lax.complex(lax.div(re, ln10), lax.div(im, ln10))
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))

View File

@ -84,6 +84,8 @@ def _get_memory_space_from_aval(
return None
case tpu_core.TPUMemorySpace.VMEM:
return tpu_custom_call.MemorySpace.VMEM
case tpu_core.TPUMemorySpace.SMEM:
return tpu_custom_call.MemorySpace.SMEM
case tpu_core.TPUMemorySpace.SEMAPHORE:
return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
return None

View File

@ -797,25 +797,30 @@ def lower_jaxpr_to_module(
# Each range is 2 events, each event is 4 bytes.
prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4)
prof_ctx = ProfilerContext(params["profile_dir"], prof_spec)
module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel(
body,
grid=parallel_grid,
cluster=(),
block=block,
in_shapes=in_structs_gmem,
out_shape=out_structs_gmem,
smem_scratch_shape=(
(*in_structs_smem, *out_structs_smem),
*extra_smem_scratch,
(
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
rs.barriers,
extra_barriers,
module, out_structs_gmem, _, launch_ctx, scratch_arr = (
mgpu_core._lower_as_gpu_kernel(
body,
grid=parallel_grid,
cluster=(),
block=block,
in_shapes=in_structs_gmem,
out_shape=out_structs_gmem,
smem_scratch_shape=(
(*in_structs_smem, *out_structs_smem),
*extra_smem_scratch,
(
mgpu.Barrier(
arrival_count=1, num_barriers=max_concurrent_steps
),
rs.barriers,
extra_barriers,
),
),
),
module_name=name_and_src_info.name,
prof_spec=prof_spec,
module_name=name_and_src_info.name,
prof_spec=prof_spec,
)
)
mgpu_core._initialize_scratch(launch_ctx, scratch_arr)
return LoweringResult(
module, parallel_grid, block, out_structs_gmem, prof_ctx
@ -1782,29 +1787,21 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
if isinstance(x, mgpu.FragmentedArray):
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
return x
elif isinstance(x, (np.number, np.ndarray, int, float)):
return mgpu.FragmentedArray.splat(
_ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)),
(),
is_signed=mgpu_utils.is_signed(dtype),
)
elif isinstance(x, ir.Value):
if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)):
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype))
raise NotImplementedError(f"Unsupported type: {type(x)}")
return mgpu.FragmentedArray.splat(
_ensure_ir_value(x, dtype), (), is_signed=mgpu_utils.is_signed(dtype)
)
def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value:
if isinstance(x, ir.Value):
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
return x
elif isinstance(x, (np.number, np.ndarray, int, float)):
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
elif isinstance(x, mgpu.FragmentedArray):
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
if isinstance(x.layout, mgpu.WGSplatFragLayout):
return x.registers.item()
raise NotImplementedError(f"Unsupported type: {type(x)}")
raise NotImplementedError(f"Unsupported layout: {x.layout}")
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
def _ir_constant(v: object, t: ir.Type) -> ir.Value:

View File

@ -16,6 +16,7 @@
from __future__ import annotations
from collections.abc import Sequence
import enum
import math
from typing import Any, Literal
@ -25,7 +26,6 @@ from jax._src import core as jax_core
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import llvm as llvm_dialect
@ -33,23 +33,42 @@ from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import lowering
from jax._src.pallas.mosaic_gpu.core import state_types
from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import utils as mgpu_utils
import jax.numpy as jnp
WARPGROUP_SIZE = 128
_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef
def _check_ref(
aval: object, name: str, memory_space: gpu_core.GPUMemorySpace
) -> None:
if not isinstance(aval, state_types.AbstractRef):
raise TypeError(f"{name} must be a reference, got {aval}")
aval_memory_space = getattr(aval, "memory_space", None) or gpu_core.GMEM
if aval_memory_space is not memory_space:
raise ValueError(
f"{name} must be a {memory_space.name.upper()} reference, got {aval}"
)
copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem")
copy_smem_to_gmem_p.multiple_results = True
@copy_smem_to_gmem_p.def_effectful_abstract_eval
def _copy_smem_to_gmem_abstract_eval(*avals, **params):
del avals, params # Unused.
def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
_check_ref(src, "src", gpu_core.SMEM)
_check_ref(dst, "dst", gpu_core.GMEM)
del args, params # Unused.
return (), {state.ReadEffect(0), state.WriteEffect(1)}
@ -114,9 +133,7 @@ def _extract_smem_copy_params(transforms):
def copy_smem_to_gmem(
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
predicate: jax.Array | None = None,
src: _Ref, dst: _Ref, predicate: jax.Array | None = None
) -> None:
"""Asynchronously copies a SMEM reference to a GMEM reference.
@ -130,10 +147,6 @@ def copy_smem_to_gmem(
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
:func:`jax.experimental.mosaic.gpu.commit_smem`
"""
if src.memory_space is not gpu_core.SMEM:
raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
if getattr(dst, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}")
src, src_transforms = state_primitives.get_ref_and_transforms(
src, None, "copy_smem_to_gmem", force_trailing_indexer=False,
)
@ -164,8 +177,11 @@ copy_gmem_to_smem_p.multiple_results = True
@copy_gmem_to_smem_p.def_effectful_abstract_eval
def _copy_gmem_to_smem_abstract_eval(*avals, **params):
del avals, params # Unused.
def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
del args, params # Unused.
_check_ref(src, "src", gpu_core.GMEM)
_check_ref(dst, "dst", gpu_core.SMEM)
_check_ref(barrier, "barrier", gpu_core.SMEM)
return (), {state.ReadEffect(0), state.WriteEffect(1)}
@ -217,21 +233,13 @@ def _copy_gmem_to_smem_lowering(
return ()
def copy_gmem_to_smem(
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
barrier: pallas_core.AbstractMemoryRef,
) -> None:
def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None:
"""Asynchronously copies a GMEM reference to a SMEM reference.
See also:
:func:`jax.experimental.mosaic.gpu.barrier_arrive`
:func:`jax.experimental.mosaic.gpu.barrier_wait`
"""
if getattr(src, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
raise TypeError(f"src must be a GMEM reference, got {src.memory_space}")
if dst.memory_space is not gpu_core.SMEM:
raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}")
src, src_transforms = state_primitives.get_ref_and_transforms(
src, None, "copy_gmem_to_smem", force_trailing_indexer=False,
)
@ -291,8 +299,9 @@ barrier_arrive_p.multiple_results = True
@barrier_arrive_p.def_effectful_abstract_eval
def _barrier_arrive_abstract_eval(*avals, **params):
del avals, params # Unused.
def _barrier_arrive_abstract_eval(barrier, *args, **params):
del args, params # Unused.
_check_ref(barrier, "barrier", gpu_core.SMEM)
return (), {gpu_core._memory_effect}
@ -328,8 +337,9 @@ barrier_wait_p.multiple_results = True
@barrier_wait_p.def_effectful_abstract_eval
def _barrier_wait_abstract_eval(*avals, **params):
del avals, params # Unused.
def _barrier_wait_abstract_eval(barrier, *args, **params):
_check_ref(barrier, "barrier", gpu_core.SMEM)
del args, params # Unused.
return (), {gpu_core._memory_effect}
@ -703,38 +713,40 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
del layout, dimension
return jax_core.ShapedArray(shape, dtype)
@lowering.register_lowering_rule(broadcasted_iota_p)
def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout):
del ctx
# Unsigned integers (as opposed to signless) cause MLIR verification
# errors so we only use signless like Mosaic GPU does.
#
# TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead.
mlir_dtype = (
ir.IntegerType.get_signless(dtype.itemsize * 8)
if jnp.issubdtype(dtype, jnp.integer)
else mlir.dtype_to_ir_type(dtype)
)
undef = llvm_dialect.mlir_undef(mlir_dtype)
is_signed = (
jnp.issubdtype(dtype, jnp.signedinteger)
if jnp.issubdtype(dtype, jnp.integer)
else None
)
i32 = ir.IntegerType.get_signless(32)
def _cast(x):
if ir.FloatType.isinstance(mlir_dtype):
x = arith_dialect.index_cast(i32, x)
return arith_dialect.uitofp(mlir_dtype, x)
else:
return arith_dialect.index_cast(mlir_dtype, x)
def _broadcasted_iota_lowering(
ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout
):
del ctx # Unused.
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
if ir.FloatType.isinstance(mlir_dtype):
i32 = ir.IntegerType.get_signless(32)
cast = lambda x: arith_dialect.uitofp(
mlir_dtype, arith_dialect.index_cast(i32, x)
)
else:
cast = lambda x: arith_dialect.index_cast(mlir_dtype, x)
is_signed = mgpu_utils.is_signed(dtype)
return mgpu.FragmentedArray.splat(
undef, shape, layout.value, is_signed=is_signed
llvm_dialect.mlir_undef(mlir_dtype),
shape,
layout.value,
is_signed=is_signed,
).foreach(
lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed
lambda _, idx: cast(idx[dimension]),
create_array=True,
is_signed=is_signed,
)
def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None):
return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout)
def broadcasted_iota(
dtype: jax.typing.DTypeLike,
shape: Sequence[int],
dimension: int,
*,
layout: Layout | None = None,
) -> jax.Array:
return broadcasted_iota_p.bind(
dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout
)

View File

@ -461,8 +461,6 @@ class KeyTy(dtypes.ExtendedDType):
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x

View File

@ -249,8 +249,8 @@ class XlaExecutable(Executable):
else:
raise
# TODO(skyewm): this should return a single dict (I think returning a list
# was to support MPMD executables, which never fully landed)
# TODO(b/384741132): this should return a single dict (I think returning a list
# was to support MPMD executables, which never fully landed).
def cost_analysis(self) -> list[dict[str, float]]:
xla_ext_exe = self.xla_extension_executable()
@ -266,9 +266,19 @@ class XlaExecutable(Executable):
# Try client method if executable cost_analysis method is unimplemented
if hasattr(xla_ext_exe, "client"):
try:
# TODO(b/384741132): We expect that the executable has only one
# HloModule. We should be able to remove this check once we update the
# Executable class to have only a single HloModule (see bug).
hlo_modules = xla_ext_exe.hlo_modules()
assert len(hlo_modules) == 1, (
f"Exectuable should have only one HloModule ({len(hlo_modules)})"
" were found)."
)
return [
xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m)
for m in xla_ext_exe.hlo_modules()
xla_extension.hlo_module_cost_analysis(
xla_ext_exe.client, hlo_modules[0]
)
]
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args

View File

@ -83,6 +83,7 @@ class MemorySpace(enum.Enum):
HBM = enum.auto()
VMEM = enum.auto()
SEMAPHORE_MEM = enum.auto()
SMEM = enum.auto()
@property
def color(self) -> int:
@ -92,6 +93,8 @@ class MemorySpace(enum.Enum):
return 1
elif self == MemorySpace.SEMAPHORE_MEM:
return 2
elif self == MemorySpace.SMEM:
return 4
else:
raise ValueError("invalid memory space: " + str(self))

View File

@ -128,7 +128,7 @@ _deprecations = {
_src_core.escaped_tracer_error),
"extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.",
_src_core.extend_axis_env_nd),
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_type),
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval),
"get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent),
"join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects),
"leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.",
@ -212,7 +212,7 @@ if typing.TYPE_CHECKING:
escaped_tracer_error = _src_core.escaped_tracer_error
extend_axis_env_nd = _src_core.extend_axis_env_nd
full_lower = _src_core.full_lower
get_type = _src_core.get_type
get_type = _src_core.get_aval
get_referent = _src_core.get_referent
jaxpr_as_fun = _src_core.jaxpr_as_fun
join_effects = _src_core.join_effects

View File

@ -724,7 +724,6 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
aval_in, aval_out, x):
if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1:
return x
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
axes = {name: i for i, ns in names.items() for name in ns}
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
@ -734,6 +733,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
unspecified = set(range(aval_in.ndim)) if auto else set()
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
unspecified_dims=unspecified)
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
@ -746,6 +746,8 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
ns = sharding_impls.physical_sharding(aval_out, ns)
aval_out = core.physical_aval(aval_out)
unspecified = set(range(aval_out.ndim)) if auto else set()
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
aval_in = core.physical_aval(aval_in)
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified)
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()

View File

@ -36,17 +36,17 @@ _deprecations = {
"jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.",
None,
),
# Added Sep 26 2024
# Finalized 2024-12-23; remove after 2024-03-23
"Device": (
"jax.lib.xla_client.Device is deprecated; use jax.Device instead.",
_xc.Device,
None,
),
"XlaRuntimeError": (
(
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
" jax.errors.JaxRuntimeError."
),
_xc.XlaRuntimeError,
None,
),
# Added Oct 10 2024
"FftType": (

View File

@ -70,8 +70,8 @@ py_extension(
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIGPUHeaders",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -85,7 +85,8 @@ py_extension(
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIGPUHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -100,8 +101,8 @@ py_extension(
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPINVGPUHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -116,8 +117,8 @@ py_extension(
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPILLVMHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -131,8 +132,8 @@ py_extension(
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -146,7 +147,8 @@ py_extension(
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -156,9 +158,10 @@ py_extension(
copts = COPTS,
linkopts = LINKOPTS,
deps = [
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
":jaxlib_mlir_capi_shared_library",
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
"@nanobind",
],
)
@ -378,6 +381,7 @@ cc_library(
name = "jaxlib_mlir_capi_objects",
deps = [
"//jaxlib/mosaic:tpu_dialect_capi_objects",
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects",
"@llvm-project//mlir:CAPIArithObjects",
"@llvm-project//mlir:CAPIGPUObjects",
"@llvm-project//mlir:CAPIIRObjects",

View File

@ -28,6 +28,7 @@ package(
py_library(
name = "mosaic",
deps = [
"//jaxlib/mosaic/python:gpu_dialect",
"//jaxlib/mosaic/python:tpu_dialect",
],
)
@ -42,6 +43,7 @@ cc_library(
"dialect/tpu/tpu_dialect.cc",
"dialect/tpu/tpu_ops.cc",
"dialect/tpu/util.cc",
"dialect/tpu/vreg_util.cc",
":extension_srcs",
] + glob([
"dialect/tpu/transforms/*.cc",
@ -50,6 +52,7 @@ cc_library(
"dialect/tpu/layout.h",
"dialect/tpu/tpu_dialect.h",
"dialect/tpu/util.h",
"dialect/tpu/vreg_util.h",
] + glob([
"dialect/tpu/transforms/*.h",
]),
@ -231,6 +234,19 @@ cc_library(
alwayslink = True,
)
cc_test(
name = "vreg_util_test",
srcs = ["dialect/tpu/vreg_util_test.cc"],
deps = [
":tpu_dialect",
"//testing/base/public:gunit_main",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:VectorDialect",
],
)
filegroup(
name = "extension_srcs",
srcs = [

View File

@ -215,3 +215,26 @@ cc_library(
"@llvm-project//mlir:CAPIIR",
],
)
# Header-only target, used when using the C API from a separate shared library.
cc_library(
name = "gpu_dialect_capi_headers",
hdrs = DIALECT_CAPI_HEADERS,
deps = [
":mosaic_gpu_inc_gen",
"@llvm-project//mlir:CAPIIRHeaders",
],
)
# Alwayslink target, used when exporting the C API from a shared library.
cc_library(
name = "gpu_dialect_capi_objects",
srcs = DIALECT_CAPI_SOURCES,
hdrs = DIALECT_CAPI_HEADERS,
deps = [
":mosaic_gpu",
":mosaic_gpu_inc_gen",
"@llvm-project//mlir:CAPIIRObjects",
],
alwayslink = True,
)

View File

@ -19,8 +19,8 @@ limitations under the License.
#include <optional>
#include <string>
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"

View File

@ -29,7 +29,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
@ -52,6 +51,7 @@
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/ADT/APInt.h"
#include "llvm/include/llvm/Support/LogicalResult.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
@ -64,6 +64,7 @@
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
#include "xla/array.h"
#include "xla/layout.h"
#include "xla/util.h"
@ -275,16 +276,6 @@ void updateSliceFromRange(xla::Array<T> &arr, Range data,
CHECK(data_it == data.end());
}
FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
if (isa<FloatType>(ty)) {
return TypedAttr(FloatAttr::get(ty, 0));
}
if (isa<IntegerType>(ty)) {
return TypedAttr(IntegerAttr::get(ty, 0));
}
return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty;
}
FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
@ -479,33 +470,6 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
return argument;
}
VectorType getNativeVregOrVmaskTypeImpl(
Type elem_ty, const int8_t bitwidth,
const std::array<int64_t, 2> target_shape) {
if (bitwidth == 32) {
return VectorType::get(target_shape, elem_ty);
}
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
elem_ty);
}
VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
const std::array<int64_t, 2> target_shape) {
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
if (bitwidth == 1) {
bitwidth = layout_bitwidth;
} else {
CHECK_EQ(bitwidth, layout_bitwidth);
}
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
}
VectorType getNativeVregType(Type elem_ty,
const std::array<int64_t, 2> target_shape) {
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
target_shape);
}
// Masks all values outside of bounds.
//
// Arguments:
@ -518,7 +482,7 @@ VectorType getNativeVregType(Type elem_ty,
// Returns:
// An MLIR value of the same type as the value argument, with all entries
// outside of bounds replaced by neutral.
FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
FailureOr<Value> maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder,
TypedValue<VectorType> value,
const VRegDataBounds &bounds,
const Attribute neutral) {
@ -542,9 +506,7 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
value.getLoc(),
VectorType::get(native_vreg_ty.getShape(), builder.getI1Type()), mask);
}
auto neutral_vec = builder.create<arith::ConstantOp>(
value.getLoc(), native_vreg_ty,
DenseElementsAttr::get(native_vreg_ty, neutral));
Value neutral_vec = getFullVector(builder, native_vreg_ty, neutral);
return builder
.create<arith::SelectOp>(value.getLoc(), mask, value, neutral_vec)
.getResult();
@ -1863,126 +1825,28 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]);
TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]);
const VectorType i32_vreg_ty =
getNativeVregType(builder.getI32Type(), ctx.target_shape);
auto getX32VmaskByPaddingEnd = [&](int64_t dim, int64_t padding) {
CHECK(dim == 0 || dim == 1);
CHECK(padding >= 0 && padding <= ctx.target_shape[dim]);
return cast<TypedValue<VectorType>>(
builder
.create<arith::CmpIOp>(
arith::CmpIPredicate::slt,
builder.create<tpu::IotaOp>(i32_vreg_ty,
builder.getI32IntegerAttr(dim)),
builder.create<arith::ConstantOp>(DenseElementsAttr::get(
i32_vreg_ty, builder.getI32IntegerAttr(
ctx.target_shape[dim] - padding))))
.getResult());
};
// We can also extend this helper function with padding_top and padding_left
// based on the offsets in vregs.
const Value i32_zeros_vreg = builder.create<arith::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0)));
const Value i32_max_vreg = builder.create<arith::ConstantOp>(
op.getLoc(), DenseElementsAttr::get(
i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff)));
auto maskVregs = [&](xla::Array<Value> &vregs, int64_t padding_bottom,
int64_t padding_right) {
auto vreg_ty = cast<VectorType>(vregs.begin()->getType());
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
// Mask out the bottom.
if (padding_bottom > 0) {
// We have limited the row size of LHS and RHS need to be a multiple of
// native tiling at the beginning of this rule. Therefore, it is safe to
// bitcast to x32 vreg for masking.
int sub_padding = padding_bottom % packing;
int x32_padding_bottom = padding_bottom / packing;
auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom);
// Create an int32 vreg which contains subelement masking and then
// logical_and with target vreg to mask out the unaligned paddings.
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
// [8, 128], then the mask will be:
//
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
// sublane 6: [0 , 0 , ..., 0 ]
// sublane 7: [0 , 0 , ..., 0 ]
//
// Through this way, in order to mask sub-elements, each target vreg only
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
// + packing).
Value partial_sublane_mask = builder.create<arith::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
i32_vreg_ty,
builder.getI32IntegerAttr(
0xffffffff >>
(sub_padding * vreg_ty.getElementTypeBitWidth()))));
// Insert 0xffffffff above the blended sublane.
Value sublane_mask = builder.create<arith::SelectOp>(
getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg,
partial_sublane_mask);
// Insert 0 below the blended sublane.
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
i32_zeros_vreg);
for (int64_t i = 0; i < vregs.dim(1); ++i) {
Value &vreg = vregs({vregs.dim(0) - 1, i});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
if (sub_padding > 0) {
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
} else {
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
i32_zeros_vreg);
}
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
// Mask out the right.
if (padding_right > 0) {
auto mask_right = getX32VmaskByPaddingEnd(1, padding_right);
for (int64_t i = 0; i < vregs.dim(0); ++i) {
Value &vreg = vregs({i, vregs.dim(1) - 1});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
i32_vreg = builder.create<arith::SelectOp>(mask_right, i32_vreg,
i32_zeros_vreg);
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
};
// Create a vreg filled with zeros.
auto getZerosVergLike =
[&](const Value &vreg) -> FailureOr<TypedValue<VectorType>> {
const VectorType vreg_type = cast<VectorType>(vreg.getType());
FAILUREOR_ASSIGN_OR_RETURN(
const Attribute zero_attr,
getZeroIntOrFloatAttr(vreg_type.getElementType()));
return cast<TypedValue<VectorType>>(
builder
.create<arith::ConstantOp>(
op.getLoc(), DenseElementsAttr::get(vreg_type, zero_attr))
.getResult());
};
FAILUREOR_ASSIGN_OR_RETURN(auto lhs_zeros_vreg,
getZerosVergLike(*lhs_vregs.begin()));
FAILUREOR_ASSIGN_OR_RETURN(auto rhs_zeros_vreg,
getZerosVergLike(*rhs_vregs.begin()));
FAILUREOR_ASSIGN_OR_RETURN(auto acc_zeros_vreg,
getZerosVergLike(*acc_vregs.begin()));
auto lhs_zeros_vreg =
getZerosVector(builder, cast<VectorType>(lhs_vregs.begin()->getType()));
auto rhs_zeros_vreg =
getZerosVector(builder, cast<VectorType>(rhs_vregs.begin()->getType()));
auto acc_zeros_vreg =
getZerosVector(builder, cast<VectorType>(acc_vregs.begin()->getType()));
// Only mask out the paddings on contracting dim of LHS and RHS.
maskVregs(lhs_vregs, 0, padded_lhs_cols - lhs_shape[1]);
RETURN_IF_FAILED(
maskNativeTilingVregs(builder, lhs_vregs, ctx.target_shape,
/*padding_bottom=*/0,
/*padding_right=*/padded_lhs_cols - lhs_shape[1]));
if (transpose_rhs) {
maskVregs(rhs_vregs, 0, padded_rhs_cols - rhs_shape[1]);
RETURN_IF_FAILED(maskNativeTilingVregs(
builder, rhs_vregs, ctx.target_shape,
/*padding_bottom=*/0,
/*padding_right=*/padded_rhs_cols - rhs_shape[1]));
} else {
maskVregs(rhs_vregs, padded_rhs_rows - rhs_shape[0], 0);
RETURN_IF_FAILED(
maskNativeTilingVregs(builder, rhs_vregs, ctx.target_shape,
/*padding_bottom=*/padded_rhs_rows - rhs_shape[0],
/*padding_right=*/0));
}
// At this point, all paddings on vregs are masked out. For now, we
@ -2875,12 +2739,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
native_vreg_ty,
/*dimension =*/builder.getI32IntegerAttr(1));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 1))));
Value offset = getFullVector(
builder, native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 1)));
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
@ -2902,12 +2764,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
native_vreg_ty,
/*dimension =*/builder.getI32IntegerAttr(0));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 2))));
Value offset = getFullVector(
builder, native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 2)));
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
@ -2924,10 +2784,8 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
SmallVector<Value> tiles;
tiles.reserve(vty.getDimSize(*dimension));
for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) {
tiles.push_back(builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(native_vreg_ty,
IntegerAttr::get(vty.getElementType(), i))));
tiles.push_back(getFullVector(builder, native_vreg_ty,
IntegerAttr::get(vty.getElementType(), i)));
}
xla::Array<Value> out_tiles(tile_array_shape);
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
@ -3516,12 +3374,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
const int64_t offset = *offsets_in[1];
const int64_t lane_offset = offset % ctx.target_shape[1];
const int64_t tile_offset = offset / ctx.target_shape[1];
const auto idx_ty =
VectorType::get(ctx.target_shape, builder.getI32Type());
auto lane_offset_cst = builder.create<arith::ConstantOp>(
broadcast_op.getLoc(), idx_ty,
DenseElementsAttr::get(idx_ty,
builder.getI32IntegerAttr(lane_offset)));
Value lane_offset_cst = getFullVector(
builder, getNativeVregType(builder.getI32Type(), ctx.target_shape),
builder.getI32IntegerAttr(lane_offset));
DenseI32ArrayAttr sublane_pattern;
if (num_tiles != 1) {
SmallVector<int32_t> pattern;
@ -3581,10 +3436,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
getNativeVregType(src_i32.getType(), ctx.target_shape);
auto tile_i32 =
builder.create<vector::BroadcastOp>(native_vreg_ty, src_i32);
auto zeros = builder.create<arith::ConstantOp>(
broadcast_op.getLoc(), tile_i32.getType(),
DenseElementsAttr::get(tile_i32.getType(),
builder.getI32IntegerAttr(0)));
Value zeros = getZerosVector(builder, tile_i32.getType());
auto tile =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::ne, tile_i32, zeros)
.getResult();
@ -5479,8 +5331,6 @@ FailureOr<xla::Array<Value>> doColumnShiftRelayout(
const std::array<int64_t, 2> vreg_slice = src.vregSlice(target_shape);
const int bitwidth = src.bitwidth();
const int packing = src.packing();
const VectorLayout dst(bitwidth, {src.offsets()[0], dst_col_offset}, tiling,
src.implicit_dim());
const int64_t col_diff = dst_col_offset - *src.offsets()[1];
if (tiling[0] % packing != 0 || tiling[1] != target_shape[1]) {
return emitError(loc,
@ -5823,7 +5673,8 @@ LogicalResult retileToLargeTileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
TypedValue<MemRefType> scratch_ref) {
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
const int64_t load_vreg_skips) {
if (dst_tile[0] % src_tile[0] != 0) {
return failure();
}
@ -5927,8 +5778,8 @@ LogicalResult retileToLargeTileWithScratch(
SmallVector<int64_t, 4> src_idx(rank);
dst_tiles.Each([&](absl::Span<const int64_t> dst_idx, Value *dst_vreg) {
int64_t dst_row_idx = *(dst_idx.end() - 2);
int64_t dst_col_idx = *(dst_idx.end() - 1);
int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group;
int64_t dst_col_idx_with_skips = *(dst_idx.end() - 1) + load_vreg_skips;
int64_t vreg_idx_in_group = dst_col_idx_with_skips % vregs_per_group;
int64_t load_offset = sublanes_per_group * stored_group_cnt +
vreg_idx_in_group * sl_per_vreg * stride;
delayed_loads.push_back(
@ -5938,16 +5789,20 @@ LogicalResult retileToLargeTileWithScratch(
// the vregs from current group and now we need to store corresponding
// group of src vregs before actually emitting the loads.
if (vreg_idx_in_group == vregs_per_group - 1 ||
dst_col_idx == dst_tiles.dimensions().back() - 1) {
auto src_row_idx = dst_row_idx * vregs_per_group;
auto src_col_idx = dst_col_idx / vregs_per_group;
dst_idx.back() == dst_tiles.dimensions().back() - 1) {
auto base_src_row_idx = dst_row_idx * vregs_per_group - store_vreg_delay;
auto src_col_idx = dst_col_idx_with_skips / vregs_per_group;
std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin());
for (int vi = 0; vi < vregs_per_group; ++vi) {
if (src_row_idx + vi >= src_tiles.dim(rank - 2) ||
const int64_t src_row_idx = base_src_row_idx + vi;
if (src_row_idx < 0) {
continue;
}
if (src_row_idx >= src_tiles.dim(rank - 2) ||
src_col_idx >= src_tiles.dim(rank - 1)) {
break;
}
*(src_idx.end() - 2) = src_row_idx + vi;
*(src_idx.end() - 2) = src_row_idx;
*(src_idx.end() - 1) = src_col_idx;
Value src_vreg = src_tiles(src_idx);
src_vreg =
@ -5976,7 +5831,8 @@ LogicalResult retileToSmallTileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
TypedValue<MemRefType> scratch_ref) {
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
const int64_t load_vreg_skips) {
if (src_tile[0] % dst_tile[0] != 0) {
return failure();
}
@ -6103,8 +5959,8 @@ LogicalResult retileToSmallTileWithScratch(
SmallVector<int64_t, 4> dst_idx(rank);
src_tiles.Each([&](absl::Span<const int64_t> src_idx, Value src_vreg) {
int64_t src_row_idx = *(src_idx.end() - 2);
int64_t src_col_idx = *(src_idx.end() - 1);
int64_t vreg_idx_in_group = src_col_idx % vregs_per_group;
int64_t src_col_idx_with_delays = *(src_idx.end() - 1) + store_vreg_delay;
int64_t vreg_idx_in_group = src_col_idx_with_delays % vregs_per_group;
src_vreg = builder.create<tpu::BitcastVregOp>(loc, temp_vreg_ty, src_vreg);
if (use_shuffled_load) {
Value store_offset = mlirIndexConst(
@ -6126,16 +5982,20 @@ LogicalResult retileToSmallTileWithScratch(
// vregs' row, this indicates we have stored all the vregs needed to
// assemble a new group of dst vreg.
if (vreg_idx_in_group == vregs_per_group - 1 ||
src_col_idx == src_tiles.dimensions().back() - 1) {
auto dst_row_idx = src_row_idx * vregs_per_group;
auto dst_col_idx = src_col_idx / vregs_per_group;
src_idx.back() == src_tiles.dimensions().back() - 1) {
auto base_dst_row_idx = src_row_idx * vregs_per_group - load_vreg_skips;
auto dst_col_idx = src_col_idx_with_delays / vregs_per_group;
std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin());
for (int vi = 0; vi < vregs_per_group; ++vi) {
if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) ||
const int64_t dst_row_idx = base_dst_row_idx + vi;
if (dst_row_idx < 0) {
continue;
}
if (dst_row_idx >= dst_tiles.dim(rank - 2) ||
dst_col_idx >= dst_tiles.dim(rank - 1)) {
break;
}
*(dst_idx.end() - 2) = dst_row_idx + vi;
*(dst_idx.end() - 2) = dst_row_idx;
*(dst_idx.end() - 1) = dst_col_idx;
Value *dst_vreg = &dst_tiles(dst_idx);
int64_t load_offset =
@ -6160,18 +6020,70 @@ LogicalResult retileToSmallTileWithScratch(
// go/mosaic-retiling-in-scratch is the full internal documentation that
// includes more details about the TPU generations.
LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
const Location loc,
xla::Array<Value> &dst_tiles,
const std::array<int64_t, 2> &dst_tiling,
const xla::Array<Value> &src_tiles,
const std::array<int64_t, 2> &src_tiling,
int packing) {
// Arguments:
// - shape: The non-implicit shape of the operand
// - dst_tiling: The desired result tiling
// - dst_offsets_hint: Hints for the result offsets. They may be used or
// ignored. See comments in the body of the function for
// more details.
// - src_vregs: The source vregs to retile.
// - src: The source layout
// Returns a pair holding the result layout (potentially using the hints) and
// the retiled vregs.
// TODO(tlongeri): Clean up the function parameters/signatures. We are passing
// in more information than strictly needed.
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> retileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
const ArrayRef<int64_t> shape, const std::array<int64_t, 2> dst_tiling,
const LayoutOffsets dst_offsets_hint, const xla::Array<Value> &src_vregs,
const VectorLayout &src) {
const int bitwidth = src.bitwidth();
const int packing = src.packing();
const std::array<int64_t, 2> src_tiling = src.tiling();
if (!(src_tiling[1] == ctx.target_shape[1] &&
dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 &&
dst_tiling[0] % packing == 0)) {
return failure();
}
const std::array<int64_t, 2> src_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling);
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
// TODO(b/368088671): When sublane tiling changes, we should be able to
// preserve some replications from the source layout. But we need to
// make sure they are implemented efficiently and well-tested. For now, we
// just simply use 0 for the replicated offset after retiling.
const LayoutOffsets src_offsets = {src.offsets()[0].value_or(0),
src.offsets()[1].value_or(0)};
// The provided offset hints are used only if they align with the source
// offsets, else we default to the smallest possible aligned offsets.
LayoutOffsets dst_offsets = {*src_offsets[0] % dst_vreg_slice[0],
*src_offsets[1] % dst_vreg_slice[1]};
// On a given dimension, either the source vreg slice size divides the dest
// vreg slice size, or vice versa (depending on the dimension and whether it's
// small-to-large or large-to-small retiling). Offset changes are supported
// as long as they are aligned modulo the smaller of the two sizes.
const std::array<int64_t, 2> alignment = {
std::min(src_vreg_slice[0], dst_vreg_slice[0]),
std::min(src_vreg_slice[1], dst_vreg_slice[1])};
if (dst_offsets_hint[0].has_value() &&
(*dst_offsets_hint[0] - *src_offsets[0]) % alignment[0] == 0) {
CHECK_LT(*dst_offsets_hint[0], dst_vreg_slice[0]);
dst_offsets[0] = *dst_offsets_hint[0];
}
if (dst_offsets_hint[1].has_value() &&
(*dst_offsets_hint[1] - *src_offsets[1]) % alignment[1] == 0) {
CHECK_LT(*dst_offsets_hint[1], dst_vreg_slice[1]);
dst_offsets[1] = *dst_offsets_hint[1];
}
// The offsets of the source in units of the destination vreg slice:
const std::array<int64_t, 2> src_offsets_in_dst_vreg_slices = {
*src_offsets[0] / dst_vreg_slice[0], *src_offsets[1] / dst_vreg_slice[1]};
// The offsets of the destination in units of the source vreg slice:
const std::array<int64_t, 2> dst_offsets_in_src_vreg_slices = {
*dst_offsets[0] / src_vreg_slice[0], *dst_offsets[1] / src_vreg_slice[1]};
// Try to get i32 vector scratch space. Because we will bitcast vregs to
// i32 vregs before using scratch for retiling. Through this way we can
// handle packed types as well.
@ -6186,24 +6098,57 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
dst_tiling[1]};
std::array<int64_t, 2> vi32_src_tiling = {src_tiling[0] / packing,
src_tiling[1]};
const VectorLayout dst(bitwidth, dst_offsets, dst_tiling, src.implicit_dim());
TPU_ASSERT_LOC(loc, dst.isValid(ctx.target_shape));
xla::Array<Value> dst_vregs(
dst.tileArrayImplicitShape(shape, ctx.target_shape));
// When differences in offsets exist, the source vregs may stored at an offset
// position in their group. For example, the 1st vreg in a row/column may be
// stored as if it was the 3rd, so that the parts corresponding to the 1st and
// 2nd in the destination are filled with padding. Likewise, loads to
// destination vregs may be skipped, when they would load only padding.
// store_vreg_delay is the position offset for stores, and load_vreg_skips is
// the position offset for loads.
//
// For example, suppose we are going from 32-bit {0, 128}(2, 128) to
// {4, 0}(8, 128). We form groups of 4 vregs that represent an (8, 512) slice
// of the padded implicit shape. For the given offsets, for the first group,
// the data is in (4:8, 128:512). But the first and second sources (stored
// vregs) of the group form the slices of data (0:2, 0:512) and (2:4, 0:512),
// which should be all padding. Likewise, the first dest vreg slice (which we
// load from) holds the data from slice (0:8, 0:128), which is all padding.
// We never load or store to slices that should contain only padding.
if (src_tiling[0] > dst_tiling[0]) {
return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles,
vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref);
DCHECK_EQ(src_offsets_in_dst_vreg_slices[1], 0);
DCHECK_EQ(dst_offsets_in_src_vreg_slices[0], 0);
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[1];
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[0];
if (failed(retileToSmallTileWithScratch(
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
return failure();
}
}
if (src_tiling[0] < dst_tiling[0]) {
return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles,
vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref);
DCHECK_EQ(src_offsets_in_dst_vreg_slices[0], 0);
DCHECK_EQ(dst_offsets_in_src_vreg_slices[1], 0);
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[0];
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[1];
if (failed(retileToLargeTileWithScratch(
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
return failure();
}
}
dst_tiles = std::move(src_tiles);
return success();
return std::make_pair(dst, dst_vregs);
}
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty,
const VectorLayout src, xla::Array<Value> vregs,
const std::array<int64_t, 2> dst_tiling, bool try_replicate_rows) {
const std::array<int64_t, 2> dst_tiling,
const LayoutOffsets dst_offsets_hint) {
bool has_enough_scratch = ctx.max_sublanes_in_scratch >=
ctx.target_shape[0] * (ctx.target_shape[0] + 1);
const auto &target_shape = ctx.target_shape;
@ -6219,6 +6164,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
const int8_t bitwidth = src.bitwidth();
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
// TODO(tlongeri): Using canonical vs non-canonical offsets can change the
// value of try_replicate rows, and it breaks some tests. It doesn't make
// sense that we have different behavior for equivalent layouts, though. We
// need better logic for picking the relayout strategy.
const bool try_replicate_rows =
src.offsets()[0].has_value() && !dst_offsets_hint[0].has_value();
// Fully replicated offsets are handled efficiently elsewhere (in relayout)
CHECK(src.offsets()[0].has_value() || src.offsets()[1].has_value());
@ -6290,15 +6241,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
});
return std::pair(dst, std::move(retiled));
}
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
if (!dst.isValid(target_shape)) {
return emitError(loc, "Not implemented: invalid offsets in tiling target");
}
auto dst_tiles_shape =
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
// (8,128) -> (8 * packing,128) tiling change for packed type.
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
@ -6308,8 +6254,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// not, since it relies on the src vreg array shape to know how many tiles
// to pack in dst, and vreg array shapes with materialized offsets are
// unfortunately not equal to vreg array shapes with replicated offsets.
CHECK(dst.offsets() == src_offsets);
xla::Array<Value> retiled(dst_tiles_shape);
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
@ -6357,7 +6305,9 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// interesting if the next step is a retile, since we can also
// match corresponding elements without shifting. It's just that
// the tiles are not adjacent (no contiguous vreg slice).
if (bitwidth < 32 && 32 % bitwidth == 0 &&
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
@ -6406,8 +6356,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// not, since it relies on the src vreg array shape to know how many tiles
// to pack in dst, and vreg array shapes with materialized offsets are
// unfortunately not equal to vreg array shapes with replicated offsets.
CHECK(dst.offsets() == src.offsets());
xla::Array<Value> retiled(dst_tiles_shape);
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
const VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
@ -6444,24 +6396,25 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
return std::pair(dst, std::move(retiled));
}
if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) {
// TODO(b/368088671): When sublane tiling changes, we should be able to
// preserve some replications from the source layout. But we need to
// make sure they are implemented efficiently and well-tested. For now, we
// just simply use 0 for the replicated offset after retiling.
dst = VectorLayout(
bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)},
dst_tiling, dst.implicit_dim());
// All clauses in the and expression are based on performance benchmarking.
bool use_alu = !has_enough_scratch ||
(ctx.hardware_generation >= 5 && src_tiling[0] != packing &&
dst_tiling[0] != packing);
if (use_alu) {
if (src_tiling[0] > dst_tiling[0]) {
return std::pair(
dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs,
dst, target_shape));
if (src_tiling[0] > dst_tiling[0] &&
// retileToReducedSublanes does not support offset changes
src.offsets()[0].value_or(0) < dst_vreg_slice[0] &&
src.offsets()[1].value_or(0) < dst_vreg_slice[1]) {
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
return std::pair(dst, retileToReducedSublanes(
builder, vty.getShape(), src, vregs,
VectorLayout(bitwidth,
{src.offsets()[0].value_or(0),
src.offsets()[1].value_or(0)},
dst_tiling, dst.implicit_dim()),
target_shape));
} else if (!has_enough_scratch) {
// TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops.
return emitError(
@ -6469,15 +6422,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
"Not implemented: retiling to increase sublane tiling with ALU");
}
}
xla::Array<Value> retiled(dst_tiles_shape);
if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs,
src_tiling, packing))) {
return failure();
}
return std::pair(dst, std::move(retiled));
return retileWithScratch(ctx, builder, loc, vty.getShape(), dst_tiling,
dst_offsets_hint, vregs, src);
}
return emitError(loc, "Not implemented: Unsupported tiling change for ")
<< vty << ": from " << src << " to " << dst;
<< vty << ": from " << src << " to (" << dst_tiling[0] << ", "
<< dst_tiling[1] << ") tiling";
}
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
@ -6737,9 +6687,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles),
dst.tiling(),
dst.offsets()[0] == std::nullopt &&
src.offsets()[0] != std::nullopt));
dst.tiling(), dst.offsets()));
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),

View File

@ -1647,10 +1647,17 @@ class VectorLayoutInferer {
Layout dst_layout;
if (layout.tiling() == nativeTiling(src_bitwidth)) {
// If the source is already in native tiling, we can unpack it directly.
src_layout = layout;
std::array<int64_t, 2> dst_native_tiling = nativeTiling(dst_bitwidth);
LayoutOffsets offsets = {layout.offsets()[0]
? *layout.offsets()[0] % dst_native_tiling[0]
: LayoutOffset(),
layout.offsets()[1]};
DCHECK_LT(offsets[1].value_or(0), dst_native_tiling[1]);
src_layout = VectorLayout(src_bitwidth, offsets, layout.tiling(),
layout.implicit_dim());
dst_layout =
VectorLayout(dst_bitwidth, layout.offsets(),
nativeTiling(dst_bitwidth), layout.implicit_dim());
VectorLayout(dst_bitwidth, offsets, dst_native_tiling,
layout.implicit_dim());
} else if (dst_bitwidth == 32 &&
default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
@ -1659,13 +1666,17 @@ class VectorLayoutInferer {
// tiling through the op.
// TODO(jevinjiang): we can relax this for non-32bit as well.
src_layout = layout;
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
layout.implicit_dim());
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(),
src_layout->tiling(), layout.implicit_dim());
} else {
// TODO(b/335863273): we should also reduce offsets.
src_layout = VectorLayout(src_bitwidth, layout.offsets(), default_tiling_,
LayoutOffsets offsets = {
layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0]
: LayoutOffset(),
layout.offsets()[1] ? *layout.offsets()[1] % default_tiling_[1]
: LayoutOffset()};
src_layout = VectorLayout(src_bitwidth, offsets, default_tiling_,
layout.implicit_dim());
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), default_tiling_,
dst_layout = VectorLayout(dst_bitwidth, offsets, default_tiling_,
layout.implicit_dim());
}
setLayout(op, src_layout, dst_layout);

View File

@ -0,0 +1,206 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
#include <array>
#include <cstdint>
#include "absl/log/check.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Diagnostics.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/Types.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/array.h"
namespace mlir::tpu {
namespace {
VectorType getNativeVregOrVmaskTypeImpl(
Type elem_ty, const int8_t bitwidth,
const std::array<int64_t, 2> target_shape) {
if (bitwidth == 32) {
return VectorType::get(target_shape, elem_ty);
}
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
elem_ty);
}
} // namespace
VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
const std::array<int64_t, 2> target_shape) {
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
if (bitwidth == 1) {
bitwidth = layout_bitwidth;
} else {
CHECK_EQ(bitwidth, layout_bitwidth);
}
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
}
VectorType getNativeVregType(Type elem_ty,
const std::array<int64_t, 2> target_shape) {
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
target_shape);
}
TypedValue<VectorType> getFullVector(ImplicitLocOpBuilder &builder,
VectorType vty, Attribute value) {
return cast<TypedValue<VectorType>>(
builder.create<arith::ConstantOp>(DenseElementsAttr::get(vty, value))
.getResult());
}
TypedValue<VectorType> getFullLikeVector(ImplicitLocOpBuilder &builder,
TypedValue<VectorType> vec,
Attribute value) {
return getFullVector(builder, vec.getType(), value);
}
TypedValue<VectorType> getZerosVector(ImplicitLocOpBuilder &builder,
VectorType vty) {
return getFullVector(builder, vty, builder.getZeroAttr(vty.getElementType()));
}
TypedValue<VectorType> getZerosLikeVector(ImplicitLocOpBuilder &builder,
TypedValue<VectorType> vec) {
return getZerosVector(builder, vec.getType());
}
FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
ImplicitLocOpBuilder &builder, int64_t padding,
const std::array<int64_t, 2> target_shape, int64_t dim) {
VectorType i32_vreg_ty =
getNativeVregType(builder.getI32Type(), target_shape);
if (dim != 0 && dim != 1) {
return builder.emitError()
<< "Expected a 2D vector for getX32VmaskByPaddingEnd";
}
if (padding < 0 || padding > target_shape[dim]) {
return builder.emitError()
<< "Padding must be in [0, target_shape[dim]). Padding: " << padding
<< ", target_shape[dim]: " << target_shape[dim];
}
Value padding_vreg =
getFullVector(builder, i32_vreg_ty,
builder.getI32IntegerAttr(target_shape[dim] - padding));
return cast<TypedValue<VectorType>>(
builder
.create<arith::CmpIOp>(
arith::CmpIPredicate::slt,
builder.create<tpu::IotaOp>(i32_vreg_ty,
builder.getI32IntegerAttr(dim)),
padding_vreg)
.getResult());
}
LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder,
xla::Array<Value> &vregs,
std::array<int64_t, 2> target_shape,
int64_t padding_bottom,
int64_t padding_right) {
auto vreg_ty = dyn_cast<VectorType>(vregs.begin()->getType());
if (!vreg_ty) {
return builder.emitError() << "Expected a vector type";
}
VectorType i32_vreg_ty =
getNativeVregType(builder.getI32Type(), target_shape);
Value i32_zeros_vreg = getZerosVector(builder, i32_vreg_ty);
Value i32_max_vreg = getFullVector(builder, i32_vreg_ty,
builder.getI32IntegerAttr(0xffffffff));
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
// Mask out the bottom.
if (padding_bottom > 0) {
// The function is only called when the vreg has native tiling. Therefore,
// it is safe to bitcast to x32 vreg for masking.
int sub_padding = padding_bottom % packing;
int x32_padding_bottom = padding_bottom / packing;
FAILUREOR_ASSIGN_OR_RETURN(
Value mask_top, getX32VmaskByPaddingEnd(builder, x32_padding_bottom + 1,
target_shape, /*dim=*/0));
FAILUREOR_ASSIGN_OR_RETURN(
Value mask_bottom,
getX32VmaskByPaddingEnd(builder, x32_padding_bottom, target_shape,
/*dim=*/0));
// Create an int32 vreg which contains subelement masking and then
// logical_and with target vreg to mask out the unaligned paddings.
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
// [8, 128], then the mask will be:
//
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
// sublane 6: [0 , 0 , ..., 0 ]
// sublane 7: [0 , 0 , ..., 0 ]
//
// Through this way, in order to mask sub-elements, each target vreg only
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
// + packing).
Value partial_sublane_mask = getFullVector(
builder, i32_vreg_ty,
builder.getI32IntegerAttr(
0xffffffff >> (sub_padding * vreg_ty.getElementTypeBitWidth())));
// Insert 0xffffffff above the blended sublane.
Value sublane_mask = builder.create<arith::SelectOp>(mask_top, i32_max_vreg,
partial_sublane_mask);
// Insert 0 below the blended sublane.
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
i32_zeros_vreg);
for (int64_t i = 0; i < vregs.dim(1); ++i) {
Value &vreg = vregs({vregs.dim(0) - 1, i});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
if (sub_padding > 0) {
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
} else {
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
i32_zeros_vreg);
}
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
// Mask out the right.
if (padding_right > 0) {
FAILUREOR_ASSIGN_OR_RETURN(
Value mask_right, getX32VmaskByPaddingEnd(builder, padding_right,
target_shape, /*dim=*/1));
for (int64_t i = 0; i < vregs.dim(0); ++i) {
Value &vreg = vregs({i, vregs.dim(1) - 1});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
i32_vreg =
builder.create<arith::SelectOp>(mask_right, i32_vreg, i32_zeros_vreg);
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
return success();
}
} // namespace mlir::tpu

View File

@ -0,0 +1,82 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
#include <array>
#include <cstdint>
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/Types.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "xla/array.h"
namespace mlir::tpu {
// Returns the native vreg or vmask type for the given element type and target
// shape. The layout bitwidth is used for i1 (vmask) elements.
VectorType getNativeVregOrVmaskType(Type elem_ty, int8_t layout_bitwidth,
std::array<int64_t, 2> target_shape);
VectorType getNativeVregType(Type elem_ty, std::array<int64_t, 2> target_shape);
// Returns a zero constant of the same type as `vty`.
TypedValue<VectorType> getZerosVector(ImplicitLocOpBuilder &builder,
VectorType vty);
// Same as above, but takes a `vec` as input.
TypedValue<VectorType> getZerosLikeVector(ImplicitLocOpBuilder &builder,
TypedValue<VectorType> vec);
// Returns a constant of the same type as `vty` with the given `value`.
TypedValue<VectorType> getFullVector(ImplicitLocOpBuilder &builder,
VectorType vty, Attribute value);
// Same as above, but takes a `vec` as input.
TypedValue<VectorType> getFullLikeVector(ImplicitLocOpBuilder &builder,
TypedValue<VectorType> vec,
Attribute value);
// Creates a vmask with false flags to bottom (dim = 0)
// or right (dim = 1) where the flag count corresponds to the (dim_size -
// padding).
//
// For example, assume vmask shape is (4, 8)
//
// getX32VmaskByPaddingEnd(padding=3, dim=1) creates:
// [T, T, T, T, T, F, F, F]
// [T, T, T, T, T, F, F, F]
// [T, T, T, T, T, F, F, F]
// [T, T, T, T, T, F, F, F]
// TODO(b/385204135): Unify with getVmaskByPaddingEnd in tpu_rotate_rule, and
// improve the codegen.
FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
ImplicitLocOpBuilder &builder, int64_t padding,
std::array<int64_t, 2> target_shape, int64_t dim);
// Masks out the padding in the bottom and right of the vregs. vregs are
// expected to have native tiling, and the masked vregs are mutated in
// `vregs`. `padding_bottom` and `padding_right` is the number of elements to
// pad in the bottom and right.
LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder,
xla::Array<Value> &vregs,
std::array<int64_t, 2> target_shape,
int64_t padding_bottom,
int64_t padding_right);
} // namespace mlir::tpu
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_

View File

@ -0,0 +1,228 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
#include <array>
#include <cstdint>
#include <memory>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/MLIRContext.h"
#include "mlir/include/mlir/IR/OwningOpRef.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/DebugStringHelper.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
namespace mlir::tpu {
namespace {
using ::testing::Eq;
using ::testing::Optional;
MATCHER_P2(IsConstantOpWithSplatValue, type, splat_value, "") {
auto constant_op = dyn_cast<arith::ConstantOp>(arg.getDefiningOp());
if (constant_op == nullptr) {
*result_listener << "Expected a constant op, got " << debugString(arg);
return false;
}
auto dense_attr = dyn_cast<DenseElementsAttr>(constant_op.getValue());
if (dense_attr == nullptr) {
*result_listener << "Expected a dense elements attr, got "
<< debugString(arg);
return false;
}
if (dense_attr.getType() != type) {
*result_listener << "Expected a dense elements attr with type "
<< debugString(type) << ", got "
<< debugString(dense_attr.getType());
return false;
}
if (!dense_attr.isSplat()) {
*result_listener << "Expected a splat dense elements attr, got "
<< debugString(dense_attr);
return false;
}
if (auto s = dense_attr.template getSplatValue<decltype(splat_value)>();
s != splat_value) {
*result_listener << "Expected a splat dense elements attr with value "
<< splat_value << ", got " << s;
return false;
}
return true;
}
MATCHER_P2(IsVectorTypeWithShape, shape, elem_ty, "") {
auto vty = dyn_cast<VectorType>(arg);
if (vty == nullptr) {
*result_listener << "Expected a vector type, got " << debugString(arg);
return false;
}
if (vty.getShape() != ArrayRef<int64_t>(shape)) {
*result_listener << "Expected a vector type with shape "
<< absl::StrJoin(shape, ",") << ", got "
<< absl::StrJoin(vty.getShape(), ",");
return false;
}
if (vty.getElementType() != elem_ty) {
*result_listener << "Expected a vector type with element type "
<< debugString(elem_ty) << ", got "
<< debugString(vty.getElementType());
return false;
}
return true;
}
class VregUtilTest : public ::testing::Test {
protected:
void SetUp() override {
context_.loadDialect<arith::ArithDialect, vector::VectorDialect,
tpu::TPUDialect>();
mlir::Location loc = mlir::UnknownLoc::get(&context_);
mlir::OpBuilder b(&context_);
module_ = b.create<ModuleOp>(loc);
builder_ = std::make_unique<mlir::ImplicitLocOpBuilder>(
module_->getLoc(), module_->getBodyRegion());
}
void TearDown() override {
builder_.reset();
// Reset the module to prevent memory leaks.
module_ = nullptr;
}
mlir::ImplicitLocOpBuilder& Builder() { return *builder_; }
private:
MLIRContext context_;
std::unique_ptr<mlir::ImplicitLocOpBuilder> builder_;
OwningOpRef<ModuleOp> module_;
};
TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeBitwidthMismatch) {
EXPECT_DEATH(getNativeVregOrVmaskType(Builder().getI16Type(),
/*layout_bitwidth=*/8, {2, 4}),
"");
}
TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeI1) {
EXPECT_THAT(getNativeVregOrVmaskType(Builder().getI1Type(),
/*layout_bitwidth=*/8, {2, 4}),
IsVectorTypeWithShape(std::array<int64_t, 3>{2, 4, 4},
Builder().getI1Type()));
}
TEST_F(VregUtilTest, GetNativeVregF32) {
EXPECT_THAT(getNativeVregType(Builder().getF32Type(), {2, 4}),
IsVectorTypeWithShape(std::array<int64_t, 2>{2, 4},
Builder().getF32Type()));
}
TEST_F(VregUtilTest, GetNativeVregBf16) {
EXPECT_THAT(getNativeVregType(Builder().getBF16Type(), {2, 4}),
IsVectorTypeWithShape(std::array<int64_t, 3>{2, 4, 2},
Builder().getBF16Type()));
}
TEST_F(VregUtilTest, GetFullVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getI32Type());
TypedValue<VectorType> vec =
getFullVector(Builder(), vty, Builder().getI32IntegerAttr(0x1));
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0x1}));
}
TEST_F(VregUtilTest, GetFullLikeVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
vty, Builder().create<arith::ConstantOp>(
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
TypedValue<VectorType> vec =
getFullLikeVector(Builder(), in_vec, Builder().getF32FloatAttr(2.0f));
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{2.0f}));
}
TEST_F(VregUtilTest, GetZerosVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getI32Type());
TypedValue<VectorType> vec = getZerosVector(Builder(), vty);
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0}));
}
TEST_F(VregUtilTest, GetZerosLikeVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
vty, Builder().create<arith::ConstantOp>(
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
TypedValue<VectorType> vec = getZerosLikeVector(Builder(), in_vec);
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{0.0f}));
}
TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim0) {
constexpr std::array<int64_t, 2> kTargetShape = {4, 8};
FailureOr<TypedValue<VectorType>> vec = getX32VmaskByPaddingEnd(
Builder(), /*padding=*/1, /*target_shape=*/kTargetShape,
/*dim=*/0);
ASSERT_TRUE(succeeded(vec));
auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
ASSERT_TRUE(cmp_op != nullptr);
EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
ASSERT_TRUE(iota_op != nullptr);
EXPECT_THAT(iota_op.getDimension(), Optional(Eq(0)));
EXPECT_THAT(
cmp_op.getRhs(),
IsConstantOpWithSplatValue(
VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{3}));
}
TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim1) {
constexpr std::array<int64_t, 2> kTargetShape = {4, 8};
FailureOr<TypedValue<VectorType>> vec = getX32VmaskByPaddingEnd(
Builder(), /*padding=*/3, /*target_shape=*/kTargetShape,
/*dim=*/1);
ASSERT_TRUE(succeeded(vec));
auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
ASSERT_TRUE(cmp_op != nullptr);
EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
ASSERT_TRUE(iota_op != nullptr);
EXPECT_THAT(iota_op.getDimension(), Optional(Eq(1)));
EXPECT_THAT(
cmp_op.getRhs(),
IsConstantOpWithSplatValue(
VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{5}));
}
} // namespace
} // namespace mlir::tpu

View File

@ -33,4 +33,5 @@ except ImportError:
from mlir.dialects._ods_common import _cext # type: ignore[import-not-found]
_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
# Add the parent module to the search prefix
_cext.globals.append_dialect_search_prefix(__name__[:__name__.rfind(".")])

View File

@ -83,6 +83,7 @@ setup(
'cuda/*',
'cuda/nvvm/libdevice/libdevice*',
'mosaic/*.py',
'mosaic/dialect/gpu/*.py',
'mosaic/gpu/*.so',
'mosaic/python/*.py',
'mosaic/python/*.so',

View File

@ -218,6 +218,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
dst_dir=mosaic_python_dir,
src_files=[
"__main__/jaxlib/mosaic/python/layout_defs.py",
"__main__/jaxlib/mosaic/python/mosaic_gpu.py",
"__main__/jaxlib/mosaic/python/tpu.py",
],
)
@ -225,6 +226,16 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
)
mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu"
os.makedirs(mosaic_gpu_dir)
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py",
dst_dir=mosaic_gpu_dir,
)
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py",
dst_dir=mosaic_gpu_dir,
)
copy_runfiles(
dst_dir=jaxlib_dir / "mlir",
@ -316,6 +327,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",

View File

@ -219,6 +219,29 @@ class DebugPrintTest(jtu.JaxTestCase):
[ 1 2 3 4 5 6 7 8 9 10 12 13 14]]
"""))
def test_debug_print_respects_numpy_printoptions(self):
def f(x):
with np.printoptions(precision=2, suppress=True):
jax.debug.print("{}", x)
x = np.array([1.2345, 2.3456, 1E-7])
# Default numpy print options:
with jtu.capture_stdout() as output:
jax.debug.print("{}", x)
self.assertEqual(output(), "[1.2345e+00 2.3456e+00 1.0000e-07]\n")
# Modified print options without JIT:
with jtu.capture_stdout() as output:
f(x)
jax.effects_barrier()
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
# Modified print options with JIT:
with jtu.capture_stdout() as output:
jax.jit(f)(x)
jax.effects_barrier()
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
class DebugPrintTransformationTest(jtu.JaxTestCase):

View File

@ -254,8 +254,6 @@ def sdpa_train_fp8(
class DotProductAttentionTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
@ -366,6 +364,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_sdpa_inference(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
query = jax.random.normal(
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
@ -407,6 +407,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_sdpa_var_seq(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
self.skipTest("Skip before fixed.")
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
query = jax.random.normal(
@ -438,6 +440,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_sdpa_broadcast_bias_and_dbias(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
@ -504,6 +508,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
)
@jtu.run_on_devices("cuda")
def test_sdpa_dbias(self, batch_size: int):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
# cuDNN only supports dbias when batch size is 1. If the batch size is
# greater, dbias is silently set to all zeros. This test verifies this
# behavior for both vmap and regular use cases.
@ -540,6 +546,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_sdpa_sliding_window_length(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
query = jax.random.normal(
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
@ -571,8 +579,43 @@ class DotProductAttentionTest(jtu.JaxTestCase):
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
@jtu.run_on_devices("cuda")
def test_sdpa_large_head_size(self):
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
self.skipTest(str(e))
return
if cudnn_version < 90500:
self.skipTest("Requires >= cuDNN 9.5.0")
if not jtu.is_cuda_compute_capability_at_least("9.0"):
self.skipTest("Requires at least Hopper arch")
B, T, N, H = 2, 64, 2, 256
bf16 = jnp.bfloat16
keys = jax.random.split(jax.random.key(0), 4)
query = jax.random.normal(keys[0], (B, T, N, H), dtype=bf16)
key = jax.random.normal(keys[1], (B, T, N, H), dtype=bf16)
value = jax.random.normal(keys[2], (B, T, N, H), dtype=bf16)
grad = jax.random.normal(keys[3], (B, T, N, H), dtype=bf16)
sdpa_train_ans = jax.jit(partial(
sdpa_train, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0)
)
sdpa_train_rfc = jax.jit(partial(
sdpa_train_ref, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0)
)
out_ans, grads_ans = sdpa_train_ans(query, key, value, grad, None, None)
out_ref, grads_ref = sdpa_train_rfc(query, key, value, grad, None, None)
self.assertArraysAllClose(out_ref, out_ans)
self.assertArraysAllClose(grads_ref[0], grads_ans[0])
self.assertArraysAllClose(grads_ref[1], grads_ans[1])
self.assertArraysAllClose(grads_ref[2], grads_ans[2])
@jtu.run_on_devices("cuda")
def test_layouts(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
dtype = "bfloat16"
B, T, N, H = 4, 1024, 8, 128
S = T
@ -600,6 +643,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
self.assertArraysAllClose(dv_ref, _cvt_back(dv))
def test_sdpa_utils(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
test_cases = [
(1, 257, 64, 8905, False, True, True),
(1, 1024, 64, 8905, False, False, True),

View File

@ -327,6 +327,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1E-6)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-8)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product(
l_max=[1, 2, 3, 6],
shape=[(5,), (10,)],
@ -349,6 +350,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
atol=3e-3, check_dtypes=False)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product(
l_max=[3, 4, 6, 32],
shape=[(2,), (3,), (4,), (64,)],
@ -381,6 +383,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
rtol=1e-5, atol=1e-5, check_dtypes=False)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmAccuracy(self):
m = jnp.arange(-3, 3)[:, None]
@ -435,6 +438,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product(
[dict(l_max=l_max, num_z=num_z)
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])

View File

@ -4386,7 +4386,7 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
elif name == 'log10':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag')
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
elif name == 'exp':
regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag')

View File

@ -39,8 +39,8 @@ jax_multiplatform_test(
"tpu",
],
enable_configs = [
"gpu_a100_x32",
"gpu_h100_x32",
"gpu_a100",
"gpu_h100",
],
shard_count = {
"cpu": 8,

View File

@ -1688,8 +1688,8 @@ class PallasControlFlowTest(PallasBaseTest):
def body(state):
i, s = state
sl = jax.lax.div(i, 128)
l = jax.lax.rem(i, 128)
sl = jax.lax.div(i, jnp.astype(128, i.dtype))
l = jax.lax.rem(i, jnp.astype(128, i.dtype))
v = pl.load(x_ref, (0, sl, l))
return i + 1, s + v

View File

@ -2555,9 +2555,7 @@ class MiscellaneousTest(PallasBaseTest):
np.testing.assert_array_equal(out, np.reshape(x, (8, 128)))
@only_passes_in_interpret()
def test_retiling2(self):
"""b/348040767"""
x = np.arange(1 * 8 * 1024, dtype=jnp.bfloat16).reshape(1, 8, 1024)
def kernel(x_ref, out_ref):

View File

@ -95,8 +95,8 @@ class PgleTest(jtu.JaxTestCase):
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
# TODO(b/37664749): Remove this flag once the bug is fixed.
'xla_gpu_enable_command_buffer': '',
},
)
def f(x):
@ -133,8 +133,6 @@ class PgleTest(jtu.JaxTestCase):
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
'xla_dump_to': dump_dir,
'xla_gpu_experimental_dump_fdo_profiles': 'True'
},
@ -217,8 +215,6 @@ class PgleTest(jtu.JaxTestCase):
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
'xla_dump_to': dump_dir,
'xla_gpu_experimental_dump_fdo_profiles': 'True'
},

View File

@ -2207,6 +2207,23 @@ class ShardMapTest(jtu.JaxTestCase):
#
# f(x) # don't crash
def test_partial_auto_of_random_keys(self):
if config.use_shardy_partitioner.value:
self.skipTest('Shardy does not support full-to-shard.')
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
keys = jax.random.split(jax.random.key(0), 8)
@jax.jit
def f(x):
return shard_map(lambda k: k,
mesh, in_specs=P('i'), out_specs=P('i'),
check_rep=False, auto=frozenset({'j'}))(keys)
y = f(keys) # don't crash
self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys),
check_dtypes=False)
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
# https://github.com/jax-ml/jax/pull/21032
mesh = jtu.create_mesh((4, 2), ('i', 'j'))

View File

@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
XLA_COMMIT = "dc7aaf834a0bb5a543f6cf98626284783a4a921c"
XLA_SHA256 = "eda76cce64b33c00139120d6b4d4c2167d9f99dc957da54225a67ddb7ec7cb23"
XLA_COMMIT = "ac6e71fe0cf864eec152de5ba761b76d8bef3153"
XLA_SHA256 = "2b568ff365bc4b5c2b257002aa71f094a2b60357ceb1f2a1c6c33f4ad1a411bd"
def repo():
tf_http_archive(