[pallas] Removed support for the deprecated pl.BlockSpec argument order

PiperOrigin-RevId: 682036180
This commit is contained in:
Sergei Lebedev 2024-10-03 14:39:24 -07:00 committed by jax authors
parent e79d77aa47
commit 41791ac756
3 changed files with 3 additions and 35 deletions

View File

@ -18,6 +18,8 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.experimental.pallas.debug_print` no longer requires all arguments
to be scalars. The restrictions on the arguments are backend-specific:
Non-scalar arguments are currently only supported on GPU, when using Triton.
* {class}`jax.experimental.pallas.BlockSpec` no longer supports the previously
deprecated argument order, where `index_map` comes before `block_shape`.
* Deprecations

View File

@ -24,13 +24,11 @@ import functools
import itertools
import threading
from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable
import warnings
import jax
from jax._src import api_util
from jax._src import config
from jax._src import core as jax_core
from jax._src import deprecations
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
@ -378,38 +376,11 @@ class BlockSpec:
See :ref:`pallas_blockspec` for more details.
"""
# An internal canonicalized version is in BlockMapping.
block_shape: tuple[int | None, ...] | None = None
block_shape: Sequence[int | None] | None = None
index_map: Callable[..., Any] | None = None
memory_space: Any | None = dataclasses.field(kw_only=True, default=None)
indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked)
def __init__(
self,
block_shape: Any | None = None,
index_map: Any | None = None,
*,
memory_space: Any | None = None,
indexing_mode: IndexingMode = blocked,
) -> None:
if callable(block_shape):
# TODO(slebedev): Remove this code path and update the signature of
# __init__ after October 1, 2024.
message = (
"BlockSpec now expects ``block_shape`` to be passed before"
" ``index_map``. Update your code by swapping the order of these"
" arguments. For example, ``pl.BlockSpace(lambda i: i, (42,))``"
" should be written as ``pl.BlockSpec((42,), lambda i: i)``."
)
if deprecations.is_accelerated("pallas-block-spec-order"):
raise TypeError(message)
warnings.warn(message, DeprecationWarning)
index_map, block_shape = block_shape, index_map
self.block_shape = block_shape
self.index_map = index_map
self.memory_space = memory_space
self.indexing_mode = indexing_mode
def to_block_mapping(
self,
origin: OriginStr,

View File

@ -18,7 +18,6 @@ See the Pallas documentation at
https://jax.readthedocs.io/en/latest/pallas.html.
"""
from jax._src.deprecations import register as _register_deprecation
from jax._src.pallas.core import Blocked
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import CompilerParams
@ -59,7 +58,3 @@ from jax._src.state.indexing import Slice
from jax._src.state.primitives import broadcast_to
ANY = MemorySpace.ANY
_register_deprecation("pallas-block-spec-order")
del _register_deprecation