mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
parent
0ec82f4d62
commit
b05975b964
@ -23,7 +23,7 @@ import jax
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, keystr
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
@ -357,8 +357,8 @@ def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
|
||||
# remat-specific errors.
|
||||
@weakref_lru_cache
|
||||
def _trace_to_jaxpr(fun, in_tree, in_avals):
|
||||
debug = pe.debug_info(fun, in_tree, True, "checkpoint")
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint")
|
||||
try:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
except core.ConcretizationTypeError as e:
|
||||
@ -386,8 +386,12 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
|
||||
args, kwargs = tree_unflatten(in_tree, args)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])(
|
||||
*in_leaves).jaxpr
|
||||
out = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1],
|
||||
return_shape=True)(*in_leaves)
|
||||
assert isinstance(out, tuple)
|
||||
jaxpr_, out_shape = out
|
||||
jaxpr = jaxpr_.jaxpr
|
||||
out_tree = lambda: tree_structure(out_shape)
|
||||
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
|
||||
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}
|
||||
|
||||
@ -401,7 +405,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
|
||||
results.append((v.aval, 'from a constant'))
|
||||
|
||||
assert len(jaxpr.invars) == len(in_leaves)
|
||||
dbg = pe.debug_info(f, in_tree, True, "saved_residuals")
|
||||
dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals")
|
||||
arg_info = pe.arg_info_all(dbg)
|
||||
for i, v in enumerate(jaxpr.invars):
|
||||
if v in res_vars:
|
||||
|
@ -2909,7 +2909,8 @@ def make_jaxpr(fun: Callable,
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
return_shape: bool = False,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> Callable[..., core.ClosedJaxpr]:
|
||||
) -> Callable[..., Union[core.ClosedJaxpr,
|
||||
Tuple[core.ClosedJaxpr, Any]]]:
|
||||
"""Creates a function that produces its jaxpr given example args.
|
||||
|
||||
Args:
|
||||
@ -2990,7 +2991,7 @@ def make_jaxpr(fun: Callable,
|
||||
in_type = tuple(zip(in_avals, keep_inputs))
|
||||
f, out_tree = flatten_fun(f, in_tree)
|
||||
f = lu.annotate(f, in_type)
|
||||
debug_info = pe.debug_info(fun, in_tree, True, 'make_jaxpr')
|
||||
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
@ -3342,7 +3343,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
"""
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
debug_info = pe.debug_info(fun, in_tree, True, "eval_shape")
|
||||
debug_info = pe.debug_info(fun, in_tree, out_tree, True, "eval_shape")
|
||||
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
|
||||
*map(shaped_abstractify, args_flat),
|
||||
debug_info=debug_info)
|
||||
|
@ -109,7 +109,9 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
||||
ans = fun(*args)
|
||||
return tree_unflatten(out_tree, ans)
|
||||
|
||||
def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
|
||||
def flattened_fun_in_tree(
|
||||
fn: lu.WrappedFun
|
||||
) -> Optional[Tuple[PyTreeDef, Callable[[], PyTreeDef], bool]]:
|
||||
# This implementation relies on internal details of linear_util.py's
|
||||
# WrappedFun, but it's for the worthy cause of better user error messages.
|
||||
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
|
||||
@ -119,14 +121,15 @@ def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]
|
||||
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
|
||||
assert (isinstance(flatten_fun_nokwargs, partial) and
|
||||
len(flatten_fun_nokwargs.args) == 1)
|
||||
flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
|
||||
flattens = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
|
||||
try:
|
||||
(in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
|
||||
for f, args in fn.transforms if f in flat_xforms)
|
||||
((in_tree,), out_tree_store, has_kwargs), = (
|
||||
(args, store, f is flatten_fun.args[0])
|
||||
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
return in_tree, has_kwargs
|
||||
return in_tree, lambda: out_tree_store.val, has_kwargs
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_nokwargs2(in_tree, *args_flat):
|
||||
|
@ -391,7 +391,7 @@ def initial_style_jaxpr(
|
||||
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
|
||||
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun, in_tree, False, 'checkify')
|
||||
debug = pe.debug_info(fun, in_tree, out_tree, False, 'checkify')
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
return jaxpr, consts, out_tree()
|
||||
|
||||
|
@ -71,6 +71,7 @@ class DebugInfo(NamedTuple):
|
||||
func_src_info: Optional[str] # f'{fun.__name__} at {filename}:{lineno}'
|
||||
signature: Optional[inspect.Signature] # inspect.signature(fun)
|
||||
in_tree: Optional[PyTreeDef] # caller/constructor might not have this info
|
||||
out_tree: Optional[Callable[[], PyTreeDef]] # lazy, not avail at trace time
|
||||
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
|
||||
traced_for: str # "jit", "scan", "make_jaxpr", etc
|
||||
# TODO(mattjj): add input type signature
|
||||
|
@ -65,7 +65,7 @@ class custom_vmap:
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
||||
|
@ -668,6 +668,7 @@ def lower_jaxpr_to_module(
|
||||
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
arg_names: Optional[Sequence[str]] = None,
|
||||
result_names: Optional[Sequence[str]] = None,
|
||||
) -> LoweringResult:
|
||||
"""Lowers a top-level jaxpr to an MLIR module.
|
||||
|
||||
@ -745,7 +746,7 @@ def lower_jaxpr_to_module(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=arg_shardings, result_shardings=result_shardings,
|
||||
input_output_aliases=input_output_aliases,
|
||||
arg_names=arg_names)
|
||||
arg_names=arg_names, result_names=result_names)
|
||||
|
||||
if not ctx.module.operation.verify():
|
||||
module_string = module_to_string(ctx.module)
|
||||
@ -863,6 +864,7 @@ def lower_jaxpr_to_fun(
|
||||
num_output_tokens: int = 0,
|
||||
api_name: str = 'jit',
|
||||
arg_names: Optional[Sequence[str]] = None,
|
||||
result_names: Optional[Sequence[str]] = None,
|
||||
) -> func_dialect.FuncOp:
|
||||
"""Lowers jaxpr and its callees to an IR function.
|
||||
|
||||
@ -992,13 +994,29 @@ def lower_jaxpr_to_fun(
|
||||
func_op.arg_attrs = ir.ArrayAttr.get(
|
||||
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
|
||||
|
||||
result_attrs: List[Dict[str, ir.Attribute]] = [
|
||||
{} for _ in range(len(flat_output_types))]
|
||||
|
||||
if config.jax_jit_pjit_api_merge and result_names:
|
||||
named_result_attrs = result_attrs[num_tokens:]
|
||||
if len(named_result_attrs) == len(result_names):
|
||||
for attrs, name_ in zip(named_result_attrs, result_names):
|
||||
attrs['jax.result_info'] = ir.StringAttr.get(name_)
|
||||
|
||||
if use_sharding_annotations and ir_result_shardings is not None:
|
||||
func_op.result_attrs = ir.ArrayAttr.get([
|
||||
ir.DictAttr.get(
|
||||
{} if sharding is None else
|
||||
{"mhlo.sharding": get_sharding_attr(sharding)}
|
||||
) for sharding in ir_result_shardings
|
||||
])
|
||||
for attrs, sharding in zip(result_attrs, ir_result_shardings):
|
||||
if sharding is not None:
|
||||
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
|
||||
|
||||
# func_op.result_attrs = ir.ArrayAttr.get([
|
||||
# ir.DictAttr.get(
|
||||
# {} if sharding is None else
|
||||
# {"mhlo.sharding": get_sharding_attr(sharding)}
|
||||
# ) for sharding in ir_result_shardings
|
||||
# ])
|
||||
|
||||
func_op.result_attrs = ir.ArrayAttr.get(
|
||||
[ir.DictAttr.get(attrs) for attrs in result_attrs])
|
||||
|
||||
entry_block = func_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
|
@ -3063,10 +3063,7 @@ def lower_sharding_computation(
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
|
||||
arg_names = None if arg_info is None else [
|
||||
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)
|
||||
if i in kept_var_idx]
|
||||
arg_names, result_names = _debug_names(jaxpr.debug_info, kept_var_idx)
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -3081,7 +3078,7 @@ def lower_sharding_computation(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_op_shardings,
|
||||
result_shardings=out_op_shardings,
|
||||
arg_names=arg_names)
|
||||
arg_names=arg_names, result_names=result_names)
|
||||
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
@ -3254,9 +3251,7 @@ def lower_mesh_computation(
|
||||
closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(
|
||||
closed_jaxpr.effects))
|
||||
arg_info = jaxpr.debug_info and pe.arg_info_all(jaxpr.debug_info)
|
||||
arg_names = None if arg_info is None else [
|
||||
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)]
|
||||
arg_names, result_names = _debug_names(jaxpr.debug_info)
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -3270,7 +3265,8 @@ def lower_mesh_computation(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_partitions,
|
||||
result_shardings=out_partitions,
|
||||
arg_names=arg_names)
|
||||
arg_names=arg_names,
|
||||
result_names=result_names)
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks)
|
||||
@ -3297,6 +3293,17 @@ def lower_mesh_computation(
|
||||
device_assignment=list(mesh.devices.flat),
|
||||
committed=True)
|
||||
|
||||
def _debug_names(
|
||||
dbg: Optional[core.DebugInfo], kept_var_idx: Optional[Set] = None
|
||||
) -> Tuple[Optional[List[str]], Optional[List[str]]]:
|
||||
if dbg is None: return (None, None)
|
||||
arg_info = pe.arg_info_all(dbg)
|
||||
arg_names = None if arg_info is None else [
|
||||
f'{name}{keystr(path)}' for i, (name, path) in enumerate(arg_info)
|
||||
if kept_var_idx is None or i in kept_var_idx]
|
||||
result_info = pe.result_info(dbg)
|
||||
result_names = None if result_info is None else map(keystr, result_info)
|
||||
return arg_names, result_names
|
||||
|
||||
class MeshComputation(stages.XlaLowering):
|
||||
_hlo: Optional[ir.Module]
|
||||
|
@ -49,7 +49,8 @@ def _typecheck_param(prim, param, name, msg_required, pred):
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
primitive_name: Optional[str] = None):
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
|
||||
debug = pe.debug_info(fun, in_tree, out_tree, False,
|
||||
primitive_name or "<unknown>")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
|
||||
return jaxpr, consts, out_tree()
|
||||
|
||||
|
@ -346,7 +346,8 @@ class custom_partitioning:
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, False, "custom_partitioning")
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
|
||||
"custom_partitioning")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
assert not len(consts)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
|
@ -2311,7 +2311,7 @@ def bcoo_conv_general_dilated(lhs, rhs, *, window_strides, padding,
|
||||
precision=precision, preferred_element_type=preferred_element_type)
|
||||
jaxpr = jax.make_jaxpr(func)(jax.ShapeDtypeStruct(lhs.shape, lhs.dtype),
|
||||
jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
|
||||
assert len(jaxpr.eqns) == 1
|
||||
assert isinstance(jaxpr, core.ClosedJaxpr) and len(jaxpr.eqns) == 1
|
||||
params = jaxpr.eqns[0].params
|
||||
|
||||
if params['lhs_dilation'] != (1,) * (lhs.ndim - 2):
|
||||
|
@ -1976,17 +1976,19 @@ def _memoize(thunk):
|
||||
return memoized
|
||||
|
||||
|
||||
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
|
||||
traced_for: str) -> DebugInfo:
|
||||
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef],
|
||||
out_tree_thunk: Optional[Callable[[], PyTreeDef]],
|
||||
has_kwargs: bool, traced_for: str) -> DebugInfo:
|
||||
try: sig = inspect.signature(fn)
|
||||
except (ValueError, TypeError): sig = None
|
||||
src_info = fun_sourceinfo(fn)
|
||||
return DebugInfo(src_info, sig, in_tree, has_kwargs, traced_for)
|
||||
return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs,
|
||||
traced_for)
|
||||
|
||||
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
|
||||
"Make a DebugInfo from data available to final-style primitives like pmap."
|
||||
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
|
||||
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
|
||||
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
|
||||
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
|
||||
|
||||
def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]:
|
||||
ba = None if dbg.in_tree is None else sig_info(dbg)
|
||||
@ -2006,6 +2008,16 @@ def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]:
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def result_info(dbg: DebugInfo) -> Optional[List[KeyPath]]:
|
||||
if dbg.out_tree is None: return None
|
||||
try:
|
||||
num_leaves = dbg.out_tree().num_leaves
|
||||
dummy_result = tree_unflatten(dbg.out_tree(), [False] * num_leaves)
|
||||
except:
|
||||
return None
|
||||
else:
|
||||
return [path for path, _ in _generate_key_paths(dummy_result)]
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
|
@ -1161,6 +1161,18 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIn("kwargs['z']", mhlo_str)
|
||||
self.assertIn("kwargs['w']", mhlo_str)
|
||||
|
||||
def test_jit_lower_result_info(self):
|
||||
if not config.jax_array or not jax.config.jax_jit_pjit_api_merge:
|
||||
raise unittest.SkipTest("test only applies after jit-pjit api merge")
|
||||
|
||||
def f(x, y, z):
|
||||
return {'a': x, 'b': [y]}
|
||||
|
||||
lowered = jax.jit(f).lower(1., (2,), [3])
|
||||
mhlo_str = str(lowered.compiler_ir('mhlo'))
|
||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
|
||||
def test_jit_enum_as_dict_keys_fails(self):
|
||||
class E(enum.Enum):
|
||||
A = 0
|
||||
|
@ -3476,6 +3476,22 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# self.assertNotIn("args[0]", mhlo_str)
|
||||
self.assertIn("args[1]", mhlo_str)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def test_jit_nested_xmap_lower_result_info(self):
|
||||
if not config.jax_array or not jax.config.jax_jit_pjit_api_merge:
|
||||
raise unittest.SkipTest("test only applies after jit-pjit api merge")
|
||||
|
||||
def f(x, y, z):
|
||||
_ = xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
|
||||
axis_resources={'i': 'y'})(jnp.arange(8.))
|
||||
return {'a': x, 'b': [y]}
|
||||
|
||||
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
|
||||
1., (2.,), [3.])
|
||||
mhlo_str = str(lowered.compiler_ir('mhlo'))
|
||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
|
||||
def test_with_sharding_constraint_with_two_meshes(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user