Merge branch 'jax-ml:main' into activation-offloading-doc

This commit is contained in:
Jane Liu 2025-01-04 16:36:27 -08:00 committed by GitHub
commit 5e3a692d36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
86 changed files with 2627 additions and 965 deletions

View File

@ -85,7 +85,7 @@ jobs:
&& steps.status.outcome == 'failure'
&& github.event_name == 'schedule'
&& github.repository == 'jax-ml/jax'
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: output-${{ matrix.python-version }}-log.jsonl
path: output-${{ matrix.python-version }}-log.jsonl

View File

@ -45,7 +45,7 @@ jobs:
--bazel_options=--config=win_clang `
--verbose
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
path: ${{ github.workspace }}\dist\*.whl

View File

@ -54,7 +54,7 @@ jobs:
--bazel_options=--config=win_clang `
--verbose
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: wheels
path: ${{ github.workspace }}\jax\dist\*.whl

View File

@ -4,6 +4,10 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
For the changes specific to the experimental Pallas APIs,
see {ref}`pallas-changelog`.
JAX follows Effort-based versioning; for a discussion of this and JAX's API
compatibility policy, refer to {ref}`api-compatibility`. For the Python and
NumPy version support policy, refer to {ref}`version-support-policy`.
<!--
Remember to align the itemized text with the first line of an item within a list.
@ -12,6 +16,16 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## Unreleased
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
@ -20,21 +34,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
is on by default.
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
and `jax.errors.JaxRuntimeError` respectively.
## jax 0.4.38 (Dec 17, 2024)

View File

@ -580,11 +580,11 @@ async def main():
if args.configure_only:
logging.info("--configure_only is set so not running any Bazel commands.")
else:
output_path = args.output_path
logger.debug("Artifacts output directory: %s", output_path)
# Wheel build command execution
for wheel in args.wheels.split(","):
output_path = args.output_path
logger.debug("Artifacts output directory: %s", output_path)
# Allow CUDA/ROCm wheels without the "jax-" prefix.
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
wheel = "jax-" + wheel

View File

@ -8,8 +8,15 @@ JAX is constantly evolving, and we want to be able to make improvements to its
APIs. That said, we want to minimize churn for the JAX user community, and we
try to make breaking changes rarely.
JAX follows a 3 month deprecation policy. When an incompatible change is made
to an API, we will make our best effort to obey the following procedure:
## JAX Versioning
JAX uses [Effort-based versioning](https://jacobtomlinson.dev/effver/) (see
{ref}`jep-effver`), and is currently in the Zero version phase.
This means that for version `0.X.Y`, incrementing `Y` will introduce minor
breaking changes, and incrementing `X` will introduce major breaking changes.
For any breaking change, JAX currently follows a 3 month deprecation policy.
When an incompatible change is made to an API, we will make our best effort
to obey the following procedure:
* the change will be announced in `CHANGELOG.md` and in the doc string for the
deprecated API, and the old API will issue a `DeprecationWarning`.
* three months after the `jax` release that deprecated an API, we may remove the
@ -47,16 +54,44 @@ prefixed with underscores, although we do not entirely comply with this yet.
## What is not covered?
* anything prefixed with an underscore.
* `jax._src`
* `jax.core`
* `jax.lib`
* `jax.interpreters`
### Explicitly private APIs
Any API or import path prefixed with an underscore is explicitly private,
and may change without warning between JAX releases. We are working to move
all private APIs into `jax._src` to make these expectations more clear.
### Legacy internal APIs
In addition, there are several legacy modules that currently expose some
private APIs without an underscore, including:
- `jax.core`
- `jax.interpreters`
- `jax.lib`
- `jax.util`
We are actively working on deprecating these modules and the APIs they contain.
In most cases, such deprecations will follow the 3 month deprecation period,
but this may not always be possible. If you use any such APIs, please expect
them to be deprecated soon, and seek alternatives.
### Experimental and example libraries
The following modules include code for experimental or demonstration purposes,
and API may change between releases without warning:
* `jax.experimental`
* `jax.example_libraries`
* `jax.extend` (see [details](https://jax.readthedocs.io/en/latest/jax.extend.html))
This list is not exhaustive.
We understand that some users depend on `jax.experimental`, and so in most cases
we follow the 3 month deprecation period for changes, but this may not always be
possible.
### JAX extend
The {mod}`jax.extend` module includes semi-public JAX internal APIs that are
meant for use by downstream projects, but do not have the same stability
guarantees of the main JAX package. If you have code that uses `jax.extend`,
we would strongly recommend CI tests against JAX's nightly releases, so as to
catch potential changes before they are released.
For details on `jax.extend`, see the [`jax.extend` module docuementation](https://jax.readthedocs.io/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`.
## Numerics and randomness

View File

@ -1,3 +1,4 @@
(jep-effver)=
# JEP 25516: Effort-based versioning for JAX
This document proposes that the JAX core library should explicitly adopt

View File

@ -93,6 +93,12 @@ plugins" error described {ref}`below <multiple_installs>`. See
<https://www.tensorflow.org/guide/profiler> for more information on installing
TensorBoard.
Nightly version of TensorBoard profiler requires nightly tensorflow and
tensorboard
```shell
pip install tf-nightly tb-nightly tbp-nightly
```
### Programmatic capture
You can instrument your code to capture a profiler trace via the

View File

@ -490,7 +490,7 @@ This section covers some of the most common patterns with JAX pytrees.
### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose`
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func}`jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
**Option 1:** Use {func}`jax.tree.map`. Here's an example:

View File

@ -56,14 +56,7 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
core.shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
@ -71,15 +64,8 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x),
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.shaped_abstractify_handlers[t] = _np_scalar_abstractify
core.literalable_types.update(array_types)
@ -90,13 +76,7 @@ def _make_abstract_python_scalar(typ, val):
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
typ = type(x)
dtype = dtypes._scalar_type_to_dtype(typ, x)
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
core.shaped_abstractify_handlers[t] = _python_scalar_abstractify
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

View File

@ -583,9 +583,6 @@ def _dtype(x):
except ValueError:
return dtypes.result_type(getattr(x, 'dtype'))
# TODO(jakevdp): fix downstream consumers and remove these aliases.
shaped_abstractify = core.shaped_abstractify
_shaped_abstractify_handlers = core.shaped_abstractify_handlers
# This decorator exists to make it easier to monkey-patch APIs in JAX.
# By default it does nothing, but it can be monkey-patched to do other things.

View File

@ -1035,7 +1035,6 @@ def _get_aval_array(self):
else:
return self.aval
core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
# TODO(jakevdp) replace this with true inheritance at the C++ level.

View File

@ -192,6 +192,14 @@ def get_compile_options(
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value
# This is a temporary workaround to simplify the AutoPGLE usage.
# TODO(b/376647494): Remove once the bug is fixed.
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
if env_options_overrides is None:
env_options_overrides = {}
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
overrides_on_build_options = [

View File

@ -1658,7 +1658,7 @@ array_garbage_collection_guard = optional_enum_state(
' do not log garbage collection of "jax.Array" objects.\n * "log":'
' log an error when a "jax.Array" is garbage collected.\n * "fatal":'
' fatal error if a "jax.Array" is garbage collected.\nDefault is'
' "allow".'
' "allow". Note that not all cycles may be detected.'
),
update_global_hook=lambda val: _update_garbage_collection_guard(
guard_lib.global_state(), 'garbage_collect_array', val

View File

@ -656,6 +656,13 @@ def check_bool_conversion(arr: Array):
" is ambiguous. Use a.any() or a.all()")
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
pytype_aval_mappings[str] = _str_abstractify
def _aval_property(name):
return property(lambda self: getattr(self.aval, name))
@ -918,6 +925,8 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])
pytype_aval_mappings[Tracer] = lambda x: x.aval
def check_eval_args(args):
for arg in args:
if isinstance(arg, Tracer):
@ -1400,45 +1409,51 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")
def _shaped_abstractify_slow(x):
try:
return x if isinstance(x, AbstractValue) else get_aval(x)
except TypeError:
pass
# We have three flavors of abstractification APIs here which each used to have
# their own separate implementation. Now they're effectively the same, with the
# following differences:
#
# - abstractify returns avals for non-traced array-like objects.
# - get_aval is like abstractify, but also accepts tracers.
# - shaped_abstractify is like get_aval, but also accepts duck-typed arrays.
#
# TODO(jakevdp): can these be unified further?
weak_type = getattr(x, 'weak_type', False)
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return ShapedArray(np.shape(x), dtype, weak_type=weak_type)
# TODO(jakevdp): deduplicate this with abstractify
def shaped_abstractify(x):
# This was originally api_util.shaped_abstractify; temporarily moved
# here in order to facilitate combining it with abstractify.
handler = shaped_abstractify_handlers.get(type(x), None)
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
for t in typ.__mro__[1:]:
if (aval_fn := pytype_aval_mappings.get(t)):
return aval_fn(x)
if isinstance(x, AbstractValue):
return x
if hasattr(x, '__jax_array__'):
return shaped_abstractify(x.__jax_array__())
if hasattr(x, 'dtype'):
return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False))
raise TypeError(
f"Cannot interpret value of type {typ} as an abstract array; it "
"does not have a dtype attribute")
def abstractify(x):
for typ in type(x).__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
if isinstance(x, Tracer):
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
return get_aval(x)
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
return abstractify(x)
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
for t in typ.__mro__[1:]:
if (aval_fn := pytype_aval_mappings.get(t)):
return aval_fn(x)
if hasattr(x, '__jax_array__'):
return get_aval(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
get_type = get_aval
def is_concrete(x):
return to_concrete_value(x) is not None
@ -1831,12 +1846,6 @@ class DShapedArray(UnshapedArray):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
shaped_abstractify_handlers: dict[Any, Callable[[Any], ShapedArray]] = {}
def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
shaped_abstractify_handlers[str] = _str_abstractify
class DArray:
_aval: DShapedArray
@ -1889,9 +1898,11 @@ class DArray:
data = self._data[slices]
return data
def _darray_aval(x):
return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
pytype_aval_mappings[DArray] = _darray_aval
pytype_aval_mappings[DArray] = \
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
@ -1920,8 +1931,8 @@ class MutableArray:
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __getitem__(self, idx): return self._aval._getitem(self, idx)
def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval

View File

@ -348,11 +348,20 @@ def check_is_flash_attention(
)
else:
# Regular attention conditions
if not ((H <= 128 and H % 8 == 0) and
(not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)):
raise NotImplementedError(
f"Unsupported sequence length Q {T}, KV {S} and head dim {H}."
)
# Check the head dim.
is_on_hopper = check_compute_capability("9.0")
H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128
if not (H <= H_max and H % 8 == 0):
raise NotImplementedError(
f"The head dim must be <= {H_max} and a mutiple of 8, "
f"but got {H}."
)
# Check patterns with bias, seqlen should be divisible by 2
if (is_training and has_bias and (T % 2 != 0 or S % 2 != 0)):
raise NotImplementedError(
f"Unsupported sequence length Q {T}, KV {S}."
)
def check_cudnn_version():
# check if cuDNN is installed

View File

@ -30,8 +30,10 @@ from jax._src import traceback_util
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
_arg_names)
from jax._src.errors import UnexpectedTracerError
from jax._src.state.types import AbstractRef
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -41,8 +43,8 @@ from jax._src.interpreters.batching import not_mapped
from jax._src.lax import lax
from jax._src.tree_util import (
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr,
treedef_children)
register_pytree_node_class, tree_leaves, tree_flatten_with_path,
tree_leaves_with_path, keystr, treedef_children)
from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable,
unzip2)
@ -608,9 +610,12 @@ class custom_vjp(Generic[ReturnValue]):
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.get_aval(x) for x in args_flat]
if config.mutable_array_checks.value:
f_ = _check_primal_refs(f_, self.nondiff_argnums)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
flat_fwd, out_trees = _flatten_fwd(
fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees,
@ -618,6 +623,37 @@ class custom_vjp(Generic[ReturnValue]):
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
return tree_unflatten(out_tree, out_flat)
@lu.transformation2
def _check_primal_refs(f, nondiff_argnums, *args):
_check_for_aliased_refs(f, nondiff_argnums, args)
out = f(*args)
_check_for_returned_refs(f, out, 'primal')
return out
def _check_for_aliased_refs(f, nondiff_argnums, args):
leaves = tree_leaves(args)
refs: dict[int, int] = {}
for i, x in enumerate(leaves):
if (isinstance((a := core.get_aval(x)), AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
if arg_names is None:
arg_names = [f'flat index {j}' for j in range(len(leaves))]
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but custom_vjp function {f} got the same mutable "
f"array reference of type {a.str_short()} at {arg_names[dup_idx]} and"
f" {arg_names[i]}.")
def _check_for_returned_refs(f, out, kind):
leaves = tree_leaves_with_path(out)
for path, leaf in leaves:
if isinstance((a := core.get_aval(leaf)), AbstractRef):
loc = f' at output tree path {keystr(path)}' if path else ''
raise ValueError(f"custom_vjp {kind} function {f} returned a mutable "
f"a array reference of type {a.str_short()}{loc}, "
"but mutable array references cannot be returned.")
@dataclasses.dataclass
class CustomVJPPrimal:
"""Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
@ -655,14 +691,18 @@ def _check_for_tracers(x):
raise UnexpectedTracerError(msg)
@partial(lu.transformation_with_aux2, use_eq_store=True)
def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
*args):
def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name,
fwd_name, in_tree, maybe_out_type, *args):
if symbolic_zeros:
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
else:
args = args[::2]
py_args = tree_unflatten(in_tree, args)
if config.mutable_array_checks.value:
_check_for_aliased_refs(f, nondiff_argnums, py_args)
pair_out = f(*py_args)
if config.mutable_array_checks.value:
_check_for_returned_refs(f, pair_out, 'fwd')
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
@ -1393,8 +1433,8 @@ def optimize_remat_of_custom_vjp_fwd(
fwd_ = lu.wrap_init(fwd)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
in_tree, out_type)
flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False,
primal_name, fwd_name, in_tree, out_type)
flat_fwd = _fix_fwd_args(flat_fwd)
in_avals = [core.get_aval(x) for x in args_flat]

View File

@ -300,8 +300,9 @@ class _DebugPrintFormatChecker(string.Formatter):
formatter = _DebugPrintFormatChecker()
def _format_print_callback(fmt: str, *args, **kwargs):
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs):
with np.printoptions(**np_printoptions):
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
"""Prints values and works in staged out JAX functions.
@ -338,8 +339,8 @@ def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
# Check that we provide the correct arguments to be formatted.
formatter.format(fmt, *args, **kwargs)
debug_callback(functools.partial(_format_print_callback, fmt), *args,
**kwargs, ordered=ordered)
debug_callback(functools.partial(_format_print_callback, fmt, np.get_printoptions()),
*args, **kwargs, ordered=ordered)
# Sharding visualization

View File

@ -81,6 +81,15 @@ class State:
raise ValueError('Number of processes must be defined.')
if process_id is None:
raise ValueError('The process id of the current process must be defined.')
if not isinstance(process_id, int):
raise TypeError("process_id must be a nonnegative int. "
f"Got process_id={process_id} of type {type(process_id)}.")
if not isinstance(num_processes, int):
raise TypeError("num_processes must be a positive int. "
f"Got num_processes={num_processes} of type {type(num_processes)}.")
if not (0 <= process_id < num_processes):
raise ValueError("process_id and num_processes must be nonnegative, with process_id < num_processes. "
f"Got process_id={process_id}, num_processes={num_processes}.")
self.coordinator_address = coordinator_address

View File

@ -205,6 +205,21 @@ _default_types: dict[str, type[Any]] = {
'c': complex_,
}
def bit_width(dtype: DTypeLike) -> int:
"""Number of bits per element for the dtype."""
# Note: we cannot use dtype.itemsize here because this is
# incorrect for sub-byte integer types.
if dtype == np.dtype(bool):
return 8 # physical bit layout for boolean dtype
elif issubdtype(dtype, np.integer):
return iinfo(dtype).bits
elif issubdtype(dtype, np.floating):
return finfo(dtype).bits
elif issubdtype(dtype, np.complexfloating):
return 2 * finfo(dtype).bits
else:
raise ValueError(f"unexpected input: {dtype=}")
# Trivial vectorspace datatype needed for tangent values of int/bool primals
float0: np.dtype = np.dtype([('float0', np.void, 0)])

View File

@ -115,7 +115,6 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
core.shaped_abstractify_handlers[EArray] = lambda self: self.aval
core.pytype_aval_mappings[EArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
tree_util.dispatch_registry.register_node(

View File

@ -20,7 +20,7 @@
// 3. Add back the licence comment at the start
//
namespace jax.export.serialization;
namespace jax_export.serialization;
enum PyTreeDefKind: byte {
leaf = 0,

View File

@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import numpy as np
from jax import lax
import jax.numpy as jnp
from jax._src.lax import lax as lax_internal
from jax._src import dtypes
from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.util import safe_zip, unzip2, HashablePartial
@ -83,7 +82,8 @@ def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
f"but expected dtype {to_dtype}")
chunks = jnp.split(arr, indices[:-1])
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
return [lax.convert_element_type(chunk.reshape(shape), dtype)
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]
return [
lax_internal._convert_element_type(chunk.reshape(shape), dtype,
warn_on_complex_to_real_cast=False)
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
]

View File

@ -539,6 +539,9 @@ class JVPTracer(Tracer):
def to_concrete_value(self):
return core.to_concrete_value(self.primal)
def get_referent(self):
return core.get_referent(self.primal)
def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = get_aval(primal).strip_weak_type()

View File

@ -1569,10 +1569,7 @@ class DynamicJaxprTracer(core.Tracer):
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
return self if val is None else get_referent(val)
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
return x.aval
core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
sentinel = object()
@ -2010,8 +2007,8 @@ class DynamicJaxprTrace(core.Trace):
def fwd_jaxpr_from_zeros(*zeros):
for store in fwd.stores: store and store.reset()
fwd_ = _interleave_fun(fwd, zeros)
jaxpr, _, consts, atr = trace_to_jaxpr_dynamic(fwd_, in_avals)
if atr: raise NotImplementedError
jaxpr, _, consts, attrs = trace_to_jaxpr_dynamic(fwd_, in_avals)
if attrs: raise NotImplementedError
return jaxpr, consts
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
@ -2154,14 +2151,14 @@ def trace_to_jaxpr_dynamic(
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(trace.to_jaxpr_tracer, ans)
_check_no_refs(debug_info, out_tracers)
_check_no_returned_refs(debug_info, out_tracers)
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
del trace, fun, in_tracers, out_tracers, ans
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
def _check_no_refs(
def _check_no_returned_refs(
dbg: lu.TracingDebugInfo | None,
out_tracers: Sequence[DynamicJaxprTracer]
) -> None:

View File

@ -38,6 +38,7 @@ from jax._src.lax.control_flow.conditionals import (
cond_p as cond_p,
switch as switch,
platform_dependent as platform_dependent,
platform_index_p as platform_index_p,
)
from jax._src.lax.control_flow.solves import (
custom_linear_solve as custom_linear_solve,

View File

@ -89,13 +89,6 @@ def _initial_style_jaxprs_with_common_consts(
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
# If we get a `Ref` in the consts, we know it must come from an outer
# `run_state`. We also know if shouldn't be boxed up in another tracer.
# We assert that it is in fact a DynamicJaxprTracer
for consts, consts_avals in zip(all_consts, all_const_avals):
for c, aval in zip(consts, consts_avals):
if isinstance(aval, state.AbstractRef):
assert isinstance(c, pe.DynamicJaxprTracer)
# TODO(sharadmv,mattjj): we could dedup *all consts* instead of just the Refs.

View File

@ -25,6 +25,8 @@ from typing import Any, TypeVar
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src.api_util import (
_check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
from jax._src import config
from jax._src import core
from jax._src import dispatch
@ -136,8 +138,14 @@ def switch(index, branches: Sequence[Callable], *operands,
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))
if config.mutable_array_checks.value:
dbg = pe.debug_info(branches[0], ops_tree, None, False, 'switch')
_check_no_aliased_ref_args(dbg, ops_avals, ops)
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops)
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
out_trees[0], jaxprs[0].out_avals,
@ -228,11 +236,14 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))
if config.mutable_array_checks.value:
dbg = pe.debug_info(true_fun, ops_tree, None, False, 'cond')
_check_no_aliased_ref_args(dbg, ops_avals, ops)
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals):
raise ValueError("Cannot pass `Ref`s into `cond`.")
true_jaxpr, false_jaxpr = jaxprs
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*true_jaxpr.consts, *consts), ops)
out_tree, false_out_tree = out_trees
if any(isinstance(out_aval, AbstractRef) for out_aval in
@ -934,6 +945,7 @@ def platform_dependent(*args: Any,
platform_index = platform_index_p.bind(
platforms=tuple(tuple(ps) for ps in platforms_lists),
has_default=(default is not None))
if default is not None:
branches = branches + (default,)
# Use a switch, to get the proper transformation rules for free. Since
@ -946,6 +958,8 @@ def platform_dependent(*args: Any,
# recognized on the compilation platform. Detect eager mode and keep only the
# needed branch.
try:
# Note/TODO(mvoz): This actually rarely seems to concretize - we could look into
# core.ensure_compile_time_eval to get better single-branch selection.
platform_index_concrete = core.concrete_or_error(operator.index, platform_index)
except core.ConcretizationTypeError:
return switch(platform_index, branches, *args)

View File

@ -546,7 +546,8 @@ def _convert_element_type(
operand: ArrayLike,
new_dtype: DTypeLike | dtypes.ExtendedDType | None = None,
weak_type: bool = False,
sharding: Sharding | None = None):
sharding: Sharding | None = None,
warn_on_complex_to_real_cast: bool = True):
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()
@ -585,7 +586,8 @@ def _convert_element_type(
isinstance(operand, Array)):
sharding = operand.sharding
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
if (warn_on_complex_to_real_cast and
dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, NumpyComplexWarning, stacklevel=2)
@ -3197,12 +3199,15 @@ def _convert_elt_type_folding_rule(consts, eqn):
# TODO(mattjj): allow constant-folding CPU-backed JAX arrays
c, = consts
o, = eqn.outvars
new_dtype = eqn.params['new_dtype']
if (type(c) in {np.ndarray, *dtypes.python_scalar_dtypes} and
isinstance(o.aval, core.UnshapedArray) and not np.shape(c) and
not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended)):
with warnings.catch_warnings():
warnings.simplefilter('ignore', util.NumpyComplexWarning)
out = np.array(c).astype(eqn.params['new_dtype'])
not dtypes.issubdtype(new_dtype, dtypes.extended)):
out = np.array(c)
if (dtypes.issubdtype(out.dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
out = out.real
out = out.astype(new_dtype)
if not o.aval.weak_type:
return [out], None
out = out.item()
@ -3367,18 +3372,21 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
new_dtype = dtypes.canonicalize_dtype(new_dtype)
if old_dtype.itemsize == new_dtype.itemsize:
old_nbits = dtypes.bit_width(old_dtype)
new_nbits = dtypes.bit_width(new_dtype)
if old_nbits == new_nbits:
return operand.shape
elif old_dtype.itemsize > new_dtype.itemsize:
return (*operand.shape, old_dtype.itemsize // new_dtype.itemsize)
elif old_nbits > new_nbits:
return (*operand.shape, old_nbits // new_nbits)
else:
dim_size = operand.shape[-1] if operand.shape else 1
if dim_size * old_dtype.itemsize != new_dtype.itemsize:
if dim_size * old_nbits != new_nbits:
raise ValueError(
f"Attempting to convert array of shape {operand.shape} "
f"from {old_dtype} of size {old_dtype.itemsize} "
f"to {new_dtype} of size {new_dtype.itemsize}, "
f"but {dim_size} * {old_dtype.itemsize} != {new_dtype.itemsize}")
f"from {old_dtype} of size {old_nbits} bits "
f"to {new_dtype} of size {new_nbits}, bits "
f"but {dim_size} * {old_nbits} != {new_nbits}")
return operand.shape[:-1]
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):

View File

@ -490,13 +490,45 @@ def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets
# Index 'data' at 'offsets'[2], 'sizes'[2]'
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
``output_offsets`` must be sharded in a way that each replica has offsets in
the target replica output perspective.
For i-th output offset, the current replica will send
`operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th
replica that will be written to
`output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th
replica ``output``.
For example, if we have 2 replicas:
replica 0:
operand: [1, 2, 2]
output: [0, 0, 0, 0]
input_offsets: [0, 1]
send_sizes: [1, 2]
output_offsets: [0, 0]
recv_sizes: [1, 1]
replica 1:
operand: [3, 4, 0]
output: [0, 0, 0, 0]
input_offsets: [0, 1]
send_sizes: [1, 1]
output_offsets: [1, 2]
recv_sizes: [2, 1]
replica 0's result will be: [1, 3, 0, 0]
replica 1's result will be: [2, 2, 4, 0]
Args:
operand: array with ragged dimension along its outermost dimension.
output: array of ragged input offsets.
input_offsets: array of ragged input send sizes.
send_sizes: array of ragged output data.
output_offsets: array of ragged output offsets.
output_offsets: array of ragged offsets in the target replica output.
recv_sizes: array of ragged output receive sizes.
Returns:
array with shape equal to ``output``.
"""

View File

@ -173,8 +173,40 @@ lt = np.less
def convert_element_type(operand, dtype):
return np.asarray(operand, dtype=dtype)
def _bitcast_uint4_to_uint8(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint4'
operand = operand.astype('uint8')
return operand[..., ::2] + (operand[..., 1::2] << 4)
def _bitcast_uint8_to_uint4(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint8'
result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4')
result[..., ::2] = (operand & 0b00001111).astype('uint4')
result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4')
return result
def bitcast_convert_type(operand, dtype):
return np.asarray(operand).view(dtype)
operand = np.asarray(operand)
nbits_in = dtypes.bit_width(operand.dtype)
nbits_out = dtypes.bit_width(dtype)
if nbits_out > nbits_in:
assert operand.shape[-1] == nbits_out // nbits_in
out_shape = operand.shape[:-1]
elif nbits_out == nbits_in:
out_shape = operand.shape
else:
out_shape = (*operand.shape, nbits_in // nbits_out)
# Special handling for 4-bit integers.
if nbits_in == 4:
operand = _bitcast_uint4_to_uint8(operand.view('uint4'))
if nbits_out == 4:
operand = _bitcast_uint8_to_uint4(operand.view('uint8'))
return operand.view(dtype).reshape(out_shape)
def clamp(min, operand, max):
return np.clip(operand, np.clip(min, None, max), max).astype(operand.dtype)

View File

@ -710,7 +710,6 @@ def one_hot(x: Any, num_classes: int, *,
'jax-nn-one-hot-float-input',
f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}",
stacklevel=1)
x_arr = x_arr.astype('int32')
return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis)

View File

@ -509,12 +509,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr
dtypes.check_user_dtype_supported(dtype, "view")
dtype = dtypes.canonicalize_dtype(dtype)
nbits_in = dtypes.bit_width(self.dtype)
nbits_out = dtypes.bit_width(dtype)
if self.ndim == 0:
if self.dtype.itemsize != dtype.itemsize:
if nbits_in != nbits_out:
raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.")
return _view(lax.expand_dims(self, (0,)), dtype).squeeze()
if (self.shape[-1] * self.dtype.itemsize) % dtype.itemsize != 0:
if (self.shape[-1] * nbits_in) % nbits_out != 0:
raise ValueError("When changing to a larger dtype, its size must be a divisor "
"of the total size in bytes of the last axis of the array.")
@ -543,16 +546,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr
# lax.bitcast_convert_type adds or subtracts dimensions depending on the
# relative bitwidths of the dtypes; we account for that with reshapes.
if self.dtype.itemsize < dtype.itemsize:
factor = dtype.itemsize // self.dtype.itemsize
if nbits_in < nbits_out:
factor = nbits_out // nbits_in
out = self.reshape(*self.shape[:-1], self.shape[-1] // factor, factor)
return lax.bitcast_convert_type(out, dtype)
if self.dtype.itemsize > dtype.itemsize:
elif nbits_in > nbits_out:
out = lax.bitcast_convert_type(self, dtype)
return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1])
return lax.bitcast_convert_type(self, dtype)
else:
return lax.bitcast_convert_type(self, dtype)
def _notimplemented_flat(self):

View File

@ -191,7 +191,7 @@ class _ScalarMeta(type):
def _abstractify_scalar_meta(x):
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
@ -5731,10 +5731,9 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
# We offer a more specific warning than the usual ComplexWarning so we prefer
# to issue our warning.
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
result = lax_internal._convert_element_type(
x_arr, dtype, sharding=_normalize_to_sharding(device))
result = lax_internal._convert_element_type(
x_arr, dtype, sharding=_normalize_to_sharding(device),
warn_on_complex_to_real_cast=False)
return _array_copy(result) if copy else result
@ -8341,18 +8340,41 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
Array([4, 8], dtype=int32)
"""
util.check_arraylike("diagonal", a)
a_shape = shape(a)
if ndim(a) < 2:
raise ValueError("diagonal requires an array of at least two dimensions.")
offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()")
a = moveaxis(a, (axis1, axis2), (-2, -1))
def _default_diag(a):
a_shape = shape(a)
diag_size = max(0, min(a_shape[axis1] + min(offset, 0),
a_shape[axis2] - max(offset, 0)))
i = arange(diag_size)
j = arange(abs(offset), abs(offset) + diag_size)
return a[..., i, j] if offset >= 0 else a[..., j, i]
a = moveaxis(a, (axis1, axis2), (-2, -1))
diag_size = max(
0, min(a_shape[axis1] + min(offset, 0), a_shape[axis2] - max(offset, 0))
)
i = arange(diag_size)
j = arange(abs(offset), abs(offset) + diag_size)
return a[..., i, j] if offset >= 0 else a[..., j, i]
# The mosaic lowering rule for diag is only defined for square arrays.
# TODO(mvoz): Add support for offsets.
if shape(a)[0] != shape(a)[1] or ndim(a) != 2 or offset != 0 or _dtype(a) == bool_:
return _default_diag(a)
else:
a_shape_eye = eye(shape(a)[0], dtype=_dtype(a))
def _mosaic_diag(a):
def _sum(x, axis):
return lax.reduce(
x,
np.array(0, _dtype(x)),
lax.add if _dtype(x) != bool_ else lax.bitwise_or,
(axis,),
)
return _sum(lax.mul(a_shape_eye, a), axis=0)
return lax.platform_dependent(a, default=_default_diag, mosaic=_mosaic_diag)
@export

View File

@ -924,7 +924,7 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None,
- :func:`jax.numpy.linalg.inv`: multiplicative inverse of a square matrix.
Notes:
:func:`jax.numpy.linalg.prng` differs from :func:`numpy.linalg.prng` in the
:func:`jax.numpy.linalg.pinv` differs from :func:`numpy.linalg.pinv` in the
default value of `rcond``: in NumPy, the default is `1e-15`. In JAX, the
default is ``10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps``.

View File

@ -20,7 +20,6 @@ from functools import partial
import math
import operator
from typing import overload, Any, Literal, Protocol, Union
import warnings
import numpy as np
@ -37,7 +36,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
set_module, NumpyComplexWarning)
set_module)
export = set_module('jax.numpy')
@ -100,7 +99,7 @@ ReductionOp = Callable[[Any, Any], Any]
def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
*, has_identity: bool = True,
preproc: Callable[[ArrayLike], ArrayLike] | None = None,
preproc: Callable[[Array], Array] | None = None,
bool_op: ReductionOp | None = None,
upcast_f16_for_computation: bool = False,
axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
@ -201,16 +200,15 @@ def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
def _cast_to_bool(operand: ArrayLike) -> Array:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NumpyComplexWarning)
return lax.convert_element_type(operand, np.bool_)
def _cast_to_bool(operand: Array) -> Array:
if dtypes.issubdtype(operand.dtype, np.complexfloating):
operand = operand.real
return lax.convert_element_type(operand, np.bool_)
def _cast_to_numeric(operand: ArrayLike) -> Array:
def _cast_to_numeric(operand: Array) -> Array:
return promote_dtypes_numeric(operand)[0]
def _require_integer(operand: ArrayLike) -> Array:
arr = lax_internal.asarray(operand)
def _require_integer(arr: Array) -> Array:
if not dtypes.isdtype(arr, ("bool", "integral")):
raise ValueError(f"integer argument required; got dtype={arr.dtype}")
return arr

View File

@ -2765,6 +2765,12 @@ def log2(x: ArrayLike, /) -> Array:
Array([-2., -1., 0., 1., 2., 3.], dtype=float32)
"""
x, = promote_args_inexact("log2", x)
if dtypes.issubdtype(x.dtype, np.complexfloating):
r = lax.log(x)
re = lax.real(r)
im = lax.imag(r)
ln2 = lax.log(_constant_like(re, 2))
return lax.complex(lax.div(re, ln2), lax.div(im, ln2))
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@ -2789,6 +2795,12 @@ def log10(x: ArrayLike, /) -> Array:
[-2. -1. 0. 1. 2. 3.]
"""
x, = promote_args_inexact("log10", x)
if dtypes.issubdtype(x.dtype, np.complexfloating):
r = lax.log(x)
re = lax.real(r)
im = lax.imag(r)
ln10 = lax.log(_constant_like(re, 10))
return lax.complex(lax.div(re, ln10), lax.div(im, ln10))
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))

View File

@ -547,9 +547,13 @@ def lower_jaxpr_to_module(
module_name = name_and_src_info.name
attrs["sym_name"] = ir.StringAttr.get(module_name)
sym_tab = ir.SymbolTable(m.operation)
func_op = lower_jaxpr_to_func(
ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
name="main", for_verification=for_verification,
ctx,
jaxpr,
mosaic_grid_mapping=mosaic_grid_mapping,
name="main",
for_verification=for_verification,
)
m.body.append(func_op)
sym_tab.insert(func_op)
@ -568,6 +572,7 @@ def lower_jaxpr_to_module(
# We checked above that the block does not require windowing.
window_params.append(ir.DictAttr.get())
continue
mlir_func = lower_jaxpr_to_transform_func(
ctx,
bm.index_map_jaxpr.jaxpr,
@ -1990,6 +1995,36 @@ lowering_rules[ad_util.add_any_p] = _add_lowering_rule
skip_mlir_conversions.add(ad_util.add_any_p)
class FoldingError(Exception):
pass
def _fold_and_get_constant_value(x):
def _fold(x, fuel):
if fuel <= 0:
raise FoldingError("Folding depth exceeded")
op_name = getattr(x.owner, "name", None)
binop_folds = {
"arith.maxsi": max,
"arith.minsi": min,
}
if op_name == "arith.constant":
if ir.IntegerType.isinstance(x.type):
return ir.IntegerAttr(x.owner.attributes["value"]).value
elif ir.FloatType.isinstance(x.type):
return ir.FloatAttr(x.owner.attributes["value"]).value
else:
raise ValueError(f"Unsupported constant type: {x.type}")
if op_name in binop_folds:
return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands)
raise FoldingError(f"Folding not supported for {x.owner}")
try:
return _fold(x, 10)
except FoldingError:
return None
def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
(aval_out,) = ctx.avals_out
@ -2708,6 +2743,12 @@ lowering_rules[lax.while_p] = _while_lowering_rule
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches):
index, *args = args
constant_index = _fold_and_get_constant_value(index)
if constant_index is not None:
return jaxpr_subcomp(
ctx.lowering_context.replace(block_shapes=ctx.block_shapes[1:]), branches[constant_index].jaxpr, *args
)
out_types = map(aval_to_ir_type, ctx.avals_out)
pred = arith.cmpi(
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
@ -3375,3 +3416,25 @@ def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
lowering_rules[lax.pad_p] = _pad_lowering_rule
def _platform_index_lowering(
ctx: mlir.LoweringRuleContext,
*,
platforms: Sequence[Sequence[str]],
has_default: bool,
):
for i, ps in enumerate(platforms):
# note - slightly odd structure here, as platforms is a seq[seq[str]]
if "mosaic" in ps:
return ir_constant(i)
if has_default:
return ir_constant(len(platforms))
raise NotImplementedError(
"No mosaic or default platform indexing rule found."
)
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering

View File

@ -797,25 +797,30 @@ def lower_jaxpr_to_module(
# Each range is 2 events, each event is 4 bytes.
prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4)
prof_ctx = ProfilerContext(params["profile_dir"], prof_spec)
module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel(
body,
grid=parallel_grid,
cluster=(),
block=block,
in_shapes=in_structs_gmem,
out_shape=out_structs_gmem,
smem_scratch_shape=(
(*in_structs_smem, *out_structs_smem),
*extra_smem_scratch,
(
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
rs.barriers,
extra_barriers,
module, out_structs_gmem, _, launch_ctx, scratch_arr = (
mgpu_core._lower_as_gpu_kernel(
body,
grid=parallel_grid,
cluster=(),
block=block,
in_shapes=in_structs_gmem,
out_shape=out_structs_gmem,
smem_scratch_shape=(
(*in_structs_smem, *out_structs_smem),
*extra_smem_scratch,
(
mgpu.Barrier(
arrival_count=1, num_barriers=max_concurrent_steps
),
rs.barriers,
extra_barriers,
),
),
),
module_name=name_and_src_info.name,
prof_spec=prof_spec,
module_name=name_and_src_info.name,
prof_spec=prof_spec,
)
)
mgpu_core._initialize_scratch(launch_ctx, scratch_arr)
return LoweringResult(
module, parallel_grid, block, out_structs_gmem, prof_ctx
@ -1572,8 +1577,11 @@ def _lower_while_via_fori(
body_nconsts,
):
assert not fori_jaxpr.constvars
# The pattern matcher looks for conditions with no constants.
assert cond_nconsts == 0
# Reflect the changes of the pattern matcher to the context.
lb_aval, ub_aval, *_ = ctx.avals_in[cond_nconsts + body_nconsts:]
ctx = ctx.replace(
avals_in=(
*ctx.avals_in[cond_nconsts:body_nconsts],
@ -1585,7 +1593,6 @@ def _lower_while_via_fori(
_, consts, (lb, ub, *args) = util.split_list(
args, [cond_nconsts, body_nconsts]
)
lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:]
lb = _ensure_ir_value(lb, lb_aval.dtype)
ub = _ensure_ir_value(ub, ub_aval.dtype)
for_out = _lower_jaxpr_to_for_loop(
@ -1780,29 +1787,21 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
if isinstance(x, mgpu.FragmentedArray):
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
return x
elif isinstance(x, (np.number, np.ndarray, int, float)):
return mgpu.FragmentedArray.splat(
_ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)),
(),
is_signed=mgpu_utils.is_signed(dtype),
)
elif isinstance(x, ir.Value):
if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)):
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype))
raise NotImplementedError(f"Unsupported type: {type(x)}")
return mgpu.FragmentedArray.splat(
_ensure_ir_value(x, dtype), (), is_signed=mgpu_utils.is_signed(dtype)
)
def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value:
if isinstance(x, ir.Value):
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
return x
elif isinstance(x, (np.number, np.ndarray, int, float)):
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
elif isinstance(x, mgpu.FragmentedArray):
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
if isinstance(x.layout, mgpu.WGSplatFragLayout):
return x.registers.item()
raise NotImplementedError(f"Unsupported type: {type(x)}")
raise NotImplementedError(f"Unsupported layout: {x.layout}")
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
def _ir_constant(v: object, t: ir.Type) -> ir.Value:

View File

@ -16,6 +16,7 @@
from __future__ import annotations
from collections.abc import Sequence
import enum
import math
from typing import Any, Literal
@ -25,7 +26,6 @@ from jax._src import core as jax_core
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import llvm as llvm_dialect
@ -33,23 +33,42 @@ from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import lowering
from jax._src.pallas.mosaic_gpu.core import state_types
from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import utils as mgpu_utils
import jax.numpy as jnp
WARPGROUP_SIZE = 128
_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef
def _check_ref(
aval: object, name: str, memory_space: gpu_core.GPUMemorySpace
) -> None:
if not isinstance(aval, state_types.AbstractRef):
raise TypeError(f"{name} must be a reference, got {aval}")
aval_memory_space = getattr(aval, "memory_space", None) or gpu_core.GMEM
if aval_memory_space is not memory_space:
raise ValueError(
f"{name} must be a {memory_space.name.upper()} reference, got {aval}"
)
copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem")
copy_smem_to_gmem_p.multiple_results = True
@copy_smem_to_gmem_p.def_effectful_abstract_eval
def _copy_smem_to_gmem_abstract_eval(*avals, **params):
del avals, params # Unused.
def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
_check_ref(src, "src", gpu_core.SMEM)
_check_ref(dst, "dst", gpu_core.GMEM)
del args, params # Unused.
return (), {state.ReadEffect(0), state.WriteEffect(1)}
@ -114,9 +133,7 @@ def _extract_smem_copy_params(transforms):
def copy_smem_to_gmem(
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
predicate: jax.Array | None = None,
src: _Ref, dst: _Ref, predicate: jax.Array | None = None
) -> None:
"""Asynchronously copies a SMEM reference to a GMEM reference.
@ -130,10 +147,6 @@ def copy_smem_to_gmem(
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
:func:`jax.experimental.mosaic.gpu.commit_smem`
"""
if src.memory_space is not gpu_core.SMEM:
raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
if getattr(dst, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}")
src, src_transforms = state_primitives.get_ref_and_transforms(
src, None, "copy_smem_to_gmem", force_trailing_indexer=False,
)
@ -164,8 +177,11 @@ copy_gmem_to_smem_p.multiple_results = True
@copy_gmem_to_smem_p.def_effectful_abstract_eval
def _copy_gmem_to_smem_abstract_eval(*avals, **params):
del avals, params # Unused.
def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
del args, params # Unused.
_check_ref(src, "src", gpu_core.GMEM)
_check_ref(dst, "dst", gpu_core.SMEM)
_check_ref(barrier, "barrier", gpu_core.SMEM)
return (), {state.ReadEffect(0), state.WriteEffect(1)}
@ -217,21 +233,13 @@ def _copy_gmem_to_smem_lowering(
return ()
def copy_gmem_to_smem(
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
barrier: pallas_core.AbstractMemoryRef,
) -> None:
def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None:
"""Asynchronously copies a GMEM reference to a SMEM reference.
See also:
:func:`jax.experimental.mosaic.gpu.barrier_arrive`
:func:`jax.experimental.mosaic.gpu.barrier_wait`
"""
if getattr(src, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
raise TypeError(f"src must be a GMEM reference, got {src.memory_space}")
if dst.memory_space is not gpu_core.SMEM:
raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}")
src, src_transforms = state_primitives.get_ref_and_transforms(
src, None, "copy_gmem_to_smem", force_trailing_indexer=False,
)
@ -291,8 +299,9 @@ barrier_arrive_p.multiple_results = True
@barrier_arrive_p.def_effectful_abstract_eval
def _barrier_arrive_abstract_eval(*avals, **params):
del avals, params # Unused.
def _barrier_arrive_abstract_eval(barrier, *args, **params):
del args, params # Unused.
_check_ref(barrier, "barrier", gpu_core.SMEM)
return (), {gpu_core._memory_effect}
@ -328,8 +337,9 @@ barrier_wait_p.multiple_results = True
@barrier_wait_p.def_effectful_abstract_eval
def _barrier_wait_abstract_eval(*avals, **params):
del avals, params # Unused.
def _barrier_wait_abstract_eval(barrier, *args, **params):
_check_ref(barrier, "barrier", gpu_core.SMEM)
del args, params # Unused.
return (), {gpu_core._memory_effect}
@ -703,38 +713,40 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
del layout, dimension
return jax_core.ShapedArray(shape, dtype)
@lowering.register_lowering_rule(broadcasted_iota_p)
def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout):
del ctx
# Unsigned integers (as opposed to signless) cause MLIR verification
# errors so we only use signless like Mosaic GPU does.
#
# TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead.
mlir_dtype = (
ir.IntegerType.get_signless(dtype.itemsize * 8)
if jnp.issubdtype(dtype, jnp.integer)
else mlir.dtype_to_ir_type(dtype)
)
undef = llvm_dialect.mlir_undef(mlir_dtype)
is_signed = (
jnp.issubdtype(dtype, jnp.signedinteger)
if jnp.issubdtype(dtype, jnp.integer)
else None
)
i32 = ir.IntegerType.get_signless(32)
def _cast(x):
if ir.FloatType.isinstance(mlir_dtype):
x = arith_dialect.index_cast(i32, x)
return arith_dialect.uitofp(mlir_dtype, x)
else:
return arith_dialect.index_cast(mlir_dtype, x)
def _broadcasted_iota_lowering(
ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout
):
del ctx # Unused.
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
if ir.FloatType.isinstance(mlir_dtype):
i32 = ir.IntegerType.get_signless(32)
cast = lambda x: arith_dialect.uitofp(
mlir_dtype, arith_dialect.index_cast(i32, x)
)
else:
cast = lambda x: arith_dialect.index_cast(mlir_dtype, x)
is_signed = mgpu_utils.is_signed(dtype)
return mgpu.FragmentedArray.splat(
undef, shape, layout.value, is_signed=is_signed
llvm_dialect.mlir_undef(mlir_dtype),
shape,
layout.value,
is_signed=is_signed,
).foreach(
lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed
lambda _, idx: cast(idx[dimension]),
create_array=True,
is_signed=is_signed,
)
def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None):
return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout)
def broadcasted_iota(
dtype: jax.typing.DTypeLike,
shape: Sequence[int],
dimension: int,
*,
layout: Layout | None = None,
) -> jax.Array:
return broadcasted_iota_p.bind(
dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout
)

View File

@ -461,8 +461,6 @@ class KeyTy(dtypes.ExtendedDType):
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x

View File

@ -249,8 +249,8 @@ class XlaExecutable(Executable):
else:
raise
# TODO(skyewm): this should return a single dict (I think returning a list
# was to support MPMD executables, which never fully landed)
# TODO(b/384741132): this should return a single dict (I think returning a list
# was to support MPMD executables, which never fully landed).
def cost_analysis(self) -> list[dict[str, float]]:
xla_ext_exe = self.xla_extension_executable()
@ -266,9 +266,19 @@ class XlaExecutable(Executable):
# Try client method if executable cost_analysis method is unimplemented
if hasattr(xla_ext_exe, "client"):
try:
# TODO(b/384741132): We expect that the executable has only one
# HloModule. We should be able to remove this check once we update the
# Executable class to have only a single HloModule (see bug).
hlo_modules = xla_ext_exe.hlo_modules()
assert len(hlo_modules) == 1, (
f"Exectuable should have only one HloModule ({len(hlo_modules)})"
" were found)."
)
return [
xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m)
for m in xla_ext_exe.hlo_modules()
xla_extension.hlo_module_cost_analysis(
xla_ext_exe.client, hlo_modules[0]
)
]
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args

View File

@ -382,6 +382,17 @@ def _lower_tpu_kernel(
pipeline.run(module.operation)
dump_mlir(module, "post-infer-vector-layout")
pipeline = [
(
"func.func(tpu-relayout-insertion{"
f" sublane-count={sl_cnt} lane-count={l_cnt}"
"})"
),
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-relayout-insertion")
mxu_size = 128 if hardware_generation < 6 else 256
pipeline = [
"func.func(tpu-apply-vector-layout{"

View File

@ -128,7 +128,7 @@ _deprecations = {
_src_core.escaped_tracer_error),
"extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.",
_src_core.extend_axis_env_nd),
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_type),
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval),
"get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent),
"join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects),
"leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.",
@ -212,7 +212,7 @@ if typing.TYPE_CHECKING:
escaped_tracer_error = _src_core.escaped_tracer_error
extend_axis_env_nd = _src_core.extend_axis_env_nd
full_lower = _src_core.full_lower
get_type = _src_core.get_type
get_type = _src_core.get_aval
get_referent = _src_core.get_referent
jaxpr_as_fun = _src_core.jaxpr_as_fun
join_effects = _src_core.join_effects

View File

@ -174,7 +174,6 @@ def _construct_smem_reftree(
) -> RefTree:
index = ir.IndexType.get()
i8 = ir.IntegerType.get_signless(8)
ptr = ir.Type.parse("!llvm.ptr")
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
@ -183,11 +182,19 @@ def _construct_smem_reftree(
for ref_ty in flat_ref_tys:
def get_barrier_ptr(num_barriers: int) -> ir.Value:
nonlocal dynamic_smem_offset
smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3)
barrier_base_ptr = llvm.getelementptr(
ptr, smem_base_ptr, [], [dynamic_smem_offset], i8
workgroup_nvptx_address_space = (
dialect_lowering.gpu_address_space_to_nvptx(
gpu.AddressSpace.Workgroup
)
)
dynamic_smem_offset += num_barriers * MBARRIER_BYTES
smem_base_ptr = utils.memref_ptr(
dynamic_smem, memory_space=workgroup_nvptx_address_space
)
smem_ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
barrier_base_ptr = llvm.getelementptr(
smem_ptr_ty, smem_base_ptr, [], [dynamic_smem_offset], i8
)
dynamic_smem_offset += num_barriers * utils.MBARRIER_BYTES
return barrier_base_ptr
match ref_ty:
case Union(members):
@ -227,9 +234,6 @@ def _construct_smem_reftree(
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
MBARRIER_BYTES = 8
def _smem_tree_size(smem_buffers: ShapeTree) -> int:
leaves = jax.tree.leaves(
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
@ -244,9 +248,9 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int:
| ClusterBarrier(_, num_barriers=num_barriers)
| Barrier(_, num_barriers=num_barriers)
):
if size % MBARRIER_BYTES:
if size % utils.MBARRIER_BYTES:
raise NotImplementedError("Misaligned barrier allocation")
size += num_barriers * MBARRIER_BYTES
size += num_barriers * utils.MBARRIER_BYTES
case _:
size += _count_buffer_bytes(l)
return size
@ -379,9 +383,11 @@ def _lower_as_gpu_kernel(
attrs["sym_name"] = ir.StringAttr.get(module_name)
if kernel_name is None:
kernel_name = getattr(body, "__name__", "anonymous")
# These are needed as nonlocal below.
launch_ctx, scratch_arr = None, None
with ir.InsertionPoint(module.body):
_declare_runtime_functions()
gmem_scratch_bytes = 0
global_scratch = llvm.GlobalOp(
ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet.
"global_scratch",
@ -390,7 +396,7 @@ def _lower_as_gpu_kernel(
)
@func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}")
def main(token_ptr, buffers):
nonlocal gmem_scratch_bytes
nonlocal launch_ctx, scratch_arr
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
arg_refs = []
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
@ -408,27 +414,40 @@ def _lower_as_gpu_kernel(
with _launch(
token, grid, cluster, block, scratch_arr, smem_scratch_shape,
prof_spec, prof_buffer
) as (launch_ctx, smem_refs):
) as (_launch_ctx, smem_refs):
nonlocal launch_ctx
launch_ctx = _launch_ctx
body(launch_ctx, *in_refs, *out_refs, smem_refs)
gmem_scratch_bytes = launch_ctx.next_scratch_offset
# Allocate and initialize the host buffer right before the launch.
# Note that we couldn't do that before, because we had to run the body
# to learn what the scratch contains.
with ir.InsertionPoint(scratch_arr.owner):
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty)
scratch_arr.set_type(scratch_arr_ty)
for init_callback in launch_ctx.host_scratch_init:
init_callback(scratch_alloc.result)
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
sym_tab = ir.SymbolTable(module.operation)
sym_tab.insert(main.func_op)
sym_tab.insert(global_scratch)
module.operation.verify()
return module, out_shape, unwrap_output_tuple
return module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr
def _initialize_scratch(
launch_ctx : launch_context.LaunchContext,
scratch_arr: ir.Value,
):
"""
Allocates and initializes the host buffer right before the launch. This needs
to be done after all TMA descriptors have been recorded by the launch context.
Only then we know what the scratch contains.
When using the Mosaic GPU dialect, the necessary information is known only
after the lowering passes have run.
"""
with ir.InsertionPoint(scratch_arr.owner):
gmem_scratch_bytes = launch_ctx.next_scratch_offset
scratch_alloc_op = scratch_arr.owner.opview.addr.owner.opview
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
scratch_alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty)
scratch_arr.set_type(scratch_arr_ty)
for init_callback in launch_ctx.host_scratch_init:
init_callback(scratch_alloc_op.result)
def _declare_runtime_functions():
"""Declares the runtime functions that can be used by the generated code."""
ptr_ty = ir.Type.parse("!llvm.ptr")
@ -462,7 +481,7 @@ def as_gpu_kernel(
elif not isinstance(in_shape, tuple):
in_shape = (in_shape,)
module, out_shape, unwrap_output_tuple = (
module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = (
_lower_as_gpu_kernel(
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
module_name, kernel_name, prof_spec
@ -473,7 +492,10 @@ def as_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr)
module.operation.verify()
expected_arg_treedef = jax.tree.structure(in_shape)
def _check_args(*args):
@ -530,6 +552,7 @@ def as_torch_gpu_kernel(
cluster: tuple[int, int, int] = (1, 1, 1),
module_name: str = "unknown",
kernel_name: str | None = None,
thread_semantics: ThreadSemantics = ThreadSemantics.Lane,
):
try:
import torch
@ -545,13 +568,22 @@ def as_torch_gpu_kernel(
flat_out_types, out_treedef = jax.tree.flatten(out_shape)
expected_arg_treedef = jax.tree.structure(in_shape)
module, out_shape, unwrap_output_tuple = (
module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = (
_lower_as_gpu_kernel(
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
module_name, kernel_name, prof_spec
)
)
if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None:
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr)
module.operation.verify()
# Get our hands on the compilation and unload functions
try:
import jax_plugins.xla_cuda12 as cuda_plugin

View File

@ -31,13 +31,14 @@ from jax._src.lib.mlir.dialects import vector
import numpy as np
from .fragmented_array import FragmentedArray, WGStridedFragLayout
from .launch_context import LaunchContext
from .layouts import from_strided_fragmented_layout_attr, has_any_layout_set, is_strided_fragmented_layout, should_have_layout, to_strided_fragmented_layout_attr
from .utils import c, ptr_as_memref, single_thread_predicate
from .utils import BarrierRef, c, memref_ptr, ptr_as_memref, single_thread_predicate
# mypy: ignore-errors
MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]]
MlirLoweringRule = Callable[[LaunchContext, ir.Operation | ir.OpView], Sequence[ir.Value]]
_lowerings: dict[str, MlirLoweringRule] = {}
@ -88,6 +89,9 @@ def _fragmented_array_from_ir(
[operand.type for operand in conversion_cast.operands],
conversion_cast.results,
)
if not isinstance(converted_outputs, list):
converted_outputs = [converted_outputs]
reverse_conversion_cast = converted_outputs[0].owner.opview
for attribute in conversion_cast.attributes:
@ -138,6 +142,7 @@ def gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int:
@_register_lowering(InitializeBarrierOp)
def _initialize_barrier_op_lowering_rule(
_: LaunchContext,
initialize_barrier_op: InitializeBarrierOp,
) -> Sequence[ir.Value]:
@ -170,7 +175,7 @@ def _initialize_barrier_op_lowering_rule(
@_register_lowering(vector.LoadOp)
def _vector_load_op_lowering_rule(
vector_load_op: vector.LoadOp,
_: LaunchContext, vector_load_op: vector.LoadOp
) -> Sequence[ir.Value]:
(out_layout_attr,) = cast(
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
@ -199,7 +204,7 @@ def _vector_load_op_lowering_rule(
@_register_lowering(vector.StoreOp)
def _vector_store_op_lowering_rule(
vector_store_op: vector.StoreOp,
_: LaunchContext, vector_store_op: vector.StoreOp
) -> Sequence[ir.Value]:
in_layout_attr, *_ = cast(
@ -229,8 +234,44 @@ def _vector_store_op_lowering_rule(
return []
@_register_lowering(mgpu.AsyncLoadOp)
def _mgpu_async_load_op_lowering_rule(
launch_context: LaunchContext, load_op: mgpu.AsyncLoadOp
) -> Sequence[ir.Value]:
mem_space = gpu_address_space_to_nvptx(gpu.AddressSpace.Workgroup)
# TODO(dasenov): Add support for the remaining op properties.
launch_context.async_copy(
src_ref=load_op.source,
dst_ref=load_op.destination,
barrier=BarrierRef(
base_address=memref_ptr(load_op.barrier, memory_space=mem_space),
offset=c(0, ir.IntegerType.get_signless(64)),
phases=None,
num_barriers=1,
),
arrive=load_op.arrive,
uniform=False,
)
return []
@_register_lowering(mgpu.AsyncStoreOp)
def _mgpu_async_store_op_lowering_rule(
launch_context: LaunchContext, store_op: mgpu.AsyncStoreOp
) -> Sequence[ir.Value]:
# TODO(dasenov): Add support for the remaining op properties.
launch_context.async_copy(
src_ref=store_op.source,
dst_ref=store_op.destination,
)
return []
@_register_lowering(arith.AddFOp)
def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:
def _arith_addf_op_lowering_rule(
_: LaunchContext, add: arith.AddFOp
) -> Sequence[ir.Value]:
fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
fragmented_array_rhs = _fragmented_array_from_ir(add.rhs)
@ -242,7 +283,7 @@ def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:
]
def lower_mgpu_dialect(module: ir.Module):
def lower_mgpu_dialect(module: ir.Module, launch_context: LaunchContext):
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
module.context.load_all_available_dialects()
@ -257,7 +298,7 @@ def lower_mgpu_dialect(module: ir.Module):
if should_have_layout(op) and not has_any_layout_set(op):
raise ValueError(f"{op} is missing a layout and can not be lowered.")
new_results = lowering_rule(op)
new_results = lowering_rule(launch_context, op)
for old, new in zip(op.results, new_results):
old.replace_all_uses_with(new)

View File

@ -36,24 +36,28 @@ from jaxlib.mlir.dialects import scf
from jaxlib.mlir.dialects import vector
import numpy as np
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401
# mypy: ignore-errors
WARPGROUP_SIZE: int = 128
DYNAMIC = -9223372036854775808
DYNAMIC32 = -2147483648
MBARRIER_BYTES = 8
# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = None):
i64 = ir.IntegerType.get_signless(64)
rank = len(memref_ty.shape)
ptr_ty = "ptr" if ptr_memory_space is None else f"ptr<{ptr_memory_space}>"
if rank > 0:
desc_ty = ir.Type.parse(
f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>"
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>, array<{rank} x i64>)>"
)
else:
desc_ty = ir.Type.parse("!llvm.struct<(ptr, ptr, i64)>")
desc_ty = ir.Type.parse(f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64)>")
desc = llvm.UndefOp(desc_ty)
desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation
desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base
@ -321,6 +325,8 @@ def bytewidth(ty: ir.Type):
return ir.IntegerType(ty).width // 8
if ir.FloatType.isinstance(ty):
return ir.FloatType(ty).width // 8
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
return MBARRIER_BYTES
raise NotImplementedError(ty)
@ -743,6 +749,18 @@ class BarrierRef:
ptr, self.base_address, [self.offset], [DYNAMIC32], i64
)
def as_dialect_barrier(self) -> ir.Value:
if self.num_barriers > 1:
raise NotImplementedError(
f"Only BarrierRef with num_barriers=1 is suppored in the MLIR "
f"Mosaic GPU dialect, but got num_barriers={self.num_barriers}"
)
return ptr_as_memref(
self.base_address,
ir.MemRefType.get((), ir.Type.parse("!mosaic_gpu.barrier")),
ptr_memory_space=3,
)
@dataclasses.dataclass(frozen=True)
class CollectiveBarrierRef:
@ -997,19 +1015,21 @@ def warp_tree_reduce(value, op, group_size):
def memref_ptr(memref_arg, memory_space=None):
i64 = ir.IntegerType.get_signless(64)
memref_ty = ir.MemRefType(memref_arg.type)
if len(memref_ty.shape) == 0:
raise NotImplementedError
elem_bytewidth = bytewidth(memref_ty.element_type)
rank = len(memref_ty.shape)
# TODO: Read out memory space from memref
space = "" if memory_space is None else "<" + str(memory_space) + ">"
ptr_ty = ir.Type.parse("!llvm.ptr" + space)
desc_ty = ir.Type.parse(
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
f" array<{rank} x i64>)>"
)
if rank == 0:
desc_ty = ir.Type.parse(f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64)>")
else:
desc_ty = ir.Type.parse(
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
f" array<{rank} x i64>)>"
)
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
elem_bytewidth = bytewidth(memref_ty.element_type)
offset_elems = llvm.extractvalue(i64, desc, [2])
offset_bytes = llvm.mul(
offset_elems,

View File

@ -724,7 +724,6 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
aval_in, aval_out, x):
if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1:
return x
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
axes = {name: i for i, ns in names.items() for name in ns}
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
@ -734,6 +733,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
unspecified = set(range(aval_in.ndim)) if auto else set()
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
unspecified_dims=unspecified)
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
@ -746,6 +746,8 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
ns = sharding_impls.physical_sharding(aval_out, ns)
aval_out = core.physical_aval(aval_out)
unspecified = set(range(aval_out.ndim)) if auto else set()
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
aval_in = core.physical_aval(aval_in)
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified)
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()

View File

@ -36,17 +36,17 @@ _deprecations = {
"jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.",
None,
),
# Added Sep 26 2024
# Finalized 2024-12-23; remove after 2024-03-23
"Device": (
"jax.lib.xla_client.Device is deprecated; use jax.Device instead.",
_xc.Device,
None,
),
"XlaRuntimeError": (
(
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
" jax.errors.JaxRuntimeError."
),
_xc.XlaRuntimeError,
None,
),
# Added Oct 10 2024
"FftType": (

View File

@ -54,7 +54,8 @@ py_extension(
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -69,8 +70,8 @@ py_extension(
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIGPUHeaders",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -84,7 +85,8 @@ py_extension(
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIGPUHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -99,8 +101,8 @@ py_extension(
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPINVGPUHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -115,8 +117,8 @@ py_extension(
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPILLVMHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -130,8 +132,8 @@ py_extension(
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -145,7 +147,8 @@ py_extension(
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
"@pybind11",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@nanobind",
],
)
@ -155,9 +158,10 @@ py_extension(
copts = COPTS,
linkopts = LINKOPTS,
deps = [
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
":jaxlib_mlir_capi_shared_library",
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
"@nanobind",
],
)
@ -377,6 +381,7 @@ cc_library(
name = "jaxlib_mlir_capi_objects",
deps = [
"//jaxlib/mosaic:tpu_dialect_capi_objects",
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects",
"@llvm-project//mlir:CAPIArithObjects",
"@llvm-project//mlir:CAPIGPUObjects",
"@llvm-project//mlir:CAPIIRObjects",

View File

@ -28,6 +28,7 @@ package(
py_library(
name = "mosaic",
deps = [
"//jaxlib/mosaic/python:gpu_dialect",
"//jaxlib/mosaic/python:tpu_dialect",
],
)
@ -42,6 +43,7 @@ cc_library(
"dialect/tpu/tpu_dialect.cc",
"dialect/tpu/tpu_ops.cc",
"dialect/tpu/util.cc",
"dialect/tpu/vreg_util.cc",
":extension_srcs",
] + glob([
"dialect/tpu/transforms/*.cc",
@ -50,6 +52,7 @@ cc_library(
"dialect/tpu/layout.h",
"dialect/tpu/tpu_dialect.h",
"dialect/tpu/util.h",
"dialect/tpu/vreg_util.h",
] + glob([
"dialect/tpu/transforms/*.h",
]),
@ -231,6 +234,19 @@ cc_library(
alwayslink = True,
)
cc_test(
name = "vreg_util_test",
srcs = ["dialect/tpu/vreg_util_test.cc"],
deps = [
":tpu_dialect",
"//testing/base/public:gunit_main",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:VectorDialect",
],
)
filegroup(
name = "extension_srcs",
srcs = [

View File

@ -215,3 +215,26 @@ cc_library(
"@llvm-project//mlir:CAPIIR",
],
)
# Header-only target, used when using the C API from a separate shared library.
cc_library(
name = "gpu_dialect_capi_headers",
hdrs = DIALECT_CAPI_HEADERS,
deps = [
":mosaic_gpu_inc_gen",
"@llvm-project//mlir:CAPIIRHeaders",
],
)
# Alwayslink target, used when exporting the C API from a shared library.
cc_library(
name = "gpu_dialect_capi_objects",
srcs = DIALECT_CAPI_SOURCES,
hdrs = DIALECT_CAPI_HEADERS,
deps = [
":mosaic_gpu",
":mosaic_gpu_inc_gen",
"@llvm-project//mlir:CAPIIRObjects",
],
alwayslink = True,
)

View File

@ -19,8 +19,8 @@ limitations under the License.
#include <optional>
#include <string>
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"

View File

@ -852,6 +852,19 @@ def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncO
];
}
def RelayoutInsertionPass : Pass<"tpu-relayout-insertion", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::arith::ArithDialect",
"::mlir::func::FuncDialect",
"::mlir::tpu::TPUDialect",
];
let constructor = "::mlir::tpu::createRelayoutInsertionPass()";
let options = [
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
];
}
def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::arith::ArithDialect",

View File

@ -81,6 +81,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
std::array<int64_t, 2> target_shape = {8, 128});
std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
std::array<int64_t, 2> target_shape = {8, 128});
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{});

View File

@ -29,7 +29,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
@ -52,6 +51,7 @@
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/ADT/APInt.h"
#include "llvm/include/llvm/Support/LogicalResult.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
@ -64,6 +64,7 @@
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
#include "xla/array.h"
#include "xla/layout.h"
#include "xla/util.h"
@ -79,41 +80,6 @@ namespace mlir::tpu {
#define GEN_PASS_DEF_APPLYVECTORLAYOUTPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
// TPU_ASSERT_* macros should be understood as an assert, i.e. use it to check
// things that should never happen. We prefer returning failure over a CHECK
// because it's easier to debug from Python (particularly from OSS where symbols
// are removed)
#define TPU_ASSERT_IMPL(stream, cond) \
if (LLVM_UNLIKELY(!(cond))) { \
(stream) << "Internal error: assert failed: " #cond; \
}
#define TPU_ASSERT_CMP_IMPL(stream, lhs, rhs, cmp) \
if (LLVM_UNLIKELY(!((lhs)cmp(rhs)))) { \
(stream) << "Internal error: assert failed: " #lhs " " #cmp " " #rhs " (" \
<< (lhs) << " vs. " << (rhs) << ")"; \
return failure(); \
}
#define TPU_ASSERT_OP(cond) TPU_ASSERT_IMPL(op.emitOpError(), cond)
#define TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, cmp) \
TPU_ASSERT_CMP_IMPL(op.emitOpError(), lhs, rhs, cmp)
#define TPU_ASSERT_EQ_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, ==)
#define TPU_ASSERT_GE_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, >=)
#define TPU_ASSERT_GT_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, >)
#define TPU_ASSERT_LE_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, <=)
#define TPU_ASSERT_LT_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, <)
#define TPU_ASSERT_LOC(loc, cond) TPU_ASSERT_IMPL(mlir::emitError(loc), cond)
#define TPU_ASSERT_CMP_LOC_IMPL(loc, lhs, rhs, cmp) \
TPU_ASSERT_CMP_IMPL(loc, lhs, rhs, cmp)
#define TPU_ASSERT_EQ_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, ==)
#define TPU_ASSERT_GE_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, >=)
#define TPU_ASSERT_GT_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, >)
#define TPU_ASSERT_LT_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <)
#define TPU_ASSERT_LE_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <=)
// The minimum bound required to rotate with scratch space. The bound refers to
// the number of VREGs on rotation dim. This number was concluded from some cost
@ -310,16 +276,6 @@ void updateSliceFromRange(xla::Array<T> &arr, Range data,
CHECK(data_it == data.end());
}
FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
if (isa<FloatType>(ty)) {
return TypedAttr(FloatAttr::get(ty, 0));
}
if (isa<IntegerType>(ty)) {
return TypedAttr(IntegerAttr::get(ty, 0));
}
return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty;
}
FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
@ -514,33 +470,6 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
return argument;
}
VectorType getNativeVregOrVmaskTypeImpl(
Type elem_ty, const int8_t bitwidth,
const std::array<int64_t, 2> target_shape) {
if (bitwidth == 32) {
return VectorType::get(target_shape, elem_ty);
}
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
elem_ty);
}
VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
const std::array<int64_t, 2> target_shape) {
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
if (bitwidth == 1) {
bitwidth = layout_bitwidth;
} else {
CHECK_EQ(bitwidth, layout_bitwidth);
}
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
}
VectorType getNativeVregType(Type elem_ty,
const std::array<int64_t, 2> target_shape) {
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
target_shape);
}
// Masks all values outside of bounds.
//
// Arguments:
@ -553,7 +482,7 @@ VectorType getNativeVregType(Type elem_ty,
// Returns:
// An MLIR value of the same type as the value argument, with all entries
// outside of bounds replaced by neutral.
FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
FailureOr<Value> maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder,
TypedValue<VectorType> value,
const VRegDataBounds &bounds,
const Attribute neutral) {
@ -577,86 +506,12 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
value.getLoc(),
VectorType::get(native_vreg_ty.getShape(), builder.getI1Type()), mask);
}
auto neutral_vec = builder.create<arith::ConstantOp>(
value.getLoc(), native_vreg_ty,
DenseElementsAttr::get(native_vreg_ty, neutral));
Value neutral_vec = getFullVector(builder, native_vreg_ty, neutral);
return builder
.create<arith::SelectOp>(value.getLoc(), mask, value, neutral_vec)
.getResult();
}
// Returns empty vector on null attribute
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
SmallVector<Layout> out_layouts;
out_layouts.reserve(array_attr.size());
for (const Attribute a : array_attr) {
if (auto layout_attr = dyn_cast_if_present<VectorLayoutAttr>(a)) {
out_layouts.push_back(layout_attr.getLayout());
} else {
return failure();
}
}
return out_layouts;
}
return SmallVector<Layout>{};
}
bool layoutIsValidForValue(const Layout &l, const Value v,
const std::array<int64_t, 2> target_shape) {
// l must be non-null iff v is of vector type
if (const auto vty = dyn_cast<VectorType>(v.getType())) {
if (!l.has_value()) {
return false;
}
// Vector type should have the same bitwidth as the layout, except for the
// i1 special case, used for vmasks (see comment for VectorLayout class).
if (!vty.getElementType().isIntOrFloat()) {
return false;
}
const int8_t bitwidth = vty.getElementTypeBitWidth();
if (bitwidth != l->bitwidth() && bitwidth != 1) {
return false;
}
return l->isValid(target_shape) && l->layout_rank() <= vty.getRank();
}
return !l.has_value();
}
// TODO(tlongeri): Unify with infer_vector_layout.cc's getOutLayout.
FailureOr<SmallVector<Layout>> getOutLayouts(
Operation &op, const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> out_layouts,
getLayoutArrayFromAttr(op.getAttr("out_layout")));
if (out_layouts.size() != op.getNumResults()) {
return op.emitOpError("out_layout size does not match number of results");
}
for (const auto [l, res] : llvm::zip_equal(out_layouts, op.getResults())) {
if (!layoutIsValidForValue(l, res, target_shape)) {
return op.emitOpError("Invalid output layout");
}
}
return out_layouts;
}
FailureOr<SmallVector<Layout>> getInLayouts(
Operation &op, const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
getLayoutArrayFromAttr(op.getAttr("in_layout")));
if (in_layouts.size() != op.getNumOperands()) {
return op.emitOpError("in_layout size does not match number of operands");
}
for (const auto [l, operand] :
llvm::zip_equal(in_layouts, op.getOperands())) {
if (!layoutIsValidForValue(l, operand, target_shape)) {
return op.emitOpError("Invalid input layout");
}
}
return in_layouts;
}
// Insert a minor dimension to the implicit shape. The original minor dimension
// becomes the new second minor dimension, laid out across sublanes.
//
@ -1970,130 +1825,30 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]);
TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]);
const VectorType i32_vreg_ty =
getNativeVregType(builder.getI32Type(), ctx.target_shape);
auto getX32VmaskByPaddingEnd = [&](int64_t dim, int64_t padding) {
CHECK(dim == 0 || dim == 1);
CHECK(padding >= 0 && padding <= ctx.target_shape[dim]);
return cast<TypedValue<VectorType>>(
builder
.create<arith::CmpIOp>(
arith::CmpIPredicate::slt,
builder.create<tpu::IotaOp>(i32_vreg_ty,
builder.getI32IntegerAttr(dim)),
builder.create<arith::ConstantOp>(DenseElementsAttr::get(
i32_vreg_ty, builder.getI32IntegerAttr(
ctx.target_shape[dim] - padding))))
.getResult());
};
// We can also extend this helper function with padding_top and padding_left
// based on the offsets in vregs.
const Value i32_zeros_vreg = builder.create<arith::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0)));
const Value i32_max_vreg = builder.create<arith::ConstantOp>(
op.getLoc(), DenseElementsAttr::get(
i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff)));
auto maskVregs = [&](xla::Array<Value> &vregs, int64_t padding_bottom,
int64_t padding_right) {
auto vreg_ty = cast<VectorType>(vregs.begin()->getType());
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
// Mask out the bottom.
if (padding_bottom > 0) {
// We have limited the row size of LHS and RHS need to be a multiple of
// native tiling at the beginning of this rule. Therefore, it is safe to
// bitcast to x32 vreg for masking.
int sub_padding = padding_bottom % packing;
int x32_padding_bottom = padding_bottom / packing;
auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom);
// Create an int32 vreg which contains subelement masking and then
// logical_and with target vreg to mask out the unaligned paddings.
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
// [8, 128], then the mask will be:
//
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
// sublane 6: [0 , 0 , ..., 0 ]
// sublane 7: [0 , 0 , ..., 0 ]
//
// Through this way, in order to mask sub-elements, each target vreg only
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
// + packing).
Value partial_sublane_mask = builder.create<arith::ConstantOp>(
op.getLoc(),
DenseElementsAttr::get(
i32_vreg_ty,
builder.getI32IntegerAttr(
0xffffffff >>
(sub_padding * vreg_ty.getElementTypeBitWidth()))));
// Insert 0xffffffff above the blended sublane.
Value sublane_mask = builder.create<arith::SelectOp>(
getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg,
partial_sublane_mask);
// Insert 0 below the blended sublane.
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
i32_zeros_vreg);
for (int64_t i = 0; i < vregs.dim(1); ++i) {
Value &vreg = vregs({vregs.dim(0) - 1, i});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
if (sub_padding > 0) {
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
} else {
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
i32_zeros_vreg);
}
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
// Mask out the right.
if (padding_right > 0) {
auto mask_right = getX32VmaskByPaddingEnd(1, padding_right);
for (int64_t i = 0; i < vregs.dim(0); ++i) {
Value &vreg = vregs({i, vregs.dim(1) - 1});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
i32_vreg = builder.create<arith::SelectOp>(mask_right, i32_vreg,
i32_zeros_vreg);
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
};
// Create a vreg filled with zeros.
auto getZerosVergLike =
[&](const Value &vreg) -> FailureOr<TypedValue<VectorType>> {
const VectorType vreg_type = cast<VectorType>(vreg.getType());
FAILUREOR_ASSIGN_OR_RETURN(
const Attribute zero_attr,
getZeroIntOrFloatAttr(vreg_type.getElementType()));
return cast<TypedValue<VectorType>>(
builder
.create<arith::ConstantOp>(
op.getLoc(), DenseElementsAttr::get(vreg_type, zero_attr))
.getResult());
};
FAILUREOR_ASSIGN_OR_RETURN(auto lhs_zeros_vreg,
getZerosVergLike(*lhs_vregs.begin()));
FAILUREOR_ASSIGN_OR_RETURN(auto rhs_zeros_vreg,
getZerosVergLike(*rhs_vregs.begin()));
FAILUREOR_ASSIGN_OR_RETURN(auto acc_zeros_vreg,
getZerosVergLike(*acc_vregs.begin()));
auto lhs_zeros_vreg =
getZerosVector(builder, cast<VectorType>(lhs_vregs.begin()->getType()));
auto rhs_zeros_vreg =
getZerosVector(builder, cast<VectorType>(rhs_vregs.begin()->getType()));
auto acc_zeros_vreg =
getZerosVector(builder, cast<VectorType>(acc_vregs.begin()->getType()));
// Only mask out the paddings on contracting dim of LHS and RHS.
maskVregs(lhs_vregs, 0, padded_lhs_cols - lhs_shape[1]);
RETURN_IF_FAILED(
maskNativeTilingVregs(builder, lhs_vregs, ctx.target_shape,
/*padding_bottom=*/0,
/*padding_right=*/padded_lhs_cols - lhs_shape[1]));
if (transpose_rhs) {
maskVregs(rhs_vregs, 0, padded_rhs_cols - rhs_shape[1]);
RETURN_IF_FAILED(maskNativeTilingVregs(
builder, rhs_vregs, ctx.target_shape,
/*padding_bottom=*/0,
/*padding_right=*/padded_rhs_cols - rhs_shape[1]));
} else {
maskVregs(rhs_vregs, padded_rhs_rows - rhs_shape[0], 0);
RETURN_IF_FAILED(
maskNativeTilingVregs(builder, rhs_vregs, ctx.target_shape,
/*padding_bottom=*/padded_rhs_rows - rhs_shape[0],
/*padding_right=*/0));
}
// TODO(b/328094640): use latch 3 for short dimensions.
// TODO(b/328093587): Skip zeros vreg matmul
// At this point, all paddings on vregs are masked out. For now, we
// append zero vregs to make LHS's second dim, both RHS's dims and ACC's
// second dim to be a multiple of mxu_size.
@ -2984,12 +2739,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
native_vreg_ty,
/*dimension =*/builder.getI32IntegerAttr(1));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 1))));
Value offset = getFullVector(
builder, native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 1)));
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
@ -3011,12 +2764,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
native_vreg_ty,
/*dimension =*/builder.getI32IntegerAttr(0));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 2))));
Value offset = getFullVector(
builder, native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 2)));
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
@ -3033,10 +2784,8 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
SmallVector<Value> tiles;
tiles.reserve(vty.getDimSize(*dimension));
for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) {
tiles.push_back(builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(native_vreg_ty,
IntegerAttr::get(vty.getElementType(), i))));
tiles.push_back(getFullVector(builder, native_vreg_ty,
IntegerAttr::get(vty.getElementType(), i)));
}
xla::Array<Value> out_tiles(tile_array_shape);
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
@ -3625,12 +3374,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
const int64_t offset = *offsets_in[1];
const int64_t lane_offset = offset % ctx.target_shape[1];
const int64_t tile_offset = offset / ctx.target_shape[1];
const auto idx_ty =
VectorType::get(ctx.target_shape, builder.getI32Type());
auto lane_offset_cst = builder.create<arith::ConstantOp>(
broadcast_op.getLoc(), idx_ty,
DenseElementsAttr::get(idx_ty,
builder.getI32IntegerAttr(lane_offset)));
Value lane_offset_cst = getFullVector(
builder, getNativeVregType(builder.getI32Type(), ctx.target_shape),
builder.getI32IntegerAttr(lane_offset));
DenseI32ArrayAttr sublane_pattern;
if (num_tiles != 1) {
SmallVector<int32_t> pattern;
@ -3690,10 +3436,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
getNativeVregType(src_i32.getType(), ctx.target_shape);
auto tile_i32 =
builder.create<vector::BroadcastOp>(native_vreg_ty, src_i32);
auto zeros = builder.create<arith::ConstantOp>(
broadcast_op.getLoc(), tile_i32.getType(),
DenseElementsAttr::get(tile_i32.getType(),
builder.getI32IntegerAttr(0)));
Value zeros = getZerosVector(builder, tile_i32.getType());
auto tile =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::ne, tile_i32, zeros)
.getResult();
@ -5588,8 +5331,6 @@ FailureOr<xla::Array<Value>> doColumnShiftRelayout(
const std::array<int64_t, 2> vreg_slice = src.vregSlice(target_shape);
const int bitwidth = src.bitwidth();
const int packing = src.packing();
const VectorLayout dst(bitwidth, {src.offsets()[0], dst_col_offset}, tiling,
src.implicit_dim());
const int64_t col_diff = dst_col_offset - *src.offsets()[1];
if (tiling[0] % packing != 0 || tiling[1] != target_shape[1]) {
return emitError(loc,
@ -5932,7 +5673,8 @@ LogicalResult retileToLargeTileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
TypedValue<MemRefType> scratch_ref) {
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
const int64_t load_vreg_skips) {
if (dst_tile[0] % src_tile[0] != 0) {
return failure();
}
@ -6036,8 +5778,8 @@ LogicalResult retileToLargeTileWithScratch(
SmallVector<int64_t, 4> src_idx(rank);
dst_tiles.Each([&](absl::Span<const int64_t> dst_idx, Value *dst_vreg) {
int64_t dst_row_idx = *(dst_idx.end() - 2);
int64_t dst_col_idx = *(dst_idx.end() - 1);
int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group;
int64_t dst_col_idx_with_skips = *(dst_idx.end() - 1) + load_vreg_skips;
int64_t vreg_idx_in_group = dst_col_idx_with_skips % vregs_per_group;
int64_t load_offset = sublanes_per_group * stored_group_cnt +
vreg_idx_in_group * sl_per_vreg * stride;
delayed_loads.push_back(
@ -6047,16 +5789,20 @@ LogicalResult retileToLargeTileWithScratch(
// the vregs from current group and now we need to store corresponding
// group of src vregs before actually emitting the loads.
if (vreg_idx_in_group == vregs_per_group - 1 ||
dst_col_idx == dst_tiles.dimensions().back() - 1) {
auto src_row_idx = dst_row_idx * vregs_per_group;
auto src_col_idx = dst_col_idx / vregs_per_group;
dst_idx.back() == dst_tiles.dimensions().back() - 1) {
auto base_src_row_idx = dst_row_idx * vregs_per_group - store_vreg_delay;
auto src_col_idx = dst_col_idx_with_skips / vregs_per_group;
std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin());
for (int vi = 0; vi < vregs_per_group; ++vi) {
if (src_row_idx + vi >= src_tiles.dim(rank - 2) ||
const int64_t src_row_idx = base_src_row_idx + vi;
if (src_row_idx < 0) {
continue;
}
if (src_row_idx >= src_tiles.dim(rank - 2) ||
src_col_idx >= src_tiles.dim(rank - 1)) {
break;
}
*(src_idx.end() - 2) = src_row_idx + vi;
*(src_idx.end() - 2) = src_row_idx;
*(src_idx.end() - 1) = src_col_idx;
Value src_vreg = src_tiles(src_idx);
src_vreg =
@ -6085,7 +5831,8 @@ LogicalResult retileToSmallTileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
TypedValue<MemRefType> scratch_ref) {
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
const int64_t load_vreg_skips) {
if (src_tile[0] % dst_tile[0] != 0) {
return failure();
}
@ -6212,8 +5959,8 @@ LogicalResult retileToSmallTileWithScratch(
SmallVector<int64_t, 4> dst_idx(rank);
src_tiles.Each([&](absl::Span<const int64_t> src_idx, Value src_vreg) {
int64_t src_row_idx = *(src_idx.end() - 2);
int64_t src_col_idx = *(src_idx.end() - 1);
int64_t vreg_idx_in_group = src_col_idx % vregs_per_group;
int64_t src_col_idx_with_delays = *(src_idx.end() - 1) + store_vreg_delay;
int64_t vreg_idx_in_group = src_col_idx_with_delays % vregs_per_group;
src_vreg = builder.create<tpu::BitcastVregOp>(loc, temp_vreg_ty, src_vreg);
if (use_shuffled_load) {
Value store_offset = mlirIndexConst(
@ -6235,16 +5982,20 @@ LogicalResult retileToSmallTileWithScratch(
// vregs' row, this indicates we have stored all the vregs needed to
// assemble a new group of dst vreg.
if (vreg_idx_in_group == vregs_per_group - 1 ||
src_col_idx == src_tiles.dimensions().back() - 1) {
auto dst_row_idx = src_row_idx * vregs_per_group;
auto dst_col_idx = src_col_idx / vregs_per_group;
src_idx.back() == src_tiles.dimensions().back() - 1) {
auto base_dst_row_idx = src_row_idx * vregs_per_group - load_vreg_skips;
auto dst_col_idx = src_col_idx_with_delays / vregs_per_group;
std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin());
for (int vi = 0; vi < vregs_per_group; ++vi) {
if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) ||
const int64_t dst_row_idx = base_dst_row_idx + vi;
if (dst_row_idx < 0) {
continue;
}
if (dst_row_idx >= dst_tiles.dim(rank - 2) ||
dst_col_idx >= dst_tiles.dim(rank - 1)) {
break;
}
*(dst_idx.end() - 2) = dst_row_idx + vi;
*(dst_idx.end() - 2) = dst_row_idx;
*(dst_idx.end() - 1) = dst_col_idx;
Value *dst_vreg = &dst_tiles(dst_idx);
int64_t load_offset =
@ -6269,18 +6020,70 @@ LogicalResult retileToSmallTileWithScratch(
// go/mosaic-retiling-in-scratch is the full internal documentation that
// includes more details about the TPU generations.
LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
const Location loc,
xla::Array<Value> &dst_tiles,
const std::array<int64_t, 2> &dst_tiling,
const xla::Array<Value> &src_tiles,
const std::array<int64_t, 2> &src_tiling,
int packing) {
// Arguments:
// - shape: The non-implicit shape of the operand
// - dst_tiling: The desired result tiling
// - dst_offsets_hint: Hints for the result offsets. They may be used or
// ignored. See comments in the body of the function for
// more details.
// - src_vregs: The source vregs to retile.
// - src: The source layout
// Returns a pair holding the result layout (potentially using the hints) and
// the retiled vregs.
// TODO(tlongeri): Clean up the function parameters/signatures. We are passing
// in more information than strictly needed.
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> retileWithScratch(
RewriteContext &ctx, OpBuilder &builder, const Location loc,
const ArrayRef<int64_t> shape, const std::array<int64_t, 2> dst_tiling,
const LayoutOffsets dst_offsets_hint, const xla::Array<Value> &src_vregs,
const VectorLayout &src) {
const int bitwidth = src.bitwidth();
const int packing = src.packing();
const std::array<int64_t, 2> src_tiling = src.tiling();
if (!(src_tiling[1] == ctx.target_shape[1] &&
dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 &&
dst_tiling[0] % packing == 0)) {
return failure();
}
const std::array<int64_t, 2> src_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling);
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
// TODO(b/368088671): When sublane tiling changes, we should be able to
// preserve some replications from the source layout. But we need to
// make sure they are implemented efficiently and well-tested. For now, we
// just simply use 0 for the replicated offset after retiling.
const LayoutOffsets src_offsets = {src.offsets()[0].value_or(0),
src.offsets()[1].value_or(0)};
// The provided offset hints are used only if they align with the source
// offsets, else we default to the smallest possible aligned offsets.
LayoutOffsets dst_offsets = {*src_offsets[0] % dst_vreg_slice[0],
*src_offsets[1] % dst_vreg_slice[1]};
// On a given dimension, either the source vreg slice size divides the dest
// vreg slice size, or vice versa (depending on the dimension and whether it's
// small-to-large or large-to-small retiling). Offset changes are supported
// as long as they are aligned modulo the smaller of the two sizes.
const std::array<int64_t, 2> alignment = {
std::min(src_vreg_slice[0], dst_vreg_slice[0]),
std::min(src_vreg_slice[1], dst_vreg_slice[1])};
if (dst_offsets_hint[0].has_value() &&
(*dst_offsets_hint[0] - *src_offsets[0]) % alignment[0] == 0) {
CHECK_LT(*dst_offsets_hint[0], dst_vreg_slice[0]);
dst_offsets[0] = *dst_offsets_hint[0];
}
if (dst_offsets_hint[1].has_value() &&
(*dst_offsets_hint[1] - *src_offsets[1]) % alignment[1] == 0) {
CHECK_LT(*dst_offsets_hint[1], dst_vreg_slice[1]);
dst_offsets[1] = *dst_offsets_hint[1];
}
// The offsets of the source in units of the destination vreg slice:
const std::array<int64_t, 2> src_offsets_in_dst_vreg_slices = {
*src_offsets[0] / dst_vreg_slice[0], *src_offsets[1] / dst_vreg_slice[1]};
// The offsets of the destination in units of the source vreg slice:
const std::array<int64_t, 2> dst_offsets_in_src_vreg_slices = {
*dst_offsets[0] / src_vreg_slice[0], *dst_offsets[1] / src_vreg_slice[1]};
// Try to get i32 vector scratch space. Because we will bitcast vregs to
// i32 vregs before using scratch for retiling. Through this way we can
// handle packed types as well.
@ -6295,24 +6098,57 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
dst_tiling[1]};
std::array<int64_t, 2> vi32_src_tiling = {src_tiling[0] / packing,
src_tiling[1]};
const VectorLayout dst(bitwidth, dst_offsets, dst_tiling, src.implicit_dim());
TPU_ASSERT_LOC(loc, dst.isValid(ctx.target_shape));
xla::Array<Value> dst_vregs(
dst.tileArrayImplicitShape(shape, ctx.target_shape));
// When differences in offsets exist, the source vregs may stored at an offset
// position in their group. For example, the 1st vreg in a row/column may be
// stored as if it was the 3rd, so that the parts corresponding to the 1st and
// 2nd in the destination are filled with padding. Likewise, loads to
// destination vregs may be skipped, when they would load only padding.
// store_vreg_delay is the position offset for stores, and load_vreg_skips is
// the position offset for loads.
//
// For example, suppose we are going from 32-bit {0, 128}(2, 128) to
// {4, 0}(8, 128). We form groups of 4 vregs that represent an (8, 512) slice
// of the padded implicit shape. For the given offsets, for the first group,
// the data is in (4:8, 128:512). But the first and second sources (stored
// vregs) of the group form the slices of data (0:2, 0:512) and (2:4, 0:512),
// which should be all padding. Likewise, the first dest vreg slice (which we
// load from) holds the data from slice (0:8, 0:128), which is all padding.
// We never load or store to slices that should contain only padding.
if (src_tiling[0] > dst_tiling[0]) {
return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles,
vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref);
DCHECK_EQ(src_offsets_in_dst_vreg_slices[1], 0);
DCHECK_EQ(dst_offsets_in_src_vreg_slices[0], 0);
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[1];
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[0];
if (failed(retileToSmallTileWithScratch(
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
return failure();
}
}
if (src_tiling[0] < dst_tiling[0]) {
return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles,
vi32_dst_tiling, src_tiles,
vi32_src_tiling, ref);
DCHECK_EQ(src_offsets_in_dst_vreg_slices[0], 0);
DCHECK_EQ(dst_offsets_in_src_vreg_slices[1], 0);
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[0];
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[1];
if (failed(retileToLargeTileWithScratch(
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
return failure();
}
}
dst_tiles = std::move(src_tiles);
return success();
return std::make_pair(dst, dst_vregs);
}
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty,
const VectorLayout src, xla::Array<Value> vregs,
const std::array<int64_t, 2> dst_tiling, bool try_replicate_rows) {
const std::array<int64_t, 2> dst_tiling,
const LayoutOffsets dst_offsets_hint) {
bool has_enough_scratch = ctx.max_sublanes_in_scratch >=
ctx.target_shape[0] * (ctx.target_shape[0] + 1);
const auto &target_shape = ctx.target_shape;
@ -6328,6 +6164,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
const int8_t bitwidth = src.bitwidth();
const std::array<int64_t, 2> dst_vreg_slice =
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
// TODO(tlongeri): Using canonical vs non-canonical offsets can change the
// value of try_replicate rows, and it breaks some tests. It doesn't make
// sense that we have different behavior for equivalent layouts, though. We
// need better logic for picking the relayout strategy.
const bool try_replicate_rows =
src.offsets()[0].has_value() && !dst_offsets_hint[0].has_value();
// Fully replicated offsets are handled efficiently elsewhere (in relayout)
CHECK(src.offsets()[0].has_value() || src.offsets()[1].has_value());
@ -6399,15 +6241,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
});
return std::pair(dst, std::move(retiled));
}
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
if (!dst.isValid(target_shape)) {
return emitError(loc, "Not implemented: invalid offsets in tiling target");
}
auto dst_tiles_shape =
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
// (8,128) -> (8 * packing,128) tiling change for packed type.
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
@ -6417,8 +6254,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// not, since it relies on the src vreg array shape to know how many tiles
// to pack in dst, and vreg array shapes with materialized offsets are
// unfortunately not equal to vreg array shapes with replicated offsets.
CHECK(dst.offsets() == src_offsets);
xla::Array<Value> retiled(dst_tiles_shape);
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
@ -6466,7 +6305,9 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// interesting if the next step is a retile, since we can also
// match corresponding elements without shifting. It's just that
// the tiles are not adjacent (no contiguous vreg slice).
if (bitwidth < 32 && 32 % bitwidth == 0 &&
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
@ -6515,8 +6356,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// not, since it relies on the src vreg array shape to know how many tiles
// to pack in dst, and vreg array shapes with materialized offsets are
// unfortunately not equal to vreg array shapes with replicated offsets.
CHECK(dst.offsets() == src.offsets());
xla::Array<Value> retiled(dst_tiles_shape);
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
xla::Array<Value> retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
const VectorType vreg_x32 =
vty.getElementType().isSignlessInteger()
? VectorType::get(target_shape, builder.getI32Type())
@ -6553,24 +6396,25 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
return std::pair(dst, std::move(retiled));
}
if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) {
// TODO(b/368088671): When sublane tiling changes, we should be able to
// preserve some replications from the source layout. But we need to
// make sure they are implemented efficiently and well-tested. For now, we
// just simply use 0 for the replicated offset after retiling.
dst = VectorLayout(
bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)},
dst_tiling, dst.implicit_dim());
// All clauses in the and expression are based on performance benchmarking.
bool use_alu = !has_enough_scratch ||
(ctx.hardware_generation >= 5 && src_tiling[0] != packing &&
dst_tiling[0] != packing);
if (use_alu) {
if (src_tiling[0] > dst_tiling[0]) {
return std::pair(
dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs,
dst, target_shape));
if (src_tiling[0] > dst_tiling[0] &&
// retileToReducedSublanes does not support offset changes
src.offsets()[0].value_or(0) < dst_vreg_slice[0] &&
src.offsets()[1].value_or(0) < dst_vreg_slice[1]) {
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
src.implicit_dim());
return std::pair(dst, retileToReducedSublanes(
builder, vty.getShape(), src, vregs,
VectorLayout(bitwidth,
{src.offsets()[0].value_or(0),
src.offsets()[1].value_or(0)},
dst_tiling, dst.implicit_dim()),
target_shape));
} else if (!has_enough_scratch) {
// TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops.
return emitError(
@ -6578,15 +6422,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
"Not implemented: retiling to increase sublane tiling with ALU");
}
}
xla::Array<Value> retiled(dst_tiles_shape);
if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs,
src_tiling, packing))) {
return failure();
}
return std::pair(dst, std::move(retiled));
return retileWithScratch(ctx, builder, loc, vty.getShape(), dst_tiling,
dst_offsets_hint, vregs, src);
}
return emitError(loc, "Not implemented: Unsupported tiling change for ")
<< vty << ": from " << src << " to " << dst;
<< vty << ": from " << src << " to (" << dst_tiling[0] << ", "
<< dst_tiling[1] << ") tiling";
}
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
@ -6846,9 +6687,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles),
dst.tiling(),
dst.offsets()[0] == std::nullopt &&
src.offsets()[0] != std::nullopt));
dst.tiling(), dst.offsets()));
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
@ -6894,10 +6733,11 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
// arguments.
auto op_result = dyn_cast<OpResult>(vector_operand);
if (op_result == nullptr) {
return op.emitError("Expected operand to be an operation result");
return op.emitError(
"Expected vector operand to be an operation result");
}
Operation *const def_op = op_result.getOwner();
TPU_ASSERT_OP(def_op);
DCHECK(def_op);
const unsigned res_idx = op_result.getResultNumber();
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> def_layouts,
getOutLayouts(*def_op, ctx.target_shape));

View File

@ -66,22 +66,6 @@ using ImplicitDim = VectorLayout::ImplicitDim;
static constexpr int kLayoutLog = 10;
class Print {
public:
explicit Print(Operation *t) : payload_(t) {}
Operation *payload_;
private:
friend std::ostream &operator<<(std::ostream &, Print);
};
std::ostream &operator<<(std::ostream &os, Print p) {
std::string s;
llvm::raw_string_ostream tmp_os(s);
p.payload_->print(tmp_os);
os << tmp_os.str();
return os;
}
bool is_fully_replicated(const Layout &layout) {
static LayoutOffsets replicated_offsets = {std::nullopt, std::nullopt};
@ -142,8 +126,7 @@ class VectorLayoutInferer {
// TODO: b/342235360 - This check is temporary while we increase and test
// support for offsets outside of the first tile. When support is more
// broad, any op without support should check it within their own rule.
if (!isa<arith::TruncIOp, arith::TruncFOp, vector::BroadcastOp,
vector::ExtractStridedSliceOp>(any_op)) {
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp>(any_op)) {
const SmallVector<Layout> layouts_in = getLayoutFromOperands(&any_op);
for (const Layout &layout : layouts_in) {
if (layout &&
@ -1664,10 +1647,17 @@ class VectorLayoutInferer {
Layout dst_layout;
if (layout.tiling() == nativeTiling(src_bitwidth)) {
// If the source is already in native tiling, we can unpack it directly.
src_layout = layout;
std::array<int64_t, 2> dst_native_tiling = nativeTiling(dst_bitwidth);
LayoutOffsets offsets = {layout.offsets()[0]
? *layout.offsets()[0] % dst_native_tiling[0]
: LayoutOffset(),
layout.offsets()[1]};
DCHECK_LT(offsets[1].value_or(0), dst_native_tiling[1]);
src_layout = VectorLayout(src_bitwidth, offsets, layout.tiling(),
layout.implicit_dim());
dst_layout =
VectorLayout(dst_bitwidth, layout.offsets(),
nativeTiling(dst_bitwidth), layout.implicit_dim());
VectorLayout(dst_bitwidth, offsets, dst_native_tiling,
layout.implicit_dim());
} else if (dst_bitwidth == 32 &&
default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
@ -1676,13 +1666,17 @@ class VectorLayoutInferer {
// tiling through the op.
// TODO(jevinjiang): we can relax this for non-32bit as well.
src_layout = layout;
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
layout.implicit_dim());
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(),
src_layout->tiling(), layout.implicit_dim());
} else {
// TODO(b/335863273): we should also reduce offsets.
src_layout = VectorLayout(src_bitwidth, layout.offsets(), default_tiling_,
LayoutOffsets offsets = {
layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0]
: LayoutOffset(),
layout.offsets()[1] ? *layout.offsets()[1] % default_tiling_[1]
: LayoutOffset()};
src_layout = VectorLayout(src_bitwidth, offsets, default_tiling_,
layout.implicit_dim());
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), default_tiling_,
dst_layout = VectorLayout(dst_bitwidth, offsets, default_tiling_,
layout.implicit_dim());
}
setLayout(op, src_layout, dst_layout);
@ -1700,22 +1694,23 @@ class VectorLayoutInferer {
auto dst_ty = cast<VectorType>(op->getResult(0).getType());
auto some_layout = getLayout(op->getOperand(0));
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
const unsigned src_bitwidth = src_ty.getElementTypeBitWidth();
const unsigned dst_bitwidth = dst_ty.getElementTypeBitWidth();
if (isa<arith::TruncFOp>(op)) {
TPU_CHECK_OP(
src_bitwidth == 32 && (dst_bitwidth == 16 || dst_bitwidth == 8),
"Only 32-bit to 16-bit or 8-bit float truncation supported");
if (dyn_cast<arith::TruncFOp>(op)) {
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 &&
(dst_ty.getElementTypeBitWidth() == 16 ||
dst_ty.getElementTypeBitWidth() == 8),
"Only 32-bit to 8-bit or 16-bit truncation supported");
} else {
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32,
"Only 32-bit truncation supported");
}
auto &layout = *some_layout;
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
auto src_layout = VectorLayout(
src_bitwidth, layout.offsets(),
select_native ? nativeTiling(src_bitwidth) : layout.tiling(),
layout.implicit_dim());
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
auto dst_layout = VectorLayout(
dst_bitwidth, layout.offsets(),
select_native ? nativeTiling(dst_bitwidth) : layout.tiling(),
dst_ty.getElementTypeBitWidth(), layout.offsets(),
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
: default_tiling_,
layout.implicit_dim());
setLayout(op, src_layout, dst_layout);
return success();
@ -1928,46 +1923,6 @@ class VectorLayoutInferer {
});
}
void setInLayout(Operation *op, ArrayRef<Layout> in) {
CHECK_EQ(in.size(), op->getNumOperands()) << Print(op);
SmallVector<Attribute, 4> in_attrs;
in_attrs.reserve(in.size());
for (const Layout &p : in) {
in_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
}
op->setAttr("in_layout", ArrayAttr::get(op->getContext(), in_attrs));
}
void setOutLayout(Operation *op, Layout out) {
setOutLayout(op, ArrayRef<Layout>(out));
}
void setOutLayout(Operation *op, ArrayRef<Layout> out) {
SmallVector<Attribute, 4> out_attrs;
out_attrs.reserve(out.size());
for (const Layout &p : out) {
out_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
}
op->setAttr("out_layout", ArrayAttr::get(op->getContext(), out_attrs));
}
void setLayout(Operation *op, Layout in, Layout out) {
setLayout(op, ArrayRef<Layout>(in), ArrayRef<Layout>(out));
}
void setLayout(Operation *op, ArrayRef<Layout> in, Layout out) {
setLayout(op, in, ArrayRef<Layout>(out));
}
void setLayout(Operation *op, Layout in, ArrayRef<Layout> out) {
setLayout(op, ArrayRef<Layout>(in), out);
}
void setLayout(Operation *op, ArrayRef<Layout> in, ArrayRef<Layout> out) {
setInLayout(op, in);
setOutLayout(op, out);
}
Layout getLayout(Value v) {
auto op = v.getDefiningOp();
CHECK(op);

View File

@ -0,0 +1,166 @@
#include <array>
#include <cstdint>
#include <memory>
#include <optional>
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "absl/log/check.h"
#include "llvm/include/llvm/ADT/STLExtras.h"
#include "llvm/include/llvm/Support/MathExtras.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/Diagnostics.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
namespace mlir::tpu {
#define GEN_PASS_DECL_RELAYOUTINSERTIONPASS
#define GEN_PASS_DEF_RELAYOUTINSERTIONPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
namespace {
FailureOr<TypedValue<VectorType>> relayout(
OpBuilder &builder, TypedValue<VectorType> v, VectorLayout src,
VectorLayout dst, const std::array<int64_t, 2> target_shape) {
// change bitwidth
if (v.getType().getElementType() == builder.getI1Type() &&
// TODO(jevinjiang): for other relayout changes (tiling, offsets, implicit
// dim), we currently rely on apply-vector-layout pass to do the relayout.
src.bitwidth() != dst.bitwidth()) {
CHECK(llvm::isPowerOf2_32(src.bitwidth()));
CHECK(llvm::isPowerOf2_32(dst.bitwidth()));
auto make_vty = [&](int bitwidth) {
return VectorType::get(v.getType().getShape(),
builder.getIntegerType(bitwidth));
};
auto make_constant = [&](int val, VectorLayout layout) {
auto vty = make_vty(layout.bitwidth());
auto constant_op = builder.create<arith::ConstantOp>(
v.getLoc(),
DenseElementsAttr::get(
vty, builder.getIntegerAttr(vty.getElementType(), val)));
setOutLayout(constant_op,
VectorLayout(layout.bitwidth(), {std::nullopt, std::nullopt},
layout.tiling(), layout.implicit_dim()));
return constant_op;
};
auto src_int_vty = make_vty(src.bitwidth());
auto dst_int_vty = make_vty(dst.bitwidth());
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
auto dst_bitwidth_layout = VectorLayout(
dst.bitwidth(),
{
src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0]
: LayoutOffset(),
src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1]
: LayoutOffset(),
},
src.tiling(), src.implicit_dim());
auto ext_op = builder.create<arith::ExtUIOp>(v.getLoc(), src_int_vty, v);
setLayout(ext_op, src, src);
// TODO(jevinjiang): some conversion might not be supported in HW.
Operation *cast_op =
dst.bitwidth() > src.bitwidth()
? builder.create<arith::ExtSIOp>(v.getLoc(), dst_int_vty, ext_op)
// TODO(jevinjiang): HW may support pack vmask directly.
: builder.create<arith::TruncIOp>(v.getLoc(), dst_int_vty, ext_op);
setLayout(cast_op, src, dst_bitwidth_layout);
auto cmp_op = builder.create<arith::CmpIOp>(
v.getLoc(), v.getType(), arith::CmpIPredicate::ne,
cast_op->getResult(0), make_constant(0, dst_bitwidth_layout));
setLayout(cmp_op, {dst_bitwidth_layout, dst_bitwidth_layout},
dst_bitwidth_layout);
return cast<TypedValue<VectorType>>(cmp_op.getResult());
}
return v;
}
// TODO(jevinjiang): make relayout to an op so we don't need decide when to
// relayout in apply-vector-layout pass.
LogicalResult insertRelayout(Operation &op,
const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
getInLayouts(op, target_shape));
if (in_layouts.size() != op.getNumOperands()) {
return op.emitError("Expected the same number of operands as in_layouts");
}
if (isa<tpu::AssumeLayoutOp>(op)) {
return success();
}
// Relayout the operands, if their requested input layouts don't match the
// layouts in which they were produced.
for (auto [idx, tup] :
llvm::enumerate(llvm::zip(op.getOperands(), in_layouts))) {
auto [operand, li] = tup;
auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand);
TPU_ASSERT_EQ_OP(vector_operand != nullptr, li.has_value());
if (vector_operand == nullptr) {
continue;
}
// The operand should always be an Operation (and not a BlockArgument)
// since we expect the FuncOp to have only memrefs and semaphores as
// arguments.
auto op_result = dyn_cast<OpResult>(vector_operand);
if (op_result == nullptr) {
return op.emitError("Expected vector operand to be an operation result");
}
Operation *const def_op = op_result.getOwner();
DCHECK(def_op);
const unsigned res_idx = op_result.getResultNumber();
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> def_layouts,
getOutLayouts(*def_op, target_shape));
const Layout lo = def_layouts[res_idx];
TPU_ASSERT_OP(lo.has_value());
if (*lo == *li) {
continue;
}
OpBuilder builder(&op);
FAILUREOR_ASSIGN_OR_RETURN(Value new_v,
relayout(builder, vector_operand, /*src=*/*lo,
/*dst=*/*li, target_shape));
op.setOperand(idx, new_v);
}
return success();
}
struct RelayoutInsertionPass
: public impl::RelayoutInsertionPassBase<RelayoutInsertionPass> {
RelayoutInsertionPass(std::array<int64_t, 2> target_shape) {
this->sublane_count = target_shape[0];
this->lane_count = target_shape[1];
}
void runOnOperation() override {
func::FuncOp func = getOperation();
auto result = func.walk([&](Operation *op) {
if (insertRelayout(*op, {sublane_count, lane_count}).failed()) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (result.wasInterrupted()) {
signalPassFailure();
return;
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
std::array<int64_t, 2> target_shape) {
return std::make_unique<RelayoutInsertionPass>(target_shape);
}
} // namespace mlir::tpu

View File

@ -18,17 +18,33 @@ limitations under the License.
#include <array>
#include <cstdint>
#include <optional>
#include <ostream>
#include <string>
#include <utility>
#include "llvm/Support/MathExtras.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/Support/raw_ostream.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/IR/ValueRange.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
namespace mlir::tpu {
std::ostream &operator<<(std::ostream &os, Print p) {
std::string s;
llvm::raw_string_ostream tmp_os(s);
p.payload_->print(tmp_os);
os << tmp_os.str();
return os;
}
SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
absl::Span<const int64_t> tiling) {
SmallVector<int64_t> tile_strides(memref_ty.getRank());
@ -147,4 +163,115 @@ bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space) {
dyn_cast_or_null<tpu::MemorySpaceAttr>(ty.getMemorySpace());
return memory_space && memory_space.getValue() == space;
}
bool layoutIsValidForValue(const Layout &l, const Value v,
const std::array<int64_t, 2> target_shape) {
// l must be non-null iff v is of vector type
if (const auto vty = dyn_cast<VectorType>(v.getType())) {
if (!l.has_value()) {
return false;
}
// Vector type should have the same bitwidth as the layout, except for the
// i1 special case, used for vmasks (see comment for VectorLayout class).
if (!vty.getElementType().isIntOrFloat()) {
return false;
}
const int8_t bitwidth = vty.getElementTypeBitWidth();
if (bitwidth != l->bitwidth() && bitwidth != 1) {
return false;
}
return l->isValid(target_shape) && l->layout_rank() <= vty.getRank();
}
return !l.has_value();
}
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
SmallVector<Layout> out_layouts;
out_layouts.reserve(array_attr.size());
for (const Attribute a : array_attr) {
if (auto layout_attr = dyn_cast_if_present<VectorLayoutAttr>(a)) {
out_layouts.push_back(layout_attr.getLayout());
} else {
return failure();
}
}
return out_layouts;
}
return SmallVector<Layout>{};
}
// TODO(tlongeri, jevinjiang): Unify with infer_vector_layout.cc's getOutLayout.
FailureOr<SmallVector<Layout>> getOutLayouts(
Operation &op, const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> out_layouts,
getLayoutArrayFromAttr(op.getAttr("out_layout")));
if (out_layouts.size() != op.getNumResults()) {
return op.emitOpError("out_layout size does not match number of results");
}
for (const auto [l, res] : llvm::zip_equal(out_layouts, op.getResults())) {
if (!layoutIsValidForValue(l, res, target_shape)) {
return op.emitOpError("Invalid output layout");
}
}
return out_layouts;
}
FailureOr<SmallVector<Layout>> getInLayouts(
Operation &op, const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
getLayoutArrayFromAttr(op.getAttr("in_layout")));
if (in_layouts.size() != op.getNumOperands()) {
return op.emitOpError("in_layout size does not match number of operands");
}
for (const auto [l, operand] :
llvm::zip_equal(in_layouts, op.getOperands())) {
if (!layoutIsValidForValue(l, operand, target_shape)) {
return op.emitOpError("Invalid input layout");
}
}
return in_layouts;
}
void setInLayout(Operation *op, ArrayRef<Layout> in) {
CHECK_EQ(in.size(), op->getNumOperands()) << Print(op);
SmallVector<Attribute, 4> in_attrs;
in_attrs.reserve(in.size());
for (const Layout &p : in) {
in_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
}
op->setAttr("in_layout", ArrayAttr::get(op->getContext(), in_attrs));
}
void setOutLayout(Operation *op, Layout out) {
setOutLayout(op, ArrayRef<Layout>(out));
}
void setOutLayout(Operation *op, ArrayRef<Layout> out) {
SmallVector<Attribute, 4> out_attrs;
out_attrs.reserve(out.size());
for (const Layout &p : out) {
out_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
}
op->setAttr("out_layout", ArrayAttr::get(op->getContext(), out_attrs));
}
void setLayout(Operation *op, Layout in, Layout out) {
setLayout(op, ArrayRef<Layout>(in), ArrayRef<Layout>(out));
}
void setLayout(Operation *op, ArrayRef<Layout> in, Layout out) {
setLayout(op, in, ArrayRef<Layout>(out));
}
void setLayout(Operation *op, Layout in, ArrayRef<Layout> out) {
setLayout(op, ArrayRef<Layout>(in), out);
}
void setLayout(Operation *op, ArrayRef<Layout> in, ArrayRef<Layout> out) {
setInLayout(op, in);
setOutLayout(op, out);
}
} // namespace mlir::tpu

View File

@ -3,9 +3,12 @@
#include <array>
#include <cstdint>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <type_traits>
#include <utility>
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
@ -15,8 +18,11 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/types/span.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/Value.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "tsl/platform/statusor.h"
// TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with
// MLIR diagnostics?
@ -31,18 +37,86 @@
// } \
// } while (false)
#define FAILUREOR_ASSIGN_OR_RETURN_IMPL(failureor, lhs, rhs) \
auto failureor = rhs; \
if (failed(failureor)) { \
return failure(); \
} \
lhs = std::move(failureor).value();
// All the macros below here are to handle the case in
// FAILUREOR_ASSIGN_OR_RETURN where the LHS is wrapped in parentheses. See a
// more detailed discussion at https://stackoverflow.com/a/62984543
#define FAILUREOR_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(X) \
FAILUREOR_ASSIGN_OR_RETURN_ESCAPE(FAILUREOR_ASSIGN_OR_RETURN_EMPTY X)
#define FAILUREOR_ASSIGN_OR_RETURN_EMPTY(...) \
FAILUREOR_ASSIGN_OR_RETURN_EMPTY __VA_ARGS__
#define FAILUREOR_ASSIGN_OR_RETURN_ESCAPE(...) \
FAILUREOR_ASSIGN_OR_RETURN_ESCAPE_(__VA_ARGS__)
#define FAILUREOR_ASSIGN_OR_RETURN_ESCAPE_(...) \
FAILUREOR_ASSIGN_OR_RETURN_##__VA_ARGS__
#define FAILUREOR_ASSIGN_OR_RETURN_FAILUREOR_ASSIGN_OR_RETURN_EMPTY
#define FAILUREOR_ASSIGN_OR_RETURN_IMPL(failureor, lhs, rhs) \
auto failureor = rhs; \
if (failed(failureor)) { \
return failure(); \
} \
FAILUREOR_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \
(std::move(failureor).value());
#define FAILUREOR_ASSIGN_OR_RETURN(lhs, rhs) \
FAILUREOR_ASSIGN_OR_RETURN_IMPL( \
TF_STATUS_MACROS_CONCAT_NAME(failureor, __COUNTER__), lhs, rhs)
#define RETURN_IF_FAILED(...) \
do { \
if (failed(__VA_ARGS__)) { \
return failure(); \
} \
} while (false)
namespace mlir::tpu {
// TPU_ASSERT_* macros should be understood as an assert, i.e. use it to check
// things that should never happen. We prefer returning failure over a CHECK
// because it's easier to debug from Python (particularly from OSS where symbols
// are removed)
#define TPU_ASSERT_IMPL(stream, cond) \
if (LLVM_UNLIKELY(!(cond))) { \
(stream) << "Internal error: assert failed: " #cond; \
}
#define TPU_ASSERT_CMP_IMPL(stream, lhs, rhs, cmp) \
if (LLVM_UNLIKELY(!((lhs)cmp(rhs)))) { \
(stream) << "Internal error: assert failed: " #lhs " " #cmp " " #rhs " (" \
<< (lhs) << " vs. " << (rhs) << ")"; \
return failure(); \
}
#define TPU_ASSERT_OP(cond) TPU_ASSERT_IMPL(op.emitOpError(), cond)
#define TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, cmp) \
TPU_ASSERT_CMP_IMPL(op.emitOpError(), lhs, rhs, cmp)
#define TPU_ASSERT_EQ_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, ==)
#define TPU_ASSERT_GE_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, >=)
#define TPU_ASSERT_GT_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, >)
#define TPU_ASSERT_LE_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, <=)
#define TPU_ASSERT_LT_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, <)
#define TPU_ASSERT_LOC(loc, cond) TPU_ASSERT_IMPL(mlir::emitError(loc), cond)
#define TPU_ASSERT_CMP_LOC_IMPL(loc, lhs, rhs, cmp) \
TPU_ASSERT_CMP_IMPL(loc, lhs, rhs, cmp)
#define TPU_ASSERT_EQ_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, ==)
#define TPU_ASSERT_GE_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, >=)
#define TPU_ASSERT_GT_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, >)
#define TPU_ASSERT_LT_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <)
#define TPU_ASSERT_LE_LOC(loc, lhs, rhs) \
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <=)
class Print {
public:
explicit Print(Operation *t) : payload_(t) {}
Operation *payload_;
private:
friend std::ostream &operator<<(std::ostream &, Print);
};
std::ostream &operator<<(std::ostream &os, Print p);
template <bool adjust_bool = false>
FailureOr<int8_t> getTypeBitwidth(Type ty) {
if (auto integer_ty = dyn_cast<IntegerType>(ty)) {
@ -117,6 +191,26 @@ bool canReinterpretToUntiledMemref(TypedValue<MemRefType> tiled_memref,
// Determines whether the given MemRefType has the given memory space.
bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space);
bool layoutIsValidForValue(const Layout &l, const Value v,
const std::array<int64_t, 2> target_shape);
// Returns empty vector on null attribute
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr);
FailureOr<SmallVector<Layout>> getOutLayouts(
Operation &op, const std::array<int64_t, 2> target_shape);
FailureOr<SmallVector<Layout>> getInLayouts(
Operation &op, const std::array<int64_t, 2> target_shape);
void setInLayout(Operation *op, ArrayRef<Layout> in);
void setOutLayout(Operation *op, Layout out);
void setOutLayout(Operation *op, ArrayRef<Layout> out);
void setLayout(Operation *op, Layout in, Layout out);
void setLayout(Operation *op, ArrayRef<Layout> in, Layout out);
void setLayout(Operation *op, Layout in, ArrayRef<Layout> out);
void setLayout(Operation *op, ArrayRef<Layout> in, ArrayRef<Layout> out);
} // namespace mlir::tpu
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_

View File

@ -0,0 +1,206 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
#include <array>
#include <cstdint>
#include "absl/log/check.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Diagnostics.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/Types.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/array.h"
namespace mlir::tpu {
namespace {
VectorType getNativeVregOrVmaskTypeImpl(
Type elem_ty, const int8_t bitwidth,
const std::array<int64_t, 2> target_shape) {
if (bitwidth == 32) {
return VectorType::get(target_shape, elem_ty);
}
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
elem_ty);
}
} // namespace
VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
const std::array<int64_t, 2> target_shape) {
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
if (bitwidth == 1) {
bitwidth = layout_bitwidth;
} else {
CHECK_EQ(bitwidth, layout_bitwidth);
}
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
}
VectorType getNativeVregType(Type elem_ty,
const std::array<int64_t, 2> target_shape) {
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
target_shape);
}
TypedValue<VectorType> getFullVector(ImplicitLocOpBuilder &builder,
VectorType vty, Attribute value) {
return cast<TypedValue<VectorType>>(
builder.create<arith::ConstantOp>(DenseElementsAttr::get(vty, value))
.getResult());
}
TypedValue<VectorType> getFullLikeVector(ImplicitLocOpBuilder &builder,
TypedValue<VectorType> vec,
Attribute value) {
return getFullVector(builder, vec.getType(), value);
}
TypedValue<VectorType> getZerosVector(ImplicitLocOpBuilder &builder,
VectorType vty) {
return getFullVector(builder, vty, builder.getZeroAttr(vty.getElementType()));
}
TypedValue<VectorType> getZerosLikeVector(ImplicitLocOpBuilder &builder,
TypedValue<VectorType> vec) {
return getZerosVector(builder, vec.getType());
}
FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
ImplicitLocOpBuilder &builder, int64_t padding,
const std::array<int64_t, 2> target_shape, int64_t dim) {
VectorType i32_vreg_ty =
getNativeVregType(builder.getI32Type(), target_shape);
if (dim != 0 && dim != 1) {
return builder.emitError()
<< "Expected a 2D vector for getX32VmaskByPaddingEnd";
}
if (padding < 0 || padding > target_shape[dim]) {
return builder.emitError()
<< "Padding must be in [0, target_shape[dim]). Padding: " << padding
<< ", target_shape[dim]: " << target_shape[dim];
}
Value padding_vreg =
getFullVector(builder, i32_vreg_ty,
builder.getI32IntegerAttr(target_shape[dim] - padding));
return cast<TypedValue<VectorType>>(
builder
.create<arith::CmpIOp>(
arith::CmpIPredicate::slt,
builder.create<tpu::IotaOp>(i32_vreg_ty,
builder.getI32IntegerAttr(dim)),
padding_vreg)
.getResult());
}
LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder,
xla::Array<Value> &vregs,
std::array<int64_t, 2> target_shape,
int64_t padding_bottom,
int64_t padding_right) {
auto vreg_ty = dyn_cast<VectorType>(vregs.begin()->getType());
if (!vreg_ty) {
return builder.emitError() << "Expected a vector type";
}
VectorType i32_vreg_ty =
getNativeVregType(builder.getI32Type(), target_shape);
Value i32_zeros_vreg = getZerosVector(builder, i32_vreg_ty);
Value i32_max_vreg = getFullVector(builder, i32_vreg_ty,
builder.getI32IntegerAttr(0xffffffff));
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
// Mask out the bottom.
if (padding_bottom > 0) {
// The function is only called when the vreg has native tiling. Therefore,
// it is safe to bitcast to x32 vreg for masking.
int sub_padding = padding_bottom % packing;
int x32_padding_bottom = padding_bottom / packing;
FAILUREOR_ASSIGN_OR_RETURN(
Value mask_top, getX32VmaskByPaddingEnd(builder, x32_padding_bottom + 1,
target_shape, /*dim=*/0));
FAILUREOR_ASSIGN_OR_RETURN(
Value mask_bottom,
getX32VmaskByPaddingEnd(builder, x32_padding_bottom, target_shape,
/*dim=*/0));
// Create an int32 vreg which contains subelement masking and then
// logical_and with target vreg to mask out the unaligned paddings.
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
// [8, 128], then the mask will be:
//
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
// sublane 6: [0 , 0 , ..., 0 ]
// sublane 7: [0 , 0 , ..., 0 ]
//
// Through this way, in order to mask sub-elements, each target vreg only
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
// + packing).
Value partial_sublane_mask = getFullVector(
builder, i32_vreg_ty,
builder.getI32IntegerAttr(
0xffffffff >> (sub_padding * vreg_ty.getElementTypeBitWidth())));
// Insert 0xffffffff above the blended sublane.
Value sublane_mask = builder.create<arith::SelectOp>(mask_top, i32_max_vreg,
partial_sublane_mask);
// Insert 0 below the blended sublane.
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
i32_zeros_vreg);
for (int64_t i = 0; i < vregs.dim(1); ++i) {
Value &vreg = vregs({vregs.dim(0) - 1, i});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
if (sub_padding > 0) {
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
} else {
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
i32_zeros_vreg);
}
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
// Mask out the right.
if (padding_right > 0) {
FAILUREOR_ASSIGN_OR_RETURN(
Value mask_right, getX32VmaskByPaddingEnd(builder, padding_right,
target_shape, /*dim=*/1));
for (int64_t i = 0; i < vregs.dim(0); ++i) {
Value &vreg = vregs({i, vregs.dim(1) - 1});
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
i32_vreg =
builder.create<arith::SelectOp>(mask_right, i32_vreg, i32_zeros_vreg);
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
}
}
return success();
}
} // namespace mlir::tpu

View File

@ -0,0 +1,82 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
#include <array>
#include <cstdint>
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/Types.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "xla/array.h"
namespace mlir::tpu {
// Returns the native vreg or vmask type for the given element type and target
// shape. The layout bitwidth is used for i1 (vmask) elements.
VectorType getNativeVregOrVmaskType(Type elem_ty, int8_t layout_bitwidth,
std::array<int64_t, 2> target_shape);
VectorType getNativeVregType(Type elem_ty, std::array<int64_t, 2> target_shape);
// Returns a zero constant of the same type as `vty`.
TypedValue<VectorType> getZerosVector(ImplicitLocOpBuilder &builder,
VectorType vty);
// Same as above, but takes a `vec` as input.
TypedValue<VectorType> getZerosLikeVector(ImplicitLocOpBuilder &builder,
TypedValue<VectorType> vec);
// Returns a constant of the same type as `vty` with the given `value`.
TypedValue<VectorType> getFullVector(ImplicitLocOpBuilder &builder,
VectorType vty, Attribute value);
// Same as above, but takes a `vec` as input.
TypedValue<VectorType> getFullLikeVector(ImplicitLocOpBuilder &builder,
TypedValue<VectorType> vec,
Attribute value);
// Creates a vmask with false flags to bottom (dim = 0)
// or right (dim = 1) where the flag count corresponds to the (dim_size -
// padding).
//
// For example, assume vmask shape is (4, 8)
//
// getX32VmaskByPaddingEnd(padding=3, dim=1) creates:
// [T, T, T, T, T, F, F, F]
// [T, T, T, T, T, F, F, F]
// [T, T, T, T, T, F, F, F]
// [T, T, T, T, T, F, F, F]
// TODO(b/385204135): Unify with getVmaskByPaddingEnd in tpu_rotate_rule, and
// improve the codegen.
FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
ImplicitLocOpBuilder &builder, int64_t padding,
std::array<int64_t, 2> target_shape, int64_t dim);
// Masks out the padding in the bottom and right of the vregs. vregs are
// expected to have native tiling, and the masked vregs are mutated in
// `vregs`. `padding_bottom` and `padding_right` is the number of elements to
// pad in the bottom and right.
LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder,
xla::Array<Value> &vregs,
std::array<int64_t, 2> target_shape,
int64_t padding_bottom,
int64_t padding_right);
} // namespace mlir::tpu
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_

View File

@ -0,0 +1,228 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
#include <array>
#include <cstdint>
#include <memory>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/MLIRContext.h"
#include "mlir/include/mlir/IR/OwningOpRef.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/DebugStringHelper.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
namespace mlir::tpu {
namespace {
using ::testing::Eq;
using ::testing::Optional;
MATCHER_P2(IsConstantOpWithSplatValue, type, splat_value, "") {
auto constant_op = dyn_cast<arith::ConstantOp>(arg.getDefiningOp());
if (constant_op == nullptr) {
*result_listener << "Expected a constant op, got " << debugString(arg);
return false;
}
auto dense_attr = dyn_cast<DenseElementsAttr>(constant_op.getValue());
if (dense_attr == nullptr) {
*result_listener << "Expected a dense elements attr, got "
<< debugString(arg);
return false;
}
if (dense_attr.getType() != type) {
*result_listener << "Expected a dense elements attr with type "
<< debugString(type) << ", got "
<< debugString(dense_attr.getType());
return false;
}
if (!dense_attr.isSplat()) {
*result_listener << "Expected a splat dense elements attr, got "
<< debugString(dense_attr);
return false;
}
if (auto s = dense_attr.template getSplatValue<decltype(splat_value)>();
s != splat_value) {
*result_listener << "Expected a splat dense elements attr with value "
<< splat_value << ", got " << s;
return false;
}
return true;
}
MATCHER_P2(IsVectorTypeWithShape, shape, elem_ty, "") {
auto vty = dyn_cast<VectorType>(arg);
if (vty == nullptr) {
*result_listener << "Expected a vector type, got " << debugString(arg);
return false;
}
if (vty.getShape() != ArrayRef<int64_t>(shape)) {
*result_listener << "Expected a vector type with shape "
<< absl::StrJoin(shape, ",") << ", got "
<< absl::StrJoin(vty.getShape(), ",");
return false;
}
if (vty.getElementType() != elem_ty) {
*result_listener << "Expected a vector type with element type "
<< debugString(elem_ty) << ", got "
<< debugString(vty.getElementType());
return false;
}
return true;
}
class VregUtilTest : public ::testing::Test {
protected:
void SetUp() override {
context_.loadDialect<arith::ArithDialect, vector::VectorDialect,
tpu::TPUDialect>();
mlir::Location loc = mlir::UnknownLoc::get(&context_);
mlir::OpBuilder b(&context_);
module_ = b.create<ModuleOp>(loc);
builder_ = std::make_unique<mlir::ImplicitLocOpBuilder>(
module_->getLoc(), module_->getBodyRegion());
}
void TearDown() override {
builder_.reset();
// Reset the module to prevent memory leaks.
module_ = nullptr;
}
mlir::ImplicitLocOpBuilder& Builder() { return *builder_; }
private:
MLIRContext context_;
std::unique_ptr<mlir::ImplicitLocOpBuilder> builder_;
OwningOpRef<ModuleOp> module_;
};
TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeBitwidthMismatch) {
EXPECT_DEATH(getNativeVregOrVmaskType(Builder().getI16Type(),
/*layout_bitwidth=*/8, {2, 4}),
"");
}
TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeI1) {
EXPECT_THAT(getNativeVregOrVmaskType(Builder().getI1Type(),
/*layout_bitwidth=*/8, {2, 4}),
IsVectorTypeWithShape(std::array<int64_t, 3>{2, 4, 4},
Builder().getI1Type()));
}
TEST_F(VregUtilTest, GetNativeVregF32) {
EXPECT_THAT(getNativeVregType(Builder().getF32Type(), {2, 4}),
IsVectorTypeWithShape(std::array<int64_t, 2>{2, 4},
Builder().getF32Type()));
}
TEST_F(VregUtilTest, GetNativeVregBf16) {
EXPECT_THAT(getNativeVregType(Builder().getBF16Type(), {2, 4}),
IsVectorTypeWithShape(std::array<int64_t, 3>{2, 4, 2},
Builder().getBF16Type()));
}
TEST_F(VregUtilTest, GetFullVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getI32Type());
TypedValue<VectorType> vec =
getFullVector(Builder(), vty, Builder().getI32IntegerAttr(0x1));
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0x1}));
}
TEST_F(VregUtilTest, GetFullLikeVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
vty, Builder().create<arith::ConstantOp>(
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
TypedValue<VectorType> vec =
getFullLikeVector(Builder(), in_vec, Builder().getF32FloatAttr(2.0f));
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{2.0f}));
}
TEST_F(VregUtilTest, GetZerosVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getI32Type());
TypedValue<VectorType> vec = getZerosVector(Builder(), vty);
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0}));
}
TEST_F(VregUtilTest, GetZerosLikeVector) {
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
vty, Builder().create<arith::ConstantOp>(
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
TypedValue<VectorType> vec = getZerosLikeVector(Builder(), in_vec);
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{0.0f}));
}
TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim0) {
constexpr std::array<int64_t, 2> kTargetShape = {4, 8};
FailureOr<TypedValue<VectorType>> vec = getX32VmaskByPaddingEnd(
Builder(), /*padding=*/1, /*target_shape=*/kTargetShape,
/*dim=*/0);
ASSERT_TRUE(succeeded(vec));
auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
ASSERT_TRUE(cmp_op != nullptr);
EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
ASSERT_TRUE(iota_op != nullptr);
EXPECT_THAT(iota_op.getDimension(), Optional(Eq(0)));
EXPECT_THAT(
cmp_op.getRhs(),
IsConstantOpWithSplatValue(
VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{3}));
}
TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim1) {
constexpr std::array<int64_t, 2> kTargetShape = {4, 8};
FailureOr<TypedValue<VectorType>> vec = getX32VmaskByPaddingEnd(
Builder(), /*padding=*/3, /*target_shape=*/kTargetShape,
/*dim=*/1);
ASSERT_TRUE(succeeded(vec));
auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
ASSERT_TRUE(cmp_op != nullptr);
EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
ASSERT_TRUE(iota_op != nullptr);
EXPECT_THAT(iota_op.getDimension(), Optional(Eq(1)));
EXPECT_THAT(
cmp_op.getRhs(),
IsConstantOpWithSplatValue(
VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{5}));
}
} // namespace
} // namespace mlir::tpu

View File

@ -33,4 +33,5 @@ except ImportError:
from mlir.dialects._ods_common import _cext # type: ignore[import-not-found]
_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
# Add the parent module to the search prefix
_cext.globals.append_dialect_search_prefix(__name__[:__name__.rfind(".")])

View File

@ -83,6 +83,7 @@ setup(
'cuda/*',
'cuda/nvvm/libdevice/libdevice*',
'mosaic/*.py',
'mosaic/dialect/gpu/*.py',
'mosaic/gpu/*.so',
'mosaic/python/*.py',
'mosaic/python/*.so',

View File

@ -218,6 +218,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
dst_dir=mosaic_python_dir,
src_files=[
"__main__/jaxlib/mosaic/python/layout_defs.py",
"__main__/jaxlib/mosaic/python/mosaic_gpu.py",
"__main__/jaxlib/mosaic/python/tpu.py",
],
)
@ -225,6 +226,16 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
)
mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu"
os.makedirs(mosaic_gpu_dir)
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py",
dst_dir=mosaic_gpu_dir,
)
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py",
dst_dir=mosaic_gpu_dir,
)
copy_runfiles(
dst_dir=jaxlib_dir / "mlir",
@ -316,6 +327,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",

View File

@ -300,6 +300,48 @@ class ColocatedPythonTest(jtu.JaxTestCase):
# around 15 seconds.
self.assertLess(elapsed_time, 10)
def testInputsWithDifferentDeviceOrders(self):
cpu_devices = _colocated_cpu_devices(jax.local_devices())[:2]
if len(cpu_devices) < 2:
self.skipTest("Not enough CPU devices")
@colocated_python.colocated_python
def add(x: jax.Array, y: jax.Array) -> jax.Array:
arrays = [
x.addressable_shards[1].data + y.addressable_shards[0].data,
x.addressable_shards[0].data + y.addressable_shards[1].data,
]
return jax.make_array_from_single_device_arrays(
y.shape, y.sharding, arrays
)
# The execution will use mixed device orders. We should specialize the
# function with devices to avoid the argument-dependent device selection.
add = add.specialize(devices=cpu_devices)
mesh1 = jax.sharding.Mesh([cpu_devices[0], cpu_devices[1]], "x")
sharding1 = jax.sharding.NamedSharding(
mesh1, jax.sharding.PartitionSpec("x")
)
mesh2 = jax.sharding.Mesh([cpu_devices[1], cpu_devices[0]], "x")
sharding2 = jax.sharding.NamedSharding(
mesh2, jax.sharding.PartitionSpec("x")
)
x = np.array([0, 2])
x = jax.device_put(x, sharding1)
y = np.array([4, 8])
y = jax.device_put(y, sharding2)
out = add(x, y)
self.assertEqual(out.sharding, sharding2)
out_device_list = [shard.device for shard in out.addressable_shards]
self.assertEqual(out_device_list, [cpu_devices[1], cpu_devices[0]])
out = jax.device_get(out)
np.testing.assert_equal(out, np.array([2 + 4, 0 + 8]))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -219,6 +219,29 @@ class DebugPrintTest(jtu.JaxTestCase):
[ 1 2 3 4 5 6 7 8 9 10 12 13 14]]
"""))
def test_debug_print_respects_numpy_printoptions(self):
def f(x):
with np.printoptions(precision=2, suppress=True):
jax.debug.print("{}", x)
x = np.array([1.2345, 2.3456, 1E-7])
# Default numpy print options:
with jtu.capture_stdout() as output:
jax.debug.print("{}", x)
self.assertEqual(output(), "[1.2345e+00 2.3456e+00 1.0000e-07]\n")
# Modified print options without JIT:
with jtu.capture_stdout() as output:
f(x)
jax.effects_barrier()
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
# Modified print options with JIT:
with jtu.capture_stdout() as output:
jax.jit(f)(x)
jax.effects_barrier()
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
class DebugPrintTransformationTest(jtu.JaxTestCase):

View File

@ -30,6 +30,7 @@ from jax._src import abstract_arrays
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import linear_util
from jax._src import prng
from jax._src import test_util as jtu
@ -279,8 +280,13 @@ class FfiTest(jtu.JaxTestCase):
def testBackwardCompatSyntax(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x, x, param=0.5)
with self.assertWarns(DeprecationWarning):
jax.jit(fun).lower(jnp.ones(5))
msg = "Calling ffi_call directly with input arguments is deprecated"
if deprecations.is_accelerated("jax-ffi-call-args"):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(fun).lower(jnp.ones(5))
else:
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(fun).lower(jnp.ones(5))
def testInputOutputAliases(self):
def fun(x):

View File

@ -158,12 +158,13 @@ class FftTest(jtu.JaxTestCase):
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6})
self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6},
rtol={np.float32: 2e-6})
# Test gradient for differentiable types.
if (config.enable_x64.value and
dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
# TODO(skye): can we be more precise?
tol = 0.15
tol = 0.16
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
# check dtypes

View File

@ -254,8 +254,6 @@ def sdpa_train_fp8(
class DotProductAttentionTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
@ -366,6 +364,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_sdpa_inference(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
query = jax.random.normal(
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
@ -407,6 +407,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_sdpa_var_seq(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
self.skipTest("Skip before fixed.")
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
query = jax.random.normal(
@ -438,6 +440,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_sdpa_broadcast_bias_and_dbias(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
@ -504,6 +508,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
)
@jtu.run_on_devices("cuda")
def test_sdpa_dbias(self, batch_size: int):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
# cuDNN only supports dbias when batch size is 1. If the batch size is
# greater, dbias is silently set to all zeros. This test verifies this
# behavior for both vmap and regular use cases.
@ -540,6 +546,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_sdpa_sliding_window_length(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
query = jax.random.normal(
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
@ -571,8 +579,43 @@ class DotProductAttentionTest(jtu.JaxTestCase):
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
@jtu.run_on_devices("cuda")
def test_sdpa_large_head_size(self):
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
self.skipTest(str(e))
return
if cudnn_version < 90500:
self.skipTest("Requires >= cuDNN 9.5.0")
if not jtu.is_cuda_compute_capability_at_least("9.0"):
self.skipTest("Requires at least Hopper arch")
B, T, N, H = 2, 64, 2, 256
bf16 = jnp.bfloat16
keys = jax.random.split(jax.random.key(0), 4)
query = jax.random.normal(keys[0], (B, T, N, H), dtype=bf16)
key = jax.random.normal(keys[1], (B, T, N, H), dtype=bf16)
value = jax.random.normal(keys[2], (B, T, N, H), dtype=bf16)
grad = jax.random.normal(keys[3], (B, T, N, H), dtype=bf16)
sdpa_train_ans = jax.jit(partial(
sdpa_train, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0)
)
sdpa_train_rfc = jax.jit(partial(
sdpa_train_ref, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0)
)
out_ans, grads_ans = sdpa_train_ans(query, key, value, grad, None, None)
out_ref, grads_ref = sdpa_train_rfc(query, key, value, grad, None, None)
self.assertArraysAllClose(out_ref, out_ans)
self.assertArraysAllClose(grads_ref[0], grads_ans[0])
self.assertArraysAllClose(grads_ref[1], grads_ans[1])
self.assertArraysAllClose(grads_ref[2], grads_ans[2])
@jtu.run_on_devices("cuda")
def test_layouts(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
dtype = "bfloat16"
B, T, N, H = 4, 1024, 8, 128
S = T
@ -600,6 +643,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
self.assertArraysAllClose(dv_ref, _cvt_back(dv))
def test_sdpa_utils(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
test_cases = [
(1, 257, 64, 8905, False, True, True),
(1, 1024, 64, 8905, False, False, True),

View File

@ -25,18 +25,10 @@ import jax.numpy as jnp
jax.config.parse_flags_with_absl()
# Helper class used to create a reference cycle.
class GarbageCollectionGuardTestNodeHelper:
def __init__(self, data):
self.data = data
self.next = None
def _create_array_cycle():
"""Creates a reference cycle of two jax.Arrays."""
n1 = GarbageCollectionGuardTestNodeHelper(jax.jit(lambda: jnp.ones( (2, 2)))())
n2 = GarbageCollectionGuardTestNodeHelper(jax.jit(lambda: jnp.zeros((2, 2)))())
n1 = jnp.ones((2, 2))
n2 = jnp.zeros((2, 2))
n1.next = n2
n2.next = n1
return weakref.ref(n1)

View File

@ -87,6 +87,32 @@ python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
# uint64 is problematic because with any uint type it promotes to float:
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
def _bitcast_uint4_to_uint8(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint4'
operand = operand.astype('uint8')
return operand[..., ::2] + (operand[..., 1::2] << 4)
def _bitcast_uint8_to_uint4(operand):
# Note: assumes little-endian byte order.
assert operand.dtype == 'uint8'
result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4')
result[..., ::2] = (operand & 0b00001111).astype('uint4')
result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4')
return result
def np_view(arr, dtype):
# Implementation of np.ndarray.view() that works for int4/uint4
dtype = np.dtype(dtype)
nbits_in = dtypes.bit_width(arr.dtype)
nbits_out = dtypes.bit_width(dtype)
if nbits_in == 4:
arr = _bitcast_uint4_to_uint8(arr.view('uint4'))
if nbits_out == 4:
arr = _bitcast_uint8_to_uint4(arr.view('uint8'))
return arr.view(dtype)
def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False,
axis=None, **kwds):
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
@ -4244,9 +4270,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@jtu.sample_product(
# Final dimension must be a multiple of 16 to ensure compatibility of all dtype pairs.
shape=[(0,), (32,), (2, 16)],
a_dtype=all_dtypes,
dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes,
shape=[(0,), (64,), (2, 32)],
a_dtype=(jnp.int4, jnp.uint4, *all_dtypes),
dtype=((jnp.int4, jnp.uint4, *all_dtypes, None)
if config.enable_x64.value else (jnp.int4, jnp.uint4, *all_dtypes)),
)
def testView(self, shape, a_dtype, dtype):
if jtu.test_device_matches(["tpu"]):
@ -4259,7 +4286,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.rng()
)
args_maker = lambda: [rng(shape, a_dtype)]
np_op = lambda x: np.asarray(x).view(dtype)
np_op = lambda x: np_view(x, dtype)
jnp_op = lambda x: jnp.asarray(x).view(dtype)
# Above may produce signaling nans; ignore warnings from invalid values.
with np.errstate(invalid='ignore'):
@ -4268,9 +4295,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@jtu.sample_product([
{'a_dtype': a_dtype, 'dtype': dtype}
for a_dtype in all_dtypes
for dtype in all_dtypes
if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize
for a_dtype in [jnp.int4, jnp.uint4, *all_dtypes]
for dtype in [jnp.int4, jnp.uint4, *all_dtypes]
if dtypes.bit_width(a_dtype) == dtypes.bit_width(dtype)
])
def testViewScalar(self, a_dtype, dtype):
if jtu.test_device_matches(["tpu"]):

View File

@ -170,43 +170,49 @@ class LaxTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(numpy_op, op, args_maker)
@jtu.sample_product(
from_dtype=jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
to_dtype=jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
from_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
to_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
shape = [(), (2,), (2, 3)]
)
def testBitcastConvertType(self, from_dtype, to_dtype, shape):
rng = jtu.rand_default(self.rng())
itemsize_in = np.dtype(from_dtype).itemsize
itemsize_out = np.dtype(to_dtype).itemsize
if itemsize_in < itemsize_out:
shape = (*shape, itemsize_out // itemsize_in)
nbits_in = dtypes.bit_width(from_dtype)
nbits_out = dtypes.bit_width(to_dtype)
if nbits_in < nbits_out:
shape = (*shape, nbits_out // nbits_in)
args_maker = lambda: [rng(shape, from_dtype)]
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
self._CompileAndCheck(op, args_maker)
jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype)
self._CompileAndCheck(jnp_op, args_maker)
# Test the shape and dtype of the output. We avoid testing the values here
# because the bitwise representation may vary from platform to platform.
out = op(*args_maker())
if itemsize_in == itemsize_out:
out = jnp_op(*args_maker())
if nbits_in == nbits_out:
expected_shape = shape
elif itemsize_in < itemsize_out:
elif nbits_in < nbits_out:
expected_shape = shape[:-1]
else:
expected_shape = (*shape, itemsize_in // itemsize_out)
expected_shape = (*shape, nbits_in // nbits_out)
self.assertEqual(out.dtype, to_dtype)
self.assertEqual(out.shape, expected_shape)
@jtu.sample_product(
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
for from_dtype, to_dtype in itertools.product(
[np.float32, np.int32, "float32", "int32"], repeat=2)],
['int4', 'uint4', np.int8, np.uint8, np.int32, np.float16, np.float32],
repeat=2)],
shape=[(4,), (2, 4), (2, 3, 4)]
)
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype):
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, shape):
nbits_in = dtypes.bit_width(from_dtype)
nbits_out = dtypes.bit_width(to_dtype)
if nbits_in < nbits_out:
shape = (*shape, nbits_out // nbits_in)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng((2, 3), from_dtype)]
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
numpy_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
args_maker = lambda: [rng(shape, from_dtype)]
jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype)
np_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
@jtu.sample_product(
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
@ -1379,83 +1385,6 @@ class LaxTest(jtu.JaxTestCase):
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
self._CheckAgainstNumpy(numpy_op, op, args_maker)
def testRaggedAllToAllErrors(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input and output shapes must be equal, except for the outermost dimension."):
jax.jit(lax.ragged_all_to_all).lower(operand, jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32), input_offsets, send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be integer type."):
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be integer type."):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be integer type."):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be integer type."):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([], dtype=jnp.int32), output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([], dtype=jnp.int32), recv_sizes)
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32))
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([], dtype=jnp.int32))
def testRaggedAllToAll(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
mlir_module = jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes).as_text()
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
self.assertIn(
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
" tensor<1x3xi64>}}",
mlir_module,
)
@jtu.sample_product(
[
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
],
dtype=jtu.dtypes.numeric,
)
def testRaggedDot(self, m, k, n, num_groups, dtype):
"""Tests ragged_dot.
The ragged_dot is tested against numpy reference implementation, and by running JAX compilation.
Raises:
SkipTest: in the case dtype is not supported.
"""
lhs_shape = (m, k)
rhs_shape = (num_groups, k, n)
def group_sizes(m, num_groups):
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
ends = jnp.concatenate([ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
starts = jnp.concatenate([jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
return ends - starts
rng = jtu.rand_small(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype), group_sizes(m, num_groups)]
self._CompileAndCheck(lax.ragged_dot, args_maker)
self._CheckAgainstNumpy(lax_reference.ragged_dot, lax.ragged_dot, args_maker)
@jtu.sample_product(
shape=[(), (2, 3)],
dtype=lax_test_util.default_dtypes,
@ -4457,7 +4386,7 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
elif name == 'log10':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag')
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
elif name == 'exp':
regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag')
@ -4751,5 +4680,181 @@ class CompositeTest(jtu.JaxTestCase):
):
grad(my_square)(1.0)
class RaggedTest(jtu.JaxTestCase):
def testRaggedAllToAll(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
mlir_module = jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes, output_offsets,
recv_sizes).as_text()
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
self.assertIn(
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
" tensor<1x3xi64>}}",
mlir_module,
)
def testRaggedAllToAllErrors(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all input and output shapes must be equal, except for"
" the outermost dimension.",
):
jax.jit(lax.ragged_all_to_all).lower(
operand,
jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32),
input_offsets, send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all input_offsets must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32),
send_sizes, output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all send_sizes must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets,
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets,
recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all output_offsets must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes,
jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all recv_sizes must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes, output_offsets,
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
" size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes,
output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
" size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, jnp.array([], dtype=jnp.int32), send_sizes,
output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
" size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets,
jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
" size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, jnp.array([], dtype=jnp.int32),
output_offsets, recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all output_offsets must be rank 1 with positive"
" dimension size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes,
jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all output_offsets must be rank 1 with positive"
" dimension size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes,
jnp.array([], dtype=jnp.int32), recv_sizes)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
" size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes, output_offsets,
jnp.array([[1, 2, 3]], dtype=jnp.int32))
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
" size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes, output_offsets,
jnp.array([], dtype=jnp.int32))
@jtu.sample_product(
[
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
],
dtype=jtu.dtypes.numeric,
)
def testRaggedDot(self, m, k, n, num_groups, dtype):
"""Tests ragged_dot.
The ragged_dot is tested against numpy reference implementation, and by
running JAX compilation.
Raises:
SkipTest: in the case dtype is not supported.
"""
lhs_shape = (m, k)
rhs_shape = (num_groups, k, n)
def group_sizes(m, num_groups):
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
ends = jnp.concatenate(
[ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
starts = jnp.concatenate(
[jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
return ends - starts
rng = jtu.rand_small(self.rng())
args_maker = lambda: [
rng(lhs_shape, dtype),
rng(rhs_shape, dtype),
group_sizes(m, num_groups),
]
self._CompileAndCheck(lax.ragged_dot, args_maker)
self._CheckAgainstNumpy(
lax_reference.ragged_dot, lax.ragged_dot, args_maker)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -586,7 +586,7 @@ class DialectLoweringTest(MosaicGpuTest):
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
)
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
self.assertEmpty(
list(filter(is_mosaic_gpu_op, self.module.body.operations))
@ -604,7 +604,7 @@ class DialectLoweringTest(MosaicGpuTest):
arrival_count=1,
)
scf.yield_([])
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
self.assertEmpty(
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
@ -626,7 +626,7 @@ class DialectLoweringTest(MosaicGpuTest):
memref.copy(barriers_ref, barriers_ref)
self.assertTrue(self.module.operation.verify())
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
self.assertTrue(self.module.operation.verify())
all_mbarrier_init_shared_ops = find_if(
@ -654,7 +654,7 @@ class DialectLoweringTest(MosaicGpuTest):
with self.assertRaisesRegex(
ValueError, "missing a layout and can not be lowered"
):
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
def test_lowering_eliminates_layouts(self):
shape = (4, 128)
@ -670,7 +670,7 @@ class DialectLoweringTest(MosaicGpuTest):
)
])
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
all_ops_with_layouts = find_if(
self.module,
@ -691,7 +691,7 @@ class DialectLoweringTest(MosaicGpuTest):
vector.store(array, ref, [zero_index, zero_index])
mgpu.infer_layout(self.module)
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
all_loads = find_if(
self.module,

View File

@ -47,6 +47,8 @@ except ImportError:
z = 2
else:
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic.gpu import core
from jax.experimental.mosaic.gpu import launch_context
from jax.experimental.mosaic.gpu import utils as utils
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
@ -1937,6 +1939,103 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)
def test_pointwise_kernel_with_tma(self):
def add(
ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value,
b_gmem_ref: ir.Value,
result_gmem_ref: ir.Value,
smem: list[ir.Value],
):
del ctx
a_smem_ref, b_smem_ref, result_smem_ref = smem[:3]
tma_barrier = smem[3]
memref_type = ir.MemRefType(a_gmem_ref.type)
shape = memref_type.shape
elt_type = memref_type.element_type
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
with utils.single_thread():
memref_bytes = utils.bytewidth(elt_type) # Also correct if rank == 0
for size in shape:
memref_bytes *= size
nvvm.mbarrier_arrive_expect_tx_shared(
tma_barrier.get_ptr(),
arith.constant(ir.IntegerType.get_signless(32), 2*memref_bytes),
)
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=a_smem_ref,
barrier=tma_barrier.as_dialect_barrier(),
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=b_smem_ref,
barrier=tma_barrier.as_dialect_barrier(),
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
)
tma_barrier.wait()
zero_index = arith.constant(ir.IndexType.get(), 0)
# SMEM -> registers
ab_type = ir.VectorType.get(shape, elt_type)
a = vector.load(ab_type, a_smem_ref, [zero_index, zero_index])
b = vector.load(ab_type, b_smem_ref, [zero_index, zero_index])
# Computation
add = arith.addf(arith.addf(a, b), b)
# Registers -> SMEM
vector.store(add, result_smem_ref, [zero_index, zero_index])
# SMEM -> GMEM
mgpu_dialect.async_store(
source=result_smem_ref,
destination=result_gmem_ref,
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
)
nvvm.cp_async_bulk_wait_group(0)
utils.warpgroup_barrier()
dtype = jnp.bfloat16
shape = (128, 128)
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
kernel = mgpu.as_gpu_kernel(
add,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(jax_shape, jax_shape),
out_shape=jax_shape,
smem_scratch_shape=[
jax_shape,
jax_shape,
jax_shape,
core.TMABarrier(1),
],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, shape).astype(dtype)
y = self.prng.uniform(-1, 1, shape).astype(dtype)
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
class UtilsTest(TestCase):
@parameterized.parameters(

View File

@ -306,30 +306,88 @@ class MutableArrayErrorsTest(jtu.JaxTestCase):
ValueError, "traced for cond returned a mutable array reference of type"):
jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0))
# TODO test_argument_aliases_cond
# TODO test_closure_and_argument_aliases_cond
def test_argument_aliases_cond(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex( ValueError, r"for cond.*at both x1 and x2"):
jax.lax.cond(True, lambda x1, x2: ..., lambda x1, x2: ..., x_ref, x_ref)
# TODO test_return_from_custom_jvp/vjp
# TODO test_argument_aliases_custom_jvp/vjp
# TODO test_closure_and_argument_aliases_custom_jvp/vjp
def test_closure_and_argument_aliases_cond(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, r"closed over and passed as the argument y_ref"):
jax.lax.cond(True,
lambda y_ref: x_ref[...] + y_ref[...],
lambda y_ref: x_ref[...] + y_ref[...],
x_ref)
# TODO(mattjj): enable when cond works with mutable arrays
# @parameterized.parameters([False, True])
# def test_cond_both_branches_close_over_same_mutable_array(self, jit):
# # see also test_cond_with_ref_reuse in state_test.py
# x_ref = core.mutable_array(0.)
# def f(pred):
# def true_fun():
# x_ref[()] = 1.
# def false_fun():
# x_ref[()] = 2.
# jax.lax.cond(pred, true_fun, false_fun)
# if jit:
# f = jax.jit(f)
# out_true = f(True)
# self.assertAllClose(x_ref[...], 1.)
# out_false = f(False)
# self.assertAllClose(x_ref[...], 2.)
@parameterized.parameters([False, True])
def test_return_from_custom_vjp_primal(self, jit):
@jax.custom_vjp
def f(ref):
return ref
f.defvjp(lambda ref: ..., lambda *_: ...)
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "custom_vjp primal function"):
f(x_ref)
@parameterized.parameters([False, True])
def test_return_from_custom_vjp_fwd(self, jit):
@jax.custom_vjp
def f(x, ref):
return x
f.defvjp(lambda x, ref: (x, ref), lambda ref, g: g)
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "custom_vjp fwd function"):
jax.vjp(f, 3., x_ref)
@parameterized.parameters([False, True])
def test_argument_aliases_custom_vjp_primal(self, jit):
@jax.custom_vjp
def f(x_ref, y_ref):
...
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
f(x_ref, x_ref)
@parameterized.parameters([False, True])
def test_argument_aliases_custom_vjp_fwd(self, jit):
@jax.custom_vjp
def f(x_ref, y_ref):
...
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
jax.vjp(f, x_ref, x_ref)
# TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp
@parameterized.parameters([False, True])
def test_cond_both_branches_close_over_same_mutable_array(self, jit):
# see also test_cond_with_ref_reuse in state_test.py
x_ref = core.mutable_array(0.)
def f(pred):
def true_fun():
x_ref[()] = 1.
def false_fun():
x_ref[()] = 2.
jax.lax.cond(pred, true_fun, false_fun)
if jit:
f = jax.jit(f)
out_true = f(True)
self.assertAllClose(x_ref[...], 1.)
out_false = f(False)
self.assertAllClose(x_ref[...], 2.)
if __name__ == '__main__':

View File

@ -39,8 +39,8 @@ jax_multiplatform_test(
"tpu",
],
enable_configs = [
"gpu_a100_x32",
"gpu_h100_x32",
"gpu_a100",
"gpu_h100",
],
shard_count = {
"cpu": 8,

View File

@ -39,6 +39,15 @@ except ImportError:
jax.config.parse_flags_with_absl()
def _fori_loop(force_while: bool, lb, ub, body, init):
if force_while:
# using jnp.asarray make the matcher for while or scan to think
# that the bounds are dynamic and forces the use of the while
# primitive.
lb, ub = jnp.asarray(lb), jnp.asarray(ub)
return jax.lax.fori_loop(lb, ub, body, init)
class PallasTest(jtu.JaxTestCase):
def setUp(self):
@ -705,19 +714,21 @@ class PallasCallTest(PallasTest):
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
np.testing.assert_array_equal(kernel(x), x)
def test_fori_loop_array(self):
@parameterized.parameters(False, True)
def test_fori_loop_array(self, force_while):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(x_ref, o_ref):
# Equivalent to x_ref[...] + 2 + 3.
o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...])
o_ref[...] = _fori_loop(force_while, 2, 4, lambda i, x: x + i, x_ref[...])
x = jnp.arange(256).astype(jnp.int32)
np.testing.assert_array_equal(kernel(x), x + 2 + 3)
def test_fori_loop_scalar(self):
@parameterized.parameters(False, True)
def test_fori_loop_scalar(self, force_while):
@functools.partial(
pl.pallas_call,
@ -726,7 +737,7 @@ class PallasCallTest(PallasTest):
def kernel(o_ref):
# Equivalent to 2 + 3.
o_ref[...] = jax.lax.broadcast(
jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0), o_ref.shape
_fori_loop(force_while, 2, 4, lambda i, x: x + i, 0), o_ref.shape
)
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
@ -747,7 +758,8 @@ class PallasCallTest(PallasTest):
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
def test_fori_loop_tuple(self):
@parameterized.parameters(False, True)
def test_fori_loop_tuple(self, force_while):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
@ -761,14 +773,15 @@ class PallasCallTest(PallasTest):
# Equivalent to 3 * (0 + 1).
o_ref[...] = jax.lax.broadcast(
sum(jax.lax.fori_loop(2, 4, body, (0, 0, 0))), o_ref.shape
sum(_fori_loop(force_while, 2, 4, body, (0, 0, 0))), o_ref.shape
)
np.testing.assert_array_equal(
kernel(), jnp.full([256], 3 * (0 + 1), dtype=jnp.int32)
)
def test_fori_loop_indexed_store(self):
@parameterized.parameters(False, True)
def test_fori_loop_indexed_store(self, force_while):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32),
@ -778,7 +791,7 @@ class PallasCallTest(PallasTest):
o_ref[idx] = x_ref[idx] + y_ref[idx]
return ()
jax.lax.fori_loop(0, 4, body, ())
_fori_loop(force_while, 0, 4, body, ())
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
y = x + 1

View File

@ -2127,6 +2127,21 @@ class OpsTest(PallasBaseTest):
)
self.assertTrue(acceptable_errors, "Failed with error: " + str(e))
@parameterized.parameters((128, 128), (256, 256))
def test_jnp_diagonal_pallas(self, n, m):
if jtu.test_device_matches(["gpu"]):
# TODO(mvoz): platform_index_p on GPU
self.skipTest("Not implemented on GPU")
x = jnp.arange(n * m, dtype=jnp.float32).reshape((n, m))
def kernel(x_ref, out_ref):
out_ref[...] = jnp.diagonal(x_ref[...])
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((n,), jnp.float32)
)(x)
np.testing.assert_array_equal(out, np.diagonal(x))
class OpsInterpretTest(OpsTest):
INTERPRET = True

View File

@ -1688,8 +1688,8 @@ class PallasControlFlowTest(PallasBaseTest):
def body(state):
i, s = state
sl = jax.lax.div(i, 128)
l = jax.lax.rem(i, 128)
sl = jax.lax.div(i, jnp.astype(128, i.dtype))
l = jax.lax.rem(i, jnp.astype(128, i.dtype))
v = pl.load(x_ref, (0, sl, l))
return i + 1, s + v

View File

@ -312,6 +312,46 @@ class OpsTest(PallasBaseTest):
expected = reduce_func(x, axis, keepdims=True)
np.testing.assert_array_equal(result, expected)
@parameterized.product(
msk_dtype=[jnp.float32, jnp.bfloat16, jnp.int8],
dtype=[jnp.float32, jnp.bfloat16],
)
def test_i1_relayout_with_bitwidth_change(self, msk_dtype, dtype):
# TODO(jevinjiang): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
shape = (129, 129)
msk_bitwidth = pallas_utils.dtype_bitwidth(msk_dtype)
bitwidth = pallas_utils.dtype_bitwidth(dtype)
if (
(jtu.get_tpu_version() > 5 and msk_bitwidth < 8)
or (jtu.get_tpu_version() == 5 and msk_bitwidth not in (8, 32))
or (jtu.get_tpu_version() < 5 and msk_bitwidth < 32)
):
self.skipTest(
"Not implemented: cast vector to mask with bitwidth =="
f" {msk_bitwidth}"
)
if jtu.get_tpu_version() <= 5 and bitwidth < 32:
self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}")
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, dtype),
)
def kernel(x_ref, mask_ref, o_ref):
zeros = jnp.zeros_like(x_ref)
o_ref[...] = jnp.where(mask_ref[...], x_ref[...], zeros)
mask = jax.random.bernoulli(jax.random.key(1234), 0.5, shape).astype(
msk_dtype
)
x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + 1
out = kernel(x, mask)
expected = jnp.where(mask, x, jnp.zeros_like(x))
self.assertArraysEqual(out, expected)
class OpsInterpretTest(OpsTest):
INTERPRET = True

View File

@ -2555,9 +2555,7 @@ class MiscellaneousTest(PallasBaseTest):
np.testing.assert_array_equal(out, np.reshape(x, (8, 128)))
@only_passes_in_interpret()
def test_retiling2(self):
"""b/348040767"""
x = np.arange(1 * 8 * 1024, dtype=jnp.bfloat16).reshape(1, 8, 1024)
def kernel(x_ref, out_ref):

View File

@ -95,8 +95,8 @@ class PgleTest(jtu.JaxTestCase):
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
# TODO(b/37664749): Remove this flag once the bug is fixed.
'xla_gpu_enable_command_buffer': '',
},
)
def f(x):
@ -133,8 +133,6 @@ class PgleTest(jtu.JaxTestCase):
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
'xla_dump_to': dump_dir,
'xla_gpu_experimental_dump_fdo_profiles': 'True'
},
@ -217,8 +215,6 @@ class PgleTest(jtu.JaxTestCase):
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
'xla_dump_to': dump_dir,
'xla_gpu_experimental_dump_fdo_profiles': 'True'
},

View File

@ -2207,6 +2207,23 @@ class ShardMapTest(jtu.JaxTestCase):
#
# f(x) # don't crash
def test_partial_auto_of_random_keys(self):
if config.use_shardy_partitioner.value:
self.skipTest('Shardy does not support full-to-shard.')
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
keys = jax.random.split(jax.random.key(0), 8)
@jax.jit
def f(x):
return shard_map(lambda k: k,
mesh, in_specs=P('i'), out_specs=P('i'),
check_rep=False, auto=frozenset({'j'}))(keys)
y = f(keys) # don't crash
self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys),
check_dtypes=False)
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
# https://github.com/jax-ml/jax/pull/21032
mesh = jtu.create_mesh((4, 2), ('i', 'j'))

View File

@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
XLA_COMMIT = "b44f55da3dac449f03466815ac431474f86fd73f"
XLA_SHA256 = "f3d37257b970fd2993cbbc9185c2271910775f752d7c3bdd1828b8f663df1ff1"
XLA_COMMIT = "c12c1148585e00985d5e1ccf2bc0768862b7df77"
XLA_SHA256 = "44396bdac8b8bc7cba958691ae8df040ba91ddb26513aed37656d6db479dd06c"
def repo():
tf_http_archive(