1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

Initial commit for jax.experimental.compute_on API.

The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
This commit is contained in:
Yash Katariya 2024-05-17 15:58:25 -07:00 committed by jax authors
parent 7a3fc7113b
commit 2d6d408b19
10 changed files with 281 additions and 24 deletions

@ -301,6 +301,7 @@ py_library_providing_imports_info(
":cloud_tpu_init",
":compilation_cache_internal",
":compiler",
":compute_on",
":config",
":core",
":custom_api_util",
@ -711,6 +712,7 @@ pytype_strict_library(
deps = [
":ad_util",
":api_util",
":compute_on",
":config",
":core",
":dtypes",
@ -778,6 +780,12 @@ pytype_strict_library(
],
)
pytype_strict_library(
name = "compute_on",
srcs = ["_src/compute_on.py"],
deps = [],
)
pytype_strict_library(
name = "layout",
srcs = ["_src/layout.py"],

55
jax/_src/compute_on.py Normal file

@ -0,0 +1,55 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import threading
from contextlib import contextmanager
class ComputeOnContext(threading.local):
def __init__(self):
self.stack = []
compute_on_context = ComputeOnContext()
@contextmanager
def extend_compute_type(c_type: str):
compute_on_context.stack.append(c_type)
if len(set(compute_on_context.stack)) > 1:
raise NotImplementedError(
'Nesting `compute_on` with different compute types is not supported'
' yet.')
try:
yield compute_on_context.stack[-1]
finally:
compute_on_context.stack.pop()
def current_compute_type() -> str | None:
return compute_on_context.stack[-1] if compute_on_context.stack else None
def _check_valid(c_type: str):
if c_type not in {'device_host', 'device'}:
raise ValueError('Invalid compute type received. Current supported values '
f'are `device_host` and `device`. Got {c_type}')
@contextmanager
def compute_on(compute_type: str):
if not isinstance(compute_type, str):
raise TypeError("`compute_on`'s compute_type argument must be a string.")
_check_valid(compute_type)
with extend_compute_type(compute_type):
yield

@ -20,6 +20,7 @@ from collections.abc import (Collection, Generator, Hashable, Iterable,
MutableMapping)
from contextlib import contextmanager
from dataclasses import dataclass
import dataclasses
import functools
from functools import partial, partialmethod, total_ordering
import gc
@ -259,6 +260,11 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
@dataclasses.dataclass(frozen=True)
class JaxprEqnContext:
compute_type: str | None
class JaxprEqn(NamedTuple):
invars: list[Atom]
outvars: list[Var]
@ -266,6 +272,7 @@ class JaxprEqn(NamedTuple):
params: dict[str, Any]
effects: Effects
source_info: source_info_util.SourceInfo
ctx: JaxprEqnContext
def __repr__(self):
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()
@ -278,6 +285,7 @@ class JaxprEqn(NamedTuple):
params: dict[str, Any] | None = None,
effects: Effects | None = None,
source_info: source_info_util.SourceInfo | None = None,
ctx: JaxprEqnContext | None = None
):
# It is slightly faster to rebuild the tuple directly than to call _replace.
return JaxprEqn(
@ -287,16 +295,19 @@ class JaxprEqn(NamedTuple):
self.params if params is None else params,
self.effects if effects is None else effects,
self.source_info if source_info is None else source_info,
self.ctx if ctx is None else ctx,
)
# TODO(mattjj): call typecheck rules here, so we don't form bad eqns
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
ctx=None):
source_info = source_info or source_info_util.new_source_info()
ctx = ctx or JaxprEqnContext(None)
if config.enable_checks.value:
assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info, ctx)
_var_counter = it.count()

@ -694,8 +694,10 @@ class LoweringRuleContext:
tokens_in: TokenSet
tokens_out: TokenSet | None # Mutable store for output containers
axis_size_env: dict[core.Var, ir.Value] | None = None # Dynamic axis sizes
dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables
# in same order as module_context.shape_poly_state.dim_vars
# The values for the dimension variables in same order as
# module_context.shape_poly_state.dim_vars
dim_var_values: Sequence[ir.Value] = ()
compute_type: str | None = None
def set_tokens_out(self, tokens_out: TokenSet):
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
@ -1565,12 +1567,14 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
tokens_in = tokens.subset(effects)
avals_in = map(aval, eqn.invars)
compute_type = eqn.ctx.compute_type if eqn.ctx is not None else None
rule_ctx = LoweringRuleContext(
module_context=ctx, primitive=eqn.primitive,
name_stack=source_info.name_stack,
avals_in=avals_in,
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
tokens_out=None, dim_var_values=dim_var_values)
tokens_out=None, dim_var_values=dim_var_values,
compute_type=compute_type)
if config.dynamic_shapes.value:
axis_size_env = {d: read(d)[0]
for a in avals_in if type(a) is core.DShapedArray

@ -34,6 +34,7 @@ from jax._src import effects
from jax._src import linear_util as lu
from jax._src import profiler
from jax._src import source_info_util
from jax._src import compute_on
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
fun_sourceinfo)
from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
@ -41,7 +42,7 @@ from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
ConcreteArray, Var, DropVar, raise_to_shaped, Atom,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent)
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
KeyPath, generate_key_paths, keystr)
@ -1252,7 +1253,7 @@ def _partial_eval_jaxpr_custom_cached(
offload_eqn = core.JaxprEqn(
outvars_copy, resvars, device_put_p,
dict(device=TransferToMemoryKind(policy.dst), src=None),
set(), source_info_util.new_source_info())
set(), source_info_util.new_source_info(), JaxprEqnContext(None))
known_eqns.append(offload_eqn)
# resvars are known and available in the backward jaxpr.
map(partial(write, False, True), resvars)
@ -1260,7 +1261,7 @@ def _partial_eval_jaxpr_custom_cached(
reload_eqn = core.JaxprEqn(
resvars, eqn.outvars, device_put_p, # type: ignore
dict(device=TransferToMemoryKind(policy.src), src=None),
set(), source_info_util.new_source_info())
set(), source_info_util.new_source_info(), JaxprEqnContext(None))
staged_eqns.append(reload_eqn)
# outvars are known and available in the backward jaxpr.
map(partial(write, False, True), eqn.outvars)
@ -1382,10 +1383,11 @@ def call_partial_eval_custom_rule(
residuals = [newvar(res_aval(params_known, var.aval))
for var in jaxpr_staged.invars[:num_res]]
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info, eqn.ctx)
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
eqn.primitive, params_staged,
jaxpr_staged.effects, eqn.source_info)
jaxpr_staged.effects, eqn.source_info, eqn.ctx)
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is Var and not inst]
@ -1425,11 +1427,11 @@ def closed_call_partial_eval_custom_rule(
eqn_known = new_jaxpr_eqn([*ins_known, *res_ref_binders],
[*out_binders_known, *res_val_binders],
eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info)
eqn.source_info, eqn.ctx)
eqn_staged = new_jaxpr_eqn([*res_val_vars, *res_ref_binders, *ins_staged],
out_binders_staged,
eqn.primitive, params_staged, jaxpr_staged.effects,
eqn.source_info)
eqn.source_info, eqn.ctx)
assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals)
assert len(ins_known) + len(res_ref_binders) == len(jaxpr_known.jaxpr.invars)
assert len(ins_staged) + len(res_ref_binders) + len(res_val_vars) == len(jaxpr_staged.jaxpr.invars)
@ -1606,7 +1608,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
new_eqn = new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
@ -1627,7 +1629,7 @@ def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn
new_eqn = new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, closed_jaxpr.effects, eqn.source_info)
eqn.primitive, new_params, closed_jaxpr.effects, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn
dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule
@ -2026,10 +2028,12 @@ class DynamicJaxprTrace(core.Trace):
" is true. Otherwise it shouldn't.")
out_avals = [out_avals] if not primitive.multiple_results else out_avals
source_info = source_info_util.current()
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects,
source_info, ctx)
self.frame.add_eqn(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()
@ -2049,6 +2053,7 @@ class DynamicJaxprTrace(core.Trace):
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
propagate_source_info=False)
source_info = source_info_util.current()
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
out_tracers = []
for aval, _ in out_type:
if type(aval) is DShapedArray:
@ -2067,7 +2072,7 @@ class DynamicJaxprTrace(core.Trace):
len(consts) + len(implicit_tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
new_params, new_params['call_jaxpr'].effects,
source_info)
source_info, ctx)
self.frame.add_eqn(eqn)
return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
@ -2094,6 +2099,7 @@ class DynamicJaxprTrace(core.Trace):
if out_axis is not None else a
for a, out_axis in zip(reduced_out_avals, out_axes)]
source_info = source_info_util.current()
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
@ -2107,7 +2113,7 @@ class DynamicJaxprTrace(core.Trace):
new_params = update_params(new_params, [True] * len(tracers), len(consts))
effs = core.filter_named_axis_effects(jaxpr.effects, {axis_name})
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
new_params, effs, source_info)
new_params, effs, source_info, ctx)
self.frame.add_eqn(eqn)
return out_tracers
@ -2134,13 +2140,14 @@ class DynamicJaxprTrace(core.Trace):
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
outvars = map(self.makevar, out_tracers)
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
dict(call_jaxpr=closed_fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts),
symbolic_zeros=symbolic_zeros),
fun_jaxpr.effects,
source_info_util.current())
source_info_util.current(), ctx)
self.frame.add_eqn(eqn)
return out_tracers
@ -2168,6 +2175,7 @@ class DynamicJaxprTrace(core.Trace):
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
outvars = map(self.makevar, out_tracers)
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
dict(fun_jaxpr=closed_fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_from_zeros,
@ -2175,7 +2183,7 @@ class DynamicJaxprTrace(core.Trace):
bwd=bwd, out_trees=out_trees,
symbolic_zeros=symbolic_zeros),
fun_jaxpr.effects,
source_info_util.current())
source_info_util.current(), ctx)
self.frame.add_eqn(eqn)
return out_tracers
@ -2212,13 +2220,14 @@ class DynamicJaxprTrace(core.Trace):
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, call_consts))
outvars = map(self.makevar, out_tracers)
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
dict(call_jaxpr=closed_call_jaxpr,
transpose_jaxpr_thunk=transpose_jaxpr_thunk,
out_types=out_types, res_tree=res_tree,
lin_tree=lin_tree, out_tree=out_tree),
closed_call_jaxpr.effects,
source_info_util.current())
source_info_util.current(), ctx)
self.frame.add_eqn(eqn)
return out_tracers
@ -2796,9 +2805,10 @@ def inline_jaxpr_into_trace(trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts,
name_stack=source_info.name_stack + eqn.source_info.name_stack)
else:
eqn_source_info = source_info
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
new_eqn = core.new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.effects, eqn_source_info)
eqn.effects, eqn_source_info, eqn_ctx)
trace.frame.add_eqn(new_eqn)
map(write, eqn.outvars, out_tracers)
core.clean_up_dead_vars(eqn, env, lu)

@ -1817,7 +1817,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
sharding=gspmd_sharding,
unconstrained_dims=unconstrained_dims),
set(),
eqn.source_info))
eqn.source_info, eqn.ctx))
return jaxpr.replace(eqns=new_eqns)
def _flatten_axes(what, tree, axes, tupled_args):

@ -616,7 +616,6 @@ def _infer_params(jit_info, args, kwargs):
assert (len(in_shardings_flat) == len(in_layouts_flat) ==
len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat))
# in_shardings and out_shardings here are all GSPMDSharding.
params = dict(
jaxpr=jaxpr,
in_shardings=in_shardings_flat,
@ -1681,6 +1680,14 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
mod_ctx.cached_primitive_lowerings[key] = func
return func
def _map_compute_type(c_type):
if c_type == 'device_host':
return 'host'
elif c_type == 'device':
return 'dense'
raise ValueError('Invalid compute type received. Current supported values '
'are `device_host` and `device`')
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, in_layouts, out_layouts, resource_env,
@ -1700,6 +1707,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
call = func_dialect.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(func.name.value),
mlir.flatten_lowering_ir_args(args))
if ctx.compute_type is not None:
dict_attr = {"_xla_compute_type": ir.StringAttr.get(
_map_compute_type(ctx.compute_type))}
call.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
out_nodes = unflatten(call.results, map(len, output_types))
tokens, out_nodes = split_list(out_nodes, [len(effects)])
tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens)))
@ -2117,7 +2128,7 @@ def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
new_eqn = core.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info)
eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn
pe.dce_rules[pjit_p] = dce_jaxpr_pjit_rule

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the ific language governing permissions and
# limitations under the License.
from jax._src.compute_on import (
compute_on as compute_on,
)

@ -234,6 +234,9 @@ jax_test(
shard_count = {
"tpu": 5,
},
deps = [
"//jax:experimental",
],
)
jax_test(

@ -33,6 +33,7 @@ from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
SingleDeviceSharding, GSPMDSharding,
TransferToMemoryKind,
common_devices_indices_map)
from jax.experimental.compute_on import compute_on
import numpy as np
config.parse_flags_with_absl()
@ -1253,6 +1254,143 @@ class DevicePutTest(jtu.JaxTestCase):
self.assertEqual(out.sharding.memory_kind, 'pinned_host')
class ComputeOffload(jtu.JaxTestCase):
def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
super().setUp()
self.orig_memories_flag = config.enable_memories.value
jax.config.update('jax_enable_memories', True)
def tearDown(self):
jax.config.update('jax_enable_memories', self.orig_memories_flag)
super().tearDown()
def test_compute_on_basic(self):
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
@compute_on('device_host')
@jax.jit
def g(x):
return x * 2
@jax.jit
def f(x):
y = g(x)
return y * 3
inp = jnp.arange(8)
out = f(inp)
self.assertArraysEqual(out, inp * 6)
lowered_text = f.lower(jnp.arange(8)).as_text()
self.assertIn('_xla_compute_type', lowered_text)
@functools.partial(jax.jit, out_shardings=out_s)
def h(x):
y = g(x)
return y * 3
out2 = h(inp)
self.assertArraysEqual(out, inp * 6)
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
def test_nested_compute_error(self):
@compute_on('device')
@jax.jit
def f0(x):
return x * 2
@compute_on('device_host')
@jax.jit
def f1(x):
return f0(x)
@jax.jit
def f2(x):
return f1(x)
with self.assertRaisesRegex(
NotImplementedError,
"Nesting `compute_on` with different compute types is not supported"
" yet."):
f2(jnp.arange(8))
# def test_compute_on_grad(self):
# @compute_on('device_host')
# @jax.jit
# def g(x):
# return x * 2
# def f(x):
# y = g(x)
# return jnp.sum(y * 3)
# inp = jnp.arange(8)
# jf = jax.jit(jax.grad(f))
# out = jf(inp)
# print(jax.jit(jax.grad(f)).lower(inp).as_text())
# def test_sharded_compute_on_host(self):
# mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
# s = NamedSharding(mesh, P('x', 'y'))
# np_inp = np.arange(16).reshape(8, 2)
# arr = jax.device_put(np_inp, s)
# @compute_on('device_host')
# @jax.jit
# def g(x):
# return x * 2
# @jax.jit
# def f(x):
# x = x * 3
# return g(x)
# out = f(arr)
# self.assertEqual(out.sharding, s)
# self.assertArraysEqual(out, np_inp * 6)
# def test_nested_no_op_compute(self):
# @compute_on('device_host')
# @jax.jit
# def f0(x):
# return x * 2
# @compute_on('device_host')
# @jax.jit
# def f1(x):
# return f0(x)
# @jax.jit
# def f2(x):
# return f1(x)
# print(f2.lower(jnp.arange(8)).as_text('hlo'))
# out = f2(jnp.arange(8))
# def test_eager_compute(self):
# inp = jnp.arange(8)
# with compute_on('device_host'):
# a = inp * 2
# print(a)
# def test_compute_only_host(self):
# @compute_on('device_host')
# @jax.jit
# def f(x):
# return x * 2
# f(jnp.arange(8))
# def test_per_annotation_wrapper(self):
# @jax.jit
# @compute_on('device_host')
# def f(x):
# return x * 2
# f(jnp.arange(8))
class ActivationOffloadingTest(jtu.JaxTestCase):
def setUp(self):