mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Initial commit for jax.experimental.compute_on
API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host. `cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation. PiperOrigin-RevId: 634909918
This commit is contained in:
parent
7a3fc7113b
commit
2d6d408b19
@ -301,6 +301,7 @@ py_library_providing_imports_info(
|
||||
":cloud_tpu_init",
|
||||
":compilation_cache_internal",
|
||||
":compiler",
|
||||
":compute_on",
|
||||
":config",
|
||||
":core",
|
||||
":custom_api_util",
|
||||
@ -711,6 +712,7 @@ pytype_strict_library(
|
||||
deps = [
|
||||
":ad_util",
|
||||
":api_util",
|
||||
":compute_on",
|
||||
":config",
|
||||
":core",
|
||||
":dtypes",
|
||||
@ -778,6 +780,12 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "compute_on",
|
||||
srcs = ["_src/compute_on.py"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "layout",
|
||||
srcs = ["_src/layout.py"],
|
||||
|
55
jax/_src/compute_on.py
Normal file
55
jax/_src/compute_on.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class ComputeOnContext(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
self.stack = []
|
||||
|
||||
compute_on_context = ComputeOnContext()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def extend_compute_type(c_type: str):
|
||||
compute_on_context.stack.append(c_type)
|
||||
if len(set(compute_on_context.stack)) > 1:
|
||||
raise NotImplementedError(
|
||||
'Nesting `compute_on` with different compute types is not supported'
|
||||
' yet.')
|
||||
try:
|
||||
yield compute_on_context.stack[-1]
|
||||
finally:
|
||||
compute_on_context.stack.pop()
|
||||
|
||||
def current_compute_type() -> str | None:
|
||||
return compute_on_context.stack[-1] if compute_on_context.stack else None
|
||||
|
||||
def _check_valid(c_type: str):
|
||||
if c_type not in {'device_host', 'device'}:
|
||||
raise ValueError('Invalid compute type received. Current supported values '
|
||||
f'are `device_host` and `device`. Got {c_type}')
|
||||
|
||||
@contextmanager
|
||||
def compute_on(compute_type: str):
|
||||
if not isinstance(compute_type, str):
|
||||
raise TypeError("`compute_on`'s compute_type argument must be a string.")
|
||||
_check_valid(compute_type)
|
||||
|
||||
with extend_compute_type(compute_type):
|
||||
yield
|
@ -20,6 +20,7 @@ from collections.abc import (Collection, Generator, Hashable, Iterable,
|
||||
MutableMapping)
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import dataclasses
|
||||
import functools
|
||||
from functools import partial, partialmethod, total_ordering
|
||||
import gc
|
||||
@ -259,6 +260,11 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
|
||||
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class JaxprEqnContext:
|
||||
compute_type: str | None
|
||||
|
||||
|
||||
class JaxprEqn(NamedTuple):
|
||||
invars: list[Atom]
|
||||
outvars: list[Var]
|
||||
@ -266,6 +272,7 @@ class JaxprEqn(NamedTuple):
|
||||
params: dict[str, Any]
|
||||
effects: Effects
|
||||
source_info: source_info_util.SourceInfo
|
||||
ctx: JaxprEqnContext
|
||||
|
||||
def __repr__(self):
|
||||
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()
|
||||
@ -278,6 +285,7 @@ class JaxprEqn(NamedTuple):
|
||||
params: dict[str, Any] | None = None,
|
||||
effects: Effects | None = None,
|
||||
source_info: source_info_util.SourceInfo | None = None,
|
||||
ctx: JaxprEqnContext | None = None
|
||||
):
|
||||
# It is slightly faster to rebuild the tuple directly than to call _replace.
|
||||
return JaxprEqn(
|
||||
@ -287,16 +295,19 @@ class JaxprEqn(NamedTuple):
|
||||
self.params if params is None else params,
|
||||
self.effects if effects is None else effects,
|
||||
self.source_info if source_info is None else source_info,
|
||||
self.ctx if ctx is None else ctx,
|
||||
)
|
||||
|
||||
|
||||
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
|
||||
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
|
||||
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
|
||||
ctx=None):
|
||||
source_info = source_info or source_info_util.new_source_info()
|
||||
ctx = ctx or JaxprEqnContext(None)
|
||||
if config.enable_checks.value:
|
||||
assert all(isinstance(x, (Var, Literal)) for x in invars)
|
||||
assert all(isinstance(v, Var) for v in outvars)
|
||||
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)
|
||||
return JaxprEqn(invars, outvars, primitive, params, effects, source_info, ctx)
|
||||
|
||||
_var_counter = it.count()
|
||||
|
||||
|
@ -694,8 +694,10 @@ class LoweringRuleContext:
|
||||
tokens_in: TokenSet
|
||||
tokens_out: TokenSet | None # Mutable store for output containers
|
||||
axis_size_env: dict[core.Var, ir.Value] | None = None # Dynamic axis sizes
|
||||
dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables
|
||||
# in same order as module_context.shape_poly_state.dim_vars
|
||||
# The values for the dimension variables in same order as
|
||||
# module_context.shape_poly_state.dim_vars
|
||||
dim_var_values: Sequence[ir.Value] = ()
|
||||
compute_type: str | None = None
|
||||
|
||||
def set_tokens_out(self, tokens_out: TokenSet):
|
||||
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
|
||||
@ -1565,12 +1567,14 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
|
||||
tokens_in = tokens.subset(effects)
|
||||
avals_in = map(aval, eqn.invars)
|
||||
compute_type = eqn.ctx.compute_type if eqn.ctx is not None else None
|
||||
rule_ctx = LoweringRuleContext(
|
||||
module_context=ctx, primitive=eqn.primitive,
|
||||
name_stack=source_info.name_stack,
|
||||
avals_in=avals_in,
|
||||
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
|
||||
tokens_out=None, dim_var_values=dim_var_values)
|
||||
tokens_out=None, dim_var_values=dim_var_values,
|
||||
compute_type=compute_type)
|
||||
if config.dynamic_shapes.value:
|
||||
axis_size_env = {d: read(d)[0]
|
||||
for a in avals_in if type(a) is core.DShapedArray
|
||||
|
@ -34,6 +34,7 @@ from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import profiler
|
||||
from jax._src import source_info_util
|
||||
from jax._src import compute_on
|
||||
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
|
||||
fun_sourceinfo)
|
||||
from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
|
||||
@ -41,7 +42,7 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
|
||||
ConcreteArray, Var, DropVar, raise_to_shaped, Atom,
|
||||
JaxprEqn, Primitive, ShapedArray, DShapedArray,
|
||||
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent)
|
||||
InputType, OutputType, get_referent, JaxprEqnContext)
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
||||
KeyPath, generate_key_paths, keystr)
|
||||
@ -1252,7 +1253,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
offload_eqn = core.JaxprEqn(
|
||||
outvars_copy, resvars, device_put_p,
|
||||
dict(device=TransferToMemoryKind(policy.dst), src=None),
|
||||
set(), source_info_util.new_source_info())
|
||||
set(), source_info_util.new_source_info(), JaxprEqnContext(None))
|
||||
known_eqns.append(offload_eqn)
|
||||
# resvars are known and available in the backward jaxpr.
|
||||
map(partial(write, False, True), resvars)
|
||||
@ -1260,7 +1261,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
reload_eqn = core.JaxprEqn(
|
||||
resvars, eqn.outvars, device_put_p, # type: ignore
|
||||
dict(device=TransferToMemoryKind(policy.src), src=None),
|
||||
set(), source_info_util.new_source_info())
|
||||
set(), source_info_util.new_source_info(), JaxprEqnContext(None))
|
||||
staged_eqns.append(reload_eqn)
|
||||
# outvars are known and available in the backward jaxpr.
|
||||
map(partial(write, False, True), eqn.outvars)
|
||||
@ -1382,10 +1383,11 @@ def call_partial_eval_custom_rule(
|
||||
residuals = [newvar(res_aval(params_known, var.aval))
|
||||
for var in jaxpr_staged.invars[:num_res]]
|
||||
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
eqn.source_info, eqn.ctx)
|
||||
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
|
||||
eqn.primitive, params_staged,
|
||||
jaxpr_staged.effects, eqn.source_info)
|
||||
jaxpr_staged.effects, eqn.source_info, eqn.ctx)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is Var and not inst]
|
||||
@ -1425,11 +1427,11 @@ def closed_call_partial_eval_custom_rule(
|
||||
eqn_known = new_jaxpr_eqn([*ins_known, *res_ref_binders],
|
||||
[*out_binders_known, *res_val_binders],
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
eqn.source_info)
|
||||
eqn.source_info, eqn.ctx)
|
||||
eqn_staged = new_jaxpr_eqn([*res_val_vars, *res_ref_binders, *ins_staged],
|
||||
out_binders_staged,
|
||||
eqn.primitive, params_staged, jaxpr_staged.effects,
|
||||
eqn.source_info)
|
||||
eqn.source_info, eqn.ctx)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
|
||||
assert len(ins_known) + len(res_ref_binders) == len(jaxpr_known.jaxpr.invars)
|
||||
assert len(ins_staged) + len(res_ref_binders) + len(res_val_vars) == len(jaxpr_staged.jaxpr.invars)
|
||||
@ -1606,7 +1608,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
|
||||
new_eqn = new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx)
|
||||
return used_inputs, new_eqn
|
||||
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
||||
|
||||
@ -1627,7 +1629,7 @@ def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn
|
||||
new_eqn = new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, closed_jaxpr.effects, eqn.source_info)
|
||||
eqn.primitive, new_params, closed_jaxpr.effects, eqn.source_info, eqn.ctx)
|
||||
return used_inputs, new_eqn
|
||||
dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule
|
||||
|
||||
@ -2026,10 +2028,12 @@ class DynamicJaxprTrace(core.Trace):
|
||||
" is true. Otherwise it shouldn't.")
|
||||
out_avals = [out_avals] if not primitive.multiple_results else out_avals
|
||||
source_info = source_info_util.current()
|
||||
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
|
||||
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
|
||||
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects,
|
||||
source_info, ctx)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers if primitive.multiple_results else out_tracers.pop()
|
||||
|
||||
@ -2049,6 +2053,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
|
||||
propagate_source_info=False)
|
||||
source_info = source_info_util.current()
|
||||
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
|
||||
out_tracers = []
|
||||
for aval, _ in out_type:
|
||||
if type(aval) is DShapedArray:
|
||||
@ -2067,7 +2072,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
len(consts) + len(implicit_tracers))
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
|
||||
new_params, new_params['call_jaxpr'].effects,
|
||||
source_info)
|
||||
source_info, ctx)
|
||||
self.frame.add_eqn(eqn)
|
||||
return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
|
||||
|
||||
@ -2094,6 +2099,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
if out_axis is not None else a
|
||||
for a, out_axis in zip(reduced_out_avals, out_axes)]
|
||||
source_info = source_info_util.current()
|
||||
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
|
||||
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
||||
@ -2107,7 +2113,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
new_params = update_params(new_params, [True] * len(tracers), len(consts))
|
||||
effs = core.filter_named_axis_effects(jaxpr.effects, {axis_name})
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
|
||||
new_params, effs, source_info)
|
||||
new_params, effs, source_info, ctx)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
@ -2134,13 +2140,14 @@ class DynamicJaxprTrace(core.Trace):
|
||||
invars = map(self.getvar, tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
|
||||
dict(call_jaxpr=closed_fun_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
||||
num_consts=len(consts),
|
||||
symbolic_zeros=symbolic_zeros),
|
||||
fun_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
source_info_util.current(), ctx)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
@ -2168,6 +2175,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
invars = map(self.getvar, tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
|
||||
dict(fun_jaxpr=closed_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_from_zeros,
|
||||
@ -2175,7 +2183,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
bwd=bwd, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros),
|
||||
fun_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
source_info_util.current(), ctx)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
@ -2212,13 +2220,14 @@ class DynamicJaxprTrace(core.Trace):
|
||||
invars = map(self.getvar, tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, call_consts))
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
|
||||
dict(call_jaxpr=closed_call_jaxpr,
|
||||
transpose_jaxpr_thunk=transpose_jaxpr_thunk,
|
||||
out_types=out_types, res_tree=res_tree,
|
||||
lin_tree=lin_tree, out_tree=out_tree),
|
||||
closed_call_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
source_info_util.current(), ctx)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
@ -2796,9 +2805,10 @@ def inline_jaxpr_into_trace(trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts,
|
||||
name_stack=source_info.name_stack + eqn.source_info.name_stack)
|
||||
else:
|
||||
eqn_source_info = source_info
|
||||
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
|
||||
|
||||
new_eqn = core.new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
|
||||
eqn.effects, eqn_source_info)
|
||||
eqn.effects, eqn_source_info, eqn_ctx)
|
||||
trace.frame.add_eqn(new_eqn)
|
||||
map(write, eqn.outvars, out_tracers)
|
||||
core.clean_up_dead_vars(eqn, env, lu)
|
||||
|
@ -1817,7 +1817,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
||||
sharding=gspmd_sharding,
|
||||
unconstrained_dims=unconstrained_dims),
|
||||
set(),
|
||||
eqn.source_info))
|
||||
eqn.source_info, eqn.ctx))
|
||||
return jaxpr.replace(eqns=new_eqns)
|
||||
|
||||
def _flatten_axes(what, tree, axes, tupled_args):
|
||||
|
@ -616,7 +616,6 @@ def _infer_params(jit_info, args, kwargs):
|
||||
assert (len(in_shardings_flat) == len(in_layouts_flat) ==
|
||||
len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat))
|
||||
|
||||
# in_shardings and out_shardings here are all GSPMDSharding.
|
||||
params = dict(
|
||||
jaxpr=jaxpr,
|
||||
in_shardings=in_shardings_flat,
|
||||
@ -1681,6 +1680,14 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
mod_ctx.cached_primitive_lowerings[key] = func
|
||||
return func
|
||||
|
||||
def _map_compute_type(c_type):
|
||||
if c_type == 'device_host':
|
||||
return 'host'
|
||||
elif c_type == 'device':
|
||||
return 'dense'
|
||||
raise ValueError('Invalid compute type received. Current supported values '
|
||||
'are `device_host` and `device`')
|
||||
|
||||
|
||||
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
out_shardings, in_layouts, out_layouts, resource_env,
|
||||
@ -1700,6 +1707,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
call = func_dialect.CallOp(flat_output_types,
|
||||
ir.FlatSymbolRefAttr.get(func.name.value),
|
||||
mlir.flatten_lowering_ir_args(args))
|
||||
if ctx.compute_type is not None:
|
||||
dict_attr = {"_xla_compute_type": ir.StringAttr.get(
|
||||
_map_compute_type(ctx.compute_type))}
|
||||
call.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
|
||||
out_nodes = unflatten(call.results, map(len, output_types))
|
||||
tokens, out_nodes = split_list(out_nodes, [len(effects)])
|
||||
tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens)))
|
||||
@ -2117,7 +2128,7 @@ def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
new_eqn = core.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info)
|
||||
eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info, eqn.ctx)
|
||||
return used_inputs, new_eqn
|
||||
|
||||
pe.dce_rules[pjit_p] = dce_jaxpr_pjit_rule
|
||||
|
17
jax/experimental/compute_on.py
Normal file
17
jax/experimental/compute_on.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the ific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.compute_on import (
|
||||
compute_on as compute_on,
|
||||
)
|
@ -234,6 +234,9 @@ jax_test(
|
||||
shard_count = {
|
||||
"tpu": 5,
|
||||
},
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -33,6 +33,7 @@ from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
|
||||
SingleDeviceSharding, GSPMDSharding,
|
||||
TransferToMemoryKind,
|
||||
common_devices_indices_map)
|
||||
from jax.experimental.compute_on import compute_on
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -1253,6 +1254,143 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding.memory_kind, 'pinned_host')
|
||||
|
||||
|
||||
class ComputeOffload(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Memories do not work on CPU and GPU backends yet.")
|
||||
super().setUp()
|
||||
self.orig_memories_flag = config.enable_memories.value
|
||||
jax.config.update('jax_enable_memories', True)
|
||||
|
||||
def tearDown(self):
|
||||
jax.config.update('jax_enable_memories', self.orig_memories_flag)
|
||||
super().tearDown()
|
||||
|
||||
def test_compute_on_basic(self):
|
||||
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
|
||||
|
||||
@compute_on('device_host')
|
||||
@jax.jit
|
||||
def g(x):
|
||||
return x * 2
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = g(x)
|
||||
return y * 3
|
||||
|
||||
inp = jnp.arange(8)
|
||||
out = f(inp)
|
||||
self.assertArraysEqual(out, inp * 6)
|
||||
|
||||
lowered_text = f.lower(jnp.arange(8)).as_text()
|
||||
self.assertIn('_xla_compute_type', lowered_text)
|
||||
|
||||
@functools.partial(jax.jit, out_shardings=out_s)
|
||||
def h(x):
|
||||
y = g(x)
|
||||
return y * 3
|
||||
|
||||
out2 = h(inp)
|
||||
self.assertArraysEqual(out, inp * 6)
|
||||
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
|
||||
|
||||
def test_nested_compute_error(self):
|
||||
@compute_on('device')
|
||||
@jax.jit
|
||||
def f0(x):
|
||||
return x * 2
|
||||
|
||||
@compute_on('device_host')
|
||||
@jax.jit
|
||||
def f1(x):
|
||||
return f0(x)
|
||||
|
||||
@jax.jit
|
||||
def f2(x):
|
||||
return f1(x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Nesting `compute_on` with different compute types is not supported"
|
||||
" yet."):
|
||||
f2(jnp.arange(8))
|
||||
|
||||
# def test_compute_on_grad(self):
|
||||
# @compute_on('device_host')
|
||||
# @jax.jit
|
||||
# def g(x):
|
||||
# return x * 2
|
||||
|
||||
# def f(x):
|
||||
# y = g(x)
|
||||
# return jnp.sum(y * 3)
|
||||
|
||||
# inp = jnp.arange(8)
|
||||
# jf = jax.jit(jax.grad(f))
|
||||
# out = jf(inp)
|
||||
# print(jax.jit(jax.grad(f)).lower(inp).as_text())
|
||||
|
||||
# def test_sharded_compute_on_host(self):
|
||||
# mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
# s = NamedSharding(mesh, P('x', 'y'))
|
||||
# np_inp = np.arange(16).reshape(8, 2)
|
||||
# arr = jax.device_put(np_inp, s)
|
||||
|
||||
# @compute_on('device_host')
|
||||
# @jax.jit
|
||||
# def g(x):
|
||||
# return x * 2
|
||||
|
||||
# @jax.jit
|
||||
# def f(x):
|
||||
# x = x * 3
|
||||
# return g(x)
|
||||
|
||||
# out = f(arr)
|
||||
# self.assertEqual(out.sharding, s)
|
||||
# self.assertArraysEqual(out, np_inp * 6)
|
||||
|
||||
# def test_nested_no_op_compute(self):
|
||||
# @compute_on('device_host')
|
||||
# @jax.jit
|
||||
# def f0(x):
|
||||
# return x * 2
|
||||
|
||||
# @compute_on('device_host')
|
||||
# @jax.jit
|
||||
# def f1(x):
|
||||
# return f0(x)
|
||||
|
||||
# @jax.jit
|
||||
# def f2(x):
|
||||
# return f1(x)
|
||||
|
||||
# print(f2.lower(jnp.arange(8)).as_text('hlo'))
|
||||
# out = f2(jnp.arange(8))
|
||||
|
||||
# def test_eager_compute(self):
|
||||
# inp = jnp.arange(8)
|
||||
# with compute_on('device_host'):
|
||||
# a = inp * 2
|
||||
# print(a)
|
||||
|
||||
# def test_compute_only_host(self):
|
||||
# @compute_on('device_host')
|
||||
# @jax.jit
|
||||
# def f(x):
|
||||
# return x * 2
|
||||
# f(jnp.arange(8))
|
||||
|
||||
# def test_per_annotation_wrapper(self):
|
||||
# @jax.jit
|
||||
# @compute_on('device_host')
|
||||
# def f(x):
|
||||
# return x * 2
|
||||
# f(jnp.arange(8))
|
||||
|
||||
|
||||
class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user