rocm_jax/tests/core_test.py
Matthew Johnson 7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00

736 lines
23 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from collections import namedtuple
from functools import partial
import gc
import operator
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.api_util import flatten_fun_nokwargs, debug_info
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src import util
from jax._src import test_util as jtu
from jax._src.core import ShapedArray, DBIdx
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import control_flow as lax_control_flow
config.parse_flags_with_absl()
__ = pe.PartialVal.unknown(ShapedArray((), np.float32))
def call(f, *args):
return jit(f)(*args)
def core_call(f, *args):
args, in_tree = jax.tree.flatten(args)
dbg = debug_info("core_call_test", f, args, {})
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
out = core.call_p.bind(f, *args)
return jax.tree.unflatten(out_tree(), out)
# call = core_call
core_call = util.curry(core_call)
@util.curry
def core_closed_call(f, *args):
args, in_tree = jax.tree.flatten(args)
dbg = debug_info("core_closed_call_test", f, args, {})
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
out = core.closed_call_p.bind(f, *args)
return jax.tree.unflatten(out_tree(), out)
def simple_fun(x, y):
return jnp.sin(x * y)
def simple_fun_fanout(x, y):
return jnp.sin(x * y) * x
def fun_with_call(x):
return call(jnp.sin, x)
def fun_with_nested_calls(x):
def f(y):
y2 = jnp.sin(y) + 1.0 + (2.0 * x)
@jit
def g(z):
return y2 * z * x + (x * y)
return call(g, y)
return call(f, x)
def error(*args):
def f(*args):
assert False
return f
def fun_with_nested_calls_2(x):
def bar(y):
def baz(w):
q = call(lambda x: y, x)
q = q + call(lambda: y)
q = q + call(lambda y: w + y, y)
q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q
return q
p, t = jvp(baz, (x + 1.0,), (y,))
return t + (x * p)
return call(bar, x)
def fun_call_jitted(x):
@jit
def g(z):
return x * z
return call(g, x)
def fun_with_two_calls(x):
return call(jnp.sin, x) + call(jnp.cos, x)
def fun_with_call_closure(x):
def foo(y, z):
return (x * x) * jnp.sin(y) * z
return call(foo, x, jnp.cos(x)) + x
def product_io_fun(x, y):
xa = x['a']
xb = x['b']
y1, (y2, y3) = y
return jnp.sin(xa + y2), [xb, (y1, y3)]
_rng = np.random.RandomState(42)
R = _rng.randn
CallSpec = namedtuple('CallSpec', ['fun', 'args'])
test_specs_base = [
CallSpec(simple_fun, (R(3, 2), R(3, 2))),
CallSpec(simple_fun_fanout, (R(3, 2), R(3, 2))),
CallSpec(product_io_fun, ({'a': R(2, 2), 'b': R(2, 2)},
(R(2, 2), (R(2, 2), R(2, 2))))),
CallSpec(fun_with_call, (R(3, 2),)),
CallSpec(fun_with_two_calls, (R(3, 2),)),
CallSpec(fun_with_call_closure, (R(3, 2),)),
CallSpec(fun_call_jitted, (R(1,),)),
CallSpec(fun_with_nested_calls, (R(),)),
CallSpec(fun_with_nested_calls, (R(3, 2),)),
CallSpec(fun_with_nested_calls_2, (R(1, 2),)),
]
def jvp_unlinearized(f, primals, tangents):
out, jvp = linearize(f, *primals)
return out, jvp(*tangents)
test_specs = []
for ts in test_specs_base:
test_specs.append(ts)
test_specs.append(CallSpec(partial(jvp, ts.fun), (ts.args, ts.args)))
test_specs.append(CallSpec(jit(ts.fun), ts.args))
test_specs.append(CallSpec(jit(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_call(ts.fun), ts.args))
test_specs.append(CallSpec(core_call(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_call(core_call(ts.fun)), ts.args))
test_specs.append(CallSpec(core_closed_call(ts.fun), ts.args))
test_specs.append(CallSpec(core_closed_call(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_closed_call(core_closed_call(ts.fun)), ts.args))
test_specs.append(CallSpec(partial(jvp_unlinearized, ts.fun),
(ts.args, ts.args)))
def fwd_deriv(f):
def df(x):
return jvp(f, (x,), (1.0,))[1]
return df
class CoreTest(jtu.JaxTestCase):
def test_tree_map(self):
xs = ({'a': 1}, [2, 3])
ys = ({'a': 10}, [20, 30])
ys_bad = ({'a': 10, 'b': 10}, [20, 30])
zs = ({'a': 11}, [22, 33])
f = lambda x, y: x + y
assert jax.tree.map(f, xs, ys) == zs
try:
jax.tree.map(f, xs, ys_bad)
assert False
except (TypeError, ValueError):
pass
def test_tree_flatten(self):
flat, _ = jax.tree.flatten(({'a': 1}, [2, 3], 4))
assert flat == [1, 2, 3, 4]
def test_tree_unflatten(self):
tree = [(1, 2), {"roy": (3, [4, 5, ()])}]
flat, treedef = jax.tree.flatten(tree)
assert flat == [1, 2, 3, 4, 5]
tree2 = jax.tree.unflatten(treedef, flat)
nodes_equal = jax.tree.map(operator.eq, tree, tree2)
assert jax.tree.reduce(operator.and_, nodes_equal)
@jtu.sample_product(
dtype=[*jtu.dtypes.all, object, [('i', 'i4'), ('f', 'f4')]]
)
def test_is_valid_jaxtype(self, dtype):
arr = np.zeros(10, dtype=dtype)
if dtype in jtu.dtypes.all:
self.assertTrue(core.valid_jaxtype(arr))
else:
self.assertFalse(core.valid_jaxtype(arr))
@parameterized.named_parameters(
(str(i), *spec) for i, spec in enumerate(test_specs))
def test_jit(self, f, args):
jtu.check_close(jit(f)(*args), f(*args))
@parameterized.named_parameters(
(str(i), *spec) for i, spec in enumerate(test_specs))
def test_jvp(self, f, args):
jtu.check_jvp(f, partial(jvp, f), args, rtol={np.float32: 3e-2})
def test_jvp_zeros(self):
def foo(x):
def bar(y):
return jnp.sin(x * y)
return jvp(bar, (3 * x,), (2 * x,))
jtu.check_eq(jit(foo)(0.5), foo(0.5))
@parameterized.parameters(test_specs)
def test_jvp_linearized(self, f, args):
jtu.check_jvp(f, partial(jvp_unlinearized, f), args,
rtol={np.float32: 3e-2})
@parameterized.named_parameters(
(str(i), *spec) for i, spec in enumerate(test_specs))
def test_vjp(self, f, args):
jtu.check_vjp(f, partial(vjp, f), args,
rtol={np.float32: 3e-1, np.float64: 1e-5},
atol={np.float32: 1e-2, np.float64: 1e-5})
def test_jvp_closure(self):
def foo(x):
def bar(y):
return jnp.multiply(x, y)
return jvp(bar, (3.0,), (1.0,))[1]
ans = jvp(foo, (1.0,), (2.0,))
assert ans == (1.0, 2.0), ans
def test_jit_closure(self):
def foo(x):
@jit
def bar(y):
return x + y
return bar(0.0)
assert jvp(foo, (1.0,), (2.0,)) == (1.0, 2.0)
def test_simple_jit(self):
def foo(x):
if x.shape == ():
return x + 1.
else:
return x + 2.
foo2 = jit(foo)
foo3 = jit(foo2)
x1, y1 = np.array(1.0), np.array(2.0)
assert foo(x1) == y1
assert foo2(x1) == y1
assert foo3(x1) == y1
x2, y2 = np.array([1.0, 2.0]), np.array([3.0, 4.0])
assert np.all(foo(x2) == y2)
assert np.all(foo2(x2) == y2)
assert np.all(foo3(x2) == y2)
def test_product_jit(self):
def foo(x, tup):
y, z = tup
w = x + z
return (w, {'x': y}), z
foo2 = jit(foo)
foo3 = jit(foo2)
args = (1.0, (2.0, 3.0))
expected_output = ((4.0, {'x': 2.0}), 3.0)
assert foo(*args) == expected_output
assert foo2(*args) == expected_output
assert foo3(*args) == foo(*args)
def test_jvp_repeated_fwd(self):
d_sin = fwd_deriv(jnp.sin)
d2_sin = fwd_deriv(d_sin)
d3_sin = fwd_deriv(d2_sin)
assert d_sin(0.0) == 1.0
assert d2_sin(0.0) == 0.0
assert d3_sin(0.0) == -1.0
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
def test_reference_cycles(self):
if jtu.TEST_NUM_THREADS.value > 1:
self.skipTest("Test does not work with multiple threads")
gc.collect()
def f(x):
return x.sum()
fn = partial(linearize, f)
params = jnp.zeros([])
debug = gc.get_debug()
try:
fn(params)
gc.set_debug(gc.DEBUG_SAVEALL)
self.assertEqual(gc.collect(), 0, msg=str(gc.garbage))
finally:
gc.set_debug(debug)
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
def test_reference_cycles_jit(self):
if jtu.TEST_NUM_THREADS.value > 1:
self.skipTest("Test does not work with multiple threads")
gc.collect()
def f(x):
return x.sum()
fn = jit(f)
params = jnp.zeros([])
debug = gc.get_debug()
try:
fn(params).block_until_ready()
gc.set_debug(gc.DEBUG_SAVEALL)
self.assertEqual(gc.collect(), 0, msg=str(gc.garbage))
finally:
gc.set_debug(debug)
def test_invalid_shape_error_with_jit_tracer_passed(self):
@jax.jit
def g_jit(x):
return jnp.zeros(shape=(2, x))
@jax.vmap
def g_vmap(x):
return jnp.zeros(shape=(2, x))
with self.assertRaisesRegex(
TypeError,
'This concrete value was not available in'
+ ' Python because it depends on',
):
g_jit(1)
with self.assertRaisesRegex(TypeError,
'This BatchTracer with object id'):
g_vmap(jnp.ones((1, )))
def test_dropvar_avals(self):
def f(x):
def body(c, _):
return c, None
(x1, x2), _ = jax.lax.scan(body, (x, x), None, length=1)
return [x2]
aval = core.ShapedArray((), jnp.dtype('int32'))
pval = pe.PartialVal.unknown(aval)
jaxpr, _, _ = pe.trace_to_jaxpr_nounits(
lu.wrap_init(f,
debug_info=debug_info("test", f, (0,), {})),
[pval], False)
dropvar, b = jaxpr.eqns[0].outvars
self.assertEqual(dropvar.aval, aval)
def test_input_residual_forwarding(self):
# https://github.com/jax-ml/jax/pull/11151
x = jnp.arange(3 * 4.).reshape(3, 4)
y = jnp.arange(4 * 3.).reshape(4, 3)
g = jax.jit(jnp.dot)
def f(y):
z, g_lin = jax.linearize(lambda y: g(x, y), y)
zdot = g_lin(y)
return z, zdot
jaxpr = jax.make_jaxpr(f)(y)
e1, e2 = jaxpr.jaxpr.eqns
self.assertLen(e1.outvars, 1) # only primal out, no residuals
self.assertEqual(e1.outvars[0].aval.shape, (3, 3)) # only primal out shape
@jtu.with_config(jax_pprint_use_color=False)
class JaxprTypeChecks(jtu.JaxTestCase):
def setUp(self):
super().setUp()
lax_control_flow._initial_style_open_jaxpr.cache_clear()
lax_control_flow._initial_style_jaxpr.cache_clear()
lax_control_flow.common._pad_jaxpr_constvars.cache_clear()
def test_check_jaxpr_correct(self):
jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
core.check_jaxpr(jaxpr)
def test_check_jaxpr_cond_correct(self):
jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
core.check_jaxpr(jaxpr)
def test_check_jaxpr_jit_invalid(self):
jaxpr = make_jaxpr(jax.jit(lambda x, y: x + 1))(1., 2.).jaxpr
pjit_eqn, = jaxpr.eqns
jaxpr._eqns[0] = pjit_eqn.replace(invars=())
self.assertRaisesRegex(
core.JaxprTypeError,
'0 operands cannot call jaxpr with 2 inputs',
lambda: core.check_jaxpr(jaxpr))
def test_check_jaxpr_cond_invalid(self):
jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
cond.params['branches'][0].jaxpr._invars = ()
self.assertRaisesRegex(
core.JaxprTypeError,
'cond branch 0 takes 0 inputs, branch 1 takes 1',
lambda: core.check_jaxpr(jaxpr))
def test_check_jaxpr_scan_correct(self):
def f(c, x):
b = jnp.cos(jnp.sum(jnp.sin(x)) + jnp.sum(jnp.cos(c)))
c = jnp.sin(c * b)
return c, b
xs = jnp.ones((5, 3))
c = jnp.ones(4)
jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr
core.check_jaxpr(jaxpr)
def test_check_jaxpr_invalid_long(self):
# jaxprs can be large, and this tests that when large ones are printed for
# context in jaxpr typechecking errors, they're not printed entirely
def enlarge(f, n):
def g(x):
for _ in range(n):
x = x + x
x = f(x)
for _ in range(n):
x = x + x
return x
return g
jaxpr = make_jaxpr(enlarge(
lambda x: lax.switch(0, [jnp.sin, jnp.cos], x), 100))(1.).jaxpr
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
cond.params['branches'][0].jaxpr._invars = ()
msg = ''
try:
core.check_jaxpr(jaxpr)
except core.JaxprTypeError as e:
msg, = e.args
self.assertIn('cond branch 0 takes 0 inputs, branch 1 takes 1', msg)
self.assertIn('in equation:', msg)
self.assertIn('from source:', msg)
self.assertIn('while checking jaxpr:', msg)
self.assertLess(msg.count('\n'), 200)
def test_check_jaxpr_eqn_mismatch(self):
def f(x):
return jnp.sin(x) + jnp.cos(x)
def new_jaxpr():
return make_jaxpr(f)(jnp.float32(1.)).jaxpr
# jaxpr is:
#
# { lambda ; a.
# let b = sin a
# c = cos a
# d = add b c
# in (d,) }
#
# NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'
jaxpr = new_jaxpr()
# int, not float!
jaxpr.eqns[0].outvars[0].aval = core.ShapedArray((), jnp.dtype(jnp.int32))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Value for variable 'b' inconsistently typed as f32\[\] "
r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
lambda: core.check_jaxpr(jaxpr))
jaxpr = new_jaxpr()
jaxpr.eqns[0].outvars[0].aval = core.ShapedArray((2, 3),
jnp.dtype(jnp.float32))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Value for variable 'b' inconsistently typed as f32\[\] "
r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a",
lambda: core.check_jaxpr(jaxpr))
def test_jaxpr_dropvar_from_jit_call(self):
def inner(x):
return x + 1, x + 2
def f(x):
_, y = jit(inner)(x)
return y + 3
jaxpr = make_jaxpr(f)(1).jaxpr
assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar)
core.check_jaxpr(jaxpr)
def test_jaxpr_dropvar_from_loop(self):
def f(x):
_, y = lax.while_loop(lambda s: s[0] < 0.,
lambda s: (jnp.sin(s[0]), jnp.cos(s[1])),
(x, x))
return y + 1.
jaxpr = make_jaxpr(f)(1.).jaxpr
assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar)
core.check_jaxpr(jaxpr)
def test_jaxpr_dropvar_from_cond(self):
def f(x):
_, y = lax.cond(x < 0.,
lambda x: (jnp.sin(x), x + 1.),
lambda x: (jnp.cos(x), x + 2.),
x)
return y
jaxpr = make_jaxpr(f)(1.).jaxpr
assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar)
core.check_jaxpr(jaxpr)
def test_jaxpr_undefined_eqn_invar(self):
jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos')
cos.invars[0] = core.gensym(suffix='_test')(cos.invars[0].aval)
self.assertRaisesRegex(
core.JaxprTypeError,
r"Variable '.+_test' not defined\n\nin equation:",
lambda: core.check_jaxpr(jaxpr))
@jtu.with_config(jax_dynamic_shapes=True)
class DynamicShapesTest(jtu.JaxTestCase):
def test_staging_basic(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
def f(x, y):
return x, y
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 3)
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)
self.assertLen(jaxpr.outvars, 2)
self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)
@unittest.skip('This test does not work with nested pjit and DShapedArray')
def test_staging_nested(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
def f(x, y):
@jax.jit
def g(x, y, z, w):
return (x, w)
return g(x, y, x, y)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (0, 1), {})),
[n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)
self.assertLen(jaxpr.outvars, 2)
self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)
self.assertLen(jaxpr.eqns, 1)
eqn = jaxpr.eqns[0]
self.assertIsInstance(eqn.primitive, core.CallPrimitive)
inner_jaxpr = eqn.params['call_jaxpr']
self.assertIsInstance(inner_jaxpr, core.Jaxpr)
self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
@unittest.skip('This test does not work with nested pjit and DShapedArray')
def test_staging_nested_including_shape_arg(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
def f(x, y):
@jax.jit
def g(_, x, y, z, w):
return (x, w)
return g(x.shape[0], x, y, x, y)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, a, b], keep_inputs=[False, True, True])
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
# d:f32[a] e:f32[a] = xla_call[
# call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let
#
# in (h, k) }
# name=g
# ] a a b c b c
# in (d, e) }
self.assertLen(jaxpr.eqns, 1)
eqn = jaxpr.eqns[0]
self.assertIsInstance(eqn.primitive, core.CallPrimitive)
inner_jaxpr = eqn.params['call_jaxpr']
self.assertIsInstance(inner_jaxpr, core.Jaxpr)
self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
def test_staging_primitive_applications(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
def f(x, y):
z = lax.mul(x, y)
w = lax.sin(z)
u = lax.reduce_sum(w, [0])
return (u,)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
self.assertLen(jaxpr.eqns, 3)
self.assertLen(jaxpr.eqns[0].outvars, 1)
self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape,
jaxpr.invars[1].aval.shape)
self.assertLen(jaxpr.outvars, 1)
self.assertEqual(jaxpr.outvars[0].aval.shape, ())
@unittest.skip('This test does not work with nested pjit and DShapedArray')
def test_typecheck_staging_nested(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(1),), jnp.dtype('float32'), weak_type=False)
def f(a, b):
@jax.jit
def g(x): return x
return g(a),
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, m, a, b], keep_inputs=[False, False, True, True])
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
# e:f32[a] = xla_call[
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }
# name=g
# ] a c
# in (e,) }
core.check_jaxpr(jaxpr) # no problems here...
# Let's introduce a type error by applying the called jaxpr to arguments
# with types which aren't consistent with its input binders:
_, _, c, d = jaxpr.invars
jaxpr.eqns[0].invars[1] = d
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
# e:f32[a] = xla_call[
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }
# name=g
# ] a d !!! type error here !!!
# in (e,) }
with self.assertRaisesRegex(TypeError, "passes operand"):
core.check_jaxpr(jaxpr)
# Restore the original jaxpr:
jaxpr.eqns[0].invars[1] = c
core.check_jaxpr(jaxpr) # no problems here...
# Let's introduce another type error by setting the call result let binders
# to have the wrong type:
jaxpr.eqns[0].outvars[0] = core.Var('', d.aval)
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
# e:f32[b] = xla_call[ !!! type error here !!!
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }
# name=g
# ] a c
# in (h,) }
with self.assertRaisesRegex(TypeError, "inconsistently typed as"):
core.check_jaxpr(jaxpr)
def test_check_jaxpr_key_reuse(self):
with config.debug_key_reuse(True):
def f(seed):
key = jax.random.key(seed)
return jax.random.uniform(key) + jax.random.normal(key)
with jax.enable_checks(True):
with self.assertRaises(jax.errors.KeyReuseError):
jax.jit(f)(0)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())