mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #193 from ROCm/ci-upstream-sync-78_1
CI: 01/06/25 upstream sync
This commit is contained in:
commit
4b11080f18
2
.github/workflows/upstream-nightly.yml
vendored
2
.github/workflows/upstream-nightly.yml
vendored
@ -85,7 +85,7 @@ jobs:
|
||||
&& steps.status.outcome == 'failure'
|
||||
&& github.event_name == 'schedule'
|
||||
&& github.repository == 'jax-ml/jax'
|
||||
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
|
||||
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
|
||||
with:
|
||||
name: output-${{ matrix.python-version }}-log.jsonl
|
||||
path: output-${{ matrix.python-version }}-log.jsonl
|
||||
|
2
.github/workflows/wheel_win_x64.yml
vendored
2
.github/workflows/wheel_win_x64.yml
vendored
@ -45,7 +45,7 @@ jobs:
|
||||
--bazel_options=--config=win_clang `
|
||||
--verbose
|
||||
|
||||
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
|
||||
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
|
||||
with:
|
||||
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
|
||||
path: ${{ github.workspace }}\dist\*.whl
|
||||
|
2
.github/workflows/windows_ci.yml
vendored
2
.github/workflows/windows_ci.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
--bazel_options=--config=win_clang `
|
||||
--verbose
|
||||
|
||||
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
|
||||
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
|
||||
with:
|
||||
name: wheels
|
||||
path: ${{ github.workspace }}\jax\dist\*.whl
|
||||
|
32
CHANGELOG.md
32
CHANGELOG.md
@ -4,6 +4,10 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
|
||||
For the changes specific to the experimental Pallas APIs,
|
||||
see {ref}`pallas-changelog`.
|
||||
|
||||
JAX follows Effort-based versioning; for a discussion of this and JAX's API
|
||||
compatibility policy, refer to {ref}`api-compatibility`. For the Python and
|
||||
NumPy version support policy, refer to {ref}`version-support-policy`.
|
||||
|
||||
<!--
|
||||
Remember to align the itemized text with the first line of an item within a list.
|
||||
|
||||
@ -12,6 +16,16 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## Unreleased
|
||||
|
||||
* Changes:
|
||||
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
|
||||
supported version until June 2025.
|
||||
|
||||
* New Features
|
||||
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
|
||||
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
|
||||
transforms in more than 3 dimensions, which was previously the limit. See
|
||||
{jax-issue}`#25606` for more details.
|
||||
|
||||
* Deprecations
|
||||
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
|
||||
are now deprecated, having been replaced by symbols of the same name
|
||||
@ -20,21 +34,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* Deletions
|
||||
* `jax_enable_memories` flag has been deleted and the behavior of that flag
|
||||
is on by default.
|
||||
|
||||
* Changes:
|
||||
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
|
||||
supported version until June 2025.
|
||||
|
||||
* Deprecations
|
||||
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
|
||||
are now deprecated, having been replaced by symbols of the same name
|
||||
in {mod}`jax.core`.
|
||||
|
||||
* New Features
|
||||
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
|
||||
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
|
||||
transforms in more than 3 dimensions, which was previously the limit. See
|
||||
{jax-issue}`#25606` for more details.
|
||||
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
|
||||
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
|
||||
and `jax.errors.JaxRuntimeError` respectively.
|
||||
|
||||
## jax 0.4.38 (Dec 17, 2024)
|
||||
|
||||
|
@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
|
||||
### Simplified Build Script
|
||||
|
||||
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.
|
||||
|
||||
|
@ -93,6 +93,12 @@ plugins" error described {ref}`below <multiple_installs>`. See
|
||||
<https://www.tensorflow.org/guide/profiler> for more information on installing
|
||||
TensorBoard.
|
||||
|
||||
Nightly version of TensorBoard profiler requires nightly tensorflow and
|
||||
tensorboard
|
||||
```shell
|
||||
pip install tf-nightly tb-nightly tbp-nightly
|
||||
```
|
||||
|
||||
### Programmatic capture
|
||||
|
||||
You can instrument your code to capture a profiler trace via the
|
||||
|
@ -490,7 +490,7 @@ This section covers some of the most common patterns with JAX pytrees.
|
||||
|
||||
### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose`
|
||||
|
||||
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
|
||||
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func}`jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
|
||||
|
||||
**Option 1:** Use {func}`jax.tree.map`. Here's an example:
|
||||
|
||||
|
@ -49,7 +49,6 @@ def masked_array_error(*args, **kwargs):
|
||||
"Use arr.filled() to convert the value to a standard numpy array.")
|
||||
|
||||
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
|
||||
core.shaped_abstractify_handlers[np.ma.MaskedArray] = masked_array_error
|
||||
|
||||
|
||||
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
@ -58,7 +57,6 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
core.shaped_abstractify_handlers[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
|
||||
|
||||
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
@ -68,7 +66,6 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
|
||||
for t in numpy_scalar_types:
|
||||
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
|
||||
core.shaped_abstractify_handlers[t] = _make_shaped_array_for_numpy_scalar
|
||||
|
||||
core.literalable_types.update(array_types)
|
||||
|
||||
@ -81,6 +78,5 @@ def _make_abstract_python_scalar(typ, val):
|
||||
|
||||
for t in dtypes.python_scalar_dtypes:
|
||||
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
|
||||
core.shaped_abstractify_handlers[t] = partial(_make_abstract_python_scalar, t)
|
||||
|
||||
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|
||||
|
@ -2564,7 +2564,6 @@ def _sds_aval_mapping(x):
|
||||
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
|
||||
weak_type=x.weak_type)
|
||||
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
|
||||
core.shaped_abstractify_handlers[ShapeDtypeStruct] = _sds_aval_mapping
|
||||
|
||||
|
||||
@api_boundary
|
||||
|
@ -1035,7 +1035,6 @@ def _get_aval_array(self):
|
||||
else:
|
||||
return self.aval
|
||||
|
||||
core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
|
||||
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
|
||||
|
||||
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
||||
|
@ -192,6 +192,14 @@ def get_compile_options(
|
||||
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
|
||||
build_options.memory_fitting_effort = config.memory_fitting_effort.value
|
||||
|
||||
# This is a temporary workaround to simplify the AutoPGLE usage.
|
||||
# TODO(b/376647494): Remove once the bug is fixed.
|
||||
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
|
||||
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
|
||||
if env_options_overrides is None:
|
||||
env_options_overrides = {}
|
||||
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
|
||||
|
||||
if env_options_overrides is not None:
|
||||
# Some overrides are passed directly on build_options.
|
||||
overrides_on_build_options = [
|
||||
|
@ -656,6 +656,13 @@ def check_bool_conversion(arr: Array):
|
||||
" is ambiguous. Use a.any() or a.all()")
|
||||
|
||||
|
||||
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
|
||||
def _str_abstractify(x):
|
||||
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
|
||||
pytype_aval_mappings[str] = _str_abstractify
|
||||
|
||||
|
||||
def _aval_property(name):
|
||||
return property(lambda self: getattr(self.aval, name))
|
||||
|
||||
@ -918,6 +925,8 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
aval_property = namedtuple("aval_property", ["fget"])
|
||||
aval_method = namedtuple("aval_method", ["fun"])
|
||||
|
||||
pytype_aval_mappings[Tracer] = lambda x: x.aval
|
||||
|
||||
def check_eval_args(args):
|
||||
for arg in args:
|
||||
if isinstance(arg, Tracer):
|
||||
@ -1400,45 +1409,51 @@ def check_valid_jaxtype(x):
|
||||
f"Value {x!r} of type {type(x)} is not a valid JAX type")
|
||||
|
||||
|
||||
def _shaped_abstractify_slow(x):
|
||||
try:
|
||||
return x if isinstance(x, AbstractValue) else get_aval(x)
|
||||
except TypeError:
|
||||
pass
|
||||
# We have three flavors of abstractification APIs here which each used to have
|
||||
# their own separate implementation. Now they're effectively the same, with the
|
||||
# following differences:
|
||||
#
|
||||
# - abstractify returns avals for non-traced array-like objects.
|
||||
# - get_aval is like abstractify, but also accepts tracers.
|
||||
# - shaped_abstractify is like get_aval, but also accepts duck-typed arrays.
|
||||
#
|
||||
# TODO(jakevdp): can these be unified further?
|
||||
|
||||
weak_type = getattr(x, 'weak_type', False)
|
||||
if hasattr(x, 'dtype'):
|
||||
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {type(x)} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
return ShapedArray(np.shape(x), dtype, weak_type=weak_type)
|
||||
|
||||
# TODO(jakevdp): deduplicate this with abstractify
|
||||
def shaped_abstractify(x):
|
||||
# This was originally api_util.shaped_abstractify; temporarily moved
|
||||
# here in order to facilitate combining it with abstractify.
|
||||
handler = shaped_abstractify_handlers.get(type(x), None)
|
||||
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
|
||||
typ = type(x)
|
||||
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
|
||||
return aval_fn(x)
|
||||
for t in typ.__mro__[1:]:
|
||||
if (aval_fn := pytype_aval_mappings.get(t)):
|
||||
return aval_fn(x)
|
||||
if isinstance(x, AbstractValue):
|
||||
return x
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return shaped_abstractify(x.__jax_array__())
|
||||
if hasattr(x, 'dtype'):
|
||||
return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False))
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {typ} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
|
||||
|
||||
def abstractify(x):
|
||||
for typ in type(x).__mro__:
|
||||
aval_fn = pytype_aval_mappings.get(typ)
|
||||
if aval_fn: return aval_fn(x)
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return abstractify(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
|
||||
if isinstance(x, Tracer):
|
||||
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
|
||||
return get_aval(x)
|
||||
|
||||
|
||||
def get_aval(x):
|
||||
if isinstance(x, Tracer):
|
||||
return x.aval
|
||||
else:
|
||||
return abstractify(x)
|
||||
typ = type(x)
|
||||
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
|
||||
return aval_fn(x)
|
||||
for t in typ.__mro__[1:]:
|
||||
if (aval_fn := pytype_aval_mappings.get(t)):
|
||||
return aval_fn(x)
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return get_aval(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
|
||||
|
||||
get_type = get_aval
|
||||
|
||||
def is_concrete(x):
|
||||
return to_concrete_value(x) is not None
|
||||
@ -1831,13 +1846,6 @@ class DShapedArray(UnshapedArray):
|
||||
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
|
||||
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
shaped_abstractify_handlers: dict[Any, Callable[[Any], AbstractValue]] = {}
|
||||
|
||||
def _str_abstractify(x):
|
||||
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
|
||||
pytype_aval_mappings[str] = _str_abstractify
|
||||
shaped_abstractify_handlers[str] = _str_abstractify
|
||||
|
||||
class DArray:
|
||||
_aval: DShapedArray
|
||||
@ -1894,7 +1902,6 @@ def _darray_aval(x):
|
||||
return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
|
||||
|
||||
pytype_aval_mappings[DArray] = _darray_aval
|
||||
shaped_abstractify_handlers[DArray] = _darray_aval
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -1924,11 +1931,10 @@ class MutableArray:
|
||||
aval = property(lambda self: self._aval)
|
||||
shape = property(lambda self: self._aval.shape)
|
||||
dtype = property(lambda self: self._aval.dtype)
|
||||
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
|
||||
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
|
||||
def __getitem__(self, idx): return self._aval._getitem(self, idx)
|
||||
def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
|
||||
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
|
||||
pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
shaped_abstractify_handlers[MutableArray] = lambda x: x._aval
|
||||
|
||||
def mutable_array(init_val):
|
||||
return mutable_array_p.bind(init_val)
|
||||
@ -1984,7 +1990,6 @@ class Token:
|
||||
def block_until_ready(self):
|
||||
self._buf.block_until_ready()
|
||||
pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
shaped_abstractify_handlers[Token] = lambda _: abstract_token
|
||||
|
||||
|
||||
# TODO(dougalm): Deprecate these. They're just here for backwards compat.
|
||||
|
@ -348,11 +348,20 @@ def check_is_flash_attention(
|
||||
)
|
||||
else:
|
||||
# Regular attention conditions
|
||||
if not ((H <= 128 and H % 8 == 0) and
|
||||
(not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported sequence length Q {T}, KV {S} and head dim {H}."
|
||||
)
|
||||
# Check the head dim.
|
||||
is_on_hopper = check_compute_capability("9.0")
|
||||
H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128
|
||||
if not (H <= H_max and H % 8 == 0):
|
||||
raise NotImplementedError(
|
||||
f"The head dim must be <= {H_max} and a mutiple of 8, "
|
||||
f"but got {H}."
|
||||
)
|
||||
|
||||
# Check patterns with bias, seqlen should be divisible by 2
|
||||
if (is_training and has_bias and (T % 2 != 0 or S % 2 != 0)):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported sequence length Q {T}, KV {S}."
|
||||
)
|
||||
|
||||
def check_cudnn_version():
|
||||
# check if cuDNN is installed
|
||||
|
@ -300,8 +300,9 @@ class _DebugPrintFormatChecker(string.Formatter):
|
||||
|
||||
formatter = _DebugPrintFormatChecker()
|
||||
|
||||
def _format_print_callback(fmt: str, *args, **kwargs):
|
||||
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
|
||||
def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs):
|
||||
with np.printoptions(**np_printoptions):
|
||||
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
|
||||
|
||||
def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
|
||||
"""Prints values and works in staged out JAX functions.
|
||||
@ -338,8 +339,8 @@ def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
|
||||
# Check that we provide the correct arguments to be formatted.
|
||||
formatter.format(fmt, *args, **kwargs)
|
||||
|
||||
debug_callback(functools.partial(_format_print_callback, fmt), *args,
|
||||
**kwargs, ordered=ordered)
|
||||
debug_callback(functools.partial(_format_print_callback, fmt, np.get_printoptions()),
|
||||
*args, **kwargs, ordered=ordered)
|
||||
|
||||
|
||||
# Sharding visualization
|
||||
|
@ -115,7 +115,6 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
|
||||
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
|
||||
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
|
||||
|
||||
core.shaped_abstractify_handlers[EArray] = lambda self: self.aval
|
||||
core.pytype_aval_mappings[EArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
|
||||
tree_util.dispatch_registry.register_node(
|
||||
|
@ -1205,7 +1205,6 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
|
||||
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
|
||||
|
||||
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
|
||||
core.shaped_abstractify_handlers[_DimExpr] = _DimExpr._get_aval
|
||||
dtypes._weak_types.append(_DimExpr)
|
||||
|
||||
def _convertible_to_int(p: DimSize) -> bool:
|
||||
|
@ -1569,10 +1569,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
|
||||
return self if val is None else get_referent(val)
|
||||
|
||||
|
||||
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
|
||||
return x.aval
|
||||
core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
|
||||
core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval
|
||||
|
||||
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
|
||||
sentinel = object()
|
||||
|
@ -710,7 +710,6 @@ def one_hot(x: Any, num_classes: int, *,
|
||||
'jax-nn-one-hot-float-input',
|
||||
f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}",
|
||||
stacklevel=1)
|
||||
x_arr = x_arr.astype('int32')
|
||||
return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis)
|
||||
|
||||
|
||||
|
@ -192,7 +192,6 @@ class _ScalarMeta(type):
|
||||
def _abstractify_scalar_meta(x):
|
||||
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
|
||||
core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta
|
||||
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
|
||||
|
||||
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
|
||||
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
|
||||
|
@ -924,7 +924,7 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None,
|
||||
- :func:`jax.numpy.linalg.inv`: multiplicative inverse of a square matrix.
|
||||
|
||||
Notes:
|
||||
:func:`jax.numpy.linalg.prng` differs from :func:`numpy.linalg.prng` in the
|
||||
:func:`jax.numpy.linalg.pinv` differs from :func:`numpy.linalg.pinv` in the
|
||||
default value of `rcond``: in NumPy, the default is `1e-15`. In JAX, the
|
||||
default is ``10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps``.
|
||||
|
||||
|
@ -2765,6 +2765,12 @@ def log2(x: ArrayLike, /) -> Array:
|
||||
Array([-2., -1., 0., 1., 2., 3.], dtype=float32)
|
||||
"""
|
||||
x, = promote_args_inexact("log2", x)
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
r = lax.log(x)
|
||||
re = lax.real(r)
|
||||
im = lax.imag(r)
|
||||
ln2 = lax.log(_constant_like(re, 2))
|
||||
return lax.complex(lax.div(re, ln2), lax.div(im, ln2))
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
|
||||
|
||||
|
||||
@ -2789,6 +2795,12 @@ def log10(x: ArrayLike, /) -> Array:
|
||||
[-2. -1. 0. 1. 2. 3.]
|
||||
"""
|
||||
x, = promote_args_inexact("log10", x)
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
r = lax.log(x)
|
||||
re = lax.real(r)
|
||||
im = lax.imag(r)
|
||||
ln10 = lax.log(_constant_like(re, 10))
|
||||
return lax.complex(lax.div(re, ln10), lax.div(im, ln10))
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
|
||||
|
||||
|
||||
|
@ -84,6 +84,8 @@ def _get_memory_space_from_aval(
|
||||
return None
|
||||
case tpu_core.TPUMemorySpace.VMEM:
|
||||
return tpu_custom_call.MemorySpace.VMEM
|
||||
case tpu_core.TPUMemorySpace.SMEM:
|
||||
return tpu_custom_call.MemorySpace.SMEM
|
||||
case tpu_core.TPUMemorySpace.SEMAPHORE:
|
||||
return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
|
||||
return None
|
||||
|
@ -797,25 +797,30 @@ def lower_jaxpr_to_module(
|
||||
# Each range is 2 events, each event is 4 bytes.
|
||||
prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4)
|
||||
prof_ctx = ProfilerContext(params["profile_dir"], prof_spec)
|
||||
module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel(
|
||||
body,
|
||||
grid=parallel_grid,
|
||||
cluster=(),
|
||||
block=block,
|
||||
in_shapes=in_structs_gmem,
|
||||
out_shape=out_structs_gmem,
|
||||
smem_scratch_shape=(
|
||||
(*in_structs_smem, *out_structs_smem),
|
||||
*extra_smem_scratch,
|
||||
(
|
||||
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
|
||||
rs.barriers,
|
||||
extra_barriers,
|
||||
module, out_structs_gmem, _, launch_ctx, scratch_arr = (
|
||||
mgpu_core._lower_as_gpu_kernel(
|
||||
body,
|
||||
grid=parallel_grid,
|
||||
cluster=(),
|
||||
block=block,
|
||||
in_shapes=in_structs_gmem,
|
||||
out_shape=out_structs_gmem,
|
||||
smem_scratch_shape=(
|
||||
(*in_structs_smem, *out_structs_smem),
|
||||
*extra_smem_scratch,
|
||||
(
|
||||
mgpu.Barrier(
|
||||
arrival_count=1, num_barriers=max_concurrent_steps
|
||||
),
|
||||
rs.barriers,
|
||||
extra_barriers,
|
||||
),
|
||||
),
|
||||
),
|
||||
module_name=name_and_src_info.name,
|
||||
prof_spec=prof_spec,
|
||||
module_name=name_and_src_info.name,
|
||||
prof_spec=prof_spec,
|
||||
)
|
||||
)
|
||||
mgpu_core._initialize_scratch(launch_ctx, scratch_arr)
|
||||
|
||||
return LoweringResult(
|
||||
module, parallel_grid, block, out_structs_gmem, prof_ctx
|
||||
@ -1782,29 +1787,21 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray:
|
||||
if isinstance(x, mgpu.FragmentedArray):
|
||||
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
return x
|
||||
elif isinstance(x, (np.number, np.ndarray, int, float)):
|
||||
return mgpu.FragmentedArray.splat(
|
||||
_ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)),
|
||||
(),
|
||||
is_signed=mgpu_utils.is_signed(dtype),
|
||||
)
|
||||
elif isinstance(x, ir.Value):
|
||||
if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)):
|
||||
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype))
|
||||
raise NotImplementedError(f"Unsupported type: {type(x)}")
|
||||
return mgpu.FragmentedArray.splat(
|
||||
_ensure_ir_value(x, dtype), (), is_signed=mgpu_utils.is_signed(dtype)
|
||||
)
|
||||
|
||||
|
||||
def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value:
|
||||
if isinstance(x, ir.Value):
|
||||
assert x.type == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
return x
|
||||
elif isinstance(x, (np.number, np.ndarray, int, float)):
|
||||
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
|
||||
elif isinstance(x, mgpu.FragmentedArray):
|
||||
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
if isinstance(x.layout, mgpu.WGSplatFragLayout):
|
||||
return x.registers.item()
|
||||
raise NotImplementedError(f"Unsupported type: {type(x)}")
|
||||
raise NotImplementedError(f"Unsupported layout: {x.layout}")
|
||||
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
|
||||
|
||||
|
||||
def _ir_constant(v: object, t: ir.Type) -> ir.Value:
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import enum
|
||||
import math
|
||||
from typing import Any, Literal
|
||||
@ -25,7 +26,6 @@ from jax._src import core as jax_core
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith as arith_dialect
|
||||
from jax._src.lib.mlir.dialects import llvm as llvm_dialect
|
||||
@ -33,23 +33,42 @@ from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.mosaic_gpu import core as gpu_core
|
||||
from jax._src.pallas.mosaic_gpu import lowering
|
||||
from jax._src.pallas.mosaic_gpu.core import state_types
|
||||
from jax._src.state import discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as state_primitives
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
from jax.experimental.mosaic import gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import utils as mgpu_utils
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
WARPGROUP_SIZE = 128
|
||||
|
||||
|
||||
_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef
|
||||
|
||||
|
||||
def _check_ref(
|
||||
aval: object, name: str, memory_space: gpu_core.GPUMemorySpace
|
||||
) -> None:
|
||||
if not isinstance(aval, state_types.AbstractRef):
|
||||
raise TypeError(f"{name} must be a reference, got {aval}")
|
||||
aval_memory_space = getattr(aval, "memory_space", None) or gpu_core.GMEM
|
||||
if aval_memory_space is not memory_space:
|
||||
raise ValueError(
|
||||
f"{name} must be a {memory_space.name.upper()} reference, got {aval}"
|
||||
)
|
||||
|
||||
|
||||
copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem")
|
||||
copy_smem_to_gmem_p.multiple_results = True
|
||||
|
||||
|
||||
@copy_smem_to_gmem_p.def_effectful_abstract_eval
|
||||
def _copy_smem_to_gmem_abstract_eval(*avals, **params):
|
||||
del avals, params # Unused.
|
||||
def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
|
||||
_check_ref(src, "src", gpu_core.SMEM)
|
||||
_check_ref(dst, "dst", gpu_core.GMEM)
|
||||
del args, params # Unused.
|
||||
return (), {state.ReadEffect(0), state.WriteEffect(1)}
|
||||
|
||||
|
||||
@ -114,9 +133,7 @@ def _extract_smem_copy_params(transforms):
|
||||
|
||||
|
||||
def copy_smem_to_gmem(
|
||||
src: pallas_core.AbstractMemoryRef,
|
||||
dst: pallas_core.AbstractMemoryRef,
|
||||
predicate: jax.Array | None = None,
|
||||
src: _Ref, dst: _Ref, predicate: jax.Array | None = None
|
||||
) -> None:
|
||||
"""Asynchronously copies a SMEM reference to a GMEM reference.
|
||||
|
||||
@ -130,10 +147,6 @@ def copy_smem_to_gmem(
|
||||
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
|
||||
:func:`jax.experimental.mosaic.gpu.commit_smem`
|
||||
"""
|
||||
if src.memory_space is not gpu_core.SMEM:
|
||||
raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
|
||||
if getattr(dst, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
|
||||
raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}")
|
||||
src, src_transforms = state_primitives.get_ref_and_transforms(
|
||||
src, None, "copy_smem_to_gmem", force_trailing_indexer=False,
|
||||
)
|
||||
@ -164,8 +177,11 @@ copy_gmem_to_smem_p.multiple_results = True
|
||||
|
||||
|
||||
@copy_gmem_to_smem_p.def_effectful_abstract_eval
|
||||
def _copy_gmem_to_smem_abstract_eval(*avals, **params):
|
||||
del avals, params # Unused.
|
||||
def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
|
||||
del args, params # Unused.
|
||||
_check_ref(src, "src", gpu_core.GMEM)
|
||||
_check_ref(dst, "dst", gpu_core.SMEM)
|
||||
_check_ref(barrier, "barrier", gpu_core.SMEM)
|
||||
return (), {state.ReadEffect(0), state.WriteEffect(1)}
|
||||
|
||||
|
||||
@ -217,21 +233,13 @@ def _copy_gmem_to_smem_lowering(
|
||||
return ()
|
||||
|
||||
|
||||
def copy_gmem_to_smem(
|
||||
src: pallas_core.AbstractMemoryRef,
|
||||
dst: pallas_core.AbstractMemoryRef,
|
||||
barrier: pallas_core.AbstractMemoryRef,
|
||||
) -> None:
|
||||
def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None:
|
||||
"""Asynchronously copies a GMEM reference to a SMEM reference.
|
||||
|
||||
See also:
|
||||
:func:`jax.experimental.mosaic.gpu.barrier_arrive`
|
||||
:func:`jax.experimental.mosaic.gpu.barrier_wait`
|
||||
"""
|
||||
if getattr(src, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
|
||||
raise TypeError(f"src must be a GMEM reference, got {src.memory_space}")
|
||||
if dst.memory_space is not gpu_core.SMEM:
|
||||
raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}")
|
||||
src, src_transforms = state_primitives.get_ref_and_transforms(
|
||||
src, None, "copy_gmem_to_smem", force_trailing_indexer=False,
|
||||
)
|
||||
@ -291,8 +299,9 @@ barrier_arrive_p.multiple_results = True
|
||||
|
||||
|
||||
@barrier_arrive_p.def_effectful_abstract_eval
|
||||
def _barrier_arrive_abstract_eval(*avals, **params):
|
||||
del avals, params # Unused.
|
||||
def _barrier_arrive_abstract_eval(barrier, *args, **params):
|
||||
del args, params # Unused.
|
||||
_check_ref(barrier, "barrier", gpu_core.SMEM)
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@ -328,8 +337,9 @@ barrier_wait_p.multiple_results = True
|
||||
|
||||
|
||||
@barrier_wait_p.def_effectful_abstract_eval
|
||||
def _barrier_wait_abstract_eval(*avals, **params):
|
||||
del avals, params # Unused.
|
||||
def _barrier_wait_abstract_eval(barrier, *args, **params):
|
||||
_check_ref(barrier, "barrier", gpu_core.SMEM)
|
||||
del args, params # Unused.
|
||||
return (), {gpu_core._memory_effect}
|
||||
|
||||
|
||||
@ -703,38 +713,40 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
|
||||
del layout, dimension
|
||||
return jax_core.ShapedArray(shape, dtype)
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(broadcasted_iota_p)
|
||||
def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout):
|
||||
del ctx
|
||||
# Unsigned integers (as opposed to signless) cause MLIR verification
|
||||
# errors so we only use signless like Mosaic GPU does.
|
||||
#
|
||||
# TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead.
|
||||
mlir_dtype = (
|
||||
ir.IntegerType.get_signless(dtype.itemsize * 8)
|
||||
if jnp.issubdtype(dtype, jnp.integer)
|
||||
else mlir.dtype_to_ir_type(dtype)
|
||||
)
|
||||
undef = llvm_dialect.mlir_undef(mlir_dtype)
|
||||
is_signed = (
|
||||
jnp.issubdtype(dtype, jnp.signedinteger)
|
||||
if jnp.issubdtype(dtype, jnp.integer)
|
||||
else None
|
||||
)
|
||||
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
def _cast(x):
|
||||
if ir.FloatType.isinstance(mlir_dtype):
|
||||
x = arith_dialect.index_cast(i32, x)
|
||||
return arith_dialect.uitofp(mlir_dtype, x)
|
||||
else:
|
||||
return arith_dialect.index_cast(mlir_dtype, x)
|
||||
def _broadcasted_iota_lowering(
|
||||
ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout
|
||||
):
|
||||
del ctx # Unused.
|
||||
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
|
||||
if ir.FloatType.isinstance(mlir_dtype):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
cast = lambda x: arith_dialect.uitofp(
|
||||
mlir_dtype, arith_dialect.index_cast(i32, x)
|
||||
)
|
||||
else:
|
||||
cast = lambda x: arith_dialect.index_cast(mlir_dtype, x)
|
||||
is_signed = mgpu_utils.is_signed(dtype)
|
||||
return mgpu.FragmentedArray.splat(
|
||||
undef, shape, layout.value, is_signed=is_signed
|
||||
llvm_dialect.mlir_undef(mlir_dtype),
|
||||
shape,
|
||||
layout.value,
|
||||
is_signed=is_signed,
|
||||
).foreach(
|
||||
lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed
|
||||
lambda _, idx: cast(idx[dimension]),
|
||||
create_array=True,
|
||||
is_signed=is_signed,
|
||||
)
|
||||
|
||||
|
||||
def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None):
|
||||
return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout)
|
||||
def broadcasted_iota(
|
||||
dtype: jax.typing.DTypeLike,
|
||||
shape: Sequence[int],
|
||||
dimension: int,
|
||||
*,
|
||||
layout: Layout | None = None,
|
||||
) -> jax.Array:
|
||||
return broadcasted_iota_p.bind(
|
||||
dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout
|
||||
)
|
||||
|
@ -461,8 +461,6 @@ class KeyTy(dtypes.ExtendedDType):
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
|
||||
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
|
||||
|
@ -249,8 +249,8 @@ class XlaExecutable(Executable):
|
||||
else:
|
||||
raise
|
||||
|
||||
# TODO(skyewm): this should return a single dict (I think returning a list
|
||||
# was to support MPMD executables, which never fully landed)
|
||||
# TODO(b/384741132): this should return a single dict (I think returning a list
|
||||
# was to support MPMD executables, which never fully landed).
|
||||
def cost_analysis(self) -> list[dict[str, float]]:
|
||||
xla_ext_exe = self.xla_extension_executable()
|
||||
|
||||
@ -266,9 +266,19 @@ class XlaExecutable(Executable):
|
||||
# Try client method if executable cost_analysis method is unimplemented
|
||||
if hasattr(xla_ext_exe, "client"):
|
||||
try:
|
||||
# TODO(b/384741132): We expect that the executable has only one
|
||||
# HloModule. We should be able to remove this check once we update the
|
||||
# Executable class to have only a single HloModule (see bug).
|
||||
hlo_modules = xla_ext_exe.hlo_modules()
|
||||
assert len(hlo_modules) == 1, (
|
||||
f"Exectuable should have only one HloModule ({len(hlo_modules)})"
|
||||
" were found)."
|
||||
)
|
||||
|
||||
return [
|
||||
xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m)
|
||||
for m in xla_ext_exe.hlo_modules()
|
||||
xla_extension.hlo_module_cost_analysis(
|
||||
xla_ext_exe.client, hlo_modules[0]
|
||||
)
|
||||
]
|
||||
except xla_extension.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
|
@ -83,6 +83,7 @@ class MemorySpace(enum.Enum):
|
||||
HBM = enum.auto()
|
||||
VMEM = enum.auto()
|
||||
SEMAPHORE_MEM = enum.auto()
|
||||
SMEM = enum.auto()
|
||||
|
||||
@property
|
||||
def color(self) -> int:
|
||||
@ -92,6 +93,8 @@ class MemorySpace(enum.Enum):
|
||||
return 1
|
||||
elif self == MemorySpace.SEMAPHORE_MEM:
|
||||
return 2
|
||||
elif self == MemorySpace.SMEM:
|
||||
return 4
|
||||
else:
|
||||
raise ValueError("invalid memory space: " + str(self))
|
||||
|
||||
|
@ -128,7 +128,7 @@ _deprecations = {
|
||||
_src_core.escaped_tracer_error),
|
||||
"extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.",
|
||||
_src_core.extend_axis_env_nd),
|
||||
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_type),
|
||||
"get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval),
|
||||
"get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent),
|
||||
"join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects),
|
||||
"leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.",
|
||||
@ -212,7 +212,7 @@ if typing.TYPE_CHECKING:
|
||||
escaped_tracer_error = _src_core.escaped_tracer_error
|
||||
extend_axis_env_nd = _src_core.extend_axis_env_nd
|
||||
full_lower = _src_core.full_lower
|
||||
get_type = _src_core.get_type
|
||||
get_type = _src_core.get_aval
|
||||
get_referent = _src_core.get_referent
|
||||
jaxpr_as_fun = _src_core.jaxpr_as_fun
|
||||
join_effects = _src_core.join_effects
|
||||
|
@ -724,7 +724,6 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
aval_in, aval_out, x):
|
||||
if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1:
|
||||
return x
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
|
||||
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
||||
@ -734,6 +733,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
unspecified = set(range(aval_in.ndim)) if auto else set()
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
|
||||
unspecified_dims=unspecified)
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
|
||||
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)
|
||||
|
||||
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
@ -746,6 +746,8 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
ns = sharding_impls.physical_sharding(aval_out, ns)
|
||||
aval_out = core.physical_aval(aval_out)
|
||||
unspecified = set(range(aval_out.ndim)) if auto else set()
|
||||
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
||||
aval_in = core.physical_aval(aval_in)
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified)
|
||||
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
||||
|
@ -36,17 +36,17 @@ _deprecations = {
|
||||
"jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.",
|
||||
None,
|
||||
),
|
||||
# Added Sep 26 2024
|
||||
# Finalized 2024-12-23; remove after 2024-03-23
|
||||
"Device": (
|
||||
"jax.lib.xla_client.Device is deprecated; use jax.Device instead.",
|
||||
_xc.Device,
|
||||
None,
|
||||
),
|
||||
"XlaRuntimeError": (
|
||||
(
|
||||
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
|
||||
" jax.errors.JaxRuntimeError."
|
||||
),
|
||||
_xc.XlaRuntimeError,
|
||||
None,
|
||||
),
|
||||
# Added Oct 10 2024
|
||||
"FftType": (
|
||||
|
@ -70,8 +70,8 @@ py_extension(
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIGPUHeaders",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -85,7 +85,8 @@ py_extension(
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIGPUHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -100,8 +101,8 @@ py_extension(
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:CAPINVGPUHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -116,8 +117,8 @@ py_extension(
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:CAPILLVMHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -131,8 +132,8 @@ py_extension(
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -146,7 +147,8 @@ py_extension(
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
||||
"@pybind11",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
@ -156,9 +158,10 @@ py_extension(
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
@ -378,6 +381,7 @@ cc_library(
|
||||
name = "jaxlib_mlir_capi_objects",
|
||||
deps = [
|
||||
"//jaxlib/mosaic:tpu_dialect_capi_objects",
|
||||
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects",
|
||||
"@llvm-project//mlir:CAPIArithObjects",
|
||||
"@llvm-project//mlir:CAPIGPUObjects",
|
||||
"@llvm-project//mlir:CAPIIRObjects",
|
||||
|
@ -28,6 +28,7 @@ package(
|
||||
py_library(
|
||||
name = "mosaic",
|
||||
deps = [
|
||||
"//jaxlib/mosaic/python:gpu_dialect",
|
||||
"//jaxlib/mosaic/python:tpu_dialect",
|
||||
],
|
||||
)
|
||||
@ -42,6 +43,7 @@ cc_library(
|
||||
"dialect/tpu/tpu_dialect.cc",
|
||||
"dialect/tpu/tpu_ops.cc",
|
||||
"dialect/tpu/util.cc",
|
||||
"dialect/tpu/vreg_util.cc",
|
||||
":extension_srcs",
|
||||
] + glob([
|
||||
"dialect/tpu/transforms/*.cc",
|
||||
@ -50,6 +52,7 @@ cc_library(
|
||||
"dialect/tpu/layout.h",
|
||||
"dialect/tpu/tpu_dialect.h",
|
||||
"dialect/tpu/util.h",
|
||||
"dialect/tpu/vreg_util.h",
|
||||
] + glob([
|
||||
"dialect/tpu/transforms/*.h",
|
||||
]),
|
||||
@ -231,6 +234,19 @@ cc_library(
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "vreg_util_test",
|
||||
srcs = ["dialect/tpu/vreg_util_test.cc"],
|
||||
deps = [
|
||||
":tpu_dialect",
|
||||
"//testing/base/public:gunit_main",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:VectorDialect",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "extension_srcs",
|
||||
srcs = [
|
||||
|
@ -215,3 +215,26 @@ cc_library(
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
],
|
||||
)
|
||||
|
||||
# Header-only target, used when using the C API from a separate shared library.
|
||||
cc_library(
|
||||
name = "gpu_dialect_capi_headers",
|
||||
hdrs = DIALECT_CAPI_HEADERS,
|
||||
deps = [
|
||||
":mosaic_gpu_inc_gen",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
],
|
||||
)
|
||||
|
||||
# Alwayslink target, used when exporting the C API from a shared library.
|
||||
cc_library(
|
||||
name = "gpu_dialect_capi_objects",
|
||||
srcs = DIALECT_CAPI_SOURCES,
|
||||
hdrs = DIALECT_CAPI_HEADERS,
|
||||
deps = [
|
||||
":mosaic_gpu",
|
||||
":mosaic_gpu_inc_gen",
|
||||
"@llvm-project//mlir:CAPIIRObjects",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
@ -19,8 +19,8 @@ limitations under the License.
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "testing/base/public/gmock.h"
|
||||
#include "testing/base/public/gunit.h"
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
@ -29,7 +29,6 @@
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
@ -52,6 +51,7 @@
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/include/llvm/ADT/APInt.h"
|
||||
#include "llvm/include/llvm/Support/LogicalResult.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
@ -64,6 +64,7 @@
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/util.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
|
||||
#include "xla/array.h"
|
||||
#include "xla/layout.h"
|
||||
#include "xla/util.h"
|
||||
@ -275,16 +276,6 @@ void updateSliceFromRange(xla::Array<T> &arr, Range data,
|
||||
CHECK(data_it == data.end());
|
||||
}
|
||||
|
||||
FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
|
||||
if (isa<FloatType>(ty)) {
|
||||
return TypedAttr(FloatAttr::get(ty, 0));
|
||||
}
|
||||
if (isa<IntegerType>(ty)) {
|
||||
return TypedAttr(IntegerAttr::get(ty, 0));
|
||||
}
|
||||
return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty;
|
||||
}
|
||||
|
||||
FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
|
||||
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
|
||||
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
|
||||
@ -479,33 +470,6 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
|
||||
return argument;
|
||||
}
|
||||
|
||||
VectorType getNativeVregOrVmaskTypeImpl(
|
||||
Type elem_ty, const int8_t bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
if (bitwidth == 32) {
|
||||
return VectorType::get(target_shape, elem_ty);
|
||||
}
|
||||
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
|
||||
elem_ty);
|
||||
}
|
||||
|
||||
VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
|
||||
if (bitwidth == 1) {
|
||||
bitwidth = layout_bitwidth;
|
||||
} else {
|
||||
CHECK_EQ(bitwidth, layout_bitwidth);
|
||||
}
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
|
||||
}
|
||||
|
||||
VectorType getNativeVregType(Type elem_ty,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
|
||||
target_shape);
|
||||
}
|
||||
|
||||
// Masks all values outside of bounds.
|
||||
//
|
||||
// Arguments:
|
||||
@ -518,7 +482,7 @@ VectorType getNativeVregType(Type elem_ty,
|
||||
// Returns:
|
||||
// An MLIR value of the same type as the value argument, with all entries
|
||||
// outside of bounds replaced by neutral.
|
||||
FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
|
||||
FailureOr<Value> maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> value,
|
||||
const VRegDataBounds &bounds,
|
||||
const Attribute neutral) {
|
||||
@ -542,9 +506,7 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
|
||||
value.getLoc(),
|
||||
VectorType::get(native_vreg_ty.getShape(), builder.getI1Type()), mask);
|
||||
}
|
||||
auto neutral_vec = builder.create<arith::ConstantOp>(
|
||||
value.getLoc(), native_vreg_ty,
|
||||
DenseElementsAttr::get(native_vreg_ty, neutral));
|
||||
Value neutral_vec = getFullVector(builder, native_vreg_ty, neutral);
|
||||
return builder
|
||||
.create<arith::SelectOp>(value.getLoc(), mask, value, neutral_vec)
|
||||
.getResult();
|
||||
@ -1863,126 +1825,28 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
|
||||
TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]);
|
||||
TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]);
|
||||
|
||||
const VectorType i32_vreg_ty =
|
||||
getNativeVregType(builder.getI32Type(), ctx.target_shape);
|
||||
auto getX32VmaskByPaddingEnd = [&](int64_t dim, int64_t padding) {
|
||||
CHECK(dim == 0 || dim == 1);
|
||||
CHECK(padding >= 0 && padding <= ctx.target_shape[dim]);
|
||||
return cast<TypedValue<VectorType>>(
|
||||
builder
|
||||
.create<arith::CmpIOp>(
|
||||
arith::CmpIPredicate::slt,
|
||||
builder.create<tpu::IotaOp>(i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(dim)),
|
||||
builder.create<arith::ConstantOp>(DenseElementsAttr::get(
|
||||
i32_vreg_ty, builder.getI32IntegerAttr(
|
||||
ctx.target_shape[dim] - padding))))
|
||||
.getResult());
|
||||
};
|
||||
|
||||
// We can also extend this helper function with padding_top and padding_left
|
||||
// based on the offsets in vregs.
|
||||
const Value i32_zeros_vreg = builder.create<arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0)));
|
||||
const Value i32_max_vreg = builder.create<arith::ConstantOp>(
|
||||
op.getLoc(), DenseElementsAttr::get(
|
||||
i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff)));
|
||||
auto maskVregs = [&](xla::Array<Value> &vregs, int64_t padding_bottom,
|
||||
int64_t padding_right) {
|
||||
auto vreg_ty = cast<VectorType>(vregs.begin()->getType());
|
||||
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
|
||||
// Mask out the bottom.
|
||||
if (padding_bottom > 0) {
|
||||
// We have limited the row size of LHS and RHS need to be a multiple of
|
||||
// native tiling at the beginning of this rule. Therefore, it is safe to
|
||||
// bitcast to x32 vreg for masking.
|
||||
int sub_padding = padding_bottom % packing;
|
||||
int x32_padding_bottom = padding_bottom / packing;
|
||||
auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom);
|
||||
// Create an int32 vreg which contains subelement masking and then
|
||||
// logical_and with target vreg to mask out the unaligned paddings.
|
||||
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
|
||||
// [8, 128], then the mask will be:
|
||||
//
|
||||
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
|
||||
// sublane 6: [0 , 0 , ..., 0 ]
|
||||
// sublane 7: [0 , 0 , ..., 0 ]
|
||||
//
|
||||
// Through this way, in order to mask sub-elements, each target vreg only
|
||||
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
|
||||
// + packing).
|
||||
Value partial_sublane_mask = builder.create<arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
DenseElementsAttr::get(
|
||||
i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(
|
||||
0xffffffff >>
|
||||
(sub_padding * vreg_ty.getElementTypeBitWidth()))));
|
||||
// Insert 0xffffffff above the blended sublane.
|
||||
Value sublane_mask = builder.create<arith::SelectOp>(
|
||||
getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg,
|
||||
partial_sublane_mask);
|
||||
// Insert 0 below the blended sublane.
|
||||
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
|
||||
i32_zeros_vreg);
|
||||
for (int64_t i = 0; i < vregs.dim(1); ++i) {
|
||||
Value &vreg = vregs({vregs.dim(0) - 1, i});
|
||||
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
|
||||
if (sub_padding > 0) {
|
||||
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
|
||||
} else {
|
||||
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
|
||||
i32_zeros_vreg);
|
||||
}
|
||||
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
|
||||
}
|
||||
}
|
||||
// Mask out the right.
|
||||
if (padding_right > 0) {
|
||||
auto mask_right = getX32VmaskByPaddingEnd(1, padding_right);
|
||||
for (int64_t i = 0; i < vregs.dim(0); ++i) {
|
||||
Value &vreg = vregs({i, vregs.dim(1) - 1});
|
||||
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
|
||||
i32_vreg = builder.create<arith::SelectOp>(mask_right, i32_vreg,
|
||||
i32_zeros_vreg);
|
||||
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create a vreg filled with zeros.
|
||||
auto getZerosVergLike =
|
||||
[&](const Value &vreg) -> FailureOr<TypedValue<VectorType>> {
|
||||
const VectorType vreg_type = cast<VectorType>(vreg.getType());
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const Attribute zero_attr,
|
||||
getZeroIntOrFloatAttr(vreg_type.getElementType()));
|
||||
return cast<TypedValue<VectorType>>(
|
||||
builder
|
||||
.create<arith::ConstantOp>(
|
||||
op.getLoc(), DenseElementsAttr::get(vreg_type, zero_attr))
|
||||
.getResult());
|
||||
};
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(auto lhs_zeros_vreg,
|
||||
getZerosVergLike(*lhs_vregs.begin()));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(auto rhs_zeros_vreg,
|
||||
getZerosVergLike(*rhs_vregs.begin()));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(auto acc_zeros_vreg,
|
||||
getZerosVergLike(*acc_vregs.begin()));
|
||||
auto lhs_zeros_vreg =
|
||||
getZerosVector(builder, cast<VectorType>(lhs_vregs.begin()->getType()));
|
||||
auto rhs_zeros_vreg =
|
||||
getZerosVector(builder, cast<VectorType>(rhs_vregs.begin()->getType()));
|
||||
auto acc_zeros_vreg =
|
||||
getZerosVector(builder, cast<VectorType>(acc_vregs.begin()->getType()));
|
||||
|
||||
// Only mask out the paddings on contracting dim of LHS and RHS.
|
||||
maskVregs(lhs_vregs, 0, padded_lhs_cols - lhs_shape[1]);
|
||||
RETURN_IF_FAILED(
|
||||
maskNativeTilingVregs(builder, lhs_vregs, ctx.target_shape,
|
||||
/*padding_bottom=*/0,
|
||||
/*padding_right=*/padded_lhs_cols - lhs_shape[1]));
|
||||
if (transpose_rhs) {
|
||||
maskVregs(rhs_vregs, 0, padded_rhs_cols - rhs_shape[1]);
|
||||
RETURN_IF_FAILED(maskNativeTilingVregs(
|
||||
builder, rhs_vregs, ctx.target_shape,
|
||||
/*padding_bottom=*/0,
|
||||
/*padding_right=*/padded_rhs_cols - rhs_shape[1]));
|
||||
} else {
|
||||
maskVregs(rhs_vregs, padded_rhs_rows - rhs_shape[0], 0);
|
||||
RETURN_IF_FAILED(
|
||||
maskNativeTilingVregs(builder, rhs_vregs, ctx.target_shape,
|
||||
/*padding_bottom=*/padded_rhs_rows - rhs_shape[0],
|
||||
/*padding_right=*/0));
|
||||
}
|
||||
|
||||
// At this point, all paddings on vregs are masked out. For now, we
|
||||
@ -2875,12 +2739,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
|
||||
native_vreg_ty,
|
||||
/*dimension =*/builder.getI32IntegerAttr(1));
|
||||
for (int64_t i = 0; i < num_tiles; ++i) {
|
||||
auto offset = builder.create<arith::ConstantOp>(
|
||||
native_vreg_ty,
|
||||
DenseElementsAttr::get(
|
||||
native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 1))));
|
||||
Value offset = getFullVector(
|
||||
builder, native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 1)));
|
||||
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
|
||||
}
|
||||
xla::Array<Value> broadcasted_tiles(tile_array_shape);
|
||||
@ -2902,12 +2764,10 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
|
||||
native_vreg_ty,
|
||||
/*dimension =*/builder.getI32IntegerAttr(0));
|
||||
for (int64_t i = 0; i < num_tiles; ++i) {
|
||||
auto offset = builder.create<arith::ConstantOp>(
|
||||
native_vreg_ty,
|
||||
DenseElementsAttr::get(
|
||||
native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 2))));
|
||||
Value offset = getFullVector(
|
||||
builder, native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 2)));
|
||||
tiles[i] = builder.create<arith::AddIOp>(vreg_iota, offset);
|
||||
}
|
||||
xla::Array<Value> broadcasted_tiles(tile_array_shape);
|
||||
@ -2924,10 +2784,8 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
|
||||
SmallVector<Value> tiles;
|
||||
tiles.reserve(vty.getDimSize(*dimension));
|
||||
for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) {
|
||||
tiles.push_back(builder.create<arith::ConstantOp>(
|
||||
native_vreg_ty,
|
||||
DenseElementsAttr::get(native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(), i))));
|
||||
tiles.push_back(getFullVector(builder, native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(), i)));
|
||||
}
|
||||
xla::Array<Value> out_tiles(tile_array_shape);
|
||||
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
@ -3516,12 +3374,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
const int64_t offset = *offsets_in[1];
|
||||
const int64_t lane_offset = offset % ctx.target_shape[1];
|
||||
const int64_t tile_offset = offset / ctx.target_shape[1];
|
||||
const auto idx_ty =
|
||||
VectorType::get(ctx.target_shape, builder.getI32Type());
|
||||
auto lane_offset_cst = builder.create<arith::ConstantOp>(
|
||||
broadcast_op.getLoc(), idx_ty,
|
||||
DenseElementsAttr::get(idx_ty,
|
||||
builder.getI32IntegerAttr(lane_offset)));
|
||||
Value lane_offset_cst = getFullVector(
|
||||
builder, getNativeVregType(builder.getI32Type(), ctx.target_shape),
|
||||
builder.getI32IntegerAttr(lane_offset));
|
||||
DenseI32ArrayAttr sublane_pattern;
|
||||
if (num_tiles != 1) {
|
||||
SmallVector<int32_t> pattern;
|
||||
@ -3581,10 +3436,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
getNativeVregType(src_i32.getType(), ctx.target_shape);
|
||||
auto tile_i32 =
|
||||
builder.create<vector::BroadcastOp>(native_vreg_ty, src_i32);
|
||||
auto zeros = builder.create<arith::ConstantOp>(
|
||||
broadcast_op.getLoc(), tile_i32.getType(),
|
||||
DenseElementsAttr::get(tile_i32.getType(),
|
||||
builder.getI32IntegerAttr(0)));
|
||||
Value zeros = getZerosVector(builder, tile_i32.getType());
|
||||
auto tile =
|
||||
builder.create<arith::CmpIOp>(arith::CmpIPredicate::ne, tile_i32, zeros)
|
||||
.getResult();
|
||||
@ -5479,8 +5331,6 @@ FailureOr<xla::Array<Value>> doColumnShiftRelayout(
|
||||
const std::array<int64_t, 2> vreg_slice = src.vregSlice(target_shape);
|
||||
const int bitwidth = src.bitwidth();
|
||||
const int packing = src.packing();
|
||||
const VectorLayout dst(bitwidth, {src.offsets()[0], dst_col_offset}, tiling,
|
||||
src.implicit_dim());
|
||||
const int64_t col_diff = dst_col_offset - *src.offsets()[1];
|
||||
if (tiling[0] % packing != 0 || tiling[1] != target_shape[1]) {
|
||||
return emitError(loc,
|
||||
@ -5823,7 +5673,8 @@ LogicalResult retileToLargeTileWithScratch(
|
||||
RewriteContext &ctx, OpBuilder &builder, const Location loc,
|
||||
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
|
||||
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
|
||||
TypedValue<MemRefType> scratch_ref) {
|
||||
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
|
||||
const int64_t load_vreg_skips) {
|
||||
if (dst_tile[0] % src_tile[0] != 0) {
|
||||
return failure();
|
||||
}
|
||||
@ -5927,8 +5778,8 @@ LogicalResult retileToLargeTileWithScratch(
|
||||
SmallVector<int64_t, 4> src_idx(rank);
|
||||
dst_tiles.Each([&](absl::Span<const int64_t> dst_idx, Value *dst_vreg) {
|
||||
int64_t dst_row_idx = *(dst_idx.end() - 2);
|
||||
int64_t dst_col_idx = *(dst_idx.end() - 1);
|
||||
int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group;
|
||||
int64_t dst_col_idx_with_skips = *(dst_idx.end() - 1) + load_vreg_skips;
|
||||
int64_t vreg_idx_in_group = dst_col_idx_with_skips % vregs_per_group;
|
||||
int64_t load_offset = sublanes_per_group * stored_group_cnt +
|
||||
vreg_idx_in_group * sl_per_vreg * stride;
|
||||
delayed_loads.push_back(
|
||||
@ -5938,16 +5789,20 @@ LogicalResult retileToLargeTileWithScratch(
|
||||
// the vregs from current group and now we need to store corresponding
|
||||
// group of src vregs before actually emitting the loads.
|
||||
if (vreg_idx_in_group == vregs_per_group - 1 ||
|
||||
dst_col_idx == dst_tiles.dimensions().back() - 1) {
|
||||
auto src_row_idx = dst_row_idx * vregs_per_group;
|
||||
auto src_col_idx = dst_col_idx / vregs_per_group;
|
||||
dst_idx.back() == dst_tiles.dimensions().back() - 1) {
|
||||
auto base_src_row_idx = dst_row_idx * vregs_per_group - store_vreg_delay;
|
||||
auto src_col_idx = dst_col_idx_with_skips / vregs_per_group;
|
||||
std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin());
|
||||
for (int vi = 0; vi < vregs_per_group; ++vi) {
|
||||
if (src_row_idx + vi >= src_tiles.dim(rank - 2) ||
|
||||
const int64_t src_row_idx = base_src_row_idx + vi;
|
||||
if (src_row_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
if (src_row_idx >= src_tiles.dim(rank - 2) ||
|
||||
src_col_idx >= src_tiles.dim(rank - 1)) {
|
||||
break;
|
||||
}
|
||||
*(src_idx.end() - 2) = src_row_idx + vi;
|
||||
*(src_idx.end() - 2) = src_row_idx;
|
||||
*(src_idx.end() - 1) = src_col_idx;
|
||||
Value src_vreg = src_tiles(src_idx);
|
||||
src_vreg =
|
||||
@ -5976,7 +5831,8 @@ LogicalResult retileToSmallTileWithScratch(
|
||||
RewriteContext &ctx, OpBuilder &builder, const Location loc,
|
||||
xla::Array<Value> &dst_tiles, const std::array<int64_t, 2> &dst_tile,
|
||||
const xla::Array<Value> &src_tiles, const std::array<int64_t, 2> &src_tile,
|
||||
TypedValue<MemRefType> scratch_ref) {
|
||||
TypedValue<MemRefType> scratch_ref, const int64_t store_vreg_delay,
|
||||
const int64_t load_vreg_skips) {
|
||||
if (src_tile[0] % dst_tile[0] != 0) {
|
||||
return failure();
|
||||
}
|
||||
@ -6103,8 +5959,8 @@ LogicalResult retileToSmallTileWithScratch(
|
||||
SmallVector<int64_t, 4> dst_idx(rank);
|
||||
src_tiles.Each([&](absl::Span<const int64_t> src_idx, Value src_vreg) {
|
||||
int64_t src_row_idx = *(src_idx.end() - 2);
|
||||
int64_t src_col_idx = *(src_idx.end() - 1);
|
||||
int64_t vreg_idx_in_group = src_col_idx % vregs_per_group;
|
||||
int64_t src_col_idx_with_delays = *(src_idx.end() - 1) + store_vreg_delay;
|
||||
int64_t vreg_idx_in_group = src_col_idx_with_delays % vregs_per_group;
|
||||
src_vreg = builder.create<tpu::BitcastVregOp>(loc, temp_vreg_ty, src_vreg);
|
||||
if (use_shuffled_load) {
|
||||
Value store_offset = mlirIndexConst(
|
||||
@ -6126,16 +5982,20 @@ LogicalResult retileToSmallTileWithScratch(
|
||||
// vregs' row, this indicates we have stored all the vregs needed to
|
||||
// assemble a new group of dst vreg.
|
||||
if (vreg_idx_in_group == vregs_per_group - 1 ||
|
||||
src_col_idx == src_tiles.dimensions().back() - 1) {
|
||||
auto dst_row_idx = src_row_idx * vregs_per_group;
|
||||
auto dst_col_idx = src_col_idx / vregs_per_group;
|
||||
src_idx.back() == src_tiles.dimensions().back() - 1) {
|
||||
auto base_dst_row_idx = src_row_idx * vregs_per_group - load_vreg_skips;
|
||||
auto dst_col_idx = src_col_idx_with_delays / vregs_per_group;
|
||||
std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin());
|
||||
for (int vi = 0; vi < vregs_per_group; ++vi) {
|
||||
if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) ||
|
||||
const int64_t dst_row_idx = base_dst_row_idx + vi;
|
||||
if (dst_row_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
if (dst_row_idx >= dst_tiles.dim(rank - 2) ||
|
||||
dst_col_idx >= dst_tiles.dim(rank - 1)) {
|
||||
break;
|
||||
}
|
||||
*(dst_idx.end() - 2) = dst_row_idx + vi;
|
||||
*(dst_idx.end() - 2) = dst_row_idx;
|
||||
*(dst_idx.end() - 1) = dst_col_idx;
|
||||
Value *dst_vreg = &dst_tiles(dst_idx);
|
||||
int64_t load_offset =
|
||||
@ -6160,18 +6020,70 @@ LogicalResult retileToSmallTileWithScratch(
|
||||
|
||||
// go/mosaic-retiling-in-scratch is the full internal documentation that
|
||||
// includes more details about the TPU generations.
|
||||
LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
|
||||
const Location loc,
|
||||
xla::Array<Value> &dst_tiles,
|
||||
const std::array<int64_t, 2> &dst_tiling,
|
||||
const xla::Array<Value> &src_tiles,
|
||||
const std::array<int64_t, 2> &src_tiling,
|
||||
int packing) {
|
||||
// Arguments:
|
||||
// - shape: The non-implicit shape of the operand
|
||||
// - dst_tiling: The desired result tiling
|
||||
// - dst_offsets_hint: Hints for the result offsets. They may be used or
|
||||
// ignored. See comments in the body of the function for
|
||||
// more details.
|
||||
// - src_vregs: The source vregs to retile.
|
||||
// - src: The source layout
|
||||
// Returns a pair holding the result layout (potentially using the hints) and
|
||||
// the retiled vregs.
|
||||
// TODO(tlongeri): Clean up the function parameters/signatures. We are passing
|
||||
// in more information than strictly needed.
|
||||
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> retileWithScratch(
|
||||
RewriteContext &ctx, OpBuilder &builder, const Location loc,
|
||||
const ArrayRef<int64_t> shape, const std::array<int64_t, 2> dst_tiling,
|
||||
const LayoutOffsets dst_offsets_hint, const xla::Array<Value> &src_vregs,
|
||||
const VectorLayout &src) {
|
||||
const int bitwidth = src.bitwidth();
|
||||
const int packing = src.packing();
|
||||
const std::array<int64_t, 2> src_tiling = src.tiling();
|
||||
if (!(src_tiling[1] == ctx.target_shape[1] &&
|
||||
dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 &&
|
||||
dst_tiling[0] % packing == 0)) {
|
||||
return failure();
|
||||
}
|
||||
const std::array<int64_t, 2> src_vreg_slice =
|
||||
VectorLayout::vregSlice(ctx.target_shape, bitwidth, src_tiling);
|
||||
const std::array<int64_t, 2> dst_vreg_slice =
|
||||
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
|
||||
|
||||
// TODO(b/368088671): When sublane tiling changes, we should be able to
|
||||
// preserve some replications from the source layout. But we need to
|
||||
// make sure they are implemented efficiently and well-tested. For now, we
|
||||
// just simply use 0 for the replicated offset after retiling.
|
||||
const LayoutOffsets src_offsets = {src.offsets()[0].value_or(0),
|
||||
src.offsets()[1].value_or(0)};
|
||||
// The provided offset hints are used only if they align with the source
|
||||
// offsets, else we default to the smallest possible aligned offsets.
|
||||
LayoutOffsets dst_offsets = {*src_offsets[0] % dst_vreg_slice[0],
|
||||
*src_offsets[1] % dst_vreg_slice[1]};
|
||||
// On a given dimension, either the source vreg slice size divides the dest
|
||||
// vreg slice size, or vice versa (depending on the dimension and whether it's
|
||||
// small-to-large or large-to-small retiling). Offset changes are supported
|
||||
// as long as they are aligned modulo the smaller of the two sizes.
|
||||
const std::array<int64_t, 2> alignment = {
|
||||
std::min(src_vreg_slice[0], dst_vreg_slice[0]),
|
||||
std::min(src_vreg_slice[1], dst_vreg_slice[1])};
|
||||
if (dst_offsets_hint[0].has_value() &&
|
||||
(*dst_offsets_hint[0] - *src_offsets[0]) % alignment[0] == 0) {
|
||||
CHECK_LT(*dst_offsets_hint[0], dst_vreg_slice[0]);
|
||||
dst_offsets[0] = *dst_offsets_hint[0];
|
||||
}
|
||||
if (dst_offsets_hint[1].has_value() &&
|
||||
(*dst_offsets_hint[1] - *src_offsets[1]) % alignment[1] == 0) {
|
||||
CHECK_LT(*dst_offsets_hint[1], dst_vreg_slice[1]);
|
||||
dst_offsets[1] = *dst_offsets_hint[1];
|
||||
}
|
||||
// The offsets of the source in units of the destination vreg slice:
|
||||
const std::array<int64_t, 2> src_offsets_in_dst_vreg_slices = {
|
||||
*src_offsets[0] / dst_vreg_slice[0], *src_offsets[1] / dst_vreg_slice[1]};
|
||||
// The offsets of the destination in units of the source vreg slice:
|
||||
const std::array<int64_t, 2> dst_offsets_in_src_vreg_slices = {
|
||||
*dst_offsets[0] / src_vreg_slice[0], *dst_offsets[1] / src_vreg_slice[1]};
|
||||
|
||||
// Try to get i32 vector scratch space. Because we will bitcast vregs to
|
||||
// i32 vregs before using scratch for retiling. Through this way we can
|
||||
// handle packed types as well.
|
||||
@ -6186,24 +6098,57 @@ LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder,
|
||||
dst_tiling[1]};
|
||||
std::array<int64_t, 2> vi32_src_tiling = {src_tiling[0] / packing,
|
||||
src_tiling[1]};
|
||||
|
||||
const VectorLayout dst(bitwidth, dst_offsets, dst_tiling, src.implicit_dim());
|
||||
TPU_ASSERT_LOC(loc, dst.isValid(ctx.target_shape));
|
||||
xla::Array<Value> dst_vregs(
|
||||
dst.tileArrayImplicitShape(shape, ctx.target_shape));
|
||||
// When differences in offsets exist, the source vregs may stored at an offset
|
||||
// position in their group. For example, the 1st vreg in a row/column may be
|
||||
// stored as if it was the 3rd, so that the parts corresponding to the 1st and
|
||||
// 2nd in the destination are filled with padding. Likewise, loads to
|
||||
// destination vregs may be skipped, when they would load only padding.
|
||||
// store_vreg_delay is the position offset for stores, and load_vreg_skips is
|
||||
// the position offset for loads.
|
||||
//
|
||||
// For example, suppose we are going from 32-bit {0, 128}(2, 128) to
|
||||
// {4, 0}(8, 128). We form groups of 4 vregs that represent an (8, 512) slice
|
||||
// of the padded implicit shape. For the given offsets, for the first group,
|
||||
// the data is in (4:8, 128:512). But the first and second sources (stored
|
||||
// vregs) of the group form the slices of data (0:2, 0:512) and (2:4, 0:512),
|
||||
// which should be all padding. Likewise, the first dest vreg slice (which we
|
||||
// load from) holds the data from slice (0:8, 0:128), which is all padding.
|
||||
// We never load or store to slices that should contain only padding.
|
||||
if (src_tiling[0] > dst_tiling[0]) {
|
||||
return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles,
|
||||
vi32_dst_tiling, src_tiles,
|
||||
vi32_src_tiling, ref);
|
||||
DCHECK_EQ(src_offsets_in_dst_vreg_slices[1], 0);
|
||||
DCHECK_EQ(dst_offsets_in_src_vreg_slices[0], 0);
|
||||
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[1];
|
||||
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[0];
|
||||
if (failed(retileToSmallTileWithScratch(
|
||||
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
|
||||
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
if (src_tiling[0] < dst_tiling[0]) {
|
||||
return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles,
|
||||
vi32_dst_tiling, src_tiles,
|
||||
vi32_src_tiling, ref);
|
||||
DCHECK_EQ(src_offsets_in_dst_vreg_slices[0], 0);
|
||||
DCHECK_EQ(dst_offsets_in_src_vreg_slices[1], 0);
|
||||
const int64_t store_vreg_delay = dst_offsets_in_src_vreg_slices[0];
|
||||
const int64_t load_vreg_skips = src_offsets_in_dst_vreg_slices[1];
|
||||
if (failed(retileToLargeTileWithScratch(
|
||||
ctx, builder, loc, dst_vregs, vi32_dst_tiling, src_vregs,
|
||||
vi32_src_tiling, ref, store_vreg_delay, load_vreg_skips))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
dst_tiles = std::move(src_tiles);
|
||||
return success();
|
||||
return std::make_pair(dst, dst_vregs);
|
||||
}
|
||||
|
||||
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty,
|
||||
const VectorLayout src, xla::Array<Value> vregs,
|
||||
const std::array<int64_t, 2> dst_tiling, bool try_replicate_rows) {
|
||||
const std::array<int64_t, 2> dst_tiling,
|
||||
const LayoutOffsets dst_offsets_hint) {
|
||||
bool has_enough_scratch = ctx.max_sublanes_in_scratch >=
|
||||
ctx.target_shape[0] * (ctx.target_shape[0] + 1);
|
||||
const auto &target_shape = ctx.target_shape;
|
||||
@ -6219,6 +6164,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
const int8_t bitwidth = src.bitwidth();
|
||||
const std::array<int64_t, 2> dst_vreg_slice =
|
||||
VectorLayout::vregSlice(ctx.target_shape, bitwidth, dst_tiling);
|
||||
// TODO(tlongeri): Using canonical vs non-canonical offsets can change the
|
||||
// value of try_replicate rows, and it breaks some tests. It doesn't make
|
||||
// sense that we have different behavior for equivalent layouts, though. We
|
||||
// need better logic for picking the relayout strategy.
|
||||
const bool try_replicate_rows =
|
||||
src.offsets()[0].has_value() && !dst_offsets_hint[0].has_value();
|
||||
|
||||
// Fully replicated offsets are handled efficiently elsewhere (in relayout)
|
||||
CHECK(src.offsets()[0].has_value() || src.offsets()[1].has_value());
|
||||
@ -6290,15 +6241,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
});
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
|
||||
src.implicit_dim());
|
||||
if (!dst.isValid(target_shape)) {
|
||||
return emitError(loc, "Not implemented: invalid offsets in tiling target");
|
||||
}
|
||||
auto dst_tiles_shape =
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
|
||||
// (8,128) -> (8 * packing,128) tiling change for packed type.
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
|
||||
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
|
||||
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
|
||||
32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
|
||||
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
|
||||
ctx.target_shape[1]}) {
|
||||
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
|
||||
@ -6308,8 +6254,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
// not, since it relies on the src vreg array shape to know how many tiles
|
||||
// to pack in dst, and vreg array shapes with materialized offsets are
|
||||
// unfortunately not equal to vreg array shapes with replicated offsets.
|
||||
CHECK(dst.offsets() == src_offsets);
|
||||
xla::Array<Value> retiled(dst_tiles_shape);
|
||||
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
|
||||
src.implicit_dim());
|
||||
xla::Array<Value> retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
VectorType vreg_x32 =
|
||||
vty.getElementType().isSignlessInteger()
|
||||
? VectorType::get(target_shape, builder.getI32Type())
|
||||
@ -6357,7 +6305,9 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
// interesting if the next step is a retile, since we can also
|
||||
// match corresponding elements without shifting. It's just that
|
||||
// the tiles are not adjacent (no contiguous vreg slice).
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 &&
|
||||
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
|
||||
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
|
||||
32 % bitwidth == 0 &&
|
||||
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
|
||||
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
|
||||
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
|
||||
@ -6406,8 +6356,10 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
// not, since it relies on the src vreg array shape to know how many tiles
|
||||
// to pack in dst, and vreg array shapes with materialized offsets are
|
||||
// unfortunately not equal to vreg array shapes with replicated offsets.
|
||||
CHECK(dst.offsets() == src.offsets());
|
||||
xla::Array<Value> retiled(dst_tiles_shape);
|
||||
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
|
||||
src.implicit_dim());
|
||||
xla::Array<Value> retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
const VectorType vreg_x32 =
|
||||
vty.getElementType().isSignlessInteger()
|
||||
? VectorType::get(target_shape, builder.getI32Type())
|
||||
@ -6444,24 +6396,25 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) {
|
||||
// TODO(b/368088671): When sublane tiling changes, we should be able to
|
||||
// preserve some replications from the source layout. But we need to
|
||||
// make sure they are implemented efficiently and well-tested. For now, we
|
||||
// just simply use 0 for the replicated offset after retiling.
|
||||
dst = VectorLayout(
|
||||
bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)},
|
||||
dst_tiling, dst.implicit_dim());
|
||||
|
||||
// All clauses in the and expression are based on performance benchmarking.
|
||||
bool use_alu = !has_enough_scratch ||
|
||||
(ctx.hardware_generation >= 5 && src_tiling[0] != packing &&
|
||||
dst_tiling[0] != packing);
|
||||
|
||||
if (use_alu) {
|
||||
if (src_tiling[0] > dst_tiling[0]) {
|
||||
return std::pair(
|
||||
dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs,
|
||||
dst, target_shape));
|
||||
if (src_tiling[0] > dst_tiling[0] &&
|
||||
// retileToReducedSublanes does not support offset changes
|
||||
src.offsets()[0].value_or(0) < dst_vreg_slice[0] &&
|
||||
src.offsets()[1].value_or(0) < dst_vreg_slice[1]) {
|
||||
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
|
||||
src.implicit_dim());
|
||||
return std::pair(dst, retileToReducedSublanes(
|
||||
builder, vty.getShape(), src, vregs,
|
||||
VectorLayout(bitwidth,
|
||||
{src.offsets()[0].value_or(0),
|
||||
src.offsets()[1].value_or(0)},
|
||||
dst_tiling, dst.implicit_dim()),
|
||||
target_shape));
|
||||
} else if (!has_enough_scratch) {
|
||||
// TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops.
|
||||
return emitError(
|
||||
@ -6469,15 +6422,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
"Not implemented: retiling to increase sublane tiling with ALU");
|
||||
}
|
||||
}
|
||||
xla::Array<Value> retiled(dst_tiles_shape);
|
||||
if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs,
|
||||
src_tiling, packing))) {
|
||||
return failure();
|
||||
}
|
||||
return std::pair(dst, std::move(retiled));
|
||||
return retileWithScratch(ctx, builder, loc, vty.getShape(), dst_tiling,
|
||||
dst_offsets_hint, vregs, src);
|
||||
}
|
||||
return emitError(loc, "Not implemented: Unsupported tiling change for ")
|
||||
<< vty << ": from " << src << " to " << dst;
|
||||
<< vty << ": from " << src << " to (" << dst_tiling[0] << ", "
|
||||
<< dst_tiling[1] << ") tiling";
|
||||
}
|
||||
|
||||
FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
|
||||
@ -6737,9 +6687,7 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
std::tie(src, src_tiles),
|
||||
changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles),
|
||||
dst.tiling(),
|
||||
dst.offsets()[0] == std::nullopt &&
|
||||
src.offsets()[0] != std::nullopt));
|
||||
dst.tiling(), dst.offsets()));
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
std::tie(src, src_tiles),
|
||||
|
@ -1647,10 +1647,17 @@ class VectorLayoutInferer {
|
||||
Layout dst_layout;
|
||||
if (layout.tiling() == nativeTiling(src_bitwidth)) {
|
||||
// If the source is already in native tiling, we can unpack it directly.
|
||||
src_layout = layout;
|
||||
std::array<int64_t, 2> dst_native_tiling = nativeTiling(dst_bitwidth);
|
||||
LayoutOffsets offsets = {layout.offsets()[0]
|
||||
? *layout.offsets()[0] % dst_native_tiling[0]
|
||||
: LayoutOffset(),
|
||||
layout.offsets()[1]};
|
||||
DCHECK_LT(offsets[1].value_or(0), dst_native_tiling[1]);
|
||||
src_layout = VectorLayout(src_bitwidth, offsets, layout.tiling(),
|
||||
layout.implicit_dim());
|
||||
dst_layout =
|
||||
VectorLayout(dst_bitwidth, layout.offsets(),
|
||||
nativeTiling(dst_bitwidth), layout.implicit_dim());
|
||||
VectorLayout(dst_bitwidth, offsets, dst_native_tiling,
|
||||
layout.implicit_dim());
|
||||
} else if (dst_bitwidth == 32 &&
|
||||
default_tiling_[0] % layout.tiling()[0] == 0 &&
|
||||
default_tiling_[1] == layout.tiling()[1]) {
|
||||
@ -1659,13 +1666,17 @@ class VectorLayoutInferer {
|
||||
// tiling through the op.
|
||||
// TODO(jevinjiang): we can relax this for non-32bit as well.
|
||||
src_layout = layout;
|
||||
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
|
||||
layout.implicit_dim());
|
||||
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(),
|
||||
src_layout->tiling(), layout.implicit_dim());
|
||||
} else {
|
||||
// TODO(b/335863273): we should also reduce offsets.
|
||||
src_layout = VectorLayout(src_bitwidth, layout.offsets(), default_tiling_,
|
||||
LayoutOffsets offsets = {
|
||||
layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0]
|
||||
: LayoutOffset(),
|
||||
layout.offsets()[1] ? *layout.offsets()[1] % default_tiling_[1]
|
||||
: LayoutOffset()};
|
||||
src_layout = VectorLayout(src_bitwidth, offsets, default_tiling_,
|
||||
layout.implicit_dim());
|
||||
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), default_tiling_,
|
||||
dst_layout = VectorLayout(dst_bitwidth, offsets, default_tiling_,
|
||||
layout.implicit_dim());
|
||||
}
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
|
206
jaxlib/mosaic/dialect/tpu/vreg_util.cc
Normal file
206
jaxlib/mosaic/dialect/tpu/vreg_util.cc
Normal file
@ -0,0 +1,206 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/include/mlir/IR/Diagnostics.h"
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/include/mlir/IR/Types.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/util.h"
|
||||
#include "xla/array.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
namespace {
|
||||
|
||||
VectorType getNativeVregOrVmaskTypeImpl(
|
||||
Type elem_ty, const int8_t bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
if (bitwidth == 32) {
|
||||
return VectorType::get(target_shape, elem_ty);
|
||||
}
|
||||
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
|
||||
elem_ty);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
|
||||
if (bitwidth == 1) {
|
||||
bitwidth = layout_bitwidth;
|
||||
} else {
|
||||
CHECK_EQ(bitwidth, layout_bitwidth);
|
||||
}
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
|
||||
}
|
||||
|
||||
VectorType getNativeVregType(Type elem_ty,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
|
||||
target_shape);
|
||||
}
|
||||
|
||||
TypedValue<VectorType> getFullVector(ImplicitLocOpBuilder &builder,
|
||||
VectorType vty, Attribute value) {
|
||||
return cast<TypedValue<VectorType>>(
|
||||
builder.create<arith::ConstantOp>(DenseElementsAttr::get(vty, value))
|
||||
.getResult());
|
||||
}
|
||||
|
||||
TypedValue<VectorType> getFullLikeVector(ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> vec,
|
||||
Attribute value) {
|
||||
return getFullVector(builder, vec.getType(), value);
|
||||
}
|
||||
|
||||
TypedValue<VectorType> getZerosVector(ImplicitLocOpBuilder &builder,
|
||||
VectorType vty) {
|
||||
return getFullVector(builder, vty, builder.getZeroAttr(vty.getElementType()));
|
||||
}
|
||||
|
||||
TypedValue<VectorType> getZerosLikeVector(ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> vec) {
|
||||
return getZerosVector(builder, vec.getType());
|
||||
}
|
||||
|
||||
FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
|
||||
ImplicitLocOpBuilder &builder, int64_t padding,
|
||||
const std::array<int64_t, 2> target_shape, int64_t dim) {
|
||||
VectorType i32_vreg_ty =
|
||||
getNativeVregType(builder.getI32Type(), target_shape);
|
||||
if (dim != 0 && dim != 1) {
|
||||
return builder.emitError()
|
||||
<< "Expected a 2D vector for getX32VmaskByPaddingEnd";
|
||||
}
|
||||
|
||||
if (padding < 0 || padding > target_shape[dim]) {
|
||||
return builder.emitError()
|
||||
<< "Padding must be in [0, target_shape[dim]). Padding: " << padding
|
||||
<< ", target_shape[dim]: " << target_shape[dim];
|
||||
}
|
||||
|
||||
Value padding_vreg =
|
||||
getFullVector(builder, i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(target_shape[dim] - padding));
|
||||
|
||||
return cast<TypedValue<VectorType>>(
|
||||
builder
|
||||
.create<arith::CmpIOp>(
|
||||
arith::CmpIPredicate::slt,
|
||||
builder.create<tpu::IotaOp>(i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(dim)),
|
||||
padding_vreg)
|
||||
.getResult());
|
||||
}
|
||||
|
||||
LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder,
|
||||
xla::Array<Value> &vregs,
|
||||
std::array<int64_t, 2> target_shape,
|
||||
int64_t padding_bottom,
|
||||
int64_t padding_right) {
|
||||
auto vreg_ty = dyn_cast<VectorType>(vregs.begin()->getType());
|
||||
if (!vreg_ty) {
|
||||
return builder.emitError() << "Expected a vector type";
|
||||
}
|
||||
|
||||
VectorType i32_vreg_ty =
|
||||
getNativeVregType(builder.getI32Type(), target_shape);
|
||||
Value i32_zeros_vreg = getZerosVector(builder, i32_vreg_ty);
|
||||
Value i32_max_vreg = getFullVector(builder, i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(0xffffffff));
|
||||
|
||||
int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1;
|
||||
// Mask out the bottom.
|
||||
if (padding_bottom > 0) {
|
||||
// The function is only called when the vreg has native tiling. Therefore,
|
||||
// it is safe to bitcast to x32 vreg for masking.
|
||||
int sub_padding = padding_bottom % packing;
|
||||
int x32_padding_bottom = padding_bottom / packing;
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
Value mask_top, getX32VmaskByPaddingEnd(builder, x32_padding_bottom + 1,
|
||||
target_shape, /*dim=*/0));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
Value mask_bottom,
|
||||
getX32VmaskByPaddingEnd(builder, x32_padding_bottom, target_shape,
|
||||
/*dim=*/0));
|
||||
// Create an int32 vreg which contains subelement masking and then
|
||||
// logical_and with target vreg to mask out the unaligned paddings.
|
||||
// Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is
|
||||
// [8, 128], then the mask will be:
|
||||
//
|
||||
// sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff]
|
||||
// sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff]
|
||||
// sublane 6: [0 , 0 , ..., 0 ]
|
||||
// sublane 7: [0 , 0 , ..., 0 ]
|
||||
//
|
||||
// Through this way, in order to mask sub-elements, each target vreg only
|
||||
// needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select
|
||||
// + packing).
|
||||
Value partial_sublane_mask = getFullVector(
|
||||
builder, i32_vreg_ty,
|
||||
builder.getI32IntegerAttr(
|
||||
0xffffffff >> (sub_padding * vreg_ty.getElementTypeBitWidth())));
|
||||
// Insert 0xffffffff above the blended sublane.
|
||||
Value sublane_mask = builder.create<arith::SelectOp>(mask_top, i32_max_vreg,
|
||||
partial_sublane_mask);
|
||||
// Insert 0 below the blended sublane.
|
||||
sublane_mask = builder.create<arith::SelectOp>(mask_bottom, sublane_mask,
|
||||
i32_zeros_vreg);
|
||||
for (int64_t i = 0; i < vregs.dim(1); ++i) {
|
||||
Value &vreg = vregs({vregs.dim(0) - 1, i});
|
||||
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
|
||||
if (sub_padding > 0) {
|
||||
i32_vreg = builder.create<arith::AndIOp>(i32_vreg, sublane_mask);
|
||||
} else {
|
||||
i32_vreg = builder.create<arith::SelectOp>(mask_bottom, i32_vreg,
|
||||
i32_zeros_vreg);
|
||||
}
|
||||
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
|
||||
}
|
||||
}
|
||||
// Mask out the right.
|
||||
if (padding_right > 0) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
Value mask_right, getX32VmaskByPaddingEnd(builder, padding_right,
|
||||
target_shape, /*dim=*/1));
|
||||
for (int64_t i = 0; i < vregs.dim(0); ++i) {
|
||||
Value &vreg = vregs({i, vregs.dim(1) - 1});
|
||||
Value i32_vreg = builder.create<tpu::BitcastVregOp>(i32_vreg_ty, vreg);
|
||||
i32_vreg =
|
||||
builder.create<arith::SelectOp>(mask_right, i32_vreg, i32_zeros_vreg);
|
||||
vreg = builder.create<tpu::BitcastVregOp>(vreg_ty, i32_vreg);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
82
jaxlib/mosaic/dialect/tpu/vreg_util.h
Normal file
82
jaxlib/mosaic/dialect/tpu/vreg_util.h
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
|
||||
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/Builders.h"
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/include/mlir/IR/Types.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "xla/array.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
// Returns the native vreg or vmask type for the given element type and target
|
||||
// shape. The layout bitwidth is used for i1 (vmask) elements.
|
||||
VectorType getNativeVregOrVmaskType(Type elem_ty, int8_t layout_bitwidth,
|
||||
std::array<int64_t, 2> target_shape);
|
||||
VectorType getNativeVregType(Type elem_ty, std::array<int64_t, 2> target_shape);
|
||||
|
||||
// Returns a zero constant of the same type as `vty`.
|
||||
TypedValue<VectorType> getZerosVector(ImplicitLocOpBuilder &builder,
|
||||
VectorType vty);
|
||||
// Same as above, but takes a `vec` as input.
|
||||
TypedValue<VectorType> getZerosLikeVector(ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> vec);
|
||||
|
||||
// Returns a constant of the same type as `vty` with the given `value`.
|
||||
TypedValue<VectorType> getFullVector(ImplicitLocOpBuilder &builder,
|
||||
VectorType vty, Attribute value);
|
||||
// Same as above, but takes a `vec` as input.
|
||||
TypedValue<VectorType> getFullLikeVector(ImplicitLocOpBuilder &builder,
|
||||
TypedValue<VectorType> vec,
|
||||
Attribute value);
|
||||
|
||||
// Creates a vmask with false flags to bottom (dim = 0)
|
||||
// or right (dim = 1) where the flag count corresponds to the (dim_size -
|
||||
// padding).
|
||||
//
|
||||
// For example, assume vmask shape is (4, 8)
|
||||
//
|
||||
// getX32VmaskByPaddingEnd(padding=3, dim=1) creates:
|
||||
// [T, T, T, T, T, F, F, F]
|
||||
// [T, T, T, T, T, F, F, F]
|
||||
// [T, T, T, T, T, F, F, F]
|
||||
// [T, T, T, T, T, F, F, F]
|
||||
// TODO(b/385204135): Unify with getVmaskByPaddingEnd in tpu_rotate_rule, and
|
||||
// improve the codegen.
|
||||
FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
|
||||
ImplicitLocOpBuilder &builder, int64_t padding,
|
||||
std::array<int64_t, 2> target_shape, int64_t dim);
|
||||
|
||||
// Masks out the padding in the bottom and right of the vregs. vregs are
|
||||
// expected to have native tiling, and the masked vregs are mutated in
|
||||
// `vregs`. `padding_bottom` and `padding_right` is the number of elements to
|
||||
// pad in the bottom and right.
|
||||
LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder,
|
||||
xla::Array<Value> &vregs,
|
||||
std::array<int64_t, 2> target_shape,
|
||||
int64_t padding_bottom,
|
||||
int64_t padding_right);
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
||||
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_VREG_UTIL_H_
|
228
jaxlib/mosaic/dialect/tpu/vreg_util_test.cc
Normal file
228
jaxlib/mosaic/dialect/tpu/vreg_util_test.cc
Normal file
@ -0,0 +1,228 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/include/mlir/IR/Builders.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "mlir/include/mlir/IR/OwningOpRef.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/Support/DebugStringHelper.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::Eq;
|
||||
using ::testing::Optional;
|
||||
|
||||
MATCHER_P2(IsConstantOpWithSplatValue, type, splat_value, "") {
|
||||
auto constant_op = dyn_cast<arith::ConstantOp>(arg.getDefiningOp());
|
||||
if (constant_op == nullptr) {
|
||||
*result_listener << "Expected a constant op, got " << debugString(arg);
|
||||
return false;
|
||||
}
|
||||
auto dense_attr = dyn_cast<DenseElementsAttr>(constant_op.getValue());
|
||||
if (dense_attr == nullptr) {
|
||||
*result_listener << "Expected a dense elements attr, got "
|
||||
<< debugString(arg);
|
||||
return false;
|
||||
}
|
||||
if (dense_attr.getType() != type) {
|
||||
*result_listener << "Expected a dense elements attr with type "
|
||||
<< debugString(type) << ", got "
|
||||
<< debugString(dense_attr.getType());
|
||||
return false;
|
||||
}
|
||||
if (!dense_attr.isSplat()) {
|
||||
*result_listener << "Expected a splat dense elements attr, got "
|
||||
<< debugString(dense_attr);
|
||||
return false;
|
||||
}
|
||||
if (auto s = dense_attr.template getSplatValue<decltype(splat_value)>();
|
||||
s != splat_value) {
|
||||
*result_listener << "Expected a splat dense elements attr with value "
|
||||
<< splat_value << ", got " << s;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MATCHER_P2(IsVectorTypeWithShape, shape, elem_ty, "") {
|
||||
auto vty = dyn_cast<VectorType>(arg);
|
||||
if (vty == nullptr) {
|
||||
*result_listener << "Expected a vector type, got " << debugString(arg);
|
||||
return false;
|
||||
}
|
||||
if (vty.getShape() != ArrayRef<int64_t>(shape)) {
|
||||
*result_listener << "Expected a vector type with shape "
|
||||
<< absl::StrJoin(shape, ",") << ", got "
|
||||
<< absl::StrJoin(vty.getShape(), ",");
|
||||
return false;
|
||||
}
|
||||
if (vty.getElementType() != elem_ty) {
|
||||
*result_listener << "Expected a vector type with element type "
|
||||
<< debugString(elem_ty) << ", got "
|
||||
<< debugString(vty.getElementType());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
class VregUtilTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
context_.loadDialect<arith::ArithDialect, vector::VectorDialect,
|
||||
tpu::TPUDialect>();
|
||||
mlir::Location loc = mlir::UnknownLoc::get(&context_);
|
||||
mlir::OpBuilder b(&context_);
|
||||
module_ = b.create<ModuleOp>(loc);
|
||||
builder_ = std::make_unique<mlir::ImplicitLocOpBuilder>(
|
||||
module_->getLoc(), module_->getBodyRegion());
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
builder_.reset();
|
||||
// Reset the module to prevent memory leaks.
|
||||
module_ = nullptr;
|
||||
}
|
||||
|
||||
mlir::ImplicitLocOpBuilder& Builder() { return *builder_; }
|
||||
|
||||
private:
|
||||
MLIRContext context_;
|
||||
std::unique_ptr<mlir::ImplicitLocOpBuilder> builder_;
|
||||
OwningOpRef<ModuleOp> module_;
|
||||
};
|
||||
|
||||
TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeBitwidthMismatch) {
|
||||
EXPECT_DEATH(getNativeVregOrVmaskType(Builder().getI16Type(),
|
||||
/*layout_bitwidth=*/8, {2, 4}),
|
||||
"");
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetNativeVregOrVmaskTypeI1) {
|
||||
EXPECT_THAT(getNativeVregOrVmaskType(Builder().getI1Type(),
|
||||
/*layout_bitwidth=*/8, {2, 4}),
|
||||
IsVectorTypeWithShape(std::array<int64_t, 3>{2, 4, 4},
|
||||
Builder().getI1Type()));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetNativeVregF32) {
|
||||
EXPECT_THAT(getNativeVregType(Builder().getF32Type(), {2, 4}),
|
||||
IsVectorTypeWithShape(std::array<int64_t, 2>{2, 4},
|
||||
Builder().getF32Type()));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetNativeVregBf16) {
|
||||
EXPECT_THAT(getNativeVregType(Builder().getBF16Type(), {2, 4}),
|
||||
IsVectorTypeWithShape(std::array<int64_t, 3>{2, 4, 2},
|
||||
Builder().getBF16Type()));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetFullVector) {
|
||||
VectorType vty = VectorType::get({2, 4}, Builder().getI32Type());
|
||||
TypedValue<VectorType> vec =
|
||||
getFullVector(Builder(), vty, Builder().getI32IntegerAttr(0x1));
|
||||
|
||||
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0x1}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetFullLikeVector) {
|
||||
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
|
||||
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
|
||||
vty, Builder().create<arith::ConstantOp>(
|
||||
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
|
||||
TypedValue<VectorType> vec =
|
||||
getFullLikeVector(Builder(), in_vec, Builder().getF32FloatAttr(2.0f));
|
||||
|
||||
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{2.0f}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetZerosVector) {
|
||||
VectorType vty = VectorType::get({2, 4}, Builder().getI32Type());
|
||||
TypedValue<VectorType> vec = getZerosVector(Builder(), vty);
|
||||
|
||||
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetZerosLikeVector) {
|
||||
VectorType vty = VectorType::get({2, 4}, Builder().getF32Type());
|
||||
TypedValue<VectorType> in_vec = Builder().create<vector::SplatOp>(
|
||||
vty, Builder().create<arith::ConstantOp>(
|
||||
vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
|
||||
TypedValue<VectorType> vec = getZerosLikeVector(Builder(), in_vec);
|
||||
|
||||
EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{0.0f}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim0) {
|
||||
constexpr std::array<int64_t, 2> kTargetShape = {4, 8};
|
||||
FailureOr<TypedValue<VectorType>> vec = getX32VmaskByPaddingEnd(
|
||||
Builder(), /*padding=*/1, /*target_shape=*/kTargetShape,
|
||||
/*dim=*/0);
|
||||
ASSERT_TRUE(succeeded(vec));
|
||||
|
||||
auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
|
||||
ASSERT_TRUE(cmp_op != nullptr);
|
||||
EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
|
||||
|
||||
auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
|
||||
ASSERT_TRUE(iota_op != nullptr);
|
||||
EXPECT_THAT(iota_op.getDimension(), Optional(Eq(0)));
|
||||
|
||||
EXPECT_THAT(
|
||||
cmp_op.getRhs(),
|
||||
IsConstantOpWithSplatValue(
|
||||
VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{3}));
|
||||
}
|
||||
|
||||
TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim1) {
|
||||
constexpr std::array<int64_t, 2> kTargetShape = {4, 8};
|
||||
FailureOr<TypedValue<VectorType>> vec = getX32VmaskByPaddingEnd(
|
||||
Builder(), /*padding=*/3, /*target_shape=*/kTargetShape,
|
||||
/*dim=*/1);
|
||||
ASSERT_TRUE(succeeded(vec));
|
||||
|
||||
auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
|
||||
ASSERT_TRUE(cmp_op != nullptr);
|
||||
EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
|
||||
|
||||
auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
|
||||
ASSERT_TRUE(iota_op != nullptr);
|
||||
EXPECT_THAT(iota_op.getDimension(), Optional(Eq(1)));
|
||||
|
||||
EXPECT_THAT(
|
||||
cmp_op.getRhs(),
|
||||
IsConstantOpWithSplatValue(
|
||||
VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlir::tpu
|
@ -33,4 +33,5 @@ except ImportError:
|
||||
from mlir.dialects._ods_common import _cext # type: ignore[import-not-found]
|
||||
|
||||
|
||||
_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
|
||||
# Add the parent module to the search prefix
|
||||
_cext.globals.append_dialect_search_prefix(__name__[:__name__.rfind(".")])
|
||||
|
@ -83,6 +83,7 @@ setup(
|
||||
'cuda/*',
|
||||
'cuda/nvvm/libdevice/libdevice*',
|
||||
'mosaic/*.py',
|
||||
'mosaic/dialect/gpu/*.py',
|
||||
'mosaic/gpu/*.so',
|
||||
'mosaic/python/*.py',
|
||||
'mosaic/python/*.so',
|
||||
|
@ -218,6 +218,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
dst_dir=mosaic_python_dir,
|
||||
src_files=[
|
||||
"__main__/jaxlib/mosaic/python/layout_defs.py",
|
||||
"__main__/jaxlib/mosaic/python/mosaic_gpu.py",
|
||||
"__main__/jaxlib/mosaic/python/tpu.py",
|
||||
],
|
||||
)
|
||||
@ -225,6 +226,16 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
|
||||
)
|
||||
mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu"
|
||||
os.makedirs(mosaic_gpu_dir)
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py",
|
||||
dst_dir=mosaic_gpu_dir,
|
||||
)
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py",
|
||||
dst_dir=mosaic_gpu_dir,
|
||||
)
|
||||
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir",
|
||||
@ -316,6 +327,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
|
||||
|
@ -219,6 +219,29 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
[ 1 2 3 4 5 6 7 8 9 10 12 13 14]]
|
||||
"""))
|
||||
|
||||
def test_debug_print_respects_numpy_printoptions(self):
|
||||
def f(x):
|
||||
with np.printoptions(precision=2, suppress=True):
|
||||
jax.debug.print("{}", x)
|
||||
x = np.array([1.2345, 2.3456, 1E-7])
|
||||
|
||||
# Default numpy print options:
|
||||
with jtu.capture_stdout() as output:
|
||||
jax.debug.print("{}", x)
|
||||
self.assertEqual(output(), "[1.2345e+00 2.3456e+00 1.0000e-07]\n")
|
||||
|
||||
# Modified print options without JIT:
|
||||
with jtu.capture_stdout() as output:
|
||||
f(x)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
|
||||
|
||||
# Modified print options with JIT:
|
||||
with jtu.capture_stdout() as output:
|
||||
jax.jit(f)(x)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
|
||||
|
||||
|
||||
class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -254,8 +254,6 @@ def sdpa_train_fp8(
|
||||
class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
@ -366,6 +364,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_inference(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
|
||||
query = jax.random.normal(
|
||||
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
|
||||
@ -407,6 +407,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_var_seq(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
self.skipTest("Skip before fixed.")
|
||||
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
|
||||
query = jax.random.normal(
|
||||
@ -438,6 +440,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_broadcast_bias_and_dbias(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
@ -504,6 +508,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_dbias(self, batch_size: int):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
# cuDNN only supports dbias when batch size is 1. If the batch size is
|
||||
# greater, dbias is silently set to all zeros. This test verifies this
|
||||
# behavior for both vmap and regular use cases.
|
||||
@ -540,6 +546,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_sliding_window_length(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
|
||||
query = jax.random.normal(
|
||||
k1, (4, 1024, 4, 64), dtype=jnp.bfloat16)
|
||||
@ -571,8 +579,43 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
|
||||
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_large_head_size(self):
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if cudnn_version < 90500:
|
||||
self.skipTest("Requires >= cuDNN 9.5.0")
|
||||
if not jtu.is_cuda_compute_capability_at_least("9.0"):
|
||||
self.skipTest("Requires at least Hopper arch")
|
||||
|
||||
B, T, N, H = 2, 64, 2, 256
|
||||
bf16 = jnp.bfloat16
|
||||
keys = jax.random.split(jax.random.key(0), 4)
|
||||
query = jax.random.normal(keys[0], (B, T, N, H), dtype=bf16)
|
||||
key = jax.random.normal(keys[1], (B, T, N, H), dtype=bf16)
|
||||
value = jax.random.normal(keys[2], (B, T, N, H), dtype=bf16)
|
||||
grad = jax.random.normal(keys[3], (B, T, N, H), dtype=bf16)
|
||||
sdpa_train_ans = jax.jit(partial(
|
||||
sdpa_train, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0)
|
||||
)
|
||||
sdpa_train_rfc = jax.jit(partial(
|
||||
sdpa_train_ref, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0)
|
||||
)
|
||||
|
||||
out_ans, grads_ans = sdpa_train_ans(query, key, value, grad, None, None)
|
||||
out_ref, grads_ref = sdpa_train_rfc(query, key, value, grad, None, None)
|
||||
self.assertArraysAllClose(out_ref, out_ans)
|
||||
self.assertArraysAllClose(grads_ref[0], grads_ans[0])
|
||||
self.assertArraysAllClose(grads_ref[1], grads_ans[1])
|
||||
self.assertArraysAllClose(grads_ref[2], grads_ans[2])
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_layouts(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
dtype = "bfloat16"
|
||||
B, T, N, H = 4, 1024, 8, 128
|
||||
S = T
|
||||
@ -600,6 +643,8 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(dv_ref, _cvt_back(dv))
|
||||
|
||||
def test_sdpa_utils(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
test_cases = [
|
||||
(1, 257, 64, 8905, False, True, True),
|
||||
(1, 1024, 64, 8905, False, False, True),
|
||||
|
@ -327,6 +327,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1E-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-8)
|
||||
|
||||
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
|
||||
@jtu.sample_product(
|
||||
l_max=[1, 2, 3, 6],
|
||||
shape=[(5,), (10,)],
|
||||
@ -349,6 +350,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
atol=3e-3, check_dtypes=False)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)
|
||||
|
||||
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
|
||||
@jtu.sample_product(
|
||||
l_max=[3, 4, 6, 32],
|
||||
shape=[(2,), (3,), (4,), (64,)],
|
||||
@ -381,6 +383,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
rtol=1e-5, atol=1e-5, check_dtypes=False)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
|
||||
|
||||
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmAccuracy(self):
|
||||
m = jnp.arange(-3, 3)[:, None]
|
||||
@ -435,6 +438,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
|
||||
|
||||
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
|
||||
@jtu.sample_product(
|
||||
[dict(l_max=l_max, num_z=num_z)
|
||||
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
|
||||
|
@ -4386,7 +4386,7 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
|
||||
|
||||
elif name == 'log10':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag')
|
||||
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
|
||||
|
||||
elif name == 'exp':
|
||||
regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag')
|
||||
|
@ -39,8 +39,8 @@ jax_multiplatform_test(
|
||||
"tpu",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
"gpu_h100_x32",
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 8,
|
||||
|
@ -1688,8 +1688,8 @@ class PallasControlFlowTest(PallasBaseTest):
|
||||
|
||||
def body(state):
|
||||
i, s = state
|
||||
sl = jax.lax.div(i, 128)
|
||||
l = jax.lax.rem(i, 128)
|
||||
sl = jax.lax.div(i, jnp.astype(128, i.dtype))
|
||||
l = jax.lax.rem(i, jnp.astype(128, i.dtype))
|
||||
v = pl.load(x_ref, (0, sl, l))
|
||||
return i + 1, s + v
|
||||
|
||||
|
@ -2555,9 +2555,7 @@ class MiscellaneousTest(PallasBaseTest):
|
||||
|
||||
np.testing.assert_array_equal(out, np.reshape(x, (8, 128)))
|
||||
|
||||
@only_passes_in_interpret()
|
||||
def test_retiling2(self):
|
||||
"""b/348040767"""
|
||||
x = np.arange(1 * 8 * 1024, dtype=jnp.bfloat16).reshape(1, 8, 1024)
|
||||
|
||||
def kernel(x_ref, out_ref):
|
||||
|
@ -95,8 +95,8 @@ class PgleTest(jtu.JaxTestCase):
|
||||
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
compiler_options={
|
||||
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
||||
# TODO(patrios): Remove this flag once b/376647494 is fixed.
|
||||
'xla_gpu_graph_min_graph_size': '100000',
|
||||
# TODO(b/37664749): Remove this flag once the bug is fixed.
|
||||
'xla_gpu_enable_command_buffer': '',
|
||||
},
|
||||
)
|
||||
def f(x):
|
||||
@ -133,8 +133,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
compiler_options={
|
||||
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
||||
# TODO(patrios): Remove this flag once b/376647494 is fixed.
|
||||
'xla_gpu_graph_min_graph_size': '100000',
|
||||
'xla_dump_to': dump_dir,
|
||||
'xla_gpu_experimental_dump_fdo_profiles': 'True'
|
||||
},
|
||||
@ -217,8 +215,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
compiler_options={
|
||||
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
||||
# TODO(patrios): Remove this flag once b/376647494 is fixed.
|
||||
'xla_gpu_graph_min_graph_size': '100000',
|
||||
'xla_dump_to': dump_dir,
|
||||
'xla_gpu_experimental_dump_fdo_profiles': 'True'
|
||||
},
|
||||
|
@ -2207,6 +2207,23 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
#
|
||||
# f(x) # don't crash
|
||||
|
||||
def test_partial_auto_of_random_keys(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest('Shardy does not support full-to-shard.')
|
||||
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
keys = jax.random.split(jax.random.key(0), 8)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return shard_map(lambda k: k,
|
||||
mesh, in_specs=P('i'), out_specs=P('i'),
|
||||
check_rep=False, auto=frozenset({'j'}))(keys)
|
||||
|
||||
y = f(keys) # don't crash
|
||||
self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys),
|
||||
check_dtypes=False)
|
||||
|
||||
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
|
||||
# https://github.com/jax-ml/jax/pull/21032
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
|
4
third_party/xla/workspace.bzl
vendored
4
third_party/xla/workspace.bzl
vendored
@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update XLA_SHA256 with the result.
|
||||
|
||||
XLA_COMMIT = "dc7aaf834a0bb5a543f6cf98626284783a4a921c"
|
||||
XLA_SHA256 = "eda76cce64b33c00139120d6b4d4c2167d9f99dc957da54225a67ddb7ec7cb23"
|
||||
XLA_COMMIT = "ac6e71fe0cf864eec152de5ba761b76d8bef3153"
|
||||
XLA_SHA256 = "2b568ff365bc4b5c2b257002aa71f094a2b60357ceb1f2a1c6c33f4ad1a411bd"
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
|
Loading…
x
Reference in New Issue
Block a user