mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
Merge pull request #27903 from mattjj:pvary-errors
PiperOrigin-RevId: 746070501
This commit is contained in:
commit
9011d66a29
jax/_src/lax/control_flow
tests
@ -260,3 +260,24 @@ def _show_diff(array1, array2):
|
||||
def _avals_short(avals):
|
||||
to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))()
|
||||
return ' '.join(map(to_str, avals))
|
||||
|
||||
def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
|
||||
assert not core.typematch(a1, a2)
|
||||
if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray):
|
||||
mismatches = []
|
||||
if a1.dtype != a2.dtype:
|
||||
mismatches.append('the dtypes do not match')
|
||||
if a1.shape != a2.shape:
|
||||
mismatches.append('the shapes do not match')
|
||||
if a1.vma != a2.vma:
|
||||
mismatches.append('the varying manual axes do not match')
|
||||
# TODO(yashkatariya,mattjj): add check for sharding-in-types mismatch
|
||||
|
||||
if len(mismatches) == 0:
|
||||
return ''
|
||||
elif len(mismatches) == 1:
|
||||
return ', so ' + mismatches[0]
|
||||
else:
|
||||
return ', so ' + ', '.join(mismatches[:-1]) + ', and ' + mismatches[-1]
|
||||
return ''
|
||||
|
||||
|
@ -23,7 +23,9 @@ import itertools
|
||||
import operator
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src.tree_util import (
|
||||
tree_flatten, tree_unflatten, tree_flatten_with_path, keystr,
|
||||
equality_errors_pytreedef)
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
@ -44,19 +46,14 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lax import lax
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (safe_map, split_list, partition_list)
|
||||
from jax._src.util import safe_map, split_list, partition_list, unzip2
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_avals_short,
|
||||
_check_tree_and_avals,
|
||||
_initial_style_jaxprs_with_common_consts,
|
||||
_make_closed_jaxpr,
|
||||
_prune_zeros,
|
||||
_typecheck_param,
|
||||
)
|
||||
_avals_short, _typecheck_param, _aval_mismatch_extra,
|
||||
_initial_style_jaxprs_with_common_consts, _make_closed_jaxpr, _prune_zeros)
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
@ -147,10 +144,9 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
if config.mutable_array_checks.value:
|
||||
api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops)
|
||||
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
||||
_check_tree_and_avals("branch 0 output",
|
||||
out_trees[0], jaxprs[0].out_avals,
|
||||
f"branch {i + 1} output",
|
||||
out_tree, jaxpr.out_avals)
|
||||
_check_branch_outputs(
|
||||
"switch", "branch 0", f"branch{i+1}", branches[0], branches[i+1],
|
||||
out_trees[0], out_tree, jaxprs[0].out_avals, jaxpr.out_avals)
|
||||
# prune passthrough outputs
|
||||
fwds = [pe._jaxpr_forwarding(jaxpr.jaxpr) for jaxpr in jaxprs]
|
||||
in_fwd = [xs[0] if len(set(xs)) == 1 else None for xs in zip(*fwds)]
|
||||
@ -270,10 +266,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
true_jaxpr.out_avals + false_jaxpr.out_avals):
|
||||
raise ValueError("Cannot return `Ref`s from `cond`.")
|
||||
|
||||
_check_tree_and_avals("true_fun output",
|
||||
out_tree, true_jaxpr.out_avals,
|
||||
"false_fun output",
|
||||
false_out_tree, false_jaxpr.out_avals)
|
||||
_check_branch_outputs(
|
||||
'cond', 'true_fun', 'false_fun', true_fun, false_fun, out_tree,
|
||||
false_out_tree, true_jaxpr.out_avals, false_jaxpr.out_avals)
|
||||
|
||||
# prune passthrough outputs
|
||||
true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr)
|
||||
false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr)
|
||||
@ -303,6 +299,90 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
assert next(out_, None) is None
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
def _check_branch_outputs(
|
||||
api_name, name1, name2, f1, f2, out_tree1, out_tree2, out_avals1,
|
||||
out_avals2) -> None:
|
||||
info1 = api_util.fun_sourceinfo(f1)
|
||||
info2 = api_util.fun_sourceinfo(f2)
|
||||
try:
|
||||
outs1 = tree_unflatten(out_tree1, out_avals1)
|
||||
except:
|
||||
paths = [None] * len(out_avals1)
|
||||
component = lambda _: ''
|
||||
else:
|
||||
leaves_and_paths, _ = tree_flatten_with_path(outs1)
|
||||
paths, _ = unzip2(leaves_and_paths)
|
||||
component = lambda p: f' at path {keystr(p)}' if p else ''
|
||||
|
||||
if out_tree1 != out_tree2:
|
||||
diffs = [f'{name1} output{component(p)} is a {thing1} but '
|
||||
f'{name2} output{component(p)} is a {thing2}, so {expl}'
|
||||
for p, thing1, thing2, expl
|
||||
in equality_errors_pytreedef(out_tree1, out_tree2)]
|
||||
|
||||
if len(diffs) == 0:
|
||||
return # the trees may have different aux data, but structures are same
|
||||
elif len(diffs) == 1:
|
||||
differences = f'{diffs[0]}.\n'
|
||||
else:
|
||||
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
|
||||
+ f' * {diffs[-1]}.\n')
|
||||
|
||||
raise TypeError(
|
||||
f'{api_name} branch outputs must have the same pytree structure, but '
|
||||
'they differ:\n\n'
|
||||
f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n'
|
||||
f'{differences}\n'
|
||||
f'Revise {name1} and/or {name2} so that they have the same pytree '
|
||||
'structure.')
|
||||
|
||||
if not all(map(core.typematch, out_avals1, out_avals2)):
|
||||
diffs = [f'the output of {name1}{component(p)} has type {a1.str_short()}'
|
||||
f' but the corresponding output of {name2} has type '
|
||||
f'{a2.str_short()}{_aval_mismatch_extra(a1, a2)}'
|
||||
for p, a1, a2 in zip(paths, out_avals1, out_avals2)
|
||||
if not core.typematch(a1, a2)]
|
||||
if len(diffs) == 0:
|
||||
return # seems unreachable but in any case we don't have a good error msg
|
||||
elif len(diffs) == 1:
|
||||
differences = f'{_capitalize(diffs[0])}.\n'
|
||||
else:
|
||||
differences = ('\n'.join(f' * {d};' for d in diffs[:-1])
|
||||
+ f'\n * {diffs[-1]}.\n')
|
||||
|
||||
pvary_applications = [
|
||||
f'applying `jax.lax.pvary(..., {tuple(a1.vma - a2.vma)})` '
|
||||
f'to the output of {n}{component(p)}'
|
||||
for p, aval1, aval2 in zip(paths, out_avals1, out_avals2)
|
||||
for n, a1, a2 in [(name1, aval2, aval1), (name2, aval1, aval2)]
|
||||
if not core.typematch(a1, a2) and
|
||||
isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray)
|
||||
and a1.vma != a2.vma and a2.vma - a1.vma]
|
||||
|
||||
if not pvary_applications:
|
||||
pvary_msg = ''
|
||||
elif len(pvary_applications) == 1:
|
||||
pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n'
|
||||
else:
|
||||
pvary_msg = ('This might be fixed by:\n' +
|
||||
'\n'.join(f' * {d};' for d in pvary_applications[:-1])
|
||||
+ f'\n * {pvary_applications[-1]}.\n')
|
||||
if pvary_msg:
|
||||
pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma "
|
||||
"for more information.\n\n")
|
||||
|
||||
raise TypeError(
|
||||
f'{api_name} branches must have equal output types but they differ.\n\n'
|
||||
f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n'
|
||||
f'{differences}\n'
|
||||
f'{pvary_msg}'
|
||||
f'Revise {name1} and/or {name2} so that all output types match.')
|
||||
|
||||
|
||||
def _capitalize(s):
|
||||
# s.capitalize() converts s[1:] to lowercase which we don't want.
|
||||
return s[0].capitalize() + s[1:]
|
||||
|
||||
@api_boundary
|
||||
@functools.wraps(_cond)
|
||||
def cond(*args, **kwargs):
|
||||
|
@ -51,7 +51,7 @@ from jax._src.lax import windowed_reductions
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_avals_short, _initial_style_jaxpr,
|
||||
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
|
||||
_typecheck_param)
|
||||
_typecheck_param, _aval_mismatch_extra)
|
||||
from jax._src.lax.other import logaddexp
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -60,24 +60,12 @@ from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import equality_errors
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (
|
||||
merge_lists,
|
||||
partition_list,
|
||||
safe_map,
|
||||
safe_zip,
|
||||
split_list,
|
||||
split_list_checked,
|
||||
unzip2,
|
||||
weakref_lru_cache,
|
||||
)
|
||||
merge_lists, partition_list, safe_map, safe_zip, split_list,
|
||||
split_list_checked, unzip2, weakref_lru_cache,)
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax.tree_util import (
|
||||
keystr,
|
||||
tree_flatten,
|
||||
tree_flatten_with_path,
|
||||
tree_map,
|
||||
tree_unflatten,
|
||||
treedef_is_leaf,
|
||||
)
|
||||
keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten,
|
||||
treedef_is_leaf)
|
||||
import numpy as np
|
||||
|
||||
_map = safe_map
|
||||
@ -428,9 +416,8 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
|
||||
for path, thing1, thing2, explanation
|
||||
in equality_errors(in_carry, out_carry)]
|
||||
if len(diffs) == 0:
|
||||
# The trees may have different aux data but structures are the same.
|
||||
return
|
||||
if len(diffs) == 1:
|
||||
return # the trees may have different aux data, but structures are same
|
||||
elif len(diffs) == 1:
|
||||
differences = f'{_capitalize(diffs[0])}.\n'
|
||||
else:
|
||||
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
|
||||
@ -447,32 +434,42 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
|
||||
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}'
|
||||
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
|
||||
if not core.typematch(in_aval, out_aval)]
|
||||
|
||||
if len(diffs) == 0:
|
||||
# The trees may have different aux data but structures are the same.
|
||||
return
|
||||
return # seems unreachable but in any case we don't have a good error msg
|
||||
if len(diffs) == 1:
|
||||
differences = f'{_capitalize(diffs[0])}.\n'
|
||||
else:
|
||||
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
|
||||
+ f' * {diffs[-1]}.\n')
|
||||
|
||||
pvary_applications = [
|
||||
f'applying `jax.lax.pvary(..., {tuple(out_aval.vma - in_aval.vma)})` '
|
||||
f'to the initial carry value corresponding to {component(path)}'
|
||||
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
|
||||
if not core.typematch(in_aval, out_aval) and
|
||||
isinstance(in_aval, ShapedArray) and isinstance(out_aval, ShapedArray)
|
||||
and in_aval.vma != out_aval.vma and out_aval.vma - in_aval.vma]
|
||||
|
||||
if not pvary_applications:
|
||||
pvary_msg = ''
|
||||
elif len(pvary_applications) == 1:
|
||||
pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n'
|
||||
else:
|
||||
pvary_msg = ('This might be fixed by:\n' +
|
||||
'\n'.join(f' * {d};\n' for d in pvary_applications[:-1])
|
||||
+ f' * {pvary_applications[-1]}.\n')
|
||||
if pvary_msg:
|
||||
pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma "
|
||||
"for more information.\n\n")
|
||||
|
||||
raise TypeError(
|
||||
f"{name} function carry input and carry output must have equal types "
|
||||
"(e.g. shapes and dtypes of arrays), "
|
||||
f"{name} function carry input and carry output must have equal types, "
|
||||
"but they differ:\n\n"
|
||||
f"{differences}\n"
|
||||
"Revise the function so that all output types (e.g. shapes "
|
||||
"and dtypes) match the corresponding input types.")
|
||||
|
||||
def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
|
||||
assert not core.typematch(a1, a2)
|
||||
if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray):
|
||||
dtype_mismatch = a1.dtype != a2.dtype
|
||||
shape_mismatch = a1.shape != a2.shape
|
||||
return (', so ' * (dtype_mismatch or shape_mismatch) +
|
||||
'the dtypes do not match' * dtype_mismatch +
|
||||
' and also ' * (dtype_mismatch and shape_mismatch) +
|
||||
'the shapes do not match' * shape_mismatch)
|
||||
return ''
|
||||
f"{pvary_msg}"
|
||||
"Revise the function so that all output types match the corresponding "
|
||||
"input types.")
|
||||
|
||||
# TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression.
|
||||
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
||||
|
@ -588,7 +588,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
init = jnp.float32(10)
|
||||
self.assertEqual(fori_loop_with_static_upper_and_lower(init), init)
|
||||
|
||||
|
||||
def testForiLoopBatched(self):
|
||||
def body_fun(i, loop_carry):
|
||||
x, y = loop_carry
|
||||
@ -994,16 +993,24 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.escape("Pred must be a scalar, got (1.0, 1.0) of type <class 'tuple'>")):
|
||||
lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.compile("true_fun output must have same type structure "
|
||||
"as false_fun output, but there are differences:.*"
|
||||
r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)):
|
||||
lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"true_fun output and false_fun output must have identical types, got\n"
|
||||
r"DIFFERENT ShapedArray\(float32\[1\]\) vs. "
|
||||
r"ShapedArray\(float32\[\].*\)."):
|
||||
re.compile(
|
||||
r"cond branch outputs must have the same pytree structure, but they"
|
||||
r" differ:.*true_fun output at path \['a'\] is a pytree leaf but"
|
||||
r" false_fun output at path \['a'\] is a <class 'tuple'>",
|
||||
re.DOTALL)):
|
||||
lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
re.compile(
|
||||
r"cond branches must have equal output types but they differ.*The"
|
||||
r" output of true_fun has type float32\[1\] but the corresponding"
|
||||
r" output of false_fun has type float32\[\], so the shapes do not"
|
||||
r" match",
|
||||
re.DOTALL)):
|
||||
lax.cond(True,
|
||||
lambda top: jnp.array([1.], jnp.float32),
|
||||
lambda fop: jnp.float32(1.),
|
||||
@ -1023,16 +1030,26 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("Empty branch sequence")):
|
||||
lax.switch(0, [], 1.)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.compile("branch 0 output must have same type structure "
|
||||
"as branch 1 output, but there are differences:.*"
|
||||
r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)):
|
||||
lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"branch 0 output and branch 1 output must have identical types, got\n"
|
||||
r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) "
|
||||
r"vs. ShapedArray\(float32\[\].*\)'}."):
|
||||
re.compile(
|
||||
"switch branch outputs must have the same pytree structure, but"
|
||||
r" they differ.*branch 0 output at path \['a'\] is a pytree leaf"
|
||||
r" but branch1 output at path \['a'\] is a <class 'tuple'>, so"
|
||||
r" their"
|
||||
" Python types differ.",
|
||||
re.DOTALL)):
|
||||
lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
re.compile(
|
||||
"switch branches must have equal output types but they differ.*The"
|
||||
r" output of branch 0 at path \['a'\] has type float32\[1\] but the"
|
||||
r" corresponding output of branch1 has type float32\[\], so the"
|
||||
" shapes do not match",
|
||||
re.DOTALL)):
|
||||
lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)),
|
||||
lambda _: dict(a=jnp.float32(1.))],
|
||||
1.)
|
||||
@ -1983,7 +2000,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
re.escape("function carry input and carry output must have equal "
|
||||
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
|
||||
"types, but they differ:\n\n"
|
||||
"The input carry x has type int32[] but the corresponding "
|
||||
"output carry component has type float32[], so the dtypes do "
|
||||
"not match"
|
||||
@ -1994,7 +2011,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
re.escape("function carry input and carry output must have equal "
|
||||
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
|
||||
"types, but they differ:\n\n"
|
||||
"The input carry component x[1] has type int32[] but the "
|
||||
"corresponding output carry component has type float32[], "
|
||||
"so the dtypes do not match"
|
||||
@ -2005,13 +2022,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
re.escape("function carry input and carry output must have equal "
|
||||
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
|
||||
"types, but they differ:\n\n"
|
||||
" * the input carry component x[0] has type int32[] but the "
|
||||
"corresponding output carry component has type float32[], "
|
||||
"so the dtypes do not match;\n"
|
||||
" * the input carry component x[1] has type int32[] but the "
|
||||
"corresponding output carry component has type float32[1,1], "
|
||||
"so the dtypes do not match and also the shapes do not match."
|
||||
"so the dtypes do not match, and the shapes do not match."
|
||||
)):
|
||||
jax.lax.scan(lambda x, _: ((x[0].astype('float32'),
|
||||
x[1].astype('float32').reshape(1, 1),
|
||||
@ -2495,7 +2512,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo))
|
||||
|
||||
# and the lowering should contain a while loop, unless the scan is fully
|
||||
# unrolled
|
||||
# unrolled
|
||||
self.assertIn("while(", scan_hlo)
|
||||
self.assertIn("while(", scan_unrolled_hlo)
|
||||
self.assertNotIn("while(", scan_fully_unrolled_hlo)
|
||||
@ -2786,7 +2803,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False)
|
||||
|
||||
|
||||
def test_while_loop_fixed_point_with_batched_pred_and_consts(self):
|
||||
def f(i, x):
|
||||
def cond(carry):
|
||||
@ -3076,7 +3092,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def leak():
|
||||
data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1)
|
||||
def g():
|
||||
return jax.lax.cond(
|
||||
return jax.lax.cond(
|
||||
True,
|
||||
lambda: data[0], # noqa: F821
|
||||
lambda: data[1], # noqa: F821
|
||||
|
@ -1084,11 +1084,9 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def f(x, y):
|
||||
def true_fn(x, y):
|
||||
return x
|
||||
return lax.pvary(x, 'y')
|
||||
def false_fun(x, y):
|
||||
return y
|
||||
x = lax.pvary(x, 'y')
|
||||
y = lax.pvary(y, 'x')
|
||||
return lax.pvary(y, 'x')
|
||||
return jax.lax.cond(True, true_fn, false_fun, x, y)
|
||||
|
||||
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x)
|
||||
@ -2809,10 +2807,10 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
xs = jax.lax.with_sharding_constraint(xs, NamedSharding(mesh, P('i')))
|
||||
|
||||
def fun(v, xs):
|
||||
# Commenting this single line below makes everything work
|
||||
v = jax.scipy.linalg.expm(v)
|
||||
v = v.sum()
|
||||
return v * xs.sum(axis=-1).astype(v.dtype)
|
||||
# Commenting this single line below makes everything work
|
||||
v = jax.scipy.linalg.expm(v)
|
||||
v = v.sum()
|
||||
return v * xs.sum(axis=-1).astype(v.dtype)
|
||||
|
||||
res = fun(variables, xs)
|
||||
fun_shard_map = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs)
|
||||
@ -2831,25 +2829,120 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_mesh((4,), ('i',))
|
||||
@jax.custom_jvp
|
||||
def f(a: jax.Array, b: jax.Array) -> jax.Array:
|
||||
return a + b
|
||||
return a + b
|
||||
|
||||
@partial(f.defjvp, symbolic_zeros=True)
|
||||
def f_jvp(primals, tangents):
|
||||
a, b = primals
|
||||
a_dot, b_dot = tangents
|
||||
y = f(a, b)
|
||||
y_dot = jnp.zeros_like(y)
|
||||
if not isinstance(a_dot, SymbolicZero):
|
||||
y_dot += a_dot
|
||||
if not isinstance(b_dot, SymbolicZero):
|
||||
y_dot += b_dot
|
||||
return y, y_dot
|
||||
a, b = primals
|
||||
a_dot, b_dot = tangents
|
||||
y = f(a, b)
|
||||
y_dot = jnp.zeros_like(y)
|
||||
if not isinstance(a_dot, SymbolicZero):
|
||||
y_dot += a_dot
|
||||
if not isinstance(b_dot, SymbolicZero):
|
||||
y_dot += b_dot
|
||||
return y, y_dot
|
||||
x = jax.random.normal(jax.random.key(0), (jax.device_count(), 20))
|
||||
A = jax.random.normal(jax.random.key(1), (jax.device_count(), 20))
|
||||
|
||||
g = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
|
||||
jax.jvp(lambda x: g(x, A), (x,), (x,)) # don't crash
|
||||
|
||||
def test_cond_pvary_errors(self):
|
||||
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
|
||||
def f(x, y):
|
||||
def true_fn(x, y):
|
||||
return x
|
||||
def false_fun(x, y):
|
||||
return y
|
||||
return jax.lax.cond(True, true_fn, false_fun, x, y)
|
||||
x = jnp.arange(4.)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"):
|
||||
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x)
|
||||
|
||||
def test_cond_pvary_errors_pytree(self):
|
||||
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
|
||||
|
||||
def f(x, y):
|
||||
def true_fn(x, y):
|
||||
return x, y
|
||||
def false_fun(x, y):
|
||||
return y, x
|
||||
return jax.lax.cond(True, true_fn, false_fun, x, y)
|
||||
x = jnp.arange(4.)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"):
|
||||
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x)
|
||||
|
||||
def test_scan_pvary_errors(self):
|
||||
mesh = jtu.create_mesh((1, 1), ('i', 'j'))
|
||||
x = jnp.arange(3.)
|
||||
y = jnp.arange(3.)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
|
||||
def f(x, y):
|
||||
def body(carry, _):
|
||||
c1, c2 = carry
|
||||
return (c2, c1), () # swap the carry
|
||||
(x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)
|
||||
return x_, y_
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"This might be fixed by applying `jax.lax.pvary\(..., \('i',\)\)` to"
|
||||
r' the initial'):
|
||||
f(x, y)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
|
||||
def g(x, y):
|
||||
def body(carry, _):
|
||||
c1, c2 = carry
|
||||
return (c2, c1), ()
|
||||
y = jax.lax.pvary(y, 'i') # fix the issue
|
||||
(x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)
|
||||
return x_, y_
|
||||
|
||||
g(x, y) # doesn't crash
|
||||
|
||||
def test_scan_pvary_errors2(self):
|
||||
mesh = jtu.create_mesh((1, 1), ('i', 'j'))
|
||||
x = jnp.arange(3.)
|
||||
y = jnp.arange(3.)
|
||||
z = jnp.arange(3.)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j')))
|
||||
def f(x, y, z):
|
||||
def body(carry, _):
|
||||
c1, c2, c3 = carry
|
||||
return (c3, c1, c2), () # swap the carry
|
||||
|
||||
# x = jax.lax.pvary(x, 'j')
|
||||
# y = jax.lax.pvary(y, ('i', 'j'))
|
||||
carry, _ = jax.lax.scan(body, (x, y, z), (), length=2)
|
||||
return carry
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r'This might be fixed by:\n \* applying `jax.lax.pvary\(...,'
|
||||
r" \('j',\)\)`"):
|
||||
f(x, y, z)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j')))
|
||||
def g(x, y, z):
|
||||
def body(carry, _):
|
||||
c1, c2, c3 = carry
|
||||
return (c3, c1, c2), () # swap the carry
|
||||
|
||||
x = jax.lax.pvary(x, 'j') # fix the issue
|
||||
y = jax.lax.pvary(y, ('i', 'j'))
|
||||
carry, _ = jax.lax.scan(body, (x, y, z), (), length=2)
|
||||
return carry
|
||||
|
||||
g(x, y, z) # doesn't crash
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user