rocm_jax/tests/core_test.py

736 lines
23 KiB
Python
Raw Normal View History

# Copyright 2018 The JAX Authors.
2018-11-17 18:03:33 -08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
2018-11-17 18:03:33 -08:00
from collections import namedtuple
from functools import partial
import gc
import operator
2018-11-17 18:03:33 -08:00
import numpy as np
2018-11-17 18:03:33 -08:00
from absl.testing import absltest
from absl.testing import parameterized
import jax
2020-06-08 16:13:30 -07:00
from jax import lax
from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.api_util import flatten_fun_nokwargs, debug_info
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
2022-05-05 15:18:57 -07:00
from jax._src import util
from jax._src import test_util as jtu
from jax._src.core import ShapedArray, DBIdx
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import control_flow as lax_control_flow
2018-11-17 18:03:33 -08:00
config.parse_flags_with_absl()
__ = pe.PartialVal.unknown(ShapedArray((), np.float32))
2018-11-17 18:03:33 -08:00
def call(f, *args):
return jit(f)(*args)
2022-05-05 15:18:57 -07:00
def core_call(f, *args):
args, in_tree = jax.tree.flatten(args)
dbg = debug_info("core_call_test", f, args, {})
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
2022-05-05 15:18:57 -07:00
out = core.call_p.bind(f, *args)
return jax.tree.unflatten(out_tree(), out)
# call = core_call
core_call = util.curry(core_call)
2022-05-05 15:18:57 -07:00
2022-05-14 11:03:50 -07:00
@util.curry
def core_closed_call(f, *args):
args, in_tree = jax.tree.flatten(args)
dbg = debug_info("core_closed_call_test", f, args, {})
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
2022-05-14 11:03:50 -07:00
out = core.closed_call_p.bind(f, *args)
return jax.tree.unflatten(out_tree(), out)
2022-05-14 11:03:50 -07:00
2018-11-17 18:03:33 -08:00
def simple_fun(x, y):
return jnp.sin(x * y)
2018-11-17 18:03:33 -08:00
def simple_fun_fanout(x, y):
return jnp.sin(x * y) * x
2018-11-17 18:03:33 -08:00
def fun_with_call(x):
return call(jnp.sin, x)
2018-11-17 18:03:33 -08:00
def fun_with_nested_calls(x):
def f(y):
y2 = jnp.sin(y) + 1.0 + (2.0 * x)
2018-11-17 18:03:33 -08:00
@jit
def g(z):
return y2 * z * x + (x * y)
return call(g, y)
return call(f, x)
def error(*args):
def f(*args):
assert False
return f
def fun_with_nested_calls_2(x):
def bar(y):
def baz(w):
q = call(lambda x: y, x)
q = q + call(lambda: y)
q = q + call(lambda y: w + y, y)
q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q
return q
2018-11-17 18:03:33 -08:00
p, t = jvp(baz, (x + 1.0,), (y,))
return t + (x * p)
return call(bar, x)
def fun_call_jitted(x):
@jit
def g(z):
return x * z
return call(g, x)
def fun_with_two_calls(x):
return call(jnp.sin, x) + call(jnp.cos, x)
2018-11-17 18:03:33 -08:00
def fun_with_call_closure(x):
def foo(y, z):
return (x * x) * jnp.sin(y) * z
2018-11-17 18:03:33 -08:00
return call(foo, x, jnp.cos(x)) + x
2018-11-17 18:03:33 -08:00
def product_io_fun(x, y):
xa = x['a']
xb = x['b']
y1, (y2, y3) = y
return jnp.sin(xa + y2), [xb, (y1, y3)]
2018-11-17 18:03:33 -08:00
2020-05-11 12:09:54 -04:00
_rng = np.random.RandomState(42)
R = _rng.randn
CallSpec = namedtuple('CallSpec', ['fun', 'args'])
2018-11-17 18:03:33 -08:00
test_specs_base = [
CallSpec(simple_fun, (R(3, 2), R(3, 2))),
CallSpec(simple_fun_fanout, (R(3, 2), R(3, 2))),
CallSpec(product_io_fun, ({'a': R(2, 2), 'b': R(2, 2)},
2018-11-17 18:03:33 -08:00
(R(2, 2), (R(2, 2), R(2, 2))))),
CallSpec(fun_with_call, (R(3, 2),)),
CallSpec(fun_with_two_calls, (R(3, 2),)),
CallSpec(fun_with_call_closure, (R(3, 2),)),
CallSpec(fun_call_jitted, (R(1,),)),
CallSpec(fun_with_nested_calls, (R(),)),
CallSpec(fun_with_nested_calls, (R(3, 2),)),
CallSpec(fun_with_nested_calls_2, (R(1, 2),)),
2018-11-17 18:03:33 -08:00
]
def jvp_unlinearized(f, primals, tangents):
out, jvp = linearize(f, *primals)
return out, jvp(*tangents)
2018-11-17 18:03:33 -08:00
test_specs = []
for ts in test_specs_base:
test_specs.append(ts)
test_specs.append(CallSpec(partial(jvp, ts.fun), (ts.args, ts.args)))
test_specs.append(CallSpec(jit(ts.fun), ts.args))
test_specs.append(CallSpec(jit(jit(ts.fun)), ts.args))
2022-05-05 15:18:57 -07:00
test_specs.append(CallSpec(core_call(ts.fun), ts.args))
test_specs.append(CallSpec(core_call(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_call(core_call(ts.fun)), ts.args))
2022-05-14 11:03:50 -07:00
test_specs.append(CallSpec(core_closed_call(ts.fun), ts.args))
test_specs.append(CallSpec(core_closed_call(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_closed_call(core_closed_call(ts.fun)), ts.args))
test_specs.append(CallSpec(partial(jvp_unlinearized, ts.fun),
2018-11-17 18:03:33 -08:00
(ts.args, ts.args)))
def fwd_deriv(f):
def df(x):
return jvp(f, (x,), (1.0,))[1]
return df
class CoreTest(jtu.JaxTestCase):
2022-04-01 14:51:54 -07:00
def test_tree_map(self):
2018-11-17 18:03:33 -08:00
xs = ({'a': 1}, [2, 3])
ys = ({'a': 10}, [20, 30])
ys_bad = ({'a': 10, 'b': 10}, [20, 30])
zs = ({'a': 11}, [22, 33])
f = lambda x, y: x + y
assert jax.tree.map(f, xs, ys) == zs
2018-11-17 18:03:33 -08:00
try:
jax.tree.map(f, xs, ys_bad)
2018-11-17 18:03:33 -08:00
assert False
2019-08-01 17:20:27 -04:00
except (TypeError, ValueError):
2018-11-17 18:03:33 -08:00
pass
def test_tree_flatten(self):
flat, _ = jax.tree.flatten(({'a': 1}, [2, 3], 4))
2018-11-17 18:03:33 -08:00
assert flat == [1, 2, 3, 4]
def test_tree_unflatten(self):
tree = [(1, 2), {"roy": (3, [4, 5, ()])}]
flat, treedef = jax.tree.flatten(tree)
assert flat == [1, 2, 3, 4, 5]
tree2 = jax.tree.unflatten(treedef, flat)
nodes_equal = jax.tree.map(operator.eq, tree, tree2)
assert jax.tree.reduce(operator.and_, nodes_equal)
2023-09-22 13:46:09 -07:00
@jtu.sample_product(
dtype=[*jtu.dtypes.all, object, [('i', 'i4'), ('f', 'f4')]]
)
def test_is_valid_jaxtype(self, dtype):
arr = np.zeros(10, dtype=dtype)
if dtype in jtu.dtypes.all:
self.assertTrue(core.valid_jaxtype(arr))
else:
self.assertFalse(core.valid_jaxtype(arr))
@parameterized.named_parameters(
(str(i), *spec) for i, spec in enumerate(test_specs))
2018-11-17 18:03:33 -08:00
def test_jit(self, f, args):
jtu.check_close(jit(f)(*args), f(*args))
2018-11-17 18:03:33 -08:00
@parameterized.named_parameters(
(str(i), *spec) for i, spec in enumerate(test_specs))
2018-11-17 18:03:33 -08:00
def test_jvp(self, f, args):
jtu.check_jvp(f, partial(jvp, f), args, rtol={np.float32: 3e-2})
2018-11-17 18:03:33 -08:00
def test_jvp_zeros(self):
def foo(x):
def bar(y):
return jnp.sin(x * y)
2018-11-17 18:03:33 -08:00
return jvp(bar, (3 * x,), (2 * x,))
jtu.check_eq(jit(foo)(0.5), foo(0.5))
@parameterized.parameters(test_specs)
def test_jvp_linearized(self, f, args):
jtu.check_jvp(f, partial(jvp_unlinearized, f), args,
rtol={np.float32: 3e-2})
2018-11-17 18:03:33 -08:00
@parameterized.named_parameters(
(str(i), *spec) for i, spec in enumerate(test_specs))
2018-11-17 18:03:33 -08:00
def test_vjp(self, f, args):
jtu.check_vjp(f, partial(vjp, f), args,
rtol={np.float32: 3e-1, np.float64: 1e-5},
atol={np.float32: 1e-2, np.float64: 1e-5})
2018-11-17 18:03:33 -08:00
def test_jvp_closure(self):
def foo(x):
def bar(y):
return jnp.multiply(x, y)
2018-11-17 18:03:33 -08:00
return jvp(bar, (3.0,), (1.0,))[1]
ans = jvp(foo, (1.0,), (2.0,))
assert ans == (1.0, 2.0), ans
def test_jit_closure(self):
def foo(x):
@jit
def bar(y):
return x + y
return bar(0.0)
assert jvp(foo, (1.0,), (2.0,)) == (1.0, 2.0)
def test_simple_jit(self):
def foo(x):
if x.shape == ():
return x + 1.
else:
return x + 2.
foo2 = jit(foo)
foo3 = jit(foo2)
x1, y1 = np.array(1.0), np.array(2.0)
2018-11-17 18:03:33 -08:00
assert foo(x1) == y1
assert foo2(x1) == y1
assert foo3(x1) == y1
x2, y2 = np.array([1.0, 2.0]), np.array([3.0, 4.0])
assert np.all(foo(x2) == y2)
assert np.all(foo2(x2) == y2)
assert np.all(foo3(x2) == y2)
2018-11-17 18:03:33 -08:00
def test_product_jit(self):
def foo(x, tup):
y, z = tup
w = x + z
return (w, {'x': y}), z
foo2 = jit(foo)
foo3 = jit(foo2)
args = (1.0, (2.0, 3.0))
expected_output = ((4.0, {'x': 2.0}), 3.0)
assert foo(*args) == expected_output
assert foo2(*args) == expected_output
assert foo3(*args) == foo(*args)
def test_jvp_repeated_fwd(self):
d_sin = fwd_deriv(jnp.sin)
2018-11-17 18:03:33 -08:00
d2_sin = fwd_deriv(d_sin)
d3_sin = fwd_deriv(d2_sin)
assert d_sin(0.0) == 1.0
assert d2_sin(0.0) == 0.0
assert d3_sin(0.0) == -1.0
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
def test_reference_cycles(self):
if jtu.TEST_NUM_THREADS.value > 1:
self.skipTest("Test does not work with multiple threads")
gc.collect()
def f(x):
return x.sum()
fn = partial(linearize, f)
params = jnp.zeros([])
debug = gc.get_debug()
try:
fn(params)
gc.set_debug(gc.DEBUG_SAVEALL)
self.assertEqual(gc.collect(), 0, msg=str(gc.garbage))
finally:
gc.set_debug(debug)
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
def test_reference_cycles_jit(self):
if jtu.TEST_NUM_THREADS.value > 1:
self.skipTest("Test does not work with multiple threads")
gc.collect()
def f(x):
return x.sum()
fn = jit(f)
params = jnp.zeros([])
debug = gc.get_debug()
try:
fn(params).block_until_ready()
gc.set_debug(gc.DEBUG_SAVEALL)
self.assertEqual(gc.collect(), 0, msg=str(gc.garbage))
finally:
gc.set_debug(debug)
def test_invalid_shape_error_with_jit_tracer_passed(self):
@jax.jit
def g_jit(x):
return jnp.zeros(shape=(2, x))
@jax.vmap
def g_vmap(x):
return jnp.zeros(shape=(2, x))
with self.assertRaisesRegex(
TypeError,
'This concrete value was not available in'
+ ' Python because it depends on',
):
g_jit(1)
with self.assertRaisesRegex(TypeError,
'This BatchTracer with object id'):
g_vmap(jnp.ones((1, )))
def test_dropvar_avals(self):
def f(x):
def body(c, _):
return c, None
(x1, x2), _ = jax.lax.scan(body, (x, x), None, length=1)
return [x2]
aval = core.ShapedArray((), jnp.dtype('int32'))
pval = pe.PartialVal.unknown(aval)
jaxpr, _, _ = pe.trace_to_jaxpr_nounits(
lu.wrap_init(f,
debug_info=debug_info("test", f, (0,), {})),
[pval], False)
dropvar, b = jaxpr.eqns[0].outvars
self.assertEqual(dropvar.aval, aval)
def test_input_residual_forwarding(self):
# https://github.com/jax-ml/jax/pull/11151
x = jnp.arange(3 * 4.).reshape(3, 4)
y = jnp.arange(4 * 3.).reshape(4, 3)
g = jax.jit(jnp.dot)
def f(y):
z, g_lin = jax.linearize(lambda y: g(x, y), y)
zdot = g_lin(y)
return z, zdot
jaxpr = jax.make_jaxpr(f)(y)
e1, e2 = jaxpr.jaxpr.eqns
self.assertLen(e1.outvars, 1) # only primal out, no residuals
self.assertEqual(e1.outvars[0].aval.shape, (3, 3)) # only primal out shape
@jtu.with_config(jax_pprint_use_color=False)
class JaxprTypeChecks(jtu.JaxTestCase):
def setUp(self):
super().setUp()
lax_control_flow._initial_style_open_jaxpr.cache_clear()
lax_control_flow._initial_style_jaxpr.cache_clear()
lax_control_flow.common._pad_jaxpr_constvars.cache_clear()
2020-04-14 22:22:35 -07:00
def test_check_jaxpr_correct(self):
2020-05-21 13:12:59 -07:00
jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
2020-04-14 22:22:35 -07:00
core.check_jaxpr(jaxpr)
def test_check_jaxpr_cond_correct(self):
jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
core.check_jaxpr(jaxpr)
def test_check_jaxpr_jit_invalid(self):
jaxpr = make_jaxpr(jax.jit(lambda x, y: x + 1))(1., 2.).jaxpr
pjit_eqn, = jaxpr.eqns
jaxpr._eqns[0] = pjit_eqn.replace(invars=())
self.assertRaisesRegex(
core.JaxprTypeError,
'0 operands cannot call jaxpr with 2 inputs',
lambda: core.check_jaxpr(jaxpr))
def test_check_jaxpr_cond_invalid(self):
jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
cond.params['branches'][0].jaxpr._invars = ()
self.assertRaisesRegex(
core.JaxprTypeError,
'cond branch 0 takes 0 inputs, branch 1 takes 1',
lambda: core.check_jaxpr(jaxpr))
def test_check_jaxpr_scan_correct(self):
def f(c, x):
b = jnp.cos(jnp.sum(jnp.sin(x)) + jnp.sum(jnp.cos(c)))
c = jnp.sin(c * b)
return c, b
xs = jnp.ones((5, 3))
c = jnp.ones(4)
jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr
core.check_jaxpr(jaxpr)
2020-04-14 22:22:35 -07:00
def test_check_jaxpr_invalid_long(self):
# jaxprs can be large, and this tests that when large ones are printed for
# context in jaxpr typechecking errors, they're not printed entirely
def enlarge(f, n):
def g(x):
for _ in range(n):
x = x + x
x = f(x)
for _ in range(n):
x = x + x
return x
return g
jaxpr = make_jaxpr(enlarge(
lambda x: lax.switch(0, [jnp.sin, jnp.cos], x), 100))(1.).jaxpr
cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
cond.params['branches'][0].jaxpr._invars = ()
msg = ''
try:
core.check_jaxpr(jaxpr)
except core.JaxprTypeError as e:
msg, = e.args
self.assertIn('cond branch 0 takes 0 inputs, branch 1 takes 1', msg)
self.assertIn('in equation:', msg)
self.assertIn('from source:', msg)
self.assertIn('while checking jaxpr:', msg)
self.assertLess(msg.count('\n'), 200)
2020-04-14 22:22:35 -07:00
def test_check_jaxpr_eqn_mismatch(self):
def f(x):
2020-05-21 13:12:59 -07:00
return jnp.sin(x) + jnp.cos(x)
2020-04-14 22:22:35 -07:00
2020-04-15 18:38:25 -07:00
def new_jaxpr():
return make_jaxpr(f)(jnp.float32(1.)).jaxpr
2020-04-15 18:38:25 -07:00
# jaxpr is:
#
# { lambda ; a.
# let b = sin a
# c = cos a
# d = add b c
# in (d,) }
#
# NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'
jaxpr = new_jaxpr()
# int, not float!
jaxpr.eqns[0].outvars[0].aval = core.ShapedArray((), jnp.dtype(jnp.int32))
2020-06-24 15:31:33 -07:00
self.assertRaisesRegex(
core.JaxprTypeError,
r"Value for variable 'b' inconsistently typed as f32\[\] "
r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
2020-06-24 15:31:33 -07:00
lambda: core.check_jaxpr(jaxpr))
2020-04-14 22:22:35 -07:00
2020-04-15 18:38:25 -07:00
jaxpr = new_jaxpr()
jaxpr.eqns[0].outvars[0].aval = core.ShapedArray((2, 3),
jnp.dtype(jnp.float32))
2020-06-24 15:31:33 -07:00
self.assertRaisesRegex(
core.JaxprTypeError,
r"Value for variable 'b' inconsistently typed as f32\[\] "
r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a",
2020-06-24 15:31:33 -07:00
lambda: core.check_jaxpr(jaxpr))
2020-04-14 22:22:35 -07:00
2020-06-08 16:13:30 -07:00
def test_jaxpr_dropvar_from_jit_call(self):
def inner(x):
return x + 1, x + 2
def f(x):
_, y = jit(inner)(x)
return y + 3
jaxpr = make_jaxpr(f)(1).jaxpr
assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar)
2020-06-08 16:13:30 -07:00
core.check_jaxpr(jaxpr)
def test_jaxpr_dropvar_from_loop(self):
def f(x):
_, y = lax.while_loop(lambda s: s[0] < 0.,
lambda s: (jnp.sin(s[0]), jnp.cos(s[1])),
(x, x))
return y + 1.
jaxpr = make_jaxpr(f)(1.).jaxpr
assert isinstance(jaxpr.eqns[0].outvars[0], core.DropVar)
2020-06-08 16:13:30 -07:00
core.check_jaxpr(jaxpr)
def test_jaxpr_dropvar_from_cond(self):
def f(x):
_, y = lax.cond(x < 0.,
lambda x: (jnp.sin(x), x + 1.),
lambda x: (jnp.cos(x), x + 2.),
x)
return y
jaxpr = make_jaxpr(f)(1.).jaxpr
assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar)
2020-06-08 16:13:30 -07:00
core.check_jaxpr(jaxpr)
2020-04-14 22:22:35 -07:00
def test_jaxpr_undefined_eqn_invar(self):
jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos')
Reland https://github.com/google/jax/pull/10573. The original PR was reverted because of downstream breakage. Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.) But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore. So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s. In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like: ``` import jax def f(x): for _ in range(3): x = jax.numpy.sin(x) return x jaxpr = jax.make_jaxpr(f)(3.) print(jaxpr) # { lambda ; a:f32[]. let # b:f32[] = sin a # c:f32[] = sin b # d:f32[] = sin c # in (d,) } _, eqn, _ = jaxpr.jaxpr.eqns print(eqn) # a:f32[] = sin b ``` Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print! So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`): ``` JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=()))) ``` PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
cos.invars[0] = core.gensym(suffix='_test')(cos.invars[0].aval)
self.assertRaisesRegex(
core.JaxprTypeError,
r"Variable '.+_test' not defined\n\nin equation:",
lambda: core.check_jaxpr(jaxpr))
@jtu.with_config(jax_dynamic_shapes=True)
class DynamicShapesTest(jtu.JaxTestCase):
def test_staging_basic(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
def f(x, y):
return x, y
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 3)
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)
self.assertLen(jaxpr.outvars, 2)
self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)
@unittest.skip('This test does not work with nested pjit and DShapedArray')
def test_staging_nested(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
def f(x, y):
@jax.jit
def g(x, y, z, w):
return (x, w)
return g(x, y, x, y)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (0, 1), {})),
[n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)
self.assertLen(jaxpr.outvars, 2)
self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)
self.assertLen(jaxpr.eqns, 1)
eqn = jaxpr.eqns[0]
self.assertIsInstance(eqn.primitive, core.CallPrimitive)
inner_jaxpr = eqn.params['call_jaxpr']
self.assertIsInstance(inner_jaxpr, core.Jaxpr)
self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
@unittest.skip('This test does not work with nested pjit and DShapedArray')
def test_staging_nested_including_shape_arg(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
def f(x, y):
@jax.jit
def g(_, x, y, z, w):
return (x, w)
return g(x.shape[0], x, y, x, y)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, a, b], keep_inputs=[False, True, True])
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
# d:f32[a] e:f32[a] = xla_call[
# call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let
#
# in (h, k) }
# name=g
# ] a a b c b c
# in (d, e) }
self.assertLen(jaxpr.eqns, 1)
eqn = jaxpr.eqns[0]
self.assertIsInstance(eqn.primitive, core.CallPrimitive)
inner_jaxpr = eqn.params['call_jaxpr']
self.assertIsInstance(inner_jaxpr, core.Jaxpr)
self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
def test_staging_primitive_applications(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
def f(x, y):
z = lax.mul(x, y)
w = lax.sin(z)
u = lax.reduce_sum(w, [0])
return (u,)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
self.assertLen(jaxpr.eqns, 3)
self.assertLen(jaxpr.eqns[0].outvars, 1)
self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape,
jaxpr.invars[1].aval.shape)
self.assertLen(jaxpr.outvars, 1)
self.assertEqual(jaxpr.outvars[0].aval.shape, ())
@unittest.skip('This test does not work with nested pjit and DShapedArray')
def test_typecheck_staging_nested(self):
n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
b = core.DShapedArray((DBIdx(1),), jnp.dtype('float32'), weak_type=False)
def f(a, b):
@jax.jit
def g(x): return x
return g(a),
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f,
debug_info=debug_info("test", f, (1, 2), {})),
[n, m, a, b], keep_inputs=[False, False, True, True])
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
# e:f32[a] = xla_call[
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }
# name=g
# ] a c
# in (e,) }
core.check_jaxpr(jaxpr) # no problems here...
# Let's introduce a type error by applying the called jaxpr to arguments
# with types which aren't consistent with its input binders:
_, _, c, d = jaxpr.invars
jaxpr.eqns[0].invars[1] = d
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
# e:f32[a] = xla_call[
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }
# name=g
# ] a d !!! type error here !!!
# in (e,) }
with self.assertRaisesRegex(TypeError, "passes operand"):
core.check_jaxpr(jaxpr)
# Restore the original jaxpr:
jaxpr.eqns[0].invars[1] = c
core.check_jaxpr(jaxpr) # no problems here...
# Let's introduce another type error by setting the call result let binders
# to have the wrong type:
Reland https://github.com/google/jax/pull/10573. The original PR was reverted because of downstream breakage. Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.) But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore. So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s. In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like: ``` import jax def f(x): for _ in range(3): x = jax.numpy.sin(x) return x jaxpr = jax.make_jaxpr(f)(3.) print(jaxpr) # { lambda ; a:f32[]. let # b:f32[] = sin a # c:f32[] = sin b # d:f32[] = sin c # in (d,) } _, eqn, _ = jaxpr.jaxpr.eqns print(eqn) # a:f32[] = sin b ``` Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print! So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`): ``` JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=()))) ``` PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
jaxpr.eqns[0].outvars[0] = core.Var('', d.aval)
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
# e:f32[b] = xla_call[ !!! type error here !!!
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }
# name=g
# ] a c
# in (h,) }
with self.assertRaisesRegex(TypeError, "inconsistently typed as"):
core.check_jaxpr(jaxpr)
def test_check_jaxpr_key_reuse(self):
with config.debug_key_reuse(True):
def f(seed):
key = jax.random.key(seed)
return jax.random.uniform(key) + jax.random.normal(key)
with jax.enable_checks(True):
with self.assertRaises(jax.errors.KeyReuseError):
jax.jit(f)(0)
2018-11-17 18:03:33 -08:00
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())