2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2021-11-22 08:22:10 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
# Primitive dispatch and jit dispatch.
|
2022-04-09 10:56:14 -07:00
|
|
|
from __future__ import annotations
|
2021-11-22 08:22:10 -08:00
|
|
|
|
2022-04-14 14:18:31 -07:00
|
|
|
import atexit
|
2024-10-03 20:47:21 -07:00
|
|
|
from collections.abc import Callable, Sequence
|
2021-12-13 21:51:08 -08:00
|
|
|
import contextlib
|
2024-06-17 10:16:38 -07:00
|
|
|
import dataclasses
|
2024-10-01 10:26:25 -07:00
|
|
|
import enum
|
2021-11-22 08:22:10 -08:00
|
|
|
from functools import partial
|
|
|
|
import itertools
|
2021-12-13 21:51:08 -08:00
|
|
|
import time
|
2024-06-26 14:44:52 -04:00
|
|
|
from typing import Any, NamedTuple
|
2022-10-13 17:06:22 +02:00
|
|
|
import logging
|
2022-04-14 14:18:31 -07:00
|
|
|
import threading
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2023-09-18 02:49:53 -07:00
|
|
|
import jax
|
2023-08-08 10:51:38 -07:00
|
|
|
from jax._src import basearray
|
2023-10-09 07:28:18 -07:00
|
|
|
from jax._src import config
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src import core
|
2023-11-29 18:06:36 -08:00
|
|
|
from jax._src import api
|
2024-10-01 10:26:25 -07:00
|
|
|
from jax._src import array
|
2021-11-22 08:22:10 -08:00
|
|
|
from jax._src import dtypes
|
2023-02-10 13:53:43 -08:00
|
|
|
from jax._src import source_info_util
|
2022-03-22 12:16:03 -07:00
|
|
|
from jax._src import traceback_util
|
2023-02-06 22:51:50 -08:00
|
|
|
from jax._src import util
|
|
|
|
from jax._src.interpreters import ad
|
2023-02-09 15:11:20 -08:00
|
|
|
from jax._src.interpreters import batching
|
2024-06-28 09:43:41 -07:00
|
|
|
from jax._src.abstract_arrays import array_types
|
2023-03-27 13:29:59 -07:00
|
|
|
from jax._src.interpreters import mlir
|
2023-02-09 15:11:20 -08:00
|
|
|
from jax._src.interpreters import xla
|
2023-03-09 16:18:31 -08:00
|
|
|
from jax._src.interpreters import pxla
|
2023-11-29 18:06:36 -08:00
|
|
|
from jax._src import lib
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
from jax._src.mesh import AbstractMesh
|
2021-11-22 08:22:10 -08:00
|
|
|
from jax._src.lib import xla_client as xc
|
2023-04-04 11:41:00 -07:00
|
|
|
from jax._src.monitoring import record_event_duration_secs
|
2023-04-06 11:42:45 -07:00
|
|
|
from jax._src.partition_spec import PartitionSpec
|
2023-03-13 08:49:39 -07:00
|
|
|
from jax._src.sharding import Sharding
|
|
|
|
from jax._src.sharding_impls import (
|
2024-06-05 09:06:36 -07:00
|
|
|
SingleDeviceSharding, NamedSharding,
|
2024-06-03 14:52:08 -07:00
|
|
|
GSPMDSharding, TransferToMemoryKind, is_single_device_sharding)
|
2024-04-03 16:12:43 -07:00
|
|
|
from jax._src.layout import Layout, DeviceLocalLayout
|
2023-02-06 22:51:50 -08:00
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
2022-11-15 12:41:08 -08:00
|
|
|
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
2023-05-15 08:07:31 -07:00
|
|
|
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
|
2022-11-15 12:41:08 -08:00
|
|
|
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
2022-08-19 10:03:43 -07:00
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
|
|
xe = xc._xla
|
|
|
|
|
|
|
|
Backend = xe.Client
|
|
|
|
Device = xc.Device
|
|
|
|
|
2022-09-27 20:59:08 +00:00
|
|
|
CompileOptions = xc.CompileOptions
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
|
2022-10-13 17:06:22 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
# This flag is set on exit; no logging should be attempted
|
|
|
|
_on_exit = False
|
|
|
|
|
|
|
|
### op-by-op execution
|
|
|
|
|
|
|
|
def apply_primitive(prim, *args, **params):
|
|
|
|
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
|
2023-11-29 18:06:36 -08:00
|
|
|
fun = xla_primitive_callable(prim, **params)
|
|
|
|
# TODO(yashkatariya): Investigate adding is_primitive to jit and never
|
|
|
|
# triggering the disable jit path instead of messing around with it here.
|
2024-04-04 14:50:42 +01:00
|
|
|
prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
|
|
|
|
try:
|
|
|
|
outs = fun(*args)
|
|
|
|
finally:
|
|
|
|
lib.jax_jit.swap_thread_local_state_disable_jit(prev)
|
2023-11-29 18:06:36 -08:00
|
|
|
return outs
|
2023-11-27 18:00:22 -08:00
|
|
|
|
2023-08-31 15:17:57 -07:00
|
|
|
@util.cache()
|
2023-11-29 18:06:36 -08:00
|
|
|
def xla_primitive_callable(prim: core.Primitive, **params):
|
2023-08-31 15:17:57 -07:00
|
|
|
def prim_fun(*args):
|
2023-11-29 18:06:36 -08:00
|
|
|
return prim.bind(*args, **params)
|
|
|
|
prim_fun.__name__ = prim.name
|
|
|
|
prim_fun.__qualname__ = prim.name
|
|
|
|
return api.jit(prim_fun)
|
2023-08-31 15:17:57 -07:00
|
|
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def simple_impl(prim):
|
|
|
|
prim.def_impl(partial(apply_primitive, prim))
|
|
|
|
|
2022-04-14 14:18:31 -07:00
|
|
|
RuntimeToken = Any
|
|
|
|
|
|
|
|
class RuntimeTokenSet(threading.local):
|
2024-04-09 15:24:40 -07:00
|
|
|
"""See docstring for effects.py module for the calling convention for tokens."""
|
2023-09-18 02:49:53 -07:00
|
|
|
|
|
|
|
# For each ordered effect, the token returned by the last dispatched
|
|
|
|
# computation, sharded over the devices in that computation.
|
2024-04-18 11:09:02 -07:00
|
|
|
current_tokens: dict[core.Effect, core.Token]
|
2023-09-18 02:49:53 -07:00
|
|
|
|
|
|
|
# For each device, the runtime token returned by the last dispatched
|
|
|
|
# computation on that device.
|
2023-06-23 15:11:37 -07:00
|
|
|
output_runtime_tokens: dict[Device, RuntimeToken]
|
2022-04-14 14:18:31 -07:00
|
|
|
|
|
|
|
def __init__(self):
|
2023-09-18 02:49:53 -07:00
|
|
|
self.current_tokens = {}
|
2022-08-04 13:23:02 -07:00
|
|
|
self.output_runtime_tokens = {}
|
2022-04-14 14:18:31 -07:00
|
|
|
|
2024-04-18 11:09:02 -07:00
|
|
|
def get_token_input(
|
|
|
|
self, eff: core.Effect, devices: list[Device]
|
|
|
|
) -> core.Token:
|
2023-09-18 02:49:53 -07:00
|
|
|
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
|
2024-04-09 15:24:40 -07:00
|
|
|
|
2024-04-18 11:09:02 -07:00
|
|
|
if isinstance(tok, core.Token):
|
2024-04-09 15:24:40 -07:00
|
|
|
# The order of devices may change, so we need to reshard if necessary.
|
|
|
|
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
|
|
|
|
# scenario. Revise the logic later. A distributed shutdown barrier inside
|
|
|
|
# the XLA program may be needed.
|
|
|
|
return jax.device_put(tok, jax.sharding.PositionalSharding(devices))
|
|
|
|
|
|
|
|
# We only use replicated sharding for the first time when the token for the
|
|
|
|
# order effect hasn't been created.
|
2023-09-18 02:49:53 -07:00
|
|
|
s = jax.sharding.GSPMDSharding.get_replicated(devices)
|
2024-08-19 15:10:00 -07:00
|
|
|
sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0])
|
2023-09-18 02:49:53 -07:00
|
|
|
self.current_tokens[eff] = sharded_tok
|
|
|
|
return sharded_tok
|
|
|
|
|
2024-04-18 11:09:02 -07:00
|
|
|
def set_token_result(self, eff: core.Effect, token: core.Token):
|
2023-09-18 02:49:53 -07:00
|
|
|
self.current_tokens[eff] = token
|
2022-05-16 18:55:52 -07:00
|
|
|
|
2022-08-04 13:23:02 -07:00
|
|
|
def set_output_runtime_token(self, device: Device, token: RuntimeToken):
|
2023-09-18 02:49:53 -07:00
|
|
|
# We're free to clobber the previous output token because on each
|
|
|
|
# device we have a total ordering of computations. Only the token
|
|
|
|
# from the latest computation matters.
|
2022-08-04 13:23:02 -07:00
|
|
|
self.output_runtime_tokens[device] = token
|
|
|
|
|
2022-04-14 14:18:31 -07:00
|
|
|
def clear(self):
|
2023-09-18 02:49:53 -07:00
|
|
|
self.current_tokens = {}
|
2022-08-04 13:23:02 -07:00
|
|
|
self.output_runtime_tokens = {}
|
2022-04-14 14:18:31 -07:00
|
|
|
|
|
|
|
def block_until_ready(self):
|
2023-09-18 02:49:53 -07:00
|
|
|
for token in self.current_tokens.values():
|
|
|
|
token.block_until_ready()
|
2022-08-04 13:23:02 -07:00
|
|
|
for token in self.output_runtime_tokens.values():
|
|
|
|
token.block_until_ready()
|
2022-08-17 10:43:50 -07:00
|
|
|
self.clear()
|
2022-04-14 14:18:31 -07:00
|
|
|
|
|
|
|
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
|
|
|
|
|
|
|
|
@atexit.register
|
|
|
|
def wait_for_tokens():
|
|
|
|
runtime_tokens.block_until_ready()
|
2021-11-22 08:22:10 -08:00
|
|
|
|
2023-08-08 06:33:35 -07:00
|
|
|
|
2021-12-13 21:51:08 -08:00
|
|
|
@contextlib.contextmanager
|
2023-07-21 14:20:39 -04:00
|
|
|
def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
|
2021-12-13 21:51:08 -08:00
|
|
|
if _on_exit:
|
|
|
|
yield
|
|
|
|
else:
|
2023-10-09 07:28:18 -07:00
|
|
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
2021-12-13 21:51:08 -08:00
|
|
|
start_time = time.time()
|
|
|
|
yield
|
|
|
|
elapsed_time = time.time() - start_time
|
2023-04-17 07:52:56 -07:00
|
|
|
if logger.isEnabledFor(log_priority):
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
logger.log(log_priority, fmt.format(
|
2023-05-15 09:15:22 -07:00
|
|
|
fun_name=fun_name, elapsed_time=elapsed_time))
|
2022-11-15 12:41:08 -08:00
|
|
|
if event is not None:
|
|
|
|
record_event_duration_secs(event, elapsed_time)
|
2021-12-13 21:51:08 -08:00
|
|
|
|
|
|
|
|
2023-08-08 06:33:35 -07:00
|
|
|
def should_tuple_args(num_args: int, platform: str) -> bool:
|
2022-12-21 13:29:52 -08:00
|
|
|
# CPU and GPU do not need tuples as they use host-side data structures that
|
|
|
|
# do not have small bounds.
|
2022-09-22 01:28:45 -07:00
|
|
|
# TPU only needs a tuple for very long lists
|
2022-12-21 13:29:52 -08:00
|
|
|
if platform == "tpu":
|
2022-08-19 04:57:07 -07:00
|
|
|
return num_args > 2000
|
|
|
|
else:
|
2022-12-21 13:29:52 -08:00
|
|
|
return False
|
2022-08-19 04:57:07 -07:00
|
|
|
|
2023-08-08 06:33:35 -07:00
|
|
|
def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
|
2022-07-11 13:23:44 -07:00
|
|
|
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
|
2021-11-22 08:22:10 -08:00
|
|
|
for eqn in jaxpr.eqns:
|
2022-07-11 13:23:44 -07:00
|
|
|
if prim_name in eqn.primitive.name:
|
2021-11-22 08:22:10 -08:00
|
|
|
return True
|
|
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
2022-07-11 13:23:44 -07:00
|
|
|
if jaxpr_has_primitive(subjaxpr, prim_name):
|
2021-11-22 08:22:10 -08:00
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
2022-10-10 22:08:06 -07:00
|
|
|
|
2023-12-08 16:31:11 -08:00
|
|
|
# Use this registry with caution. It will void the guarantee that lowering to
|
|
|
|
# stablehlo is oblivious of physical devices.
|
|
|
|
prim_requires_devices_during_lowering: set[core.Primitive] = set()
|
|
|
|
|
2024-10-03 20:47:21 -07:00
|
|
|
@util.weakref_lru_cache
|
2024-08-12 10:39:58 -07:00
|
|
|
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool:
|
2023-12-08 16:31:11 -08:00
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
if eqn.primitive in prim_requires_devices_during_lowering:
|
|
|
|
return True
|
|
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
|
|
if jaxpr_has_prim_requiring_devices(subjaxpr):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2023-02-10 15:36:04 -08:00
|
|
|
class SourceInfo(NamedTuple):
|
2024-02-13 18:26:41 -08:00
|
|
|
source_info: source_info_util.SourceInfo
|
2023-02-10 15:36:04 -08:00
|
|
|
eqn_name: str
|
|
|
|
|
|
|
|
|
2024-10-03 20:47:21 -07:00
|
|
|
@util.weakref_lru_cache
|
2024-07-09 07:32:38 -07:00
|
|
|
def get_intermediate_shardings(
|
2024-10-03 20:47:21 -07:00
|
|
|
jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]:
|
2023-02-11 07:36:53 -08:00
|
|
|
from jax._src import pjit
|
|
|
|
from jax.experimental import shard_map
|
2022-10-10 22:08:06 -07:00
|
|
|
|
2024-10-03 20:47:21 -07:00
|
|
|
out = []
|
2022-10-10 22:08:06 -07:00
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
if eqn.primitive is pjit.sharding_constraint_p:
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
s = eqn.params['sharding']
|
|
|
|
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
|
|
|
|
continue
|
2024-02-13 18:26:41 -08:00
|
|
|
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
2024-10-03 20:47:21 -07:00
|
|
|
out.append((s, source_info))
|
2022-10-10 22:08:06 -07:00
|
|
|
elif eqn.primitive is pjit.pjit_p:
|
2024-02-13 18:26:41 -08:00
|
|
|
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
2024-10-03 20:47:21 -07:00
|
|
|
out.extend((i, source_info) for i in eqn.params['in_shardings'])
|
|
|
|
out.extend((o, source_info) for o in eqn.params['out_shardings'])
|
2022-11-04 15:29:10 -07:00
|
|
|
elif eqn.primitive is shard_map.shard_map_p:
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
if not eqn.params['mesh']._is_jax_device_mesh:
|
|
|
|
continue
|
2024-02-13 18:26:41 -08:00
|
|
|
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
2022-11-04 15:29:10 -07:00
|
|
|
def _names_to_pspec(names):
|
|
|
|
ndmin = max(names) + 1 if names else 0
|
|
|
|
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
|
2024-10-03 20:47:21 -07:00
|
|
|
out.extend((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
|
|
|
|
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
|
2023-08-04 09:43:39 -07:00
|
|
|
elif eqn.primitive is device_put_p:
|
2024-06-17 10:16:38 -07:00
|
|
|
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
2024-10-03 20:47:21 -07:00
|
|
|
out.extend((s, source_info) for s in eqn.params['devices']
|
|
|
|
if isinstance(s, Sharding) and s.memory_kind is not None)
|
2022-10-10 22:08:06 -07:00
|
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
2024-10-03 20:47:21 -07:00
|
|
|
out.extend(get_intermediate_shardings(subjaxpr))
|
|
|
|
return out
|
2022-10-10 22:08:06 -07:00
|
|
|
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
|
2022-10-10 18:51:04 -07:00
|
|
|
return (any(type(v.aval.dtype) is core.bint for v in jaxpr.invars
|
|
|
|
if isinstance(v.aval, core.UnshapedArray)) or
|
|
|
|
any(_is_bint_axis_size(d)
|
2022-03-30 17:52:55 -07:00
|
|
|
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
|
2022-10-10 18:51:04 -07:00
|
|
|
for e in j.eqns for v in e.outvars
|
|
|
|
if isinstance(v.aval, core.DShapedArray) for d in v.aval.shape))
|
|
|
|
|
|
|
|
def _is_bint_axis_size(d: core.AxisSize) -> bool:
|
|
|
|
if isinstance(d, core.DArray):
|
|
|
|
assert not d.shape
|
|
|
|
return type(d.dtype) is core.bint
|
|
|
|
elif isinstance(d, core.Var):
|
|
|
|
return (isinstance(d.aval, core.DShapedArray) and
|
|
|
|
type(d.aval.dtype) is core.bint)
|
|
|
|
return False
|
2022-03-30 17:52:55 -07:00
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
# We can optionally set a Jaxpr rewriter that can be applied just before
|
|
|
|
# compilation. This mechanism is used for compiling id_tap, we can
|
|
|
|
# remove it once we bring the id_tap implementation into the core.
|
2023-07-21 14:20:39 -04:00
|
|
|
outfeed_rewriter: Callable[[core.Jaxpr], core.Jaxpr] | None = None
|
2021-11-22 08:22:10 -08:00
|
|
|
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
|
|
|
if outfeed_rewriter is not None:
|
|
|
|
return outfeed_rewriter(jaxpr)
|
|
|
|
else:
|
|
|
|
return jaxpr
|
|
|
|
|
|
|
|
|
2023-08-08 06:33:35 -07:00
|
|
|
def check_arg(arg: Any):
|
2023-08-08 11:17:40 -07:00
|
|
|
if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
|
2022-12-22 08:40:36 -08:00
|
|
|
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
|
|
|
|
"JAX type.")
|
|
|
|
|
|
|
|
|
2023-08-08 10:51:38 -07:00
|
|
|
def jaxpr_replicas(jaxpr: core.Jaxpr) -> int:
|
2021-11-22 08:22:10 -08:00
|
|
|
"""The number of replicas needed for a jaxpr.
|
|
|
|
|
|
|
|
For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
|
|
|
|
subjaxprs. For a list of eqns, take the maximum number of replicas.
|
|
|
|
"""
|
2023-08-08 10:51:38 -07:00
|
|
|
return max(unsafe_map(_eqn_replicas, jaxpr.eqns), default=1)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
2022-08-30 14:59:34 -07:00
|
|
|
# TODO(mattjj): this function assumes that only pmap has a parameter named
|
|
|
|
# axis_size, and that it corresponds to cross-replica mapping
|
2023-08-08 10:51:38 -07:00
|
|
|
def _eqn_replicas(eqn: core.JaxprEqn) -> int:
|
2021-11-22 08:22:10 -08:00
|
|
|
call_jaxpr = eqn.params.get("call_jaxpr")
|
2022-08-30 14:59:34 -07:00
|
|
|
if call_jaxpr:
|
2021-11-22 08:22:10 -08:00
|
|
|
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
|
2023-02-07 15:00:56 -08:00
|
|
|
elif eqn.primitive in xla.initial_style_primitives:
|
2023-08-08 10:51:38 -07:00
|
|
|
return _initial_style_primitive_replicas(eqn.params)
|
2021-11-22 08:22:10 -08:00
|
|
|
else:
|
|
|
|
return 1
|
|
|
|
|
2023-08-08 10:51:38 -07:00
|
|
|
def _initial_style_primitive_replicas(params: dict[str, Any]) -> int:
|
|
|
|
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(),
|
|
|
|
default=1)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
2023-08-08 10:51:38 -07:00
|
|
|
def needs_check_special() -> bool:
|
2023-10-09 07:28:18 -07:00
|
|
|
return config.debug_infs.value or config.debug_nans.value
|
2021-11-22 08:22:10 -08:00
|
|
|
|
2023-08-08 10:51:38 -07:00
|
|
|
def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
|
2021-11-22 08:22:10 -08:00
|
|
|
if needs_check_special():
|
|
|
|
for buf in bufs:
|
2022-09-23 11:40:01 -07:00
|
|
|
_check_special(name, buf.dtype, buf)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
2023-08-08 10:51:38 -07:00
|
|
|
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
2022-09-23 11:40:01 -07:00
|
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
|
2021-11-22 08:22:10 -08:00
|
|
|
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
|
2021-11-22 08:22:10 -08:00
|
|
|
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
|
|
|
|
2024-10-01 10:26:25 -07:00
|
|
|
class CopySemantics(enum.Enum):
|
|
|
|
ALIAS = enum.auto()
|
|
|
|
COPY = enum.auto()
|
|
|
|
DONATE = enum.auto()
|
2023-02-23 15:37:13 -08:00
|
|
|
|
2023-05-09 09:57:33 -07:00
|
|
|
def _identity_fn(x):
|
|
|
|
return x
|
|
|
|
|
2024-10-01 10:26:25 -07:00
|
|
|
def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
|
2024-07-10 12:56:21 -07:00
|
|
|
x._check_if_deleted()
|
2023-05-09 09:57:33 -07:00
|
|
|
inp_sharding = x.sharding
|
|
|
|
|
2024-10-01 10:26:25 -07:00
|
|
|
donate_argnums = 0 if copy == CopySemantics.DONATE else None
|
2023-05-09 09:57:33 -07:00
|
|
|
if inp_sharding._device_assignment == target_sharding._device_assignment:
|
2024-10-01 10:26:25 -07:00
|
|
|
return api.jit(_identity_fn, out_shardings=target_sharding,
|
|
|
|
donate_argnums=donate_argnums)(x)
|
2023-05-09 09:57:33 -07:00
|
|
|
|
2023-05-09 12:55:10 -07:00
|
|
|
if inp_sharding.device_set != target_sharding.device_set:
|
2023-05-09 09:57:33 -07:00
|
|
|
inp_ids = [d.id for d in inp_sharding._device_assignment]
|
|
|
|
inp_plat = inp_sharding._device_assignment[0].platform.upper()
|
|
|
|
target_ids = [d.id for d in target_sharding._device_assignment]
|
|
|
|
target_plat = target_sharding._device_assignment[0].platform.upper()
|
|
|
|
raise ValueError("Input and target sharding should have the same set of "
|
|
|
|
f"devices. Got input's device set ids: {inp_ids} on "
|
|
|
|
f"platform {inp_plat} and target sharding's device set "
|
|
|
|
f"ids: {target_ids} on platform {target_plat}")
|
|
|
|
|
2023-08-07 16:46:17 -07:00
|
|
|
old_hlo_sharding = inp_sharding._to_xla_hlo_sharding(x.ndim)
|
|
|
|
if old_hlo_sharding.is_replicated():
|
|
|
|
new_hlo_sharding = old_hlo_sharding
|
2023-05-11 11:41:59 -07:00
|
|
|
else:
|
|
|
|
permute_order = np.vectorize(target_sharding._device_assignment.index,
|
|
|
|
otypes=[int])(inp_sharding._device_assignment)
|
2023-08-07 16:46:17 -07:00
|
|
|
# Unfortunately need to fallback to V1 sharding here.
|
|
|
|
new_op_sharding = old_hlo_sharding.to_proto()
|
|
|
|
new_op_sharding.iota_reshape_dims = []
|
|
|
|
new_op_sharding.iota_transpose_perm = []
|
2023-05-11 11:41:59 -07:00
|
|
|
new_op_sharding.tile_assignment_devices = np.take(
|
2024-07-22 13:55:55 -07:00
|
|
|
permute_order, old_hlo_sharding.tile_assignment_devices()
|
2023-08-07 16:46:17 -07:00
|
|
|
)
|
|
|
|
new_hlo_sharding = xc.HloSharding.from_proto(new_op_sharding)
|
2024-07-20 09:08:16 -07:00
|
|
|
# TODO(yashkatariya): Enable this when HloSharding conversion is fixed in
|
|
|
|
# XLA.
|
|
|
|
# assert (new_op_sharding.tile_assignment_dimensions
|
|
|
|
# == new_hlo_sharding.tile_assignment_dimensions())
|
|
|
|
# assert (new_op_sharding.tile_assignment_devices
|
|
|
|
# == new_hlo_sharding.tile_assignment_devices())
|
2024-07-22 13:55:55 -07:00
|
|
|
assert (list(np.take(inp_sharding._device_assignment,
|
|
|
|
old_hlo_sharding.tile_assignment_devices()))
|
|
|
|
== list(np.take(target_sharding._device_assignment,
|
|
|
|
new_op_sharding.tile_assignment_devices)))
|
|
|
|
|
|
|
|
new_x = array.make_array_from_single_device_arrays(
|
|
|
|
x.shape,
|
|
|
|
GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding,
|
|
|
|
memory_kind=target_sharding.memory_kind),
|
|
|
|
x._arrays,
|
|
|
|
)
|
2024-10-01 10:26:25 -07:00
|
|
|
return api.jit(_identity_fn, out_shardings=target_sharding,
|
|
|
|
donate_argnums=donate_argnums)(new_x)
|
2023-05-09 09:57:33 -07:00
|
|
|
|
2023-03-09 16:18:31 -08:00
|
|
|
|
2024-06-17 10:16:38 -07:00
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class _DeferredShardArg:
|
|
|
|
"""Deferred call to `pxla.shard_args`.
|
|
|
|
|
|
|
|
Per-array impls return this object instead of a result array to indicate a
|
|
|
|
deferred `shard_args` call. `_batched_device_put_impl` then batches all
|
|
|
|
`_DeferredShardArg` objects into a single `shard_args` call.
|
|
|
|
"""
|
|
|
|
|
|
|
|
x: Any
|
|
|
|
s: Sharding
|
|
|
|
aval: core.AbstractValue
|
|
|
|
committed: bool
|
|
|
|
|
|
|
|
@property
|
|
|
|
def result_handler(self):
|
|
|
|
return pxla.global_aval_to_result_handler(self.aval, self.s, self.committed)
|
|
|
|
|
|
|
|
|
2024-10-01 10:26:25 -07:00
|
|
|
def _device_put_sharding_impl(x, aval, device, copy):
|
2024-06-28 09:43:41 -07:00
|
|
|
from jax.experimental import multihost_utils
|
2023-08-22 22:07:24 -07:00
|
|
|
|
2022-11-29 16:39:45 -08:00
|
|
|
if isinstance(device, Sharding):
|
2022-10-07 16:48:34 -07:00
|
|
|
s = device
|
2024-10-01 10:26:25 -07:00
|
|
|
if (getattr(x, 'sharding', None) == s and getattr(x, '_committed', False)
|
|
|
|
and copy == CopySemantics.ALIAS):
|
2023-05-09 09:57:33 -07:00
|
|
|
return x
|
2024-06-28 09:43:41 -07:00
|
|
|
|
2024-05-17 09:46:36 +01:00
|
|
|
if (not s.is_fully_addressable and
|
2023-05-09 09:57:33 -07:00
|
|
|
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable):
|
2024-06-05 09:06:36 -07:00
|
|
|
assert isinstance(s, Sharding)
|
2024-10-01 10:26:25 -07:00
|
|
|
return _different_device_order_reshard(x, s, copy)
|
2024-07-01 13:13:53 -07:00
|
|
|
|
|
|
|
if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and
|
2024-08-14 09:02:20 -07:00
|
|
|
x.is_fully_addressable and s.num_devices > 1 and
|
2024-07-01 13:13:53 -07:00
|
|
|
s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error
|
|
|
|
s.device_set == x.sharding.device_set):
|
|
|
|
assert isinstance(s, Sharding)
|
2024-10-01 10:26:25 -07:00
|
|
|
return _different_device_order_reshard(x, s, copy)
|
2024-06-28 09:43:41 -07:00
|
|
|
|
2024-05-17 09:46:36 +01:00
|
|
|
if not s.is_fully_addressable:
|
2024-06-28 09:43:41 -07:00
|
|
|
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
|
|
|
|
type(x) in array_types):
|
|
|
|
# TODO(yashkatariya): Move this check to `jit`.
|
|
|
|
multihost_utils.assert_equal(
|
|
|
|
x, fail_message=(
|
|
|
|
f"{type(x)} passed to device_put is not the same on each"
|
|
|
|
" process. Make sure you are passing the same value of"
|
|
|
|
f" {type(x)} on each process."))
|
2024-10-01 10:26:25 -07:00
|
|
|
return api.jit(
|
|
|
|
_identity_fn, out_shardings=s,
|
|
|
|
donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x)
|
2023-08-01 10:16:42 -07:00
|
|
|
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
|
2022-10-05 15:17:29 -07:00
|
|
|
raise ValueError(
|
2023-08-01 10:16:42 -07:00
|
|
|
"device_put's second argument must be a Device or a Sharding which"
|
2024-06-28 09:43:41 -07:00
|
|
|
f" represents addressable devices, but got {s}. Please pass device or"
|
|
|
|
" Sharding which represents addressable devices.")
|
2024-06-17 10:16:38 -07:00
|
|
|
return _DeferredShardArg(x, s, aval, True)
|
2022-10-05 15:17:29 -07:00
|
|
|
|
2022-10-07 13:49:57 -07:00
|
|
|
# Only `Device` exists below. `Sharding` instance is handled above.
|
2022-10-05 15:17:29 -07:00
|
|
|
if isinstance(x, array.ArrayImpl):
|
2022-10-08 19:23:32 -07:00
|
|
|
if not x.is_fully_addressable:
|
2022-10-05 15:17:29 -07:00
|
|
|
raise ValueError(
|
|
|
|
"device_put's first argument must be a fully addressable array, but "
|
|
|
|
f"got value with devices {x.devices()}")
|
2024-10-01 10:26:25 -07:00
|
|
|
if device is None and copy == CopySemantics.ALIAS:
|
2022-10-05 15:17:29 -07:00
|
|
|
return x
|
|
|
|
elif is_single_device_sharding(x.sharding):
|
2024-10-01 10:26:25 -07:00
|
|
|
device = x.sharding._device_assignment[0] if device is None else device
|
2023-03-14 10:19:03 -07:00
|
|
|
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
|
|
|
|
[device])
|
2022-08-19 10:03:43 -07:00
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
sh = SingleDeviceSharding(pxla._get_default_device()
|
|
|
|
if device is None else device)
|
2024-06-17 10:16:38 -07:00
|
|
|
return _DeferredShardArg(x, sh, aval, device is not None)
|
|
|
|
|
2022-06-24 10:04:31 -07:00
|
|
|
|
2024-04-03 16:12:43 -07:00
|
|
|
def _device_put_impl(
|
2024-10-01 10:26:25 -07:00
|
|
|
x, *, device: Device | Sharding | Layout | None,
|
|
|
|
src: Device | Sharding | Layout | None, copy: CopySemantics):
|
2024-04-03 16:12:43 -07:00
|
|
|
if (isinstance(device, TransferToMemoryKind) or
|
|
|
|
isinstance(src, TransferToMemoryKind)):
|
|
|
|
raise ValueError(
|
|
|
|
"TransferToMemoryKind argument to jax.device_put can only be used"
|
|
|
|
" inside jax.jit. If you are using device_put outside jax.jit, then"
|
|
|
|
" please provide a concrete Sharding with memory_kind.")
|
|
|
|
|
|
|
|
try:
|
|
|
|
aval = xla.abstractify(x)
|
|
|
|
except TypeError as err:
|
|
|
|
raise TypeError(
|
|
|
|
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
|
|
|
|
|
|
|
if isinstance(device, Layout):
|
|
|
|
l = device
|
|
|
|
dll = l.device_local_layout
|
|
|
|
x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None
|
2024-04-03 18:36:44 -07:00
|
|
|
if dll is None and l.sharding is None:
|
2024-10-01 10:26:25 -07:00
|
|
|
return _device_put_sharding_impl(x, aval, l.sharding, copy)
|
2024-04-03 16:12:43 -07:00
|
|
|
if (not isinstance(l.sharding, Sharding) or
|
|
|
|
not isinstance(dll, (DeviceLocalLayout, type(None)))):
|
|
|
|
raise ValueError(
|
|
|
|
"sharding and device_local_layout in `Layout` instance should be"
|
2024-04-03 18:36:44 -07:00
|
|
|
f" concrete. Got layout: {l} for input {aval.str_short()}")
|
2024-10-01 10:26:25 -07:00
|
|
|
if (getattr(x, 'layout', None) == l and getattr(x, '_committed', False) and
|
|
|
|
copy == CopySemantics.ALIAS):
|
2024-04-03 16:12:43 -07:00
|
|
|
return x
|
|
|
|
if x_dll is None and dll is None:
|
2024-10-01 10:26:25 -07:00
|
|
|
return _device_put_sharding_impl(x, aval, l.sharding, copy)
|
|
|
|
return api.jit(
|
|
|
|
_identity_fn, out_shardings=l,
|
|
|
|
donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x)
|
2024-04-03 16:12:43 -07:00
|
|
|
|
2024-10-01 10:26:25 -07:00
|
|
|
return _device_put_sharding_impl(x, aval, device, copy)
|
2024-04-03 16:12:43 -07:00
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
2024-06-17 10:16:38 -07:00
|
|
|
def _batched_device_put_impl(
|
|
|
|
*xs,
|
|
|
|
devices: Sequence[Device | Sharding | Layout | None],
|
|
|
|
srcs: Sequence[Device | Sharding | Layout | None],
|
2024-10-01 10:26:25 -07:00
|
|
|
copy_semantics: Sequence[CopySemantics]):
|
2024-06-17 10:16:38 -07:00
|
|
|
ys = []
|
|
|
|
shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], []
|
2024-10-01 10:26:25 -07:00
|
|
|
for i, (x, device, src, cp) in enumerate(zip(xs, devices, srcs, copy_semantics)):
|
|
|
|
y = _device_put_impl(x, device=device, src=src, copy=cp)
|
2024-06-17 10:16:38 -07:00
|
|
|
if isinstance(y, _DeferredShardArg):
|
|
|
|
shard_arg_indices.append(i)
|
|
|
|
shard_arg_xs.append(y.x)
|
|
|
|
shard_arg_shardings.append(y.s)
|
|
|
|
ys.append(y)
|
|
|
|
|
|
|
|
if shard_arg_xs:
|
|
|
|
# Batch shard_arg calls. Helps improve efficiency for backends that support
|
|
|
|
# efficient batch transfer.
|
2024-08-19 15:10:00 -07:00
|
|
|
# device_put handles `Layout` via a different path, so just pass `None` as
|
|
|
|
# the layout here.
|
|
|
|
shard_arg_results = pxla.shard_args(
|
|
|
|
shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs)
|
2024-06-17 10:16:38 -07:00
|
|
|
for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results):
|
|
|
|
assert isinstance(ys[i], _DeferredShardArg)
|
|
|
|
ys[i] = ys[i].result_handler(shard_arg_result)
|
|
|
|
|
|
|
|
return ys
|
|
|
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
device_put_p = core.Primitive('device_put')
|
2024-06-17 10:16:38 -07:00
|
|
|
device_put_p.multiple_results = True
|
|
|
|
device_put_p.def_impl(_batched_device_put_impl)
|
2024-10-01 10:26:25 -07:00
|
|
|
device_put_p.def_abstract_eval(lambda *xs, devices, srcs, copy_semantics: xs)
|
2024-06-17 10:16:38 -07:00
|
|
|
|
2024-10-01 10:26:25 -07:00
|
|
|
def _device_put_transpose(cts, *_, devices, srcs, copy_semantics):
|
2024-06-17 10:16:38 -07:00
|
|
|
results = [None] * len(cts)
|
|
|
|
dp_args = []
|
2024-10-01 10:26:25 -07:00
|
|
|
for i, (ct, device, src, cp) in enumerate(zip(cts, devices, srcs, copy_semantics)):
|
2024-06-17 10:16:38 -07:00
|
|
|
if type(ct) is not ad.Zero:
|
2024-10-01 10:26:25 -07:00
|
|
|
dp_args.append((i, ct, device, src, cp))
|
2024-06-17 10:16:38 -07:00
|
|
|
if dp_args:
|
2024-10-01 10:26:25 -07:00
|
|
|
indices, args, devices, srcs, copy_semantics = list(zip(*dp_args))
|
|
|
|
new_copy_semantics = []
|
|
|
|
for cp in copy_semantics:
|
|
|
|
if cp == CopySemantics.DONATE:
|
|
|
|
raise ValueError(
|
|
|
|
"donate=True is not allowed during tranposition of device_put."
|
|
|
|
" Please file an issue if you want this to be supported.")
|
|
|
|
elif cp == CopySemantics.ALIAS:
|
|
|
|
new_copy_semantics.append(CopySemantics.COPY)
|
|
|
|
else:
|
|
|
|
assert cp == CopySemantics.COPY
|
|
|
|
new_copy_semantics.append(CopySemantics.COPY)
|
|
|
|
ys = device_put_p.bind(*args, devices=srcs, srcs=devices,
|
|
|
|
copy_semantics=new_copy_semantics)
|
2024-06-17 10:16:38 -07:00
|
|
|
for i, y in zip(indices, ys):
|
|
|
|
results[i] = y
|
|
|
|
return results
|
|
|
|
ad.primitive_jvps[device_put_p] = partial(ad.linear_jvp, device_put_p)
|
|
|
|
ad.primitive_transposes[device_put_p] = _device_put_transpose
|
|
|
|
|
|
|
|
def _device_put_batcher(batched_args, batch_dims, **params):
|
|
|
|
mapped_batch_dims = [bd for bd in batch_dims if bd is not batching.not_mapped]
|
|
|
|
assert not mapped_batch_dims or all(
|
|
|
|
mapped_batch_dims[0] == bd for bd in mapped_batch_dims[1:]
|
|
|
|
), batch_dims
|
|
|
|
return device_put_p.bind(*batched_args, **params), batch_dims
|
|
|
|
batching.primitive_batchers[device_put_p] = _device_put_batcher
|
|
|
|
|
2024-10-01 10:26:25 -07:00
|
|
|
def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
|
2024-08-08 11:23:50 -07:00
|
|
|
# TODO(yashkatariya): Maybe we should add the custom calls anyways if it's
|
|
|
|
# being used inside jit? Atleast for now, this preserves the old behavior.
|
|
|
|
if ctx.module_context.all_default_mem_kind:
|
|
|
|
return xs
|
2024-10-01 10:26:25 -07:00
|
|
|
def lower(x, device, aval, out_aval):
|
2024-06-17 10:16:38 -07:00
|
|
|
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
|
|
|
|
device.memory_kind is not None):
|
|
|
|
if isinstance(device, Sharding):
|
2024-10-11 09:43:46 -07:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
x = mlir.wrap_with_sharding_op(
|
|
|
|
ctx, x, out_aval,
|
|
|
|
device._to_sdy_sharding(aval.ndim))
|
|
|
|
else:
|
|
|
|
x = mlir.wrap_with_sharding_op(
|
|
|
|
ctx, x, out_aval,
|
|
|
|
device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
2024-06-17 10:16:38 -07:00
|
|
|
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
|
|
|
|
return x
|
|
|
|
return x
|
2024-10-01 10:26:25 -07:00
|
|
|
return list(map(lower, xs, devices, ctx.avals_in, ctx.avals_out))
|
2024-08-08 11:23:50 -07:00
|
|
|
|
2024-02-29 07:04:36 -08:00
|
|
|
mlir.register_lowering(
|
|
|
|
device_put_p, _tpu_gpu_device_put_lowering, platform='tpu')
|
|
|
|
mlir.register_lowering(
|
|
|
|
device_put_p, _tpu_gpu_device_put_lowering, platform='gpu')
|
2023-09-11 11:54:29 -07:00
|
|
|
|
|
|
|
|
2024-10-01 10:26:25 -07:00
|
|
|
def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
|
2024-06-17 10:16:38 -07:00
|
|
|
return xs
|
2023-09-11 11:54:29 -07:00
|
|
|
mlir.register_lowering(device_put_p, _common_device_put_lowering)
|
2024-05-10 15:34:03 -07:00
|
|
|
|
2024-10-01 10:26:25 -07:00
|
|
|
def _propagate_mem_kind_dp(*xm, devices, srcs, copy_semantics):
|
2024-06-17 10:16:38 -07:00
|
|
|
memory_kinds = []
|
|
|
|
for device in devices:
|
|
|
|
if isinstance(device, (Sharding, TransferToMemoryKind)):
|
|
|
|
memory_kinds.append(device.memory_kind)
|
|
|
|
else:
|
|
|
|
memory_kinds.append(None)
|
|
|
|
return memory_kinds
|
2024-05-10 15:34:03 -07:00
|
|
|
pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp
|