mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Make jit
a thin wrapper around pjit
which ignores the mesh context manager (just like how it is today)
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh. This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`. PiperOrigin-RevId: 501707496
This commit is contained in:
parent
7206cb5b7b
commit
c8ad89e358
@ -282,10 +282,36 @@ def jit(
|
||||
return _jit(False, fun, static_argnums, static_argnames, device, backend,
|
||||
donate_argnums, inline, keep_unused, abstracted_axes)
|
||||
|
||||
# TODO(yashkatariya): Remove the above jit function after
|
||||
# `jax_jit_pjit_api_merge` defaults to True.
|
||||
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
jit = pjit.pjit # type: ignore # noqa: F811
|
||||
def jit( # type: ignore # noqa: F811 # pylint: disable=function-redefined
|
||||
fun: Callable,
|
||||
in_axis_resources=pxla._UNSPECIFIED,
|
||||
out_axis_resources=pxla._UNSPECIFIED,
|
||||
static_argnums: Union[int, Sequence[int], None] = None,
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
donate_argnums: Union[int, Sequence[int]] = (),
|
||||
keep_unused: bool = False,
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
inline: bool = False,
|
||||
) -> stages.Wrapped:
|
||||
(in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
|
||||
static_argnames) = pjit.pre_infer_params(
|
||||
fun, in_axis_resources, out_axis_resources, donate_argnums,
|
||||
static_argnums, static_argnames, device, backend)
|
||||
|
||||
def infer_params(*args, **kwargs):
|
||||
pjit_info_args = pjit.PjitInfo(
|
||||
fun=fun, in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
device=device, backend=backend, keep_unused=keep_unused,
|
||||
inline=inline, resource_env=None)
|
||||
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
return pjit.post_infer_params(fun, infer_params, static_argnums,
|
||||
static_argnames)
|
||||
|
||||
|
||||
def _jit(
|
||||
|
@ -855,8 +855,12 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
# Update pjit params to account for extra error values.
|
||||
num_error_vals = len(err_vals)
|
||||
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
|
||||
sharding = OpShardingSharding.get_replicated(
|
||||
list(resource_env.physical_mesh.devices.flat))
|
||||
if jax.config.jax_array:
|
||||
sharding = pjit._UNSPECIFIED
|
||||
else:
|
||||
sharding = OpShardingSharding.get_replicated(
|
||||
list(resource_env.physical_mesh.devices.flat))
|
||||
|
||||
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
|
||||
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
|
||||
|
||||
|
574
jax/_src/pjit.py
574
jax/_src/pjit.py
@ -17,7 +17,7 @@ from enum import IntEnum
|
||||
import numpy as np
|
||||
from collections import OrderedDict, Counter
|
||||
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
|
||||
Iterable)
|
||||
Iterable, NamedTuple, Any)
|
||||
import itertools as it
|
||||
from functools import partial, lru_cache
|
||||
import threading
|
||||
@ -170,7 +170,237 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums, static_argnames):
|
||||
return wraps(fun)(cpp_pjit_f)
|
||||
|
||||
|
||||
# TODO(yashkatariya): Add pjit microbenchmarks.
|
||||
def pre_infer_params(fun, in_axis_resources, out_axis_resources,
|
||||
donate_argnums, static_argnums, static_argnames, device,
|
||||
backend):
|
||||
check_callable(fun)
|
||||
|
||||
if not config.jax_array and (_is_unspecified(in_axis_resources) or
|
||||
_is_unspecified(out_axis_resources)):
|
||||
raise ValueError(
|
||||
"in_axis_resources and out_axis_resources should not "
|
||||
"be the unspecified singleton value. Please enable `jax.Array` to use "
|
||||
"this feature. You can use jax.config.update('jax_array', True) or "
|
||||
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
|
||||
"boolean flag to something true-like.")
|
||||
|
||||
if backend is not None or device is not None:
|
||||
warnings.warn(
|
||||
'backend and device argument on jit is deprecated. You can use a '
|
||||
'`jax.sharding.Mesh` context manager or device_put the arguments '
|
||||
'before passing them to `jit`. Please see '
|
||||
'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html '
|
||||
'for more information.', DeprecationWarning)
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
f"got {device=} and {backend=}")
|
||||
if not _is_unspecified(in_axis_resources):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'in_axis_resources should not be specified.')
|
||||
if not _is_unspecified(out_axis_resources):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'out_axis_resources should not be specified.')
|
||||
|
||||
if isinstance(in_axis_resources, list):
|
||||
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
||||
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
|
||||
# in cases like these users expect tuples and lists to be treated
|
||||
# essentially interchangeably, so we canonicalize lists to tuples here
|
||||
# rather than raising an error. https://github.com/google/jax/issues/2367
|
||||
in_axis_resources = tuple(in_axis_resources)
|
||||
|
||||
in_axis_resources, _, _ = _prepare_axis_resources(
|
||||
in_axis_resources, "in_axis_resources")
|
||||
out_axis_resources, _, _ = _prepare_axis_resources(
|
||||
out_axis_resources, "out_axis_resources")
|
||||
|
||||
donate_argnums, static_argnums, static_argnames = resolve_argnums(
|
||||
fun, donate_argnums, static_argnums, static_argnames)
|
||||
|
||||
return (in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
|
||||
static_argnames)
|
||||
|
||||
|
||||
def post_infer_params(fun, infer_params, static_argnums, static_argnames):
|
||||
if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115:
|
||||
wrapped = _cpp_pjit(fun, infer_params, static_argnums, static_argnames)
|
||||
else:
|
||||
wrapped = _python_pjit(fun, infer_params)
|
||||
|
||||
def lower(*args, **kwargs):
|
||||
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
|
||||
donate_argnums) = infer_params(*args, **kwargs)
|
||||
if config.jax_array:
|
||||
resource_env = params['resource_env']
|
||||
mesh = None if resource_env is None else resource_env.physical_mesh
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args_flat, params['in_shardings'], params['out_shardings'], mesh)
|
||||
else:
|
||||
in_shardings = params['in_shardings']
|
||||
in_is_global = _calc_is_global_sequence(
|
||||
params['in_positional_semantics'], in_shardings)
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
in_is_global, params['keep_unused'], always_lower=True)
|
||||
|
||||
if kwargs:
|
||||
args_kwargs_in_tree = in_tree
|
||||
local_in_avals = in_tree.unflatten(flat_local_in_avals)
|
||||
else:
|
||||
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
|
||||
|
||||
return stages.Lowered.from_flat_info(
|
||||
lowering,
|
||||
args_kwargs_in_tree,
|
||||
local_in_avals,
|
||||
donate_argnums,
|
||||
out_tree,
|
||||
no_kwargs=True)
|
||||
|
||||
wrapped.lower = lower
|
||||
return wrapped
|
||||
|
||||
|
||||
class PjitInfo(NamedTuple):
|
||||
fun: Callable
|
||||
in_axis_resources: Any
|
||||
out_axis_resources: Any
|
||||
static_argnums: Tuple[int, ...]
|
||||
static_argnames: Tuple[str, ...]
|
||||
donate_argnums: Tuple[int, ...]
|
||||
device: Optional[xc.Device]
|
||||
backend: Optional[str]
|
||||
keep_unused: bool
|
||||
inline: bool
|
||||
resource_env: Any
|
||||
|
||||
|
||||
def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
(fun, in_axis_resources, out_axis_resources, static_argnums, static_argnames,
|
||||
donate_argnums, device, backend, keep_unused, inline,
|
||||
resource_env) = pjit_info_args
|
||||
|
||||
if kwargs and not _is_unspecified(in_axis_resources):
|
||||
raise ValueError(
|
||||
"pjit does not support kwargs when in_axis_resources is specified.")
|
||||
|
||||
if resource_env is not None:
|
||||
pjit_mesh = resource_env.physical_mesh
|
||||
if pjit_mesh.empty:
|
||||
if config.jax_array:
|
||||
# Don't enforce requiring a mesh when `jax_array` flag is enabled. But
|
||||
# if mesh is not empty then pjit will respect it.
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
|
||||
"it's defined at the call site?")
|
||||
else:
|
||||
pjit_mesh = None
|
||||
|
||||
if (backend or device) and pjit_mesh is not None and not pjit_mesh.empty:
|
||||
raise ValueError(
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit.")
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args,
|
||||
allow_invalid=True)
|
||||
del args
|
||||
|
||||
# TODO(yashkatariya): Merge the nokwargs and kwargs path. One blocker is
|
||||
# flatten_axes which if kwargs are present in the treedef (even empty {}),
|
||||
# leads to wrong expansion.
|
||||
if kwargs:
|
||||
f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs)
|
||||
args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs))
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
else:
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
|
||||
dyn_kwargs = ()
|
||||
del kwargs
|
||||
|
||||
if donate_argnums and not config.jax_debug_nans:
|
||||
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
if config.jax_array:
|
||||
# If backend or device is set as an arg on jit, then resolve them to
|
||||
# in_shardings and out_shardings as if user passed in in_axis_resources
|
||||
# and out_axis_resources.
|
||||
if backend or device:
|
||||
in_shardings = out_shardings = _create_sharding_with_device_backend(
|
||||
device, backend)
|
||||
else:
|
||||
in_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources)
|
||||
out_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources)
|
||||
else:
|
||||
in_shardings = tree_map(
|
||||
lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
|
||||
in_axis_resources)
|
||||
out_shardings = tree_map(
|
||||
lambda x: x if _is_unspecified(x) else
|
||||
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources)
|
||||
# This check fails extremely rarely and has a huge cost in the dispatch
|
||||
# path. So hide it behind the jax_enable_checks flag.
|
||||
if config.jax_enable_checks:
|
||||
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
|
||||
|
||||
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
|
||||
# TODO(yashkatariya): This is a hack. This should go away when avals have
|
||||
# is_global attribute.
|
||||
if config.jax_array:
|
||||
in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
|
||||
else:
|
||||
in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat))
|
||||
out_positional_semantics = (
|
||||
pxla._PositionalSemantics.GLOBAL
|
||||
if config.jax_parallel_functions_output_gda or config.jax_array else
|
||||
pxla._positional_semantics.val)
|
||||
|
||||
global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources(
|
||||
hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,
|
||||
tuple(isinstance(a, GDA) for a in args_flat), resource_env)
|
||||
|
||||
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
|
||||
flat_fun, hashable_pytree(out_shardings), global_in_avals,
|
||||
HashableFunction(out_tree, closure=()))
|
||||
|
||||
if (any(_is_from_gda(i) for i in canonicalized_in_shardings_flat) or
|
||||
not config.jax_array):
|
||||
canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec(
|
||||
canonicalized_in_shardings_flat, args_flat)
|
||||
|
||||
assert len(args_flat) == len(canonicalized_in_shardings_flat)
|
||||
|
||||
canonicalized_in_shardings_flat = (
|
||||
_UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat
|
||||
donated_invars = (False,) * len(consts) + donated_invars
|
||||
in_positional_semantics = (
|
||||
pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics
|
||||
|
||||
# in_shardings and out_shardings here are all OpShardingSharding.
|
||||
params = dict(
|
||||
jaxpr=jaxpr,
|
||||
in_shardings=canonicalized_in_shardings_flat,
|
||||
out_shardings=canonicalized_out_shardings_flat,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=getattr(flat_fun, '__name__', '<unnamed function>'),
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
)
|
||||
return (consts + args_flat, local_in_avals, params, in_tree, out_tree(),
|
||||
donate_argnums)
|
||||
|
||||
|
||||
# in_axis_resources and out_axis_resources can't be None as the default value
|
||||
# because `None` means that the input is fully replicated.
|
||||
def pjit(
|
||||
@ -320,206 +550,24 @@ def pjit(
|
||||
... print(f(x)) # doctest: +SKIP
|
||||
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
|
||||
"""
|
||||
check_callable(fun)
|
||||
|
||||
if not config.jax_array and (_is_unspecified(in_axis_resources) or
|
||||
_is_unspecified(out_axis_resources)):
|
||||
raise ValueError(
|
||||
"in_axis_resources and out_axis_resources should not "
|
||||
"be the unspecified singleton value. Please enable `jax.Array` to use "
|
||||
"this feature. You can use jax.config.update('jax_array', True) or "
|
||||
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
|
||||
"boolean flag to something true-like.")
|
||||
|
||||
if backend is not None or device is not None:
|
||||
warnings.warn(
|
||||
'backend and device argument on jit is deprecated. You can use a '
|
||||
'`jax.sharding.Mesh` context manager or device_put the arguments '
|
||||
'before passing them to `jit`. Please see '
|
||||
'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html '
|
||||
'for more information.', DeprecationWarning)
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
f"got {device=} and {backend=}")
|
||||
if not _is_unspecified(in_axis_resources):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'in_axis_resources should not be specified.')
|
||||
if not _is_unspecified(out_axis_resources):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'out_axis_resources should not be specified.')
|
||||
|
||||
if isinstance(in_axis_resources, list):
|
||||
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
||||
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
|
||||
# in cases like these users expect tuples and lists to be treated
|
||||
# essentially interchangeably, so we canonicalize lists to tuples here
|
||||
# rather than raising an error. https://github.com/google/jax/issues/2367
|
||||
in_axis_resources = tuple(in_axis_resources)
|
||||
|
||||
in_axis_resources, _, _ = _prepare_axis_resources(
|
||||
in_axis_resources, "in_axis_resources")
|
||||
out_axis_resources, _, _ = _prepare_axis_resources(
|
||||
out_axis_resources, "out_axis_resources")
|
||||
|
||||
donate_argnums, static_argnums, static_argnames = resolve_argnums(
|
||||
fun, donate_argnums, static_argnums, static_argnames)
|
||||
(in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
|
||||
static_argnames) = pre_infer_params(
|
||||
fun, in_axis_resources, out_axis_resources, donate_argnums,
|
||||
static_argnums, static_argnames, device, backend)
|
||||
|
||||
def infer_params(*args, **kwargs):
|
||||
if kwargs and not _is_unspecified(in_axis_resources):
|
||||
raise ValueError(
|
||||
"pjit does not support kwargs when in_axis_resources is specified.")
|
||||
|
||||
# Putting this outside of wrapped would make resources lexically scoped
|
||||
resource_env = pxla.thread_resources.env
|
||||
pjit_mesh = resource_env.physical_mesh
|
||||
if pjit_mesh.empty:
|
||||
if config.jax_array:
|
||||
# Don't enforce requiring a mesh when `jax_array` flag is enabled. But
|
||||
# if mesh is not empty then pjit will respect it.
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
|
||||
"it's defined at the call site?")
|
||||
pjit_info_args = PjitInfo(
|
||||
fun=fun, in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
device=device, backend=backend, keep_unused=keep_unused,
|
||||
inline=inline, resource_env=resource_env)
|
||||
return common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
if (backend or device) and not pjit_mesh.empty:
|
||||
raise ValueError(
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit.")
|
||||
return post_infer_params(fun, infer_params, static_argnums, static_argnames)
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args,
|
||||
allow_invalid=True)
|
||||
del args
|
||||
|
||||
# TODO(yashkatariya): Merge the nokwargs and kwargs path. One blocker is
|
||||
# flatten_axes which if kwargs are present in the treedef (even empty {}),
|
||||
# leads to wrong expansion.
|
||||
if kwargs:
|
||||
f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs)
|
||||
args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs))
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
else:
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
|
||||
dyn_kwargs = ()
|
||||
del kwargs
|
||||
|
||||
if donate_argnums and not config.jax_debug_nans:
|
||||
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
if config.jax_array:
|
||||
# If backend or device is set as an arg on jit, then resolve them to
|
||||
# in_shardings and out_shardings as if user passed in in_axis_resources
|
||||
# and out_axis_resources.
|
||||
if backend or device:
|
||||
in_shardings = out_shardings = _create_sharding_with_device_backend(
|
||||
device, backend)
|
||||
else:
|
||||
in_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources)
|
||||
out_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources)
|
||||
else:
|
||||
in_shardings = tree_map(
|
||||
lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
|
||||
in_axis_resources)
|
||||
out_shardings = tree_map(
|
||||
lambda x: x if _is_unspecified(x) else
|
||||
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources)
|
||||
# This check fails extremely rarely and has a huge cost in the dispatch
|
||||
# path. So hide it behind the jax_enable_checks flag.
|
||||
if config.jax_enable_checks:
|
||||
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
|
||||
|
||||
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
|
||||
# TODO(yashkatariya): This is a hack. This should go away when avals have
|
||||
# is_global attribute.
|
||||
if config.jax_array:
|
||||
in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
|
||||
else:
|
||||
in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat))
|
||||
out_positional_semantics = (
|
||||
pxla._PositionalSemantics.GLOBAL
|
||||
if config.jax_parallel_functions_output_gda or config.jax_array else
|
||||
pxla._positional_semantics.val)
|
||||
|
||||
global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources(
|
||||
hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,
|
||||
tuple(isinstance(a, GDA) for a in args_flat), resource_env)
|
||||
|
||||
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
|
||||
flat_fun, hashable_pytree(out_shardings), global_in_avals,
|
||||
HashableFunction(out_tree, closure=()))
|
||||
|
||||
if (any(_is_from_gda(i) for i in canonicalized_in_shardings_flat) or
|
||||
not config.jax_array):
|
||||
canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec(
|
||||
canonicalized_in_shardings_flat, args_flat)
|
||||
|
||||
assert len(args_flat) == len(canonicalized_in_shardings_flat)
|
||||
|
||||
canonicalized_in_shardings_flat = (
|
||||
_UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat
|
||||
donated_invars = (False,) * len(consts) + donated_invars
|
||||
in_positional_semantics = (
|
||||
pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics
|
||||
|
||||
# in_shardings and out_shardings here are all OpShardingSharding.
|
||||
params = dict(
|
||||
jaxpr=jaxpr,
|
||||
in_shardings=canonicalized_in_shardings_flat,
|
||||
out_shardings=canonicalized_out_shardings_flat,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=getattr(flat_fun, '__name__', '<unnamed function>'),
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
)
|
||||
return (consts + args_flat, local_in_avals, params, in_tree, out_tree(),
|
||||
donate_argnums)
|
||||
|
||||
if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115:
|
||||
wrapped = _cpp_pjit(fun, infer_params, static_argnums, static_argnames)
|
||||
else:
|
||||
wrapped = _python_pjit(fun, infer_params)
|
||||
|
||||
def lower(*args, **kwargs):
|
||||
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
|
||||
donate_argnums) = infer_params(*args, **kwargs)
|
||||
if config.jax_array:
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args_flat, params['in_shardings'], params['out_shardings'],
|
||||
params['resource_env'].physical_mesh)
|
||||
else:
|
||||
in_shardings = params['in_shardings']
|
||||
in_is_global = _calc_is_global_sequence(
|
||||
params['in_positional_semantics'], in_shardings)
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
in_is_global, params['keep_unused'], always_lower=True)
|
||||
|
||||
if kwargs:
|
||||
args_kwargs_in_tree = in_tree
|
||||
local_in_avals = in_tree.unflatten(flat_local_in_avals)
|
||||
else:
|
||||
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
|
||||
|
||||
return stages.Lowered.from_flat_info(
|
||||
lowering,
|
||||
args_kwargs_in_tree,
|
||||
local_in_avals,
|
||||
donate_argnums,
|
||||
out_tree,
|
||||
no_kwargs=True)
|
||||
|
||||
wrapped.lower = lower
|
||||
return wrapped
|
||||
|
||||
class _ListWithW(list):
|
||||
__slots__ = ('__weakref__',)
|
||||
@ -543,6 +591,10 @@ def _create_sharding_for_array(mesh, x):
|
||||
# FROM_GDA is removed.
|
||||
if isinstance(x, XLACompatibleSharding) or _is_unspecified_or_from_gda_or_auto(x):
|
||||
return x
|
||||
if mesh is None:
|
||||
raise RuntimeError(
|
||||
"jit does not support using the mesh context manager. Please pass in "
|
||||
"the sharding explicitly via in_axis_resources or out_axis_resources.")
|
||||
if mesh.empty:
|
||||
raise RuntimeError("pjit requires a non-empty mesh! Is a mesh defined at "
|
||||
"the call site? Alternatively, provide a "
|
||||
@ -944,7 +996,10 @@ pjit_p = core.Primitive("pjit")
|
||||
pjit_p.multiple_results = True
|
||||
|
||||
|
||||
def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
|
||||
def _resolve_in_shardings(
|
||||
args, pjit_in_shardings: Sequence[PjitSharding],
|
||||
out_shardings: Sequence[PjitSharding],
|
||||
pjit_mesh: Optional[pxla.Mesh]) -> Sequence[PjitSharding]:
|
||||
# If True, means that device or backend is set by the user on pjit and it
|
||||
# has the same semantics as device_put i.e. doesn't matter which device the
|
||||
# arg is on, reshard it to the device mentioned. So don't do any of the
|
||||
@ -972,7 +1027,7 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
|
||||
pxla._get_and_check_device_assignment(
|
||||
it.chain(
|
||||
committed_arg_shardings, pjit_in_shardings, out_shardings),
|
||||
(None if pjit_mesh.empty else list(pjit_mesh.devices.flat)))
|
||||
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
|
||||
|
||||
resolved_in_shardings = []
|
||||
for arg, pjit_in_s in safe_zip(args, pjit_in_shardings):
|
||||
@ -1037,8 +1092,9 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
if config.jax_array:
|
||||
in_shardings = _resolve_in_shardings(args, in_shardings, out_shardings,
|
||||
resource_env.physical_mesh)
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args, in_shardings, out_shardings,
|
||||
resource_env.physical_mesh if resource_env is not None else None)
|
||||
|
||||
in_is_global = _calc_is_global_sequence(in_positional_semantics, in_shardings)
|
||||
if config.jax_array and all(_is_unspecified(o) for o in out_shardings):
|
||||
@ -1145,11 +1201,15 @@ def _pjit_lower_cached(
|
||||
out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings
|
||||
|
||||
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
|
||||
|
||||
f = core.jaxpr_as_fun(jaxpr)
|
||||
f.__name__ = name
|
||||
fun = lu.wrap_init(f)
|
||||
|
||||
mesh = resource_env.physical_mesh
|
||||
if resource_env is not None:
|
||||
mesh = resource_env.physical_mesh
|
||||
else:
|
||||
mesh = None
|
||||
|
||||
# Convert to `NamedSharding` when `jax_array` is not enabled. This is
|
||||
# because GDA/SDA/DA are dependent on mesh for generating outputs.
|
||||
@ -1187,7 +1247,8 @@ def _pjit_lower_cached(
|
||||
fun, 'pjit', name, in_shardings, out_shardings, donated_invars,
|
||||
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
|
||||
always_lower=always_lower,
|
||||
devices_from_context=(None if mesh.empty else list(mesh.devices.flat)))
|
||||
devices_from_context=(
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)))
|
||||
|
||||
|
||||
def pjit_staging_rule(trace, *args, **params):
|
||||
@ -1207,6 +1268,8 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
|
||||
disallowed_effects = jaxpr.effects - mlir.lowerable_effects
|
||||
if disallowed_effects:
|
||||
raise ValueError('Effects not supported in `pjit`.')
|
||||
if config.jax_array:
|
||||
return jaxpr.out_avals, jaxpr.effects
|
||||
return global_to_local(out_positional_semantics, jaxpr.out_avals,
|
||||
out_shardings, resource_env.physical_mesh), jaxpr.effects
|
||||
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
||||
@ -1216,7 +1279,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
out_shardings, resource_env, donated_invars,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused, inline):
|
||||
if not config.jax_array:
|
||||
if not config.jax_jit_pjit_api_merge:
|
||||
if not isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")
|
||||
@ -1264,7 +1327,12 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
# `insert_axis` is set to True only for some `xmap` uses.
|
||||
new_parts = (axis_name,) if insert_axis else (
|
||||
() if spmd_axis_name is None else (spmd_axis_name,))
|
||||
mesh = resource_env.physical_mesh
|
||||
|
||||
if resource_env is not None:
|
||||
mesh = resource_env.physical_mesh
|
||||
else:
|
||||
mesh = None
|
||||
|
||||
in_shardings = tuple(
|
||||
_pjit_batcher_for_sharding(i, 0, new_parts, mesh, aval.ndim) if is_mapped else i
|
||||
for is_mapped, i, aval in zip(is_mapped_in, in_shardings, new_jaxpr.in_avals))
|
||||
@ -1302,7 +1370,7 @@ def _pjit_batcher_for_sharding(
|
||||
return OpShardingSharding(s._device_assignment, new_op) # type: ignore
|
||||
else:
|
||||
assert isinstance(s, OpShardingSharding)
|
||||
assert not mesh.empty
|
||||
assert mesh is not None and not mesh.empty
|
||||
parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0] # type: ignore
|
||||
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val)
|
||||
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
|
||||
@ -1378,32 +1446,37 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
|
||||
if num_residuals:
|
||||
in_is_global = _calc_is_global_sequence(
|
||||
known_params['in_positional_semantics'], known_params['in_shardings'])
|
||||
compiled = _pjit_lower(
|
||||
known_params["jaxpr"], known_params["in_shardings"],
|
||||
known_params["out_shardings"], known_params["resource_env"],
|
||||
known_params["donated_invars"], known_params["name"],
|
||||
in_is_global, known_params['keep_unused'], always_lower=False).compile(
|
||||
_allow_propagation_to_outputs=True,
|
||||
_allow_compile_replicated=False)
|
||||
da = compiled._device_assignment
|
||||
_, out_op_sharding_shardings = pxla._get_op_sharding_shardings_from_executable(
|
||||
compiled.xla_executable, da, len(known_jaxpr.in_avals),
|
||||
len(known_jaxpr.out_avals))
|
||||
assert len(out_op_sharding_shardings) == len(known_jaxpr.out_avals), (
|
||||
len(out_op_sharding_shardings), len(known_jaxpr.out_avals))
|
||||
out_op_shardings = [o._to_xla_op_sharding(a.ndim) for o, a in
|
||||
safe_zip(out_op_sharding_shardings, known_jaxpr.out_avals)]
|
||||
residual_op_shardings = tuple(out_op_shardings[-num_residuals:])
|
||||
else:
|
||||
residual_op_shardings = ()
|
||||
assert len(residual_shardings) == len(residual_op_shardings), (
|
||||
len(residual_shardings), len(residual_op_shardings))
|
||||
residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings)
|
||||
known_params['out_shardings'] = (
|
||||
keep_where(out_shardings, known_outs) + residual_shardings)
|
||||
# resource_env is None in the jit wrapper around pjit.
|
||||
# TODO(apaszke,yashkatariya): Replace this check with
|
||||
# `if not config.jax_array` after XLA stops overriding user shardings when
|
||||
# `_allow_propagation_to_outputs = True`.
|
||||
if resource_env is not None:
|
||||
if num_residuals:
|
||||
in_is_global = _calc_is_global_sequence(
|
||||
known_params['in_positional_semantics'], known_params['in_shardings'])
|
||||
compiled = _pjit_lower(
|
||||
known_params["jaxpr"], known_params["in_shardings"],
|
||||
known_params["out_shardings"], known_params["resource_env"],
|
||||
known_params["donated_invars"], known_params["name"],
|
||||
in_is_global, known_params['keep_unused'], always_lower=False).compile(
|
||||
_allow_propagation_to_outputs=True,
|
||||
_allow_compile_replicated=False)
|
||||
da = compiled._device_assignment
|
||||
_, out_op_sharding_shardings = pxla._get_op_sharding_shardings_from_executable(
|
||||
compiled.xla_executable, da, len(known_jaxpr.in_avals),
|
||||
len(known_jaxpr.out_avals))
|
||||
assert len(out_op_sharding_shardings) == len(known_jaxpr.out_avals), (
|
||||
len(out_op_sharding_shardings), len(known_jaxpr.out_avals))
|
||||
out_op_shardings = [o._to_xla_op_sharding(a.ndim) for o, a in
|
||||
safe_zip(out_op_sharding_shardings, known_jaxpr.out_avals)]
|
||||
residual_op_shardings = tuple(out_op_shardings[-num_residuals:])
|
||||
else:
|
||||
residual_op_shardings = ()
|
||||
assert len(residual_shardings) == len(residual_op_shardings), (
|
||||
len(residual_shardings), len(residual_op_shardings))
|
||||
residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings)
|
||||
known_params['out_shardings'] = (
|
||||
keep_where(out_shardings, known_outs) + residual_shardings)
|
||||
|
||||
all_known_outs = pjit_p.bind(
|
||||
*(pv.get_known() for pv in in_pvals if pv.is_known()),
|
||||
@ -1436,12 +1509,16 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
|
||||
if config.jax_array:
|
||||
unknown_out_avals = unknown_jaxpr.out_avals
|
||||
else:
|
||||
unknown_out_avals = global_to_local(
|
||||
unknown_params["out_positional_semantics"], unknown_jaxpr.out_avals,
|
||||
unknown_params["out_shardings"],
|
||||
unknown_params["resource_env"].physical_mesh)
|
||||
unknown_tracers_out = [
|
||||
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
||||
for aval in global_to_local(unknown_params["out_positional_semantics"],
|
||||
unknown_jaxpr.out_avals,
|
||||
unknown_params["out_shardings"],
|
||||
unknown_params["resource_env"].physical_mesh)
|
||||
for aval in unknown_out_avals
|
||||
]
|
||||
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
|
||||
unknown_tracers_out,
|
||||
@ -1525,8 +1602,9 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax
|
||||
def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources):
|
||||
jaxpr = params["jaxpr"]
|
||||
what = "pjit input"
|
||||
if resource_env.physical_mesh != params['resource_env'].physical_mesh:
|
||||
raise RuntimeError("Changing the physical mesh is not allowed inside pjit.")
|
||||
if (resource_env is not None and params['resource_env'] is not None and
|
||||
resource_env.physical_mesh != params['resource_env'].physical_mesh):
|
||||
raise RuntimeError("Changing the physical mesh is not allowed inside pjit.")
|
||||
|
||||
for aval, s in zip(jaxpr.in_avals, params['in_shardings']):
|
||||
if _is_unspecified(s) or _is_auto(s):
|
||||
@ -1535,9 +1613,14 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
|
||||
s._original_sharding, '_parsed_pspec'):
|
||||
parsed_pspec = s._original_sharding._parsed_pspec
|
||||
else:
|
||||
parsed_pspec = parse_flatten_op_sharding(
|
||||
s._op_sharding, resource_env.physical_mesh)[0]
|
||||
_check_resources_against_named_axes(what, aval, parsed_pspec, named_axis_resources)
|
||||
if resource_env is not None:
|
||||
parsed_pspec = parse_flatten_op_sharding(
|
||||
s._op_sharding, resource_env.physical_mesh)[0]
|
||||
else:
|
||||
parsed_pspec = None
|
||||
if parsed_pspec is not None:
|
||||
_check_resources_against_named_axes(what, aval, parsed_pspec,
|
||||
named_axis_resources)
|
||||
|
||||
pxla.resource_typecheck(
|
||||
jaxpr.jaxpr, resource_env, named_axis_resources,
|
||||
@ -1552,9 +1635,14 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
|
||||
s._original_sharding, '_parsed_pspec'):
|
||||
parsed_pspec = s._original_sharding._parsed_pspec
|
||||
else:
|
||||
parsed_pspec = parse_flatten_op_sharding(
|
||||
s._op_sharding, resource_env.physical_mesh)[0]
|
||||
_check_resources_against_named_axes(what, aval, parsed_pspec, named_axis_resources)
|
||||
if resource_env is not None:
|
||||
parsed_pspec = parse_flatten_op_sharding(
|
||||
s._op_sharding, resource_env.physical_mesh)[0]
|
||||
else:
|
||||
parsed_pspec = None
|
||||
if parsed_pspec is not None:
|
||||
_check_resources_against_named_axes(what, aval, parsed_pspec,
|
||||
named_axis_resources)
|
||||
|
||||
pxla.custom_resource_typing_rules[pjit_p] = _resource_typing_pjit
|
||||
|
||||
|
@ -2710,8 +2710,10 @@ def _check_if_any_auto(
|
||||
|
||||
|
||||
def _get_and_check_device_assignment(
|
||||
shardings: Iterable[sharding_internal.XLACompatibleSharding],
|
||||
devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]:
|
||||
shardings: Iterable[Union[sharding_internal.XLACompatibleSharding,
|
||||
_UnspecifiedValue, _AUTOAxisResource]],
|
||||
devices: Optional[Sequence[xc.Device]]
|
||||
) -> Tuple[xla.Backend, Sequence[xc.Device]]:
|
||||
from jax._src.api import local_devices
|
||||
|
||||
first_device_assignment = None
|
||||
|
@ -199,6 +199,7 @@ jax_test(
|
||||
backend_tags = {
|
||||
"tpu": ["notsan"], # Times out under tsan.
|
||||
},
|
||||
enable_configs = ["cpu_jit_pjit_api_merge"],
|
||||
pjrt_c_api_bypass = True,
|
||||
shard_count = {
|
||||
"cpu": 5,
|
||||
|
@ -600,7 +600,10 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
f = pjit(lambda x: jax.grad(h)(x),
|
||||
in_axis_resources=None, out_axis_resources=None)
|
||||
x = jnp.arange(8, dtype=jnp.float32)
|
||||
self.assertAllClose(f(x), jnp.cos(x))
|
||||
out = f(x)
|
||||
self.assertAllClose(out, jnp.cos(x))
|
||||
if jax.config.jax_array:
|
||||
self.assertLen(out.devices(), 2)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testNoopPartitionSpecs(self):
|
||||
@ -2081,8 +2084,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
@jax_array(True)
|
||||
def test_pjit_single_device_sharding_add(self):
|
||||
a = jnp.array([1, 2, 3], dtype=jnp.float32)
|
||||
b = jnp.array([4, 5, 6], dtype=jnp.float32)
|
||||
a = np.array([1, 2, 3], dtype=jnp.float32)
|
||||
b = np.array([4, 5, 6], dtype=jnp.float32)
|
||||
|
||||
@pjit
|
||||
def add(x, y):
|
||||
@ -2462,11 +2465,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out = jnp.zeros(shape, jnp.bfloat16)
|
||||
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Pjit's devices and Array's devices should be equal. "
|
||||
r"Got Pjit's device ids \[0\] on platform.*and "
|
||||
r"Array's device ids \[0, 1, 2, 3\] on platform"):
|
||||
# This split is needed because original `jit` adds `device` as a
|
||||
# `devices_from_context` whereas `pjit` passes it as an in_sharding.
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
error_msg = ("Devices of all `Array` inputs and outputs should be the same. "
|
||||
r"Got array device ids \[0\] on platform.*and "
|
||||
r"another array's device ids \[0, 1, 2, 3\] on platform")
|
||||
else:
|
||||
error_msg = ("Pjit's devices and Array's devices should be equal. "
|
||||
r"Got Pjit's device ids \[0\] on platform.*and "
|
||||
r"Array's device ids \[0, 1, 2, 3\] on platform")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, error_msg):
|
||||
sharded_zeros((4096, 3072), P('x', 'y'))
|
||||
|
||||
@jax_array(True)
|
||||
@ -2920,7 +2930,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
_check(out2, jax.devices()[1], y)
|
||||
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
|
||||
h = pjit(mul, device=jax.devices()[-1])
|
||||
h_out = h(y)
|
||||
@ -2928,7 +2937,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
_check(h_out, jax.devices()[-1], y)
|
||||
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits)
|
||||
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
|
||||
|
||||
# AOT test
|
||||
compiled = f.lower(jax.ShapedArray(y.shape, y.dtype)).compile()
|
||||
@ -3130,6 +3138,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# Second call is to trigger C++ dispatch.
|
||||
f(inp) # doesn't crash
|
||||
|
||||
def test_pjit_sin_nested(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
|
||||
@pjit
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
with mesh:
|
||||
inp = jnp.arange(8.)
|
||||
out = f(inp)
|
||||
self.assertArraysAllClose(out, np.sin(inp))
|
||||
self.assertLen(out.devices(), 8)
|
||||
|
||||
def test_jit_with_mesh_context_manager(self):
|
||||
if not jax.config.jax_jit_pjit_api_merge:
|
||||
self.skipTest("This test only works if jax_jit_pjit_api_merge is True")
|
||||
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"jit does not support using the mesh context manager"):
|
||||
with mesh:
|
||||
jax.jit(lambda x: x, in_axis_resources=P('x'),
|
||||
out_axis_resources=P('x'))(jnp.arange(8))
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user