mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Pallas TPU now accepts compiler parameters only via mosaic=...
This mirrors a similar change in Pallas GPU, which uses triton=. PiperOrigin-RevId: 624131518
This commit is contained in:
parent
3146c2a3f6
commit
386be2d307
@ -32,8 +32,8 @@ def encode_kernel_regeneration_metadata(
|
||||
config.
|
||||
|
||||
Returns:
|
||||
A dict that can be directly passed to pallas_call as a 'mosaic_params'
|
||||
argument.
|
||||
A dict that can be passed to pallas_call via
|
||||
compiler_params=dict(mosaic=...)).
|
||||
|
||||
Raises:
|
||||
TypeError: when the input metadata is not serializable in json format.
|
||||
|
@ -17,17 +17,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax.experimental import mosaic
|
||||
from jax.experimental.mosaic.dialects import tpu
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core
|
||||
from jax._src.pallas.mosaic import lowering
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax.experimental import mosaic
|
||||
from jax.experimental.mosaic.dialects import tpu
|
||||
|
||||
|
||||
def pallas_call_tpu_lowering_rule(
|
||||
@ -54,11 +55,17 @@ def pallas_call_tpu_lowering_rule(
|
||||
compiler_params=compiler_params)
|
||||
if debug:
|
||||
print(jaxpr)
|
||||
if 'mosaic_params' in compiler_params:
|
||||
assert 'mosaic' not in compiler_params
|
||||
mosaic_params = compiler_params['mosaic_params']
|
||||
elif 'mosaic' in compiler_params:
|
||||
mosaic_params = compiler_params['mosaic']
|
||||
if "mosaic_params" in compiler_params:
|
||||
# TODO(slebedev): Remove this branch after July 12th 2024.
|
||||
warnings.warn(
|
||||
"Passing Mosaic parameters via compiler_params=dict(mosaic_params=...)"
|
||||
" is deprecated. Use compiler_params=dict(mosaic=...) instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert "mosaic" not in compiler_params
|
||||
mosaic_params = compiler_params["mosaic_params"]
|
||||
elif "mosaic" in compiler_params:
|
||||
mosaic_params = compiler_params["mosaic"]
|
||||
else:
|
||||
mosaic_params = {}
|
||||
mesh = None
|
||||
@ -70,8 +77,6 @@ def pallas_call_tpu_lowering_rule(
|
||||
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
|
||||
mlir_ctx.load_all_available_dialects()
|
||||
tpu.register_dialect(mlir_ctx)
|
||||
if mosaic_params is None:
|
||||
mosaic_params = {}
|
||||
dimension_semantics = mosaic_params.get("dimension_semantics", None)
|
||||
kernel_regeneration_metadata = mosaic_params.get(
|
||||
"kernel_regeneration_metadata"
|
||||
|
Loading…
x
Reference in New Issue
Block a user