Pallas TPU now accepts compiler parameters only via mosaic=...

This mirrors a similar change in Pallas GPU, which uses triton=.

PiperOrigin-RevId: 624131518
This commit is contained in:
Sergei Lebedev 2024-04-12 04:33:33 -07:00 committed by jax authors
parent 3146c2a3f6
commit 386be2d307
2 changed files with 16 additions and 11 deletions

View File

@ -32,8 +32,8 @@ def encode_kernel_regeneration_metadata(
config.
Returns:
A dict that can be directly passed to pallas_call as a 'mosaic_params'
argument.
A dict that can be passed to pallas_call via
compiler_params=dict(mosaic=...)).
Raises:
TypeError: when the input metadata is not serializable in json format.

View File

@ -17,17 +17,18 @@
from __future__ import annotations
from typing import Any
import warnings
import jax
from jax import core as jax_core
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
from jax._src import sharding_impls
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.pallas import core
from jax._src.pallas.mosaic import lowering
from jax._src.pallas.pallas_call import pallas_call_p
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
def pallas_call_tpu_lowering_rule(
@ -54,11 +55,17 @@ def pallas_call_tpu_lowering_rule(
compiler_params=compiler_params)
if debug:
print(jaxpr)
if 'mosaic_params' in compiler_params:
assert 'mosaic' not in compiler_params
mosaic_params = compiler_params['mosaic_params']
elif 'mosaic' in compiler_params:
mosaic_params = compiler_params['mosaic']
if "mosaic_params" in compiler_params:
# TODO(slebedev): Remove this branch after July 12th 2024.
warnings.warn(
"Passing Mosaic parameters via compiler_params=dict(mosaic_params=...)"
" is deprecated. Use compiler_params=dict(mosaic=...) instead.",
DeprecationWarning,
)
assert "mosaic" not in compiler_params
mosaic_params = compiler_params["mosaic_params"]
elif "mosaic" in compiler_params:
mosaic_params = compiler_params["mosaic"]
else:
mosaic_params = {}
mesh = None
@ -70,8 +77,6 @@ def pallas_call_tpu_lowering_rule(
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
mlir_ctx.load_all_available_dialects()
tpu.register_dialect(mlir_ctx)
if mosaic_params is None:
mosaic_params = {}
dimension_semantics = mosaic_params.get("dimension_semantics", None)
kernel_regeneration_metadata = mosaic_params.get(
"kernel_regeneration_metadata"