# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import re import unittest import numpy as np from absl.testing import absltest import jax import jax.numpy as jnp from jax import lax from jax import config from jax.interpreters import batching import jax._src.lib import jax._src.util from jax._src import core from jax._src import test_util as jtu config.parse_flags_with_absl() FLAGS = config.FLAGS @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") class DynamicShapeStagingTest(jtu.JaxTestCase): def test_basic_staging(self): def f(x, _): return x x = jnp.arange(3) y = jnp.ones((3, 4)) jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x, y) # { lambda ; a:i32[] b:i32[a] c:f32[a,4]. let in (b,) } self.assertLen(jaxpr.in_avals, 3) self.assertLen(jaxpr.in_avals[0].shape, 0) self.assertLen(jaxpr.in_avals[1].shape, 1) self.assertLen(jaxpr.in_avals[2].shape, 2) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[1].shape[0]) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) self.assertLen(jaxpr.out_avals, 1) self.assertLen(jaxpr.out_avals[0].shape, 1) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) def test_basic_staging_repeated(self): def f(x, _): return x x = jnp.arange(3) y = jnp.ones((3, 3)) jaxpr = jax.make_jaxpr(f, abstracted_axes=(('n',), ('n', 'n')))(x, y) # { lambda ; a:i32[] b:i32[a] c:f32[a,a]. let in (b,) } self.assertLen(jaxpr.in_avals, 3) self.assertLen(jaxpr.in_avals[0].shape, 0) self.assertLen(jaxpr.in_avals[1].shape, 1) self.assertLen(jaxpr.in_avals[2].shape, 2) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[1].shape[0]) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[1]) self.assertLen(jaxpr.out_avals, 1) self.assertLen(jaxpr.out_avals[0].shape, 1) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) def test_basic_staging_multiple_shape_vars(self): def f(x, _): return x x = jnp.arange(3) y = jnp.ones((4, 3)) jaxpr = jax.make_jaxpr(f, abstracted_axes=(('n',), ('m', 'n')))(x, y) # { lambda ; a:i32[] b: i32[] c:i32[a] d:f32[b,a]. let in (c,) } self.assertLen(jaxpr.in_avals, 4) self.assertLen(jaxpr.in_avals[0].shape, 0) self.assertLen(jaxpr.in_avals[1].shape, 0) self.assertLen(jaxpr.in_avals[2].shape, 1) self.assertLen(jaxpr.in_avals[3].shape, 2) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.in_avals[3].shape[0]) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[3].shape[1]) self.assertLen(jaxpr.out_avals, 1) self.assertLen(jaxpr.out_avals[0].shape, 1) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) def test_basic_add(self): def f(x, y): return x + y x = jnp.arange(3) y = jnp.arange(1, 4) jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x, y) # { lambda ; a:i32[] b:i32[a] c:i32[a]. let d:i32[a] = add b c in (d,) } self.assertLen(jaxpr.eqns, 1) self.assertLen(jaxpr.out_avals, 1) self.assertLen(jaxpr.out_avals[0].shape, 1) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) def test_basic_jnp(self): def f(x): y = x + jnp.sin(x) return y.sum() x = jnp.ones((3, 4)) jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x) # { lambda ; a:i32[] b:f32[a,4]. let # c:f32[a,4] = sin b # d:f32[a,4] = add b c # e:f32[] = reduce_sum[axes=(0, 1)] d # in (e,) } self.assertLen(jaxpr.in_avals, 2) self.assertLen(jaxpr.eqns, 3) # sin, add, and reduce_sum self.assertLen(jaxpr.out_avals, 1) self.assertLen(jaxpr.out_avals[0].shape, 0) def test_shape_errors_var_and_lit(self): def f(x, y): return jnp.sin(x) + y x = np.ones(3) y = np.ones(3) with self.assertRaisesRegex( Exception, '[Ii]ncompatible shapes for broadcasting'): _ = jax.make_jaxpr(f, abstracted_axes=({0: 'n'}, {}))(x, y) def test_shape_errors_distinct_vars(self): def f(x, y): return jnp.sin(x) + y x = np.ones(3) y = np.ones(3) with self.assertRaisesRegex( Exception, '[Ii]ncompatible shapes for broadcasting'): _ = jax.make_jaxpr(f, abstracted_axes=({0: 'n'}, {0: 'm'}))(x, y) def test_basic_dot(self): A = jnp.ones((3, 4)) x = jnp.ones(4) jaxpr = jax.make_jaxpr(jnp.dot, abstracted_axes=(('m', 'n'), ('n',)))(A, x) # { lambda ; a:i32[] b:i32[] c:f32[a,b] d:f32[b]. let # e:f32[a] = dot_general[dimension_numbers=(((1,), (0,)), ((), ()))] c d # in (e,) } self.assertLen(jaxpr.in_avals, 4) self.assertLen(jaxpr.in_avals[0].shape, 0) # two shape vars self.assertLen(jaxpr.in_avals[1].shape, 0) self.assertLen(jaxpr.in_avals[2].shape, 2) # one matrix self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.in_avals[2].shape[1]) self.assertLen(jaxpr.in_avals[3].shape, 1) # one vector self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.in_avals[3].shape[0]) self.assertLen(jaxpr.eqns, 1) self.assertLen(jaxpr.out_avals, 1) self.assertLen(jaxpr.out_avals[0].shape, 1) # output vector self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) def test_basic_broadcast(self): def f(x, n): return lax.broadcast(x, (n,)) jaxpr = jax.make_jaxpr(f)(jnp.ones(4), 3) # { lambda ; a:f32[4] b:i32[]. let # c:f32[b,4] = broadcast_in_dim[bcast_dims=(1,) shape=(None, 4)] a b # in (c,) } self.assertLen(jaxpr.in_avals, 2) self.assertLen(jaxpr.eqns, 1) self.assertLen(jaxpr.out_avals, 1) self.assertLen(jaxpr.out_avals[0].shape, 2) self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.out_avals[0].shape[0]) self.assertEqual(4, jaxpr.out_avals[0].shape[1]) def test_basic_batchpoly_neuralnet(self): def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.tanh(outputs) return outputs def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return jnp.sum((preds - targets) ** 2) sizes = [784, 128, 128, 10] params = [(jnp.ones((input_dim, output_dim)), jnp.ones(output_dim)) for input_dim, output_dim in zip(sizes[:-1], sizes[1:])] batch = (jnp.ones((32, 784)), jnp.ones((32, 10))) # Mainly we want to test that make_jaxpr doesn't crash here. jaxpr = jax.make_jaxpr(loss, abstracted_axes=({}, {0: 'n'}))(params, batch) self.assertLen(jaxpr.in_avals, 9) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[-2].shape[0]) self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[-1].shape[0]) self.assertLen(jaxpr.out_avals, 1) self.assertLen(jaxpr.out_avals[0].shape, 0) def test_closing_over_polymorphic_shape(self): def f(n): x = jnp.zeros(n) return jax.jit(lambda: x)() jaxpr = jax.make_jaxpr(f)(3) # { lambda ; a:i32[]. let # b:f32[a] = bcast[dims=() shape=(None,)] 0.0 a # c:f32[a] = pjit[ # jaxpr={ lambda ; d:i32[] e:f32[d]. let in (e,) } # name= # ] a b # in (c,) } a, = jaxpr.jaxpr.invars c, = jaxpr.jaxpr.outvars self.assertLen(c.aval.shape, 1) self.assertIs(a, c.aval.shape[0]) def test_closing_over_dynamic_shape(self): def f(n): m = 2 * n x = jnp.zeros(m) return jax.jit(lambda: x)() # { lambda ; a:i32[]. let # b:i32[] = mul a 2 # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 0.0 b # d:f32[b] = pjit[ # jaxpr={ lambda ; e:i32[] f:f32[e]. let in (f,) } # name= # ] b c # in (d,) } jaxpr = jax.make_jaxpr(f)(3) b, = jaxpr.jaxpr.eqns[0].outvars c, = jaxpr.jaxpr.eqns[1].outvars d, = jaxpr.jaxpr.eqns[2].outvars self.assertLen(c.aval.shape, 1) self.assertIs(b, c.aval.shape[0]) self.assertLen(d.aval.shape, 1) self.assertIs(b, d.aval.shape[0]) def test_closing_over_polymorphic_shape_and_adding(self): def f(n): x = jnp.zeros(n) y = jnp.zeros(n) @jax.jit def g(): return x + y return g() # { lambda ; a:i32[]. let # b:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a # c:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a # d:f32[a] = pjit[ # jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let # h:f32[e] = add f g # in (h,) } # name=g # ] a b c # in (d,) } jaxpr = jax.make_jaxpr(f)(3) # doesn't fail on the addition! a, = jaxpr.jaxpr.invars b, = jaxpr.jaxpr.eqns[0].outvars c, = jaxpr.jaxpr.eqns[1].outvars d, = jaxpr.jaxpr.eqns[2].outvars self.assertIs(a, b.aval.shape[0]) self.assertIs(a, c.aval.shape[0]) self.assertIs(a, d.aval.shape[0]) def test_passing_in_equal_polymorphic_shapes_and_adding(self): def f(n): x = jnp.zeros(n) @jax.jit def g(x, y): return x + y return g(x, x) # { lambda ; a:i32[]. let # b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 0.0 a # c:f32[a] = pjit[ # jaxpr={ lambda ; d:i32[] e:f32[d] f:f32[d]. let # g:f32[d] = add e f # in (g,) } # name=g # ] a b b # in (c,) } jaxpr = jax.make_jaxpr(f)(3) a, = jaxpr.jaxpr.invars c, = jaxpr.jaxpr.outvars self.assertLen(c.aval.shape, 1) self.assertIs(a, c.aval.shape[0]) @unittest.skip("doesn't work yet: shape error b/c we don't notice x and y same") def test_closing_over_and_passing_arg_addition(self): # TODO(mattjj,dougalm): currently fails to notice equal shapes, fix! def f(n): x = jnp.zeros(n) @jax.jit def g(y): return x + y return g(x) _ = jax.make_jaxpr(f)(3) @unittest.skip("doesn't work yet: shape error b/c we don't notice x and jnp.zeros(m) same") def test_closing_over_and_passing_size_addition(self): # TODO(mattjj,dougalm): currently fails to notice equal shapes, fix! def f(n): x = jnp.zeros(n) @jax.jit def g(m): return jnp.zeros(m) + x return g(n) _ = jax.make_jaxpr(f)(3) def test_closing_over_and_broadcasting_polymorphic_shape(self): def f(n): x = jnp.zeros(n) @jax.jit def g(): return jnp.zeros(n) + x return g() # { lambda ; a:i32[]. let # b:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a # c:f32[a] = pjit[ # jaxpr={ lambda ; d:i32[] e:f32[d]. let # f:f32[d] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 d # g:f32[d] = add f e # in (g,) } # name=g # ] a b # in (c,) } jaxpr = jax.make_jaxpr(f)(3) a, = jaxpr.jaxpr.invars c, = jaxpr.jaxpr.outvars self.assertLen(c.aval.shape, 1) self.assertIs(a, c.aval.shape[0]) def test_closing_over_repeated_shapes(self): def zeros(shape): if not isinstance(shape, (tuple, list)): shape = shape, return lax.broadcast(0., shape) def f(n): m = 2 * n x = zeros((m, m)) return jax.jit(lambda: x.sum(0))() # { lambda ; a:i32[]. let # b:i32[] = mul a 2 # c:f32[b,b] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 0.0 # b b # d:f32[b] = pjit[ # jaxpr={ lambda ; e:i32[] f:f32[e,e]. let # g:f32[e] = reduce_sum[axes=(0,)] f # in (g,) } # name= # ] b c # in (d,) } jaxpr = jax.make_jaxpr(f)(3) a, = jaxpr.jaxpr.invars b, = jaxpr.jaxpr.eqns[0].outvars c, = jaxpr.jaxpr.eqns[1].outvars d, = jaxpr.jaxpr.eqns[2].outvars b_, c_ = jaxpr.jaxpr.eqns[2].invars self.assertLen(c.aval.shape, 2) self.assertIs(c.aval.shape[0], b) self.assertIs(c.aval.shape[1], b) self.assertIs(b, b_) self.assertIs(c, c_) self.assertLen(d.aval.shape, 1) self.assertIs(d.aval.shape[0], b) def test_staging_repeated_nested(self): def zeros(shape): if not isinstance(shape, (tuple, list)): shape = shape, return lax.broadcast(jnp.float32(0.), shape) def f(n): m = 2 * n x = zeros((m, n)) y = zeros(m) return jax.jit(lambda x, y: x.sum(1) + y)(x, y) # { lambda ; a:i32[]. let # b:i32[] = mul a 2 # c:f32[b,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 0.0 # b a # d:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b # e:f32[b] = pjit[ # jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f,g] i:f32[f]. let # j:f32[f] = reduce_sum[axes=(1,)] h # k:f32[f] = add j i # in (k,) } # name= # ] b a c d # in (e,) } jaxpr = jax.make_jaxpr(f)(jnp.int32(3)) a, = jaxpr.jaxpr.invars b, = jaxpr.jaxpr.eqns[0].outvars c, = jaxpr.jaxpr.eqns[1].outvars d, = jaxpr.jaxpr.eqns[2].outvars e, = jaxpr.jaxpr.eqns[3].outvars b_, a_, c_, d_ = jaxpr.jaxpr.eqns[3].invars self.assertLen(c.aval.shape, 2) self.assertIs(c.aval.shape[0], b) self.assertIs(c.aval.shape[1], a) self.assertLen(e.aval.shape, 1) self.assertIs(e.aval.shape[0], b) self.assertIs(a, a_) self.assertIs(b, b_) self.assertIs(c, c_) self.assertIs(d, d_) def test_jit_abstracted_axes_staging(self): # We just test make_jaxpr-of-jit because dynamic shape compilation/execution # may not be supported. @partial(jax.jit, abstracted_axes=('n',)) def f(x): return jnp.sum(x) jaxpr = jax.make_jaxpr(f)(jnp.ones(3, jnp.dtype('float32'))) # { lambda ; a:f32[3]. let # b:f32[] = pjit[ # jaxpr={ lambda ; c:i32[] d:f32[c]. let # e:f32[] = reduce_sum[axes=(0,)] d # in (e,) } # name=f # ] 3 a # in (b,) } a, = jaxpr.jaxpr.invars e, = jaxpr.jaxpr.eqns self.assertLen(e.invars, 2) self.assertIsInstance(e.invars[0], core.Literal) self.assertIs(e.invars[1], a) b, = e.outvars self.assertLen(b.aval.shape, 0) subjaxpr = e.params['jaxpr'] c, d = subjaxpr.jaxpr.invars self.assertLen(c.aval.shape, 0) self.assertLen(d.aval.shape, 1) self.assertIs(d.aval.shape[0], c) def test_jit_abstracted_axes_staging2(self): @partial(jax.jit, abstracted_axes=('n',)) def fun(x): return jnp.sum(x) jaxpr = jax.make_jaxpr(lambda n: fun(jnp.ones(n + n, jnp.dtype('float32'))) )(3) # { lambda ; a:i32[]. let # b:i32[] = add a a # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b # d:f32[] = pjit[ # jaxpr={ lambda ; e:i32[] f:f32[e]. let # g:f32[] = reduce_sum[axes=(0,)] f # in (g,) } # name=f # ] b c # in (d,) } a, = jaxpr.jaxpr.invars e1, e2, e3 = jaxpr.jaxpr.eqns b, = e1.outvars c, = e2.outvars b_, c_ = e3.invars self.assertIs(b, b_) self.assertIs(c, c_) subjaxpr = e3.params['jaxpr'] e, f = subjaxpr.jaxpr.invars self.assertLen(e.aval.shape, 0) self.assertLen(f.aval.shape, 1) self.assertIs(f.aval.shape[0], e) def test_jit_abstracted_axes_staging3(self): f = jax.jit(jnp.sum, abstracted_axes=('n',)) jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3.)) # { lambda ; a:i32[] b:f32[a]. let # c:f32[] = pjit[ # jaxpr={ lambda ; d:i32[] e:f32[d]. let # f:f32[] = reduce_sum[axes=(0,)] e # in (f,) } # name=sum # ] a b # in (c,) } a, b = jaxpr.jaxpr.invars e, = jaxpr.jaxpr.eqns self.assertIs(e.invars[0], a) self.assertIs(e.invars[1], b) c, = e.outvars self.assertLen(c.aval.shape, 0) subjaxpr = e.params['jaxpr'] d, e = subjaxpr.jaxpr.invars self.assertLen(d.aval.shape, 0) self.assertLen(e.aval.shape, 1) self.assertIs(e.aval.shape[0], d) def test_jit_abstracted_axes_return_polymorphic_shape(self): f = jax.jit(lambda x: x, abstracted_axes=('n',)) jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash # { lambda ; a:i32[3]. let # b:i32[3] = pjit[ # jaxpr={ lambda ; c:i32[] d:i32[c]. let in (d,) } # name= # ] 3 a # in (b,) } a, = jaxpr.jaxpr.invars e, = jaxpr.jaxpr.eqns three, a_ = e.invars b, = e.outvars self.assertIsInstance(three, core.Literal) self.assertEqual(three.val, 3) self.assertIs(a_, a) self.assertLen(b.aval.shape, 1) self.assertEqual(b.aval.shape[0], 3) def test_jit_abstracted_axes_return_polymorphic_shape2(self): f = jax.jit(lambda n: jnp.ones(n)) # TODO(mattjj,dougalm): support dynamic shapes in type checker with jax.enable_checks(False): jaxpr = jax.make_jaxpr(f)(3) # { lambda ; a:i32[]. let # b:f32[a] = pjit[ # jaxpr={ lambda ; c:i32[]. let # d:f32[c] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 # c # in (d,) } # name= # ] a # in (b,) } a, = jaxpr.jaxpr.invars e, = jaxpr.jaxpr.eqns a_, = e.invars self.assertIs(a, a_) b, = e.outvars a__, = b.aval.shape self.assertIs(a, a__) with jax.enable_checks(False): jaxpr = jax.make_jaxpr(lambda: f(3))() # { lambda ; . let # a:f32[3] = pjit[ # jaxpr={ lambda ; b:i32[]. let # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 # b # in (c,) } # name= # ] 3 # in (a,) } () = jaxpr.jaxpr.invars e, = jaxpr.jaxpr.eqns three, = e.invars self.assertIsInstance(three, core.Literal) self.assertEqual(three.val, 3) b, = e.outvars three_, = b.aval.shape self.assertIsInstance(three_, int) self.assertEqual(three_, 3) def test_zero_size_checking(self): def f(x): if core.definitely_equal(x.size, 0): return x else: return -x x = jnp.zeros(1) jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x) # doesn't crash self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 1) y = jnp.zeros((2, 0)) jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(y) # doesn't crash self.assertLen(jaxpr.jaxpr.eqns, 0) def test_flattening_basic(self): x = jnp.zeros((2, 3, 4, 5)) # don't need to divide or multiply any dynamic axis sizes jaxpr = jax.make_jaxpr(lambda x: x.reshape(x.shape[0], -1), abstracted_axes={0: 'n'})(x) self.assertLen(jaxpr.jaxpr.eqns, 1) jaxpr = jax.make_jaxpr(lambda x: x.reshape(3, x.shape[0], -1), abstracted_axes={0: 'n'})(x) self.assertLen(jaxpr.jaxpr.eqns, 1) jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, x.shape[0]), abstracted_axes={0: 'n'})(x) self.assertLen(jaxpr.jaxpr.eqns, 1) # don't need to divide but do need a dynamic axis size in multiplication # (so to typecheck we'd need nontrivial reductions) jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1), abstracted_axes={0: 'n'})(x) self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) # may have mul with 1 self.assertEqual(str(jaxpr.jaxpr.eqns[-2].primitive), 'mul') self.assertEqual(str(jaxpr.jaxpr.eqns[-1].primitive), 'reshape') jaxpr = jax.make_jaxpr(lambda x: x.reshape(2, -1), abstracted_axes={0: 'n'})(x) self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, 12), abstracted_axes={0: 'n'})(x) self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) # do need divide, also shouldn't typecheck _ = jax.make_jaxpr(lambda x: x.reshape(x.shape[0], x.shape[0], -1), abstracted_axes={0: 'n'})(x) # don't crash @unittest.skip("Test does not work with jax.Array") @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") class DynamicShapeAutodiffTest(jtu.JaxTestCase): def test_jvp_broadcast(self): @jax.jit def fn(n, x): return lax.broadcast_in_dim(x, (n,), ()) outer_jaxpr = jax.make_jaxpr( lambda x, t: jax.jvp(lambda y: fn(3, y), (x,), (t,)) )(3., 4.) # { lambda ; a:f32[] b:f32[]. let # c:f32[3] d:f32[3] = pjit[ # jaxpr={ lambda ; e:i32[] f:f32[] g:f32[]. let # h:f32[e] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] f e # i:f32[e] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] g e # in (h, i) } # name=f # ] 3 a b # in (c, d) } self.assertLen(outer_jaxpr.jaxpr.eqns, 1) eqn, = outer_jaxpr.jaxpr.eqns self.assertIn('jaxpr', eqn.params) jaxpr = eqn.params['jaxpr'].jaxpr self.assertLen(jaxpr.invars, 3) e, f, g = jaxpr.invars self.assertEqual(e.aval.shape, ()) self.assertEqual(f.aval.shape, ()) self.assertEqual(g.aval.shape, ()) self.assertLen(jaxpr.outvars, 2) h, i = jaxpr.outvars self.assertEqual(h.aval.shape, (e,)) self.assertEqual(i.aval.shape, (e,)) self.assertLen(eqn.outvars, 2) c, d = eqn.outvars self.assertEqual(c.aval.shape, (3,)) self.assertEqual(d.aval.shape, (3,)) def test_jvp_basic(self): @partial(jax.jit, abstracted_axes=('n',)) def foo(x): return jnp.sin(x) x = t = jnp.arange(3.) outer_jaxpr = jax.make_jaxpr(lambda x, t: jax.jvp(foo, (x,), (t,)))(x, t) # { lambda ; a:f32[3] b:f32[3]. let # c:f32[3] d:f32[3] = pjit[ # jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let # h:f32[e] = sin f # i:f32[e] = cos f # j:f32[e] = mul g i # in (h, j) } # name=f # ] 3 a b # in (c, d) } self.assertLen(outer_jaxpr.jaxpr.eqns, 1) eqn, = outer_jaxpr.eqns self.assertIn('jaxpr', eqn.params) jaxpr = eqn.params['jaxpr'].jaxpr self.assertLen(jaxpr.invars, 3) e, f, g = jaxpr.invars self.assertEqual(e.aval.shape, ()) self.assertEqual(f.aval.shape, (e,)) self.assertEqual(g.aval.shape, (e,)) self.assertLen(jaxpr.outvars, 2) self.assertLen(eqn.outvars, 2) c, d = eqn.outvars self.assertEqual(c.aval.shape, (3,)) self.assertEqual(d.aval.shape, (3,)) def test_linearize_basic(self): @partial(jax.jit, abstracted_axes=('n',)) def foo(x): return jax.lax.sin(x) x = jnp.arange(3.) # primal computation outer_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(foo, x))(x) # { lambda ; a:f32[3]. let # b:f32[3] c:f32[3] = pjit[ # jaxpr={ lambda ; d:i32[] e:f32[d]. let # f:f32[d] = sin e # g:f32[d] = cos e # in (f, g) } # name=foo # ] 3 a # in (b, c) } self.assertLen(outer_jaxpr.jaxpr.eqns, 1) eqn, = outer_jaxpr.jaxpr.eqns self.assertIn('jaxpr', eqn.params) jaxpr = eqn.params['jaxpr'].jaxpr self.assertLen(jaxpr.invars, 2) d, e = jaxpr.invars self.assertEqual(d.aval.shape, ()) self.assertEqual(e.aval.shape, (d,)) self.assertLen(jaxpr.eqns, 2) self.assertLen(jaxpr.outvars, 2) f, g = jaxpr.outvars self.assertEqual(jaxpr.eqns[0].outvars, [f]) self.assertEqual(jaxpr.eqns[1].outvars, [g]) self.assertLen(eqn.outvars, 2) b, c = eqn.outvars self.assertEqual(b.aval.shape, (3,)) self.assertEqual(c.aval.shape, (3,)) # primal and tangent computation outer_jaxpr = jax.make_jaxpr( lambda x, xdot: jax.linearize(foo, x)[1](xdot))(x, x) # { lambda ; a:f32[3] b:f32[3]. let # _:f32[3] c:f32[3] = pjit[ # jaxpr={ lambda ; d:i32[] e:f32[d]. let # f:f32[d] = sin e # g:f32[d] = cos e # in (f, g) } # name=foo # ] 3 a # h:f32[3] = pjit[ # jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[i]. let # l:f32[i] = mul k j # in (l,) } # name=foo # ] 3 c b # in (h,) } self.assertLen(outer_jaxpr.jaxpr.eqns, 2) _, eqn = outer_jaxpr.jaxpr.eqns self.assertIn('jaxpr', eqn.params) jaxpr = eqn.params['jaxpr'].jaxpr self.assertLen(jaxpr.invars, 3) i, j, k = jaxpr.invars self.assertEqual(i.aval.shape, ()) self.assertEqual(j.aval.shape, (i,)) self.assertEqual(k.aval.shape, (i,)) self.assertLen(eqn.outvars, 1) h, = eqn.outvars self.assertEqual(h.aval.shape, (3,)) def test_linearize_basic2(self): @partial(jax.jit, abstracted_axes=('n',)) def foo(x): return jax.jit(jax.lax.sin)(x) x = jnp.arange(3.) outer_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(foo, x))(x) # { lambda ; a:f32[3]. let # b:f32[3] c:f32[3] = pjit[ # jaxpr={ lambda ; d:i32[] e:f32[d]. let # f:f32[d] g:f32[d] = pjit[ # jaxpr={ lambda ; h:i32[] i:f32[h]. let # j:f32[h] = sin i # k:f32[h] = cos i # in (j, k) } # name=sin # ] d e # in (f, g) } # name=foo # ] 3 a # in (b, c) } self.assertLen(outer_jaxpr.jaxpr.eqns, 1) eqn, = outer_jaxpr.jaxpr.eqns self.assertLen(eqn.outvars, 2) b, c = eqn.outvars self.assertEqual(b.aval.shape, (3,)) self.assertEqual(c.aval.shape, (3,)) def test_grad_basic(self): @partial(jax.jit, abstracted_axes=('n',)) def foo(x): y = jax.lax.sin(x) return y.sum() x = jnp.arange(3.) outer_jaxpr = jax.make_jaxpr(jax.grad(foo))(x) # { lambda ; a:f32[3]. let # _:f32[] b:f32[3] = pjit[ # jaxpr={ lambda ; c:i32[] d:f32[c]. let # e:f32[c] = sin d # f:f32[c] = cos d # g:f32[] = reduce_sum[axes=(0,)] e # in (g, f) } # name=foo # ] 3 a # h:f32[3] = pjit[ # jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[]. let # l:f32[i] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] k i # m:f32[i] = mul l j # in (m,) } # name=foo # ] 3 b 1.0 # in (h,) } self.assertLen(outer_jaxpr.jaxpr.eqns, 2) fwd_eqn, bwd_eqn = outer_jaxpr.jaxpr.eqns self.assertIn('jaxpr', fwd_eqn.params) fwd_jaxpr = fwd_eqn.params['jaxpr'].jaxpr self.assertLen(fwd_jaxpr.invars, 2) c, d = fwd_jaxpr.invars self.assertEqual(c.aval.shape, ()) self.assertEqual(d.aval.shape, (c,)) self.assertLen(fwd_jaxpr.outvars, 2) g, f = fwd_jaxpr.outvars self.assertEqual(g.aval.shape, ()) self.assertEqual(f.aval.shape, (c,)) self.assertLen(fwd_eqn.outvars, 2) _, b = fwd_eqn.outvars self.assertEqual(b.aval.shape, (3,)) self.assertIn('jaxpr', bwd_eqn.params) bwd_jaxpr = bwd_eqn.params['jaxpr'].jaxpr self.assertLen(bwd_jaxpr.invars, 3) i, j, k = bwd_jaxpr.invars self.assertEqual(i.aval.shape, ()) self.assertEqual(j.aval.shape, (i,)) self.assertEqual(k.aval.shape, ()) self.assertLen(bwd_jaxpr.outvars, 1) m, = bwd_jaxpr.outvars self.assertEqual(m.aval.shape, (i,)) self.assertLen(bwd_eqn.outvars, 1) h, = bwd_eqn.outvars self.assertEqual(h.aval.shape, (3,)) def test_mlp_autodiff_dynamic_batch_toplevel(self): def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.maximum(0, outputs) return outputs def loss(params, batch): inputs, targets = batch predictions = predict(params, inputs) return jnp.sum((predictions - targets) ** 2) batch = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) params = [(jnp.ones((784, 256)), jnp.ones(256)), (jnp.ones((256, 256)), jnp.ones(256)), (jnp.ones((256, 10)), jnp.ones( 10))] # jvp def loss_jvp(params, batch): return jax.jvp(loss, (params, batch), (params, batch)) jaxpr = jax.make_jaxpr(loss_jvp, abstracted_axes=({}, {0: 'n'}))(params, batch) core.check_jaxpr(jaxpr.jaxpr) # linearize def loss_lin(params, batch): y, f_lin = jax.linearize(loss, params, batch) y_dot = f_lin(params, batch) return y, y_dot jaxpr = jax.make_jaxpr(loss_lin, abstracted_axes=({}, {0: 'n'}))(params, batch) core.check_jaxpr(jaxpr.jaxpr) # grad jaxpr = jax.make_jaxpr(jax.grad(loss), abstracted_axes=({}, {0: 'n'}))(params, batch) core.check_jaxpr(jaxpr.jaxpr) def test_mlp_autodiff_dynamic_batch_inner(self): # This is like the above 'toplevel' test, but instead of introducing # abstracted axes on the make_jaxpr call, we do it on a jit. @partial(jax.jit, abstracted_axes=({}, {0: 'n'})) def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.maximum(0, outputs) return outputs def loss(params, batch): inputs, targets = batch predictions = predict(params, inputs) return jnp.sum((predictions - targets) ** 2) batch = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) params = [(jnp.ones((784, 256)), jnp.ones(256)), (jnp.ones((256, 256)), jnp.ones(256)), (jnp.ones((256, 10)), jnp.ones( 10))] # jvp def loss_jvp(params, batch): return jax.jvp(loss, (params, batch), (params, batch)) jaxpr = jax.make_jaxpr(loss_jvp)(params, batch) core.check_jaxpr(jaxpr.jaxpr) # linearize def loss_lin(params, batch): y, f_lin = jax.linearize(loss, params, batch) y_dot = f_lin(params, batch) return y, y_dot jaxpr = jax.make_jaxpr(loss_lin)(params, batch) core.check_jaxpr(jaxpr.jaxpr) # grad jaxpr = jax.make_jaxpr(jax.grad(loss))(params, batch) core.check_jaxpr(jaxpr.jaxpr) def test_bint_broadcast(self): d = lax.convert_element_type(3, core.bint(5)) bint = lambda x, b: lax.convert_element_type(x, core.bint(b)) x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash self.assertIsInstance(x, core.DArray) self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) self.assertEqual( x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, True)) def f(n): return jnp.zeros(n) x = jax.jit(f)(d) self.assertIsInstance(x, core.DArray) self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) self.assertEqual( x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, False)) jaxpr = jax.make_jaxpr(f)(d).jaxpr # { lambda ; a:bint{≤5}[]. let # b:f32[a] = broadcast_in_dim[...] 0.0 a # in (b,) } self.assertLen(jaxpr.invars, 1) a, = jaxpr.invars self.assertEqual(a.aval, core.DShapedArray((), core.bint(5))) self.assertLen(jaxpr.eqns, 1) eqn, = jaxpr.eqns self.assertLen(eqn.outvars, 1) b, = eqn.outvars self.assertEqual(b.aval.shape, (a,)) def test_vmap_abstracted_axis(self): def foo(x, y): z = jax.vmap(jnp.sin)(x) * y return jax.vmap(jnp.add)(x, z) x = jnp.arange(3.) jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n',))(x, x).jaxpr self.assertLen(jaxpr.invars, 3) a, b, c = jaxpr.invars self.assertEqual(a.aval.shape, ()) self.assertEqual(b.aval.shape, (a,)) self.assertEqual(c.aval.shape, (a,)) self.assertLen(jaxpr.eqns, 3) self.assertLen(jaxpr.outvars, 1) f, = jaxpr.outvars self.assertEqual(f.aval.shape, (a,)) def test_vmap_abstracted_axes_2d(self): def foo(x, y): z = jax.vmap(jax.vmap(jnp.sin))(x) * y return jax.vmap(jax.vmap(jnp.add))(x, z) x = jnp.arange(12.).reshape(3, 4) jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n', 'm'))(x, x).jaxpr self.assertLen(jaxpr.invars, 4) a, b, c, d = jaxpr.invars self.assertEqual(a.aval.shape, ()) self.assertEqual(b.aval.shape, ()) self.assertEqual(c.aval.shape, (a, b)) self.assertEqual(c.aval.shape, (a, b)) self.assertLen(jaxpr.eqns, 3) self.assertLen(jaxpr.outvars, 1) f, = jaxpr.outvars self.assertEqual(f.aval.shape, (a, b)) def test_vmap_of_indexing_basic(self): x = jnp.arange(3.) def f(idxs): return jax.vmap(lambda i: x[i])(idxs) idxs = jnp.arange(3) jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(idxs).jaxpr # { lambda a:f32[3]; b:i32[] c:i32[b]. let # d:bool[b] = lt c 0 # e:i32[b] = add c 3 # f:i32[b] = select_n d c e # g:i32[b,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None, 1)] f b # h:f32[b,1] = gather[ # dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)) # fill_value=None # indices_are_sorted=False # mode=GatherScatterMode.PROMISE_IN_BOUNDS # slice_sizes=(1,) # unique_indices=False # ] a g # i:f32[b] = squeeze[dimensions=(1,)] h # in (i,) } b, _ = jaxpr.invars e, = (e for e in jaxpr.eqns if str(e.primitive) == 'gather') h, = e.outvars self.assertEqual(h.aval.shape, (b, 1)) def test_einsum_basic(self): x = jnp.arange(20.).reshape(4, 5) def f(x): return jnp.einsum('ij,kj->ik', x, x) jaxpr = jax.make_jaxpr(f, abstracted_axes=('n', 'm'))(x).jaxpr # { lambda ; a:i32[] b:i32[] c:f32[a,b]. let # d:f32[a,a] = pjit[ # jaxpr={ lambda ; e:i32[] f:i32[] g:f32[e,f] h:f32[e,f]. let # i:f32[e,e] = dot_general[ # dimension_numbers=(((1,), (1,)), ((), ())) # precision=None # preferred_element_type=None # ] g h # in (i,) } # name=_einsum # ] a b c c # in (d,) } self.assertLen(jaxpr.invars, 3) a, b, c = jaxpr.invars self.assertEqual(c.aval.shape[0], a) self.assertLen(jaxpr.eqns, 1) self.assertLen(jaxpr.eqns[0].outvars, 1) d, = jaxpr.eqns[0].outvars self.assertEqual(d.aval.shape, (a, a)) def test_inferring_valid_subjaxpr_type_add(self): def f(x): return x + x.shape[0] jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash def test_slicing_basic_jaxpr(self): def f(x): return x[0] jaxpr = jax.make_jaxpr(f, abstracted_axes=(None, 'n'))(jnp.zeros((3, 4))) # { lambda ; a:i32[] b:f32[3,a]. let # c:f32[1,a] = dynamic_slice[slice_sizes=(1, None)] b 0 0 a # d:f32[a] = squeeze[dimensions=(0,)] c # in (d,) } self.assertLen(jaxpr.jaxpr.invars, 2) a, _ = jaxpr.jaxpr.invars self.assertLen(jaxpr.jaxpr.outvars, 1) d, = jaxpr.jaxpr.outvars self.assertLen(d.aval.shape, 1) self.assertEqual(d.aval.shape, (a,)) def test_shape_tuple_argument_to_zeros(self): @partial(jax.jit, abstracted_axes=(('n',), ('n',))) def f(x, y): zero = jnp.zeros(jnp.shape(x)) return zero * y x = jnp.arange(3.0) y = jnp.arange(3.0) + 1 jax.make_jaxpr(f)(x, y) # doesn't crash @unittest.skip("Test does not work with jax.Array") @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") class DynamicShapeExecutionTest(jtu.JaxTestCase): @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_jit_basic_iree(self): @jax.jit def f(i): return jnp.sum(jnp.ones(i, dtype='float32')) self.assertAllClose(f(3), jnp.array(3., dtype='float32'), check_dtypes=True) @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_jit_basic_iree_2(self): count = 0 @partial(jax.jit, abstracted_axes=('n',)) def f(x): nonlocal count count += 1 return jnp.sum(x) x = f(np.arange(3)) y = f(np.arange(4)) self.assertAllClose(x, 3., check_dtypes=False) self.assertAllClose(y, 6., check_dtypes=False) self.assertEqual(count, 1) @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_jit_polymorphic_output_iree(self): # like test_jit_basic_iree, but without the jnp.sum! count = 0 @jax.jit def f(i): nonlocal count count += 1 return jnp.ones(i, dtype='float32') self.assertAllClose(f(3), np.ones(3, dtype='float32'), check_dtypes=True) self.assertAllClose(f(4), np.ones(4, dtype='float32'), check_dtypes=True) self.assertEqual(count, 1) @unittest.skip('TODO: need typechecking rule for concatenate') def test_concatenate(self): @partial(jax.jit, abstracted_axes=({0: 'n'},)) def f(x): # x: f32[n, 4] return jnp.concatenate([x, x, x], axis=0) f(np.ones((5, 4), dtype=np.float32)) # TODO: add assertions @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_reshape(self): @partial(jax.jit, abstracted_axes=({0: 'n'},)) def f(x): # x: f32[n, 4] return jnp.reshape(x, (2, -1)) f(np.ones((5, 4), dtype=np.float32)) # TODO: add assertions @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_nested(self): @jax.jit def nested_f(x): # f32[h, v] -> f32[h, v] # A nested call that needs shape variables return jnp.sin(x) @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'v'},)) def f(x): # f32[h, w] -> f32[h, w] return jnp.sin(x) + jax.jit(nested_f)(x) f(np.ones((3, 5), dtype=np.float32)) # TODO: add assertions @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_nested_arange(self): def nested_f(x): # f32[h, v] -> f32[h, v] # A nested call that needs to compute with shapes return jnp.arange(x.shape[0] * x.shape[1], dtype=x.dtype).reshape(x.shape) @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) def f(x): # f32[h, w] -> f32[h, w] return x + jax.jit(nested_f)(x) f(np.ones((3, 5), dtype=np.float32)) # TODO: add assertions @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test') def test_transpose(self): # see also https://github.com/iree-org/iree-jax/issues/57 @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) def f(x): # f32[h, w] -> f32[w, h] return x.T f(np.ones((3, 5), dtype=np.float32)) # doesn't crash # TODO: add assertions @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test') def test_matmul(self): @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) def f(x): # f32[w, w] -> f32[w, w] return jnp.matmul(x, x) f(np.ones((5, 5), dtype=np.float32)) # TODO: add assertions @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test') def test_matmul_shape_error(self): @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) def f(x): # f32[h, w] -> error return jnp.matmul(x, x) # TODO(necula): improve error message, print actual shapes with self.assertRaisesRegex(TypeError, re.escape("dot_general requires contracting dimensions to have the same shape, got")): f(np.ones((5, 5), dtype=np.float32)) @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") @unittest.skip("TODO: investigate failure") def test_cond(self): @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) def f(x): # f32[w, w] -> f32[w, w] return lax.cond(True, lambda x: jnp.sin(x), lambda x: jnp.matmul(x, x), x) f(np.ones((5, 5), dtype=np.float32)) # TODO: add assertions @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_arange(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) def f(x): # f32[w] -> f32[w] return jnp.arange(x.shape[0], dtype=x.dtype) + x f(np.ones((5,), dtype=np.float32)) # TODO: add assertions @unittest.skip('failing w/ iree error') @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_broadcast(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) def f(x): # f32[w] -> f32[w, w] return jnp.broadcast_to(x, (x.shape[0], x.shape[0])) f(np.ones((5,), dtype=np.float32)) # TODO: add assertions @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_zeros(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) def f(x): # f32[w] -> f32[w] return jnp.zeros(x.shape[0], dtype=x.dtype) + x f(np.ones((5,), dtype=np.float32)) # TODO: add assertions @unittest.skip('failing w/ iree error') @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_stack(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) def f(x): return jnp.stack([jnp.sin(x), jnp.cos(x)]) f(np.ones((5,), dtype=np.float32)) # TODO: add assertions @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_jit_dependent_pair_output_iree(self): # Like the above 'polymorhpic output' test, but now with a `2 * n`! count = 0 @jax.jit def f(n): nonlocal count count += 1 return jnp.arange(2 * n) x = f(3) y = f(4) self.assertAllClose(x, jnp.arange(2 * 3), check_dtypes=False) self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False) self.assertEqual(count, 1) @unittest.skip("revising slicing logic") def test_slicing_basic(self): f = jax.jit(lambda x, n: jnp.sum(x[:n])) # TODO(mattjj): revise getslice, add typecheck rule for it, enable checks with jax.enable_checks(False): ans = f(jnp.arange(10), 3) expected = jnp.sum(jnp.arange(10)[:3]) self.assertAllClose(ans, expected, check_dtypes=True) # TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize # operation 'while' that was explicitly marked illegal" @unittest.skip("revising slicing logic") def test_scan_basic(self): def cumsum(x): def body(i, _): return i + 1, jnp.sum(x[:i+1]) _, ans = lax.scan(body, 0, None, length=len(x)) return ans x = jnp.array([3, 1, 4, 1, 5, 9]) with jax.enable_checks(False): ans = cumsum(x) expected = jnp.cumsum(x) self.assertAllClose(ans, expected, check_dtypes=False) @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_jit_of_broadcast(self): x = jax.jit(jnp.ones)(3) self.assertAllClose(x, jnp.ones(3)) @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_jit_of_broadcast2(self): x = jax.jit(lambda n: jnp.ones(2 * n))(3) self.assertAllClose(x, jnp.ones(2 * 3)) @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_mlp_autodiff_dynamic_batch_iree(self): count = 0 def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.maximum(0, outputs) return outputs def loss_ref(params, batch): nonlocal count count += 1 # count retraces inputs, targets = batch predictions = predict(params, inputs) return jnp.sum((predictions - targets) ** 2) loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'})) params = [(jnp.ones((784, 256)), jnp.ones(256)), (jnp.ones((256, 10)), jnp.ones( 10))] # two different size batches batch1 = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) batch2 = (inputs, targets) = (jnp.ones((32, 784)), jnp.ones((32, 10))) _ = loss(params, batch1) _ = loss(params, batch2) self.assertEqual(count, 1) _ = jax.grad(loss)(params, batch1) _ = jax.grad(loss)(params, batch2) self.assertEqual(count, 2) ans = loss( params, batch1) expected = loss_ref(params, batch1) self.assertAllClose(ans, expected) ans = jax.grad(loss )(params, batch1) expected = jax.grad(loss_ref)(params, batch1) self.assertAllClose(ans, expected) @jax.enable_checks(False) # TODO(mattjj): upgrade typecompat to handle bints def test_mlp_autodiff_dynamic_batch_bint(self): count = 0 def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.maximum(0, outputs) return outputs def loss_ref(params, batch): nonlocal count count += 1 # count traces inputs, targets = batch predictions = predict(params, inputs) return jnp.sum((predictions - targets) ** 2) loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'})) params = [(jnp.ones((784, 256)), jnp.ones(256)), (jnp.ones((256, 10)), jnp.ones( 10))] # two different batch sizes *with bints* bs1 = jax.lax.convert_element_type(128, core.bint(128)) batch1 = (jnp.ones((bs1, 784)), jnp.ones((bs1, 10))) bs2 = jax.lax.convert_element_type(32, core.bint(128)) batch2 = (jnp.ones((bs2, 784)), jnp.ones((bs2, 10))) # count retraces (and don't crash) self.assertEqual(count, 0) _ = jax.grad(loss)(params, batch1) self.assertEqual(count, 1) g2 = jax.grad(loss)(params, batch2) self.assertEqual(count, 1) # cache hit! # check the numbers make sense batch = (jnp.ones((32, 784)), jnp.ones((32, 10))) g2_expected = jax.grad(loss_ref)(params, batch) self.assertAllClose(g2, g2_expected, check_dtypes=False, atol=1e-3, rtol=1e-3) def test_bint_basic(self): d = lax.convert_element_type(3, core.bint(5)) self.assertEqual(str(d), '3{≤5}') @jax.jit def f(d): jnp.sin(3.) # don't have an empty jaxpr return d f(d) # doesn't crash def test_bint_iota(self): def f(d): return jnp.arange(d, dtype='int32') y = f(lax.convert_element_type(3, core.bint(5))) self.assertIsInstance(y, core.DArray) self.assertAllClose(y._data, np.arange(5), check_dtypes=False) d = lax.convert_element_type(3, core.bint(5)) y = jax.jit(f)(d) self.assertIsInstance(y, core.DArray) self.assertAllClose(y._data, np.arange(5), check_dtypes=False) def test_bint_compilation_cache(self): count = 0 @jax.jit def f(n): nonlocal count count += 1 return jnp.zeros(n) f(lax.convert_element_type(3, core.bint(5))) f(lax.convert_element_type(4, core.bint(5))) self.assertEqual(count, 1) def test_bint_compilation_cache2(self): count = 0 @partial(jax.jit, abstracted_axes=('n',)) def f(x): nonlocal count count += 1 return x.sum() d = lax.convert_element_type(3, core.bint(5)) x = jnp.arange(d) y = f(x) self.assertEqual(y, 3) self.assertEqual(count, 1) d = lax.convert_element_type(4, core.bint(5)) x = jnp.arange(d) y = f(x) self.assertEqual(y, 6) self.assertEqual(count, 1) d = lax.convert_element_type(4, core.bint(6)) x = jnp.arange(d) y = f(x) self.assertEqual(y, 6) self.assertEqual(count, 2) @unittest.skip('do we want to support this?') def test_bint_add(self): d = lax.convert_element_type(4, core.bint(6)) x = jnp.arange(d) @jax.jit def f(x): return x + x f(x) # doesn't crash def test_lower_abstracted_axes(self): @partial(jax.jit, abstracted_axes=('n',)) def f(x): return x.sum() f_lowered = f.lower(np.arange(3, dtype='int32')) mlir_str = f_lowered.compiler_ir() self.assertIn('tensor', str(mlir_str)) def test_lower_abstracted_axes_shapedtypestruct(self): @partial(jax.jit, abstracted_axes=('n',)) def f(x): return x.sum() f_lowered = f.lower(jax.ShapeDtypeStruct((3,), np.int32)) mlir_str = f_lowered.compiler_ir() self.assertIn('tensor', str(mlir_str)) def test_slicing_basic_lower(self): @partial(jax.jit, abstracted_axes=(None, 'n')) def f(x): return x[0] f.lower(jnp.zeros((3, 4))).compiler_ir() # doesn't crash @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_slicing_basic_execute(self): @partial(jax.jit, abstracted_axes=(None, 'n')) def f(x): return x[0] y = f(jnp.arange(3 * 4).reshape(3, 4)) self.assertAllClose(y, jnp.array([0, 1, 2, 3])) def test_gather_basic_bounded(self): x = jnp.arange(3. * 4.).reshape(3, 4) def f(i): return x[i] sz = jax.lax.convert_element_type(2, core.bint(3)) idx = jnp.arange(sz) y = jax.jit(jax.vmap(f), abstracted_axes=('n',))(idx) self.assertIsInstance(y, core.DArray) self.assertEqual(y.shape, (sz, 4)) self.assertAllClose(y._data, x) @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow", jax_disable_jit=True, jax_traceback_filtering='off') class PileTest(jtu.JaxTestCase): def test_internal_pile(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) xs = jax.vmap(lambda n: jax.lax.iota('int32', n).sum())(ins) self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False) def test_make_pile_from_dynamic_shape(self): # We may not want to support returning piles from vmapped functions (instead # preferring to have a separate API which allows piles). But for now it # makes for a convenient way to construct piles for the other tests! ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) p = jax.vmap(partial(jnp.arange, dtype='int32'), out_axes=batching.pile_axis)(ins) self.assertIsInstance(p, batching.Pile) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') data = jax.lax.broadcasted_iota('int32', (3, 5), 1) self.assertAllClose(p.data, data, check_dtypes=False) def test_pile_map_eltwise(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) p = jax.vmap(partial(jnp.arange, dtype='int32'), out_axes=batching.pile_axis)(ins) p = pile_map(lambda x: x ** 2)(p) self.assertIsInstance(p, batching.Pile) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2 self.assertAllClose(p.data, data, check_dtypes=False) def test_pile_map_vector_dot(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) p = jax.vmap(partial(jnp.arange, dtype='int32'), out_axes=batching.pile_axis)(ins) y = pile_map(jnp.dot)(p, p) self.assertIsInstance(y, batching.Pile) self.assertAllClose(y.data, jnp.array([5, 0, 14], dtype='int32')) def test_pile_map_matrix_dot_ragged_contract(self): sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.pile_axis )(sizes) p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.pile_axis )(sizes) y = jax.vmap(jnp.dot, in_axes=batching.pile_axis, out_axes=0, axis_size=3)(p1, p2) self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)), check_dtypes=False) def test_pile_map_matrix_dot_ragged_tensor(self): sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): lhs_one_d = jnp.arange(size, dtype='int32') + 1 lhs_two_d = jax.lax.broadcast_in_dim(lhs_one_d, (size, 2), (0,)) rhs = jax.lax.broadcasted_iota('int32', (2, 4), 0) + 1 return jnp.dot(lhs_two_d, rhs) p = jax.vmap(func, out_axes=batching.pile_axis)(sizes) self.assertIsInstance(p, batching.Pile) self.assertEqual(p.data.shape, (3, 5, 4)) def test_broadcast_in_dim_while_ragged(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jax.lax.broadcast_in_dim(one_d, (size, 7), (0,)) return two_d p = jax.vmap(func, out_axes=batching.pile_axis)(ins) self.assertIsInstance(p, batching.Pile) data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1) self.assertAllClose(p.data, data) def test_broadcast_in_dim_to_ragged(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): one_d = jnp.arange(12, dtype='int32') two_d = jax.lax.broadcast_in_dim(one_d, (size, 12), (1,)) return two_d p = jax.vmap(func, out_axes=batching.pile_axis)(ins) self.assertIsInstance(p, batching.Pile) data = jax.lax.broadcasted_iota('int32', (3, 5, 12), 2) self.assertAllClose(p.data, data) def test_broadcast_in_dim_ragged_to_static_error(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): one_d = jnp.arange(size, dtype='int32') # Broadcast should error even if the target shape is the same as the # underlying data shape, because the semantic size doesn't match. two_d = jax.lax.broadcast_in_dim(one_d, (4, 5), (1,)) return two_d msg = r"got operand of shape \(\[dynamic\],\), target broadcast shape \(4, 5\)" with self.assertRaisesRegex(TypeError, msg): jax.vmap(func, out_axes=batching.pile_axis)(ins) def test_broadcast_in_dim_to_doubly_ragged(self): ins1 = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) ins2 = lax.convert_element_type(jnp.array([2, 5, 1]), core.bint(6)) def func(size1, size2): one_d = jnp.arange(size1, dtype='int32') two_d = jax.lax.broadcast_in_dim(one_d, (size1, size2), (0,)) return two_d p = jax.vmap(func, out_axes=batching.pile_axis)(ins1, ins2) self.assertIsInstance(p, batching.Pile) data = jax.lax.broadcasted_iota('int32', (3, 5, 6), 1) self.assertAllClose(p.data, data) def test_squeeze_ragged(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jax.lax.broadcast_in_dim(one_d, (size, 1), (0,)) one_again = jax.lax.squeeze(two_d, dimensions=[1]) return one_again p = jax.vmap(func, out_axes=batching.pile_axis)(ins) self.assertIsInstance(p, batching.Pile) data = jax.lax.broadcasted_iota('int32', (3, 5), 1) self.assertAllClose(p.data, data) def test_broadcast_to_while_ragged(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jnp.broadcast_to(one_d, (4, size)) return two_d p = jax.vmap(func, out_axes=batching.pile_axis)(ins) self.assertIsInstance(p, batching.Pile) data = jax.lax.broadcasted_iota('int32', (3, 4, 5), 2) self.assertAllClose(p.data, data) def test_broadcast_to_doubly_ragged(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jnp.broadcast_to(one_d, (size, size)) return two_d p = jax.vmap(func, out_axes=batching.pile_axis)(ins) self.assertIsInstance(p, batching.Pile) data = jax.lax.broadcasted_iota('int32', (3, 5, 5), 2) self.assertAllClose(p.data, data) def test_transpose_ragged(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jnp.broadcast_to(one_d, (7, size)) return jnp.transpose(two_d, [1, 0]) p = jax.vmap(func, out_axes=batching.pile_axis)(ins) self.assertIsInstance(p, batching.Pile) data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1) self.assertAllClose(p.data, data) def test_einsum_with_ragged_tensor_dimension(self): x_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def fprop_layer(x_size): one_d = jnp.arange(x_size, dtype='int32') x = jax.lax.broadcast_in_dim(one_d, (x_size, 11), [0]) wqkv = jax.lax.broadcasted_iota('int32', (3, 2, 7, 11), 1) qkv = jnp.einsum('te,ihqe->ithq', x, wqkv) return qkv p = jax.vmap(fprop_layer, out_axes=batching.pile_axis)(x_sizes) self.assertIsInstance(p, batching.Pile) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]') self.assertEqual(p.data.shape, (3, 3, 5, 2, 7)) def test_einsum_with_ragged_tensor_and_contract_dimensions(self): ragged_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def fprop_layer(ragged_size): one_d = jnp.arange(ragged_size, dtype='int32') alpha = jax.lax.broadcast_in_dim(one_d, (ragged_size, ragged_size, 2), [1]) v = jax.lax.broadcast_in_dim(one_d, (ragged_size, 2, 7), [0]) inner = jnp.einsum('tsh,shq->thq', alpha, v) return inner p = jax.vmap(fprop_layer, out_axes=batching.pile_axis)(ragged_sizes) self.assertIsInstance(p, batching.Pile) self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]') self.assertEqual(p.data.shape, (3, 5, 2, 7)) def test_split_while_ragged(self): ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) def func(size): one_d = jnp.arange(size, dtype='int32') two_d = jnp.broadcast_to(one_d, (2, size)) part_1, part_2 = two_d return part_1 p = jax.vmap(func, out_axes=batching.pile_axis)(ins) self.assertIsInstance(p, batching.Pile) data = jax.lax.broadcasted_iota('int32', (3, 5), 1) self.assertAllClose(p.data, data) def pile_map(f): def mapped(*piles): return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis, axis_size=piles[0].aval.length)(*piles) return mapped if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())