2022-12-16 13:06:38 -08:00
|
|
|
|
# Copyright 2021 The JAX Authors.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2024-05-29 04:03:02 +00:00
|
|
|
|
from collections import defaultdict
|
2024-06-26 14:44:52 -04:00
|
|
|
|
from collections.abc import Callable, Sequence, Iterable
|
2022-12-16 13:06:38 -08:00
|
|
|
|
import dataclasses
|
2024-06-11 12:46:11 -07:00
|
|
|
|
from functools import partial
|
2024-02-15 13:48:49 -08:00
|
|
|
|
import inspect
|
2023-05-12 11:14:53 -07:00
|
|
|
|
import logging
|
2023-06-09 14:43:42 -07:00
|
|
|
|
import operator as op
|
2023-05-26 08:56:56 -07:00
|
|
|
|
import weakref
|
2024-06-26 14:44:52 -04:00
|
|
|
|
from typing import NamedTuple, Any, Union, cast
|
2022-12-16 13:06:38 -08:00
|
|
|
|
import threading
|
|
|
|
|
import warnings
|
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
|
import numpy as np
|
|
|
|
|
|
2024-01-22 13:44:34 -08:00
|
|
|
|
from jax._src import api
|
2024-04-04 14:33:06 -04:00
|
|
|
|
from jax._src import ad_util
|
2024-02-15 13:48:49 -08:00
|
|
|
|
from jax._src import api_util
|
2023-10-09 07:28:18 -07:00
|
|
|
|
from jax._src import config
|
2023-02-09 11:02:24 -08:00
|
|
|
|
from jax._src import core
|
2022-12-16 13:06:38 -08:00
|
|
|
|
from jax._src import dispatch
|
2024-02-29 15:30:19 -08:00
|
|
|
|
from jax._src import dtypes
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import linear_util as lu
|
2024-01-22 13:44:34 -08:00
|
|
|
|
from jax._src import mesh as mesh_lib
|
2023-04-06 08:31:47 -07:00
|
|
|
|
from jax._src import op_shardings
|
2024-05-29 01:49:06 -07:00
|
|
|
|
from jax._src import profiler
|
2023-04-10 10:15:08 -07:00
|
|
|
|
from jax._src import sharding_impls
|
2022-12-16 13:06:38 -08:00
|
|
|
|
from jax._src import source_info_util
|
2024-01-22 13:44:34 -08:00
|
|
|
|
from jax._src import stages
|
2023-01-17 18:42:21 -08:00
|
|
|
|
from jax._src import traceback_util
|
2024-01-22 13:44:34 -08:00
|
|
|
|
from jax._src import tree_util
|
|
|
|
|
from jax._src import util
|
2023-02-28 07:01:14 -08:00
|
|
|
|
from jax._src import xla_bridge as xb
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src.api_util import (
|
|
|
|
|
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
2024-04-23 17:37:52 -07:00
|
|
|
|
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
|
2024-02-13 16:45:27 -08:00
|
|
|
|
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
|
2024-04-23 17:37:52 -07:00
|
|
|
|
hoist_obj_attrs)
|
2023-04-04 11:41:00 -07:00
|
|
|
|
from jax._src.interpreters import partial_eval as pe
|
2023-04-06 11:42:45 -07:00
|
|
|
|
from jax._src.partition_spec import PartitionSpec
|
2023-04-04 11:41:00 -07:00
|
|
|
|
from jax._src.interpreters import xla
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src.interpreters import ad
|
2023-02-09 15:11:20 -08:00
|
|
|
|
from jax._src.interpreters import batching
|
|
|
|
|
from jax._src.interpreters import mlir
|
2023-02-07 11:16:01 -08:00
|
|
|
|
from jax._src.interpreters import pxla
|
2022-12-16 13:06:38 -08:00
|
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
|
from jax._src.lib.mlir.dialects import func as func_dialect
|
2024-06-21 13:52:19 -07:00
|
|
|
|
from jax._src.lib import jax_jit
|
2022-12-16 13:06:38 -08:00
|
|
|
|
from jax._src.lib import xla_client as xc
|
2024-09-17 16:10:41 -07:00
|
|
|
|
from jax._src.lib import xla_extension_version
|
2024-06-05 09:06:36 -07:00
|
|
|
|
from jax._src import sharding
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
from jax._src.mesh import AbstractMesh
|
2023-04-04 11:41:00 -07:00
|
|
|
|
from jax._src.sharding_impls import (
|
2024-06-18 11:31:09 -04:00
|
|
|
|
NamedSharding, GSPMDSharding,
|
2024-03-26 13:28:03 -07:00
|
|
|
|
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
|
2024-07-24 12:39:42 -07:00
|
|
|
|
ParsedPartitionSpec, get_single_pspec, is_unspecified,
|
2023-08-04 16:26:31 -07:00
|
|
|
|
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
|
2024-04-04 14:33:06 -04:00
|
|
|
|
from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src.traceback_util import api_boundary
|
2023-04-04 11:41:00 -07:00
|
|
|
|
from jax._src.tree_util import (
|
2024-05-22 23:30:55 -04:00
|
|
|
|
tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves,
|
2024-03-21 08:59:28 -07:00
|
|
|
|
treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
PyTreeDef, none_leaf_registry as none_lr)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
from jax._src.util import (
|
2023-01-13 12:53:42 -08:00
|
|
|
|
HashableFunction, safe_map, safe_zip, wraps,
|
2023-04-10 10:15:08 -07:00
|
|
|
|
distributed_debug_log, split_list, weakref_lru_cache,
|
2024-07-03 16:38:18 -04:00
|
|
|
|
merge_lists, subs_list, fun_name, fun_qual_name)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-03-22 20:54:45 -07:00
|
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
|
|
2023-01-17 18:42:21 -08:00
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
2023-05-20 22:59:52 -07:00
|
|
|
|
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTO]
|
|
|
|
|
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTO]
|
|
|
|
|
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTO]
|
|
|
|
|
MeshShardingMinusUnspecified = Union[NamedSharding, AUTO]
|
2023-02-10 13:53:43 -08:00
|
|
|
|
|
2023-05-12 11:14:53 -07:00
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
def _find_arg_mismatch(arg_list, fails, fun_name):
|
|
|
|
|
mismatched_args_msg = []
|
2023-12-20 17:23:49 -08:00
|
|
|
|
def mismatch(err):
|
|
|
|
|
for name, inp_da, aval in arg_list:
|
|
|
|
|
if err.m_type == pxla.MismatchType.ARG_SHARDING and err.da == inp_da:
|
2023-02-10 13:53:43 -08:00
|
|
|
|
mismatched_args_msg.append(
|
2023-07-21 14:20:39 -04:00
|
|
|
|
f"argument {name} of {fun_name} with shape {aval.str_short()} and "
|
2023-12-20 17:23:49 -08:00
|
|
|
|
f"{err._dev_ids_plat_str}")
|
2023-02-10 13:53:43 -08:00
|
|
|
|
break
|
2023-12-20 17:23:49 -08:00
|
|
|
|
first_err, second_err = fails
|
|
|
|
|
mismatch(first_err)
|
|
|
|
|
mismatch(second_err)
|
2023-02-10 13:53:43 -08:00
|
|
|
|
return mismatched_args_msg
|
|
|
|
|
|
2023-03-21 08:39:46 -07:00
|
|
|
|
|
|
|
|
|
def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
|
|
|
|
|
arg_names):
|
|
|
|
|
arg_list = []
|
2023-12-19 17:31:25 -08:00
|
|
|
|
if arg_names is None:
|
|
|
|
|
arg_names = [''] * len(args_flat)
|
2023-03-22 20:54:45 -07:00
|
|
|
|
for a, n in zip(args_flat, arg_names):
|
2024-04-02 08:55:51 -07:00
|
|
|
|
da = (a.sharding._device_assignment
|
|
|
|
|
if getattr(a, 'sharding', None) is not None else None)
|
2023-03-21 08:39:46 -07:00
|
|
|
|
arg_list.append((n, da, shaped_abstractify(a)))
|
2023-02-10 13:53:43 -08:00
|
|
|
|
|
|
|
|
|
mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name)
|
|
|
|
|
|
|
|
|
|
if len(mismatched_args_msg) == 2:
|
2024-05-22 06:35:38 -07:00
|
|
|
|
first, second = mismatched_args_msg # pytype: disable=bad-unpacking
|
2023-02-10 13:53:43 -08:00
|
|
|
|
extra_msg = f" Got {first} and {second}"
|
|
|
|
|
elif len(mismatched_args_msg) == 1:
|
|
|
|
|
first, second = fails
|
|
|
|
|
# Choose the failure left which is not already covered by ARG_SHARDING.
|
|
|
|
|
left = second if first.m_type == pxla.MismatchType.ARG_SHARDING else first
|
|
|
|
|
extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}"
|
|
|
|
|
else:
|
|
|
|
|
first, second = fails
|
|
|
|
|
extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}"
|
|
|
|
|
msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}")
|
|
|
|
|
return msg
|
|
|
|
|
|
|
|
|
|
|
2024-03-21 05:35:44 -07:00
|
|
|
|
class PjitInfo(NamedTuple):
|
|
|
|
|
"""Things that we know about a jit instance before it is called.
|
|
|
|
|
|
|
|
|
|
In other words, this structure contains arguments to jit()/pjit(),
|
|
|
|
|
preprocessed and validated.
|
|
|
|
|
"""
|
|
|
|
|
fun_sourceinfo: str | None
|
|
|
|
|
fun_signature: inspect.Signature | None
|
2024-03-21 08:59:28 -07:00
|
|
|
|
# Shardings, as specified by the user. These can either be UNSPECIFIED or they
|
|
|
|
|
# can be a tree (prefix) of shardings or None.
|
|
|
|
|
user_specified_in_shardings: bool
|
|
|
|
|
in_shardings_treedef: PyTreeDef
|
|
|
|
|
in_shardings_leaves: tuple[Any, ...]
|
|
|
|
|
out_shardings_treedef: PyTreeDef
|
|
|
|
|
out_shardings_leaves: tuple[Any, ...]
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts_treedef: PyTreeDef
|
|
|
|
|
in_layouts_leaves: tuple[Any, ...]
|
|
|
|
|
out_layouts_treedef: PyTreeDef
|
|
|
|
|
out_layouts_leaves: tuple[Any, ...]
|
2024-03-21 05:35:44 -07:00
|
|
|
|
static_argnums: tuple[int, ...]
|
|
|
|
|
static_argnames: tuple[str, ...]
|
|
|
|
|
donate_argnums: tuple[int, ...]
|
|
|
|
|
donate_argnames: tuple[str, ...]
|
|
|
|
|
device: xc.Device | None
|
|
|
|
|
backend: str | None
|
|
|
|
|
keep_unused: bool
|
|
|
|
|
inline: bool
|
|
|
|
|
abstracted_axes: Any | None
|
|
|
|
|
use_resource_env: bool # False for jit, True for pjit
|
|
|
|
|
|
2024-06-21 13:52:19 -07:00
|
|
|
|
# Hash and compare PjitInfo by identity when used as a cache key.
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return id(self)
|
2024-03-21 05:35:44 -07:00
|
|
|
|
|
2024-06-21 13:52:19 -07:00
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
return self is other
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _python_pjit_helper(fun, jit_info, *args, **kwargs):
|
|
|
|
|
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
for arg in args_flat:
|
2022-12-22 08:40:36 -08:00
|
|
|
|
dispatch.check_arg(arg)
|
2024-03-21 17:45:44 -07:00
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
if p.attrs_tracked:
|
|
|
|
|
init_states = _get_states(p.attrs_tracked)
|
2024-01-25 22:20:36 -08:00
|
|
|
|
args_flat = [*init_states, *args_flat]
|
2024-03-21 17:45:44 -07:00
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
try:
|
2024-06-20 09:57:41 -07:00
|
|
|
|
out_flat = pjit_p.bind(*args_flat, **p.params)
|
2023-02-10 13:53:43 -08:00
|
|
|
|
except pxla.DeviceAssignmentMismatchError as e:
|
|
|
|
|
fails, = e.args
|
2024-06-20 09:57:41 -07:00
|
|
|
|
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
|
2023-03-21 08:39:46 -07:00
|
|
|
|
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
|
2023-02-10 13:53:43 -08:00
|
|
|
|
msg = _device_assignment_mismatch_error(
|
2024-06-20 09:57:41 -07:00
|
|
|
|
fun_name, fails, args_flat, api_name, p.arg_names)
|
2023-02-10 13:53:43 -08:00
|
|
|
|
raise ValueError(msg) from None
|
2024-03-21 17:45:44 -07:00
|
|
|
|
except xla.InvalidInputException as e:
|
2024-06-20 09:57:41 -07:00
|
|
|
|
arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
|
2024-03-21 17:45:44 -07:00
|
|
|
|
# Run canonicalization again to figure out which arg failed.
|
2024-06-20 09:57:41 -07:00
|
|
|
|
if p.params['jaxpr'].consts:
|
2024-03-21 17:45:44 -07:00
|
|
|
|
raise TypeError(e.args[0]) from e
|
|
|
|
|
else:
|
2024-06-20 09:57:41 -07:00
|
|
|
|
for arg, name, aval in zip(args_flat, arg_names, p.in_avals):
|
2024-03-21 17:45:44 -07:00
|
|
|
|
try:
|
|
|
|
|
xla.canonicalize_dtype(arg)
|
|
|
|
|
except xla.InvalidInputException as _:
|
|
|
|
|
# Reraise as TypeError with the new message.
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"Argument '{name}' of shape {aval.str_short()} of type"
|
|
|
|
|
f' {type(arg)} is not a valid JAX type.') from e
|
|
|
|
|
raise AssertionError("Unreachable") from e
|
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
if p.attrs_tracked:
|
|
|
|
|
num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked)
|
2024-05-22 23:30:55 -04:00
|
|
|
|
final_states, out_flat = split_list(out_flat, [num_states_out])
|
2024-06-20 09:57:41 -07:00
|
|
|
|
_set_states(p.attrs_tracked, final_states)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
outs = tree_unflatten(p.out_tree, out_flat)
|
|
|
|
|
return outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], p.attrs_tracked
|
2024-01-25 22:20:36 -08:00
|
|
|
|
|
2024-03-21 17:45:44 -07:00
|
|
|
|
|
2024-01-25 22:20:36 -08:00
|
|
|
|
def _set_states(attrs_tracked, vals):
|
2024-05-17 09:46:36 +01:00
|
|
|
|
from jax.experimental.attrs import jax_setattr
|
2024-05-22 23:30:55 -04:00
|
|
|
|
valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]])
|
|
|
|
|
for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss):
|
|
|
|
|
val = tree_unflatten(treedef, leaves)
|
2024-01-25 22:20:36 -08:00
|
|
|
|
jax_setattr(obj, attr, val)
|
|
|
|
|
|
|
|
|
|
def _get_states(attrs_tracked):
|
2024-05-17 09:46:36 +01:00
|
|
|
|
from jax.experimental.attrs import jax_getattr
|
2024-05-22 23:30:55 -04:00
|
|
|
|
vals = []
|
|
|
|
|
for treedef, _, (obj, attr) in attrs_tracked:
|
|
|
|
|
tree = jax_getattr(obj, attr)
|
|
|
|
|
leaves, treedef_ = tree_flatten(tree)
|
|
|
|
|
assert treedef == treedef_
|
|
|
|
|
vals.extend(leaves)
|
|
|
|
|
return vals
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-05-29 01:49:06 -07:00
|
|
|
|
def _need_to_rebuild_with_fdo(pgle_profiler):
|
|
|
|
|
return (pgle_profiler is not None and pgle_profiler.is_enabled()
|
|
|
|
|
and not pgle_profiler.is_fdo_consumed())
|
2023-02-10 13:53:43 -08:00
|
|
|
|
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
def _get_fastpath_data(
|
2024-03-21 08:09:37 -07:00
|
|
|
|
executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
|
2024-05-29 01:49:06 -07:00
|
|
|
|
consts, abstracted_axes, pgle_profiler
|
2024-06-26 14:44:52 -04:00
|
|
|
|
) -> pxla.MeshExecutableFastpathData | None:
|
2024-02-29 15:30:19 -08:00
|
|
|
|
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
|
2023-05-26 08:56:56 -07:00
|
|
|
|
use_fastpath = (
|
2024-03-06 11:41:34 -08:00
|
|
|
|
executable is not None
|
|
|
|
|
and isinstance(executable, pxla.MeshExecutable)
|
|
|
|
|
and isinstance(executable.unsafe_call, pxla.ExecuteReplicated)
|
2023-05-26 08:56:56 -07:00
|
|
|
|
# No effects in computation
|
2024-03-06 11:41:34 -08:00
|
|
|
|
and not executable.unsafe_call.ordered_effects
|
|
|
|
|
and not executable.unsafe_call.has_unordered_effects
|
|
|
|
|
and not executable.unsafe_call.has_host_callbacks
|
|
|
|
|
and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened)
|
2024-03-21 08:09:37 -07:00
|
|
|
|
and abstracted_axes is None
|
2024-01-25 22:20:36 -08:00
|
|
|
|
# no attr state effects
|
2024-03-06 11:41:34 -08:00
|
|
|
|
and not attrs_tracked
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
# no ref state effects
|
2024-03-06 11:41:34 -08:00
|
|
|
|
and not any(isinstance(e, RefEffect) for e in effects)
|
2024-02-29 15:30:19 -08:00
|
|
|
|
# no prng reuse checking
|
2024-03-21 10:47:16 -07:00
|
|
|
|
and not (config.debug_key_reuse.value and any(
|
2024-02-29 15:30:19 -08:00
|
|
|
|
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
|
2024-04-11 12:23:01 -07:00
|
|
|
|
for arg in (*args_flat, *out_flat, *consts)))
|
2024-05-29 01:49:06 -07:00
|
|
|
|
and not _need_to_rebuild_with_fdo(pgle_profiler)
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
)
|
2023-05-26 08:56:56 -07:00
|
|
|
|
|
|
|
|
|
if use_fastpath:
|
2024-02-29 15:30:19 -08:00
|
|
|
|
out_avals = [o.aval for o in out_reflattened]
|
|
|
|
|
out_committed = [o._committed for o in out_reflattened]
|
2023-05-26 08:56:56 -07:00
|
|
|
|
kept_var_bitvec = [i in executable._kept_var_idx
|
|
|
|
|
for i in range(len(args_flat))]
|
2024-03-06 11:41:34 -08:00
|
|
|
|
in_shardings = [
|
2024-06-03 14:52:08 -07:00
|
|
|
|
sharding_impls.physical_sharding(a, s)
|
2024-03-06 11:41:34 -08:00
|
|
|
|
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
|
|
|
|
|
else s
|
|
|
|
|
for s, a in zip(executable._in_shardings, executable.in_avals)
|
|
|
|
|
]
|
2023-05-26 08:56:56 -07:00
|
|
|
|
fastpath_data = pxla.MeshExecutableFastpathData(
|
2024-03-06 11:41:34 -08:00
|
|
|
|
executable.xla_executable, out_tree, in_shardings,
|
2023-12-19 14:25:25 -08:00
|
|
|
|
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
2024-08-28 11:05:45 -07:00
|
|
|
|
executable._dispatch_in_layouts)
|
2023-05-26 08:56:56 -07:00
|
|
|
|
else:
|
|
|
|
|
fastpath_data = None
|
|
|
|
|
return fastpath_data
|
|
|
|
|
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
class _MostRecentPjitCallExecutable(threading.local):
|
|
|
|
|
def __init__(self):
|
2023-06-05 10:06:30 -07:00
|
|
|
|
self.weak_key_dict = weakref.WeakKeyDictionary()
|
2024-05-29 01:49:06 -07:00
|
|
|
|
self.weak_pgle_profiler_dict = weakref.WeakKeyDictionary()
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable()
|
|
|
|
|
|
2023-01-03 14:05:17 -08:00
|
|
|
|
|
2023-06-05 10:06:30 -07:00
|
|
|
|
def _read_most_recent_pjit_call_executable(jaxpr):
|
|
|
|
|
return _most_recent_pjit_call_executable.weak_key_dict.get(jaxpr, None)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
|
2024-05-29 01:49:06 -07:00
|
|
|
|
def _read_pgle_profiler(jaxpr):
|
2024-08-19 15:10:00 -07:00
|
|
|
|
return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None)
|
2024-05-29 01:49:06 -07:00
|
|
|
|
|
2023-02-14 18:45:31 -08:00
|
|
|
|
def _cpp_pjit_evict_fn(self):
|
|
|
|
|
self._clear_cache()
|
2024-05-23 05:35:00 -07:00
|
|
|
|
_create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error
|
2024-06-21 13:52:19 -07:00
|
|
|
|
_infer_params_cached.cache_clear()
|
2023-02-14 18:45:31 -08:00
|
|
|
|
|
|
|
|
|
|
2023-05-26 08:56:56 -07:00
|
|
|
|
# The entries are doubled here from the default 4096 because _pjit_call_impl
|
|
|
|
|
# also has a cpp dispatch path and that would double the number of entries in
|
|
|
|
|
# the global shared cache.
|
2024-09-17 16:10:41 -07:00
|
|
|
|
# This cache is only used for jit's with only fun. For example: jax.jit(f)
|
|
|
|
|
_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192)
|
2023-02-06 20:34:51 -08:00
|
|
|
|
|
2024-09-17 16:10:41 -07:00
|
|
|
|
# This cache is used for jit where extra arguments are defined other than the
|
|
|
|
|
# fun. For example: jax.jit(f, donate_argnums=...) OR
|
|
|
|
|
# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the
|
|
|
|
|
# capacity might get full very fast because of all the jitted function in JAX
|
|
|
|
|
# which might evict train_step for example.
|
|
|
|
|
_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)
|
2023-02-06 20:34:51 -08:00
|
|
|
|
|
2024-09-17 16:10:41 -07:00
|
|
|
|
|
|
|
|
|
if xla_extension_version < 286:
|
|
|
|
|
def _get_cpp_global_cache(pjit_has_explicit_sharding):
|
|
|
|
|
if pjit_has_explicit_sharding:
|
|
|
|
|
return xc._xla.PjitFunctionCache()
|
|
|
|
|
else:
|
|
|
|
|
return _cpp_pjit_cache_fun_only
|
|
|
|
|
|
|
|
|
|
def _pjit_explicit_sharding_and_layout(
|
|
|
|
|
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
|
|
|
|
|
device, backend) -> bool:
|
|
|
|
|
return (device is not None or
|
|
|
|
|
backend is not None or
|
|
|
|
|
any(not is_unspecified(i) for i in in_shardings_flat) or
|
|
|
|
|
any(not is_unspecified(o) for o in out_shardings_flat) or
|
|
|
|
|
any(i is not None for i in in_layouts_flat) or
|
|
|
|
|
any(o is not None for o in out_layouts_flat))
|
|
|
|
|
else:
|
|
|
|
|
def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore
|
|
|
|
|
if contains_explicit_attributes:
|
|
|
|
|
return _cpp_pjit_cache_explicit_attributes
|
|
|
|
|
else:
|
|
|
|
|
return _cpp_pjit_cache_fun_only
|
2023-05-30 19:51:06 -07:00
|
|
|
|
|
|
|
|
|
|
2024-06-21 13:52:19 -07:00
|
|
|
|
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
2023-01-13 10:15:30 -08:00
|
|
|
|
|
2023-01-17 18:42:21 -08:00
|
|
|
|
@api_boundary
|
2023-01-13 10:15:30 -08:00
|
|
|
|
def cache_miss(*args, **kwargs):
|
2024-08-23 21:21:55 +00:00
|
|
|
|
if config.no_tracing.value:
|
|
|
|
|
raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
|
|
|
|
|
"`jit`, but 'no_tracing' is set")
|
2024-01-25 22:20:36 -08:00
|
|
|
|
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
|
2024-06-21 13:52:19 -07:00
|
|
|
|
fun, jit_info, *args, **kwargs)
|
2023-06-05 10:06:30 -07:00
|
|
|
|
executable = _read_most_recent_pjit_call_executable(jaxpr)
|
2024-05-29 01:49:06 -07:00
|
|
|
|
pgle_profiler = _read_pgle_profiler(jaxpr)
|
2024-01-25 22:20:36 -08:00
|
|
|
|
maybe_fastpath_data = _get_fastpath_data(
|
2024-03-21 08:09:37 -07:00
|
|
|
|
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
|
2024-05-29 01:49:06 -07:00
|
|
|
|
jaxpr.consts, jit_info.abstracted_axes,
|
|
|
|
|
pgle_profiler)
|
|
|
|
|
|
2024-06-18 11:31:09 -04:00
|
|
|
|
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-09-17 16:10:41 -07:00
|
|
|
|
if xla_extension_version >= 286:
|
|
|
|
|
cache_key = pxla.JitGlobalCppCacheKeys(
|
|
|
|
|
donate_argnums=jit_info.donate_argnums,
|
|
|
|
|
donate_argnames=jit_info.donate_argnames,
|
|
|
|
|
device=jit_info.device, backend=jit_info.backend,
|
|
|
|
|
in_shardings_treedef=jit_info.in_shardings_treedef,
|
|
|
|
|
in_shardings_leaves=jit_info.in_shardings_leaves,
|
|
|
|
|
out_shardings_treedef=jit_info.out_shardings_treedef,
|
|
|
|
|
out_shardings_leaves=jit_info.out_shardings_leaves,
|
|
|
|
|
in_layouts_treedef=jit_info.in_layouts_treedef,
|
|
|
|
|
in_layouts_leaves=jit_info.in_layouts_leaves,
|
|
|
|
|
out_layouts_treedef=jit_info.out_layouts_treedef,
|
|
|
|
|
out_layouts_leaves=jit_info.out_layouts_leaves,
|
|
|
|
|
use_resource_env=jit_info.use_resource_env)
|
|
|
|
|
cpp_pjit_f = xc._xla.pjit(
|
|
|
|
|
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
|
|
|
|
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
|
|
|
|
|
pxla.cc_shard_arg,
|
|
|
|
|
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
|
|
|
|
|
else:
|
|
|
|
|
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
|
|
|
|
|
jit_info.in_shardings_leaves, jit_info.out_shardings_leaves,
|
|
|
|
|
jit_info.in_layouts_leaves, jit_info.out_layouts_leaves,
|
|
|
|
|
jit_info.device, jit_info.backend)
|
|
|
|
|
cpp_pjit_f = xc._xla.pjit(
|
|
|
|
|
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
|
|
|
|
jit_info.static_argnames, jit_info.donate_argnums,
|
|
|
|
|
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
|
|
|
|
_get_cpp_global_cache(has_explicit_sharding))
|
2023-02-14 18:45:31 -08:00
|
|
|
|
|
|
|
|
|
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
|
|
|
|
cpp_pjitted_f._fun = fun
|
|
|
|
|
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
|
|
|
|
|
return cpp_pjitted_f
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
def _split_layout_and_sharding(entries):
|
|
|
|
|
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
|
|
|
|
|
layouts, shardings = [], []
|
|
|
|
|
|
|
|
|
|
for e in entries_flat:
|
2024-07-02 19:12:27 -07:00
|
|
|
|
if isinstance(e, Layout):
|
2024-04-05 20:08:48 -07:00
|
|
|
|
layouts.append(e.device_local_layout)
|
|
|
|
|
shardings.append(e.sharding)
|
|
|
|
|
elif isinstance(e, (DeviceLocalLayout, AutoLayout)):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'`jax.jit` does not accept device-local layouts directly. Create '
|
|
|
|
|
'a `Layout` instance wrapping this device-local layout and pass '
|
|
|
|
|
f'that to `jit` instead. Got {e}')
|
|
|
|
|
else:
|
|
|
|
|
layouts.append(None)
|
|
|
|
|
shardings.append(e)
|
|
|
|
|
|
|
|
|
|
assert len(layouts) == len(shardings)
|
|
|
|
|
return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings)
|
|
|
|
|
|
|
|
|
|
|
2024-03-21 05:35:44 -07:00
|
|
|
|
def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
|
|
|
|
donate_argnums: int | Sequence[int] | None,
|
|
|
|
|
donate_argnames: str | Iterable[str] | None,
|
|
|
|
|
static_argnums: int | Sequence[int] | None,
|
|
|
|
|
static_argnames: str | Iterable[str] | None,
|
|
|
|
|
device: xc.Device | None, backend: str | None,
|
|
|
|
|
abstracted_axes: Any | None, keep_unused: bool,
|
|
|
|
|
inline: bool, use_resource_env: bool) -> PjitInfo:
|
|
|
|
|
"""Parses the arguments to jit/pjit.
|
|
|
|
|
|
|
|
|
|
Performs any preprocessing and validation of the arguments that we can do
|
|
|
|
|
ahead of time before the jit()-ed function is invoked.
|
|
|
|
|
"""
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if abstracted_axes and not config.dynamic_shapes.value:
|
2023-01-14 20:16:57 -08:00
|
|
|
|
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
|
|
|
|
|
2023-01-12 17:23:55 -08:00
|
|
|
|
check_callable(fun)
|
|
|
|
|
|
|
|
|
|
if backend is not None or device is not None:
|
|
|
|
|
warnings.warn(
|
2024-05-28 07:59:31 -07:00
|
|
|
|
'backend and device argument on jit is deprecated. You can use'
|
|
|
|
|
' `jax.device_put(..., jax.local_devices("cpu")[0])` on the inputs to'
|
|
|
|
|
' the jitted function to get the same behavior.', DeprecationWarning)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
if device is not None and backend is not None:
|
|
|
|
|
raise ValueError("can't specify both a device and a backend for jit, "
|
|
|
|
|
f"got {device=} and {backend=}")
|
2023-06-15 15:21:36 -07:00
|
|
|
|
if in_shardings is not None and not is_unspecified(in_shardings):
|
2023-01-12 17:23:55 -08:00
|
|
|
|
raise ValueError('If backend or device is specified on jit, then '
|
2023-02-11 15:29:38 -08:00
|
|
|
|
'in_shardings should not be specified.')
|
2023-06-15 15:21:36 -07:00
|
|
|
|
if out_shardings is not None and not is_unspecified(out_shardings):
|
2023-01-12 17:23:55 -08:00
|
|
|
|
raise ValueError('If backend or device is specified on jit, then '
|
2023-02-11 15:29:38 -08:00
|
|
|
|
'out_shardings should not be specified.')
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2023-02-11 15:29:38 -08:00
|
|
|
|
if isinstance(in_shardings, list):
|
2023-01-12 17:23:55 -08:00
|
|
|
|
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
|
|
|
|
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
|
|
|
|
|
# in cases like these users expect tuples and lists to be treated
|
|
|
|
|
# essentially interchangeably, so we canonicalize lists to tuples here
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# rather than raising an error. https://github.com/jax-ml/jax/issues/2367
|
2023-02-11 15:29:38 -08:00
|
|
|
|
in_shardings = tuple(in_shardings)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts, in_shardings = _split_layout_and_sharding(in_shardings)
|
|
|
|
|
out_layouts, out_shardings = _split_layout_and_sharding(out_shardings)
|
|
|
|
|
|
2024-04-03 22:38:45 -07:00
|
|
|
|
in_shardings = prepare_axis_resources(in_shardings, 'in_shardings')
|
|
|
|
|
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-03-21 08:59:28 -07:00
|
|
|
|
user_specified_in_shardings = (in_shardings is not None and
|
|
|
|
|
not is_unspecified(in_shardings))
|
2024-04-05 20:08:48 -07:00
|
|
|
|
|
|
|
|
|
in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings)
|
|
|
|
|
out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings)
|
|
|
|
|
in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts)
|
|
|
|
|
out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts)
|
2024-03-21 08:59:28 -07:00
|
|
|
|
|
2024-03-21 05:35:44 -07:00
|
|
|
|
fun_sourceinfo = api_util.fun_sourceinfo(fun)
|
|
|
|
|
fun_signature = api_util.fun_signature(fun)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-03-21 06:35:20 -07:00
|
|
|
|
donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums(
|
|
|
|
|
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
|
|
|
|
|
static_argnames)
|
|
|
|
|
|
2024-03-21 05:35:44 -07:00
|
|
|
|
return PjitInfo(
|
|
|
|
|
fun_sourceinfo=fun_sourceinfo,
|
|
|
|
|
fun_signature=fun_signature,
|
2024-03-21 08:59:28 -07:00
|
|
|
|
user_specified_in_shardings=user_specified_in_shardings,
|
|
|
|
|
in_shardings_treedef=in_shardings_treedef,
|
|
|
|
|
in_shardings_leaves=tuple(in_shardings_leaves),
|
|
|
|
|
out_shardings_treedef=out_shardings_treedef,
|
|
|
|
|
out_shardings_leaves=tuple(out_shardings_leaves),
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts_treedef=in_layouts_treedef,
|
|
|
|
|
in_layouts_leaves=tuple(in_layouts_leaves),
|
|
|
|
|
out_layouts_treedef=out_layouts_treedef,
|
|
|
|
|
out_layouts_leaves=tuple(out_layouts_leaves),
|
2024-03-21 08:59:28 -07:00
|
|
|
|
static_argnums=static_argnums,
|
2024-03-21 05:35:44 -07:00
|
|
|
|
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
|
|
|
|
donate_argnames=donate_argnames, device=device, backend=backend,
|
|
|
|
|
keep_unused=keep_unused, inline=inline,
|
|
|
|
|
abstracted_axes=abstracted_axes,
|
|
|
|
|
use_resource_env=use_resource_env)
|
|
|
|
|
|
|
|
|
|
|
2024-06-21 13:52:19 -07:00
|
|
|
|
def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2023-01-17 18:42:21 -08:00
|
|
|
|
@api_boundary
|
2023-05-16 19:47:19 -07:00
|
|
|
|
def lower(*args, **kwargs):
|
2024-09-12 19:02:57 -07:00
|
|
|
|
return trace(*args, **kwargs).lower()
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-01-18 22:10:24 -08:00
|
|
|
|
@api_boundary
|
|
|
|
|
def eval_shape(*args, **kwargs):
|
2024-06-21 13:52:19 -07:00
|
|
|
|
p, _ = _infer_params(fun, jit_info, args, kwargs)
|
2024-06-20 09:57:41 -07:00
|
|
|
|
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
|
2024-04-15 09:18:46 -07:00
|
|
|
|
# TODO(yashkatariya): Add `Layout` to SDS.
|
2024-08-29 08:35:00 -07:00
|
|
|
|
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s,
|
|
|
|
|
weak_type=x.weak_type)
|
2024-06-20 09:57:41 -07:00
|
|
|
|
for x, s in zip(p.params['jaxpr'].out_avals, out_s)]
|
|
|
|
|
return tree_unflatten(p.out_tree, out)
|
2024-01-18 22:10:24 -08:00
|
|
|
|
|
2024-06-05 17:45:34 -07:00
|
|
|
|
@api_boundary
|
2024-06-06 17:42:25 -07:00
|
|
|
|
def trace(*args, **kwargs) -> stages.Traced:
|
2024-06-21 13:52:19 -07:00
|
|
|
|
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
|
2024-06-20 09:57:41 -07:00
|
|
|
|
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
|
|
|
|
|
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
|
2024-06-21 13:52:19 -07:00
|
|
|
|
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
|
2024-06-17 11:58:18 -07:00
|
|
|
|
pgle_profiler=None)
|
2024-06-20 09:57:41 -07:00
|
|
|
|
return stages.Traced(
|
2024-09-12 19:02:57 -07:00
|
|
|
|
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
|
2024-06-21 13:52:19 -07:00
|
|
|
|
lower_callable, args_flat, p.arg_names, p.num_consts)
|
2024-06-05 17:45:34 -07:00
|
|
|
|
|
2024-06-21 13:52:19 -07:00
|
|
|
|
wrapped = _cpp_pjit(fun, jit_info)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
wrapped.lower = lower
|
2024-01-18 22:10:24 -08:00
|
|
|
|
wrapped.eval_shape = eval_shape
|
2024-06-06 17:42:25 -07:00
|
|
|
|
wrapped.trace = trace
|
2023-01-12 17:23:55 -08:00
|
|
|
|
return wrapped
|
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
|
2024-03-21 05:35:44 -07:00
|
|
|
|
def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
|
|
|
|
|
donate_argnums: int | Sequence[int] | None,
|
|
|
|
|
donate_argnames: str | Iterable[str] | None,
|
|
|
|
|
static_argnums: int | Sequence[int] | None,
|
|
|
|
|
static_argnames: str | Iterable[str] | None,
|
|
|
|
|
device: xc.Device | None, backend: str | None,
|
|
|
|
|
abstracted_axes: Any | None, keep_unused: bool,
|
|
|
|
|
inline: bool, use_resource_env: bool) -> Any:
|
|
|
|
|
"""jit() and pjit() are thin wrappers around this function."""
|
|
|
|
|
jit_info = _parse_jit_arguments(
|
|
|
|
|
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
|
|
|
|
static_argnums, static_argnames, device, backend, abstracted_axes,
|
|
|
|
|
keep_unused, inline, use_resource_env)
|
2024-06-21 13:52:19 -07:00
|
|
|
|
return _make_jit_wrapper(fun, jit_info)
|
2024-03-21 05:35:44 -07:00
|
|
|
|
|
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
class PjitParams(NamedTuple):
|
2024-06-21 13:52:19 -07:00
|
|
|
|
consts: list[Any] # Only jaxpr constants, we can't keep other arguments alive
|
2024-06-20 09:57:41 -07:00
|
|
|
|
params: dict[str, Any]
|
|
|
|
|
in_avals: tuple[core.AbstractValue, ...]
|
|
|
|
|
in_tree: PyTreeDef
|
|
|
|
|
out_tree: PyTreeDef
|
|
|
|
|
donated_invars: tuple[bool, ...]
|
|
|
|
|
arg_names: tuple[str, ...] | None
|
|
|
|
|
num_consts: int
|
|
|
|
|
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
|
|
|
|
|
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-06-21 13:52:19 -07:00
|
|
|
|
def _infer_params_impl(
|
|
|
|
|
fun: Callable,
|
|
|
|
|
ji: PjitInfo,
|
|
|
|
|
pjit_mesh: mesh_lib.Mesh | None,
|
|
|
|
|
resource_env: mesh_lib.ResourceEnv | None,
|
|
|
|
|
args: tuple[Any, ...],
|
|
|
|
|
kwargs: dict[str, Any],
|
|
|
|
|
in_avals: tuple[core.AbstractValue, ...] | None,
|
|
|
|
|
) -> tuple[PjitParams, list[Any]]:
|
2024-03-21 08:59:28 -07:00
|
|
|
|
have_kwargs = bool(kwargs)
|
2024-06-20 09:57:41 -07:00
|
|
|
|
if have_kwargs and ji.user_specified_in_shardings:
|
2023-01-12 17:23:55 -08:00
|
|
|
|
raise ValueError(
|
2023-02-11 15:29:38 -08:00
|
|
|
|
"pjit does not support kwargs when in_shardings is specified.")
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-06-21 13:52:19 -07:00
|
|
|
|
if pjit_mesh is not None:
|
2024-03-21 05:35:44 -07:00
|
|
|
|
jit_name = 'pjit'
|
2024-06-20 09:57:41 -07:00
|
|
|
|
if (ji.backend or ji.device) and not pjit_mesh.empty:
|
2024-03-21 08:59:28 -07:00
|
|
|
|
raise ValueError(
|
|
|
|
|
"Mesh context manager should not be used with jit when backend or "
|
|
|
|
|
"device is also specified as an argument to jit.")
|
2023-01-12 17:23:55 -08:00
|
|
|
|
else:
|
2024-03-21 05:35:44 -07:00
|
|
|
|
jit_name = 'jit'
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
|
|
|
|
|
ji.static_argnums, ji.static_argnames)
|
2024-06-21 13:52:19 -07:00
|
|
|
|
f = lu.wrap_init(fun)
|
2024-01-18 22:10:24 -08:00
|
|
|
|
f, res_paths = result_paths(f)
|
2024-06-20 09:57:41 -07:00
|
|
|
|
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
|
2024-01-18 22:10:24 -08:00
|
|
|
|
del args
|
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
f, dyn_kwargs = argnames_partial_except(f, ji.static_argnames, kwargs)
|
2024-01-18 22:10:24 -08:00
|
|
|
|
explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs))
|
|
|
|
|
flat_fun, out_tree = flatten_fun(f, in_tree)
|
2024-02-13 16:45:27 -08:00
|
|
|
|
flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
if (ji.donate_argnums or ji.donate_argnames) and not config.debug_nans.value:
|
|
|
|
|
donated_invars = donation_vector(ji.donate_argnums, ji.donate_argnames, in_tree)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
else:
|
2023-03-22 20:54:45 -07:00
|
|
|
|
donated_invars = (False,) * len(explicit_args)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
|
# If backend or device is set as an arg on jit, then resolve them to
|
|
|
|
|
# in_shardings and out_shardings as if user passed in in_shardings
|
|
|
|
|
# and out_shardings.
|
2024-06-20 09:57:41 -07:00
|
|
|
|
device_or_backend_set = bool(ji.backend or ji.device)
|
2024-03-21 08:59:28 -07:00
|
|
|
|
if device_or_backend_set:
|
2024-06-20 09:57:41 -07:00
|
|
|
|
sharding = _create_sharding_with_device_backend(ji.device, ji.backend)
|
2024-03-21 08:59:28 -07:00
|
|
|
|
leaves, treedef = tree_flatten(sharding)
|
|
|
|
|
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
|
|
|
|
|
in_shardings_treedef = out_shardings_treedef = treedef
|
2023-01-12 17:23:55 -08:00
|
|
|
|
else:
|
2024-03-21 08:59:28 -07:00
|
|
|
|
in_shardings_leaves = tuple(
|
|
|
|
|
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
|
2024-06-20 09:57:41 -07:00
|
|
|
|
for x in ji.in_shardings_leaves)
|
|
|
|
|
in_shardings_treedef = ji.in_shardings_treedef
|
2024-03-21 08:59:28 -07:00
|
|
|
|
out_shardings_leaves = tuple(
|
|
|
|
|
_create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name)
|
2024-06-20 09:57:41 -07:00
|
|
|
|
for x in ji.out_shardings_leaves)
|
|
|
|
|
out_shardings_treedef = ji.out_shardings_treedef
|
2023-02-11 15:29:38 -08:00
|
|
|
|
|
2024-03-21 08:59:28 -07:00
|
|
|
|
assert None not in in_shardings_leaves
|
|
|
|
|
assert None not in out_shardings_leaves
|
2023-06-15 15:21:36 -07:00
|
|
|
|
|
2024-06-26 14:44:52 -04:00
|
|
|
|
in_type: core.InputType | tuple[core.AbstractValue, ...]
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.dynamic_shapes.value:
|
2023-03-22 20:54:45 -07:00
|
|
|
|
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
|
2023-03-28 10:29:01 -07:00
|
|
|
|
in_avals = tuple(a for a, e in in_type if e)
|
2024-06-21 13:52:19 -07:00
|
|
|
|
elif in_avals is None:
|
2023-04-05 11:23:02 -07:00
|
|
|
|
avals = []
|
|
|
|
|
for i, a in enumerate(explicit_args):
|
|
|
|
|
try:
|
|
|
|
|
avals.append(shaped_abstractify(a))
|
|
|
|
|
except OverflowError as e:
|
|
|
|
|
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg
|
|
|
|
|
else f"flattened argument number is {i}")
|
|
|
|
|
raise OverflowError(
|
|
|
|
|
"An overflow was encountered while parsing an argument to a jitted "
|
|
|
|
|
f"computation, whose {arg_path}."
|
|
|
|
|
) from e
|
|
|
|
|
in_type = in_avals = tuple(avals)
|
2024-06-21 13:52:19 -07:00
|
|
|
|
else:
|
|
|
|
|
in_type = in_avals
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
|
|
|
|
|
in_shardings_treedef, in_shardings_leaves,
|
2024-06-20 09:57:41 -07:00
|
|
|
|
ji.in_layouts_treedef, ji.in_layouts_leaves,
|
2024-03-21 08:59:28 -07:00
|
|
|
|
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
attr_token = _attr_token(flat_fun, in_type)
|
2024-07-12 08:09:54 -07:00
|
|
|
|
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
|
2024-06-20 09:57:41 -07:00
|
|
|
|
flat_fun, in_type, attr_token, dbg,
|
|
|
|
|
HashableFunction(res_paths, closure=()),
|
|
|
|
|
IgnoreKey(ji.inline))
|
|
|
|
|
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
|
2024-07-09 07:32:38 -07:00
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
|
|
|
|
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
|
|
|
|
|
ji.out_layouts_leaves, HashableFunction(out_tree, closure=()),
|
2024-07-12 08:09:54 -07:00
|
|
|
|
tuple(out_avals), jaxpr.jaxpr.debug_info, device_or_backend_set)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.dynamic_shapes.value:
|
2024-06-20 09:57:41 -07:00
|
|
|
|
implicit_args = _extract_implicit_args(
|
|
|
|
|
cast(core.InputType, in_type), explicit_args)
|
2023-03-22 20:54:45 -07:00
|
|
|
|
else:
|
|
|
|
|
implicit_args = []
|
|
|
|
|
args_flat = [*implicit_args, *explicit_args]
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-05-22 23:30:55 -04:00
|
|
|
|
num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked)
|
|
|
|
|
num_extra_args = len(implicit_args) + num_states_in + len(consts)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat
|
2023-11-15 08:48:17 -08:00
|
|
|
|
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
|
2023-03-22 20:54:45 -07:00
|
|
|
|
donated_invars = (False,) * num_extra_args + donated_invars
|
2024-04-05 20:08:48 -07:00
|
|
|
|
assert (len(in_shardings_flat) == len(in_layouts_flat) ==
|
2024-05-22 23:30:55 -04:00
|
|
|
|
len(donated_invars) == num_states_in + len(consts) + len(args_flat))
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
|
|
|
|
params = dict(
|
|
|
|
|
jaxpr=jaxpr,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_shardings=in_shardings_flat,
|
|
|
|
|
out_shardings=out_shardings_flat,
|
|
|
|
|
in_layouts=in_layouts_flat,
|
|
|
|
|
out_layouts=out_layouts_flat,
|
2023-01-12 17:23:55 -08:00
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
donated_invars=donated_invars,
|
2024-07-02 13:07:46 -04:00
|
|
|
|
name=fun_qual_name(flat_fun),
|
2024-06-20 09:57:41 -07:00
|
|
|
|
keep_unused=ji.keep_unused,
|
|
|
|
|
inline=ji.inline,
|
2023-01-12 17:23:55 -08:00
|
|
|
|
)
|
2024-06-21 13:52:19 -07:00
|
|
|
|
return PjitParams(consts, params, in_avals, in_tree, out_tree(),
|
2024-06-20 09:57:41 -07:00
|
|
|
|
donated_invars, dbg.arg_names if dbg else None, len(consts),
|
2024-06-21 13:52:19 -07:00
|
|
|
|
attrs_tracked), args_flat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InferParamsCacheEntry:
|
|
|
|
|
"""Mutable value object for _infer_params_cached."""
|
|
|
|
|
__slots__ = ['pjit_params']
|
|
|
|
|
|
|
|
|
|
pjit_params: PjitParams | None
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.pjit_params = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We use an outer cache that is keyed on the signature of the arguments, but
|
|
|
|
|
# when populating a cache entry using _infer_params_impl, we need to provide
|
|
|
|
|
# actual arguments. In principle we could refactor _infer_params_impl to look
|
|
|
|
|
# only at an argument signature instead of args/kwargs in those cases that we
|
|
|
|
|
# cache, but this was a more minimal change.
|
|
|
|
|
@util.weakref_lru_cache
|
|
|
|
|
def _infer_params_cached(
|
|
|
|
|
fun: Callable,
|
|
|
|
|
jit_info: PjitInfo,
|
|
|
|
|
signature: jax_jit.ArgumentSignature,
|
|
|
|
|
in_avals: tuple[core.AbstractValue, ...],
|
|
|
|
|
pjit_mesh: mesh_lib.Mesh | None,
|
|
|
|
|
resource_env: mesh_lib.ResourceEnv | None,
|
|
|
|
|
) -> InferParamsCacheEntry:
|
|
|
|
|
return InferParamsCacheEntry()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _infer_params(
|
|
|
|
|
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
|
|
|
) -> tuple[PjitParams, list[Any]]:
|
|
|
|
|
if ji.use_resource_env:
|
|
|
|
|
# We need to fetch the mesh from inside the wrapped function, because
|
|
|
|
|
# meshes are dynamically scoped (i.e., with a context manager).
|
|
|
|
|
resource_env = mesh_lib.thread_resources.env
|
|
|
|
|
pjit_mesh = resource_env.physical_mesh
|
|
|
|
|
else:
|
|
|
|
|
resource_env = None
|
|
|
|
|
pjit_mesh = None
|
|
|
|
|
|
2024-07-29 18:43:56 -07:00
|
|
|
|
skip_cache = config.dynamic_shapes.value
|
2024-06-21 13:52:19 -07:00
|
|
|
|
if not skip_cache:
|
|
|
|
|
signature, dynargs = jax_jit.parse_arguments(
|
|
|
|
|
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
|
|
|
|
ji.static_argnames, tree_util.default_registry)
|
|
|
|
|
try:
|
|
|
|
|
avals = tuple(shaped_abstractify(a) for a in dynargs)
|
|
|
|
|
except (OverflowError, TypeError):
|
|
|
|
|
# If we see something we don't understand, use the slow path.
|
|
|
|
|
skip_cache = True
|
|
|
|
|
|
|
|
|
|
if skip_cache:
|
|
|
|
|
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,
|
|
|
|
|
kwargs, in_avals=None)
|
|
|
|
|
return p, p.consts + args_flat
|
|
|
|
|
|
|
|
|
|
entry = _infer_params_cached(
|
|
|
|
|
fun, ji, signature, avals, pjit_mesh, resource_env)
|
|
|
|
|
if entry.pjit_params is None:
|
|
|
|
|
p, args_flat = _infer_params_impl(
|
|
|
|
|
fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
|
|
|
|
|
if p.attrs_tracked:
|
|
|
|
|
# If there are attrs_tracked, don't use the cache.
|
|
|
|
|
return p, p.consts + args_flat
|
|
|
|
|
else:
|
|
|
|
|
entry.pjit_params = p
|
|
|
|
|
return entry.pjit_params, entry.pjit_params.consts + dynargs
|
2024-06-20 09:57:41 -07:00
|
|
|
|
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2023-03-22 20:54:45 -07:00
|
|
|
|
def _extract_implicit_args(
|
2023-06-23 15:11:37 -07:00
|
|
|
|
in_type: Sequence[tuple[core.AbstractValue, bool]],
|
2023-03-22 20:54:45 -07:00
|
|
|
|
explicit_args: Sequence[Any]
|
|
|
|
|
) -> Sequence[core.Tracer]:
|
|
|
|
|
"""
|
|
|
|
|
Given an input type and explicitly-passed arguments (per the user-facing API
|
|
|
|
|
calling convention), extract implicit axis size arguments from shapes of
|
|
|
|
|
explicit arguments (for the trace-time / jaxpr-level calling convention).
|
|
|
|
|
"""
|
|
|
|
|
# First, using `in_type` construct a list to represent the full argument list,
|
|
|
|
|
# leaving the implicit arguments as None placeholders for now.
|
|
|
|
|
explicit_args_ = iter(explicit_args)
|
|
|
|
|
args = [next(explicit_args_) if expl else None for _, expl in in_type]
|
|
|
|
|
assert next(explicit_args_, None) is None
|
|
|
|
|
del explicit_args, explicit_args_
|
|
|
|
|
|
|
|
|
|
# Next, populate the implicit arguments using the DBIdxs in `in_type`.
|
|
|
|
|
for i, (aval, explicit) in enumerate(in_type):
|
|
|
|
|
if not explicit or not isinstance(aval, core.DShapedArray):
|
|
|
|
|
continue # can't populate an implicit argument
|
|
|
|
|
arg = args[i]
|
|
|
|
|
assert arg is not None
|
|
|
|
|
for d1, d2 in zip(aval.shape, arg.aval.shape):
|
|
|
|
|
if isinstance(d1, core.DBIdx):
|
|
|
|
|
if args[d1.val] is None:
|
|
|
|
|
args[d1.val] = d2
|
|
|
|
|
assert core.same_referent(args[d1.val], d2)
|
|
|
|
|
assert all(x is not None for x in args)
|
2024-05-22 06:35:38 -07:00
|
|
|
|
return [x for x, (_, e) in zip(args, in_type) if not e] # pytype: disable=bad-return-type
|
2023-03-22 20:54:45 -07:00
|
|
|
|
|
|
|
|
|
def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
2023-12-11 13:59:29 +00:00
|
|
|
|
) -> list[pe.AbstractedAxesSpec] | None:
|
2023-03-22 20:54:45 -07:00
|
|
|
|
if abstracted_axes is None: return None
|
|
|
|
|
if kwargs: raise NotImplementedError
|
|
|
|
|
def ax_leaf(l):
|
|
|
|
|
return (isinstance(l, dict) and all_leaves(l.values()) or
|
|
|
|
|
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
|
|
|
|
|
return broadcast_prefix(abstracted_axes, args, ax_leaf)
|
|
|
|
|
|
2023-01-12 17:23:55 -08:00
|
|
|
|
|
2024-01-18 22:10:24 -08:00
|
|
|
|
class JitWrapped(stages.Wrapped):
|
|
|
|
|
|
|
|
|
|
def eval_shape(self, *args, **kwargs):
|
|
|
|
|
"""See ``jax.eval_shape``."""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2024-06-06 17:42:25 -07:00
|
|
|
|
def trace(self, *args, **kwargs) -> stages.Traced:
|
2024-06-05 17:45:34 -07:00
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2024-01-18 22:10:24 -08:00
|
|
|
|
|
2023-02-11 15:29:38 -08:00
|
|
|
|
# in_shardings and out_shardings can't be None as the default value
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# because `None` means that the input is fully replicated.
|
|
|
|
|
def pjit(
|
|
|
|
|
fun: Callable,
|
2023-04-10 10:15:08 -07:00
|
|
|
|
in_shardings=UNSPECIFIED,
|
|
|
|
|
out_shardings=UNSPECIFIED,
|
2023-12-11 13:59:29 +00:00
|
|
|
|
static_argnums: int | Sequence[int] | None = None,
|
|
|
|
|
static_argnames: str | Iterable[str] | None = None,
|
|
|
|
|
donate_argnums: int | Sequence[int] | None = None,
|
|
|
|
|
donate_argnames: str | Iterable[str] | None = None,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
keep_unused: bool = False,
|
2023-12-11 13:59:29 +00:00
|
|
|
|
device: xc.Device | None = None,
|
|
|
|
|
backend: str | None = None,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
inline: bool = False,
|
2023-12-11 13:59:29 +00:00
|
|
|
|
abstracted_axes: Any | None = None,
|
2024-01-18 22:10:24 -08:00
|
|
|
|
) -> JitWrapped:
|
2022-12-16 13:06:38 -08:00
|
|
|
|
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
|
|
|
|
|
|
2023-05-25 10:13:50 -07:00
|
|
|
|
NOTE: This function is now equivalent to jax.jit please use that instead.
|
2022-12-16 13:06:38 -08:00
|
|
|
|
The returned function has semantics equivalent to those of ``fun``, but is
|
|
|
|
|
compiled to an XLA computation that runs across multiple devices
|
|
|
|
|
(e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted
|
|
|
|
|
version of ``fun`` would not fit in a single device's memory, or to speed up
|
|
|
|
|
``fun`` by running each operation in parallel across multiple devices.
|
|
|
|
|
|
|
|
|
|
The partitioning over devices happens automatically based on the
|
2023-02-11 15:29:38 -08:00
|
|
|
|
propagation of the input partitioning specified in ``in_shardings`` and
|
|
|
|
|
the output partitioning specified in ``out_shardings``. The resources
|
2022-12-16 13:06:38 -08:00
|
|
|
|
specified in those two arguments must refer to mesh axes, as defined by
|
2023-02-03 14:28:07 -08:00
|
|
|
|
the :py:func:`jax.sharding.Mesh` context manager. Note that the mesh
|
2022-12-16 13:06:38 -08:00
|
|
|
|
definition at :func:`~pjit` application time is ignored, and the returned function
|
|
|
|
|
will use the mesh definition available at each call site.
|
|
|
|
|
|
|
|
|
|
Inputs to a :func:`~pjit`'d function will be automatically partitioned across devices
|
2023-02-11 15:29:38 -08:00
|
|
|
|
if they're not already correctly partitioned based on ``in_shardings``.
|
2022-12-16 13:06:38 -08:00
|
|
|
|
In some scenarios, ensuring that the inputs are already correctly pre-partitioned
|
|
|
|
|
can increase performance. For example, if passing the output of one
|
|
|
|
|
:func:`~pjit`'d function to another :func:`~pjit`’d function (or the same
|
|
|
|
|
:func:`~pjit`’d function in a loop), make sure the relevant
|
2023-02-11 15:29:38 -08:00
|
|
|
|
``out_shardings`` match the corresponding ``in_shardings``.
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
**Multi-process platforms:** On multi-process platforms such as TPU pods,
|
|
|
|
|
:func:`~pjit` can be used to run computations across all available devices across
|
|
|
|
|
processes. To achieve this, :func:`~pjit` is designed to be used in SPMD Python
|
|
|
|
|
programs, where every process is running the same Python code such that all
|
|
|
|
|
processes run the same :func:`~pjit`'d function in the same order.
|
|
|
|
|
|
|
|
|
|
When running in this configuration, the mesh should contain devices across
|
2024-05-09 08:37:43 -07:00
|
|
|
|
all processes. All inputs arguments must be globally shaped.
|
|
|
|
|
``fun`` will still be executed across *all* devices in the mesh,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
including those from other processes, and will be given a global view of the
|
2024-05-09 08:37:43 -07:00
|
|
|
|
data spread across multiple processes as a single array.
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
The SPMD model also requires that the same multi-process :func:`~pjit`'d
|
|
|
|
|
functions must be run in the same order on all processes, but they can be
|
|
|
|
|
interspersed with arbitrary operations running in a single process.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
fun: Function to be compiled. Should be a pure function, as side-effects may
|
|
|
|
|
only be executed once. Its arguments and return value should be arrays,
|
|
|
|
|
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
|
|
|
|
|
Positional arguments indicated by ``static_argnums`` can be anything at
|
|
|
|
|
all, provided they are hashable and have an equality operation defined.
|
|
|
|
|
Static arguments are included as part of a compilation cache key, which is
|
|
|
|
|
why hash and equality operators must be defined.
|
2023-02-11 15:29:38 -08:00
|
|
|
|
in_shardings: Pytree of structure matching that of arguments to ``fun``,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
with all actual arguments replaced by resource assignment specifications.
|
|
|
|
|
It is also valid to specify a pytree prefix (e.g. one value in place of a
|
|
|
|
|
whole subtree), in which case the leaves get broadcast to all values in
|
|
|
|
|
that subtree.
|
|
|
|
|
|
2023-04-06 10:49:57 -07:00
|
|
|
|
The ``in_shardings`` argument is optional. JAX will infer the shardings
|
|
|
|
|
from the input :py:class:`jax.Array`'s, and defaults to replicating the input
|
|
|
|
|
if the sharding cannot be inferred.
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
The valid resource assignment specifications are:
|
2023-04-07 09:35:51 -07:00
|
|
|
|
|
2024-06-05 09:06:36 -07:00
|
|
|
|
- :py:class:`Sharding`, which will decide how the value
|
2023-04-07 09:35:51 -07:00
|
|
|
|
will be partitioned. With this, using a mesh context manager is not
|
|
|
|
|
required.
|
2023-06-15 15:21:36 -07:00
|
|
|
|
- :py:obj:`None` is a special case whose semantics are:
|
2023-06-16 13:14:38 -07:00
|
|
|
|
- if the mesh context manager is *not* provided, JAX has the freedom to
|
|
|
|
|
choose whatever sharding it wants.
|
|
|
|
|
For in_shardings, JAX will mark is as replicated but this behavior
|
|
|
|
|
can change in the future.
|
|
|
|
|
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
|
|
|
|
determine the output shardings.
|
|
|
|
|
- If the mesh context manager is provided, None will imply that the
|
|
|
|
|
value will be replicated on all devices of the mesh.
|
2023-04-07 09:35:51 -07:00
|
|
|
|
- For backwards compatibility, in_shardings still supports ingesting
|
2023-06-15 15:21:36 -07:00
|
|
|
|
:py:class:`PartitionSpec`. This option can *only* be used with the
|
|
|
|
|
mesh context manager.
|
2023-06-16 13:14:38 -07:00
|
|
|
|
|
2023-04-07 09:35:51 -07:00
|
|
|
|
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
|
|
|
|
|
of the partitioned value. Each element can be a :py:obj:`None`, a mesh
|
|
|
|
|
axis or a tuple of mesh axes, and specifies the set of resources assigned
|
|
|
|
|
to partition the value's dimension matching its position in the spec.
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
The size of every dimension has to be a multiple of the total number of
|
|
|
|
|
resources assigned to it.
|
2023-02-11 15:29:38 -08:00
|
|
|
|
out_shardings: Like ``in_shardings``, but specifies resource
|
2022-12-16 13:06:38 -08:00
|
|
|
|
assignment for function outputs.
|
2023-04-07 09:35:51 -07:00
|
|
|
|
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
|
|
|
|
|
will use GSPMD's sharding propagation to determine how to shard the outputs.
|
2022-12-16 13:06:38 -08:00
|
|
|
|
static_argnums: An optional int or collection of ints that specify which
|
|
|
|
|
positional arguments to treat as static (compile-time constant).
|
|
|
|
|
Operations that only depend on static arguments will be constant-folded in
|
|
|
|
|
Python (during tracing), and so the corresponding argument values can be
|
|
|
|
|
any Python object.
|
|
|
|
|
|
|
|
|
|
Static arguments should be hashable, meaning both ``__hash__`` and
|
|
|
|
|
``__eq__`` are implemented, and immutable. Calling the jitted function
|
|
|
|
|
with different values for these constants will trigger recompilation.
|
|
|
|
|
Arguments that are not arrays or containers thereof must be marked as
|
|
|
|
|
static.
|
|
|
|
|
|
|
|
|
|
If ``static_argnums`` is not provided, no arguments are treated as static.
|
|
|
|
|
static_argnames: An optional string or collection of strings specifying
|
|
|
|
|
which named arguments to treat as static (compile-time constant). See the
|
|
|
|
|
comment on ``static_argnums`` for details. If not
|
|
|
|
|
provided but ``static_argnums`` is set, the default is based on calling
|
|
|
|
|
``inspect.signature(fun)`` to find corresponding named arguments.
|
2023-07-14 14:27:29 -07:00
|
|
|
|
donate_argnums: Specify which positional argument buffers are "donated" to
|
|
|
|
|
the computation. It is safe to donate argument buffers if you no longer
|
|
|
|
|
need them once the computation has finished. In some cases XLA can make
|
|
|
|
|
use of donated buffers to reduce the amount of memory needed to perform a
|
|
|
|
|
computation, for example recycling one of your input buffers to store a
|
|
|
|
|
result. You should not reuse buffers that you donate to a computation, JAX
|
|
|
|
|
will raise an error if you try to. By default, no argument buffers are
|
|
|
|
|
donated.
|
|
|
|
|
|
|
|
|
|
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
|
|
|
|
|
arguments are donated. If ``donate_argnums`` is not provided but
|
|
|
|
|
``donate_argnames`` is, or vice versa, JAX uses
|
|
|
|
|
:code:`inspect.signature(fun)` to find any positional arguments that
|
|
|
|
|
correspond to ``donate_argnames``
|
|
|
|
|
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
|
|
|
|
|
provided, ``inspect.signature`` is not used, and only actual
|
|
|
|
|
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
|
|
|
|
|
be donated.
|
|
|
|
|
|
|
|
|
|
For more details on buffer donation see the
|
|
|
|
|
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
2023-07-12 15:09:18 -07:00
|
|
|
|
donate_argnames: An optional string or collection of strings specifying
|
|
|
|
|
which named arguments are donated to the computation. See the
|
|
|
|
|
comment on ``donate_argnums`` for details. If not
|
|
|
|
|
provided but ``donate_argnums`` is set, the default is based on calling
|
|
|
|
|
``inspect.signature(fun)`` to find corresponding named arguments.
|
2022-12-16 13:06:38 -08:00
|
|
|
|
keep_unused: If `False` (the default), arguments that JAX determines to be
|
|
|
|
|
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
|
|
|
|
Such arguments will not be transferred to the device nor provided to the
|
|
|
|
|
underlying executable. If `True`, unused arguments will not be pruned.
|
|
|
|
|
device: This argument is deprecated. Please put your arguments on the
|
|
|
|
|
device you want before passing them to jit.
|
|
|
|
|
Optional, the Device the jitted function will run on. (Available devices
|
|
|
|
|
can be retrieved via :py:func:`jax.devices`.) The default is inherited
|
|
|
|
|
from XLA's DeviceAssignment logic and is usually to use
|
|
|
|
|
``jax.devices()[0]``.
|
|
|
|
|
backend: This argument is deprecated. Please put your arguments on the
|
|
|
|
|
backend you want before passing them to jit.
|
|
|
|
|
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
|
|
|
|
``'tpu'``.
|
2023-04-07 09:35:51 -07:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
Returns:
|
|
|
|
|
A wrapped version of ``fun``, set up for just-in-time compilation and
|
|
|
|
|
automatically partitioned by the mesh available at each call site.
|
|
|
|
|
|
|
|
|
|
For example, a convolution operator can be automatically partitioned over
|
|
|
|
|
an arbitrary set of devices by a single :func:`~pjit` application:
|
|
|
|
|
|
|
|
|
|
>>> import jax
|
|
|
|
|
>>> import jax.numpy as jnp
|
|
|
|
|
>>> import numpy as np
|
2022-12-22 08:40:36 -08:00
|
|
|
|
>>> from jax.sharding import Mesh, PartitionSpec
|
|
|
|
|
>>> from jax.experimental.pjit import pjit
|
2022-12-16 13:06:38 -08:00
|
|
|
|
>>>
|
|
|
|
|
>>> x = jnp.arange(8, dtype=jnp.float32)
|
|
|
|
|
>>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'),
|
2023-02-11 15:29:38 -08:00
|
|
|
|
... in_shardings=None, out_shardings=PartitionSpec('devices'))
|
2022-12-16 13:06:38 -08:00
|
|
|
|
>>> with Mesh(np.array(jax.devices()), ('devices',)):
|
|
|
|
|
... print(f(x)) # doctest: +SKIP
|
|
|
|
|
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
|
|
|
|
|
"""
|
2024-03-21 05:35:44 -07:00
|
|
|
|
return make_jit(
|
2023-07-12 15:09:18 -07:00
|
|
|
|
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
2024-03-21 05:35:44 -07:00
|
|
|
|
static_argnums, static_argnames, device, backend, abstracted_axes,
|
|
|
|
|
keep_unused, inline, use_resource_env=True)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hashable_pytree(pytree):
|
|
|
|
|
vals, treedef = tree_flatten(pytree)
|
|
|
|
|
vals = tuple(vals)
|
|
|
|
|
return HashableFunction(lambda: tree_unflatten(treedef, vals),
|
|
|
|
|
closure=(treedef, vals))
|
|
|
|
|
|
|
|
|
|
|
2023-05-03 19:28:54 -07:00
|
|
|
|
def _create_sharding_for_array(mesh, x, name, api_name):
|
2023-06-15 15:21:36 -07:00
|
|
|
|
if x is None and (mesh is None or mesh.empty):
|
|
|
|
|
return UNSPECIFIED
|
2024-06-05 09:06:36 -07:00
|
|
|
|
if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
return x
|
2023-01-12 17:23:55 -08:00
|
|
|
|
if mesh is None:
|
2024-06-05 09:06:36 -07:00
|
|
|
|
msg = ('jax.jit only supports `Sharding`s being passed to'
|
2023-04-06 10:49:57 -07:00
|
|
|
|
f' {name}. Looks like you are passing either `PartitionSpec` or `None`'
|
|
|
|
|
f' which is not allowed in jax.jit.\n')
|
|
|
|
|
if name == 'in_shardings':
|
|
|
|
|
msg += (f'Note that {name} argument is optional. JAX will infer the shardings'
|
|
|
|
|
" from the input jax.Array's and will default to replicating the"
|
|
|
|
|
' input if the sharding cannot be inferred.')
|
|
|
|
|
elif name == 'out_shardings':
|
|
|
|
|
msg += (f'Note that {name} is optional. If not specified, jax.jit will'
|
|
|
|
|
" use GSPMD's sharding propagation to figure out what the sharding"
|
|
|
|
|
' of the output(s) should be.')
|
|
|
|
|
raise RuntimeError(msg)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
if mesh.empty:
|
2023-03-17 13:33:45 -07:00
|
|
|
|
raise RuntimeError(
|
2023-05-03 19:28:54 -07:00
|
|
|
|
f'{api_name} requires a non-empty mesh if you are passing'
|
|
|
|
|
f' `PartitionSpec`s or `None` to {name}! Is a mesh defined at the call'
|
2024-06-05 09:06:36 -07:00
|
|
|
|
f' site? Alternatively, provide `Sharding`s to {name} and'
|
2023-05-03 19:28:54 -07:00
|
|
|
|
' then the mesh context manager is not required.')
|
2023-04-10 10:15:08 -07:00
|
|
|
|
# A nice user error is raised in prepare_axis_resources.
|
2023-06-15 15:21:36 -07:00
|
|
|
|
assert x is None or isinstance(x, ParsedPartitionSpec), x
|
2024-09-19 11:38:01 -07:00
|
|
|
|
return (pxla.create_mesh_pspec_sharding(mesh, x) if x is None else
|
|
|
|
|
pxla.create_mesh_pspec_sharding(mesh, x.get_partition_spec(), x))
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_sharding_with_device_backend(device, backend):
|
|
|
|
|
if device is not None:
|
|
|
|
|
assert backend is None
|
|
|
|
|
out = SingleDeviceSharding(device)
|
|
|
|
|
elif backend is not None:
|
|
|
|
|
assert device is None
|
2023-09-08 09:17:53 -07:00
|
|
|
|
out = SingleDeviceSharding(xb.get_backend(backend).local_devices()[0])
|
2024-03-26 13:28:03 -07:00
|
|
|
|
else:
|
|
|
|
|
raise AssertionError('Unreachable!')
|
|
|
|
|
out._device_backend = True
|
2022-12-16 13:06:38 -08:00
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def flatten_axis_resources(what, tree, shardings, tupled_args):
|
|
|
|
|
try:
|
|
|
|
|
return tuple(flatten_axes(what, tree, shardings, tupled_args=tupled_args))
|
|
|
|
|
except ValueError:
|
|
|
|
|
pass # Raise a tree prefix error below
|
|
|
|
|
|
|
|
|
|
# Tree leaves are always valid prefixes, so if there was a prefix error as
|
|
|
|
|
# assumed here, axis_resources must not be a leaf.
|
|
|
|
|
assert not treedef_is_leaf(tree_structure(shardings))
|
|
|
|
|
|
|
|
|
|
# Check the type directly rather than using isinstance because of namedtuples.
|
|
|
|
|
if tupled_args and (type(shardings) is not tuple or
|
|
|
|
|
len(shardings) != len(tree.children())):
|
|
|
|
|
# We know axis_resources is meant to be a tuple corresponding to the args
|
|
|
|
|
# tuple, but while it is a non-leaf pytree, either it wasn't a tuple or it
|
|
|
|
|
# wasn't the right length.
|
|
|
|
|
msg = (f"{what} specification must be a tree prefix of the positional "
|
|
|
|
|
f"arguments tuple passed to the `pjit`-decorated function. In "
|
|
|
|
|
f"particular, {what} must either be a None, a PartitionSpec, or "
|
|
|
|
|
f"a tuple of length equal to the number of positional arguments.")
|
|
|
|
|
# If `tree` represents an args tuple, then `axis_resources` must be a tuple.
|
|
|
|
|
# TODO(mattjj,apaszke): disable implicit list casts, remove 'or list' below
|
|
|
|
|
if type(shardings) is not tuple:
|
|
|
|
|
msg += f" But {what} is not a tuple: got {type(shardings)} instead."
|
|
|
|
|
elif len(shardings) != len(tree.children()):
|
|
|
|
|
msg += (f" But {what} is the wrong length: got a tuple or list of length "
|
|
|
|
|
f"{len(shardings)} for an args tuple of length "
|
|
|
|
|
f"{len(tree.children())}.")
|
|
|
|
|
|
|
|
|
|
# As an extra hint, let's check if the user just forgot to wrap
|
|
|
|
|
# shardings in a singleton tuple.
|
|
|
|
|
if len(tree.children()) == 1:
|
|
|
|
|
try: flatten_axes(what, tree, (shardings,))
|
|
|
|
|
except ValueError: pass # That's not the issue.
|
|
|
|
|
else:
|
|
|
|
|
msg += (f" Given the corresponding argument being "
|
|
|
|
|
f"passed, it looks like {what} might need to be wrapped in "
|
|
|
|
|
f"a singleton tuple.")
|
|
|
|
|
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
|
axis_tree = shardings
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-03-21 06:01:10 +01:00
|
|
|
|
# Because we only have the `tree` treedef and not the full pytree here,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# we construct a dummy tree to compare against. Revise this in callers?
|
|
|
|
|
dummy_tree = tree_unflatten(tree, [PytreeLeaf()] * tree.num_leaves)
|
|
|
|
|
errors = prefix_errors(axis_tree, dummy_tree)
|
|
|
|
|
if errors:
|
|
|
|
|
e = errors[0] # Only show information about the first disagreement found.
|
|
|
|
|
raise e(what)
|
|
|
|
|
|
|
|
|
|
# At this point we've failed to find a tree prefix error.
|
|
|
|
|
assert False, "Please open a bug report!" # This should be unreachable.
|
|
|
|
|
|
|
|
|
|
class PytreeLeaf:
|
|
|
|
|
def __repr__(self): return "pytree leaf"
|
|
|
|
|
|
|
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
|
@util.cache(max_size=4096, trace_context_in_key=False)
|
2024-03-21 08:59:28 -07:00
|
|
|
|
def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts_treedef, in_layouts_leaves,
|
|
|
|
|
in_avals, in_tree, debug_info,
|
2023-12-19 17:31:25 -08:00
|
|
|
|
device_or_backend_set, kws):
|
|
|
|
|
if not kws:
|
|
|
|
|
in_tree, _ = treedef_children(in_tree)
|
|
|
|
|
|
2024-03-21 08:59:28 -07:00
|
|
|
|
orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves)
|
2023-03-28 10:29:01 -07:00
|
|
|
|
# Only do this if original in_shardings are unspecified. If it is AUTO, go
|
|
|
|
|
# via flatten_axis_resources.
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified(orig_in_shardings):
|
2023-03-28 10:29:01 -07:00
|
|
|
|
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
|
2023-01-18 00:12:25 -08:00
|
|
|
|
else:
|
|
|
|
|
in_shardings_flat = flatten_axis_resources(
|
2023-11-15 08:48:17 -08:00
|
|
|
|
"pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True)
|
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts = tree_unflatten(in_layouts_treedef, in_layouts_leaves)
|
2023-11-15 08:48:17 -08:00
|
|
|
|
if in_layouts is None:
|
|
|
|
|
in_layouts_flat = (in_layouts,) * len(in_avals)
|
|
|
|
|
else:
|
|
|
|
|
in_layouts_flat = flatten_axis_resources(
|
|
|
|
|
"pjit in_layouts", in_tree, in_layouts, tupled_args=True)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-02-13 16:45:27 -08:00
|
|
|
|
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
|
|
|
|
|
attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals)
|
|
|
|
|
if not config.dynamic_shapes.value and not attrs_tracked:
|
2023-03-28 10:29:01 -07:00
|
|
|
|
pjit_check_aval_sharding(in_shardings_flat, in_avals,
|
2023-05-04 21:49:28 -07:00
|
|
|
|
None if debug_info is None else debug_info.arg_names,
|
2023-03-22 20:54:45 -07:00
|
|
|
|
"pjit arguments", allow_uneven_sharding=False)
|
2024-07-12 08:09:54 -07:00
|
|
|
|
check_aval_layout_compatibility(
|
|
|
|
|
in_layouts_flat, in_avals,
|
|
|
|
|
None if debug_info is None else debug_info.arg_names, "jit arguments")
|
2024-04-05 20:08:48 -07:00
|
|
|
|
return in_shardings_flat, in_layouts_flat
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-06-09 14:43:42 -07:00
|
|
|
|
callsites: set[str] = set()
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-06-09 14:43:42 -07:00
|
|
|
|
def explain_tracing_cache_miss(
|
2024-07-05 09:51:02 +01:00
|
|
|
|
f: Callable, unseen_f: bool, cache: dict, key: tuple):
|
2023-06-09 14:43:42 -07:00
|
|
|
|
if config.check_tracer_leaks.value: return
|
|
|
|
|
|
|
|
|
|
def unpack(key):
|
2024-05-29 04:03:02 +00:00
|
|
|
|
transforms, (), _, (in_type, _, debug_info, _, inline), *_, ctx = key
|
2024-02-13 16:45:27 -08:00
|
|
|
|
# TODO(dougalm,mattjj): enable cache miss explanation with attrs
|
2024-03-25 12:54:06 -07:00
|
|
|
|
_, (_, (in_tree,)), *_ = transforms
|
2023-06-09 14:43:42 -07:00
|
|
|
|
return in_tree, in_type, debug_info, inline.val, ctx
|
|
|
|
|
in_tree, in_type, debug_info, inline, ctx = unpack(key)
|
|
|
|
|
if inline: return
|
|
|
|
|
|
|
|
|
|
msg: list[str] = []
|
|
|
|
|
p = msg.append
|
|
|
|
|
done = lambda: logger.log(logging.WARNING, '\n'.join(msg))
|
|
|
|
|
|
|
|
|
|
callsite = source_info_util.summarize(source_info_util.current())
|
|
|
|
|
p(f"TRACING CACHE MISS at {callsite} because:")
|
|
|
|
|
|
|
|
|
|
# have we seen this function before at all?
|
|
|
|
|
fun_name = getattr(f, '__qualname__', f)
|
2024-07-04 09:37:18 +01:00
|
|
|
|
if debug_info is not None and debug_info.func_src_info:
|
2023-06-09 14:43:42 -07:00
|
|
|
|
_, _, *rest = debug_info.func_src_info.split(' ')
|
|
|
|
|
src_info = " defined at " + ' '.join(rest)
|
|
|
|
|
else:
|
|
|
|
|
src_info = ''
|
|
|
|
|
if unseen_f:
|
|
|
|
|
p(f" never seen function:\n {fun_name} id={id(f)}{src_info}")
|
|
|
|
|
if callsite in callsites:
|
|
|
|
|
p(" but seen another function defined on the same line; maybe the function is\n"
|
|
|
|
|
" being re-defined repeatedly, preventing caching?")
|
|
|
|
|
callsites.add(callsite)
|
|
|
|
|
return done()
|
|
|
|
|
else:
|
|
|
|
|
p(f" for {fun_name}{src_info}")
|
|
|
|
|
|
|
|
|
|
seen_keys = map(unpack, cache.keys())
|
|
|
|
|
|
|
|
|
|
# have we maybe switched some args to be kwargs or visa-versa?
|
|
|
|
|
args_tree, kwargs_tree = treedef_children(in_tree)
|
|
|
|
|
args_kwargs_trees = [treedef_children(k) for k, *_ in seen_keys]
|
|
|
|
|
args_kwargs_match = [t for t in args_kwargs_trees
|
|
|
|
|
if t == [args_tree, kwargs_tree]]
|
|
|
|
|
if not args_kwargs_match:
|
|
|
|
|
num_args = len(treedef_children(args_tree))
|
|
|
|
|
_, kwarg_keys = kwargs_tree.node_data() # type: ignore
|
|
|
|
|
p(f" never seen passing {num_args} positional args and {len(kwarg_keys)} "
|
|
|
|
|
"keyword args with keys:\n"
|
|
|
|
|
f" {', '.join(map(repr, kwarg_keys))}")
|
|
|
|
|
dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore
|
|
|
|
|
if t != [args_tree, kwargs_tree]]
|
2024-06-07 15:51:52 -07:00
|
|
|
|
close_kwargs = min(
|
|
|
|
|
dont_match, key=set(kwarg_keys).symmetric_difference, default=None
|
|
|
|
|
)
|
2023-06-09 14:43:42 -07:00
|
|
|
|
if not close_kwargs:
|
|
|
|
|
p(" closest seen is passing no keyword args")
|
|
|
|
|
else:
|
|
|
|
|
p(f" closest seen passes {len(close_kwargs)} keyword args with keys:\n"
|
|
|
|
|
f" {', '.join(map(repr, close_kwargs))}")
|
|
|
|
|
return done()
|
|
|
|
|
|
|
|
|
|
# have we never seen this tracing context before?
|
|
|
|
|
ctxs_match = [c for *_, c in seen_keys if c == ctx]
|
|
|
|
|
if not ctxs_match:
|
|
|
|
|
p(" tracing context doesn't match, e.g. due to config or context manager")
|
|
|
|
|
dont_match = [c for *_, c in seen_keys if c != ctx]
|
|
|
|
|
closest_ctx = min(dont_match, key=lambda c: sum(map(op.ne, c, ctx)))
|
|
|
|
|
idxs = [i for i, (c1, c2) in enumerate(zip(ctx, closest_ctx)) if c1 != c2]
|
|
|
|
|
p(" closest seen context tuple differs at positions:\n"
|
|
|
|
|
f" {', '.join(map(str, idxs))}\n"
|
|
|
|
|
" compare to tuple returned by config._trace_context() in jax/_src/config.py.")
|
|
|
|
|
return done()
|
|
|
|
|
|
|
|
|
|
# have we never seen this input pytree before?
|
|
|
|
|
trees_match = [k for k in seen_keys if k[0] == in_tree]
|
|
|
|
|
if not trees_match:
|
|
|
|
|
in_tree_str = f':\n {in_tree}' if len(str(in_tree)) < 76 else ''
|
|
|
|
|
p(f" never seen input pytree{in_tree_str}")
|
|
|
|
|
dont_match = [t for t, *_ in seen_keys if t != in_tree]
|
|
|
|
|
closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves))
|
2024-07-02 00:40:13 -07:00
|
|
|
|
errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type]
|
2023-06-09 14:43:42 -07:00
|
|
|
|
p(f" closest seen input pytree has {len(errs)} mismatches, including:")
|
|
|
|
|
for path, thing1, thing2, explanation in errs:
|
|
|
|
|
fst, *path = path # type: ignore
|
|
|
|
|
base = ['args', 'kwargs'][fst.idx]
|
2024-05-22 06:35:38 -07:00
|
|
|
|
p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1},"
|
2023-06-09 14:43:42 -07:00
|
|
|
|
f" so {explanation}")
|
|
|
|
|
return done()
|
|
|
|
|
|
|
|
|
|
# have we never seen these input types (eg shapes, dtypes) before?
|
|
|
|
|
types_match = [k for k in trees_match if k[1] == in_type]
|
|
|
|
|
if not types_match:
|
|
|
|
|
if len(in_type) < 5:
|
|
|
|
|
in_type_str = ':\n {}'.format(', '.join(
|
|
|
|
|
f'{n}: {ty.str_short(short_dtypes=True)}'
|
|
|
|
|
for n, ty in zip(debug_info.arg_names, in_type)))
|
|
|
|
|
else:
|
|
|
|
|
in_type_str = ''
|
|
|
|
|
p(f" never seen input type signature{in_type_str}")
|
|
|
|
|
dont_match = [t for _, t, *_ in trees_match if t != in_type]
|
|
|
|
|
closest_ty = min(dont_match, key=lambda t: sum(map(op.ne, t, in_type)))
|
|
|
|
|
num_mismatch = sum(map(op.ne, closest_ty, in_type))
|
|
|
|
|
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
|
|
|
|
|
add_weak_type_hint = False
|
|
|
|
|
for name, ty1, ty2 in zip(debug_info.arg_names, closest_ty, in_type):
|
|
|
|
|
if ty1 != ty2:
|
|
|
|
|
if type(ty1) == type(ty2) == core.ShapedArray:
|
|
|
|
|
s1, s2 = ty1.str_short(True), ty2.str_short(True)
|
|
|
|
|
if s1 == s2: # weak types don't show up in str_short()
|
|
|
|
|
assert ty1.weak_type ^ ty2.weak_type
|
|
|
|
|
s1 += f'{{weak_type={ty1.weak_type}}}'
|
|
|
|
|
s2 += f'{{weak_type={ty2.weak_type}}}'
|
|
|
|
|
add_weak_type_hint = True
|
|
|
|
|
else:
|
|
|
|
|
s1, s2 = str(ty1), str(ty2)
|
|
|
|
|
p(f" * at {name}, seen {s1}, but now given {s2}")
|
|
|
|
|
if add_weak_type_hint:
|
|
|
|
|
p('where weak_type=True often means a Python builtin numeric value, and ')
|
|
|
|
|
p('weak_type=False means a jax.Array.')
|
|
|
|
|
p('See https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types')
|
|
|
|
|
return done()
|
|
|
|
|
|
|
|
|
|
# we think this is unreachable...
|
2024-09-20 07:51:48 -07:00
|
|
|
|
p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax")
|
2023-06-09 14:43:42 -07:00
|
|
|
|
return done()
|
|
|
|
|
|
|
|
|
|
@partial(lu.cache, explain=explain_tracing_cache_miss)
|
2024-06-20 09:57:41 -07:00
|
|
|
|
def _create_pjit_jaxpr(
|
|
|
|
|
fun: lu.WrappedFun,
|
|
|
|
|
in_type: core.InputType | Sequence[core.AbstractValue],
|
|
|
|
|
attr_data: int,
|
|
|
|
|
debug_info: lu.TracingDebugInfo,
|
|
|
|
|
out_paths: Callable,
|
|
|
|
|
ignored_inline: IgnoreKey
|
|
|
|
|
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
|
|
|
|
|
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
2023-06-09 14:43:42 -07:00
|
|
|
|
del ignored_inline # just for explain_cache_miss
|
2023-05-15 09:15:22 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
2024-07-25 22:20:25 +03:00
|
|
|
|
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
|
2023-05-15 09:15:22 -07:00
|
|
|
|
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
|
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.dynamic_shapes.value:
|
2023-03-22 20:54:45 -07:00
|
|
|
|
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
2024-06-20 09:57:41 -07:00
|
|
|
|
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug)
|
2024-01-25 22:20:36 -08:00
|
|
|
|
attrs_tracked = []
|
2023-03-22 20:54:45 -07:00
|
|
|
|
else:
|
2024-01-25 22:20:36 -08:00
|
|
|
|
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
2023-03-22 20:54:45 -07:00
|
|
|
|
fun, in_type, debug_info=pe_debug)
|
2024-05-22 23:30:55 -04:00
|
|
|
|
# assert attr_data is sentinel or attr_data matches attrs_tracked
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
|
|
2024-01-25 22:20:36 -08:00
|
|
|
|
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
|
|
|
|
|
if not config.dynamic_shapes.value and not attrs_tracked:
|
2023-03-22 20:54:45 -07:00
|
|
|
|
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
2023-01-08 10:37:40 -08:00
|
|
|
|
|
2024-03-21 10:47:16 -07:00
|
|
|
|
if config.debug_key_reuse.value:
|
2023-12-11 12:03:48 -08:00
|
|
|
|
# Import here to avoid circular imports
|
|
|
|
|
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr
|
|
|
|
|
check_key_reuse_jaxpr(jaxpr)
|
|
|
|
|
|
2023-01-08 10:37:40 -08:00
|
|
|
|
if any(isinstance(c, core.Tracer) for c in consts):
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
|
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
2023-01-08 10:37:40 -08:00
|
|
|
|
final_consts = consts
|
|
|
|
|
else:
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
2023-01-08 10:37:40 -08:00
|
|
|
|
final_consts = []
|
2024-01-25 22:20:36 -08:00
|
|
|
|
return closed_jaxpr, final_consts, global_out_avals, attrs_tracked
|
2023-02-17 12:01:50 -08:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
|
@util.cache(max_size=4096, trace_context_in_key=False)
|
2023-02-17 12:01:50 -08:00
|
|
|
|
def _check_and_canonicalize_out_shardings(
|
2024-04-05 20:08:48 -07:00
|
|
|
|
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
|
2024-07-12 08:09:54 -07:00
|
|
|
|
out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set):
|
2024-03-21 08:59:28 -07:00
|
|
|
|
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if (is_unspecified(orig_out_shardings) or
|
2024-06-05 09:06:36 -07:00
|
|
|
|
isinstance(orig_out_shardings, sharding.Sharding)):
|
2024-07-12 08:09:54 -07:00
|
|
|
|
out_shardings_flat = (orig_out_shardings,) * len(out_avals)
|
2023-01-18 00:12:25 -08:00
|
|
|
|
else:
|
|
|
|
|
out_shardings_flat = flatten_axis_resources(
|
2023-02-11 15:29:38 -08:00
|
|
|
|
"pjit out_shardings", out_tree(), orig_out_shardings,
|
2023-01-18 00:12:25 -08:00
|
|
|
|
tupled_args=False)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
out_layouts = tree_unflatten(out_layouts_treedef, out_layouts_leaves)
|
2023-11-15 08:48:17 -08:00
|
|
|
|
if out_layouts is None:
|
2024-07-12 08:09:54 -07:00
|
|
|
|
out_layouts_flat = (out_layouts,) * len(out_avals)
|
2023-11-15 08:48:17 -08:00
|
|
|
|
else:
|
|
|
|
|
out_layouts_flat = flatten_axis_resources(
|
|
|
|
|
"pjit out_layouts", out_tree(), out_layouts, tupled_args=False)
|
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if not config.dynamic_shapes.value:
|
2023-05-04 21:49:28 -07:00
|
|
|
|
pjit_check_aval_sharding(
|
2024-07-12 08:09:54 -07:00
|
|
|
|
out_shardings_flat, out_avals,
|
2023-05-04 21:49:28 -07:00
|
|
|
|
None if debug_info is None else debug_info.result_paths,
|
|
|
|
|
"pjit outputs", allow_uneven_sharding=False)
|
2024-07-12 08:09:54 -07:00
|
|
|
|
check_aval_layout_compatibility(
|
|
|
|
|
out_layouts_flat, out_avals,
|
|
|
|
|
None if debug_info is None else debug_info.result_paths, "jit outputs")
|
2024-04-05 20:08:48 -07:00
|
|
|
|
return out_shardings_flat, out_layouts_flat
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-02-17 12:01:50 -08:00
|
|
|
|
|
2024-05-22 23:30:55 -04:00
|
|
|
|
AttrRecord = tuple[object, str, PyTreeDef, list[core.AbstractValue]]
|
2024-05-29 04:03:02 +00:00
|
|
|
|
_seen_attrs = weakref.WeakKeyDictionary() # type: ignore
|
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
def seen_attrs_get(
|
|
|
|
|
fun: lu.WrappedFun,
|
|
|
|
|
in_type: core.InputType | tuple[core.AbstractValue, ...]
|
|
|
|
|
) -> list:
|
2024-05-29 04:03:02 +00:00
|
|
|
|
cache = _seen_attrs.setdefault(fun.f, defaultdict(list))
|
|
|
|
|
assert fun.in_type is None or fun.in_type == in_type
|
|
|
|
|
return cache[(fun.transforms, fun.params, in_type)]
|
2024-05-22 23:30:55 -04:00
|
|
|
|
|
2024-06-20 09:57:41 -07:00
|
|
|
|
def _attr_token(
|
|
|
|
|
fun: lu.WrappedFun,
|
|
|
|
|
in_type: core.InputType | tuple[core.AbstractValue, ...]
|
|
|
|
|
) -> int:
|
2024-05-22 23:30:55 -04:00
|
|
|
|
from jax.experimental.attrs import jax_getattr
|
2024-05-29 04:03:02 +00:00
|
|
|
|
cases = seen_attrs_get(fun, in_type)
|
2024-05-22 23:30:55 -04:00
|
|
|
|
for i, records in enumerate(cases):
|
|
|
|
|
for obj, attr, treedef, avals in records:
|
|
|
|
|
val = jax_getattr(obj, attr)
|
|
|
|
|
vals, treedef_ = tree_flatten(val)
|
|
|
|
|
avals_ = map(shaped_abstractify, vals)
|
|
|
|
|
if treedef != treedef_ or avals != avals_: break
|
|
|
|
|
else:
|
|
|
|
|
return i
|
|
|
|
|
return len(cases)
|
|
|
|
|
|
|
|
|
|
def _attr_update(fun, in_type, i, attrs_tracked):
|
|
|
|
|
from jax.experimental.attrs import jax_getattr
|
|
|
|
|
leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr))
|
|
|
|
|
records = [(obj, attr, init_tree, map(shaped_abstractify, leaves(obj, attr)))
|
|
|
|
|
for init_tree, _, (obj, attr) in attrs_tracked]
|
2024-05-29 04:03:02 +00:00
|
|
|
|
cases = seen_attrs_get(fun, in_type)
|
2024-05-22 23:30:55 -04:00
|
|
|
|
if i == len(cases):
|
|
|
|
|
cases.append(records)
|
|
|
|
|
else:
|
|
|
|
|
assert i < len(cases) and cases[i] == records
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-06-09 14:43:42 -07:00
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class IgnoreKey:
|
|
|
|
|
val: Any
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return hash(self.__class__)
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
return isinstance(other, IgnoreKey) # ignore self.val!
|
|
|
|
|
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
def pjit_check_aval_sharding(
|
2023-12-11 13:59:29 +00:00
|
|
|
|
shardings, flat_avals, names: tuple[str, ...] | None,
|
2023-05-04 21:49:28 -07:00
|
|
|
|
what_aval: str, allow_uneven_sharding: bool):
|
|
|
|
|
new_names = [''] * len(shardings) if names is None else names
|
|
|
|
|
for aval, s, name in zip(flat_avals, shardings, new_names):
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified_or_auto(s):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
continue
|
2023-05-04 21:49:28 -07:00
|
|
|
|
name_str = f' with pytree key path {name}' if name else ''
|
2022-12-16 13:06:38 -08:00
|
|
|
|
shape = aval.shape
|
|
|
|
|
try:
|
2024-05-29 15:28:14 -07:00
|
|
|
|
# Sharding interfaces can implement `check_compatible_aval` as an optional
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# method to raise a more meaningful error.
|
2024-05-29 15:28:14 -07:00
|
|
|
|
if hasattr(s, 'check_compatible_aval'):
|
|
|
|
|
s.check_compatible_aval(shape)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
else:
|
2023-06-05 13:40:59 -07:00
|
|
|
|
s._to_xla_hlo_sharding(len(shape))
|
2022-12-16 13:06:38 -08:00
|
|
|
|
except ValueError as e:
|
2023-05-04 21:49:28 -07:00
|
|
|
|
raise ValueError(
|
|
|
|
|
f'One of {what_aval}{name_str} is incompatible with its sharding '
|
2023-10-23 15:11:15 +01:00
|
|
|
|
f'annotation {s}: {e}')
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# Use the `OpSharding` proto to find out how many ways each dimension of
|
|
|
|
|
# the aval is sharded. This approach will work across all
|
2024-06-05 09:06:36 -07:00
|
|
|
|
# Sharding.
|
2023-06-05 13:40:59 -07:00
|
|
|
|
hlo_sharding = s._to_xla_hlo_sharding(len(shape))
|
|
|
|
|
assert hlo_sharding is not None
|
|
|
|
|
num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(hlo_sharding)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
for i, size in enumerate(num_ways_dim_sharded):
|
|
|
|
|
if not allow_uneven_sharding and shape[i] % size != 0:
|
2023-05-04 21:49:28 -07:00
|
|
|
|
raise ValueError(f"One of {what_aval}{name_str} was given the sharding "
|
2022-12-16 13:06:38 -08:00
|
|
|
|
f"of {s}, which implies that "
|
2023-04-06 14:51:30 -07:00
|
|
|
|
f"the global size of its dimension {i} should be "
|
2022-12-16 13:06:38 -08:00
|
|
|
|
f"divisible by {size}, but it is equal to {shape[i]} "
|
2023-05-04 21:49:28 -07:00
|
|
|
|
f"(full shape: {shape})")
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
|
2024-07-12 08:09:54 -07:00
|
|
|
|
def check_aval_layout_compatibility(
|
|
|
|
|
layouts, flat_avals, names: tuple[str, ...] | None, what_aval: str):
|
|
|
|
|
new_names = [''] * len(layouts) if names is None else names
|
|
|
|
|
for aval, l, name in zip(flat_avals, layouts, new_names):
|
|
|
|
|
if l is None or isinstance(l, AutoLayout):
|
|
|
|
|
continue
|
|
|
|
|
name_str = f' with pytree key path {name}' if name else ''
|
|
|
|
|
shape = aval.shape
|
|
|
|
|
try:
|
|
|
|
|
l.check_compatible_aval(shape)
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f'One of {what_aval}{name_str} is incompatible with its layout '
|
|
|
|
|
f'annotation {l}: {e}')
|
|
|
|
|
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# -------------------- pjit rules --------------------
|
|
|
|
|
|
2023-01-18 12:55:31 -08:00
|
|
|
|
pjit_p = core.AxisPrimitive("pjit")
|
2022-12-16 13:06:38 -08:00
|
|
|
|
pjit_p.multiple_results = True
|
|
|
|
|
|
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
2024-03-25 10:07:55 -07:00
|
|
|
|
# If device or backend is set, return the default layout. This is because you
|
|
|
|
|
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
|
|
|
|
|
# which causes error checks to fail. Returning the default layout allows
|
|
|
|
|
# this to exist. It's the same for handling shardings.
|
2024-04-05 20:08:48 -07:00
|
|
|
|
if pxla.check_device_backend_on_shardings(resolved_in_shardings):
|
2024-03-25 10:07:55 -07:00
|
|
|
|
return (None,) * len(jit_in_layouts)
|
|
|
|
|
|
|
|
|
|
resolved_in_layouts = []
|
2024-04-05 20:08:48 -07:00
|
|
|
|
for arg, jit_in_l, rs, aval in safe_zip(
|
|
|
|
|
args, jit_in_layouts, resolved_in_shardings, in_avals):
|
2024-08-28 11:05:45 -07:00
|
|
|
|
committed = getattr(arg, '_committed', True)
|
|
|
|
|
# `arg_layout` is only used for checking purposes in the `else` branch
|
|
|
|
|
# below. We cannot replace default layout with None to raise nicer errors.
|
|
|
|
|
# `dispatch_arg_layout` replaces default layouts with `None` to simplify
|
|
|
|
|
# dispatch and lowering logic downstream.
|
|
|
|
|
if hasattr(arg, 'layout'):
|
|
|
|
|
arg_layout = arg.layout.device_local_layout
|
|
|
|
|
dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval)
|
|
|
|
|
else arg_layout)
|
|
|
|
|
else:
|
|
|
|
|
arg_layout, dispatch_arg_layout = None, None
|
2024-04-05 20:08:48 -07:00
|
|
|
|
# Sharding can be unspecified when array is committed if it's a PmapSharding.
|
|
|
|
|
is_pmap_sharding = (is_unspecified(rs) or
|
|
|
|
|
isinstance(getattr(arg, 'sharding', None), PmapSharding))
|
2024-03-25 10:07:55 -07:00
|
|
|
|
if jit_in_l is None:
|
|
|
|
|
if committed:
|
2024-04-05 20:08:48 -07:00
|
|
|
|
if is_pmap_sharding:
|
|
|
|
|
resolved_in_layouts.append(None)
|
|
|
|
|
else:
|
2024-08-28 11:05:45 -07:00
|
|
|
|
resolved_in_layouts.append(dispatch_arg_layout)
|
2024-03-25 10:07:55 -07:00
|
|
|
|
else:
|
|
|
|
|
resolved_in_layouts.append(None)
|
|
|
|
|
else:
|
2024-04-04 16:41:36 -07:00
|
|
|
|
# arg_layout can be None because some backends don't implement the
|
|
|
|
|
# required layout methods. Hence `arr.layout` can return
|
|
|
|
|
# `Layout(None, sharding)`
|
2024-07-12 09:22:44 -07:00
|
|
|
|
if (committed
|
|
|
|
|
and not is_pmap_sharding
|
|
|
|
|
and arg_layout is not None
|
|
|
|
|
and not pxla.is_user_xla_layout_equal(jit_in_l, arg_layout)):
|
2024-05-30 15:06:12 -07:00
|
|
|
|
extra_msg = ''
|
|
|
|
|
if isinstance(jit_in_l, AutoLayout):
|
|
|
|
|
extra_msg = (
|
|
|
|
|
' The layout given to `jax.jit` is `DeviceLocalLayout.AUTO` but'
|
|
|
|
|
' the corresponding argument passed is a `jax.Array` with a'
|
|
|
|
|
' concrete layout. Consider passing a `jax.ShapeDtypeStruct`'
|
|
|
|
|
' instead of `jax.Array` as an argument to the jitted function '
|
|
|
|
|
' when using `DeviceLocalLayout.AUTO`.'
|
|
|
|
|
)
|
2024-03-25 10:07:55 -07:00
|
|
|
|
raise ValueError('Layout passed to jit does not match the layout '
|
|
|
|
|
'on the respective arg. '
|
|
|
|
|
f'Got pjit layout: {jit_in_l},\n'
|
2024-04-04 16:41:36 -07:00
|
|
|
|
f'arg layout: {arg_layout} for '
|
2024-05-30 15:06:12 -07:00
|
|
|
|
f'arg shape: {shaped_abstractify(arg).str_short()}.'
|
|
|
|
|
f'{extra_msg}')
|
2024-03-25 10:07:55 -07:00
|
|
|
|
resolved_in_layouts.append(jit_in_l)
|
|
|
|
|
return tuple(resolved_in_layouts)
|
|
|
|
|
|
|
|
|
|
|
2024-09-12 18:47:25 -07:00
|
|
|
|
def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
|
|
|
|
) -> Sequence[PjitSharding]:
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# If True, means that device or backend is set by the user on pjit and it
|
|
|
|
|
# has the same semantics as device_put i.e. doesn't matter which device the
|
|
|
|
|
# arg is on, reshard it to the device mentioned. So don't do any of the
|
|
|
|
|
# checks and just return the pjit_in_shardings directly. `shard_args` will
|
|
|
|
|
# handle the resharding.
|
2023-02-07 11:16:01 -08:00
|
|
|
|
if pxla.check_device_backend_on_shardings(pjit_in_shardings):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
return pjit_in_shardings
|
|
|
|
|
|
|
|
|
|
committed_arg_shardings = []
|
|
|
|
|
for a in args:
|
2024-01-22 13:44:34 -08:00
|
|
|
|
arg_s = getattr(a, 'sharding', None)
|
|
|
|
|
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
|
|
|
|
|
# not allow None as the sharding.
|
|
|
|
|
if arg_s is None:
|
|
|
|
|
continue
|
|
|
|
|
# Don't consider PmapSharding inputs as committed. They will get resharded
|
|
|
|
|
# unconditionally.
|
|
|
|
|
if isinstance(arg_s, PmapSharding):
|
|
|
|
|
continue
|
|
|
|
|
if getattr(a, '_committed', True):
|
|
|
|
|
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
resolved_in_shardings = []
|
2023-03-22 20:54:45 -07:00
|
|
|
|
for arg, pjit_in_s in zip(args, pjit_in_shardings):
|
2023-06-26 21:46:02 -07:00
|
|
|
|
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
|
|
|
|
|
# not allow None as the sharding.
|
2022-12-16 13:06:38 -08:00
|
|
|
|
arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True))
|
2023-06-26 21:46:02 -07:00
|
|
|
|
if hasattr(arg, 'sharding') and arg.sharding is not None
|
|
|
|
|
else (UNSPECIFIED, False))
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified(pjit_in_s):
|
|
|
|
|
if is_unspecified(arg_s):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resolved_in_shardings.append(arg_s)
|
|
|
|
|
else:
|
|
|
|
|
if committed:
|
2023-01-03 16:08:07 -08:00
|
|
|
|
# If the arg has a PmapSharding, then reshard it unconditionally.
|
|
|
|
|
if isinstance(arg_s, PmapSharding):
|
2023-04-10 10:15:08 -07:00
|
|
|
|
resolved_in_shardings.append(UNSPECIFIED)
|
2023-01-03 16:08:07 -08:00
|
|
|
|
else:
|
2024-03-26 13:28:03 -07:00
|
|
|
|
resolved_in_shardings.append(arg_s)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
else:
|
|
|
|
|
if dispatch.is_single_device_sharding(arg_s):
|
2023-04-10 10:15:08 -07:00
|
|
|
|
resolved_in_shardings.append(UNSPECIFIED)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError('Having uncommitted Array sharded on '
|
|
|
|
|
'multiple devices is not supported.')
|
|
|
|
|
else:
|
2023-04-06 08:31:47 -07:00
|
|
|
|
if (isinstance(arg, np.ndarray) and
|
2023-04-14 13:55:52 -07:00
|
|
|
|
not pjit_in_s.is_fully_replicated and # type: ignore
|
|
|
|
|
xb.process_count() > 1):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
raise ValueError(
|
2023-03-15 17:08:21 -07:00
|
|
|
|
'Passing non-trivial shardings for numpy '
|
2022-12-16 13:06:38 -08:00
|
|
|
|
'inputs is not allowed. To fix this error, either specify a '
|
|
|
|
|
'replicated sharding explicitly or use '
|
|
|
|
|
'`jax.experimental.multihost_utils.host_local_array_to_global_array(...)` '
|
|
|
|
|
'to convert your host local numpy inputs to a jax.Array which you '
|
|
|
|
|
'can pass to pjit. '
|
|
|
|
|
'If the numpy input is the same on each process, then you can use '
|
|
|
|
|
'`jax.make_array_from_callback(...) to create a `jax.Array` which '
|
|
|
|
|
'you can pass to pjit. '
|
|
|
|
|
'Please see the jax.Array migration guide for more information '
|
|
|
|
|
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
|
|
|
|
|
f'Got arg shape: {arg.shape}, arg value: {arg}')
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if not is_unspecified(arg_s):
|
2023-08-04 09:43:39 -07:00
|
|
|
|
# jax.jit does not allow resharding across different memory kinds even
|
|
|
|
|
# if the argument is uncommitted. Use jax.device_put for those cases,
|
|
|
|
|
# either outside or inside jax.jit.
|
2023-08-04 16:26:31 -07:00
|
|
|
|
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore
|
2023-08-04 09:43:39 -07:00
|
|
|
|
raise ValueError(
|
|
|
|
|
'Memory kinds passed to jax.jit does not match memory kind on the'
|
|
|
|
|
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
|
2024-05-22 06:35:38 -07:00
|
|
|
|
f'arg memory kind: {arg_s.memory_kind} for ' # pytype: disable=attribute-error
|
2023-08-23 13:24:08 -07:00
|
|
|
|
f'arg shape: {shaped_abstractify(arg).str_short()}')
|
2023-01-03 16:08:07 -08:00
|
|
|
|
if (committed and
|
|
|
|
|
not isinstance(arg_s, PmapSharding) and
|
2023-04-06 08:31:47 -07:00
|
|
|
|
not op_shardings.are_op_shardings_equal(
|
2023-06-05 13:40:59 -07:00
|
|
|
|
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore
|
|
|
|
|
arg_s._to_xla_hlo_sharding(arg.ndim))):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
raise ValueError('Sharding passed to pjit does not match the sharding '
|
|
|
|
|
'on the respective arg. '
|
2024-03-26 13:28:03 -07:00
|
|
|
|
f'Got pjit sharding: {pjit_in_s},\n'
|
2023-08-23 13:24:08 -07:00
|
|
|
|
f'arg sharding: {arg_s} for '
|
|
|
|
|
f'arg shape: {shaped_abstractify(arg).str_short()}')
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resolved_in_shardings.append(pjit_in_s)
|
|
|
|
|
|
|
|
|
|
return tuple(resolved_in_shardings)
|
|
|
|
|
|
|
|
|
|
|
2024-04-06 13:43:32 -07:00
|
|
|
|
def _resolve_and_lower(
|
|
|
|
|
args, jaxpr, in_shardings, out_shardings, in_layouts,
|
|
|
|
|
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
|
2024-06-17 11:58:18 -07:00
|
|
|
|
lowering_platforms, lowering_parameters, pgle_profiler):
|
2024-09-12 18:47:25 -07:00
|
|
|
|
in_shardings = _resolve_in_shardings(args, in_shardings)
|
2024-04-06 13:43:32 -07:00
|
|
|
|
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
|
|
|
|
|
jaxpr.in_avals)
|
|
|
|
|
lowered = _pjit_lower(
|
|
|
|
|
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
|
2023-04-26 15:54:50 -07:00
|
|
|
|
donated_invars, name, keep_unused, inline,
|
2024-06-17 11:58:18 -07:00
|
|
|
|
lowering_platforms=lowering_platforms,
|
2024-05-29 01:49:06 -07:00
|
|
|
|
lowering_parameters=lowering_parameters,
|
|
|
|
|
pgle_profiler=pgle_profiler)
|
2024-04-06 13:43:32 -07:00
|
|
|
|
return lowered
|
|
|
|
|
|
2023-05-26 08:56:56 -07:00
|
|
|
|
def _pjit_call_impl_python(
|
2024-04-05 20:08:48 -07:00
|
|
|
|
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
|
|
|
|
resource_env, donated_invars, name, keep_unused, inline):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
global _most_recent_pjit_call_executable
|
|
|
|
|
|
2024-05-29 01:49:06 -07:00
|
|
|
|
compile_options = None
|
|
|
|
|
pgle_profiler = None
|
|
|
|
|
pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict
|
|
|
|
|
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
|
|
|
|
|
if jaxpr not in pgle_profiler_dict:
|
|
|
|
|
pgle_profiler_dict[jaxpr] = profiler.PGLEProfiler(
|
|
|
|
|
config.pgle_profiling_runs.value,
|
|
|
|
|
config.pgle_aggregation_percentile.value)
|
|
|
|
|
|
|
|
|
|
pgle_profiler = pgle_profiler_dict[jaxpr]
|
|
|
|
|
# The method below will return FDO profile when module was profiled
|
|
|
|
|
# config.jax_pgle_profiling_runs amount of times, otherwise the result will
|
|
|
|
|
# be None.
|
|
|
|
|
fdo_profile = pgle_profiler.consume_fdo_profile()
|
|
|
|
|
if fdo_profile is not None:
|
|
|
|
|
compile_options = {'fdo_profile': fdo_profile}
|
|
|
|
|
|
|
|
|
|
# TODO(patrios): Do not pass mutable profile session through cached lowering
|
|
|
|
|
# chain. Instead we need to move profilers dictionary to pxla module and use
|
|
|
|
|
# module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode.
|
2024-04-06 13:43:32 -07:00
|
|
|
|
compiled = _resolve_and_lower(
|
2024-05-29 01:49:06 -07:00
|
|
|
|
args, jaxpr=jaxpr, in_shardings=in_shardings,
|
|
|
|
|
out_shardings=out_shardings, in_layouts=in_layouts,
|
|
|
|
|
out_layouts=out_layouts, resource_env=resource_env,
|
2024-04-06 13:43:32 -07:00
|
|
|
|
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
2024-06-17 11:58:18 -07:00
|
|
|
|
inline=inline, lowering_platforms=None,
|
|
|
|
|
lowering_parameters=mlir.LoweringParameters(),
|
2024-05-29 01:49:06 -07:00
|
|
|
|
pgle_profiler=pgle_profiler
|
|
|
|
|
).compile(compile_options)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-06-05 10:06:30 -07:00
|
|
|
|
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# This check is expensive so only do it if enable_checks is on.
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if compiled._auto_spmd_lowering and config.enable_checks.value:
|
2024-03-25 10:07:55 -07:00
|
|
|
|
pxla.check_array_xla_sharding_layout_match(
|
|
|
|
|
args, compiled._in_shardings, compiled._in_layouts,
|
2024-04-09 22:11:17 -07:00
|
|
|
|
jaxpr.jaxpr.debug_info, compiled._kept_var_idx)
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.distributed_debug.value:
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# Defensively only perform fingerprint logic if debug logging is enabled
|
|
|
|
|
# NOTE(skyewm): I didn't benchmark this
|
|
|
|
|
fingerprint = None
|
|
|
|
|
if hasattr(compiled.runtime_executable(), "fingerprint"):
|
|
|
|
|
fingerprint = compiled.runtime_executable().fingerprint
|
|
|
|
|
if fingerprint is not None:
|
|
|
|
|
fingerprint = fingerprint.hex()
|
|
|
|
|
distributed_debug_log(("Running pjit'd function", name),
|
|
|
|
|
("in_shardings", in_shardings),
|
|
|
|
|
("out_shardings", out_shardings),
|
2024-04-05 20:08:48 -07:00
|
|
|
|
("in_layouts", in_layouts),
|
|
|
|
|
("out_layouts", out_layouts),
|
2023-03-22 20:54:45 -07:00
|
|
|
|
("abstract args", map(xla.abstractify, args)),
|
2022-12-16 13:06:38 -08:00
|
|
|
|
("fingerprint", fingerprint))
|
2023-01-10 16:26:18 -08:00
|
|
|
|
try:
|
2023-05-26 08:56:56 -07:00
|
|
|
|
return compiled.unsafe_call(*args), compiled
|
2023-11-29 18:06:36 -08:00
|
|
|
|
except FloatingPointError as e:
|
2023-10-09 07:28:18 -07:00
|
|
|
|
assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case
|
2023-02-14 09:26:53 -08:00
|
|
|
|
|
2023-11-29 18:06:36 -08:00
|
|
|
|
if len(jaxpr.eqns) > 1:
|
|
|
|
|
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
|
2023-02-14 09:26:53 -08:00
|
|
|
|
|
|
|
|
|
# If control reaches this line, we got a NaN on the output of `compiled`
|
|
|
|
|
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
|
2023-11-29 18:06:36 -08:00
|
|
|
|
msg = (f"{str(e)}. Because "
|
2023-10-09 07:28:18 -07:00
|
|
|
|
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
|
2023-01-10 16:26:18 -08:00
|
|
|
|
"de-optimized function (i.e., the function as if the `jit` "
|
|
|
|
|
"decorator were removed) was called in an attempt to get a more "
|
|
|
|
|
"precise error message. However, the de-optimized function did not "
|
|
|
|
|
"produce invalid values during its execution. This behavior can "
|
2023-01-20 08:47:45 -08:00
|
|
|
|
"result from `jit` optimizations causing the invalid value to be "
|
2023-01-10 16:26:18 -08:00
|
|
|
|
"produced. It may also arise from having nan/inf constants as "
|
|
|
|
|
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
|
|
|
|
|
"\n\n"
|
|
|
|
|
"It may be possible to avoid the invalid value by removing the "
|
|
|
|
|
"`jit` decorator, at the cost of losing optimizations. "
|
|
|
|
|
"\n\n"
|
|
|
|
|
"If you see this error, consider opening a bug report at "
|
2024-09-20 07:51:48 -07:00
|
|
|
|
"https://github.com/jax-ml/jax.")
|
2023-01-10 16:26:18 -08:00
|
|
|
|
raise FloatingPointError(msg)
|
|
|
|
|
|
2023-05-26 08:56:56 -07:00
|
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
2024-04-05 20:08:48 -07:00
|
|
|
|
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
|
|
|
|
|
out_layouts, resource_env, donated_invars, name,
|
|
|
|
|
keep_unused, inline):
|
2023-05-26 08:56:56 -07:00
|
|
|
|
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
|
|
|
|
|
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
|
|
|
|
|
# the jaxpr defeating the purpose of weakref_lru_cache. So return a function
|
|
|
|
|
# that closes over a weakrefed jaxpr and gets called inside that function.
|
|
|
|
|
# This way there won't be a strong reference to the jaxpr from the output
|
|
|
|
|
# function.
|
|
|
|
|
jaxpr = weakref.ref(jaxpr)
|
|
|
|
|
return lambda *args: core.jaxpr_as_fun(jaxpr())(*args) # pylint: disable=unnecessary-lambda
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pjit_call_impl(*args, jaxpr,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_shardings, out_shardings, in_layouts, out_layouts,
|
|
|
|
|
resource_env,
|
2023-05-26 08:56:56 -07:00
|
|
|
|
donated_invars, name, keep_unused, inline):
|
|
|
|
|
def call_impl_cache_miss(*args_, **kwargs_):
|
|
|
|
|
out_flat, compiled = _pjit_call_impl_python(
|
|
|
|
|
*args, jaxpr=jaxpr, in_shardings=in_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
out_shardings=out_shardings, in_layouts=in_layouts,
|
|
|
|
|
out_layouts=out_layouts, resource_env=resource_env,
|
2023-05-26 08:56:56 -07:00
|
|
|
|
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
|
|
|
|
inline=inline)
|
2024-05-29 01:49:06 -07:00
|
|
|
|
pgle_profiler = _read_pgle_profiler(jaxpr)
|
2023-05-26 08:56:56 -07:00
|
|
|
|
fastpath_data = _get_fastpath_data(
|
2024-03-21 08:09:37 -07:00
|
|
|
|
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
|
2024-05-29 01:49:06 -07:00
|
|
|
|
jaxpr.consts, None, pgle_profiler)
|
2024-06-18 11:31:09 -04:00
|
|
|
|
return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
|
2023-05-26 08:56:56 -07:00
|
|
|
|
|
|
|
|
|
f = _get_jaxpr_as_fun(
|
2024-04-05 20:08:48 -07:00
|
|
|
|
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
2023-05-26 08:56:56 -07:00
|
|
|
|
resource_env, donated_invars, name, keep_unused, inline)
|
2024-09-17 16:10:41 -07:00
|
|
|
|
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
|
|
|
|
if xla_extension_version >= 286:
|
|
|
|
|
cache_key = pxla.JitGlobalCppCacheKeys(
|
|
|
|
|
donate_argnums=donated_argnums, donate_argnames=None,
|
|
|
|
|
device=None, backend=None,
|
|
|
|
|
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
|
|
|
|
|
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
|
|
|
|
|
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
|
|
|
|
|
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
|
|
|
|
|
use_resource_env=resource_env is not None)
|
|
|
|
|
return xc._xla.pjit(
|
|
|
|
|
name, f, call_impl_cache_miss, [], [], cache_key,
|
|
|
|
|
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
|
|
|
|
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
|
|
|
|
|
else:
|
|
|
|
|
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
|
|
|
|
|
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
|
|
|
|
|
return xc._xla.pjit(
|
|
|
|
|
name, f, call_impl_cache_miss, [], [], donated_argnums,
|
|
|
|
|
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
|
|
|
|
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
2023-05-26 08:56:56 -07:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
pjit_p.def_impl(_pjit_call_impl)
|
|
|
|
|
|
|
|
|
|
|
2024-03-26 13:28:03 -07:00
|
|
|
|
def _pjit_lower(*args, **kwargs):
|
|
|
|
|
return _pjit_lower_cached(*args, **kwargs)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
|
|
|
|
def _pjit_lower_cached(
|
|
|
|
|
jaxpr: core.ClosedJaxpr,
|
2024-03-26 13:28:03 -07:00
|
|
|
|
in_shardings,
|
|
|
|
|
out_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts: pxla.MaybeLayout,
|
|
|
|
|
out_layouts: pxla.MaybeLayout,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resource_env,
|
|
|
|
|
donated_invars,
|
|
|
|
|
name: str,
|
|
|
|
|
keep_unused: bool,
|
2023-04-26 15:54:50 -07:00
|
|
|
|
inline: bool,
|
2023-02-28 11:30:23 +01:00
|
|
|
|
*,
|
2024-06-17 11:58:18 -07:00
|
|
|
|
lowering_platforms: tuple[str, ...] | None,
|
2024-05-29 01:49:06 -07:00
|
|
|
|
lowering_parameters: mlir.LoweringParameters,
|
|
|
|
|
pgle_profiler: profiler.PGLEProfiler | None):
|
2024-08-19 18:42:45 -07:00
|
|
|
|
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
|
|
|
|
|
if resource_env is not None else (None, 'jit'))
|
2024-07-24 12:39:42 -07:00
|
|
|
|
return pxla.lower_sharding_computation(
|
|
|
|
|
jaxpr, api_name, name, in_shardings, out_shardings,
|
|
|
|
|
in_layouts, out_layouts, tuple(donated_invars),
|
2024-08-02 11:04:01 -07:00
|
|
|
|
keep_unused=keep_unused, context_mesh=mesh,
|
2024-06-17 11:58:18 -07:00
|
|
|
|
lowering_platforms=lowering_platforms,
|
2024-07-24 12:39:42 -07:00
|
|
|
|
lowering_parameters=lowering_parameters,
|
|
|
|
|
pgle_profiler=pgle_profiler)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pjit_staging_rule(trace, *args, **params):
|
2024-05-24 01:14:16 +00:00
|
|
|
|
jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
|
|
|
|
|
params['jaxpr'], params['out_shardings'], params['out_layouts'])
|
|
|
|
|
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
|
|
|
|
|
out_layouts=out_layouts)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
if (params["inline"] and
|
2023-04-10 10:15:08 -07:00
|
|
|
|
all(is_unspecified(i) for i in params["in_shardings"]) and
|
2024-04-05 20:08:48 -07:00
|
|
|
|
all(is_unspecified(o) for o in params["out_shardings"]) and
|
|
|
|
|
all(i is None for i in params["in_layouts"]) and
|
|
|
|
|
all(o is None for o in params["out_layouts"])):
|
2024-02-15 12:27:13 -08:00
|
|
|
|
if config.dynamic_shapes.value:
|
|
|
|
|
# Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
|
|
|
|
|
# shapes are enabled, use eval_jaxpr, which uses the tracing machinery,
|
|
|
|
|
# but redundantly performs abstract evaluation again.
|
2024-05-24 01:14:16 +00:00
|
|
|
|
out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
|
|
|
|
|
propagate_source_info=False)
|
2024-02-15 12:27:13 -08:00
|
|
|
|
else:
|
2024-05-24 01:14:16 +00:00
|
|
|
|
out_tracers = pe.inline_jaxpr_into_trace(
|
|
|
|
|
trace, jaxpr.jaxpr, jaxpr.consts, *args)
|
2023-10-09 07:28:18 -07:00
|
|
|
|
elif config.dynamic_shapes.value:
|
2023-03-22 20:54:45 -07:00
|
|
|
|
source_info = source_info_util.current()
|
|
|
|
|
out_tracers = []
|
2024-05-24 01:14:16 +00:00
|
|
|
|
for aval in _out_type(jaxpr):
|
2023-03-22 20:54:45 -07:00
|
|
|
|
if type(aval) is core.DShapedArray:
|
|
|
|
|
shape = [args[d.val] if type(d) is core.InDBIdx else
|
|
|
|
|
out_tracers[d.val] if type(d) is core.OutDBIdx else
|
|
|
|
|
d for d in aval.shape]
|
|
|
|
|
aval = aval.update(shape=tuple(core.get_referent(d) for d in shape))
|
|
|
|
|
out_tracers.append(pe.DynamicJaxprTracer(trace, aval, source_info))
|
|
|
|
|
eqn = core.new_jaxpr_eqn(
|
|
|
|
|
map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params,
|
2024-05-24 01:14:16 +00:00
|
|
|
|
jaxpr.effects, source_info)
|
2023-03-22 20:54:45 -07:00
|
|
|
|
trace.frame.add_eqn(eqn)
|
2024-05-24 01:14:16 +00:00
|
|
|
|
elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
|
|
|
|
|
jaxpr, consts = pxla._move_mutable_consts(jaxpr)
|
2024-03-05 16:20:24 -08:00
|
|
|
|
consts = map(trace.instantiate_const, consts)
|
|
|
|
|
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts = (*params['in_layouts'],) + (None,) * len(consts)
|
2024-03-05 16:20:24 -08:00
|
|
|
|
donated_invars = (*params['donated_invars'],) + (False,) * len(consts)
|
|
|
|
|
new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts=in_layouts, donated_invars=donated_invars)
|
2024-05-24 01:14:16 +00:00
|
|
|
|
out_tracers = trace.default_process_primitive(
|
|
|
|
|
pjit_p, (*args, *consts), new_params)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
else:
|
2024-05-24 01:14:16 +00:00
|
|
|
|
out_tracers = trace.default_process_primitive(pjit_p, args, params)
|
|
|
|
|
|
|
|
|
|
out_tracers_ = iter(out_tracers)
|
|
|
|
|
out_tracers = [args[f] if type(f) is int else next(out_tracers_)
|
|
|
|
|
for f in in_fwd]
|
|
|
|
|
assert next(out_tracers_, None) is None
|
|
|
|
|
return out_tracers
|
2022-12-16 13:06:38 -08:00
|
|
|
|
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
|
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
|
2024-05-24 01:14:16 +00:00
|
|
|
|
def _pjit_forwarding(jaxpr, out_shardings, out_layouts):
|
|
|
|
|
in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr)
|
|
|
|
|
in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol
|
|
|
|
|
in zip(in_fwd, out_shardings, out_layouts)]
|
|
|
|
|
keep = [f is None for f in in_fwd]
|
|
|
|
|
jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep)
|
|
|
|
|
out_shardings = [o for o, k in zip(out_shardings, keep) if k]
|
|
|
|
|
out_layouts = [o for o, k in zip(out_layouts , keep) if k]
|
|
|
|
|
return jaxpr, in_fwd, out_shardings, out_layouts
|
|
|
|
|
|
|
|
|
|
def pjit_forwarding_rule(eqn):
|
|
|
|
|
jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
|
|
|
|
|
eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts'])
|
|
|
|
|
new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None]
|
|
|
|
|
new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=(*out_shardings,),
|
|
|
|
|
out_layouts=(*out_layouts,))
|
|
|
|
|
new_eqn = eqn.replace(params=new_params, outvars=new_outvars)
|
|
|
|
|
fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd]
|
|
|
|
|
return fwd_vars, new_eqn
|
|
|
|
|
pe.forwarding_rules[pjit_p] = pjit_forwarding_rule
|
|
|
|
|
|
|
|
|
|
|
2023-03-22 20:54:45 -07:00
|
|
|
|
# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them,
|
|
|
|
|
# since it's actually not possible in general to infer the type from the term
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]:
|
2023-03-22 20:54:45 -07:00
|
|
|
|
out = []
|
|
|
|
|
in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)}
|
|
|
|
|
out_idx = {x: i for i, x in enumerate(jaxpr.jaxpr.invars)
|
|
|
|
|
if type(x) is core.Var}
|
|
|
|
|
for x in jaxpr.jaxpr.outvars:
|
|
|
|
|
aval = x.aval
|
|
|
|
|
if type(aval) is core.DShapedArray:
|
|
|
|
|
shape = [core.InDBIdx(in_idx[d]) if d in in_idx else
|
|
|
|
|
core.OutDBIdx(out_idx[d]) if d in out_idx else
|
|
|
|
|
d for d in x.aval.shape]
|
|
|
|
|
aval = aval.update(shape=tuple(shape))
|
|
|
|
|
out.append(aval)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
2023-03-21 21:43:20 -07:00
|
|
|
|
def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
|
|
|
|
|
return core._check_call(ctx_factory, pjit_p, in_atoms,
|
|
|
|
|
dict(params, call_jaxpr=jaxpr.jaxpr))
|
|
|
|
|
core.custom_typechecks[pjit_p] = _pjit_typecheck
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-04-05 20:08:48 -07:00
|
|
|
|
def _pjit_abstract_eval(*args, jaxpr, **_):
|
2023-03-15 17:08:21 -07:00
|
|
|
|
return jaxpr.out_avals, jaxpr.effects
|
2022-12-16 13:06:38 -08:00
|
|
|
|
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
|
|
|
|
|
|
|
|
|
|
2023-10-30 15:27:17 -07:00
|
|
|
|
def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
out_shardings, in_layouts, out_layouts,
|
|
|
|
|
api_name):
|
2023-10-30 15:27:17 -07:00
|
|
|
|
mod_ctx = ctx.module_context
|
|
|
|
|
axis_ctx = ctx.module_context.axis_context
|
2023-12-08 14:35:27 -08:00
|
|
|
|
num_devices = None
|
2023-10-30 15:27:17 -07:00
|
|
|
|
if isinstance(axis_ctx, sharding_impls.ShardingContext):
|
2023-12-08 14:35:27 -08:00
|
|
|
|
num_devices = axis_ctx.num_devices
|
2023-10-30 15:27:17 -07:00
|
|
|
|
elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
2023-12-08 14:35:27 -08:00
|
|
|
|
num_devices = axis_ctx.mesh.size
|
|
|
|
|
key = (pjit_p, name, jaxpr, effects, num_devices,
|
2024-03-26 13:28:03 -07:00
|
|
|
|
pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals),
|
|
|
|
|
pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals),
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts, out_layouts, api_name)
|
2023-10-30 15:27:17 -07:00
|
|
|
|
|
|
|
|
|
func = mod_ctx.cached_primitive_lowerings.get(key, None)
|
|
|
|
|
if func is None:
|
2024-03-15 16:01:13 -07:00
|
|
|
|
arg_shardings = [None if is_unspecified(i) else i for i in in_shardings]
|
|
|
|
|
result_shardings = [None if is_unspecified(o) else o for o in out_shardings]
|
2023-10-30 15:27:17 -07:00
|
|
|
|
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
|
|
|
|
|
# inputs or outputs because they are lost during MLIR->HLO conversion.
|
|
|
|
|
# using_sharding_annotation=False means we add an identity operation instead.
|
|
|
|
|
func = mlir.lower_jaxpr_to_fun(
|
2024-02-20 07:16:38 -08:00
|
|
|
|
mod_ctx, name, jaxpr, effects, ctx.name_stack,
|
|
|
|
|
arg_shardings=arg_shardings, result_shardings=result_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
use_sharding_annotations=False, api_name=api_name,
|
|
|
|
|
arg_layouts=in_layouts, result_layouts=out_layouts)
|
2023-10-30 15:27:17 -07:00
|
|
|
|
mod_ctx.cached_primitive_lowerings[key] = func
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
out_shardings, in_layouts, out_layouts, resource_env,
|
|
|
|
|
donated_invars, keep_unused, inline):
|
2023-02-23 16:03:00 -08:00
|
|
|
|
effects = list(ctx.tokens_in.effects())
|
2024-07-03 16:38:18 -04:00
|
|
|
|
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
|
2023-01-12 17:40:06 -08:00
|
|
|
|
output_types = [mlir.token_type()] * len(effects) + output_types
|
2024-07-03 16:38:18 -04:00
|
|
|
|
flat_output_types = mlir.flatten_ir_types(output_types)
|
2023-05-15 10:31:38 -07:00
|
|
|
|
|
2023-10-30 15:27:17 -07:00
|
|
|
|
func = _pjit_cached_lower_jaxpr_to_fun(
|
|
|
|
|
ctx, name, jaxpr, tuple(effects), in_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
out_shardings, in_layouts, out_layouts,
|
|
|
|
|
api_name=('jit' if resource_env is None else 'pjit'))
|
2023-10-30 15:27:17 -07:00
|
|
|
|
|
2023-02-23 16:03:00 -08:00
|
|
|
|
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
|
|
|
|
|
args = (*ctx.dim_var_values, *tokens_in, *args)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
call = func_dialect.CallOp(flat_output_types,
|
|
|
|
|
ir.FlatSymbolRefAttr.get(func.name.value),
|
2024-07-01 08:42:48 -04:00
|
|
|
|
mlir.flatten_ir_values(args))
|
2024-05-28 16:58:33 -07:00
|
|
|
|
mlir.wrap_compute_type_in_place(ctx, call)
|
2024-07-03 16:38:18 -04:00
|
|
|
|
out_nodes = mlir.unflatten_ir_values_like_types(call.results, output_types)
|
2023-01-12 17:40:06 -08:00
|
|
|
|
tokens, out_nodes = split_list(out_nodes, [len(effects)])
|
|
|
|
|
tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens)))
|
|
|
|
|
ctx.set_tokens_out(tokens_out)
|
|
|
|
|
return out_nodes
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(pjit_p, _pjit_lowering)
|
|
|
|
|
|
|
|
|
|
|
2024-07-24 19:01:31 -07:00
|
|
|
|
def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type,
|
|
|
|
|
vals_in, dims_in, jaxpr, in_shardings, out_shardings,
|
|
|
|
|
in_layouts, out_layouts, resource_env, donated_invars, name,
|
|
|
|
|
keep_unused, inline):
|
2023-06-30 14:34:48 -07:00
|
|
|
|
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
|
2023-01-12 21:16:18 -08:00
|
|
|
|
new_jaxpr, axes_out = batching.batch_jaxpr2(
|
2023-02-13 14:57:50 -08:00
|
|
|
|
jaxpr, axis_size, dims_in, axis_name=axis_name,
|
|
|
|
|
spmd_axis_name=spmd_axis_name, main_type=main_type)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-01-12 17:23:55 -08:00
|
|
|
|
if resource_env is not None:
|
|
|
|
|
mesh = resource_env.physical_mesh
|
|
|
|
|
else:
|
|
|
|
|
mesh = None
|
|
|
|
|
|
2023-06-30 14:34:48 -07:00
|
|
|
|
# TODO(axch): prepend with Nones (?) to account for new segment_lens inputs
|
2022-12-16 13:06:38 -08:00
|
|
|
|
in_shardings = tuple(
|
2024-07-24 19:01:31 -07:00
|
|
|
|
_pjit_batcher_for_sharding(i, axis_in, spmd_axis_name, mesh, aval.ndim)
|
2023-01-12 21:16:18 -08:00
|
|
|
|
if axis_in is not None else i
|
|
|
|
|
for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals))
|
2022-12-16 13:06:38 -08:00
|
|
|
|
out_shardings = tuple(
|
2024-07-24 19:01:31 -07:00
|
|
|
|
_pjit_batcher_for_sharding(o, axis_out, spmd_axis_name, mesh, aval.ndim)
|
2023-01-12 21:16:18 -08:00
|
|
|
|
if axis_out is not None else o
|
|
|
|
|
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
|
2024-04-05 20:08:48 -07:00
|
|
|
|
# TODO(yashkatariya): Figure out layouts should change under vmap.
|
|
|
|
|
if not (all(l is None for l in in_layouts) and
|
|
|
|
|
all(l is None for l in out_layouts)):
|
2024-07-25 18:31:50 -07:00
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
'Concrete layouts are not supported for vmap(jit).')
|
2024-04-05 20:08:48 -07:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
vals_out = pjit_p.bind(
|
|
|
|
|
*vals_in,
|
|
|
|
|
jaxpr=new_jaxpr,
|
|
|
|
|
in_shardings=in_shardings,
|
|
|
|
|
out_shardings=out_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts=in_layouts,
|
|
|
|
|
out_layouts=out_layouts,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
donated_invars=donated_invars,
|
|
|
|
|
name=name,
|
|
|
|
|
keep_unused=keep_unused,
|
|
|
|
|
inline=inline)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
|
2023-07-14 14:27:29 -07:00
|
|
|
|
resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
|
2023-06-30 14:34:48 -07:00
|
|
|
|
vals_in, vals_out, axes_out)
|
2023-07-14 14:27:29 -07:00
|
|
|
|
return vals_out, resolved_axes_out
|
2023-01-12 21:16:18 -08:00
|
|
|
|
|
2024-07-24 19:01:31 -07:00
|
|
|
|
batching.spmd_axis_primitive_batchers[pjit_p] = _pjit_batcher
|
|
|
|
|
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
def _pjit_batcher_for_sharding(
|
2024-06-05 09:06:36 -07:00
|
|
|
|
s: sharding.Sharding | UnspecifiedValue,
|
2024-07-24 19:01:31 -07:00
|
|
|
|
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified(s):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
return s
|
2024-03-26 13:28:03 -07:00
|
|
|
|
hlo_s = s._to_xla_hlo_sharding(ndim) # type: ignore
|
2024-07-24 19:01:31 -07:00
|
|
|
|
if spmd_axis_name is None:
|
2024-03-26 13:28:03 -07:00
|
|
|
|
if sharding_impls.is_op_sharding_replicated(hlo_s):
|
2023-05-03 11:54:46 -07:00
|
|
|
|
return s
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
|
|
|
|
|
parsed_pspec = s._parsed_pspec.insert_axis_partitions(dim, None)
|
|
|
|
|
return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec)
|
2024-05-17 09:46:36 +01:00
|
|
|
|
new_op = hlo_s.to_proto().clone()
|
2022-12-16 13:06:38 -08:00
|
|
|
|
tad = list(new_op.tile_assignment_dimensions)
|
|
|
|
|
tad.insert(dim, 1)
|
|
|
|
|
new_op.tile_assignment_dimensions = tad
|
2024-05-07 16:06:48 -07:00
|
|
|
|
new_gs = GSPMDSharding(
|
|
|
|
|
s._device_assignment, new_op, # type: ignore
|
|
|
|
|
_device_list=getattr(s, '_internal_device_list', None))
|
2024-05-17 09:46:36 +01:00
|
|
|
|
return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0]
|
2022-12-16 13:06:38 -08:00
|
|
|
|
else:
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
|
|
|
|
|
parsed_pspec = s._parsed_pspec.insert_axis_partitions(dim, spmd_axis_name)
|
|
|
|
|
return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec)
|
2024-03-26 13:28:03 -07:00
|
|
|
|
if isinstance(s, NamedSharding):
|
2024-05-17 09:46:36 +01:00
|
|
|
|
mesh = s.mesh
|
2023-08-29 20:58:20 -07:00
|
|
|
|
if mesh is None or mesh.empty:
|
|
|
|
|
raise ValueError(
|
2024-07-24 19:01:31 -07:00
|
|
|
|
'If you are using spmd_axis_name parameter of jax.vmap,'
|
2023-08-29 20:58:20 -07:00
|
|
|
|
' please make sure to run your jitted function inside the mesh'
|
|
|
|
|
' context manager. Only `jax.lax.with_sharding_constraint` with'
|
|
|
|
|
' `jax.sharding.NamedSharding` as an input can be transformed with'
|
|
|
|
|
' spmd_axis_name batching rules outside of an explicit mesh context'
|
2024-03-26 13:28:03 -07:00
|
|
|
|
f' manager scope{s!r}')
|
2024-05-17 09:46:36 +01:00
|
|
|
|
parsed_pspec = parse_flatten_op_sharding(hlo_s, mesh)[0]
|
2024-07-24 19:01:31 -07:00
|
|
|
|
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, spmd_axis_name)
|
2024-03-26 13:28:03 -07:00
|
|
|
|
return NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pjit_jvp(primals_in, tangents_in,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
2023-03-17 11:50:59 -07:00
|
|
|
|
resource_env, donated_invars, name, keep_unused, inline):
|
2024-04-04 14:33:06 -04:00
|
|
|
|
if any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
|
|
|
|
|
jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr)
|
|
|
|
|
mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals)
|
|
|
|
|
primals_in = [*primals_in, *mut_primals]
|
|
|
|
|
tangents_in = [*tangents_in, *mut_tangents]
|
|
|
|
|
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut_primals)
|
2024-04-12 14:25:38 -07:00
|
|
|
|
in_layouts = (*in_layouts,) + (None,) * len(mut_primals)
|
2024-04-04 14:33:06 -04:00
|
|
|
|
donated_invars = (*donated_invars,) + (False,) * len(mut_primals)
|
|
|
|
|
|
|
|
|
|
tangents_in = [ad_util.zeros_like_aval(a) if isinstance(a, AbstractRef) else x
|
|
|
|
|
for x, a in zip(tangents_in, jaxpr.in_avals)]
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
|
|
|
|
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
|
|
|
|
jaxpr, is_nz_tangents_in, instantiate=False)
|
|
|
|
|
|
|
|
|
|
def _filter_zeros(is_nz_l, l):
|
|
|
|
|
return (x for nz, x in zip(is_nz_l, l) if nz)
|
|
|
|
|
_filter_zeros_in = partial(_filter_zeros, is_nz_tangents_in)
|
|
|
|
|
_filter_zeros_out = partial(_filter_zeros, is_nz_tangents_out)
|
|
|
|
|
outputs = pjit_p.bind(
|
|
|
|
|
*primals_in, *_filter_zeros_in(tangents_in),
|
|
|
|
|
jaxpr=jaxpr_jvp,
|
|
|
|
|
in_shardings=(*in_shardings, *_filter_zeros_in(in_shardings)),
|
|
|
|
|
out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)),
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)),
|
|
|
|
|
out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)),
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
|
2023-01-13 12:53:42 -08:00
|
|
|
|
name=name,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
keep_unused=keep_unused,
|
|
|
|
|
inline=inline)
|
|
|
|
|
|
|
|
|
|
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
|
|
|
|
|
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
|
|
|
|
|
tangents_out_it = iter(tangents_out)
|
|
|
|
|
return primals_out, [next(tangents_out_it) if nz else ad.Zero(aval)
|
|
|
|
|
for nz, aval in zip(is_nz_tangents_out, jaxpr.out_avals)]
|
|
|
|
|
ad.primitive_jvps[pjit_p] = _pjit_jvp
|
|
|
|
|
|
|
|
|
|
|
2023-01-20 18:03:24 -08:00
|
|
|
|
@weakref_lru_cache
|
|
|
|
|
def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
|
2024-02-23 10:23:31 -08:00
|
|
|
|
in_fwd: tuple[int | None, ...]) -> core.ClosedJaxpr:
|
2023-01-20 18:03:24 -08:00
|
|
|
|
updated_jaxpr = known_jaxpr.jaxpr.replace(
|
2023-10-12 16:00:08 -07:00
|
|
|
|
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, in_fwd)
|
2023-01-20 18:03:24 -08:00
|
|
|
|
if i is None])
|
|
|
|
|
return known_jaxpr.replace(jaxpr=updated_jaxpr)
|
|
|
|
|
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
def _pjit_partial_eval(trace, *in_tracers,
|
|
|
|
|
jaxpr, in_shardings, out_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts, out_layouts, resource_env, donated_invars,
|
|
|
|
|
name, keep_unused, inline):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
in_pvals = [t.pval for t in in_tracers]
|
|
|
|
|
|
|
|
|
|
known_ins = tuple(pv.is_known() for pv in in_pvals)
|
|
|
|
|
unknown_ins = tuple(not k for k in known_ins)
|
2024-03-15 10:00:27 -07:00
|
|
|
|
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
|
|
|
|
|
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
unknown_outs = tuple(unknown_outs)
|
|
|
|
|
known_outs = tuple(not uk for uk in unknown_outs)
|
|
|
|
|
num_residuals = len(res_avals)
|
2023-10-12 16:00:08 -07:00
|
|
|
|
res_shardings = (UNSPECIFIED,) * num_residuals
|
2024-04-05 20:08:48 -07:00
|
|
|
|
res_layouts = (None,) * num_residuals
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
def keep_where(l, should_keep):
|
2023-10-12 16:00:08 -07:00
|
|
|
|
return tuple(x for x, keep in zip(l, should_keep) if keep)
|
|
|
|
|
|
2024-03-15 10:00:27 -07:00
|
|
|
|
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
|
2024-04-05 20:08:48 -07:00
|
|
|
|
known_out_layouts = keep_where(out_layouts, known_outs) + res_layouts
|
2024-03-15 10:14:57 -07:00
|
|
|
|
|
2024-03-15 12:09:21 -07:00
|
|
|
|
# Input-to-output forwarding: compute which outputs are just forwarded inputs.
|
|
|
|
|
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
|
|
|
|
|
in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
|
|
|
|
|
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
|
|
|
|
|
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_fwd = [
|
|
|
|
|
fwd if is_unspecified(os) and ol is None else None
|
|
|
|
|
for os, ol, fwd in zip(
|
|
|
|
|
keep_where(out_shardings, known_outs),
|
|
|
|
|
keep_where(out_layouts, known_outs), in_fwd_primal)
|
|
|
|
|
] + in_fwd_res
|
2024-03-15 12:09:21 -07:00
|
|
|
|
del in_fwd_primal, in_fwd_res
|
|
|
|
|
# Prune jaxpr outputs and out_shardings by removing the input-forwards.
|
|
|
|
|
keep = [f is None for f in in_fwd]
|
|
|
|
|
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
|
|
|
|
|
known_out_shardings = keep_where(known_out_shardings, keep)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
known_out_layouts = keep_where(known_out_layouts, keep)
|
2024-03-15 12:09:21 -07:00
|
|
|
|
# Update num_out_primals to reflect pruning.
|
|
|
|
|
kept_primals, kept_res = split_list(keep, [num_out_primals])
|
|
|
|
|
num_out_primals = sum(kept_primals)
|
|
|
|
|
del keep, kept_primals, kept_res
|
|
|
|
|
|
|
|
|
|
# Output-to-output forwarding: compute which residuals are just primal outputs
|
|
|
|
|
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
|
|
|
|
|
idx_map = {id(v): i for i, v in enumerate(out_vars)}
|
|
|
|
|
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
|
|
|
|
|
# Prune jaxpr outputs and out_shardings by removing forwarded residuals.
|
|
|
|
|
keep = [f is None for f in out_fwd]
|
|
|
|
|
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
|
|
|
|
|
known_out_shardings = keep_where(known_out_shardings, keep)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
known_out_layouts = keep_where(known_out_layouts, keep)
|
2024-03-15 12:09:21 -07:00
|
|
|
|
del keep
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
known_params = dict(
|
2023-10-12 16:00:08 -07:00
|
|
|
|
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
|
2024-04-05 20:08:48 -07:00
|
|
|
|
out_shardings=known_out_shardings,
|
|
|
|
|
in_layouts=keep_where(in_layouts, known_ins),
|
|
|
|
|
out_layouts=known_out_layouts, resource_env=resource_env,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
donated_invars=keep_where(donated_invars, known_ins),
|
2023-10-12 16:00:08 -07:00
|
|
|
|
name=name, keep_unused=keep_unused, inline=inline)
|
2023-01-20 18:03:24 -08:00
|
|
|
|
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals)
|
2023-01-20 18:03:24 -08:00
|
|
|
|
|
|
|
|
|
# Bind known things to pjit_p.
|
|
|
|
|
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
|
|
|
|
|
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
|
2024-03-15 12:09:21 -07:00
|
|
|
|
# Add back in the output fwds.
|
|
|
|
|
all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
|
|
|
|
|
# Add back in the input fwds.
|
|
|
|
|
all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
|
2023-01-20 18:03:24 -08:00
|
|
|
|
|
2023-10-12 16:00:08 -07:00
|
|
|
|
known_out_vals, residual_vals = \
|
|
|
|
|
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
|
2023-10-19 00:38:19 -07:00
|
|
|
|
residual_tracers = map(trace.new_instantiated_const, residual_vals)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-03-15 10:00:27 -07:00
|
|
|
|
# The convention of partial_eval_jaxpr_nounits is to place residual binders at
|
|
|
|
|
# the front of the jaxpr produced, so we move them to the back since both the
|
|
|
|
|
# jaxpr equation built below and the pjit transpose rule assume a
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# residual-inputs-last convention.
|
|
|
|
|
unknown_jaxpr = pe.move_binders_to_back(
|
|
|
|
|
unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins))
|
|
|
|
|
# Prepare unknown tracers
|
|
|
|
|
unknown_params = dict(
|
|
|
|
|
jaxpr=unknown_jaxpr,
|
2023-10-12 16:00:08 -07:00
|
|
|
|
in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings),
|
2022-12-16 13:06:38 -08:00
|
|
|
|
out_shardings=keep_where(out_shardings, unknown_outs),
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts),
|
|
|
|
|
out_layouts=keep_where(out_layouts, unknown_outs),
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
donated_invars=(keep_where(donated_invars, unknown_ins) +
|
|
|
|
|
(False,) * num_residuals),
|
|
|
|
|
name=name,
|
|
|
|
|
keep_unused=keep_unused,
|
|
|
|
|
inline=inline)
|
|
|
|
|
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
|
2023-03-15 17:08:21 -07:00
|
|
|
|
unknown_out_avals = unknown_jaxpr.out_avals
|
2022-12-16 13:06:38 -08:00
|
|
|
|
unknown_tracers_out = [
|
|
|
|
|
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
2023-01-12 17:23:55 -08:00
|
|
|
|
for aval in unknown_out_avals
|
2022-12-16 13:06:38 -08:00
|
|
|
|
]
|
|
|
|
|
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
|
|
|
|
|
unknown_tracers_out,
|
|
|
|
|
pjit_p,
|
|
|
|
|
unknown_params,
|
|
|
|
|
unknown_jaxpr.effects,
|
2024-05-22 15:16:07 -07:00
|
|
|
|
source_info_util.current())
|
2022-12-16 13:06:38 -08:00
|
|
|
|
for t in unknown_tracers_out: t.recipe = eqn
|
|
|
|
|
return merge_lists(unknown_outs, known_out_vals, unknown_tracers_out)
|
2023-01-13 10:15:30 -08:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
|
|
|
|
|
|
|
|
|
|
|
2023-01-13 10:15:30 -08:00
|
|
|
|
def _pjit_partial_eval_custom_params_updater(
|
|
|
|
|
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
|
|
|
|
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
2023-10-12 16:00:08 -07:00
|
|
|
|
num_res_out: int, num_res_in: int, params_known: dict, params_staged: dict
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[dict, dict]:
|
2023-01-13 10:15:30 -08:00
|
|
|
|
# prune inputs to jaxpr_known according to unks_in
|
|
|
|
|
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
|
|
|
|
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
|
|
|
|
|
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts_known, _ = pe.partition_list(unks_in, params_known['in_layouts'])
|
|
|
|
|
_, out_layouts_known = pe.partition_list(kept_outs_known, params_known['out_layouts'])
|
|
|
|
|
|
2023-01-13 10:15:30 -08:00
|
|
|
|
new_params_known = dict(params_known,
|
|
|
|
|
in_shardings=tuple(in_shardings_known),
|
2023-10-12 16:00:08 -07:00
|
|
|
|
out_shardings=(*out_shardings_known,
|
|
|
|
|
*[UNSPECIFIED] * num_res_out),
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts=tuple(in_layouts_known),
|
|
|
|
|
out_layouts=(*out_layouts_known, *[None] * num_res_out),
|
2023-03-17 11:50:59 -07:00
|
|
|
|
donated_invars=tuple(donated_invars_known))
|
2023-01-13 10:15:30 -08:00
|
|
|
|
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
|
|
|
|
|
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
assert len(new_params_known['in_layouts']) == len(params_known['jaxpr'].in_avals)
|
|
|
|
|
assert len(new_params_known['out_layouts']) == len(params_known['jaxpr'].out_avals)
|
2023-01-13 10:15:30 -08:00
|
|
|
|
|
|
|
|
|
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
|
|
|
|
|
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
|
2023-10-12 16:00:08 -07:00
|
|
|
|
donated_invars_staged = [False] * num_res_in + donated_invars_staged
|
2023-01-13 10:15:30 -08:00
|
|
|
|
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
|
2023-10-12 16:00:08 -07:00
|
|
|
|
in_shardings_staged = [*[UNSPECIFIED] * num_res_in, *in_shardings_staged]
|
2023-01-13 10:15:30 -08:00
|
|
|
|
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
|
2024-04-05 20:08:48 -07:00
|
|
|
|
_, in_layouts_staged = pe.partition_list(inst_in, params_staged['in_layouts'])
|
|
|
|
|
in_layouts_staged = [*[None] * num_res_in, *in_layouts_staged]
|
|
|
|
|
_, out_layouts_staged = pe.partition_list(kept_outs_staged, params_staged['out_layouts'])
|
2023-01-14 10:18:28 -08:00
|
|
|
|
|
2023-01-13 10:15:30 -08:00
|
|
|
|
new_params_staged = dict(params_staged,
|
|
|
|
|
in_shardings=tuple(in_shardings_staged),
|
|
|
|
|
out_shardings=tuple(out_shardings_staged),
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts=tuple(in_layouts_staged),
|
|
|
|
|
out_layouts=tuple(out_layouts_staged),
|
2023-03-17 11:50:59 -07:00
|
|
|
|
donated_invars=tuple(donated_invars_staged))
|
2023-01-13 10:15:30 -08:00
|
|
|
|
assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals)
|
|
|
|
|
assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
assert len(new_params_staged['in_layouts']) == len(params_staged['jaxpr'].in_avals)
|
|
|
|
|
assert len(new_params_staged['out_layouts']) == len(params_staged['jaxpr'].out_avals)
|
2023-01-13 10:15:30 -08:00
|
|
|
|
return new_params_known, new_params_staged
|
|
|
|
|
|
|
|
|
|
pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
|
|
|
|
|
partial(pe.closed_call_partial_eval_custom_rule, 'jaxpr',
|
|
|
|
|
_pjit_partial_eval_custom_params_updater)
|
|
|
|
|
|
|
|
|
|
|
2023-01-24 09:57:55 -08:00
|
|
|
|
@lu.cache
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
|
def _pjit_transpose_trace(fun, in_avals):
|
2024-01-29 15:18:33 -08:00
|
|
|
|
transpose_jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
|
|
|
|
fun, in_avals)
|
2023-01-24 09:57:55 -08:00
|
|
|
|
transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
|
2024-01-29 15:18:33 -08:00
|
|
|
|
return transpose_jaxpr, attrs_tracked
|
2023-01-24 09:57:55 -08:00
|
|
|
|
|
|
|
|
|
|
2024-02-24 16:11:41 -08:00
|
|
|
|
def _pjit_transpose(cts_in, *primals_in,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
2023-03-17 11:50:59 -07:00
|
|
|
|
resource_env, donated_invars, name, keep_unused, inline):
|
2022-12-16 13:06:38 -08:00
|
|
|
|
def prune_type(ty, xs, maybe_zeros):
|
|
|
|
|
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
|
|
|
|
|
|
|
|
|
|
body = lu.wrap_init(ad.closed_backward_pass)
|
2024-02-24 16:11:41 -08:00
|
|
|
|
body = lu.hashable_partial(body, jaxpr, False)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in))
|
|
|
|
|
body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef)
|
|
|
|
|
|
|
|
|
|
transpose_in_shardings = (
|
|
|
|
|
*prune_type(ad.UndefinedPrimal, in_shardings, primals_in),
|
|
|
|
|
*prune_type(ad.Zero, out_shardings, cts_in)
|
|
|
|
|
)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
transpose_in_layouts = (
|
|
|
|
|
*prune_type(ad.UndefinedPrimal, in_layouts, primals_in),
|
|
|
|
|
*prune_type(ad.Zero, out_layouts, cts_in)
|
|
|
|
|
)
|
2023-01-24 09:57:55 -08:00
|
|
|
|
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
|
|
|
|
|
for ct in primals_and_nz_cts_in)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-01-29 15:18:33 -08:00
|
|
|
|
transpose_jaxpr, attrs_tracked = _pjit_transpose_trace(
|
|
|
|
|
body, global_cts_in_avals)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
cts_out_treedef = cts_out_treedef_thunk()
|
|
|
|
|
transpose_out_shardings = prune_type(
|
|
|
|
|
ad.Zero,
|
|
|
|
|
in_shardings,
|
|
|
|
|
tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
|
2024-04-05 20:08:48 -07:00
|
|
|
|
transpose_out_layouts = prune_type(
|
|
|
|
|
ad.Zero,
|
|
|
|
|
in_layouts,
|
|
|
|
|
tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-01-29 15:18:33 -08:00
|
|
|
|
if attrs_tracked:
|
|
|
|
|
init_states = _get_states(attrs_tracked)
|
|
|
|
|
primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in]
|
|
|
|
|
transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings
|
|
|
|
|
transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings
|
2024-04-05 20:08:48 -07:00
|
|
|
|
transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts
|
|
|
|
|
transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts
|
2024-01-29 15:18:33 -08:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
nz_cts_out = pjit_p.bind(
|
|
|
|
|
*primals_and_nz_cts_in,
|
|
|
|
|
jaxpr=transpose_jaxpr,
|
|
|
|
|
in_shardings=transpose_in_shardings,
|
|
|
|
|
out_shardings=transpose_out_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts=transpose_in_layouts,
|
|
|
|
|
out_layouts=transpose_out_layouts,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
donated_invars=(False,) * len(primals_and_nz_cts_in),
|
|
|
|
|
name=name,
|
|
|
|
|
keep_unused=keep_unused,
|
|
|
|
|
inline=inline)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
|
2024-01-29 15:18:33 -08:00
|
|
|
|
if attrs_tracked:
|
|
|
|
|
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])
|
|
|
|
|
_set_states(attrs_tracked, final_states)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
return tree_unflatten(cts_out_treedef, nz_cts_out)
|
|
|
|
|
ad.reducing_transposes[pjit_p] = _pjit_transpose
|
|
|
|
|
|
|
|
|
|
|
2023-01-23 17:31:33 -08:00
|
|
|
|
@weakref_lru_cache
|
|
|
|
|
def _dce_jaxpr_pjit(
|
2024-02-23 10:23:31 -08:00
|
|
|
|
jaxpr: core.ClosedJaxpr, used_outputs: tuple[bool, ...]
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[core.ClosedJaxpr, list[bool]]:
|
2023-01-23 17:31:33 -08:00
|
|
|
|
new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, used_outputs)
|
|
|
|
|
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts), used_inputs
|
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
|
2023-12-11 13:59:29 +00:00
|
|
|
|
) -> tuple[list[bool], core.JaxprEqn | None]:
|
2023-01-23 17:31:33 -08:00
|
|
|
|
dced_jaxpr, used_inputs = _dce_jaxpr_pjit(
|
|
|
|
|
eqn.params['jaxpr'], tuple(used_outputs))
|
|
|
|
|
|
|
|
|
|
def keep_where(xs, keeps):
|
2023-03-22 20:54:45 -07:00
|
|
|
|
return tuple(x for x, keep in zip(xs, keeps) if keep)
|
2023-01-23 17:31:33 -08:00
|
|
|
|
|
|
|
|
|
eqn_params = eqn.params
|
|
|
|
|
new_params = dict(
|
|
|
|
|
eqn_params,
|
|
|
|
|
jaxpr=dced_jaxpr,
|
|
|
|
|
in_shardings=keep_where(eqn_params["in_shardings"], used_inputs),
|
|
|
|
|
out_shardings=keep_where(eqn_params["out_shardings"], used_outputs),
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_layouts=keep_where(eqn_params["in_layouts"], used_inputs),
|
|
|
|
|
out_layouts=keep_where(eqn_params["out_layouts"], used_outputs),
|
2023-01-23 17:31:33 -08:00
|
|
|
|
donated_invars=keep_where(eqn_params["donated_invars"], used_inputs),
|
|
|
|
|
)
|
|
|
|
|
if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects:
|
|
|
|
|
return used_inputs, None
|
|
|
|
|
else:
|
|
|
|
|
new_eqn = core.new_jaxpr_eqn(
|
|
|
|
|
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
|
|
|
|
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
2024-05-17 15:58:25 -07:00
|
|
|
|
eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info, eqn.ctx)
|
2023-01-23 17:31:33 -08:00
|
|
|
|
return used_inputs, new_eqn
|
|
|
|
|
|
|
|
|
|
pe.dce_rules[pjit_p] = dce_jaxpr_pjit_rule
|
|
|
|
|
|
|
|
|
|
|
2023-02-09 11:02:24 -08:00
|
|
|
|
def _pjit_pp_rule(eqn, context, settings):
|
|
|
|
|
params = dict(eqn.params)
|
|
|
|
|
del params['inline']
|
|
|
|
|
if not any(params['donated_invars']):
|
|
|
|
|
del params['donated_invars']
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if all(is_unspecified(s) for s in params['in_shardings']):
|
2023-02-09 11:02:24 -08:00
|
|
|
|
del params['in_shardings']
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if all(is_unspecified(s) for s in params['out_shardings']):
|
2023-02-09 11:02:24 -08:00
|
|
|
|
del params['out_shardings']
|
2024-04-05 20:08:48 -07:00
|
|
|
|
if all(l is None for l in params['in_layouts']):
|
|
|
|
|
del params['in_layouts']
|
|
|
|
|
if all(l is None for l in params['out_layouts']):
|
|
|
|
|
del params['out_layouts']
|
2023-02-09 11:02:24 -08:00
|
|
|
|
if not params['keep_unused']:
|
|
|
|
|
del params['keep_unused']
|
|
|
|
|
if (params['resource_env'] is None or
|
|
|
|
|
params['resource_env'].physical_mesh.empty):
|
|
|
|
|
del params['resource_env']
|
2023-12-07 15:56:56 +00:00
|
|
|
|
|
|
|
|
|
# Move name= to the front to make the resulting equation easier to scan.
|
|
|
|
|
del params["name"]
|
|
|
|
|
return core._pp_eqn(eqn, context, settings, params=["name"] + sorted(params))
|
|
|
|
|
|
2023-02-09 11:02:24 -08:00
|
|
|
|
core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
|
|
|
|
|
|
|
|
|
|
|
2023-10-04 12:57:17 -07:00
|
|
|
|
def _pjit_state_discharge_rule(
|
2024-04-05 20:08:48 -07:00
|
|
|
|
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings,
|
|
|
|
|
in_layouts, out_layouts, **params):
|
2023-10-04 12:57:17 -07:00
|
|
|
|
if not (all(map(is_unspecified, in_shardings)) and
|
2024-04-05 20:08:48 -07:00
|
|
|
|
all(map(is_unspecified, out_shardings))):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
if not (all(l is None for l in in_layouts) and
|
|
|
|
|
all(l is None for l in out_layouts)):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2023-10-04 12:57:17 -07:00
|
|
|
|
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
|
|
|
|
|
num_outs = len(jaxpr.outvars)
|
|
|
|
|
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts)
|
|
|
|
|
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
|
|
|
|
|
new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars)
|
|
|
|
|
new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars)
|
2024-04-05 20:08:48 -07:00
|
|
|
|
new_in_layouts = (None,) * len(discharged_jaxpr.invars)
|
|
|
|
|
new_out_layouts = (None,) * len(discharged_jaxpr.outvars)
|
2023-10-04 12:57:17 -07:00
|
|
|
|
out_and_ref_vals = pjit_p.bind(
|
|
|
|
|
*args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings,
|
2024-04-05 20:08:48 -07:00
|
|
|
|
out_shardings=new_out_shardings, in_layouts=new_in_layouts,
|
|
|
|
|
out_layouts=new_out_layouts, **params)
|
2023-10-04 12:57:17 -07:00
|
|
|
|
out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs])
|
|
|
|
|
ref_vals_iter = iter(ref_vals)
|
2024-04-04 14:33:06 -04:00
|
|
|
|
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef)
|
2023-10-04 12:57:17 -07:00
|
|
|
|
else None for aval in in_avals)
|
|
|
|
|
sentinel = object()
|
|
|
|
|
assert next(ref_vals_iter, sentinel) is sentinel
|
|
|
|
|
return new_invals, out_vals
|
|
|
|
|
state_discharge.register_discharge_rule(pjit_p)(_pjit_state_discharge_rule)
|
|
|
|
|
|
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
# -------------------- with_sharding_constraint --------------------
|
|
|
|
|
|
2023-05-26 12:34:32 -07:00
|
|
|
|
def with_sharding_constraint(x, shardings):
|
2023-04-26 10:19:04 -07:00
|
|
|
|
"""Mechanism to constrain the sharding of an Array inside a jitted computation
|
|
|
|
|
|
|
|
|
|
This is a strict constraint for the GSPMD partitioner and not a hint. For examples
|
|
|
|
|
of how to use this function, see `Distributed arrays and automatic parallelization`_.
|
|
|
|
|
|
|
|
|
|
Args:
|
2023-09-22 14:54:31 -07:00
|
|
|
|
x: PyTree of jax.Arrays which will have their shardings constrained
|
2023-04-26 10:19:04 -07:00
|
|
|
|
shardings: PyTree of sharding specifications. Valid values are the same as for
|
|
|
|
|
the ``in_shardings`` argument of :func:`jax.experimental.pjit`.
|
|
|
|
|
Returns:
|
|
|
|
|
x_with_shardings: PyTree of jax.Arrays with specified sharding constraints.
|
|
|
|
|
|
|
|
|
|
.. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
|
|
|
|
|
"""
|
2022-12-16 13:06:38 -08:00
|
|
|
|
x_flat, tree = tree_flatten(x)
|
2024-06-27 16:46:44 -07:00
|
|
|
|
|
|
|
|
|
layouts, shardings = _split_layout_and_sharding(shardings)
|
|
|
|
|
|
2024-04-03 22:38:45 -07:00
|
|
|
|
user_shardings = prepare_axis_resources(
|
2023-05-26 12:34:32 -07:00
|
|
|
|
shardings, "shardings", allow_unconstrained_dims=True)
|
|
|
|
|
del shardings
|
2023-02-13 10:53:21 -08:00
|
|
|
|
|
|
|
|
|
user_shardings_flat = tuple(
|
|
|
|
|
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
|
|
|
|
|
del user_shardings
|
|
|
|
|
|
2024-06-27 16:46:44 -07:00
|
|
|
|
user_layouts_flat = tuple(
|
|
|
|
|
flatten_axes("with_sharding_constraint layouts", tree, layouts))
|
|
|
|
|
del layouts
|
|
|
|
|
|
2023-04-04 11:41:00 -07:00
|
|
|
|
resource_env = mesh_lib.thread_resources.env
|
2022-12-16 13:06:38 -08:00
|
|
|
|
mesh = resource_env.physical_mesh
|
|
|
|
|
|
2023-05-03 19:28:54 -07:00
|
|
|
|
shardings_flat = [_create_sharding_for_array(mesh, a, 'shardings',
|
|
|
|
|
'with_sharding_constraint')
|
2023-03-15 17:08:21 -07:00
|
|
|
|
for a in user_shardings_flat]
|
2024-07-25 04:20:09 -07:00
|
|
|
|
# TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's
|
|
|
|
|
# already part of the shardings.
|
2023-03-15 17:08:21 -07:00
|
|
|
|
unconstrained_dims = [get_unconstrained_dims(s)
|
|
|
|
|
if isinstance(s, NamedSharding) else {}
|
|
|
|
|
for s in shardings_flat]
|
2023-02-13 10:53:21 -08:00
|
|
|
|
del user_shardings_flat
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-05-04 21:49:28 -07:00
|
|
|
|
pjit_check_aval_sharding(
|
|
|
|
|
shardings_flat, x_flat, None, "with_sharding_constraint arguments",
|
|
|
|
|
allow_uneven_sharding=True)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2024-07-12 08:09:54 -07:00
|
|
|
|
check_aval_layout_compatibility(user_layouts_flat, x_flat, None,
|
|
|
|
|
"with_sharding_constraint arguments")
|
|
|
|
|
|
2024-06-27 16:46:44 -07:00
|
|
|
|
outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
unconstrained_dims=ud)
|
2024-06-27 16:46:44 -07:00
|
|
|
|
for xf, s, l, ud in zip(x_flat, shardings_flat, user_layouts_flat,
|
|
|
|
|
unconstrained_dims)]
|
2022-12-16 13:06:38 -08:00
|
|
|
|
return tree_unflatten(tree, outs)
|
|
|
|
|
|
2023-05-17 11:49:31 -07:00
|
|
|
|
def _identity_fn(x): return x
|
|
|
|
|
|
2024-06-27 16:46:44 -07:00
|
|
|
|
def _sharding_constraint_impl(x, sharding, layout, resource_env,
|
|
|
|
|
unconstrained_dims):
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
if (isinstance(sharding, NamedSharding) and
|
|
|
|
|
isinstance(sharding.mesh, AbstractMesh)):
|
2024-08-27 13:30:12 -07:00
|
|
|
|
aval = shaped_abstractify(x)
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
if not hasattr(x, 'sharding'):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'Target sharding contains a `jax.sharding.AbstractMesh` which'
|
|
|
|
|
' requires the input passed should be a `jax.Array`. Got'
|
|
|
|
|
f' {type(x)} with shape {aval.str_short()}')
|
|
|
|
|
if not isinstance(x.sharding, NamedSharding):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
'The sharding on the input must be a `NamedSharding` since the target'
|
|
|
|
|
' sharding has an `AbstractMesh` in it. Got sharding type'
|
2024-08-27 13:30:12 -07:00
|
|
|
|
f' {type(x.sharding)} for shape {aval.str_short()}')
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
if x.sharding.mesh.shape_tuple != sharding.mesh.shape_tuple:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f'Mesh shape of the input {x.sharding.mesh.shape_tuple} does not'
|
|
|
|
|
' match the mesh shape of the target sharding'
|
2024-08-27 13:30:12 -07:00
|
|
|
|
f' {sharding.mesh.shape_tuple} for shape {aval.str_short()}')
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
sharding = NamedSharding._from_parsed_pspec(
|
|
|
|
|
x.sharding.mesh, sharding._parsed_pspec)
|
|
|
|
|
|
2024-06-27 16:46:44 -07:00
|
|
|
|
if layout is None:
|
|
|
|
|
if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim):
|
|
|
|
|
return x
|
|
|
|
|
# Run a jit here to raise good errors when device assignment don't match.
|
|
|
|
|
return api.jit(_identity_fn, out_shardings=sharding)(x)
|
|
|
|
|
else:
|
|
|
|
|
if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and
|
|
|
|
|
x.sharding.is_equivalent_to(sharding, x.ndim)):
|
|
|
|
|
return x
|
|
|
|
|
return api.jit(_identity_fn, out_shardings=Layout(layout, sharding))(x)
|
2023-05-17 11:49:31 -07:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
sharding_constraint_p = core.Primitive("sharding_constraint")
|
|
|
|
|
sharding_constraint_p.def_impl(_sharding_constraint_impl)
|
|
|
|
|
sharding_constraint_p.def_abstract_eval(lambda x, **_: x)
|
|
|
|
|
ad.deflinear2(sharding_constraint_p,
|
|
|
|
|
lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),))
|
|
|
|
|
|
2024-06-27 16:46:44 -07:00
|
|
|
|
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resource_env, unconstrained_dims):
|
|
|
|
|
aval, = ctx.avals_in
|
2023-04-05 09:38:37 +02:00
|
|
|
|
out_aval, = ctx.avals_out
|
2022-12-16 13:06:38 -08:00
|
|
|
|
axis_ctx = ctx.module_context.axis_context
|
2024-07-24 19:01:31 -07:00
|
|
|
|
if (isinstance(axis_ctx, sharding_impls.SPMDAxisContext) and
|
|
|
|
|
axis_ctx.manual_axes):
|
|
|
|
|
sharding = mlir.add_manual_axes(axis_ctx, sharding, aval.ndim)
|
2024-07-25 04:20:09 -07:00
|
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
|
sharding = sharding._to_sdy_sharding(aval.ndim)
|
|
|
|
|
else:
|
|
|
|
|
sharding = sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
|
2024-06-27 16:46:44 -07:00
|
|
|
|
out = mlir.wrap_with_sharding_op(
|
2024-07-25 04:20:09 -07:00
|
|
|
|
ctx, x_node, out_aval, sharding, unspecified_dims=unconstrained_dims)
|
2024-06-27 16:46:44 -07:00
|
|
|
|
if layout is not None:
|
|
|
|
|
out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval)
|
|
|
|
|
return [out]
|
2022-12-16 13:06:38 -08:00
|
|
|
|
mlir.register_lowering(sharding_constraint_p,
|
|
|
|
|
_sharding_constraint_hlo_lowering)
|
|
|
|
|
|
|
|
|
|
|
2024-06-27 16:46:44 -07:00
|
|
|
|
def _sharding_constraint_batcher(
|
2024-07-24 19:01:31 -07:00
|
|
|
|
spmd_axis_name, axis_size, axis_name, main_type, vals_in,
|
2024-06-27 16:46:44 -07:00
|
|
|
|
dims_in, sharding, layout, resource_env, unconstrained_dims):
|
2024-05-04 03:27:31 +00:00
|
|
|
|
if spmd_axis_name is not None and isinstance(sharding, NamedSharding):
|
|
|
|
|
used = {n for ns in sharding.spec
|
|
|
|
|
for n in (ns if isinstance(ns, tuple) else (ns,))}
|
|
|
|
|
if set(spmd_axis_name) & used:
|
2024-05-30 15:06:12 -07:00
|
|
|
|
raise ValueError(f"vmap spmd_axis_name {spmd_axis_name} cannot appear in "
|
|
|
|
|
"with_sharding_constraint spec, but got spec "
|
|
|
|
|
f"{sharding.spec}")
|
2022-12-16 13:06:38 -08:00
|
|
|
|
x, = vals_in
|
|
|
|
|
d, = dims_in
|
2024-05-30 17:42:14 -07:00
|
|
|
|
|
2024-07-24 19:01:31 -07:00
|
|
|
|
unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims}
|
|
|
|
|
if spmd_axis_name is None:
|
2022-12-16 13:06:38 -08:00
|
|
|
|
unconstrained_dims.add(d)
|
2024-05-30 17:42:14 -07:00
|
|
|
|
|
|
|
|
|
vmapped_sharding = _pjit_batcher_for_sharding(
|
2024-07-24 19:01:31 -07:00
|
|
|
|
sharding, d, spmd_axis_name, resource_env.physical_mesh, x.ndim)
|
2024-05-30 17:42:14 -07:00
|
|
|
|
if unconstrained_dims and isinstance(vmapped_sharding, NamedSharding):
|
|
|
|
|
new_spec = list(vmapped_sharding.spec) + [None] * (x.ndim - len(vmapped_sharding.spec))
|
|
|
|
|
for u in unconstrained_dims:
|
|
|
|
|
new_spec[u] = PartitionSpec.UNCONSTRAINED
|
|
|
|
|
vmapped_sharding = NamedSharding(
|
|
|
|
|
vmapped_sharding.mesh, PartitionSpec(*new_spec))
|
|
|
|
|
|
2024-06-27 16:46:44 -07:00
|
|
|
|
# TODO(yashkatariya): Figure out layouts should change under vmap.
|
|
|
|
|
if layout is not None:
|
2024-07-25 18:31:50 -07:00
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
'Concrete layout is not supported for vmap(with_sharding_constraint). '
|
|
|
|
|
f'Got layout {layout}')
|
2024-06-27 16:46:44 -07:00
|
|
|
|
|
2022-12-16 13:06:38 -08:00
|
|
|
|
y = sharding_constraint_p.bind(
|
|
|
|
|
x,
|
2024-05-30 17:42:14 -07:00
|
|
|
|
sharding=vmapped_sharding,
|
2024-06-27 16:46:44 -07:00
|
|
|
|
layout=layout,
|
2022-12-16 13:06:38 -08:00
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
unconstrained_dims=unconstrained_dims)
|
|
|
|
|
return y, d
|
2024-07-24 19:01:31 -07:00
|
|
|
|
batching.spmd_axis_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher
|
2022-12-16 13:06:38 -08:00
|
|
|
|
batching.axis_primitive_batchers[sharding_constraint_p] = partial(
|
2024-07-24 19:01:31 -07:00
|
|
|
|
_sharding_constraint_batcher, None)
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
# -------------------- helpers --------------------
|
|
|
|
|
|
|
|
|
|
def get_unconstrained_dims(sharding: NamedSharding):
|
2023-03-14 14:19:25 -07:00
|
|
|
|
assert sharding._parsed_pspec is not None
|
2022-12-16 13:06:38 -08:00
|
|
|
|
return {i for i, axes in enumerate(sharding._parsed_pspec)
|
|
|
|
|
if axes is None}
|
|
|
|
|
|
|
|
|
|
|
2023-11-17 12:18:46 -08:00
|
|
|
|
def _get_partition_spec(
|
|
|
|
|
ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]:
|
2023-04-10 10:15:08 -07:00
|
|
|
|
return [get_single_pspec(p) for p in ppspec]
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
|
2023-08-01 13:26:43 -07:00
|
|
|
|
def get_op_sharding_from_executable(
|
2023-06-23 15:11:37 -07:00
|
|
|
|
executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
|
|
|
|
|
in_op_shardings: list[xc.OpSharding] = []
|
2022-12-19 17:38:24 +00:00
|
|
|
|
parameter_shardings_from_xla = executable.get_parameter_shardings()
|
|
|
|
|
if parameter_shardings_from_xla is not None:
|
|
|
|
|
in_op_shardings = parameter_shardings_from_xla
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
out_op_shardings: list[xc.OpSharding] = []
|
2022-12-19 17:38:24 +00:00
|
|
|
|
output_shardings_from_xla = executable.get_output_shardings()
|
|
|
|
|
if output_shardings_from_xla is not None:
|
|
|
|
|
out_op_shardings = output_shardings_from_xla
|
2022-12-16 13:06:38 -08:00
|
|
|
|
|
|
|
|
|
return in_op_shardings, out_op_shardings
|
|
|
|
|
|
|
|
|
|
|
2023-11-17 12:18:46 -08:00
|
|
|
|
def _get_ppspec_from_executable(
|
|
|
|
|
executable, mesh
|
|
|
|
|
) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
|
2023-08-01 13:26:43 -07:00
|
|
|
|
input_op_shardings, output_op_sharding = get_op_sharding_from_executable(
|
2023-07-19 17:38:14 -07:00
|
|
|
|
executable
|
|
|
|
|
)
|
2023-06-23 15:11:37 -07:00
|
|
|
|
in_ppspec: list[ParsedPartitionSpec] = []
|
2022-12-16 13:06:38 -08:00
|
|
|
|
for s in input_op_shardings:
|
|
|
|
|
in_ppspec.extend(parse_flatten_op_sharding(s, mesh))
|
2023-07-19 17:38:14 -07:00
|
|
|
|
|
|
|
|
|
out_ppspec: list[ParsedPartitionSpec] = []
|
|
|
|
|
for s in output_op_sharding:
|
|
|
|
|
out_ppspec.extend(parse_flatten_op_sharding(s, mesh))
|
2022-12-16 13:06:38 -08:00
|
|
|
|
return in_ppspec, out_ppspec
|
|
|
|
|
|
|
|
|
|
|
2023-08-01 13:26:43 -07:00
|
|
|
|
def get_pspec_from_executable(
|
2022-12-16 13:06:38 -08:00
|
|
|
|
executable, mesh: pxla.Mesh
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]:
|
2022-12-16 13:06:38 -08:00
|
|
|
|
in_ppspec, out_ppspec = _get_ppspec_from_executable(executable, mesh)
|
|
|
|
|
out_partition_spec = _get_partition_spec(out_ppspec)
|
|
|
|
|
in_partition_spec = _get_partition_spec(in_ppspec)
|
|
|
|
|
return tuple(in_partition_spec), tuple(out_partition_spec)
|