mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[pallas] Removed support for the deprecated pl.BlockSpec
argument order
PiperOrigin-RevId: 682036180
This commit is contained in:
parent
e79d77aa47
commit
41791ac756
@ -18,6 +18,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* {func}`jax.experimental.pallas.debug_print` no longer requires all arguments
|
||||
to be scalars. The restrictions on the arguments are backend-specific:
|
||||
Non-scalar arguments are currently only supported on GPU, when using Triton.
|
||||
* {class}`jax.experimental.pallas.BlockSpec` no longer supports the previously
|
||||
deprecated argument order, where `index_map` comes before `block_shape`.
|
||||
|
||||
* Deprecations
|
||||
|
||||
|
@ -24,13 +24,11 @@ import functools
|
||||
import itertools
|
||||
import threading
|
||||
from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh as mesh_lib
|
||||
@ -378,38 +376,11 @@ class BlockSpec:
|
||||
See :ref:`pallas_blockspec` for more details.
|
||||
"""
|
||||
# An internal canonicalized version is in BlockMapping.
|
||||
block_shape: tuple[int | None, ...] | None = None
|
||||
block_shape: Sequence[int | None] | None = None
|
||||
index_map: Callable[..., Any] | None = None
|
||||
memory_space: Any | None = dataclasses.field(kw_only=True, default=None)
|
||||
indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_shape: Any | None = None,
|
||||
index_map: Any | None = None,
|
||||
*,
|
||||
memory_space: Any | None = None,
|
||||
indexing_mode: IndexingMode = blocked,
|
||||
) -> None:
|
||||
if callable(block_shape):
|
||||
# TODO(slebedev): Remove this code path and update the signature of
|
||||
# __init__ after October 1, 2024.
|
||||
message = (
|
||||
"BlockSpec now expects ``block_shape`` to be passed before"
|
||||
" ``index_map``. Update your code by swapping the order of these"
|
||||
" arguments. For example, ``pl.BlockSpace(lambda i: i, (42,))``"
|
||||
" should be written as ``pl.BlockSpec((42,), lambda i: i)``."
|
||||
)
|
||||
if deprecations.is_accelerated("pallas-block-spec-order"):
|
||||
raise TypeError(message)
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
index_map, block_shape = block_shape, index_map
|
||||
|
||||
self.block_shape = block_shape
|
||||
self.index_map = index_map
|
||||
self.memory_space = memory_space
|
||||
self.indexing_mode = indexing_mode
|
||||
|
||||
def to_block_mapping(
|
||||
self,
|
||||
origin: OriginStr,
|
||||
|
@ -18,7 +18,6 @@ See the Pallas documentation at
|
||||
https://jax.readthedocs.io/en/latest/pallas.html.
|
||||
"""
|
||||
|
||||
from jax._src.deprecations import register as _register_deprecation
|
||||
from jax._src.pallas.core import Blocked
|
||||
from jax._src.pallas.core import BlockSpec
|
||||
from jax._src.pallas.core import CompilerParams
|
||||
@ -59,7 +58,3 @@ from jax._src.state.indexing import Slice
|
||||
from jax._src.state.primitives import broadcast_to
|
||||
|
||||
ANY = MemorySpace.ANY
|
||||
|
||||
|
||||
_register_deprecation("pallas-block-spec-order")
|
||||
del _register_deprecation
|
||||
|
Loading…
x
Reference in New Issue
Block a user