mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge branch 'jax-ml:main' into activation-offloading-doc
This commit is contained in:
commit
5e3a692d36
2
.github/workflows/upstream-nightly.yml
vendored
2
.github/workflows/upstream-nightly.yml
vendored
@ -85,7 +85,7 @@ jobs:
|
||||
&& steps.status.outcome == 'failure'
|
||||
&& github.event_name == 'schedule'
|
||||
&& github.repository == 'jax-ml/jax'
|
||||
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
|
||||
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
|
||||
with:
|
||||
name: output-${{ matrix.python-version }}-log.jsonl
|
||||
path: output-${{ matrix.python-version }}-log.jsonl
|
||||
|
2
.github/workflows/wheel_win_x64.yml
vendored
2
.github/workflows/wheel_win_x64.yml
vendored
@ -45,7 +45,7 @@ jobs:
|
||||
--bazel_options=--config=win_clang `
|
||||
--verbose
|
||||
|
||||
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
|
||||
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
|
||||
with:
|
||||
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
|
||||
path: ${{ github.workspace }}\dist\*.whl
|
||||
|
2
.github/workflows/windows_ci.yml
vendored
2
.github/workflows/windows_ci.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
--bazel_options=--config=win_clang `
|
||||
--verbose
|
||||
|
||||
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
|
||||
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
|
||||
with:
|
||||
name: wheels
|
||||
path: ${{ github.workspace }}\jax\dist\*.whl
|
||||
|
32
CHANGELOG.md
32
CHANGELOG.md
@ -4,6 +4,10 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
|
||||
For the changes specific to the experimental Pallas APIs,
|
||||
see {ref}`pallas-changelog`.
|
||||
|
||||
JAX follows Effort-based versioning; for a discussion of this and JAX's API
|
||||
compatibility policy, refer to {ref}`api-compatibility`. For the Python and
|
||||
NumPy version support policy, refer to {ref}`version-support-policy`.
|
||||
|
||||
<!--
|
||||
Remember to align the itemized text with the first line of an item within a list.
|
||||
|
||||
@ -12,6 +16,16 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## Unreleased
|
||||
|
||||
* Changes:
|
||||
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
|
||||
supported version until June 2025.
|
||||
|
||||
* New Features
|
||||
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
|
||||
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
|
||||
transforms in more than 3 dimensions, which was previously the limit. See
|
||||
{jax-issue}`#25606` for more details.
|
||||
|
||||
* Deprecations
|
||||
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
|
||||
are now deprecated, having been replaced by symbols of the same name
|
||||
@ -20,21 +34,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* Deletions
|
||||
* `jax_enable_memories` flag has been deleted and the behavior of that flag
|
||||
is on by default.
|
||||
|
||||
* Changes:
|
||||
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
|
||||
supported version until June 2025.
|
||||
|
||||
* Deprecations
|
||||
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
|
||||
are now deprecated, having been replaced by symbols of the same name
|
||||
in {mod}`jax.core`.
|
||||
|
||||
* New Features
|
||||
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
|
||||
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
|
||||
transforms in more than 3 dimensions, which was previously the limit. See
|
||||
{jax-issue}`#25606` for more details.
|
||||
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
|
||||
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
|
||||
and `jax.errors.JaxRuntimeError` respectively.
|
||||
|
||||
## jax 0.4.38 (Dec 17, 2024)
|
||||
|
||||
|
@ -580,11 +580,11 @@ async def main():
|
||||
if args.configure_only:
|
||||
logging.info("--configure_only is set so not running any Bazel commands.")
|
||||
else:
|
||||
output_path = args.output_path
|
||||
logger.debug("Artifacts output directory: %s", output_path)
|
||||
|
||||
# Wheel build command execution
|
||||
for wheel in args.wheels.split(","):
|
||||
output_path = args.output_path
|
||||
logger.debug("Artifacts output directory: %s", output_path)
|
||||
|
||||
# Allow CUDA/ROCm wheels without the "jax-" prefix.
|
||||
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
|
||||
wheel = "jax-" + wheel
|
||||
|
@ -8,8 +8,15 @@ JAX is constantly evolving, and we want to be able to make improvements to its
|
||||
APIs. That said, we want to minimize churn for the JAX user community, and we
|
||||
try to make breaking changes rarely.
|
||||
|
||||
JAX follows a 3 month deprecation policy. When an incompatible change is made
|
||||
to an API, we will make our best effort to obey the following procedure:
|
||||
## JAX Versioning
|
||||
JAX uses [Effort-based versioning](https://jacobtomlinson.dev/effver/) (see
|
||||
{ref}`jep-effver`), and is currently in the Zero version phase.
|
||||
This means that for version `0.X.Y`, incrementing `Y` will introduce minor
|
||||
breaking changes, and incrementing `X` will introduce major breaking changes.
|
||||
|
||||
For any breaking change, JAX currently follows a 3 month deprecation policy.
|
||||
When an incompatible change is made to an API, we will make our best effort
|
||||
to obey the following procedure:
|
||||
* the change will be announced in `CHANGELOG.md` and in the doc string for the
|
||||
deprecated API, and the old API will issue a `DeprecationWarning`.
|
||||
* three months after the `jax` release that deprecated an API, we may remove the
|
||||
@ -47,16 +54,44 @@ prefixed with underscores, although we do not entirely comply with this yet.
|
||||
|
||||
## What is not covered?
|
||||
|
||||
* anything prefixed with an underscore.
|
||||
* `jax._src`
|
||||
* `jax.core`
|
||||
* `jax.lib`
|
||||
* `jax.interpreters`
|
||||
### Explicitly private APIs
|
||||
Any API or import path prefixed with an underscore is explicitly private,
|
||||
and may change without warning between JAX releases. We are working to move
|
||||
all private APIs into `jax._src` to make these expectations more clear.
|
||||
|
||||
### Legacy internal APIs
|
||||
In addition, there are several legacy modules that currently expose some
|
||||
private APIs without an underscore, including:
|
||||
|
||||
- `jax.core`
|
||||
- `jax.interpreters`
|
||||
- `jax.lib`
|
||||
- `jax.util`
|
||||
|
||||
We are actively working on deprecating these modules and the APIs they contain.
|
||||
In most cases, such deprecations will follow the 3 month deprecation period,
|
||||
but this may not always be possible. If you use any such APIs, please expect
|
||||
them to be deprecated soon, and seek alternatives.
|
||||
|
||||
### Experimental and example libraries
|
||||
The following modules include code for experimental or demonstration purposes,
|
||||
and API may change between releases without warning:
|
||||
|
||||
* `jax.experimental`
|
||||
* `jax.example_libraries`
|
||||
* `jax.extend` (see [details](https://jax.readthedocs.io/en/latest/jax.extend.html))
|
||||
|
||||
This list is not exhaustive.
|
||||
We understand that some users depend on `jax.experimental`, and so in most cases
|
||||
we follow the 3 month deprecation period for changes, but this may not always be
|
||||
possible.
|
||||
|
||||
### JAX extend
|
||||
The {mod}`jax.extend` module includes semi-public JAX internal APIs that are
|
||||
meant for use by downstream projects, but do not have the same stability
|
||||
guarantees of the main JAX package. If you have code that uses `jax.extend`,
|
||||
we would strongly recommend CI tests against JAX's nightly releases, so as to
|
||||
catch potential changes before they are released.
|
||||
|
||||
For details on `jax.extend`, see the [`jax.extend` module docuementation](https://jax.readthedocs.io/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`.
|
||||
|
||||
## Numerics and randomness
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
(jep-effver)=
|
||||
# JEP 25516: Effort-based versioning for JAX
|
||||
|
||||
This document proposes that the JAX core library should explicitly adopt
|
||||
|
@ -93,6 +93,12 @@ plugins" error described {ref}`below <multiple_installs>`. See
|
||||
<https://www.tensorflow.org/guide/profiler> for more information on installing
|
||||
TensorBoard.
|
||||
|
||||
Nightly version of TensorBoard profiler requires nightly tensorflow and
|
||||
tensorboard
|
||||
```shell
|
||||
pip install tf-nightly tb-nightly tbp-nightly
|
||||
```
|
||||
|
||||
### Programmatic capture
|
||||
|
||||
You can instrument your code to capture a profiler trace via the
|
||||
|
@ -490,7 +490,7 @@ This section covers some of the most common patterns with JAX pytrees.
|
||||
|
||||
### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose`
|
||||
|
||||
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
|
||||
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func}`jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
|
||||
|
||||
**Option 1:** Use {func}`jax.tree.map`. Here's an example:
|
||||
|
||||
|
@ -56,14 +56,7 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
|
||||
dtype = x.dtype
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape,
|
||||
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
|
||||
|
||||
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
core.shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
|
||||
|
||||
|
||||
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
@ -71,15 +64,8 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
|
||||
dtype = np.dtype(x)
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x),
|
||||
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
|
||||
|
||||
for t in numpy_scalar_types:
|
||||
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
|
||||
core.shaped_abstractify_handlers[t] = _np_scalar_abstractify
|
||||
|
||||
core.literalable_types.update(array_types)
|
||||
|
||||
@ -90,13 +76,7 @@ def _make_abstract_python_scalar(typ, val):
|
||||
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
|
||||
weak_type=typ is not bool)
|
||||
|
||||
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
|
||||
typ = type(x)
|
||||
dtype = dtypes._scalar_type_to_dtype(typ, x)
|
||||
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
|
||||
|
||||
for t in dtypes.python_scalar_dtypes:
|
||||
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
|
||||
core.shaped_abstractify_handlers[t] = _python_scalar_abstractify
|
||||
|
||||
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|
||||
|
@ -583,9 +583,6 @@ def _dtype(x):
|
||||
except ValueError:
|
||||
return dtypes.result_type(getattr(x, 'dtype'))
|
||||
|
||||
# TODO(jakevdp): fix downstream consumers and remove these aliases.
|
||||
shaped_abstractify = core.shaped_abstractify
|
||||
_shaped_abstractify_handlers = core.shaped_abstractify_handlers
|
||||
|
||||
# This decorator exists to make it easier to monkey-patch APIs in JAX.
|
||||
# By default it does nothing, but it can be monkey-patched to do other things.
|
||||
|
@ -1035,7 +1035,6 @@ def _get_aval_array(self):
|
||||
else:
|
||||
return self.aval
|
||||
|
||||
core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
|
||||
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
|
||||
|
||||
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
||||
|
@ -192,6 +192,14 @@ def get_compile_options(
|
||||
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
|
||||
build_options.memory_fitting_effort = config.memory_fitting_effort.value
|
||||
|
||||
# This is a temporary workaround to simplify the AutoPGLE usage.
|
||||
# TODO(b/376647494): Remove once the bug is fixed.
|
||||
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
|
||||
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
|
||||
if env_options_overrides is None:
|
||||
env_options_overrides = {}
|
||||
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
|
||||
|
||||
if env_options_overrides is not None:
|
||||
# Some overrides are passed directly on build_options.
|
||||
overrides_on_build_options = [
|
||||
|
@ -1658,7 +1658,7 @@ array_garbage_collection_guard = optional_enum_state(
|
||||
' do not log garbage collection of "jax.Array" objects.\n * "log":'
|
||||
' log an error when a "jax.Array" is garbage collected.\n * "fatal":'
|
||||
' fatal error if a "jax.Array" is garbage collected.\nDefault is'
|
||||
' "allow".'
|
||||
' "allow". Note that not all cycles may be detected.'
|
||||
),
|
||||
update_global_hook=lambda val: _update_garbage_collection_guard(
|
||||
guard_lib.global_state(), 'garbage_collect_array', val
|
||||
|
@ -656,6 +656,13 @@ def check_bool_conversion(arr: Array):
|
||||
" is ambiguous. Use a.any() or a.all()")
|
||||
|
||||
|
||||
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
|
||||
def _str_abstractify(x):
|
||||
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
|
||||
pytype_aval_mappings[str] = _str_abstractify
|
||||
|
||||
|
||||
def _aval_property(name):
|
||||
return property(lambda self: getattr(self.aval, name))
|
||||
|
||||
@ -918,6 +925,8 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
aval_property = namedtuple("aval_property", ["fget"])
|
||||
aval_method = namedtuple("aval_method", ["fun"])
|
||||
|
||||
pytype_aval_mappings[Tracer] = lambda x: x.aval
|
||||
|
||||
def check_eval_args(args):
|
||||
for arg in args:
|
||||
if isinstance(arg, Tracer):
|
||||
@ -1400,45 +1409,51 @@ def check_valid_jaxtype(x):
|
||||
f"Value {x!r} of type {type(x)} is not a valid JAX type")
|
||||
|
||||
|
||||
def _shaped_abstractify_slow(x):
|
||||
try:
|
||||
return x if isinstance(x, AbstractValue) else get_aval(x)
|
||||
except TypeError:
|
||||
pass
|
||||
# We have three flavors of abstractification APIs here which each used to have
|
||||
# their own separate implementation. Now they're effectively the same, with the
|
||||
# following differences:
|
||||
#
|
||||
# - abstractify returns avals for non-traced array-like objects.
|
||||
# - get_aval is like abstractify, but also accepts tracers.
|
||||
# - shaped_abstractify is like get_aval, but also accepts duck-typed arrays.
|
||||
#
|
||||
# TODO(jakevdp): can these be unified further?
|
||||
|
||||
weak_type = getattr(x, 'weak_type', False)
|
||||
if hasattr(x, 'dtype'):
|
||||
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {type(x)} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
return ShapedArray(np.shape(x), dtype, weak_type=weak_type)
|
||||
|
||||
# TODO(jakevdp): deduplicate this with abstractify
|
||||
def shaped_abstractify(x):
|
||||
# This was originally api_util.shaped_abstractify; temporarily moved
|
||||
# here in order to facilitate combining it with abstractify.
|
||||
handler = shaped_abstractify_handlers.get(type(x), None)
|
||||
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
|
||||
typ = type(x)
|
||||
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
|
||||
return aval_fn(x)
|
||||
for t in typ.__mro__[1:]:
|
||||
if (aval_fn := pytype_aval_mappings.get(t)):
|
||||
return aval_fn(x)
|
||||
if isinstance(x, AbstractValue):
|
||||
return x
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return shaped_abstractify(x.__jax_array__())
|
||||
if hasattr(x, 'dtype'):
|
||||
return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False))
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {typ} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
|
||||
|
||||
def abstractify(x):
|
||||
for typ in type(x).__mro__:
|
||||
aval_fn = pytype_aval_mappings.get(typ)
|
||||
if aval_fn: return aval_fn(x)
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return abstractify(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
|
||||
if isinstance(x, Tracer):
|
||||
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
|
||||
return get_aval(x)
|
||||
|
||||
|
||||
def get_aval(x):
|
||||
if isinstance(x, Tracer):
|
||||
return x.aval
|
||||
else:
|
||||
return abstractify(x)
|
||||
typ = type(x)
|
||||
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
|
||||
return aval_fn(x)
|
||||
for t in typ.__mro__[1:]:
|
||||
if (aval_fn := pytype_aval_mappings.get(t)):
|
||||
return aval_fn(x)
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return get_aval(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
|
||||
|
||||
get_type = get_aval
|
||||
|
||||
def is_concrete(x):
|
||||
return to_concrete_value(x) is not None
|
||||
@ -1831,12 +1846,6 @@ class DShapedArray(UnshapedArray):
|
||||
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
|
||||
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
shaped_abstractify_handlers: dict[Any, Callable[[Any], ShapedArray]] = {}
|
||||
|
||||
def _str_abstractify(x):
|
||||
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
|
||||
shaped_abstractify_handlers[str] = _str_abstractify
|
||||
|
||||
class DArray:
|
||||
_aval: DShapedArray
|
||||
@ -1889,9 +1898,11 @@ class DArray:
|
||||
data = self._data[slices]
|
||||
return data
|
||||
|
||||
def _darray_aval(x):
|
||||
return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
|
||||
|
||||
pytype_aval_mappings[DArray] = _darray_aval
|
||||
|
||||
pytype_aval_mappings[DArray] = \
|
||||
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class bint(dtypes.ExtendedDType):
|
||||
@ -1920,8 +1931,8 @@ class MutableArray:
|
||||
aval = property(lambda self: self._aval)
|
||||
shape = property(lambda self: self._aval.shape)
|
||||
dtype = property(lambda self: self._aval.dtype)
|
||||
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
|
||||
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
|
||||
def __getitem__(self, idx): return self._aval._getitem(self, idx)
|
||||
def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
|
||||
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
|
||||
pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
|
||||
|
@ -348,11 +348,20 @@ def check_is_flash_attention(
|
||||
)
|
||||
else:
|
||||
# Regular attention conditions
|
||||
if not ((H <= 128 and H % 8 == 0) and
|
||||
(not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported sequence length Q {T}, KV {S} and head dim {H}."
|
||||
)
|
||||
# Check the head dim.
|
||||
is_on_hopper = check_compute_capability("9.0")
|
||||
H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128
|
||||
if not (H <= H_max and H % 8 == 0):
|
||||
raise NotImplementedError(
|
||||
f"The head dim must be <= {H_max} and a mutiple of 8, "
|
||||
f"but got {H}."
|
||||
)
|
||||
|
||||
# Check patterns with bias, seqlen should be divisible by 2
|
||||
if (is_training and has_bias and (T % 2 != 0 or S % 2 != 0)):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported sequence length Q {T}, KV {S}."
|
||||
)
|
||||
|
||||
def check_cudnn_version():
|
||||
# check if cuDNN is installed
|
||||
|
@ -30,8 +30,10 @@ from jax._src import traceback_util
|
||||
from jax._src.ad_util import (
|
||||
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
||||
from jax._src.api_util import (
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
|
||||
_arg_names)
|
||||
from jax._src.errors import UnexpectedTracerError
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -41,8 +43,8 @@ from jax._src.interpreters.batching import not_mapped
|
||||
from jax._src.lax import lax
|
||||
from jax._src.tree_util import (
|
||||
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr,
|
||||
treedef_children)
|
||||
register_pytree_node_class, tree_leaves, tree_flatten_with_path,
|
||||
tree_leaves_with_path, keystr, treedef_children)
|
||||
from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable,
|
||||
unzip2)
|
||||
|
||||
@ -608,9 +610,12 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
if config.mutable_array_checks.value:
|
||||
f_ = _check_primal_refs(f_, self.nondiff_argnums)
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, out_type)
|
||||
flat_fwd, out_trees = _flatten_fwd(
|
||||
fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, out_type)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees,
|
||||
@ -618,6 +623,37 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
@lu.transformation2
|
||||
def _check_primal_refs(f, nondiff_argnums, *args):
|
||||
_check_for_aliased_refs(f, nondiff_argnums, args)
|
||||
out = f(*args)
|
||||
_check_for_returned_refs(f, out, 'primal')
|
||||
return out
|
||||
|
||||
def _check_for_aliased_refs(f, nondiff_argnums, args):
|
||||
leaves = tree_leaves(args)
|
||||
refs: dict[int, int] = {}
|
||||
for i, x in enumerate(leaves):
|
||||
if (isinstance((a := core.get_aval(x)), AbstractRef) and
|
||||
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
|
||||
arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
|
||||
if arg_names is None:
|
||||
arg_names = [f'flat index {j}' for j in range(len(leaves))]
|
||||
raise ValueError(
|
||||
"only one reference to a mutable array may be passed as an argument "
|
||||
f"to a function, but custom_vjp function {f} got the same mutable "
|
||||
f"array reference of type {a.str_short()} at {arg_names[dup_idx]} and"
|
||||
f" {arg_names[i]}.")
|
||||
|
||||
def _check_for_returned_refs(f, out, kind):
|
||||
leaves = tree_leaves_with_path(out)
|
||||
for path, leaf in leaves:
|
||||
if isinstance((a := core.get_aval(leaf)), AbstractRef):
|
||||
loc = f' at output tree path {keystr(path)}' if path else ''
|
||||
raise ValueError(f"custom_vjp {kind} function {f} returned a mutable "
|
||||
f"a array reference of type {a.str_short()}{loc}, "
|
||||
"but mutable array references cannot be returned.")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomVJPPrimal:
|
||||
"""Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
|
||||
@ -655,14 +691,18 @@ def _check_for_tracers(x):
|
||||
raise UnexpectedTracerError(msg)
|
||||
|
||||
@partial(lu.transformation_with_aux2, use_eq_store=True)
|
||||
def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
|
||||
*args):
|
||||
def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, maybe_out_type, *args):
|
||||
if symbolic_zeros:
|
||||
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
|
||||
else:
|
||||
args = args[::2]
|
||||
py_args = tree_unflatten(in_tree, args)
|
||||
if config.mutable_array_checks.value:
|
||||
_check_for_aliased_refs(f, nondiff_argnums, py_args)
|
||||
pair_out = f(*py_args)
|
||||
if config.mutable_array_checks.value:
|
||||
_check_for_returned_refs(f, pair_out, 'fwd')
|
||||
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
|
||||
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
|
||||
"must produce a pair (list or tuple of length two) where the first "
|
||||
@ -1393,8 +1433,8 @@ def optimize_remat_of_custom_vjp_fwd(
|
||||
fwd_ = lu.wrap_init(fwd)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
|
||||
in_tree, out_type)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False,
|
||||
primal_name, fwd_name, in_tree, out_type)
|
||||
flat_fwd = _fix_fwd_args(flat_fwd)
|
||||
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
|
@ -300,8 +300,9 @@ class _DebugPrintFormatChecker(string.Formatter):
|
||||
|
||||
formatter = _DebugPrintFormatChecker()
|
||||
|
||||
def _format_print_callback(fmt: str, *args, **kwargs):
|
||||
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
|
||||
def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs):
|
||||
with np.printoptions(**np_printoptions):
|
||||
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
|
||||
|
||||
def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
|
||||
"""Prints values and works in staged out JAX functions.
|
||||
@ -338,8 +339,8 @@ def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
|
||||
# Check that we provide the correct arguments to be formatted.
|
||||
formatter.format(fmt, *args, **kwargs)
|
||||
|
||||
debug_callback(functools.partial(_format_print_callback, fmt), *args,
|
||||
**kwargs, ordered=ordered)
|
||||
debug_callback(functools.partial(_format_print_callback, fmt, np.get_printoptions()),
|
||||
*args, **kwargs, ordered=ordered)
|
||||
|
||||
|
||||
# Sharding visualization
|
||||
|
@ -81,6 +81,15 @@ class State:
|
||||
raise ValueError('Number of processes must be defined.')
|
||||
if process_id is None:
|
||||
raise ValueError('The process id of the current process must be defined.')
|
||||
if not isinstance(process_id, int):
|
||||
raise TypeError("process_id must be a nonnegative int. "
|
||||
f"Got process_id={process_id} of type {type(process_id)}.")
|
||||
if not isinstance(num_processes, int):
|
||||
raise TypeError("num_processes must be a positive int. "
|
||||
f"Got num_processes={num_processes} of type {type(num_processes)}.")
|
||||
if not (0 <= process_id < num_processes):
|
||||
raise ValueError("process_id and num_processes must be nonnegative, with process_id < num_processes. "
|
||||
f"Got process_id={process_id}, num_processes={num_processes}.")
|
||||
|
||||
self.coordinator_address = coordinator_address
|
||||
|
||||
|
@ -205,6 +205,21 @@ _default_types: dict[str, type[Any]] = {
|
||||
'c': complex_,
|
||||
}
|
||||
|
||||
def bit_width(dtype: DTypeLike) -> int:
|
||||
"""Number of bits per element for the dtype."""
|
||||
# Note: we cannot use dtype.itemsize here because this is
|
||||
# incorrect for sub-byte integer types.
|
||||
if dtype == np.dtype(bool):
|
||||
return 8 # physical bit layout for boolean dtype
|
||||
elif issubdtype(dtype, np.integer):
|
||||
return iinfo(dtype).bits
|
||||
elif issubdtype(dtype, np.floating):
|
||||
return finfo(dtype).bits
|
||||
elif issubdtype(dtype, np.complexfloating):
|
||||
return 2 * finfo(dtype).bits
|
||||
else:
|
||||
raise ValueError(f"unexpected input: {dtype=}")
|
||||
|
||||
# Trivial vectorspace datatype needed for tangent values of int/bool primals
|
||||
float0: np.dtype = np.dtype([('float0', np.void, 0)])
|
||||
|
||||
|
@ -115,7 +115,6 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
|
||||
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
|
||||
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
|
||||
|
||||
core.shaped_abstractify_handlers[EArray] = lambda self: self.aval
|
||||
core.pytype_aval_mappings[EArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
|
||||
tree_util.dispatch_registry.register_node(
|
||||
|
@ -20,7 +20,7 @@
|
||||
// 3. Add back the licence comment at the start
|
||||
//
|
||||
|
||||
namespace jax.export.serialization;
|
||||
namespace jax_export.serialization;
|
||||
|
||||
enum PyTreeDefKind: byte {
|
||||
leaf = 0,
|
||||
|
@ -12,13 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src import dtypes
|
||||
from jax._src.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src.util import safe_zip, unzip2, HashablePartial
|
||||
@ -83,7 +82,8 @@ def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
|
||||
raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
|
||||
f"but expected dtype {to_dtype}")
|
||||
chunks = jnp.split(arr, indices[:-1])
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
|
||||
return [lax.convert_element_type(chunk.reshape(shape), dtype)
|
||||
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]
|
||||
return [
|
||||
lax_internal._convert_element_type(chunk.reshape(shape), dtype,
|
||||
warn_on_complex_to_real_cast=False)
|
||||
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
|
||||
]
|
||||
|
@ -539,6 +539,9 @@ class JVPTracer(Tracer):
|
||||
def to_concrete_value(self):
|
||||
return core.to_concrete_value(self.primal)
|
||||
|
||||
def get_referent(self):
|
||||
return core.get_referent(self.primal)
|
||||
|
||||
def _primal_tangent_shapes_match(primal, tangent):
|
||||
if type(tangent) is not Zero:
|
||||
primal_aval = get_aval(primal).strip_weak_type()
|
||||
|
@ -1569,10 +1569,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
|
||||
return self if val is None else get_referent(val)
|
||||
|
||||
|
||||
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
|
||||
return x.aval
|
||||
core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
|
||||
core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval
|
||||
|
||||
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
sentinel = object()
|
||||
@ -2010,8 +2007,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
def fwd_jaxpr_from_zeros(*zeros):
|
||||
for store in fwd.stores: store and store.reset()
|
||||
fwd_ = _interleave_fun(fwd, zeros)
|
||||
jaxpr, _, consts, atr = trace_to_jaxpr_dynamic(fwd_, in_avals)
|
||||
if atr: raise NotImplementedError
|
||||
jaxpr, _, consts, attrs = trace_to_jaxpr_dynamic(fwd_, in_avals)
|
||||
if attrs: raise NotImplementedError
|
||||
return jaxpr, consts
|
||||
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
@ -2154,14 +2151,14 @@ def trace_to_jaxpr_dynamic(
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
|
||||
out_tracers = map(trace.to_jaxpr_tracer, ans)
|
||||
_check_no_refs(debug_info, out_tracers)
|
||||
_check_no_returned_refs(debug_info, out_tracers)
|
||||
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
|
||||
del trace, fun, in_tracers, out_tracers, ans
|
||||
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
|
||||
|
||||
def _check_no_refs(
|
||||
def _check_no_returned_refs(
|
||||
dbg: lu.TracingDebugInfo | None,
|
||||
out_tracers: Sequence[DynamicJaxprTracer]
|
||||
) -> None:
|
||||
|
@ -38,6 +38,7 @@ from jax._src.lax.control_flow.conditionals import (
|
||||
cond_p as cond_p,
|
||||
switch as switch,
|
||||
platform_dependent as platform_dependent,
|
||||
platform_index_p as platform_index_p,
|
||||
)
|
||||
from jax._src.lax.control_flow.solves import (
|
||||
custom_linear_solve as custom_linear_solve,
|
||||
|
@ -89,13 +89,6 @@ def _initial_style_jaxprs_with_common_consts(
|
||||
|
||||
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
|
||||
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
|
||||
# If we get a `Ref` in the consts, we know it must come from an outer
|
||||
# `run_state`. We also know if shouldn't be boxed up in another tracer.
|
||||
# We assert that it is in fact a DynamicJaxprTracer
|
||||
for consts, consts_avals in zip(all_consts, all_const_avals):
|
||||
for c, aval in zip(consts, consts_avals):
|
||||
if isinstance(aval, state.AbstractRef):
|
||||
assert isinstance(c, pe.DynamicJaxprTracer)
|
||||
|
||||
# TODO(sharadmv,mattjj): we could dedup *all consts* instead of just the Refs.
|
||||
|
||||
|
@ -25,6 +25,8 @@ from typing import Any, TypeVar
|
||||
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src.api_util import (
|
||||
_check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -136,8 +138,14 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
dbg = pe.debug_info(branches[0], ops_tree, None, False, 'switch')
|
||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
branches, ops_tree, ops_avals, primitive_name='switch')
|
||||
if config.mutable_array_checks.value:
|
||||
_check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops)
|
||||
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
||||
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
||||
out_trees[0], jaxprs[0].out_avals,
|
||||
@ -228,11 +236,14 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
dbg = pe.debug_info(true_fun, ops_tree, None, False, 'cond')
|
||||
_check_no_aliased_ref_args(dbg, ops_avals, ops)
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals):
|
||||
raise ValueError("Cannot pass `Ref`s into `cond`.")
|
||||
true_jaxpr, false_jaxpr = jaxprs
|
||||
if config.mutable_array_checks.value:
|
||||
_check_no_aliased_closed_over_refs(dbg, (*true_jaxpr.consts, *consts), ops)
|
||||
|
||||
out_tree, false_out_tree = out_trees
|
||||
if any(isinstance(out_aval, AbstractRef) for out_aval in
|
||||
@ -934,6 +945,7 @@ def platform_dependent(*args: Any,
|
||||
platform_index = platform_index_p.bind(
|
||||
platforms=tuple(tuple(ps) for ps in platforms_lists),
|
||||
has_default=(default is not None))
|
||||
|
||||
if default is not None:
|
||||
branches = branches + (default,)
|
||||
# Use a switch, to get the proper transformation rules for free. Since
|
||||
@ -946,6 +958,8 @@ def platform_dependent(*args: Any,
|
||||
# recognized on the compilation platform. Detect eager mode and keep only the
|
||||
# needed branch.
|
||||
try:
|
||||
# Note/TODO(mvoz): This actually rarely seems to concretize - we could look into
|
||||
# core.ensure_compile_time_eval to get better single-branch selection.
|
||||
platform_index_concrete = core.concrete_or_error(operator.index, platform_index)
|
||||
except core.ConcretizationTypeError:
|
||||
return switch(platform_index, branches, *args)
|
||||
|
@ -546,7 +546,8 @@ def _convert_element_type(
|
||||
operand: ArrayLike,
|
||||
new_dtype: DTypeLike | dtypes.ExtendedDType | None = None,
|
||||
weak_type: bool = False,
|
||||
sharding: Sharding | None = None):
|
||||
sharding: Sharding | None = None,
|
||||
warn_on_complex_to_real_cast: bool = True):
|
||||
if hasattr(operand, '__jax_array__'):
|
||||
operand = operand.__jax_array__()
|
||||
|
||||
@ -585,7 +586,8 @@ def _convert_element_type(
|
||||
isinstance(operand, Array)):
|
||||
sharding = operand.sharding
|
||||
|
||||
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
||||
if (warn_on_complex_to_real_cast and
|
||||
dtypes.issubdtype(old_dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||
msg = "Casting complex values to real discards the imaginary part"
|
||||
warnings.warn(msg, NumpyComplexWarning, stacklevel=2)
|
||||
@ -3197,12 +3199,15 @@ def _convert_elt_type_folding_rule(consts, eqn):
|
||||
# TODO(mattjj): allow constant-folding CPU-backed JAX arrays
|
||||
c, = consts
|
||||
o, = eqn.outvars
|
||||
new_dtype = eqn.params['new_dtype']
|
||||
if (type(c) in {np.ndarray, *dtypes.python_scalar_dtypes} and
|
||||
isinstance(o.aval, core.UnshapedArray) and not np.shape(c) and
|
||||
not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended)):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', util.NumpyComplexWarning)
|
||||
out = np.array(c).astype(eqn.params['new_dtype'])
|
||||
not dtypes.issubdtype(new_dtype, dtypes.extended)):
|
||||
out = np.array(c)
|
||||
if (dtypes.issubdtype(out.dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||
out = out.real
|
||||
out = out.astype(new_dtype)
|
||||
if not o.aval.weak_type:
|
||||
return [out], None
|
||||
out = out.item()
|
||||
@ -3367,18 +3372,21 @@ def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
|
||||
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
||||
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
||||
|
||||
if old_dtype.itemsize == new_dtype.itemsize:
|
||||
old_nbits = dtypes.bit_width(old_dtype)
|
||||
new_nbits = dtypes.bit_width(new_dtype)
|
||||
|
||||
if old_nbits == new_nbits:
|
||||
return operand.shape
|
||||
elif old_dtype.itemsize > new_dtype.itemsize:
|
||||
return (*operand.shape, old_dtype.itemsize // new_dtype.itemsize)
|
||||
elif old_nbits > new_nbits:
|
||||
return (*operand.shape, old_nbits // new_nbits)
|
||||
else:
|
||||
dim_size = operand.shape[-1] if operand.shape else 1
|
||||
if dim_size * old_dtype.itemsize != new_dtype.itemsize:
|
||||
if dim_size * old_nbits != new_nbits:
|
||||
raise ValueError(
|
||||
f"Attempting to convert array of shape {operand.shape} "
|
||||
f"from {old_dtype} of size {old_dtype.itemsize} "
|
||||
f"to {new_dtype} of size {new_dtype.itemsize}, "
|
||||
f"but {dim_size} * {old_dtype.itemsize} != {new_dtype.itemsize}")
|
||||
f"from {old_dtype} of size {old_nbits} bits "
|
||||
f"to {new_dtype} of size {new_nbits}, bits "
|
||||
f"but {dim_size} * {old_nbits} != {new_nbits}")
|
||||
return operand.shape[:-1]
|
||||
|
||||
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
|
||||
|
@ -490,13 +490,45 @@ def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets
|
||||
# Index 'data' at 'offsets'[2], 'sizes'[2]'
|
||||
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
|
||||
|
||||
|
||||
``output_offsets`` must be sharded in a way that each replica has offsets in
|
||||
the target replica output perspective.
|
||||
|
||||
For i-th output offset, the current replica will send
|
||||
`operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th
|
||||
replica that will be written to
|
||||
`output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th
|
||||
replica ``output``.
|
||||
|
||||
For example, if we have 2 replicas:
|
||||
|
||||
replica 0:
|
||||
operand: [1, 2, 2]
|
||||
output: [0, 0, 0, 0]
|
||||
input_offsets: [0, 1]
|
||||
send_sizes: [1, 2]
|
||||
output_offsets: [0, 0]
|
||||
recv_sizes: [1, 1]
|
||||
|
||||
replica 1:
|
||||
operand: [3, 4, 0]
|
||||
output: [0, 0, 0, 0]
|
||||
input_offsets: [0, 1]
|
||||
send_sizes: [1, 1]
|
||||
output_offsets: [1, 2]
|
||||
recv_sizes: [2, 1]
|
||||
|
||||
replica 0's result will be: [1, 3, 0, 0]
|
||||
replica 1's result will be: [2, 2, 4, 0]
|
||||
|
||||
Args:
|
||||
operand: array with ragged dimension along its outermost dimension.
|
||||
output: array of ragged input offsets.
|
||||
input_offsets: array of ragged input send sizes.
|
||||
send_sizes: array of ragged output data.
|
||||
output_offsets: array of ragged output offsets.
|
||||
output_offsets: array of ragged offsets in the target replica output.
|
||||
recv_sizes: array of ragged output receive sizes.
|
||||
|
||||
Returns:
|
||||
array with shape equal to ``output``.
|
||||
"""
|
||||
|
@ -173,8 +173,40 @@ lt = np.less
|
||||
def convert_element_type(operand, dtype):
|
||||
return np.asarray(operand, dtype=dtype)
|
||||
|
||||
def _bitcast_uint4_to_uint8(operand):
|
||||
# Note: assumes little-endian byte order.
|
||||
assert operand.dtype == 'uint4'
|
||||
operand = operand.astype('uint8')
|
||||
return operand[..., ::2] + (operand[..., 1::2] << 4)
|
||||
|
||||
def _bitcast_uint8_to_uint4(operand):
|
||||
# Note: assumes little-endian byte order.
|
||||
assert operand.dtype == 'uint8'
|
||||
result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4')
|
||||
result[..., ::2] = (operand & 0b00001111).astype('uint4')
|
||||
result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4')
|
||||
return result
|
||||
|
||||
def bitcast_convert_type(operand, dtype):
|
||||
return np.asarray(operand).view(dtype)
|
||||
operand = np.asarray(operand)
|
||||
nbits_in = dtypes.bit_width(operand.dtype)
|
||||
nbits_out = dtypes.bit_width(dtype)
|
||||
|
||||
if nbits_out > nbits_in:
|
||||
assert operand.shape[-1] == nbits_out // nbits_in
|
||||
out_shape = operand.shape[:-1]
|
||||
elif nbits_out == nbits_in:
|
||||
out_shape = operand.shape
|
||||
else:
|
||||
out_shape = (*operand.shape, nbits_in // nbits_out)
|
||||
|
||||
# Special handling for 4-bit integers.
|
||||
if nbits_in == 4:
|
||||
operand = _bitcast_uint4_to_uint8(operand.view('uint4'))
|
||||
if nbits_out == 4:
|
||||
operand = _bitcast_uint8_to_uint4(operand.view('uint8'))
|
||||
|
||||
return operand.view(dtype).reshape(out_shape)
|
||||
|
||||
def clamp(min, operand, max):
|
||||
return np.clip(operand, np.clip(min, None, max), max).astype(operand.dtype)
|
||||
|
@ -710,7 +710,6 @@ def one_hot(x: Any, num_classes: int, *,
|
||||
'jax-nn-one-hot-float-input',
|
||||
f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}",
|
||||
stacklevel=1)
|
||||
x_arr = x_arr.astype('int32')
|
||||
return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis)
|
||||
|
||||
|
||||
|
@ -509,12 +509,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr
|
||||
dtypes.check_user_dtype_supported(dtype, "view")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
||||
nbits_in = dtypes.bit_width(self.dtype)
|
||||
nbits_out = dtypes.bit_width(dtype)
|
||||
|
||||
if self.ndim == 0:
|
||||
if self.dtype.itemsize != dtype.itemsize:
|
||||
if nbits_in != nbits_out:
|
||||
raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.")
|
||||
return _view(lax.expand_dims(self, (0,)), dtype).squeeze()
|
||||
|
||||
if (self.shape[-1] * self.dtype.itemsize) % dtype.itemsize != 0:
|
||||
if (self.shape[-1] * nbits_in) % nbits_out != 0:
|
||||
raise ValueError("When changing to a larger dtype, its size must be a divisor "
|
||||
"of the total size in bytes of the last axis of the array.")
|
||||
|
||||
@ -543,16 +546,15 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr
|
||||
|
||||
# lax.bitcast_convert_type adds or subtracts dimensions depending on the
|
||||
# relative bitwidths of the dtypes; we account for that with reshapes.
|
||||
if self.dtype.itemsize < dtype.itemsize:
|
||||
factor = dtype.itemsize // self.dtype.itemsize
|
||||
if nbits_in < nbits_out:
|
||||
factor = nbits_out // nbits_in
|
||||
out = self.reshape(*self.shape[:-1], self.shape[-1] // factor, factor)
|
||||
return lax.bitcast_convert_type(out, dtype)
|
||||
|
||||
if self.dtype.itemsize > dtype.itemsize:
|
||||
elif nbits_in > nbits_out:
|
||||
out = lax.bitcast_convert_type(self, dtype)
|
||||
return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1])
|
||||
|
||||
return lax.bitcast_convert_type(self, dtype)
|
||||
else:
|
||||
return lax.bitcast_convert_type(self, dtype)
|
||||
|
||||
|
||||
def _notimplemented_flat(self):
|
||||
|
@ -191,7 +191,7 @@ class _ScalarMeta(type):
|
||||
|
||||
def _abstractify_scalar_meta(x):
|
||||
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
|
||||
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
|
||||
core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta
|
||||
|
||||
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
|
||||
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
|
||||
@ -5731,10 +5731,9 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
|
||||
|
||||
# We offer a more specific warning than the usual ComplexWarning so we prefer
|
||||
# to issue our warning.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", ComplexWarning)
|
||||
result = lax_internal._convert_element_type(
|
||||
x_arr, dtype, sharding=_normalize_to_sharding(device))
|
||||
result = lax_internal._convert_element_type(
|
||||
x_arr, dtype, sharding=_normalize_to_sharding(device),
|
||||
warn_on_complex_to_real_cast=False)
|
||||
return _array_copy(result) if copy else result
|
||||
|
||||
|
||||
@ -8341,18 +8340,41 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
|
||||
Array([4, 8], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("diagonal", a)
|
||||
a_shape = shape(a)
|
||||
|
||||
if ndim(a) < 2:
|
||||
raise ValueError("diagonal requires an array of at least two dimensions.")
|
||||
offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()")
|
||||
|
||||
a = moveaxis(a, (axis1, axis2), (-2, -1))
|
||||
def _default_diag(a):
|
||||
a_shape = shape(a)
|
||||
|
||||
diag_size = max(0, min(a_shape[axis1] + min(offset, 0),
|
||||
a_shape[axis2] - max(offset, 0)))
|
||||
i = arange(diag_size)
|
||||
j = arange(abs(offset), abs(offset) + diag_size)
|
||||
return a[..., i, j] if offset >= 0 else a[..., j, i]
|
||||
a = moveaxis(a, (axis1, axis2), (-2, -1))
|
||||
|
||||
diag_size = max(
|
||||
0, min(a_shape[axis1] + min(offset, 0), a_shape[axis2] - max(offset, 0))
|
||||
)
|
||||
i = arange(diag_size)
|
||||
j = arange(abs(offset), abs(offset) + diag_size)
|
||||
return a[..., i, j] if offset >= 0 else a[..., j, i]
|
||||
|
||||
|
||||
# The mosaic lowering rule for diag is only defined for square arrays.
|
||||
# TODO(mvoz): Add support for offsets.
|
||||
if shape(a)[0] != shape(a)[1] or ndim(a) != 2 or offset != 0 or _dtype(a) == bool_:
|
||||
return _default_diag(a)
|
||||
else:
|
||||
a_shape_eye = eye(shape(a)[0], dtype=_dtype(a))
|
||||
|
||||
def _mosaic_diag(a):
|
||||
def _sum(x, axis):
|
||||
return lax.reduce(
|
||||
x,
|
||||
np.array(0, _dtype(x)),
|
||||
lax.add if _dtype(x) != bool_ else lax.bitwise_or,
|
||||
(axis,),
|
||||
)
|
||||
return _sum(lax.mul(a_shape_eye, a), axis=0)
|
||||
return lax.platform_dependent(a, default=_default_diag, mosaic=_mosaic_diag)
|
||||
|
||||
|
||||
@export
|
||||
|
@ -924,7 +924,7 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None,
|
||||
- :func:`jax.numpy.linalg.inv`: multiplicative inverse of a square matrix.
|
||||
|
||||
Notes:
|
||||
:func:`jax.numpy.linalg.prng` differs from :func:`numpy.linalg.prng` in the
|
||||
:func:`jax.numpy.linalg.pinv` differs from :func:`numpy.linalg.pinv` in the
|
||||
default value of `rcond``: in NumPy, the default is `1e-15`. In JAX, the
|
||||
default is ``10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps``.
|
||||
|
||||
|
@ -20,7 +20,6 @@ from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import overload, Any, Literal, Protocol, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -37,7 +36,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
|
||||
from jax._src.util import (
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
|
||||
set_module, NumpyComplexWarning)
|
||||
set_module)
|
||||
|
||||
|
||||
export = set_module('jax.numpy')
|
||||
@ -100,7 +99,7 @@ ReductionOp = Callable[[Any, Any], Any]
|
||||
|
||||
def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
|
||||
*, has_identity: bool = True,
|
||||
preproc: Callable[[ArrayLike], ArrayLike] | None = None,
|
||||
preproc: Callable[[Array], Array] | None = None,
|
||||
bool_op: ReductionOp | None = None,
|
||||
upcast_f16_for_computation: bool = False,
|
||||
axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
|
||||
@ -201,16 +200,15 @@ def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
|
||||
sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
|
||||
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
|
||||
|
||||
def _cast_to_bool(operand: ArrayLike) -> Array:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=NumpyComplexWarning)
|
||||
return lax.convert_element_type(operand, np.bool_)
|
||||
def _cast_to_bool(operand: Array) -> Array:
|
||||
if dtypes.issubdtype(operand.dtype, np.complexfloating):
|
||||
operand = operand.real
|
||||
return lax.convert_element_type(operand, np.bool_)
|
||||
|
||||
def _cast_to_numeric(operand: ArrayLike) -> Array:
|
||||
def _cast_to_numeric(operand: Array) -> Array:
|
||||
return promote_dtypes_numeric(operand)[0]
|
||||
|
||||
def _require_integer(operand: ArrayLike) -> Array:
|
||||
arr = lax_internal.asarray(operand)
|
||||
def _require_integer(arr: Array) -> Array:
|
||||
if not dtypes.isdtype(arr, ("bool", "integral")):
|
||||
raise ValueError(f"integer argument required; got dtype={arr.dtype}")
|
||||
return arr
|
||||
|
@ -2765,6 +2765,12 @@ def log2(x: ArrayLike, /) -> Array:
|
||||
Array([-2., -1., 0., 1., 2., 3.], dtype=float32)
|
||||
"""
|
||||
x, = promote_args_inexact("log2", x)
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
r = lax.log(x)
|
||||
re = lax.real(r)
|
||||
im = lax.imag(r)
|
||||
ln2 = lax.log(_constant_like(re, 2))
|
||||
return lax.complex(lax.div(re, ln2), lax.div(im, ln2))
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
|
||||
|
||||
|
||||
@ -2789,6 +2795,12 @@ def log10(x: ArrayLike, /) -> Array:
|
||||
[-2. -1. 0. 1. 2. 3.]
|
||||
"""
|
||||
x, = promote_args_inexact("log10", x)
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
r = lax.log(x)
|
||||
re = lax.real(r)
|
||||
im = lax.imag(r)
|
||||
ln10 = lax.log(_constant_like(re, 10))
|
||||
return lax.complex(lax.div(re, ln10), lax.div(im, ln10))
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
|
||||
|
||||
|
||||
|
@ -547,9 +547,13 @@ def lower_jaxpr_to_module(
|
||||
module_name = name_and_src_info.name
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
sym_tab = ir.SymbolTable(m.operation)
|
||||
|
||||
func_op = lower_jaxpr_to_func(
|
||||
ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
|
||||
name="main", for_verification=for_verification,
|
||||
ctx,
|
||||
jaxpr,
|
||||
mosaic_grid_mapping=mosaic_grid_mapping,
|
||||
name="main",
|
||||
for_verification=for_verification,
|
||||
)
|
||||
m.body.append(func_op)
|
||||
sym_tab.insert(func_op)
|
||||
@ -568,6 +572,7 @@ def lower_jaxpr_to_module(
|
||||
# We checked above that the block does not require windowing.
|
||||
window_params.append(ir.DictAttr.get())
|
||||
continue
|
||||
|
||||
mlir_func = lower_jaxpr_to_transform_func(
|
||||
ctx,
|
||||
bm.index_map_jaxpr.jaxpr,
|
||||
@ -1990,6 +1995,36 @@ lowering_rules[ad_util.add_any_p] = _add_lowering_rule
|
||||
skip_mlir_conversions.add(ad_util.add_any_p)
|
||||
|
||||
|
||||
class FoldingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _fold_and_get_constant_value(x):
|
||||
def _fold(x, fuel):
|
||||
if fuel <= 0:
|
||||
raise FoldingError("Folding depth exceeded")
|
||||
op_name = getattr(x.owner, "name", None)
|
||||
binop_folds = {
|
||||
"arith.maxsi": max,
|
||||
"arith.minsi": min,
|
||||
}
|
||||
if op_name == "arith.constant":
|
||||
if ir.IntegerType.isinstance(x.type):
|
||||
return ir.IntegerAttr(x.owner.attributes["value"]).value
|
||||
elif ir.FloatType.isinstance(x.type):
|
||||
return ir.FloatAttr(x.owner.attributes["value"]).value
|
||||
else:
|
||||
raise ValueError(f"Unsupported constant type: {x.type}")
|
||||
if op_name in binop_folds:
|
||||
return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands)
|
||||
raise FoldingError(f"Folding not supported for {x.owner}")
|
||||
|
||||
try:
|
||||
return _fold(x, 10)
|
||||
except FoldingError:
|
||||
return None
|
||||
|
||||
|
||||
def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
||||
(aval_out,) = ctx.avals_out
|
||||
@ -2708,6 +2743,12 @@ lowering_rules[lax.while_p] = _while_lowering_rule
|
||||
|
||||
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches):
|
||||
index, *args = args
|
||||
constant_index = _fold_and_get_constant_value(index)
|
||||
|
||||
if constant_index is not None:
|
||||
return jaxpr_subcomp(
|
||||
ctx.lowering_context.replace(block_shapes=ctx.block_shapes[1:]), branches[constant_index].jaxpr, *args
|
||||
)
|
||||
out_types = map(aval_to_ir_type, ctx.avals_out)
|
||||
pred = arith.cmpi(
|
||||
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
|
||||
@ -3375,3 +3416,25 @@ def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
|
||||
|
||||
|
||||
lowering_rules[lax.pad_p] = _pad_lowering_rule
|
||||
|
||||
|
||||
def _platform_index_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*,
|
||||
platforms: Sequence[Sequence[str]],
|
||||
has_default: bool,
|
||||
):
|
||||
for i, ps in enumerate(platforms):
|
||||
# note - slightly odd structure here, as platforms is a seq[seq[str]]
|
||||
if "mosaic" in ps:
|
||||
return ir_constant(i)
|
||||
|
||||
if has_default:
|
||||
return ir_constant(len(platforms))
|
||||
|
||||
raise NotImplementedError(
|
||||
"No mosaic or default platform indexing rule found."
|
||||
)
|
||||
|
||||
|
||||
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
|
||||
|
@ -797,25 +797,30 @@ def lower_jaxpr_to_module(
|
||||
# Each range is 2 events, each event is 4 bytes.
|
||||
prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4)
|
||||
prof_ctx = ProfilerContext(params["profile_dir"], prof_spec)
|
||||
module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel(
|
||||
body,
|
||||
grid=parallel_grid,
|
||||
cluster=(),
|
||||
block=block,
|
||||
in_shapes=in_structs_gmem,
|
||||
out_shape=out_structs_gmem,
|
||||
smem_scratch_shape=(
|
||||
(*in_structs_smem, *out_structs_smem),
|
||||
*extra_smem_scratch,
|
||||
(
|
||||
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
|
||||
rs.barriers,
|
||||
extra_barriers,
|
||||
module, out_structs_gmem, _, launch_ctx, scratch_arr = (
|
||||
mgpu_core._lower_as_gpu_kernel(
|
||||
body,
|
||||
grid=parallel_grid,
|
||||
cluster=(),
|
||||
block=block,
|
||||
in_shapes=in_structs_gmem,
|
||||
out_shape=out_structs_gmem,
|
||||
smem_scratch_shape=(
|
||||
(*in_structs_smem, *out_structs_smem),
|
||||
*extra_smem_scratch,
|
||||
(
|
||||
mgpu.Barrier(
|
||||
arrival_count=1, num_barriers=max_concurrent_steps
|
||||
),
|
||||
rs.barriers,
|
||||
extra_barriers,
|
||||
),
|
||||
),
|
||||
),
|
||||
module_name=name_and_src_info.name,
|
||||
prof_spec=prof_spec,
|
||||
module_name=name_and_src_info.name,
|
||||
prof_spec=prof_spec,
|
||||
)
|
||||
)
|
||||
mgpu_core._initialize_scratch(launch_ctx, scratch_arr)
|
||||
|
||||
return LoweringResult(
|
||||
module, parallel_grid, block, out_structs_gmem, prof_ctx
|
||||
@ -1572,8 +1577,11 @@ def _lower_while_via_fori(
|
||||
body_nconsts,
|
||||
):
|
||||
assert not fori_jaxpr.constvars
|
||||
# The pattern matcher looks for conditions with no constants.
|
||||
assert cond_nconsts == 0
|
||||
|
||||
# Reflect the changes of the pattern matcher to the context.
|
||||
lb_aval, ub_aval, *_ = ctx.avals_in[cond_nconsts + body_nconsts:]
|
||||
ctx = ctx.replace(
|
||||
avals_in=(
|
||||
*ctx.avals_in[cond_nconsts:body_nconsts],
|
||||
@ -1585,7 +1593,6 @@ def _lower_while_via_fori(
|
||||
_, consts, (lb, ub, *args) = util.split_list(
|
||||
args, [cond_nconsts, body_nconsts]
|
||||
)
|
||||
lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:]
|
||||
lb = _ensure_ir_value(lb, lb_aval.dtype)
|
||||
ub = _ensure_ir_value(ub, ub_aval.dtype)
|
||||
for_out = _lower_jaxpr_to_for_loop(
|
||||
@ -1780,29 +1787,21 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
|
||||
if isinstance(x, mgpu.FragmentedArray):
|
||||
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
return x
|
||||
elif isinstance(x, (np.number, np.ndarray, int, float)):
|
||||
return mgpu.FragmentedArray.splat(
|
||||
_ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)),
|
||||
(),
|
||||
is_signed=mgpu_utils.is_signed(dtype),
|
||||
)
|
||||
elif isinstance(x, ir.Value):
|
||||
if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)):
|
||||
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype))
|
||||
raise NotImplementedError(f"Unsupported type: {type(x)}")
|
||||
return mgpu.FragmentedArray.splat(
|
||||
_ensure_ir_value(x, dtype), (), is_signed=mgpu_utils.is_signed(dtype)
|
||||
)
|
||||
|
||||
|
||||
def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value:
|
||||
if isinstance(x, ir.Value):
|
||||
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
return x
|
||||
elif isinstance(x, (np.number, np.ndarray, int, float)):
|
||||
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
|
||||
elif isinstance(x, mgpu.FragmentedArray):
|
||||
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
if isinstance(x.layout, mgpu.WGSplatFragLayout):
|
||||
return x.registers.item()
|
||||
raise NotImplementedError(f"Unsupported type: {type(x)}")
|
||||
raise NotImplementedError(f"Unsupported layout: {x.layout}")
|
||||
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
|
||||
|
||||
|
||||
def _ir_constant(v: object, t: ir.Type) -> ir.Value:
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import enum
|
||||
import math
|
||||
from typing import Any, Literal
|
||||
@ -25,7 +26,6 @@ from jax._src import core as jax_core
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith as arith_dialect
|
||||
from jax._src.lib.mlir.dialects import llvm as llvm_dialect
|
||||
@ -33,23 +33,42 @@ from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.mosaic_gpu import core as gpu_core
|
||||
from jax._src.pallas.mosaic_gpu import lowering
|
||||
from jax._src.pallas.mosaic_gpu.core import state_types
|
||||
from jax._src.state import discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as state_primitives
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
from jax.experimental.mosaic import gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import utils as mgpu_utils
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
WARPGROUP_SIZE = 128
|
||||
|
||||
|
||||
_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef
|
||||
|
||||
|
||||
def _check_ref(
|
||||
aval: object, name: str, memory_space: gpu_core.GPUMemorySpace
|
||||
) -> None:
|
||||
if not isinstance(aval, state_types.AbstractRef):
|
||||
raise TypeError(f"{name} must be a reference, got {aval}")
|
||||
aval_memory_space = getattr(aval, "memory_space", None) or gpu_core.GMEM
|
||||
if aval_memory_space is not memory_space:
|
||||
raise ValueError(
|
||||
f"{name} must be a {memory_space.name.upper()} reference, got {aval}"
|
||||
)
|
||||
|
||||
|
||||
copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem")
|
||||
copy_smem_to_gmem_p.multiple_results = True
|
||||
|
||||
|
||||
@copy_smem_to_gmem_p.def_effectful_abstract_eval
|
||||
def _copy_smem_to_gmem_abstract_eval(*avals, **params):
|
||||
del avals, params # Unused.
|
||||
def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
|
||||
_check_ref(src, "src", gpu_core.SMEM)
|
||||
_check_ref(dst, "dst", gpu_core.GMEM)
|
||||
del args, params # Unused.
|
||||
return (), {state.ReadEffect(0), state.WriteEffect(1)}
|
||||
|
||||
|
||||
@ -114,9 +133,7 @@ def _extract_smem_copy_params(transforms):
|
||||
|
||||
|
||||
def copy_smem_to_gmem(
|
||||
src: pallas_core.AbstractMemoryRef,
|
||||
dst: pallas_core.AbstractMemoryRef,
|
||||
predicate: jax.Array | None = None,
|
||||
src: _Ref, dst: _Ref, predicate: jax.Array | None = None
|
||||
) -> None:
|
||||
"""Asynchronously copies a SMEM reference to a GMEM reference.
|
||||
|
||||
@ -130,10 +147,6 @@ def copy_smem_to_gmem(
|
||||
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
|
||||
:func:`jax.experimental.mosaic.gpu.commit_smem`
|
||||
"""
|
||||
if src.memory_space is not gpu_core.SMEM:
|
||||
raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
|
||||
if getattr(dst, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
|
||||
raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}")
|
||||
src, src_transforms = state_primitives.get_ref_and_transforms(
|
||||
src, None, "copy_smem_to_gmem", force_trailing_indexer=False,
|
||||
)
|
||||
@ -164,8 +177,11 @@ copy_gmem_to_smem_p.multiple_results = True
|
||||
|
||||
|
||||
@copy_gmem_to_smem_p.def_effectful_abstract_eval
|
||||
def _copy_gmem_to_smem_abstract_eval(*avals, **params):
|
||||
del avals, params # Unused.
|
||||
def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
|
||||
del args, params # Unused.
|
||||
_check_ref(src, "src", gpu_core.GMEM)
|
||||
_check_ref(dst, "dst", gpu_core.SMEM)
|
||||
_check_ref(barrier, "barrier", gpu_core.SMEM)
|
||||
return (), {state.ReadEffect(0), state.WriteEffect(1)}
|
||||
|
||||
|
||||
@ -217,21 +233,13 @@ def _copy_gmem_to_smem_lowering(
|
||||
return ()
|
||||
|
||||
|
||||
def copy_gmem_to_smem(
|
||||
src: pallas_core.AbstractMemoryRef,
|
||||
dst: pallas_core.AbstractMemoryRef,
|
||||
barrier: pallas_core.AbstractMemoryRef,
|
||||
) -> None:
|
||||
def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None:
|
||||
"""Asynchronously copies a GMEM reference to a SMEM reference.
|
||||
|
||||
See also:
|
||||
:func:`jax.experimental.mosaic.gpu.barrier_arrive`
|
||||
:func:`jax.experimental.mosaic.gpu.barrier_wait`
|
||||
"""
|
||||
if getattr(src, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
|
||||
raise TypeError(f"src must be a GMEM reference, got {src.memory_space}")
|
||||
if dst.memory_space is not gpu_core.SMEM:
|
||||
raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}")
|
||||
src, src_transforms = state_primitives.get_ref_and_transforms(
|
||||
src, None, "copy_gmem_to_smem", force_trailing_indexer=False,
|
||||
)
|
||||
@ -291,8 +299,9 @@ barrier_arrive_p.multiple_results = True
|
||||
|
||||
|
||||
@barrier_arrive_p.def_effectful_abstract_eval
|
||||
def _barrier_arrive_abstract_eval(*avals, **params):
|
||||
del avals, params # Unused.
|
||||
def _barrier_arrive_abstract_eval(barrier, *args, **params):
|
||||
del args, params # Unused.
|
||||
_check_ref(barrier, "barrier", gpu_core.SMEM)
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@ -328,8 +337,9 @@ barrier_wait_p.multiple_results = True
|
||||
|
||||
|
||||
@barrier_wait_p.def_effectful_abstract_eval
|
||||
def _barrier_wait_abstract_eval(*avals, **params):
|
||||
del avals, params # Unused.
|
||||
def _barrier_wait_abstract_eval(barrier, *args, **params):
|
||||
_check_ref(barrier, "barrier", gpu_core.SMEM)
|
||||
del args, params # Unused.
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@ -703,38 +713,40 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
|
||||
del layout, dimension
|
||||
return jax_core.ShapedArray(shape, dtype)
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(broadcasted_iota_p)
|
||||
def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout):
|
||||
del ctx
|
||||
# Unsigned integers (as opposed to signless) cause MLIR verification
|
||||
# errors so we only use signless like Mosaic GPU does.
|
||||
#
|
||||
# TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead.
|
||||
mlir_dtype = (
|
||||
ir.IntegerType.get_signless(dtype.itemsize * 8)
|
||||
if jnp.issubdtype(dtype, jnp.integer)
|
||||
else mlir.dtype_to_ir_type(dtype)
|
||||
)
|
||||
undef = llvm_dialect.mlir_undef(mlir_dtype)
|
||||
is_signed = (
|
||||
jnp.issubdtype(dtype, jnp.signedinteger)
|
||||
if jnp.issubdtype(dtype, jnp.integer)
|
||||
else None
|
||||
)
|
||||
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
def _cast(x):
|
||||
if ir.FloatType.isinstance(mlir_dtype):
|
||||
x = arith_dialect.index_cast(i32, x)
|
||||
return arith_dialect.uitofp(mlir_dtype, x)
|
||||
else:
|
||||
return arith_dialect.index_cast(mlir_dtype, x)
|
||||
def _broadcasted_iota_lowering(
|
||||
ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout
|
||||
):
|
||||
del ctx # Unused.
|
||||
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
|
||||
if ir.FloatType.isinstance(mlir_dtype):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
cast = lambda x: arith_dialect.uitofp(
|
||||
mlir_dtype, arith_dialect.index_cast(i32, x)
|
||||
)
|
||||
else:
|
||||
cast = lambda x: arith_dialect.index_cast(mlir_dtype, x)
|
||||
is_signed = mgpu_utils.is_signed(dtype)
|
||||
return mgpu.FragmentedArray.splat(
|
||||
undef, shape, layout.value, is_signed=is_signed
|
||||
llvm_dialect.mlir_undef(mlir_dtype),
|
||||
shape,
|
||||
layout.value,
|
||||
is_signed=is_signed,
|
||||
).foreach(
|
||||
lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed
|
||||
lambda _, idx: cast(idx[dimension]),
|
||||
create_array=True,
|
||||
is_signed=is_signed,
|
||||
)
|
||||
|
||||
|
||||
def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None):
|
||||
return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout)
|
||||
def broadcasted_iota(
|
||||
dtype: jax.typing.DTypeLike,
|
||||
shape: Sequence[int],
|
||||
dimension: int,
|
||||
*,
|
||||
layout: Layout | None = None,
|
||||
) -> jax.Array:
|
||||
return broadcasted_iota_p.bind(
|
||||
dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout
|
||||
)
|
||||
|
@ -461,8 +461,6 @@ class KeyTy(dtypes.ExtendedDType):
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
|
||||
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
|
||||
|
@ -249,8 +249,8 @@ class XlaExecutable(Executable):
|
||||
else:
|
||||
raise
|
||||
|
||||
# TODO(skyewm): this should return a single dict (I think returning a list
|
||||
# was to support MPMD executables, which never fully landed)
|
||||
# TODO(b/384741132): this should return a single dict (I think returning a list
|
||||
# was to support MPMD executables, which never fully landed).
|
||||
def cost_analysis(self) -> list[dict[str, float]]:
|
||||
xla_ext_exe = self.xla_extension_executable()
|
||||
|
||||
@ -266,9 +266,19 @@ class XlaExecutable(Executable):
|
||||
# Try client method if executable cost_analysis method is unimplemented
|
||||
if hasattr(xla_ext_exe, "client"):
|
||||
try:
|
||||
# TODO(b/384741132): We expect that the executable has only one
|
||||
# HloModule. We should be able to remove this check once we update the
|
||||
# Executable class to have only a single HloModule (see bug).
|
||||
hlo_modules = xla_ext_exe.hlo_modules()
|
||||
assert len(hlo_modules) == 1, (
|
||||
f"Exectuable should have only one HloModule ({len(hlo_modules)})"
|
||||
" were found)."
|
||||
)
|
||||
|
||||
return [
|
||||
xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m)
|
||||
for m in xla_ext_exe.hlo_modules()
|
||||
xla_extension.hlo_module_cost_analysis(
|
||||
xla_ext_exe.client, hlo_modules[0]
|
||||
)
|
||||
]
|
||||
except xla_extension.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
|
@ -382,6 +382,17 @@ def _lower_tpu_kernel(
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-infer-vector-layout")
|
||||
|
||||
pipeline = [
|
||||
(
|
||||
"func.func(tpu-relayout-insertion{"
|
||||
f" sublane-count={sl_cnt} lane-count={l_cnt}"
|
||||
"})"
|
||||
),
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-relayout-insertion")
|
||||
|
||||
mxu_size = 128 if hardware_generation < 6 else 256
|
||||
pipeline = [
|
||||
"func.func(tpu-apply-vector-layout{"
|
||||
|
@ -128,7 +128,7 @@ _deprecations = {
|
||||
_src_core.escaped_tracer_error),
|
||||
"extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.",
|
||||
_src_core.extend_axis_env_nd),
|
||||
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_type),
|
||||
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval),
|
||||
"get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent),
|
||||
"join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects),
|
||||
"leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.",
|
||||
@ -212,7 +212,7 @@ if typing.TYPE_CHECKING:
|
||||
escaped_tracer_error = _src_core.escaped_tracer_error
|
||||
extend_axis_env_nd = _src_core.extend_axis_env_nd
|
||||
full_lower = _src_core.full_lower
|
||||
get_type = _src_core.get_type
|
||||
get_type = _src_core.get_aval
|
||||
get_referent = _src_core.get_referent
|
||||
jaxpr_as_fun = _src_core.jaxpr_as_fun
|
||||
join_effects = _src_core.join_effects
|
||||
|
@ -174,7 +174,6 @@ def _construct_smem_reftree(
|
||||
) -> RefTree:
|
||||
index = ir.IndexType.get()
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
ptr = ir.Type.parse("!llvm.ptr")
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
|
||||
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
|
||||
@ -183,11 +182,19 @@ def _construct_smem_reftree(
|
||||
for ref_ty in flat_ref_tys:
|
||||
def get_barrier_ptr(num_barriers: int) -> ir.Value:
|
||||
nonlocal dynamic_smem_offset
|
||||
smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3)
|
||||
barrier_base_ptr = llvm.getelementptr(
|
||||
ptr, smem_base_ptr, [], [dynamic_smem_offset], i8
|
||||
workgroup_nvptx_address_space = (
|
||||
dialect_lowering.gpu_address_space_to_nvptx(
|
||||
gpu.AddressSpace.Workgroup
|
||||
)
|
||||
)
|
||||
dynamic_smem_offset += num_barriers * MBARRIER_BYTES
|
||||
smem_base_ptr = utils.memref_ptr(
|
||||
dynamic_smem, memory_space=workgroup_nvptx_address_space
|
||||
)
|
||||
smem_ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
|
||||
barrier_base_ptr = llvm.getelementptr(
|
||||
smem_ptr_ty, smem_base_ptr, [], [dynamic_smem_offset], i8
|
||||
)
|
||||
dynamic_smem_offset += num_barriers * utils.MBARRIER_BYTES
|
||||
return barrier_base_ptr
|
||||
match ref_ty:
|
||||
case Union(members):
|
||||
@ -227,9 +234,6 @@ def _construct_smem_reftree(
|
||||
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
|
||||
|
||||
|
||||
MBARRIER_BYTES = 8
|
||||
|
||||
|
||||
def _smem_tree_size(smem_buffers: ShapeTree) -> int:
|
||||
leaves = jax.tree.leaves(
|
||||
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
|
||||
@ -244,9 +248,9 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int:
|
||||
| ClusterBarrier(_, num_barriers=num_barriers)
|
||||
| Barrier(_, num_barriers=num_barriers)
|
||||
):
|
||||
if size % MBARRIER_BYTES:
|
||||
if size % utils.MBARRIER_BYTES:
|
||||
raise NotImplementedError("Misaligned barrier allocation")
|
||||
size += num_barriers * MBARRIER_BYTES
|
||||
size += num_barriers * utils.MBARRIER_BYTES
|
||||
case _:
|
||||
size += _count_buffer_bytes(l)
|
||||
return size
|
||||
@ -379,9 +383,11 @@ def _lower_as_gpu_kernel(
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
if kernel_name is None:
|
||||
kernel_name = getattr(body, "__name__", "anonymous")
|
||||
|
||||
# These are needed as nonlocal below.
|
||||
launch_ctx, scratch_arr = None, None
|
||||
with ir.InsertionPoint(module.body):
|
||||
_declare_runtime_functions()
|
||||
gmem_scratch_bytes = 0
|
||||
global_scratch = llvm.GlobalOp(
|
||||
ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet.
|
||||
"global_scratch",
|
||||
@ -390,7 +396,7 @@ def _lower_as_gpu_kernel(
|
||||
)
|
||||
@func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}")
|
||||
def main(token_ptr, buffers):
|
||||
nonlocal gmem_scratch_bytes
|
||||
nonlocal launch_ctx, scratch_arr
|
||||
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
|
||||
arg_refs = []
|
||||
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
|
||||
@ -408,27 +414,40 @@ def _lower_as_gpu_kernel(
|
||||
with _launch(
|
||||
token, grid, cluster, block, scratch_arr, smem_scratch_shape,
|
||||
prof_spec, prof_buffer
|
||||
) as (launch_ctx, smem_refs):
|
||||
) as (_launch_ctx, smem_refs):
|
||||
nonlocal launch_ctx
|
||||
launch_ctx = _launch_ctx
|
||||
body(launch_ctx, *in_refs, *out_refs, smem_refs)
|
||||
gmem_scratch_bytes = launch_ctx.next_scratch_offset
|
||||
# Allocate and initialize the host buffer right before the launch.
|
||||
# Note that we couldn't do that before, because we had to run the body
|
||||
# to learn what the scratch contains.
|
||||
with ir.InsertionPoint(scratch_arr.owner):
|
||||
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
|
||||
scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty)
|
||||
scratch_arr.set_type(scratch_arr_ty)
|
||||
for init_callback in launch_ctx.host_scratch_init:
|
||||
init_callback(scratch_alloc.result)
|
||||
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
sym_tab = ir.SymbolTable(module.operation)
|
||||
sym_tab.insert(main.func_op)
|
||||
sym_tab.insert(global_scratch)
|
||||
module.operation.verify()
|
||||
|
||||
return module, out_shape, unwrap_output_tuple
|
||||
return module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr
|
||||
|
||||
|
||||
def _initialize_scratch(
|
||||
launch_ctx : launch_context.LaunchContext,
|
||||
scratch_arr: ir.Value,
|
||||
):
|
||||
"""
|
||||
Allocates and initializes the host buffer right before the launch. This needs
|
||||
to be done after all TMA descriptors have been recorded by the launch context.
|
||||
Only then we know what the scratch contains.
|
||||
|
||||
When using the Mosaic GPU dialect, the necessary information is known only
|
||||
after the lowering passes have run.
|
||||
"""
|
||||
with ir.InsertionPoint(scratch_arr.owner):
|
||||
gmem_scratch_bytes = launch_ctx.next_scratch_offset
|
||||
scratch_alloc_op = scratch_arr.owner.opview.addr.owner.opview
|
||||
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
|
||||
scratch_alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty)
|
||||
scratch_arr.set_type(scratch_arr_ty)
|
||||
for init_callback in launch_ctx.host_scratch_init:
|
||||
init_callback(scratch_alloc_op.result)
|
||||
|
||||
def _declare_runtime_functions():
|
||||
"""Declares the runtime functions that can be used by the generated code."""
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
@ -462,7 +481,7 @@ def as_gpu_kernel(
|
||||
elif not isinstance(in_shape, tuple):
|
||||
in_shape = (in_shape,)
|
||||
|
||||
module, out_shape, unwrap_output_tuple = (
|
||||
module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = (
|
||||
_lower_as_gpu_kernel(
|
||||
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
|
||||
module_name, kernel_name, prof_spec
|
||||
@ -473,7 +492,10 @@ def as_gpu_kernel(
|
||||
# Run Python lowering passes. The remaining passes will be run in C++ in
|
||||
# jax/jaxlib/mosaic/gpu/custom_call.cc
|
||||
layout_inference.infer_layout(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
|
||||
|
||||
_initialize_scratch(launch_ctx, scratch_arr)
|
||||
module.operation.verify()
|
||||
|
||||
expected_arg_treedef = jax.tree.structure(in_shape)
|
||||
def _check_args(*args):
|
||||
@ -530,6 +552,7 @@ def as_torch_gpu_kernel(
|
||||
cluster: tuple[int, int, int] = (1, 1, 1),
|
||||
module_name: str = "unknown",
|
||||
kernel_name: str | None = None,
|
||||
thread_semantics: ThreadSemantics = ThreadSemantics.Lane,
|
||||
):
|
||||
try:
|
||||
import torch
|
||||
@ -545,13 +568,22 @@ def as_torch_gpu_kernel(
|
||||
flat_out_types, out_treedef = jax.tree.flatten(out_shape)
|
||||
expected_arg_treedef = jax.tree.structure(in_shape)
|
||||
|
||||
module, out_shape, unwrap_output_tuple = (
|
||||
module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = (
|
||||
_lower_as_gpu_kernel(
|
||||
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
|
||||
module_name, kernel_name, prof_spec
|
||||
)
|
||||
)
|
||||
|
||||
if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None:
|
||||
# Run Python lowering passes. The remaining passes will be run in C++ in
|
||||
# jax/jaxlib/mosaic/gpu/custom_call.cc
|
||||
layout_inference.infer_layout(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
|
||||
|
||||
_initialize_scratch(launch_ctx, scratch_arr)
|
||||
module.operation.verify()
|
||||
|
||||
# Get our hands on the compilation and unload functions
|
||||
try:
|
||||
import jax_plugins.xla_cuda12 as cuda_plugin
|
||||
|
@ -31,13 +31,14 @@ from jax._src.lib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
from .fragmented_array import FragmentedArray, WGStridedFragLayout
|
||||
from .launch_context import LaunchContext
|
||||
from .layouts import from_strided_fragmented_layout_attr, has_any_layout_set, is_strided_fragmented_layout, should_have_layout, to_strided_fragmented_layout_attr
|
||||
from .utils import c, ptr_as_memref, single_thread_predicate
|
||||
from .utils import BarrierRef, c, memref_ptr, ptr_as_memref, single_thread_predicate
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]]
|
||||
MlirLoweringRule = Callable[[LaunchContext, ir.Operation | ir.OpView], Sequence[ir.Value]]
|
||||
|
||||
|
||||
_lowerings: dict[str, MlirLoweringRule] = {}
|
||||
@ -88,6 +89,9 @@ def _fragmented_array_from_ir(
|
||||
[operand.type for operand in conversion_cast.operands],
|
||||
conversion_cast.results,
|
||||
)
|
||||
if not isinstance(converted_outputs, list):
|
||||
converted_outputs = [converted_outputs]
|
||||
|
||||
|
||||
reverse_conversion_cast = converted_outputs[0].owner.opview
|
||||
for attribute in conversion_cast.attributes:
|
||||
@ -138,6 +142,7 @@ def gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int:
|
||||
|
||||
@_register_lowering(InitializeBarrierOp)
|
||||
def _initialize_barrier_op_lowering_rule(
|
||||
_: LaunchContext,
|
||||
initialize_barrier_op: InitializeBarrierOp,
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
@ -170,7 +175,7 @@ def _initialize_barrier_op_lowering_rule(
|
||||
|
||||
@_register_lowering(vector.LoadOp)
|
||||
def _vector_load_op_lowering_rule(
|
||||
vector_load_op: vector.LoadOp,
|
||||
_: LaunchContext, vector_load_op: vector.LoadOp
|
||||
) -> Sequence[ir.Value]:
|
||||
(out_layout_attr,) = cast(
|
||||
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
|
||||
@ -199,7 +204,7 @@ def _vector_load_op_lowering_rule(
|
||||
|
||||
@_register_lowering(vector.StoreOp)
|
||||
def _vector_store_op_lowering_rule(
|
||||
vector_store_op: vector.StoreOp,
|
||||
_: LaunchContext, vector_store_op: vector.StoreOp
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
in_layout_attr, *_ = cast(
|
||||
@ -229,8 +234,44 @@ def _vector_store_op_lowering_rule(
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(mgpu.AsyncLoadOp)
|
||||
def _mgpu_async_load_op_lowering_rule(
|
||||
launch_context: LaunchContext, load_op: mgpu.AsyncLoadOp
|
||||
) -> Sequence[ir.Value]:
|
||||
mem_space = gpu_address_space_to_nvptx(gpu.AddressSpace.Workgroup)
|
||||
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
launch_context.async_copy(
|
||||
src_ref=load_op.source,
|
||||
dst_ref=load_op.destination,
|
||||
barrier=BarrierRef(
|
||||
base_address=memref_ptr(load_op.barrier, memory_space=mem_space),
|
||||
offset=c(0, ir.IntegerType.get_signless(64)),
|
||||
phases=None,
|
||||
num_barriers=1,
|
||||
),
|
||||
arrive=load_op.arrive,
|
||||
uniform=False,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(mgpu.AsyncStoreOp)
|
||||
def _mgpu_async_store_op_lowering_rule(
|
||||
launch_context: LaunchContext, store_op: mgpu.AsyncStoreOp
|
||||
) -> Sequence[ir.Value]:
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
launch_context.async_copy(
|
||||
src_ref=store_op.source,
|
||||
dst_ref=store_op.destination,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(arith.AddFOp)
|
||||
def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:
|
||||
def _arith_addf_op_lowering_rule(
|
||||
_: LaunchContext, add: arith.AddFOp
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
|
||||
fragmented_array_rhs = _fragmented_array_from_ir(add.rhs)
|
||||
@ -242,7 +283,7 @@ def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:
|
||||
]
|
||||
|
||||
|
||||
def lower_mgpu_dialect(module: ir.Module):
|
||||
def lower_mgpu_dialect(module: ir.Module, launch_context: LaunchContext):
|
||||
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
|
||||
module.context.load_all_available_dialects()
|
||||
|
||||
@ -257,7 +298,7 @@ def lower_mgpu_dialect(module: ir.Module):
|
||||
if should_have_layout(op) and not has_any_layout_set(op):
|
||||
raise ValueError(f"{op} is missing a layout and can not be lowered.")
|
||||
|
||||
new_results = lowering_rule(op)
|
||||
new_results = lowering_rule(launch_context, op)
|
||||
|
||||
for old, new in zip(op.results, new_results):
|
||||
old.replace_all_uses_with(new)
|
||||
|
@ -36,24 +36,28 @@ from jaxlib.mlir.dialects import scf
|
||||
from jaxlib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
WARPGROUP_SIZE: int = 128
|
||||
DYNAMIC = -9223372036854775808
|
||||
DYNAMIC32 = -2147483648
|
||||
MBARRIER_BYTES = 8
|
||||
|
||||
# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
|
||||
|
||||
|
||||
def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
|
||||
def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = None):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
rank = len(memref_ty.shape)
|
||||
ptr_ty = "ptr" if ptr_memory_space is None else f"ptr<{ptr_memory_space}>"
|
||||
if rank > 0:
|
||||
desc_ty = ir.Type.parse(
|
||||
f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>"
|
||||
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>, array<{rank} x i64>)>"
|
||||
)
|
||||
else:
|
||||
desc_ty = ir.Type.parse("!llvm.struct<(ptr, ptr, i64)>")
|
||||
desc_ty = ir.Type.parse(f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64)>")
|
||||
desc = llvm.UndefOp(desc_ty)
|
||||
desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation
|
||||
desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base
|
||||
@ -321,6 +325,8 @@ def bytewidth(ty: ir.Type):
|
||||
return ir.IntegerType(ty).width // 8
|
||||
if ir.FloatType.isinstance(ty):
|
||||
return ir.FloatType(ty).width // 8
|
||||
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
|
||||
return MBARRIER_BYTES
|
||||
raise NotImplementedError(ty)
|
||||
|
||||
|
||||
@ -743,6 +749,18 @@ class BarrierRef:
|
||||
ptr, self.base_address, [self.offset], [DYNAMIC32], i64
|
||||
)
|
||||
|
||||
def as_dialect_barrier(self) -> ir.Value:
|
||||
if self.num_barriers > 1:
|
||||
raise NotImplementedError(
|
||||
f"Only BarrierRef with num_barriers=1 is suppored in the MLIR "
|
||||
f"Mosaic GPU dialect, but got num_barriers={self.num_barriers}"
|
||||
)
|
||||
return ptr_as_memref(
|
||||
self.base_address,
|
||||
ir.MemRefType.get((), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
ptr_memory_space=3,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CollectiveBarrierRef:
|
||||
@ -997,19 +1015,21 @@ def warp_tree_reduce(value, op, group_size):
|
||||
def memref_ptr(memref_arg, memory_space=None):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
memref_ty = ir.MemRefType(memref_arg.type)
|
||||
if len(memref_ty.shape) == 0:
|
||||
raise NotImplementedError
|
||||
elem_bytewidth = bytewidth(memref_ty.element_type)
|
||||
rank = len(memref_ty.shape)
|
||||
# TODO: Read out memory space from memref
|
||||
space = "" if memory_space is None else "<" + str(memory_space) + ">"
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr" + space)
|
||||
desc_ty = ir.Type.parse(
|
||||
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
|
||||
f" array<{rank} x i64>)>"
|
||||
)
|
||||
if rank == 0:
|
||||
desc_ty = ir.Type.parse(f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64)>")
|
||||
else:
|
||||
desc_ty = ir.Type.parse(
|
||||
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
|
||||
f" array<{rank} x i64>)>"
|
||||
)
|
||||
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
|
||||
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
|
||||
|
||||
elem_bytewidth = bytewidth(memref_ty.element_type)
|
||||
offset_elems = llvm.extractvalue(i64, desc, [2])
|
||||
offset_bytes = llvm.mul(
|
||||
offset_elems,
|
||||
|
@ -724,7 +724,6 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
aval_in, aval_out, x):
|
||||
if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1:
|
||||
return x
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
|
||||
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
||||
@ -734,6 +733,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
unspecified = set(range(aval_in.ndim)) if auto else set()
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
|
||||
unspecified_dims=unspecified)
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
|
||||
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)
|
||||
|
||||
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
@ -746,6 +746,8 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
ns = sharding_impls.physical_sharding(aval_out, ns)
|
||||
aval_out = core.physical_aval(aval_out)
|
||||
unspecified = set(range(aval_out.ndim)) if auto else set()
|
||||
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
||||
aval_in = core.physical_aval(aval_in)
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified)
|
||||
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
||||
|
@ -36,17 +36,17 @@ _deprecations = {
|
||||
"jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.",
|
||||
None,
|
||||
),
|
||||
# Added Sep 26 2024
|
||||
# Finalized 2024-12-23; remove after 2024-03-23
|
||||
"Device": (
|
||||
"jax.lib.xla_client.Device is deprecated; use jax.Device instead.",
|
||||
_xc.Device,
|
||||
None,
|
||||
),
|
||||
"XlaRuntimeError": (
|
||||
(
|
||||
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
|
||||
" jax.errors.JaxRuntimeError."
|
||||
),
|
||||
_xc.XlaRuntimeError,
|
||||
None,
|
||||
),
|
||||
# Added Oct 10 2024
|
||||
"FftType": (
|
||||
|
@ -54,7 +54,8 @@ py_extension(
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -69,8 +70,8 @@ py_extension(
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIGPUHeaders",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -84,7 +85,8 @@ py_extension(
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIGPUHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -99,8 +101,8 @@ py_extension(
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:CAPINVGPUHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -115,8 +117,8 @@ py_extension(
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:CAPILLVMHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -130,8 +132,8 @@ py_extension(
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -145,7 +147,8 @@ py_extension(
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -155,9 +158,10 @@ py_extension(
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
@ -377,6 +381,7 @@ cc_library(
|
||||
name = "jaxlib_mlir_capi_objects",
|
||||
deps = [
|
||||
"//jaxlib/mosaic:tpu_dialect_capi_objects",
|
||||
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects",
|
||||
"@llvm-project//mlir:CAPIArithObjects",
|
||||
"@llvm-project//mlir:CAPIGPUObjects",
|
||||
"@llvm-project//mlir:CAPIIRObjects",
|
||||
|
@ -28,6 +28,7 @@ package(
|
||||
py_library(
|
||||
name = "mosaic",
|
||||
deps = [
|
||||
"//jaxlib/mosaic/python:gpu_dialect",
|
||||
"//jaxlib/mosaic/python:tpu_dialect",
|
||||
],
|
||||
)
|
||||
@ -42,6 +43,7 @@ cc_library(
|
||||
"dialect/tpu/tpu_dialect.cc",
|
||||
"dialect/tpu/tpu_ops.cc",
|
||||
"dialect/tpu/util.cc",
|
||||
"dialect/tpu/vreg_util.cc",
|
||||
":extension_srcs",
|
||||
] + glob([
|
||||
"dialect/tpu/transforms/*.cc",
|
||||
@ -50,6 +52,7 @@ cc_library(
|
||||
"dialect/tpu/layout.h",
|
||||
"dialect/tpu/tpu_dialect.h",
|
||||
"dialect/tpu/util.h",
|
||||
"dialect/tpu/vreg_util.h",
|
||||
] + glob([
|
||||
"dialect/tpu/transforms/*.h",
|
||||
]),
|
||||
@ -231,6 +234,19 @@ cc_library(
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "vreg_util_test",
|
||||
srcs = ["dialect/tpu/vreg_util_test.cc"],
|
||||
deps = [
|
||||
":tpu_dialect",
|
||||
"//testing/base/public:gunit_main",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:VectorDialect",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "extension_srcs",
|
||||
srcs = [
|
||||
|
@ -215,3 +215,26 @@ cc_library(
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
],
|
||||
)
|
||||
|
||||
# Header-only target, used when using the C API from a separate shared library.
|
||||
cc_library(
|
||||
name = "gpu_dialect_capi_headers",
|
||||
hdrs = DIALECT_CAPI_HEADERS,
|
||||
deps = [
|
||||
":mosaic_gpu_inc_gen",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
],
|
||||
)
|
||||
|
||||
# Alwayslink target, used when exporting the C API from a shared library.
|
||||
cc_library(
|
||||
name = "gpu_dialect_capi_objects",
|
||||
srcs = DIALECT_CAPI_SOURCES,
|
||||
hdrs = DIALECT_CAPI_HEADERS,
|
||||
deps = [
|
||||
":mosaic_gpu",
|
||||
":mosaic_gpu_inc_gen",
|
||||
"@llvm-project//mlir:CAPIIRObjects",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
@ -19,8 +19,8 @@ limitations under the License.
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "testing/base/public/gmock.h"
|
||||
#include "testing/base/public/gunit.h"
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
@ -852,6 +852,19 @@ def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncO
|
||||
];
|
||||
}
|
||||
|
||||
def RelayoutInsertionPass : Pass<"tpu-relayout-insertion", "::mlir::func::FuncOp"> {
|
||||
let dependentDialects = [
|
||||
"::mlir::arith::ArithDialect",
|
||||
"::mlir::func::FuncDialect",
|
||||
"::mlir::tpu::TPUDialect",
|
||||
];
|
||||
let constructor = "::mlir::tpu::createRelayoutInsertionPass()";
|
||||
let options = [
|
||||
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
|
||||
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
|
||||
];
|
||||
}
|
||||
|
||||
def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncOp"> {
|
||||
let dependentDialects = [
|
||||
"::mlir::arith::ArithDialect",
|
||||
|
@ -81,6 +81,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
|
||||
std::array<int64_t, 2> target_shape = {8, 128});
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
|
||||
std::array<int64_t, 2> target_shape = {8, 128});
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
|
||||
const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{});
|
||||
|
||||
|
@ -29,7 +29,6 @@
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
@ -52,6 +51,7 @@
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/include/llvm/ADT/APInt.h"
|
||||
#include "llvm/include/llvm/Support/LogicalResult.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
@ -64,6 +64,7 @@
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/util.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
|
||||
#include "xla/array.h"
|
||||
#include "xla/layout.h"
|
||||
#include "xla/util.h"
|
||||
@ -79,41 +80,6 @@ namespace mlir::tpu {
|
||||
#define GEN_PASS_DEF_APPLYVECTORLAYOUTPASS
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
|
||||
|
||||
// TPU_ASSERT_* macros should be understood as an assert, i.e. use it to check
|
||||
// things that should never happen. We prefer returning failure over a CHECK
|
||||
// because it's easier to debug from Python (particularly from OSS where symbols
|
||||
// are removed)
|
||||
#define TPU_ASSERT_IMPL(stream, cond) \
|
||||
if (LLVM_UNLIKELY(!(cond))) { \
|
||||
(stream) << "Internal error: assert failed: " #cond; \
|
||||
}
|
||||
#define TPU_ASSERT_CMP_IMPL(stream, lhs, rhs, cmp) \
|
||||
if (LLVM_UNLIKELY(!((lhs)cmp(rhs)))) { \
|
||||
(stream) << "Internal error: assert failed: " #lhs " " #cmp " " #rhs " (" \
|
||||
<< (lhs) << " vs. " << (rhs) << ")"; \
|
||||
return failure(); \
|
||||
}
|
||||
#define TPU_ASSERT_OP(cond) TPU_ASSERT_IMPL(op.emitOpError(), cond)
|
||||
#define TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, cmp) \
|
||||
TPU_ASSERT_CMP_IMPL(op.emitOpError(), lhs, rhs, cmp)
|
||||
#define TPU_ASSERT_EQ_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, ==)
|
||||
#define TPU_ASSERT_GE_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, >=)
|
||||
#define TPU_ASSERT_GT_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, >)
|
||||
#define TPU_ASSERT_LE_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, <=)
|
||||
#define TPU_ASSERT_LT_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, <)
|
||||
#define TPU_ASSERT_LOC(loc, cond) TPU_ASSERT_IMPL(mlir::emitError(loc), cond)
|
||||
#define TPU_ASSERT_CMP_LOC_IMPL(loc, lhs, rhs, cmp) \
|
||||
TPU_ASSERT_CMP_IMPL(loc, lhs, rhs, cmp)
|
||||
#define TPU_ASSERT_EQ_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, ==)
|
||||
#define TPU_ASSERT_GE_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, >=)
|
||||
#define TPU_ASSERT_GT_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, >)
|
||||
#define TPU_ASSERT_LT_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <)
|
||||
#define TPU_ASSERT_LE_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <=)
|
||||
|
||||
// The minimum bound required to rotate with scratch space. The bound refers to
|
||||
// the number of VREGs on rotation dim. This number was concluded from some cost
|
||||
@ -310,16 +276,6 @@ void updateSliceFromRange(xla::Array<T> &arr, Range data,
|
||||
CHECK(data_it == data.end());
|
||||
}
|
||||
|
||||
FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
|
||||
if (isa<FloatType>(ty)) {
|
||||
return TypedAttr(FloatAttr::get(ty, 0));
|
||||
}
|
||||
if (isa<IntegerType>(ty)) {
|
||||
return TypedAttr(IntegerAttr::get(ty, 0));
|
||||
}
|
||||
return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty;
|
||||
}
|
||||
|
||||
FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
|
||||
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
|
||||
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
|
||||
@ -514,33 +470,6 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
|
||||
return argument;
|
||||
}
|
||||
|
||||
VectorType getNativeVregOrVmaskTypeImpl(
|
||||
Type elem_ty, const int8_t bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
if (bitwidth == 32) {
|
||||
return VectorType::get(target_shape, elem_ty);
|
||||
}
|
||||
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
|
||||
elem_ty);
|
||||
}
|
||||
|
||||
VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
|
||||
if (bitwidth == 1) {
|
||||
bitwidth = layout_bitwidth;
|
||||
} else {
|
||||
CHECK_EQ(bitwidth, layout_bitwidth);
|
||||
}
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
|
||||
}
|
||||
|
||||
VectorType getNativeVregType(Type elem_ty,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
|
||||
target_shape);
|
||||
}
|
||||
|
||||
// Masks all values outside of bounds.
|
||||
//
|
||||
// Arguments:
|
||||
@ -553,7 +482,7 @@ VectorType getNativeVregType(Type elem_ty,
|
||||
// Returns:
|
||||
// An MLIR value of the same type as the value argument, with all entries
|
||||
// outside of bounds replaced by neutral.
|
||||
FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
|
||||
FailureOr<Value> maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> value,
|
||||
const VRegDataBounds &bounds,
|
||||
const Attribute neutral) {
|
||||
@ -577,86 +506,12 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
|
||||
value.getLoc(),
|
||||
VectorType::get(native_vreg_ty.getShape(), builder.getI1Type()), mask);
|
||||
}
|
||||
auto neutral_vec = builder.create<arith::ConstantOp>(
|
||||
value.getLoc(), native_vreg_ty,
|
||||
DenseElementsAttr::get(native_vreg_ty, neutral));
|
||||
Value neutral_vec = getFullVector(builder, native_vreg_ty, neutral);
|
||||
return builder
|
||||
.create<arith::SelectOp>(value.getLoc(), mask, value, neutral_vec)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Returns empty vector on null attribute
|
||||
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
|
||||
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
|
||||
SmallVector<Layout> out_layouts;
|
||||
out_layouts.reserve(array_attr.size());
|
||||
for (const Attribute a : array_attr) {
|
||||
if (auto layout_attr = dyn_cast_if_present<VectorLayoutAttr>(a)) {
|
||||
out_layouts.push_back(layout_attr.getLayout());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return out_layouts;
|
||||
}
|
||||
return SmallVector<Layout>{};
|
||||
}
|
||||
|
||||
bool layoutIsValidForValue(const Layout &l, const Value v,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
// l must be non-null iff v is of vector type
|
||||
if (const auto vty = dyn_cast<VectorType>(v.getType())) {
|
||||
if (!l.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Vector type should have the same bitwidth as the layout, except for the
|
||||
// i1 special case, used for vmasks (see comment for VectorLayout class).
|
||||
if (!vty.getElementType().isIntOrFloat()) {
|
||||
return false;
|
||||
}
|
||||
const int8_t bitwidth = vty.getElementTypeBitWidth();
|
||||
if (bitwidth != l->bitwidth() && bitwidth != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return l->isValid(target_shape) && l->layout_rank() <= vty.getRank();
|
||||
}
|
||||
return !l.has_value();
|
||||
}
|
||||
|
||||
// TODO(tlongeri): Unify with infer_vector_layout.cc's getOutLayout.
|
||||
FailureOr<SmallVector<Layout>> getOutLayouts(
|
||||
Operation &op, const std::array<int64_t, 2> target_shape) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> out_layouts,
|
||||
getLayoutArrayFromAttr(op.getAttr("out_layout")));
|
||||
if (out_layouts.size() != op.getNumResults()) {
|
||||
return op.emitOpError("out_layout size does not match number of results");
|
||||
}
|
||||
for (const auto [l, res] : llvm::zip_equal(out_layouts, op.getResults())) {
|
||||
if (!layoutIsValidForValue(l, res, target_shape)) {
|
||||
return op.emitOpError("Invalid output layout");
|
||||
}
|
||||
}
|
||||
return out_layouts;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Layout>> getInLayouts(
|
||||
Operation &op, const std::array<int64_t, 2> target_shape) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
|
||||
getLayoutArrayFromAttr(op.getAttr("in_layout")));
|
||||
if (in_layouts.size() != op.getNumOperands()) {
|
||||
return op.emitOpError("in_layout size does not match number of operands");
|
||||
}
|
||||
for (const auto [l, operand] :
|
||||
llvm::zip_equal(in_layouts, op.getOperands())) {
|
||||
if (!layoutIsValidForValue(l, operand, target_shape)) {
|
||||
return op.emitOpError("Invalid input layout");
|
||||
}
|
||||
}
|
||||
return in_layouts;
|
||||
}
|
||||
|
||||
// Insert a minor dimension to the implicit shape. The original minor dimension
|
||||
// becomes the new second minor dimension, laid out across sublanes.
|
||||
//
|
||||
@ -1970,130 +1825,30 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]);
|
||||
TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]);
|
||||
|
||||
const VectorType i32_vreg_ty =
|
||||
getNativeVregType(builder.getI32Type(), ctx.target_shape);
|
||||
auto getX32VmaskByPaddingEnd = [&](int64_t dim, int64_t padding) {
|
||||
CHECK(dim == 0 || dim == 1);
|
||||
CHECK(padding >= 0 && padding <= ctx.target_shape[dim]);
|
||||
return cast<TypedValue<VectorType>>(
|
||||
builder
|
||||
.create<arith::CmpIOp>(
|
||||
arith::CmpIPredicate::slt,
|
||||
builder.create<tpu::IotaOp>(i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(dim)),
|
||||
builder.create<arith::ConstantOp>(DenseElementsAttr::get(
|
||||
i32_vreg_ty, builder.getI32IntegerAttr(
|
||||
ctx.target_shape[dim] - padding))))
|
||||
.getResult());
|
||||
};
|
||||
|
||||
// We can also extend this helper function with padding_top and padding_left
|
||||
// based on the offsets in vregs.
|
||||
const Value i32_zeros_vreg = builder.create<arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0)));
|
||||
const Value i32_max_vreg = builder.create<arith::ConstantOp>(
|
||||
op.getLoc(), DenseElementsAttr::get(
|
||||
i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff)));
|
||||
auto maskVregs = [&](xla::Array<Value> &vregs, int64_t padding_bottom,
|
||||
int64_t padding_right) {
|
||||
auto vreg_ty = cast<VectorType>(vregs.begin()->getType());
|
||||
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
|
||||
// Mask out the bottom.
|
||||
if (padding_bottom > 0) {
|
||||
// We have limited the row size of LHS and RHS need to be a multiple of
|
||||
// native tiling at the beginning of this rule. Therefore, it is safe to
|
||||
// bitcast to x32 vreg for masking.
|
||||
int sub_padding = padding_bottom % packing;
|
||||
int x32_padding_bottom = padding_bottom / packing;
|
||||
auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom);
|
||||
// Create an int32 vreg which contains subelement masking and then
|
||||
// logical_and with target vreg to mask out the unaligned paddings.
|
||||
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
|
||||
// [8, 128], then the mask will be:
|
||||
//
|
||||
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
|
||||
// sublane 6: [0 , 0 , ..., 0 ]
|
||||
// sublane 7: [0 , 0 , ..., 0 ]
|
||||
//
|
||||
// Through this way, in order to mask sub-elements, each target vreg only
|
||||
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
|
||||
// + packing).
|
||||
Value partial_sublane_mask = builder.create<arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
DenseElementsAttr::get(
|
||||
i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(
|
||||
0xffffffff >>
|
||||
(sub_padding * vreg_ty.getElementTypeBitWidth()))));
|
||||
// Insert 0xffffffff above the blended sublane.
|
||||
Value sublane_mask = builder.create<arith::SelectOp>(
|
||||
getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg,
|
||||
partial_sublane_mask);
|
||||
// Insert 0 below the blended sublane.
|
||||
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
|
||||
i32_zeros_vreg);
|
||||
for (int64_t i = 0; i < vregs.dim(1); ++i) {
|
||||
Value &vreg = vregs({vregs.dim(0) - 1, i});
|
||||
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
|
||||
if (sub_padding > 0) {
|
||||
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
|
||||
} else {
|
||||
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
|
||||
i32_zeros_vreg);
|
||||
}
|
||||
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
|
||||
}
|
||||
}
|
||||
// Mask out the right.
|
||||
if (padding_right > 0) {
|
||||
auto mask_right = getX32VmaskByPaddingEnd(1, padding_right);
|
||||
for (int64_t i = 0; i < vregs.dim(0); ++i) {
|
||||
Value &vreg = vregs({i, vregs.dim(1) - 1});
|
||||
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
|
||||
i32_vreg = builder.create<arith::SelectOp>(mask_right, i32_vreg,
|
||||
i32_zeros_vreg);
|
||||
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create a vreg filled with zeros.
|
||||
auto getZerosVergLike =
|
||||
[&](const Value &vreg) -> FailureOr<TypedValue<VectorType>> {
|
||||
const VectorType vreg_type = cast<VectorType>(vreg.getType());
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const Attribute zero_attr,
|
||||
getZeroIntOrFloatAttr(vreg_type.getElementType()));
|
||||
return cast<TypedValue<VectorType>>(
|
||||
builder
|
||||
.create<arith::ConstantOp>(
|
||||
op.getLoc(), DenseElementsAttr::get(vreg_type, zero_attr))
|
||||
.getResult());
|
||||
};
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(auto lhs_zeros_vreg,
|
||||
getZerosVergLike(*lhs_vregs.begin()));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(auto rhs_zeros_vreg,
|
||||
getZerosVergLike(*rhs_vregs.begin()));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(auto acc_zeros_vreg,
|
||||
getZerosVergLike(*acc_vregs.begin()));
|
||||
auto lhs_zeros_vreg =
|
||||
getZerosVector(builder, cast<VectorType>(lhs_vregs.begin()->getType()));
|
||||
auto rhs_zeros_vreg =
|
||||
getZerosVector(builder, cast<VectorType>(rhs_vregs.begin()->getType()));
|
||||
auto acc_zeros_vreg =
|
||||
getZerosVector(builder, cast<VectorType>(acc_vregs.begin()->getType()));
|
||||
|
||||
// Only mask out the paddings on contracting dim of LHS and RHS.
|
||||
maskVregs(lhs_vregs, 0, padded_lhs_cols - lhs_shape[1]);
|
||||
RETURN_IF_FAILED(
|
||||
maskNativeTilingVregs(builder, lhs_vregs, ctx.target_shape,
|
||||
/*padding_bottom=*/0,
|
||||
/*padding_right=*/padded_lhs_cols - lhs_shape[1]));
|
||||
if (transpose_rhs) {
|
||||
maskVregs(rhs_vregs, 0, padded_rhs_cols - rhs_shape[1]);
|
||||
RETURN_IF_FAILED(maskNativeTilingVregs(
|
||||
builder, rhs_vregs, ctx.target_shape,
|
||||
/*padding_bottom=*/0,
|
||||
/*padding_right=*/padded_rhs_cols - rhs_shape[1]));
|
||||
} else {
|
||||
maskVregs(rhs_vregs, padded_rhs_rows - rhs_shape[0], 0);
|
||||
RETURN_IF_FAILED(
|
||||
maskNativeTilingVregs(builder, rhs_vregs, ctx.target_shape,
|
||||
/*padding_bottom=*/padded_rhs_rows - rhs_shape[0],
|
||||
/*padding_right=*/0));
|
||||
}
|
||||
|
||||
// TODO(b/328094640): use latch 3 for short dimensions.
|
||||
// TODO(b/328093587): Skip zeros vreg matmul
|
||||
// At this point, all paddings on vregs are masked out. For now, we
|
||||
// append zero vregs to make LHS's second dim, both RHS's dims and ACC's
|
||||
// second dim to be a multiple of mxu_size.
|
||||
@ -2984,12 +2739,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
|
||||
native_vreg_ty,
|
||||
/*dimension =*/builder.getI32IntegerAttr(1));
|
||||
for (int64_t i = 0; i < num_tiles; ++i) {
|
||||
auto offset = builder.create<arith::ConstantOp>(
|
||||
native_vreg_ty,
|
||||
DenseElementsAttr::get(
|
||||
native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 1))));
|
||||
Value offset = getFullVector(
|
||||
builder, native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 1)));
|
||||
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
|
||||
}
|
||||
xla::Array<Value> broadcasted_tiles(tile_array_shape);
|
||||
@ -3011,12 +2764,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
|
||||
native_vreg_ty,
|
||||
/*dimension =*/builder.getI32IntegerAttr(0));
|
||||
for (int64_t i = 0; i < num_tiles; ++i) {
|
||||
auto offset = builder.create<arith::ConstantOp>(
|
||||
native_vreg_ty,
|
||||
DenseElementsAttr::get(
|
||||
native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 2))));
|
||||
Value offset = getFullVector(
|
||||
builder, native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 2)));
|
||||
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
|
||||
}
|
||||
xla::Array<Value> broadcasted_tiles(tile_array_shape);
|
||||
@ -3033,10 +2784,8 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
|
||||
SmallVector<Value> tiles;
|
||||
tiles.reserve(vty.getDimSize(*dimension));
|
||||
for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) {
|
||||
tiles.push_back(builder.create<arith::ConstantOp>(
|
||||
native_vreg_ty,
|
||||
DenseElementsAttr::get(native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(), i))));
|
||||
tiles.push_back(getFullVector(builder, native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(), i)));
|
||||
}
|
||||
xla::Array<Value> out_tiles(tile_array_shape);
|
||||
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
@ -3625,12 +3374,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
const int64_t offset = *offsets_in[1];
|
||||
const int64_t lane_offset = offset % ctx.target_shape[1];
|
||||
const int64_t tile_offset = offset / ctx.target_shape[1];
|
||||
const auto idx_ty =
|
||||
VectorType::get(ctx.target_shape, builder.getI32Type());
|
||||
auto lane_offset_cst = builder.create<arith::ConstantOp>(
|
||||
broadcast_op.getLoc(), idx_ty,
|
||||
DenseElementsAttr::get(idx_ty,
|
||||
builder.getI32IntegerAttr(lane_offset)));
|
||||
Value lane_offset_cst = getFullVector(
|
||||
builder, getNativeVregType(builder.getI32Type(), ctx.target_shape),
|
||||
builder.getI32IntegerAttr(lane_offset));
|
||||
DenseI32ArrayAttr sublane_pattern;
|
||||
if (num_tiles != 1) {
|
||||
SmallVector<int32_t> pattern;
|
||||
@ -3690,10 +3436,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
getNativeVregType(src_i32.getType(), ctx.target_shape);
|
||||
auto tile_i32 =
|
||||
builder.create<vector::BroadcastOp>(native_vreg_ty, src_i32);
|
||||
auto zeros = builder.create<arith::ConstantOp>(
|
||||
broadcast_op.getLoc(), tile_i32.getType(),
|
||||
DenseElementsAttr::get(tile_i32.getType(),
|
||||
builder.getI32IntegerAttr(0)));
|
||||
Value zeros = getZerosVector(builder, tile_i32.getType());
|
||||
auto tile =
|
||||
builder.create<arith::CmpIOp>(arith::CmpIPredicate::ne, tile_i32, zeros)
|
||||
.getResult();
|
||||
@ -5588,8 +5331,6 @@ FailureOr<xla::Array<Value>> doColumnShiftRelayout(
|
||||
const std::array<int64_t, 2> vreg_slice = src.vregSlice(target_shape);
|
||||
const int bitwidth = src.bitwidth();
|
||||
const int packing = src.packing();
|
||||
const VectorLayout dst(bitwidth, {src.offsets()[0], dst_col_offset}, tiling,
|
||||
src.implicit_dim());
|
||||
const int64_t col_diff = dst_col_offset - *src.offsets()[1];
|
||||
if (tiling[0] % packing != 0 || tiling[1] != target_shape[1]) {
|
||||
return emitError(loc,
|
||||
@ -5932,7 +5673,8 @@ LogicalResult retileToLargeTileWithScratch(
|
||||
RewriteContext &ctx, OpBuilder &builder, const Location loc,
|
||||
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
|
||||
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
|
||||
TypedValue<MemRefType> scratch_ref) {
|
||||
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
|
||||
const int64_t load_vreg_skips) {
|
||||
if (dst_tile[0] % src_tile[0] != 0) {
|
||||
return failure();
|
||||
}
|
||||
@ -6036,8 +5778,8 @@ LogicalResult retileToLargeTileWithScratch(
|
||||
SmallVector<int64_t, 4> src_idx(rank);
|
||||
dst_tiles.Each([&](absl::Span<const int64_t> dst_idx, Value *dst_vreg) {
|
||||
int64_t dst_row_idx = *(dst_idx.end() - 2);
|
||||
int64_t dst_col_idx = *(dst_idx.end() - 1);
|
||||
int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group;
|
||||
int64_t dst_col_idx_with_skips = *(dst_idx.end() - 1) + load_vreg_skips;
|
||||
int64_t vreg_idx_in_group = dst_col_idx_with_skips % vregs_per_group;
|
||||
int64_t load_offset = sublanes_per_group * stored_group_cnt +
|
||||
vreg_idx_in_group * sl_per_vreg * stride;
|
||||
delayed_loads.push_back(
|
||||
@ -6047,16 +5789,20 @@ LogicalResult retileToLargeTileWithScratch(
|
||||
// the vregs from current group and now we need to store corresponding
|
||||
// group of src vregs before actually emitting the loads.
|
||||
if (vreg_idx_in_group == vregs_per_group - 1 ||
|
||||
dst_col_idx == dst_tiles.dimensions().back() - 1) {
|
||||
auto src_row_idx = dst_row_idx * vregs_per_group;
|
||||
auto src_col_idx = dst_col_idx / vregs_per_group;
|
||||
dst_idx.back() == dst_tiles.dimensions().back() - 1) {
|
||||
auto base_src_row_idx = dst_row_idx * vregs_per_group - store_vreg_delay;
|
||||
auto src_col_idx = dst_col_idx_with_skips / vregs_per_group;
|
||||
std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin());
|
||||
for (int vi = 0; vi < vregs_per_group; ++vi) {
|
||||
if (src_row_idx + vi >= src_tiles.dim(rank - 2) ||
|
||||
const int64_t src_row_idx = base_src_row_idx + vi;
|
||||
if (src_row_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
if (src_row_idx >= src_tiles.dim(rank - 2) ||
|
||||
src_col_idx >= src_tiles.dim(rank - 1)) {
|
||||
break;
|
||||
}
|
||||
*(src_idx.end() - 2) = src_row_idx + vi;
|
||||
*(src_idx.end() - 2) = src_row_idx;
|
||||
*(src_idx.end() - 1) = src_col_idx;
|
||||
Value src_vreg = src_tiles(src_idx);
|
||||
src_vreg =
|
||||
@ -6085,7 +5831,8 @@ LogicalResult retileToSmallTileWithScratch(
|
||||
RewriteContext &ctx, OpBuilder &builder, const Location loc,
|
||||
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
|
||||
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
|
||||
TypedValue<MemRefType> scratch_ref) {
|
||||
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
|
||||
const int64_t load_vreg_skips) {
|
||||
if (src_tile[0] % dst_tile[0] != 0) {
|
||||
return failure();
|
||||
}
|
||||
@ -6212,8 +5959,8 @@ LogicalResult retileToSmallTileWithScratch(
|
||||
SmallVector<int64_t, 4> dst_idx(rank);
|
||||
src_tiles.Each([&](absl::Span<const int64_t> src_idx, Value src_vreg) {
|
||||
int64_t src_row_idx = *(src_idx.end() - 2);
|
||||
int64_t src_col_idx = *(src_idx.end() - 1);
|
||||
int64_t vreg_idx_in_group = src_col_idx % vregs_per_group;
|
||||
int64_t src_col_idx_with_delays = *(src_idx.end() - 1) + store_vreg_delay;
|
||||
int64_t vreg_idx_in_group = src_col_idx_with_delays % vregs_per_group;
|
||||
src_vreg = builder.create<tpu::BitcastVregOp>(loc, temp_vreg_ty, src_vreg);
|
||||
if (use_shuffled_load) {
|
||||
Value store_offset = mlirIndexConst(
|
||||
@ -6235,16 +5982,20 @@ LogicalResult retileToSmallTileWithScratch(
|
||||
// vregs' row, this indicates we have stored all the vregs needed to
|
||||
// assemble a new group of dst vreg.
|
||||
if (vreg_idx_in_group == vregs_per_group - 1 ||
|
||||
src_col_idx == src_tiles.dimensions().back() - 1) {
|
||||
auto dst_row_idx = src_row_idx * vregs_per_group;
|
||||
auto dst_col_idx = src_col_idx / vregs_per_group;
|
||||
src_idx.back() == src_tiles.dimensions().back() - 1) {
|
||||
auto base_dst_row_idx = src_row_idx * vregs_per_group - load_vreg_skips;
|
||||
auto dst_col_idx = src_col_idx_with_delays / vregs_per_group;
|
||||
std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin());
|
||||
for (int vi = 0; vi < vregs_per_group; ++vi) {
|
||||
if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) ||
|
||||
const int64_t dst_row_idx = base_dst_row_idx + vi;
|
||||
if (dst_row_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
if (dst_row_idx >= dst_tiles.dim(rank - 2) ||
|
||||
dst_col_idx >= dst_tiles.dim(rank - 1)) {
|
||||
break;
|
||||
}
|
||||
*(dst_idx.end() - 2) = dst_row_idx + vi;
|
||||
*(dst_idx.end() - 2) = dst_row_idx;
|
||||
*(dst_idx.end() - 1) = dst_col_idx;
|
||||
Value *dst_vreg = &dst_tiles(dst_idx);
|
||||
int64_t load_offset =
|
||||
@ -6269,18 +6020,70 @@ LogicalResult retileToSmallTileWithScratch(
|
||||
|
||||
// go/mosaic-retiling-in-scratch is the full internal documentation that
|
||||
// includes more details about the TPU generations.
|
||||
LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
|
||||
const Location loc,
|
||||
xla::Array<Value> &dst_tiles,
|
||||
const std::array<int64_t, 2> &dst_tiling,
|
||||
const xla::Array<Value> &src_tiles,
|
||||
const std::array<int64_t, 2> &src_tiling,
|
||||
int packing) {
|
||||
// Arguments:
|
||||
// - shape: The non-implicit shape of the operand
|
||||
// - dst_tiling: The desired result tiling
|
||||
// - dst_offsets_hint: Hints for the result offsets. They may be used or
|
||||
// ignored. See comments in the body of the function for
|
||||
// more details.
|
||||
// - src_vregs: The source vregs to retile.
|
||||
// - src: The source layout
|
||||
// Returns a pair holding the result layout (potentially using the hints) and
|
||||
// the retiled vregs.
|
||||
// TODO(tlongeri): Clean up the function parameters/signatures. We are passing
|
||||
// in more information than strictly needed.
|
||||
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> retileWithScratch(
|
||||
RewriteContext &ctx, OpBuilder &builder, const Location loc,
|
||||
const ArrayRef<int64_t> shape, const std::array<int64_t, 2> dst_tiling,
|
||||
const LayoutOffsets dst_offsets_hint, const xla::Array<Value> &src_vregs,
|
||||
const VectorLayout &src) {
|
||||
const int bitwidth = src.bitwidth();
|
||||
const int packing = src.packing();
|
||||
const std::array<int64_t, 2> src_tiling = src.tiling();
|
||||
if (!(src_tiling[1] == ctx.target_shape[1] &&
|
||||
dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 &&
|
||||
dst_tiling[0] % packing == 0)) {
|
||||
return failure();
|
||||
}
|
||||
const std::array<int64_t, 2> src_vreg_slice =
|
||||
VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling);
|
||||
const std::array<int64_t, 2> dst_vreg_slice =
|
||||
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
|
||||
|
||||
// TODO(b/368088671): When sublane tiling changes, we should be able to
|
||||
// preserve some replications from the source layout. But we need to
|
||||
// make sure they are implemented efficiently and well-tested. For now, we
|
||||
// just simply use 0 for the replicated offset after retiling.
|
||||
const LayoutOffsets src_offsets = {src.offsets()[0].value_or(0),
|
||||
src.offsets()[1].value_or(0)};
|
||||
// The provided offset hints are used only if they align with the source
|
||||
// offsets, else we default to the smallest possible aligned offsets.
|
||||
LayoutOffsets dst_offsets = {*src_offsets[0] % dst_vreg_slice[0],
|
||||
*src_offsets[1] % dst_vreg_slice[1]};
|
||||
// On a given dimension, either the source vreg slice size divides the dest
|
||||
// vreg slice size, or vice versa (depending on the dimension and whether it's
|
||||
// small-to-large or large-to-small retiling). Offset changes are supported
|
||||
// as long as they are aligned modulo the smaller of the two sizes.
|
||||
const std::array<int64_t, 2> alignment = {
|
||||
std::min(src_vreg_slice[0], dst_vreg_slice[0]),
|
||||
std::min(src_vreg_slice[1], dst_vreg_slice[1])};
|
||||
if (dst_offsets_hint[0].has_value() &&
|
||||
(*dst_offsets_hint[0] - *src_offsets[0]) % alignment[0] == 0) {
|
||||
CHECK_LT(*dst_offsets_hint[0], dst_vreg_slice[0]);
|
||||
dst_offsets[0] = *dst_offsets_hint[0];
|
||||
}
|
||||
if (dst_offsets_hint[1].has_value() &&
|
||||
(*dst_offsets_hint[1] - *src_offsets[1]) % alignment[1] == 0) {
|
||||
CHECK_LT(*dst_offsets_hint[1], dst_vreg_slice[1]);
|
||||
dst_offsets[1] = *dst_offsets_hint[1];
|
||||
}
|
||||
// The offsets of the source in units of the destination vreg slice:
|
||||
const std::array<int64_t, 2> src_offsets_in_dst_vreg_slices = {
|
||||
*src_offsets[0] / dst_vreg_slice[0], *src_offsets[1] / dst_vreg_slice[1]};
|
||||
// The offsets of the destination in units of the source vreg slice:
|
||||
const std::array<int64_t, 2> dst_offsets_in_src_vreg_slices = {
|
||||
*dst_offsets[0] / src_vreg_slice[0], *dst_offsets[1] / src_vreg_slice[1]};
|
||||
|
||||
// Try to get i32 vector scratch space. Because we will bitcast vregs to
|
||||
// i32 vregs before using scratch for retiling. Through this way we can
|
||||
// handle packed types as well.
|
||||
@ -6295,24 +6098,57 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
|
||||
dst_tiling[1]};
|
||||
std::array<int64_t, 2> vi32_src_tiling = {src_tiling[0] / packing,
|
||||
src_tiling[1]};
|
||||
|
||||
const VectorLayout dst(bitwidth, dst_offsets, dst_tiling, src.implicit_dim());
|
||||
TPU_ASSERT_LOC(loc, dst.isValid(ctx.target_shape));
|
||||
xla::Array<Value> dst_vregs(
|
||||
dst.tileArrayImplicitShape(shape, ctx.target_shape));
|
||||
// When differences in offsets exist, the source vregs may stored at an offset
|
||||
// position in their group. For example, the 1st vreg in a row/column may be
|
||||
// stored as if it was the 3rd, so that the parts corresponding to the 1st and
|
||||
// 2nd in the destination are filled with padding. Likewise, loads to
|
||||
// destination vregs may be skipped, when they would load only padding.
|
||||
// store_vreg_delay is the position offset for stores, and load_vreg_skips is
|
||||
// the position offset for loads.
|
||||
//
|
||||
// For example, suppose we are going from 32-bit {0, 128}(2, 128) to
|
||||
// {4, 0}(8, 128). We form groups of 4 vregs that represent an (8, 512) slice
|
||||
// of the padded implicit shape. For the given offsets, for the first group,
|
||||
// the data is in (4:8, 128:512). But the first and second sources (stored
|
||||
// vregs) of the group form the slices of data (0:2, 0:512) and (2:4, 0:512),
|
||||
// which should be all padding. Likewise, the first dest vreg slice (which we
|
||||
// load from) holds the data from slice (0:8, 0:128), which is all padding.
|
||||
// We never load or store to slices that should contain only padding.
|
||||
if (src_tiling[0] > dst_tiling[0]) {
|
||||
return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles,
|
||||
vi32_dst_tiling, src_tiles,
|
||||
vi32_src_tiling, ref);
|
||||
DCHECK_EQ(src_offsets_in_dst_vreg_slices[1], 0);
|
||||
DCHECK_EQ(dst_offsets_in_src_vreg_slices[0], 0);
|
||||
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[1];
|
||||
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[0];
|
||||
if (failed(retileToSmallTileWithScratch(
|
||||
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
|
||||
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
if (src_tiling[0] < dst_tiling[0]) {
|
||||
return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles,
|
||||
vi32_dst_tiling, src_tiles,
|
||||
vi32_src_tiling, ref);
|
||||
DCHECK_EQ(src_offsets_in_dst_vreg_slices[0], 0);
|
||||
DCHECK_EQ(dst_offsets_in_src_vreg_slices[1], 0);
|
||||
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[0];
|
||||
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[1];
|
||||
if (failed(retileToLargeTileWithScratch(
|
||||
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
|
||||
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
dst_tiles = std::move(src_tiles);
|
||||
return success();
|
||||
return std::make_pair(dst, dst_vregs);
|
||||
}
|
||||
|
||||
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty,
|
||||
const VectorLayout src, xla::Array<Value> vregs,
|
||||
const std::array<int64_t, 2> dst_tiling, bool try_replicate_rows) {
|
||||
const std::array<int64_t, 2> dst_tiling,
|
||||
const LayoutOffsets dst_offsets_hint) {
|
||||
bool has_enough_scratch = ctx.max_sublanes_in_scratch >=
|
||||
ctx.target_shape[0] * (ctx.target_shape[0] + 1);
|
||||
const auto &target_shape = ctx.target_shape;
|
||||
@ -6328,6 +6164,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
const int8_t bitwidth = src.bitwidth();
|
||||
const std::array<int64_t, 2> dst_vreg_slice =
|
||||
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
|
||||
// TODO(tlongeri): Using canonical vs non-canonical offsets can change the
|
||||
// value of try_replicate rows, and it breaks some tests. It doesn't make
|
||||
// sense that we have different behavior for equivalent layouts, though. We
|
||||
// need better logic for picking the relayout strategy.
|
||||
const bool try_replicate_rows =
|
||||
src.offsets()[0].has_value() && !dst_offsets_hint[0].has_value();
|
||||
|
||||
// Fully replicated offsets are handled efficiently elsewhere (in relayout)
|
||||
CHECK(src.offsets()[0].has_value() || src.offsets()[1].has_value());
|
||||
@ -6399,15 +6241,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
});
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
|
||||
src.implicit_dim());
|
||||
if (!dst.isValid(target_shape)) {
|
||||
return emitError(loc, "Not implemented: invalid offsets in tiling target");
|
||||
}
|
||||
auto dst_tiles_shape =
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
|
||||
// (8,128) -> (8 * packing,128) tiling change for packed type.
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
|
||||
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
|
||||
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
|
||||
32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
|
||||
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
|
||||
ctx.target_shape[1]}) {
|
||||
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
|
||||
@ -6417,8 +6254,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
// not, since it relies on the src vreg array shape to know how many tiles
|
||||
// to pack in dst, and vreg array shapes with materialized offsets are
|
||||
// unfortunately not equal to vreg array shapes with replicated offsets.
|
||||
CHECK(dst.offsets() == src_offsets);
|
||||
xla::Array<Value> retiled(dst_tiles_shape);
|
||||
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
|
||||
src.implicit_dim());
|
||||
xla::Array<Value> retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
VectorType vreg_x32 =
|
||||
vty.getElementType().isSignlessInteger()
|
||||
? VectorType::get(target_shape, builder.getI32Type())
|
||||
@ -6466,7 +6305,9 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
// interesting if the next step is a retile, since we can also
|
||||
// match corresponding elements without shifting. It's just that
|
||||
// the tiles are not adjacent (no contiguous vreg slice).
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 &&
|
||||
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
|
||||
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
|
||||
32 % bitwidth == 0 &&
|
||||
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
|
||||
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
|
||||
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
|
||||
@ -6515,8 +6356,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
// not, since it relies on the src vreg array shape to know how many tiles
|
||||
// to pack in dst, and vreg array shapes with materialized offsets are
|
||||
// unfortunately not equal to vreg array shapes with replicated offsets.
|
||||
CHECK(dst.offsets() == src.offsets());
|
||||
xla::Array<Value> retiled(dst_tiles_shape);
|
||||
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
|
||||
src.implicit_dim());
|
||||
xla::Array<Value> retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
const VectorType vreg_x32 =
|
||||
vty.getElementType().isSignlessInteger()
|
||||
? VectorType::get(target_shape, builder.getI32Type())
|
||||
@ -6553,24 +6396,25 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) {
|
||||
// TODO(b/368088671): When sublane tiling changes, we should be able to
|
||||
// preserve some replications from the source layout. But we need to
|
||||
// make sure they are implemented efficiently and well-tested. For now, we
|
||||
// just simply use 0 for the replicated offset after retiling.
|
||||
dst = VectorLayout(
|
||||
bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)},
|
||||
dst_tiling, dst.implicit_dim());
|
||||
|
||||
// All clauses in the and expression are based on performance benchmarking.
|
||||
bool use_alu = !has_enough_scratch ||
|
||||
(ctx.hardware_generation >= 5 && src_tiling[0] != packing &&
|
||||
dst_tiling[0] != packing);
|
||||
|
||||
if (use_alu) {
|
||||
if (src_tiling[0] > dst_tiling[0]) {
|
||||
return std::pair(
|
||||
dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs,
|
||||
dst, target_shape));
|
||||
if (src_tiling[0] > dst_tiling[0] &&
|
||||
// retileToReducedSublanes does not support offset changes
|
||||
src.offsets()[0].value_or(0) < dst_vreg_slice[0] &&
|
||||
src.offsets()[1].value_or(0) < dst_vreg_slice[1]) {
|
||||
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
|
||||
src.implicit_dim());
|
||||
return std::pair(dst, retileToReducedSublanes(
|
||||
builder, vty.getShape(), src, vregs,
|
||||
VectorLayout(bitwidth,
|
||||
{src.offsets()[0].value_or(0),
|
||||
src.offsets()[1].value_or(0)},
|
||||
dst_tiling, dst.implicit_dim()),
|
||||
target_shape));
|
||||
} else if (!has_enough_scratch) {
|
||||
// TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops.
|
||||
return emitError(
|
||||
@ -6578,15 +6422,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
"Not implemented: retiling to increase sublane tiling with ALU");
|
||||
}
|
||||
}
|
||||
xla::Array<Value> retiled(dst_tiles_shape);
|
||||
if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs,
|
||||
src_tiling, packing))) {
|
||||
return failure();
|
||||
}
|
||||
return std::pair(dst, std::move(retiled));
|
||||
return retileWithScratch(ctx, builder, loc, vty.getShape(), dst_tiling,
|
||||
dst_offsets_hint, vregs, src);
|
||||
}
|
||||
return emitError(loc, "Not implemented: Unsupported tiling change for ")
|
||||
<< vty << ": from " << src << " to " << dst;
|
||||
<< vty << ": from " << src << " to (" << dst_tiling[0] << ", "
|
||||
<< dst_tiling[1] << ") tiling";
|
||||
}
|
||||
|
||||
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
|
||||
@ -6846,9 +6687,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
std::tie(src, src_tiles),
|
||||
changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles),
|
||||
dst.tiling(),
|
||||
dst.offsets()[0] == std::nullopt &&
|
||||
src.offsets()[0] != std::nullopt));
|
||||
dst.tiling(), dst.offsets()));
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
std::tie(src, src_tiles),
|
||||
@ -6894,10 +6733,11 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
|
||||
// arguments.
|
||||
auto op_result = dyn_cast<OpResult>(vector_operand);
|
||||
if (op_result == nullptr) {
|
||||
return op.emitError("Expected operand to be an operation result");
|
||||
return op.emitError(
|
||||
"Expected vector operand to be an operation result");
|
||||
}
|
||||
Operation *const def_op = op_result.getOwner();
|
||||
TPU_ASSERT_OP(def_op);
|
||||
DCHECK(def_op);
|
||||
const unsigned res_idx = op_result.getResultNumber();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> def_layouts,
|
||||
getOutLayouts(*def_op, ctx.target_shape));
|
||||
|
@ -66,22 +66,6 @@ using ImplicitDim = VectorLayout::ImplicitDim;
|
||||
|
||||
static constexpr int kLayoutLog = 10;
|
||||
|
||||
class Print {
|
||||
public:
|
||||
explicit Print(Operation *t) : payload_(t) {}
|
||||
Operation *payload_;
|
||||
|
||||
private:
|
||||
friend std::ostream &operator<<(std::ostream &, Print);
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, Print p) {
|
||||
std::string s;
|
||||
llvm::raw_string_ostream tmp_os(s);
|
||||
p.payload_->print(tmp_os);
|
||||
os << tmp_os.str();
|
||||
return os;
|
||||
}
|
||||
|
||||
bool is_fully_replicated(const Layout &layout) {
|
||||
static LayoutOffsets replicated_offsets = {std::nullopt, std::nullopt};
|
||||
@ -142,8 +126,7 @@ class VectorLayoutInferer {
|
||||
// TODO: b/342235360 - This check is temporary while we increase and test
|
||||
// support for offsets outside of the first tile. When support is more
|
||||
// broad, any op without support should check it within their own rule.
|
||||
if (!isa<arith::TruncIOp, arith::TruncFOp, vector::BroadcastOp,
|
||||
vector::ExtractStridedSliceOp>(any_op)) {
|
||||
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp>(any_op)) {
|
||||
const SmallVector<Layout> layouts_in = getLayoutFromOperands(&any_op);
|
||||
for (const Layout &layout : layouts_in) {
|
||||
if (layout &&
|
||||
@ -1664,10 +1647,17 @@ class VectorLayoutInferer {
|
||||
Layout dst_layout;
|
||||
if (layout.tiling() == nativeTiling(src_bitwidth)) {
|
||||
// If the source is already in native tiling, we can unpack it directly.
|
||||
src_layout = layout;
|
||||
std::array<int64_t, 2> dst_native_tiling = nativeTiling(dst_bitwidth);
|
||||
LayoutOffsets offsets = {layout.offsets()[0]
|
||||
? *layout.offsets()[0] % dst_native_tiling[0]
|
||||
: LayoutOffset(),
|
||||
layout.offsets()[1]};
|
||||
DCHECK_LT(offsets[1].value_or(0), dst_native_tiling[1]);
|
||||
src_layout = VectorLayout(src_bitwidth, offsets, layout.tiling(),
|
||||
layout.implicit_dim());
|
||||
dst_layout =
|
||||
VectorLayout(dst_bitwidth, layout.offsets(),
|
||||
nativeTiling(dst_bitwidth), layout.implicit_dim());
|
||||
VectorLayout(dst_bitwidth, offsets, dst_native_tiling,
|
||||
layout.implicit_dim());
|
||||
} else if (dst_bitwidth == 32 &&
|
||||
default_tiling_[0] % layout.tiling()[0] == 0 &&
|
||||
default_tiling_[1] == layout.tiling()[1]) {
|
||||
@ -1676,13 +1666,17 @@ class VectorLayoutInferer {
|
||||
// tiling through the op.
|
||||
// TODO(jevinjiang): we can relax this for non-32bit as well.
|
||||
src_layout = layout;
|
||||
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
|
||||
layout.implicit_dim());
|
||||
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(),
|
||||
src_layout->tiling(), layout.implicit_dim());
|
||||
} else {
|
||||
// TODO(b/335863273): we should also reduce offsets.
|
||||
src_layout = VectorLayout(src_bitwidth, layout.offsets(), default_tiling_,
|
||||
LayoutOffsets offsets = {
|
||||
layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0]
|
||||
: LayoutOffset(),
|
||||
layout.offsets()[1] ? *layout.offsets()[1] % default_tiling_[1]
|
||||
: LayoutOffset()};
|
||||
src_layout = VectorLayout(src_bitwidth, offsets, default_tiling_,
|
||||
layout.implicit_dim());
|
||||
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), default_tiling_,
|
||||
dst_layout = VectorLayout(dst_bitwidth, offsets, default_tiling_,
|
||||
layout.implicit_dim());
|
||||
}
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
@ -1700,22 +1694,23 @@ class VectorLayoutInferer {
|
||||
auto dst_ty = cast<VectorType>(op->getResult(0).getType());
|
||||
auto some_layout = getLayout(op->getOperand(0));
|
||||
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
|
||||
const unsigned src_bitwidth = src_ty.getElementTypeBitWidth();
|
||||
const unsigned dst_bitwidth = dst_ty.getElementTypeBitWidth();
|
||||
if (isa<arith::TruncFOp>(op)) {
|
||||
TPU_CHECK_OP(
|
||||
src_bitwidth == 32 && (dst_bitwidth == 16 || dst_bitwidth == 8),
|
||||
"Only 32-bit to 16-bit or 8-bit float truncation supported");
|
||||
if (dyn_cast<arith::TruncFOp>(op)) {
|
||||
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 &&
|
||||
(dst_ty.getElementTypeBitWidth() == 16 ||
|
||||
dst_ty.getElementTypeBitWidth() == 8),
|
||||
"Only 32-bit to 8-bit or 16-bit truncation supported");
|
||||
} else {
|
||||
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32,
|
||||
"Only 32-bit truncation supported");
|
||||
}
|
||||
auto &layout = *some_layout;
|
||||
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
|
||||
auto src_layout = VectorLayout(
|
||||
src_bitwidth, layout.offsets(),
|
||||
select_native ? nativeTiling(src_bitwidth) : layout.tiling(),
|
||||
layout.implicit_dim());
|
||||
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
|
||||
layout.implicit_dim());
|
||||
auto dst_layout = VectorLayout(
|
||||
dst_bitwidth, layout.offsets(),
|
||||
select_native ? nativeTiling(dst_bitwidth) : layout.tiling(),
|
||||
dst_ty.getElementTypeBitWidth(), layout.offsets(),
|
||||
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
|
||||
: default_tiling_,
|
||||
layout.implicit_dim());
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
return success();
|
||||
@ -1928,46 +1923,6 @@ class VectorLayoutInferer {
|
||||
});
|
||||
}
|
||||
|
||||
void setInLayout(Operation *op, ArrayRef<Layout> in) {
|
||||
CHECK_EQ(in.size(), op->getNumOperands()) << Print(op);
|
||||
SmallVector<Attribute, 4> in_attrs;
|
||||
in_attrs.reserve(in.size());
|
||||
for (const Layout &p : in) {
|
||||
in_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
|
||||
}
|
||||
op->setAttr("in_layout", ArrayAttr::get(op->getContext(), in_attrs));
|
||||
}
|
||||
|
||||
void setOutLayout(Operation *op, Layout out) {
|
||||
setOutLayout(op, ArrayRef<Layout>(out));
|
||||
}
|
||||
|
||||
void setOutLayout(Operation *op, ArrayRef<Layout> out) {
|
||||
SmallVector<Attribute, 4> out_attrs;
|
||||
out_attrs.reserve(out.size());
|
||||
for (const Layout &p : out) {
|
||||
out_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
|
||||
}
|
||||
op->setAttr("out_layout", ArrayAttr::get(op->getContext(), out_attrs));
|
||||
}
|
||||
|
||||
void setLayout(Operation *op, Layout in, Layout out) {
|
||||
setLayout(op, ArrayRef<Layout>(in), ArrayRef<Layout>(out));
|
||||
}
|
||||
|
||||
void setLayout(Operation *op, ArrayRef<Layout> in, Layout out) {
|
||||
setLayout(op, in, ArrayRef<Layout>(out));
|
||||
}
|
||||
|
||||
void setLayout(Operation *op, Layout in, ArrayRef<Layout> out) {
|
||||
setLayout(op, ArrayRef<Layout>(in), out);
|
||||
}
|
||||
|
||||
void setLayout(Operation *op, ArrayRef<Layout> in, ArrayRef<Layout> out) {
|
||||
setInLayout(op, in);
|
||||
setOutLayout(op, out);
|
||||
}
|
||||
|
||||
Layout getLayout(Value v) {
|
||||
auto op = v.getDefiningOp();
|
||||
CHECK(op);
|
||||
|
166
jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc
Normal file
166
jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc
Normal file
@ -0,0 +1,166 @@
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "llvm/include/llvm/Support/MathExtras.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/IR/Builders.h"
|
||||
#include "mlir/include/mlir/IR/Diagnostics.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/util.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
#define GEN_PASS_DECL_RELAYOUTINSERTIONPASS
|
||||
#define GEN_PASS_DEF_RELAYOUTINSERTIONPASS
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
FailureOr<TypedValue<VectorType>> relayout(
|
||||
OpBuilder &builder, TypedValue<VectorType> v, VectorLayout src,
|
||||
VectorLayout dst, const std::array<int64_t, 2> target_shape) {
|
||||
// change bitwidth
|
||||
if (v.getType().getElementType() == builder.getI1Type() &&
|
||||
// TODO(jevinjiang): for other relayout changes (tiling, offsets, implicit
|
||||
// dim), we currently rely on apply-vector-layout pass to do the relayout.
|
||||
src.bitwidth() != dst.bitwidth()) {
|
||||
CHECK(llvm::isPowerOf2_32(src.bitwidth()));
|
||||
CHECK(llvm::isPowerOf2_32(dst.bitwidth()));
|
||||
auto make_vty = [&](int bitwidth) {
|
||||
return VectorType::get(v.getType().getShape(),
|
||||
builder.getIntegerType(bitwidth));
|
||||
};
|
||||
auto make_constant = [&](int val, VectorLayout layout) {
|
||||
auto vty = make_vty(layout.bitwidth());
|
||||
auto constant_op = builder.create<arith::ConstantOp>(
|
||||
v.getLoc(),
|
||||
DenseElementsAttr::get(
|
||||
vty, builder.getIntegerAttr(vty.getElementType(), val)));
|
||||
setOutLayout(constant_op,
|
||||
VectorLayout(layout.bitwidth(), {std::nullopt, std::nullopt},
|
||||
layout.tiling(), layout.implicit_dim()));
|
||||
return constant_op;
|
||||
};
|
||||
auto src_int_vty = make_vty(src.bitwidth());
|
||||
auto dst_int_vty = make_vty(dst.bitwidth());
|
||||
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
|
||||
auto dst_bitwidth_layout = VectorLayout(
|
||||
dst.bitwidth(),
|
||||
{
|
||||
src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0]
|
||||
: LayoutOffset(),
|
||||
src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1]
|
||||
: LayoutOffset(),
|
||||
},
|
||||
src.tiling(), src.implicit_dim());
|
||||
auto ext_op = builder.create<arith::ExtUIOp>(v.getLoc(), src_int_vty, v);
|
||||
setLayout(ext_op, src, src);
|
||||
|
||||
// TODO(jevinjiang): some conversion might not be supported in HW.
|
||||
Operation *cast_op =
|
||||
dst.bitwidth() > src.bitwidth()
|
||||
? builder.create<arith::ExtSIOp>(v.getLoc(), dst_int_vty, ext_op)
|
||||
// TODO(jevinjiang): HW may support pack vmask directly.
|
||||
: builder.create<arith::TruncIOp>(v.getLoc(), dst_int_vty, ext_op);
|
||||
setLayout(cast_op, src, dst_bitwidth_layout);
|
||||
|
||||
auto cmp_op = builder.create<arith::CmpIOp>(
|
||||
v.getLoc(), v.getType(), arith::CmpIPredicate::ne,
|
||||
cast_op->getResult(0), make_constant(0, dst_bitwidth_layout));
|
||||
setLayout(cmp_op, {dst_bitwidth_layout, dst_bitwidth_layout},
|
||||
dst_bitwidth_layout);
|
||||
return cast<TypedValue<VectorType>>(cmp_op.getResult());
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
// TODO(jevinjiang): make relayout to an op so we don't need decide when to
|
||||
// relayout in apply-vector-layout pass.
|
||||
LogicalResult insertRelayout(Operation &op,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
|
||||
getInLayouts(op, target_shape));
|
||||
if (in_layouts.size() != op.getNumOperands()) {
|
||||
return op.emitError("Expected the same number of operands as in_layouts");
|
||||
}
|
||||
if (isa<tpu::AssumeLayoutOp>(op)) {
|
||||
return success();
|
||||
}
|
||||
// Relayout the operands, if their requested input layouts don't match the
|
||||
// layouts in which they were produced.
|
||||
for (auto [idx, tup] :
|
||||
llvm::enumerate(llvm::zip(op.getOperands(), in_layouts))) {
|
||||
auto [operand, li] = tup;
|
||||
auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand);
|
||||
TPU_ASSERT_EQ_OP(vector_operand != nullptr, li.has_value());
|
||||
if (vector_operand == nullptr) {
|
||||
continue;
|
||||
}
|
||||
// The operand should always be an Operation (and not a BlockArgument)
|
||||
// since we expect the FuncOp to have only memrefs and semaphores as
|
||||
// arguments.
|
||||
auto op_result = dyn_cast<OpResult>(vector_operand);
|
||||
if (op_result == nullptr) {
|
||||
return op.emitError("Expected vector operand to be an operation result");
|
||||
}
|
||||
Operation *const def_op = op_result.getOwner();
|
||||
DCHECK(def_op);
|
||||
const unsigned res_idx = op_result.getResultNumber();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> def_layouts,
|
||||
getOutLayouts(*def_op, target_shape));
|
||||
const Layout lo = def_layouts[res_idx];
|
||||
TPU_ASSERT_OP(lo.has_value());
|
||||
if (*lo == *li) {
|
||||
continue;
|
||||
}
|
||||
OpBuilder builder(&op);
|
||||
FAILUREOR_ASSIGN_OR_RETURN(Value new_v,
|
||||
relayout(builder, vector_operand, /*src=*/*lo,
|
||||
/*dst=*/*li, target_shape));
|
||||
op.setOperand(idx, new_v);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
struct RelayoutInsertionPass
|
||||
: public impl::RelayoutInsertionPassBase<RelayoutInsertionPass> {
|
||||
RelayoutInsertionPass(std::array<int64_t, 2> target_shape) {
|
||||
this->sublane_count = target_shape[0];
|
||||
this->lane_count = target_shape[1];
|
||||
}
|
||||
void runOnOperation() override {
|
||||
func::FuncOp func = getOperation();
|
||||
auto result = func.walk([&](Operation *op) {
|
||||
if (insertRelayout(*op, {sublane_count, lane_count}).failed()) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (result.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
|
||||
std::array<int64_t, 2> target_shape) {
|
||||
return std::make_unique<RelayoutInsertionPass>(target_shape);
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
@ -18,17 +18,33 @@ limitations under the License.
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/include/llvm/Support/raw_ostream.h"
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/IR/ValueRange.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, Print p) {
|
||||
std::string s;
|
||||
llvm::raw_string_ostream tmp_os(s);
|
||||
p.payload_->print(tmp_os);
|
||||
os << tmp_os.str();
|
||||
return os;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
|
||||
absl::Span<const int64_t> tiling) {
|
||||
SmallVector<int64_t> tile_strides(memref_ty.getRank());
|
||||
@ -147,4 +163,115 @@ bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space) {
|
||||
dyn_cast_or_null<tpu::MemorySpaceAttr>(ty.getMemorySpace());
|
||||
return memory_space && memory_space.getValue() == space;
|
||||
}
|
||||
|
||||
bool layoutIsValidForValue(const Layout &l, const Value v,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
// l must be non-null iff v is of vector type
|
||||
if (const auto vty = dyn_cast<VectorType>(v.getType())) {
|
||||
if (!l.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Vector type should have the same bitwidth as the layout, except for the
|
||||
// i1 special case, used for vmasks (see comment for VectorLayout class).
|
||||
if (!vty.getElementType().isIntOrFloat()) {
|
||||
return false;
|
||||
}
|
||||
const int8_t bitwidth = vty.getElementTypeBitWidth();
|
||||
if (bitwidth != l->bitwidth() && bitwidth != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return l->isValid(target_shape) && l->layout_rank() <= vty.getRank();
|
||||
}
|
||||
return !l.has_value();
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
|
||||
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
|
||||
SmallVector<Layout> out_layouts;
|
||||
out_layouts.reserve(array_attr.size());
|
||||
for (const Attribute a : array_attr) {
|
||||
if (auto layout_attr = dyn_cast_if_present<VectorLayoutAttr>(a)) {
|
||||
out_layouts.push_back(layout_attr.getLayout());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return out_layouts;
|
||||
}
|
||||
return SmallVector<Layout>{};
|
||||
}
|
||||
|
||||
// TODO(tlongeri, jevinjiang): Unify with infer_vector_layout.cc's getOutLayout.
|
||||
FailureOr<SmallVector<Layout>> getOutLayouts(
|
||||
Operation &op, const std::array<int64_t, 2> target_shape) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> out_layouts,
|
||||
getLayoutArrayFromAttr(op.getAttr("out_layout")));
|
||||
if (out_layouts.size() != op.getNumResults()) {
|
||||
return op.emitOpError("out_layout size does not match number of results");
|
||||
}
|
||||
for (const auto [l, res] : llvm::zip_equal(out_layouts, op.getResults())) {
|
||||
if (!layoutIsValidForValue(l, res, target_shape)) {
|
||||
return op.emitOpError("Invalid output layout");
|
||||
}
|
||||
}
|
||||
return out_layouts;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Layout>> getInLayouts(
|
||||
Operation &op, const std::array<int64_t, 2> target_shape) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
|
||||
getLayoutArrayFromAttr(op.getAttr("in_layout")));
|
||||
if (in_layouts.size() != op.getNumOperands()) {
|
||||
return op.emitOpError("in_layout size does not match number of operands");
|
||||
}
|
||||
for (const auto [l, operand] :
|
||||
llvm::zip_equal(in_layouts, op.getOperands())) {
|
||||
if (!layoutIsValidForValue(l, operand, target_shape)) {
|
||||
return op.emitOpError("Invalid input layout");
|
||||
}
|
||||
}
|
||||
return in_layouts;
|
||||
}
|
||||
|
||||
void setInLayout(Operation *op, ArrayRef<Layout> in) {
|
||||
CHECK_EQ(in.size(), op->getNumOperands()) << Print(op);
|
||||
SmallVector<Attribute, 4> in_attrs;
|
||||
in_attrs.reserve(in.size());
|
||||
for (const Layout &p : in) {
|
||||
in_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
|
||||
}
|
||||
op->setAttr("in_layout", ArrayAttr::get(op->getContext(), in_attrs));
|
||||
}
|
||||
|
||||
void setOutLayout(Operation *op, Layout out) {
|
||||
setOutLayout(op, ArrayRef<Layout>(out));
|
||||
}
|
||||
|
||||
void setOutLayout(Operation *op, ArrayRef<Layout> out) {
|
||||
SmallVector<Attribute, 4> out_attrs;
|
||||
out_attrs.reserve(out.size());
|
||||
for (const Layout &p : out) {
|
||||
out_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
|
||||
}
|
||||
op->setAttr("out_layout", ArrayAttr::get(op->getContext(), out_attrs));
|
||||
}
|
||||
|
||||
void setLayout(Operation *op, Layout in, Layout out) {
|
||||
setLayout(op, ArrayRef<Layout>(in), ArrayRef<Layout>(out));
|
||||
}
|
||||
|
||||
void setLayout(Operation *op, ArrayRef<Layout> in, Layout out) {
|
||||
setLayout(op, in, ArrayRef<Layout>(out));
|
||||
}
|
||||
|
||||
void setLayout(Operation *op, Layout in, ArrayRef<Layout> out) {
|
||||
setLayout(op, ArrayRef<Layout>(in), out);
|
||||
}
|
||||
|
||||
void setLayout(Operation *op, ArrayRef<Layout> in, ArrayRef<Layout> out) {
|
||||
setInLayout(op, in);
|
||||
setOutLayout(op, out);
|
||||
}
|
||||
} // namespace mlir::tpu
|
||||
|
@ -3,9 +3,12 @@
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
@ -15,8 +18,11 @@
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
|
||||
// TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with
|
||||
// MLIR diagnostics?
|
||||
@ -31,18 +37,86 @@
|
||||
// } \
|
||||
// } while (false)
|
||||
|
||||
#define FAILUREOR_ASSIGN_OR_RETURN_IMPL(failureor, lhs, rhs) \
|
||||
auto failureor = rhs; \
|
||||
if (failed(failureor)) { \
|
||||
return failure(); \
|
||||
} \
|
||||
lhs = std::move(failureor).value();
|
||||
// All the macros below here are to handle the case in
|
||||
// FAILUREOR_ASSIGN_OR_RETURN where the LHS is wrapped in parentheses. See a
|
||||
// more detailed discussion at https://stackoverflow.com/a/62984543
|
||||
#define FAILUREOR_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(X) \
|
||||
FAILUREOR_ASSIGN_OR_RETURN_ESCAPE(FAILUREOR_ASSIGN_OR_RETURN_EMPTY X)
|
||||
#define FAILUREOR_ASSIGN_OR_RETURN_EMPTY(...) \
|
||||
FAILUREOR_ASSIGN_OR_RETURN_EMPTY __VA_ARGS__
|
||||
#define FAILUREOR_ASSIGN_OR_RETURN_ESCAPE(...) \
|
||||
FAILUREOR_ASSIGN_OR_RETURN_ESCAPE_(__VA_ARGS__)
|
||||
#define FAILUREOR_ASSIGN_OR_RETURN_ESCAPE_(...) \
|
||||
FAILUREOR_ASSIGN_OR_RETURN_##__VA_ARGS__
|
||||
#define FAILUREOR_ASSIGN_OR_RETURN_FAILUREOR_ASSIGN_OR_RETURN_EMPTY
|
||||
|
||||
#define FAILUREOR_ASSIGN_OR_RETURN_IMPL(failureor, lhs, rhs) \
|
||||
auto failureor = rhs; \
|
||||
if (failed(failureor)) { \
|
||||
return failure(); \
|
||||
} \
|
||||
FAILUREOR_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \
|
||||
(std::move(failureor).value());
|
||||
#define FAILUREOR_ASSIGN_OR_RETURN(lhs, rhs) \
|
||||
FAILUREOR_ASSIGN_OR_RETURN_IMPL( \
|
||||
TF_STATUS_MACROS_CONCAT_NAME(failureor, __COUNTER__), lhs, rhs)
|
||||
|
||||
#define RETURN_IF_FAILED(...) \
|
||||
do { \
|
||||
if (failed(__VA_ARGS__)) { \
|
||||
return failure(); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
// TPU_ASSERT_* macros should be understood as an assert, i.e. use it to check
|
||||
// things that should never happen. We prefer returning failure over a CHECK
|
||||
// because it's easier to debug from Python (particularly from OSS where symbols
|
||||
// are removed)
|
||||
#define TPU_ASSERT_IMPL(stream, cond) \
|
||||
if (LLVM_UNLIKELY(!(cond))) { \
|
||||
(stream) << "Internal error: assert failed: " #cond; \
|
||||
}
|
||||
#define TPU_ASSERT_CMP_IMPL(stream, lhs, rhs, cmp) \
|
||||
if (LLVM_UNLIKELY(!((lhs)cmp(rhs)))) { \
|
||||
(stream) << "Internal error: assert failed: " #lhs " " #cmp " " #rhs " (" \
|
||||
<< (lhs) << " vs. " << (rhs) << ")"; \
|
||||
return failure(); \
|
||||
}
|
||||
#define TPU_ASSERT_OP(cond) TPU_ASSERT_IMPL(op.emitOpError(), cond)
|
||||
#define TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, cmp) \
|
||||
TPU_ASSERT_CMP_IMPL(op.emitOpError(), lhs, rhs, cmp)
|
||||
#define TPU_ASSERT_EQ_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, ==)
|
||||
#define TPU_ASSERT_GE_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, >=)
|
||||
#define TPU_ASSERT_GT_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, >)
|
||||
#define TPU_ASSERT_LE_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, <=)
|
||||
#define TPU_ASSERT_LT_OP(lhs, rhs) TPU_ASSERT_CMP_OP_IMPL(lhs, rhs, <)
|
||||
#define TPU_ASSERT_LOC(loc, cond) TPU_ASSERT_IMPL(mlir::emitError(loc), cond)
|
||||
#define TPU_ASSERT_CMP_LOC_IMPL(loc, lhs, rhs, cmp) \
|
||||
TPU_ASSERT_CMP_IMPL(loc, lhs, rhs, cmp)
|
||||
#define TPU_ASSERT_EQ_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, ==)
|
||||
#define TPU_ASSERT_GE_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, >=)
|
||||
#define TPU_ASSERT_GT_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, >)
|
||||
#define TPU_ASSERT_LT_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <)
|
||||
#define TPU_ASSERT_LE_LOC(loc, lhs, rhs) \
|
||||
TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <=)
|
||||
|
||||
class Print {
|
||||
public:
|
||||
explicit Print(Operation *t) : payload_(t) {}
|
||||
Operation *payload_;
|
||||
|
||||
private:
|
||||
friend std::ostream &operator<<(std::ostream &, Print);
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, Print p);
|
||||
|
||||
template <bool adjust_bool = false>
|
||||
FailureOr<int8_t> getTypeBitwidth(Type ty) {
|
||||
if (auto integer_ty = dyn_cast<IntegerType>(ty)) {
|
||||
@ -117,6 +191,26 @@ bool canReinterpretToUntiledMemref(TypedValue<MemRefType> tiled_memref,
|
||||
|
||||
// Determines whether the given MemRefType has the given memory space.
|
||||
bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space);
|
||||
|
||||
bool layoutIsValidForValue(const Layout &l, const Value v,
|
||||
const std::array<int64_t, 2> target_shape);
|
||||
|
||||
// Returns empty vector on null attribute
|
||||
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr);
|
||||
|
||||
FailureOr<SmallVector<Layout>> getOutLayouts(
|
||||
Operation &op, const std::array<int64_t, 2> target_shape);
|
||||
|
||||
FailureOr<SmallVector<Layout>> getInLayouts(
|
||||
Operation &op, const std::array<int64_t, 2> target_shape);
|
||||
|
||||
void setInLayout(Operation *op, ArrayRef<Layout> in);
|
||||
void setOutLayout(Operation *op, Layout out);
|
||||
void setOutLayout(Operation *op, ArrayRef<Layout> out);
|
||||
void setLayout(Operation *op, Layout in, Layout out);
|
||||
void setLayout(Operation *op, ArrayRef<Layout> in, Layout out);
|
||||
void setLayout(Operation *op, Layout in, ArrayRef<Layout> out);
|
||||
void setLayout(Operation *op, ArrayRef<Layout> in, ArrayRef<Layout> out);
|
||||
} // namespace mlir::tpu
|
||||
|
||||
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_
|
||||
|
206
jaxlib/mosaic/dialect/tpu/vreg_util.cc
Normal file
206
jaxlib/mosaic/dialect/tpu/vreg_util.cc
Normal file
@ -0,0 +1,206 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/include/mlir/IR/Diagnostics.h"
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/include/mlir/IR/Types.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/util.h"
|
||||
#include "xla/array.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
namespace {
|
||||
|
||||
VectorType getNativeVregOrVmaskTypeImpl(
|
||||
Type elem_ty, const int8_t bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
if (bitwidth == 32) {
|
||||
return VectorType::get(target_shape, elem_ty);
|
||||
}
|
||||
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
|
||||
elem_ty);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
|
||||
if (bitwidth == 1) {
|
||||
bitwidth = layout_bitwidth;
|
||||
} else {
|
||||
CHECK_EQ(bitwidth, layout_bitwidth);
|
||||
}
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
|
||||
}
|
||||
|
||||
VectorType getNativeVregType(Type elem_ty,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
|
||||
target_shape);
|
||||
}
|
||||
|
||||
TypedValue<VectorType> getFullVector(ImplicitLocOpBuilder &builder,
|
||||
VectorType vty, Attribute value) {
|
||||
return cast<TypedValue<VectorType>>(
|
||||
builder.create<arith::ConstantOp>(DenseElementsAttr::get(vty, value))
|
||||
.getResult());
|
||||
}
|
||||
|
||||
TypedValue<VectorType> getFullLikeVector(ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> vec,
|
||||
Attribute value) {
|
||||
return getFullVector(builder, vec.getType(), value);
|
||||
}
|
||||
|
||||
TypedValue<VectorType> getZerosVector(ImplicitLocOpBuilder &builder,
|
||||
VectorType vty) {
|
||||
return getFullVector(builder, vty, builder.getZeroAttr(vty.getElementType()));
|
||||
}
|
||||
|
||||
TypedValue<VectorType> getZerosLikeVector(ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> vec) {
|
||||
return getZerosVector(builder, vec.getType());
|
||||
}
|
||||
|
||||
FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
|
||||
ImplicitLocOpBuilder &builder, int64_t padding,
|
||||
const std::array<int64_t, 2> target_shape, int64_t dim) {
|
||||
VectorType i32_vreg_ty =
|
||||
getNativeVregType(builder.getI32Type(), target_shape);
|
||||
if (dim != 0 && dim != 1) {
|
||||
return builder.emitError()
|
||||
<< "Expected a 2D vector for getX32VmaskByPaddingEnd";
|
||||
}
|
||||
|
||||
if (padding < 0 || padding > target_shape[dim]) {
|
||||
return builder.emitError()
|
||||
<< "Padding must be in [0, target_shape[dim]). Padding: " << padding
|
||||
<< ", target_shape[dim]: " << target_shape[dim];
|
||||
}
|
||||
|
||||
Value padding_vreg =
|
||||
getFullVector(builder, i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(target_shape[dim] - padding));
|
||||
|
||||
return cast<TypedValue<VectorType>>(
|
||||
builder
|
||||
.create<arith::CmpIOp>(
|
||||
arith::CmpIPredicate::slt,
|
||||
builder.create<tpu::IotaOp>(i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(dim)),
|
||||
padding_vreg)
|
||||
.getResult());
|
||||
}
|
||||
|
||||
LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder,
|
||||
xla::Array<Value> &vregs,
|
||||
std::array<int64_t, 2> target_shape,
|
||||
int64_t padding_bottom,
|
||||
int64_t padding_right) {
|
||||
auto vreg_ty = dyn_cast<VectorType>(vregs.begin()->getType());
|
||||
if (!vreg_ty) {
|
||||
return builder.emitError() << "Expected a vector type";
|
||||
}
|
||||
|
||||
VectorType i32_vreg_ty =
|
||||
getNativeVregType(builder.getI32Type(), target_shape);
|
||||
Value i32_zeros_vreg = getZerosVector(builder, i32_vreg_ty);
|
||||
Value i32_max_vreg = getFullVector(builder, i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(0xffffffff));
|
||||
|
||||
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
|
||||
// Mask out the bottom.
|
||||
if (padding_bottom > 0) {
|
||||
// The function is only called when the vreg has native tiling. Therefore,
|
||||
// it is safe to bitcast to x32 vreg for masking.
|
||||
int sub_padding = padding_bottom % packing;
|
||||
int x32_padding_bottom = padding_bottom / packing;
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
Value mask_top, getX32VmaskByPaddingEnd(builder, x32_padding_bottom + 1,
|
||||
target_shape, /*dim=*/0));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
Value mask_bottom,
|
||||
getX32VmaskByPaddingEnd(builder, x32_padding_bottom, target_shape,
|
||||
/*dim=*/0));
|
||||
// Create an int32 vreg which contains subelement masking and then
|
||||
// logical_and with target vreg to mask out the unaligned paddings.
|
||||
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
|
||||
// [8, 128], then the mask will be:
|
||||
//
|
||||
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
|
||||
// sublane 6: [0 , 0 , ..., 0 ]
|
||||
// sublane 7: [0 , 0 , ..., 0 ]
|
||||
//
|
||||
// Through this way, in order to mask sub-elements, each target vreg only
|
||||
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
|
||||
// + packing).
|
||||
Value partial_sublane_mask = getFullVector(
|
||||
builder, i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(
|
||||
0xffffffff >> (sub_padding * vreg_ty.getElementTypeBitWidth())));
|
||||
// Insert 0xffffffff above the blended sublane.
|
||||
Value sublane_mask = builder.create<arith::SelectOp>(mask_top, i32_max_vreg,
|
||||
partial_sublane_mask);
|
||||
// Insert 0 below the blended sublane.
|
||||
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
|
||||
i32_zeros_vreg);
|
||||
for (int64_t i = 0; i < vregs.dim(1); ++i) {
|
||||
Value &vreg = vregs({vregs.dim(0) - 1, i});
|
||||
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
|
||||
if (sub_padding > 0) {
|
||||
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
|
||||
} else {
|
||||
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
|
||||
i32_zeros_vreg);
|
||||
}
|
||||
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
|
||||
}
|
||||
}
|
||||
// Mask out the right.
|
||||
if (padding_right > 0) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
Value mask_right, getX32VmaskByPaddingEnd(builder, padding_right,
|
||||
target_shape, /*dim=*/1));
|
||||
for (int64_t i = 0; i < vregs.dim(0); ++i) {
|
||||
Value &vreg = vregs({i, vregs.dim(1) - 1});
|
||||
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
|
||||
i32_vreg =
|
||||
builder.create<arith::SelectOp>(mask_right, i32_vreg, i32_zeros_vreg);
|
||||
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
82
jaxlib/mosaic/dialect/tpu/vreg_util.h
Normal file
82
jaxlib/mosaic/dialect/tpu/vreg_util.h
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
|
||||
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/Builders.h"
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/include/mlir/IR/Types.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "xla/array.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
// Returns the native vreg or vmask type for the given element type and target
|
||||
// shape. The layout bitwidth is used for i1 (vmask) elements.
|
||||
VectorType getNativeVregOrVmaskType(Type elem_ty, int8_t layout_bitwidth,
|
||||
std::array<int64_t, 2> target_shape);
|
||||
VectorType getNativeVregType(Type elem_ty, std::array<int64_t, 2> target_shape);
|
||||
|
||||
// Returns a zero constant of the same type as `vty`.
|
||||
TypedValue<VectorType> getZerosVector(ImplicitLocOpBuilder &builder,
|
||||
VectorType vty);
|
||||
// Same as above, but takes a `vec` as input.
|
||||
TypedValue<VectorType> getZerosLikeVector(ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> vec);
|
||||
|
||||
// Returns a constant of the same type as `vty` with the given `value`.
|
||||
TypedValue<VectorType> getFullVector(ImplicitLocOpBuilder &builder,
|
||||
VectorType vty, Attribute value);
|
||||
// Same as above, but takes a `vec` as input.
|
||||
TypedValue<VectorType> getFullLikeVector(ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> vec,
|
||||
Attribute value);
|
||||
|
||||
// Creates a vmask with false flags to bottom (dim = 0)
|
||||
// or right (dim = 1) where the flag count corresponds to the (dim_size -
|
||||
// padding).
|
||||
//
|
||||
// For example, assume vmask shape is (4, 8)
|
||||
//
|
||||
// getX32VmaskByPaddingEnd(padding=3, dim=1) creates:
|
||||
// [T, T, T, T, T, F, F, F]
|
||||
// [T, T, T, T, T, F, F, F]
|
||||
// [T, T, T, T, T, F, F, F]
|
||||
// [T, T, T, T, T, F, F, F]
|
||||
// TODO(b/385204135): Unify with getVmaskByPaddingEnd in tpu_rotate_rule, and
|
||||
// improve the codegen.
|
||||
FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
|
||||
ImplicitLocOpBuilder &builder, int64_t padding,
|
||||
std::array<int64_t, 2> target_shape, int64_t dim);
|
||||
|
||||
// Masks out the padding in the bottom and right of the vregs. vregs are
|
||||
// expected to have native tiling, and the masked vregs are mutated in
|
||||
// `vregs`. `padding_bottom` and `padding_right` is the number of elements to
|
||||
// pad in the bottom and right.
|
||||
LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder,
|
||||
xla::Array<Value> &vregs,
|
||||
std::array<int64_t, 2> target_shape,
|
||||
int64_t padding_bottom,
|
||||
int64_t padding_right);
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
||||
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
|
228
jaxlib/mosaic/dialect/tpu/vreg_util_test.cc
Normal file
228
jaxlib/mosaic/dialect/tpu/vreg_util_test.cc
Normal file
@ -0,0 +1,228 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/include/mlir/IR/Builders.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "mlir/include/mlir/IR/OwningOpRef.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/Support/DebugStringHelper.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::Eq;
|
||||
using ::testing::Optional;
|
||||
|
||||
MATCHER_P2(IsConstantOpWithSplatValue, type, splat_value, "") {
|
||||
auto constant_op = dyn_cast<arith::ConstantOp>(arg.getDefiningOp());
|
||||
if (constant_op == nullptr) {
|
||||
*result_listener << "Expected a constant op, got " << debugString(arg);
|
||||
return false;
|
||||
}
|
||||
auto dense_attr = dyn_cast<DenseElementsAttr>(constant_op.getValue());
|
||||
if (dense_attr == nullptr) {
|
||||
*result_listener << "Expected a dense elements attr, got "
|
||||
<< debugString(arg);
|
||||
return false;
|
||||
}
|
||||
if (dense_attr.getType() != type) {
|
||||
*result_listener << "Expected a dense elements attr with type "
|
||||
<< debugString(type) << ", got "
|
||||
<< debugString(dense_attr.getType());
|
||||
return false;
|
||||
}
|
||||
if (!dense_attr.isSplat()) {
|
||||
*result_listener << "Expected a splat dense elements attr, got "
|
||||
<< debugString(dense_attr);
|
||||
return false;
|
||||
}
|
||||
if (auto s = dense_attr.template getSplatValue<decltype(splat_value)>();
|
||||
s != splat_value) {
|
||||
*result_listener << "Expected a splat dense elements attr with value "
|
||||
<< splat_value << ", got " << s;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MATCHER_P2(IsVectorTypeWithShape, shape, elem_ty, "") {
|
||||
auto vty = dyn_cast<VectorType>(arg);
|
||||
if (vty == nullptr) {
|
||||
*result_listener << "Expected a vector type, got " << debugString(arg);
|
||||
return false;
|
||||
}
|
||||
if (vty.getShape() != ArrayRef<int64_t>(shape)) {
|
||||
*result_listener << "Expected a vector type with shape "
|
||||
<< absl::StrJoin(shape, ",") << ", got "
|
||||
<< absl::StrJoin(vty.getShape(), ",");
|
||||
return false;
|
||||
}
|
||||
if (vty.getElementType() != elem_ty) {
|
||||
*result_listener << "Expected a vector type with element type "
|
||||
<< debugString(elem_ty) << ", got "
|
||||
<< debugString(vty.getElementType());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
class VregUtilTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
context_.loadDialect<arith::ArithDialect, vector::VectorDialect,
|
||||
tpu::TPUDialect>();
|
||||
mlir::Location loc = mlir::UnknownLoc::get(&context_);
|
||||
mlir::OpBuilder b(&context_);
|
||||
module_ = b.create<ModuleOp>(loc);
|
||||
builder_ = std::make_unique<mlir::ImplicitLocOpBuilder>(
|
||||
module_->getLoc(), module_->getBodyRegion());
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
builder_.reset();
|
||||
// Reset the module to prevent memory leaks.
|
||||
module_ = nullptr;
|
||||
}
|
||||
|
||||
mlir::ImplicitLocOpBuilder& Builder() { return *builder_; }
|
||||
|
||||
private:
|
||||
MLIRContext context_;
|
||||
std::unique_ptr<mlir::ImplicitLocOpBuilder> builder_;
|
||||
OwningOpRef<ModuleOp> module_;
|
||||
};
|
||||
|
||||
TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeBitwidthMismatch) {
|
||||
EXPECT_DEATH(getNativeVregOrVmaskType(Builder().getI16Type(),
|
||||
/*layout_bitwidth=*/8, {2, 4}),
|
||||
"");
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeI1) {
|
||||
EXPECT_THAT(getNativeVregOrVmaskType(Builder().getI1Type(),
|
||||
/*layout_bitwidth=*/8, {2, 4}),
|
||||
IsVectorTypeWithShape(std::array<int64_t, 3>{2, 4, 4},
|
||||
Builder().getI1Type()));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetNativeVregF32) {
|
||||
EXPECT_THAT(getNativeVregType(Builder().getF32Type(), {2, 4}),
|
||||
IsVectorTypeWithShape(std::array<int64_t, 2>{2, 4},
|
||||
Builder().getF32Type()));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetNativeVregBf16) {
|
||||
EXPECT_THAT(getNativeVregType(Builder().getBF16Type(), {2, 4}),
|
||||
IsVectorTypeWithShape(std::array<int64_t, 3>{2, 4, 2},
|
||||
Builder().getBF16Type()));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetFullVector) {
|
||||
VectorType vty = VectorType::get({2, 4}, Builder().getI32Type());
|
||||
TypedValue<VectorType> vec =
|
||||
getFullVector(Builder(), vty, Builder().getI32IntegerAttr(0x1));
|
||||
|
||||
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0x1}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetFullLikeVector) {
|
||||
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
|
||||
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
|
||||
vty, Builder().create<arith::ConstantOp>(
|
||||
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
|
||||
TypedValue<VectorType> vec =
|
||||
getFullLikeVector(Builder(), in_vec, Builder().getF32FloatAttr(2.0f));
|
||||
|
||||
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{2.0f}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetZerosVector) {
|
||||
VectorType vty = VectorType::get({2, 4}, Builder().getI32Type());
|
||||
TypedValue<VectorType> vec = getZerosVector(Builder(), vty);
|
||||
|
||||
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetZerosLikeVector) {
|
||||
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
|
||||
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
|
||||
vty, Builder().create<arith::ConstantOp>(
|
||||
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
|
||||
TypedValue<VectorType> vec = getZerosLikeVector(Builder(), in_vec);
|
||||
|
||||
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{0.0f}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim0) {
|
||||
constexpr std::array<int64_t, 2> kTargetShape = {4, 8};
|
||||
FailureOr<TypedValue<VectorType>> vec = getX32VmaskByPaddingEnd(
|
||||
Builder(), /*padding=*/1, /*target_shape=*/kTargetShape,
|
||||
/*dim=*/0);
|
||||
ASSERT_TRUE(succeeded(vec));
|
||||
|
||||
auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
|
||||
ASSERT_TRUE(cmp_op != nullptr);
|
||||
EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
|
||||
|
||||
auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
|
||||
ASSERT_TRUE(iota_op != nullptr);
|
||||
EXPECT_THAT(iota_op.getDimension(), Optional(Eq(0)));
|
||||
|
||||
EXPECT_THAT(
|
||||
cmp_op.getRhs(),
|
||||
IsConstantOpWithSplatValue(
|
||||
VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{3}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim1) {
|
||||
constexpr std::array<int64_t, 2> kTargetShape = {4, 8};
|
||||
FailureOr<TypedValue<VectorType>> vec = getX32VmaskByPaddingEnd(
|
||||
Builder(), /*padding=*/3, /*target_shape=*/kTargetShape,
|
||||
/*dim=*/1);
|
||||
ASSERT_TRUE(succeeded(vec));
|
||||
|
||||
auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
|
||||
ASSERT_TRUE(cmp_op != nullptr);
|
||||
EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
|
||||
|
||||
auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
|
||||
ASSERT_TRUE(iota_op != nullptr);
|
||||
EXPECT_THAT(iota_op.getDimension(), Optional(Eq(1)));
|
||||
|
||||
EXPECT_THAT(
|
||||
cmp_op.getRhs(),
|
||||
IsConstantOpWithSplatValue(
|
||||
VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlir::tpu
|
@ -33,4 +33,5 @@ except ImportError:
|
||||
from mlir.dialects._ods_common import _cext # type: ignore[import-not-found]
|
||||
|
||||
|
||||
_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
|
||||
# Add the parent module to the search prefix
|
||||
_cext.globals.append_dialect_search_prefix(__name__[:__name__.rfind(".")])
|
||||
|
@ -83,6 +83,7 @@ setup(
|
||||
'cuda/*',
|
||||
'cuda/nvvm/libdevice/libdevice*',
|
||||
'mosaic/*.py',
|
||||
'mosaic/dialect/gpu/*.py',
|
||||
'mosaic/gpu/*.so',
|
||||
'mosaic/python/*.py',
|
||||
'mosaic/python/*.so',
|
||||
|
@ -218,6 +218,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
dst_dir=mosaic_python_dir,
|
||||
src_files=[
|
||||
"__main__/jaxlib/mosaic/python/layout_defs.py",
|
||||
"__main__/jaxlib/mosaic/python/mosaic_gpu.py",
|
||||
"__main__/jaxlib/mosaic/python/tpu.py",
|
||||
],
|
||||
)
|
||||
@ -225,6 +226,16 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
|
||||
)
|
||||
mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu"
|
||||
os.makedirs(mosaic_gpu_dir)
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py",
|
||||
dst_dir=mosaic_gpu_dir,
|
||||
)
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py",
|
||||
dst_dir=mosaic_gpu_dir,
|
||||
)
|
||||
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir",
|
||||
@ -316,6 +327,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
|
||||
|
@ -300,6 +300,48 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
# around 15 seconds.
|
||||
self.assertLess(elapsed_time, 10)
|
||||
|
||||
def testInputsWithDifferentDeviceOrders(self):
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())[:2]
|
||||
if len(cpu_devices) < 2:
|
||||
self.skipTest("Not enough CPU devices")
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def add(x: jax.Array, y: jax.Array) -> jax.Array:
|
||||
arrays = [
|
||||
x.addressable_shards[1].data + y.addressable_shards[0].data,
|
||||
x.addressable_shards[0].data + y.addressable_shards[1].data,
|
||||
]
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
y.shape, y.sharding, arrays
|
||||
)
|
||||
|
||||
# The execution will use mixed device orders. We should specialize the
|
||||
# function with devices to avoid the argument-dependent device selection.
|
||||
add = add.specialize(devices=cpu_devices)
|
||||
|
||||
mesh1 = jax.sharding.Mesh([cpu_devices[0], cpu_devices[1]], "x")
|
||||
sharding1 = jax.sharding.NamedSharding(
|
||||
mesh1, jax.sharding.PartitionSpec("x")
|
||||
)
|
||||
mesh2 = jax.sharding.Mesh([cpu_devices[1], cpu_devices[0]], "x")
|
||||
sharding2 = jax.sharding.NamedSharding(
|
||||
mesh2, jax.sharding.PartitionSpec("x")
|
||||
)
|
||||
|
||||
x = np.array([0, 2])
|
||||
x = jax.device_put(x, sharding1)
|
||||
y = np.array([4, 8])
|
||||
y = jax.device_put(y, sharding2)
|
||||
|
||||
out = add(x, y)
|
||||
|
||||
self.assertEqual(out.sharding, sharding2)
|
||||
out_device_list = [shard.device for shard in out.addressable_shards]
|
||||
self.assertEqual(out_device_list, [cpu_devices[1], cpu_devices[0]])
|
||||
|
||||
out = jax.device_get(out)
|
||||
np.testing.assert_equal(out, np.array([2 + 4, 0 + 8]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -219,6 +219,29 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
[ 1 2 3 4 5 6 7 8 9 10 12 13 14]]
|
||||
"""))
|
||||
|
||||
def test_debug_print_respects_numpy_printoptions(self):
|
||||
def f(x):
|
||||
with np.printoptions(precision=2, suppress=True):
|
||||
jax.debug.print("{}", x)
|
||||
x = np.array([1.2345, 2.3456, 1E-7])
|
||||
|
||||
# Default numpy print options:
|
||||
with jtu.capture_stdout() as output:
|
||||
jax.debug.print("{}", x)
|
||||
self.assertEqual(output(), "[1.2345e+00 2.3456e+00 1.0000e-07]\n")
|
||||
|
||||
# Modified print options without JIT:
|
||||
with jtu.capture_stdout() as output:
|
||||
f(x)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
|
||||
|
||||
# Modified print options with JIT:
|
||||
with jtu.capture_stdout() as output:
|
||||
jax.jit(f)(x)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
|
||||
|
||||
|
||||
class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -30,6 +30,7 @@ from jax._src import abstract_arrays
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import linear_util
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
@ -279,8 +280,13 @@ class FfiTest(jtu.JaxTestCase):
|
||||
def testBackwardCompatSyntax(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test_ffi", x, x, param=0.5)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
msg = "Calling ffi_call directly with input arguments is deprecated"
|
||||
if deprecations.is_accelerated("jax-ffi-call-args"):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
else:
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
|
||||
def testInputOutputAliases(self):
|
||||
def fun(x):
|
||||
|
@ -158,12 +158,13 @@ class FftTest(jtu.JaxTestCase):
|
||||
# Numpy promotes to complex128 aggressively.
|
||||
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6})
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6},
|
||||
rtol={np.float32: 2e-6})
|
||||
# Test gradient for differentiable types.
|
||||
if (config.enable_x64.value and
|
||||
dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
|
||||
# TODO(skye): can we be more precise?
|
||||
tol = 0.15
|
||||
tol = 0.16
|
||||
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
|
||||
|
||||
# check dtypes
|
||||
|
@ -254,8 +254,6 @@ def sdpa_train_fp8(
|
||||
class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
@ -366,6 +364,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_inference(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
|
||||
query = jax.random.normal(
|
||||
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
|
||||
@ -407,6 +407,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_var_seq(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
self.skipTest("Skip before fixed.")
|
||||
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
|
||||
query = jax.random.normal(
|
||||
@ -438,6 +440,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_broadcast_bias_and_dbias(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
@ -504,6 +508,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_dbias(self, batch_size: int):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
# cuDNN only supports dbias when batch size is 1. If the batch size is
|
||||
# greater, dbias is silently set to all zeros. This test verifies this
|
||||
# behavior for both vmap and regular use cases.
|
||||
@ -540,6 +546,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_sliding_window_length(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
|
||||
query = jax.random.normal(
|
||||
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
|
||||
@ -571,8 +579,43 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
|
||||
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_large_head_size(self):
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if cudnn_version < 90500:
|
||||
self.skipTest("Requires >= cuDNN 9.5.0")
|
||||
if not jtu.is_cuda_compute_capability_at_least("9.0"):
|
||||
self.skipTest("Requires at least Hopper arch")
|
||||
|
||||
B, T, N, H = 2, 64, 2, 256
|
||||
bf16 = jnp.bfloat16
|
||||
keys = jax.random.split(jax.random.key(0), 4)
|
||||
query = jax.random.normal(keys[0], (B, T, N, H), dtype=bf16)
|
||||
key = jax.random.normal(keys[1], (B, T, N, H), dtype=bf16)
|
||||
value = jax.random.normal(keys[2], (B, T, N, H), dtype=bf16)
|
||||
grad = jax.random.normal(keys[3], (B, T, N, H), dtype=bf16)
|
||||
sdpa_train_ans = jax.jit(partial(
|
||||
sdpa_train, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0)
|
||||
)
|
||||
sdpa_train_rfc = jax.jit(partial(
|
||||
sdpa_train_ref, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0)
|
||||
)
|
||||
|
||||
out_ans, grads_ans = sdpa_train_ans(query, key, value, grad, None, None)
|
||||
out_ref, grads_ref = sdpa_train_rfc(query, key, value, grad, None, None)
|
||||
self.assertArraysAllClose(out_ref, out_ans)
|
||||
self.assertArraysAllClose(grads_ref[0], grads_ans[0])
|
||||
self.assertArraysAllClose(grads_ref[1], grads_ans[1])
|
||||
self.assertArraysAllClose(grads_ref[2], grads_ans[2])
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_layouts(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
dtype = "bfloat16"
|
||||
B, T, N, H = 4, 1024, 8, 128
|
||||
S = T
|
||||
@ -600,6 +643,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(dv_ref, _cvt_back(dv))
|
||||
|
||||
def test_sdpa_utils(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
test_cases = [
|
||||
(1, 257, 64, 8905, False, True, True),
|
||||
(1, 1024, 64, 8905, False, False, True),
|
||||
|
@ -25,18 +25,10 @@ import jax.numpy as jnp
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
# Helper class used to create a reference cycle.
|
||||
class GarbageCollectionGuardTestNodeHelper:
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.next = None
|
||||
|
||||
|
||||
def _create_array_cycle():
|
||||
"""Creates a reference cycle of two jax.Arrays."""
|
||||
n1 = GarbageCollectionGuardTestNodeHelper(jax.jit(lambda: jnp.ones( (2, 2)))())
|
||||
n2 = GarbageCollectionGuardTestNodeHelper(jax.jit(lambda: jnp.zeros((2, 2)))())
|
||||
n1 = jnp.ones((2, 2))
|
||||
n2 = jnp.zeros((2, 2))
|
||||
n1.next = n2
|
||||
n2.next = n1
|
||||
return weakref.ref(n1)
|
||||
|
@ -87,6 +87,32 @@ python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
||||
# uint64 is problematic because with any uint type it promotes to float:
|
||||
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
|
||||
|
||||
def _bitcast_uint4_to_uint8(operand):
|
||||
# Note: assumes little-endian byte order.
|
||||
assert operand.dtype == 'uint4'
|
||||
operand = operand.astype('uint8')
|
||||
return operand[..., ::2] + (operand[..., 1::2] << 4)
|
||||
|
||||
def _bitcast_uint8_to_uint4(operand):
|
||||
# Note: assumes little-endian byte order.
|
||||
assert operand.dtype == 'uint8'
|
||||
result = np.zeros((*operand.shape[:-1], operand.shape[-1] * 2), dtype='uint4')
|
||||
result[..., ::2] = (operand & 0b00001111).astype('uint4')
|
||||
result[..., 1::2] = ((operand & 0b11110000) >> 4).astype('uint4')
|
||||
return result
|
||||
|
||||
def np_view(arr, dtype):
|
||||
# Implementation of np.ndarray.view() that works for int4/uint4
|
||||
dtype = np.dtype(dtype)
|
||||
nbits_in = dtypes.bit_width(arr.dtype)
|
||||
nbits_out = dtypes.bit_width(dtype)
|
||||
if nbits_in == 4:
|
||||
arr = _bitcast_uint4_to_uint8(arr.view('uint4'))
|
||||
if nbits_out == 4:
|
||||
arr = _bitcast_uint8_to_uint4(arr.view('uint8'))
|
||||
return arr.view(dtype)
|
||||
|
||||
|
||||
def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False,
|
||||
axis=None, **kwds):
|
||||
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
|
||||
@ -4244,9 +4270,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
# Final dimension must be a multiple of 16 to ensure compatibility of all dtype pairs.
|
||||
shape=[(0,), (32,), (2, 16)],
|
||||
a_dtype=all_dtypes,
|
||||
dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes,
|
||||
shape=[(0,), (64,), (2, 32)],
|
||||
a_dtype=(jnp.int4, jnp.uint4, *all_dtypes),
|
||||
dtype=((jnp.int4, jnp.uint4, *all_dtypes, None)
|
||||
if config.enable_x64.value else (jnp.int4, jnp.uint4, *all_dtypes)),
|
||||
)
|
||||
def testView(self, shape, a_dtype, dtype):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
@ -4259,7 +4286,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.rng()
|
||||
)
|
||||
args_maker = lambda: [rng(shape, a_dtype)]
|
||||
np_op = lambda x: np.asarray(x).view(dtype)
|
||||
np_op = lambda x: np_view(x, dtype)
|
||||
jnp_op = lambda x: jnp.asarray(x).view(dtype)
|
||||
# Above may produce signaling nans; ignore warnings from invalid values.
|
||||
with np.errstate(invalid='ignore'):
|
||||
@ -4268,9 +4295,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product([
|
||||
{'a_dtype': a_dtype, 'dtype': dtype}
|
||||
for a_dtype in all_dtypes
|
||||
for dtype in all_dtypes
|
||||
if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize
|
||||
for a_dtype in [jnp.int4, jnp.uint4, *all_dtypes]
|
||||
for dtype in [jnp.int4, jnp.uint4, *all_dtypes]
|
||||
if dtypes.bit_width(a_dtype) == dtypes.bit_width(dtype)
|
||||
])
|
||||
def testViewScalar(self, a_dtype, dtype):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
|
@ -170,43 +170,49 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
from_dtype=jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
|
||||
to_dtype=jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
|
||||
from_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
|
||||
to_dtype=['int4', 'uint4'] + jtu.dtypes.all_floating + jtu.dtypes.all_integer + jtu.dtypes.all_unsigned,
|
||||
shape = [(), (2,), (2, 3)]
|
||||
)
|
||||
def testBitcastConvertType(self, from_dtype, to_dtype, shape):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
itemsize_in = np.dtype(from_dtype).itemsize
|
||||
itemsize_out = np.dtype(to_dtype).itemsize
|
||||
if itemsize_in < itemsize_out:
|
||||
shape = (*shape, itemsize_out // itemsize_in)
|
||||
nbits_in = dtypes.bit_width(from_dtype)
|
||||
nbits_out = dtypes.bit_width(to_dtype)
|
||||
if nbits_in < nbits_out:
|
||||
shape = (*shape, nbits_out // nbits_in)
|
||||
args_maker = lambda: [rng(shape, from_dtype)]
|
||||
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
# Test the shape and dtype of the output. We avoid testing the values here
|
||||
# because the bitwise representation may vary from platform to platform.
|
||||
out = op(*args_maker())
|
||||
if itemsize_in == itemsize_out:
|
||||
out = jnp_op(*args_maker())
|
||||
if nbits_in == nbits_out:
|
||||
expected_shape = shape
|
||||
elif itemsize_in < itemsize_out:
|
||||
elif nbits_in < nbits_out:
|
||||
expected_shape = shape[:-1]
|
||||
else:
|
||||
expected_shape = (*shape, itemsize_in // itemsize_out)
|
||||
expected_shape = (*shape, nbits_in // nbits_out)
|
||||
self.assertEqual(out.dtype, to_dtype)
|
||||
self.assertEqual(out.shape, expected_shape)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
|
||||
for from_dtype, to_dtype in itertools.product(
|
||||
[np.float32, np.int32, "float32", "int32"], repeat=2)],
|
||||
['int4', 'uint4', np.int8, np.uint8, np.int32, np.float16, np.float32],
|
||||
repeat=2)],
|
||||
shape=[(4,), (2, 4), (2, 3, 4)]
|
||||
)
|
||||
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype):
|
||||
def testBitcastConvertTypeAgainstNumpy(self, from_dtype, to_dtype, shape):
|
||||
nbits_in = dtypes.bit_width(from_dtype)
|
||||
nbits_out = dtypes.bit_width(to_dtype)
|
||||
if nbits_in < nbits_out:
|
||||
shape = (*shape, nbits_out // nbits_in)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng((2, 3), from_dtype)]
|
||||
op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
||||
numpy_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
|
||||
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
||||
args_maker = lambda: [rng(shape, from_dtype)]
|
||||
jnp_op = lambda x: lax.bitcast_convert_type(x, to_dtype)
|
||||
np_op = lambda x: lax_reference.bitcast_convert_type(x, to_dtype)
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
|
||||
@ -1379,83 +1385,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
|
||||
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
||||
|
||||
def testRaggedAllToAllErrors(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input and output shapes must be equal, except for the outermost dimension."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32), input_offsets, send_sizes, output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be integer type."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), send_sizes, output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be integer type."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be integer type."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be integer type."):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([], dtype=jnp.int32), output_offsets, recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([], dtype=jnp.int32), recv_sizes)
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32))
|
||||
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
|
||||
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([], dtype=jnp.int32))
|
||||
|
||||
def testRaggedAllToAll(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
mlir_module = jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes).as_text()
|
||||
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
|
||||
self.assertIn(
|
||||
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
|
||||
" tensor<1x3xi64>}}",
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
|
||||
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
|
||||
],
|
||||
dtype=jtu.dtypes.numeric,
|
||||
)
|
||||
def testRaggedDot(self, m, k, n, num_groups, dtype):
|
||||
"""Tests ragged_dot.
|
||||
|
||||
The ragged_dot is tested against numpy reference implementation, and by running JAX compilation.
|
||||
|
||||
Raises:
|
||||
SkipTest: in the case dtype is not supported.
|
||||
"""
|
||||
lhs_shape = (m, k)
|
||||
rhs_shape = (num_groups, k, n)
|
||||
def group_sizes(m, num_groups):
|
||||
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
|
||||
ends = jnp.concatenate([ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
|
||||
starts = jnp.concatenate([jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
|
||||
return ends - starts
|
||||
rng = jtu.rand_small(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype), group_sizes(m, num_groups)]
|
||||
self._CompileAndCheck(lax.ragged_dot, args_maker)
|
||||
self._CheckAgainstNumpy(lax_reference.ragged_dot, lax.ragged_dot, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(), (2, 3)],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
@ -4457,7 +4386,7 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
|
||||
|
||||
elif name == 'log10':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag')
|
||||
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
|
||||
|
||||
elif name == 'exp':
|
||||
regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag')
|
||||
@ -4751,5 +4680,181 @@ class CompositeTest(jtu.JaxTestCase):
|
||||
):
|
||||
grad(my_square)(1.0)
|
||||
|
||||
|
||||
class RaggedTest(jtu.JaxTestCase):
|
||||
|
||||
def testRaggedAllToAll(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
mlir_module = jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
recv_sizes).as_text()
|
||||
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
|
||||
self.assertIn(
|
||||
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
|
||||
" tensor<1x3xi64>}}",
|
||||
mlir_module,
|
||||
)
|
||||
|
||||
def testRaggedAllToAllErrors(self):
|
||||
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
|
||||
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
|
||||
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
|
||||
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all input and output shapes must be equal, except for"
|
||||
" the outermost dimension.",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand,
|
||||
jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32),
|
||||
input_offsets, send_sizes, output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all input_offsets must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32),
|
||||
send_sizes, output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all send_sizes must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets,
|
||||
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets,
|
||||
recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all output_offsets must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes,
|
||||
jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, "ragged_all_to_all recv_sizes must be integer type."
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
|
||||
" size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes,
|
||||
output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
|
||||
" size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, jnp.array([], dtype=jnp.int32), send_sizes,
|
||||
output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets,
|
||||
jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, jnp.array([], dtype=jnp.int32),
|
||||
output_offsets, recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all output_offsets must be rank 1 with positive"
|
||||
" dimension size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes,
|
||||
jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all output_offsets must be rank 1 with positive"
|
||||
" dimension size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes,
|
||||
jnp.array([], dtype=jnp.int32), recv_sizes)
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (1, 3)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
jnp.array([[1, 2, 3]], dtype=jnp.int32))
|
||||
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape (0,)",
|
||||
):
|
||||
jax.jit(lax.ragged_all_to_all).lower(
|
||||
operand, output, input_offsets, send_sizes, output_offsets,
|
||||
jnp.array([], dtype=jnp.int32))
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},
|
||||
{'m': 10, 'k': 9, 'n': 8, 'num_groups': 2},
|
||||
],
|
||||
dtype=jtu.dtypes.numeric,
|
||||
)
|
||||
def testRaggedDot(self, m, k, n, num_groups, dtype):
|
||||
"""Tests ragged_dot.
|
||||
|
||||
The ragged_dot is tested against numpy reference implementation, and by
|
||||
running JAX compilation.
|
||||
|
||||
Raises:
|
||||
SkipTest: in the case dtype is not supported.
|
||||
"""
|
||||
lhs_shape = (m, k)
|
||||
rhs_shape = (num_groups, k, n)
|
||||
|
||||
def group_sizes(m, num_groups):
|
||||
ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1))
|
||||
ends = jnp.concatenate(
|
||||
[ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)])
|
||||
starts = jnp.concatenate(
|
||||
[jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final])
|
||||
return ends - starts
|
||||
|
||||
rng = jtu.rand_small(self.rng())
|
||||
args_maker = lambda: [
|
||||
rng(lhs_shape, dtype),
|
||||
rng(rhs_shape, dtype),
|
||||
group_sizes(m, num_groups),
|
||||
]
|
||||
self._CompileAndCheck(lax.ragged_dot, args_maker)
|
||||
self._CheckAgainstNumpy(
|
||||
lax_reference.ragged_dot, lax.ragged_dot, args_maker)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -586,7 +586,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1,
|
||||
)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
self.assertEmpty(
|
||||
list(filter(is_mosaic_gpu_op, self.module.body.operations))
|
||||
@ -604,7 +604,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
arrival_count=1,
|
||||
)
|
||||
scf.yield_([])
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
self.assertEmpty(
|
||||
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
|
||||
@ -626,7 +626,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
memref.copy(barriers_ref, barriers_ref)
|
||||
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
|
||||
all_mbarrier_init_shared_ops = find_if(
|
||||
@ -654,7 +654,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "missing a layout and can not be lowered"
|
||||
):
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
def test_lowering_eliminates_layouts(self):
|
||||
shape = (4, 128)
|
||||
@ -670,7 +670,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
)
|
||||
])
|
||||
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
all_ops_with_layouts = find_if(
|
||||
self.module,
|
||||
@ -691,7 +691,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
vector.store(array, ref, [zero_index, zero_index])
|
||||
|
||||
mgpu.infer_layout(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
all_loads = find_if(
|
||||
self.module,
|
||||
|
@ -47,6 +47,8 @@ except ImportError:
|
||||
z = 2
|
||||
else:
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import core
|
||||
from jax.experimental.mosaic.gpu import launch_context
|
||||
from jax.experimental.mosaic.gpu import utils as utils
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
|
||||
@ -1937,6 +1939,103 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
|
||||
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)
|
||||
|
||||
def test_pointwise_kernel_with_tma(self):
|
||||
def add(
|
||||
ctx: launch_context.LaunchContext,
|
||||
a_gmem_ref: ir.Value,
|
||||
b_gmem_ref: ir.Value,
|
||||
result_gmem_ref: ir.Value,
|
||||
smem: list[ir.Value],
|
||||
):
|
||||
del ctx
|
||||
a_smem_ref, b_smem_ref, result_smem_ref = smem[:3]
|
||||
tma_barrier = smem[3]
|
||||
memref_type = ir.MemRefType(a_gmem_ref.type)
|
||||
shape = memref_type.shape
|
||||
elt_type = memref_type.element_type
|
||||
|
||||
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
|
||||
with utils.single_thread():
|
||||
memref_bytes = utils.bytewidth(elt_type) # Also correct if rank == 0
|
||||
for size in shape:
|
||||
memref_bytes *= size
|
||||
nvvm.mbarrier_arrive_expect_tx_shared(
|
||||
tma_barrier.get_ptr(),
|
||||
arith.constant(ir.IntegerType.get_signless(32), 2*memref_bytes),
|
||||
)
|
||||
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
source=a_gmem_ref,
|
||||
destination=a_smem_ref,
|
||||
barrier=tma_barrier.as_dialect_barrier(),
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
arrive=False,
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
source=b_gmem_ref,
|
||||
destination=b_smem_ref,
|
||||
barrier=tma_barrier.as_dialect_barrier(),
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
arrive=False,
|
||||
)
|
||||
|
||||
tma_barrier.wait()
|
||||
|
||||
zero_index = arith.constant(ir.IndexType.get(), 0)
|
||||
|
||||
# SMEM -> registers
|
||||
ab_type = ir.VectorType.get(shape, elt_type)
|
||||
a = vector.load(ab_type, a_smem_ref, [zero_index, zero_index])
|
||||
b = vector.load(ab_type, b_smem_ref, [zero_index, zero_index])
|
||||
|
||||
# Computation
|
||||
add = arith.addf(arith.addf(a, b), b)
|
||||
|
||||
# Registers -> SMEM
|
||||
vector.store(add, result_smem_ref, [zero_index, zero_index])
|
||||
|
||||
# SMEM -> GMEM
|
||||
mgpu_dialect.async_store(
|
||||
source=result_smem_ref,
|
||||
destination=result_gmem_ref,
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
nvvm.cp_async_bulk_wait_group(0)
|
||||
utils.warpgroup_barrier()
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
shape = (128, 128)
|
||||
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
|
||||
kernel = mgpu.as_gpu_kernel(
|
||||
add,
|
||||
grid=(1, 1, 1),
|
||||
block=(128, 1, 1),
|
||||
in_shape=(jax_shape, jax_shape),
|
||||
out_shape=jax_shape,
|
||||
smem_scratch_shape=[
|
||||
jax_shape,
|
||||
jax_shape,
|
||||
jax_shape,
|
||||
core.TMABarrier(1),
|
||||
],
|
||||
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
||||
)
|
||||
|
||||
x = self.prng.uniform(-1, 1, shape).astype(dtype)
|
||||
y = self.prng.uniform(-1, 1, shape).astype(dtype)
|
||||
|
||||
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
|
||||
|
||||
|
||||
class UtilsTest(TestCase):
|
||||
@parameterized.parameters(
|
||||
|
@ -306,30 +306,88 @@ class MutableArrayErrorsTest(jtu.JaxTestCase):
|
||||
ValueError, "traced for cond returned a mutable array reference of type"):
|
||||
jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0))
|
||||
|
||||
# TODO test_argument_aliases_cond
|
||||
# TODO test_closure_and_argument_aliases_cond
|
||||
def test_argument_aliases_cond(self):
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex( ValueError, r"for cond.*at both x1 and x2"):
|
||||
jax.lax.cond(True, lambda x1, x2: ..., lambda x1, x2: ..., x_ref, x_ref)
|
||||
|
||||
# TODO test_return_from_custom_jvp/vjp
|
||||
# TODO test_argument_aliases_custom_jvp/vjp
|
||||
# TODO test_closure_and_argument_aliases_custom_jvp/vjp
|
||||
def test_closure_and_argument_aliases_cond(self):
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"closed over and passed as the argument y_ref"):
|
||||
jax.lax.cond(True,
|
||||
lambda y_ref: x_ref[...] + y_ref[...],
|
||||
lambda y_ref: x_ref[...] + y_ref[...],
|
||||
x_ref)
|
||||
|
||||
# TODO(mattjj): enable when cond works with mutable arrays
|
||||
# @parameterized.parameters([False, True])
|
||||
# def test_cond_both_branches_close_over_same_mutable_array(self, jit):
|
||||
# # see also test_cond_with_ref_reuse in state_test.py
|
||||
# x_ref = core.mutable_array(0.)
|
||||
# def f(pred):
|
||||
# def true_fun():
|
||||
# x_ref[()] = 1.
|
||||
# def false_fun():
|
||||
# x_ref[()] = 2.
|
||||
# jax.lax.cond(pred, true_fun, false_fun)
|
||||
# if jit:
|
||||
# f = jax.jit(f)
|
||||
# out_true = f(True)
|
||||
# self.assertAllClose(x_ref[...], 1.)
|
||||
# out_false = f(False)
|
||||
# self.assertAllClose(x_ref[...], 2.)
|
||||
@parameterized.parameters([False, True])
|
||||
def test_return_from_custom_vjp_primal(self, jit):
|
||||
@jax.custom_vjp
|
||||
def f(ref):
|
||||
return ref
|
||||
f.defvjp(lambda ref: ..., lambda *_: ...)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "custom_vjp primal function"):
|
||||
f(x_ref)
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_return_from_custom_vjp_fwd(self, jit):
|
||||
@jax.custom_vjp
|
||||
def f(x, ref):
|
||||
return x
|
||||
f.defvjp(lambda x, ref: (x, ref), lambda ref, g: g)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "custom_vjp fwd function"):
|
||||
jax.vjp(f, 3., x_ref)
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_argument_aliases_custom_vjp_primal(self, jit):
|
||||
@jax.custom_vjp
|
||||
def f(x_ref, y_ref):
|
||||
...
|
||||
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
|
||||
f(x_ref, x_ref)
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_argument_aliases_custom_vjp_fwd(self, jit):
|
||||
@jax.custom_vjp
|
||||
def f(x_ref, y_ref):
|
||||
...
|
||||
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x_ref = core.mutable_array(0.)
|
||||
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
|
||||
jax.vjp(f, x_ref, x_ref)
|
||||
|
||||
# TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
def test_cond_both_branches_close_over_same_mutable_array(self, jit):
|
||||
# see also test_cond_with_ref_reuse in state_test.py
|
||||
x_ref = core.mutable_array(0.)
|
||||
def f(pred):
|
||||
def true_fun():
|
||||
x_ref[()] = 1.
|
||||
def false_fun():
|
||||
x_ref[()] = 2.
|
||||
jax.lax.cond(pred, true_fun, false_fun)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
out_true = f(True)
|
||||
self.assertAllClose(x_ref[...], 1.)
|
||||
out_false = f(False)
|
||||
self.assertAllClose(x_ref[...], 2.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -39,8 +39,8 @@ jax_multiplatform_test(
|
||||
"tpu",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
"gpu_h100_x32",
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 8,
|
||||
|
@ -39,6 +39,15 @@ except ImportError:
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _fori_loop(force_while: bool, lb, ub, body, init):
|
||||
if force_while:
|
||||
# using jnp.asarray make the matcher for while or scan to think
|
||||
# that the bounds are dynamic and forces the use of the while
|
||||
# primitive.
|
||||
lb, ub = jnp.asarray(lb), jnp.asarray(ub)
|
||||
return jax.lax.fori_loop(lb, ub, body, init)
|
||||
|
||||
|
||||
class PallasTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -705,19 +714,21 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
|
||||
np.testing.assert_array_equal(kernel(x), x)
|
||||
|
||||
def test_fori_loop_array(self):
|
||||
@parameterized.parameters(False, True)
|
||||
def test_fori_loop_array(self, force_while):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
# Equivalent to x_ref[...] + 2 + 3.
|
||||
o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...])
|
||||
o_ref[...] = _fori_loop(force_while, 2, 4, lambda i, x: x + i, x_ref[...])
|
||||
|
||||
x = jnp.arange(256).astype(jnp.int32)
|
||||
np.testing.assert_array_equal(kernel(x), x + 2 + 3)
|
||||
|
||||
def test_fori_loop_scalar(self):
|
||||
@parameterized.parameters(False, True)
|
||||
def test_fori_loop_scalar(self, force_while):
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
@ -726,7 +737,7 @@ class PallasCallTest(PallasTest):
|
||||
def kernel(o_ref):
|
||||
# Equivalent to 2 + 3.
|
||||
o_ref[...] = jax.lax.broadcast(
|
||||
jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0), o_ref.shape
|
||||
_fori_loop(force_while, 2, 4, lambda i, x: x + i, 0), o_ref.shape
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
|
||||
@ -747,7 +758,8 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
|
||||
|
||||
def test_fori_loop_tuple(self):
|
||||
@parameterized.parameters(False, True)
|
||||
def test_fori_loop_tuple(self, force_while):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
||||
@ -761,14 +773,15 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
# Equivalent to 3 * (0 + 1).
|
||||
o_ref[...] = jax.lax.broadcast(
|
||||
sum(jax.lax.fori_loop(2, 4, body, (0, 0, 0))), o_ref.shape
|
||||
sum(_fori_loop(force_while, 2, 4, body, (0, 0, 0))), o_ref.shape
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
kernel(), jnp.full([256], 3 * (0 + 1), dtype=jnp.int32)
|
||||
)
|
||||
|
||||
def test_fori_loop_indexed_store(self):
|
||||
@parameterized.parameters(False, True)
|
||||
def test_fori_loop_indexed_store(self, force_while):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32),
|
||||
@ -778,7 +791,7 @@ class PallasCallTest(PallasTest):
|
||||
o_ref[idx] = x_ref[idx] + y_ref[idx]
|
||||
return ()
|
||||
|
||||
jax.lax.fori_loop(0, 4, body, ())
|
||||
_fori_loop(force_while, 0, 4, body, ())
|
||||
|
||||
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
|
||||
y = x + 1
|
||||
|
@ -2127,6 +2127,21 @@ class OpsTest(PallasBaseTest):
|
||||
)
|
||||
self.assertTrue(acceptable_errors, "Failed with error: " + str(e))
|
||||
|
||||
@parameterized.parameters((128, 128), (256, 256))
|
||||
def test_jnp_diagonal_pallas(self, n, m):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
# TODO(mvoz): platform_index_p on GPU
|
||||
self.skipTest("Not implemented on GPU")
|
||||
x = jnp.arange(n * m, dtype=jnp.float32).reshape((n, m))
|
||||
|
||||
def kernel(x_ref, out_ref):
|
||||
out_ref[...] = jnp.diagonal(x_ref[...])
|
||||
|
||||
out = self.pallas_call(
|
||||
kernel, out_shape=jax.ShapeDtypeStruct((n,), jnp.float32)
|
||||
)(x)
|
||||
np.testing.assert_array_equal(out, np.diagonal(x))
|
||||
|
||||
|
||||
class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
@ -1688,8 +1688,8 @@ class PallasControlFlowTest(PallasBaseTest):
|
||||
|
||||
def body(state):
|
||||
i, s = state
|
||||
sl = jax.lax.div(i, 128)
|
||||
l = jax.lax.rem(i, 128)
|
||||
sl = jax.lax.div(i, jnp.astype(128, i.dtype))
|
||||
l = jax.lax.rem(i, jnp.astype(128, i.dtype))
|
||||
v = pl.load(x_ref, (0, sl, l))
|
||||
return i + 1, s + v
|
||||
|
||||
|
@ -312,6 +312,46 @@ class OpsTest(PallasBaseTest):
|
||||
expected = reduce_func(x, axis, keepdims=True)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
@parameterized.product(
|
||||
msk_dtype=[jnp.float32, jnp.bfloat16, jnp.int8],
|
||||
dtype=[jnp.float32, jnp.bfloat16],
|
||||
)
|
||||
def test_i1_relayout_with_bitwidth_change(self, msk_dtype, dtype):
|
||||
# TODO(jevinjiang): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
shape = (129, 129)
|
||||
msk_bitwidth = pallas_utils.dtype_bitwidth(msk_dtype)
|
||||
bitwidth = pallas_utils.dtype_bitwidth(dtype)
|
||||
if (
|
||||
(jtu.get_tpu_version() > 5 and msk_bitwidth < 8)
|
||||
or (jtu.get_tpu_version() == 5 and msk_bitwidth not in (8, 32))
|
||||
or (jtu.get_tpu_version() < 5 and msk_bitwidth < 32)
|
||||
):
|
||||
self.skipTest(
|
||||
"Not implemented: cast vector to mask with bitwidth =="
|
||||
f" {msk_bitwidth}"
|
||||
)
|
||||
if jtu.get_tpu_version() <= 5 and bitwidth < 32:
|
||||
self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}")
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct(shape, dtype),
|
||||
)
|
||||
def kernel(x_ref, mask_ref, o_ref):
|
||||
zeros = jnp.zeros_like(x_ref)
|
||||
o_ref[...] = jnp.where(mask_ref[...], x_ref[...], zeros)
|
||||
|
||||
mask = jax.random.bernoulli(jax.random.key(1234), 0.5, shape).astype(
|
||||
msk_dtype
|
||||
)
|
||||
x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + 1
|
||||
|
||||
out = kernel(x, mask)
|
||||
expected = jnp.where(mask, x, jnp.zeros_like(x))
|
||||
self.assertArraysEqual(out, expected)
|
||||
|
||||
|
||||
class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
@ -2555,9 +2555,7 @@ class MiscellaneousTest(PallasBaseTest):
|
||||
|
||||
np.testing.assert_array_equal(out, np.reshape(x, (8, 128)))
|
||||
|
||||
@only_passes_in_interpret()
|
||||
def test_retiling2(self):
|
||||
"""b/348040767"""
|
||||
x = np.arange(1 * 8 * 1024, dtype=jnp.bfloat16).reshape(1, 8, 1024)
|
||||
|
||||
def kernel(x_ref, out_ref):
|
||||
|
@ -95,8 +95,8 @@ class PgleTest(jtu.JaxTestCase):
|
||||
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
compiler_options={
|
||||
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
||||
# TODO(patrios): Remove this flag once b/376647494 is fixed.
|
||||
'xla_gpu_graph_min_graph_size': '100000',
|
||||
# TODO(b/37664749): Remove this flag once the bug is fixed.
|
||||
'xla_gpu_enable_command_buffer': '',
|
||||
},
|
||||
)
|
||||
def f(x):
|
||||
@ -133,8 +133,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
compiler_options={
|
||||
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
||||
# TODO(patrios): Remove this flag once b/376647494 is fixed.
|
||||
'xla_gpu_graph_min_graph_size': '100000',
|
||||
'xla_dump_to': dump_dir,
|
||||
'xla_gpu_experimental_dump_fdo_profiles': 'True'
|
||||
},
|
||||
@ -217,8 +215,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
compiler_options={
|
||||
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
||||
# TODO(patrios): Remove this flag once b/376647494 is fixed.
|
||||
'xla_gpu_graph_min_graph_size': '100000',
|
||||
'xla_dump_to': dump_dir,
|
||||
'xla_gpu_experimental_dump_fdo_profiles': 'True'
|
||||
},
|
||||
|
@ -2207,6 +2207,23 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
#
|
||||
# f(x) # don't crash
|
||||
|
||||
def test_partial_auto_of_random_keys(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest('Shardy does not support full-to-shard.')
|
||||
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
keys = jax.random.split(jax.random.key(0), 8)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return shard_map(lambda k: k,
|
||||
mesh, in_specs=P('i'), out_specs=P('i'),
|
||||
check_rep=False, auto=frozenset({'j'}))(keys)
|
||||
|
||||
y = f(keys) # don't crash
|
||||
self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys),
|
||||
check_dtypes=False)
|
||||
|
||||
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
|
||||
# https://github.com/jax-ml/jax/pull/21032
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
|
4
third_party/xla/workspace.bzl
vendored
4
third_party/xla/workspace.bzl
vendored
@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update XLA_SHA256 with the result.
|
||||
|
||||
XLA_COMMIT = "b44f55da3dac449f03466815ac431474f86fd73f"
|
||||
XLA_SHA256 = "f3d37257b970fd2993cbbc9185c2271910775f752d7c3bdd1828b8f663df1ff1"
|
||||
XLA_COMMIT = "c12c1148585e00985d5e1ccf2bc0768862b7df77"
|
||||
XLA_SHA256 = "44396bdac8b8bc7cba958691ae8df040ba91ddb26513aed37656d6db479dd06c"
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
|
Loading…
x
Reference in New Issue
Block a user