mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
663 lines
21 KiB
Python
663 lines
21 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import functools
|
|
|
|
from absl.testing import absltest
|
|
import jax
|
|
from jax import api_util
|
|
import jax.numpy as jnp
|
|
from jax._src import core
|
|
from jax import lax
|
|
from jax._src.pjit import pjit
|
|
from jax._src import linear_util as lu
|
|
from jax._src import test_util as jtu
|
|
from jax._src import ad_checkpoint
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
def _get_hlo(f):
|
|
def wrapped(*args, **kwargs):
|
|
return jax.jit(f).lower(*args, **kwargs).as_text('hlo', debug_info=True)
|
|
|
|
return wrapped
|
|
|
|
|
|
class NameStackTest(jtu.JaxTestCase):
|
|
|
|
def test_trivial_name_stack(self):
|
|
|
|
def f(x):
|
|
return x + 1
|
|
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
|
for eqn in jaxpr.eqns:
|
|
self.assertEqual(str(eqn.source_info.name_stack), '')
|
|
|
|
def test_name_call_name_stack(self):
|
|
|
|
@jax.named_call
|
|
def f(x):
|
|
return x + 1
|
|
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
|
for eqn in jaxpr.eqns:
|
|
self.assertEqual(str(eqn.source_info.name_stack), 'f')
|
|
|
|
def test_manual_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
return x + 1
|
|
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
|
for eqn in jaxpr.eqns:
|
|
self.assertEqual(str(eqn.source_info.name_stack), 'foo')
|
|
|
|
def test_nested_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
with jax.named_scope('bar'):
|
|
return x + 1
|
|
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
|
for eqn in jaxpr.eqns:
|
|
self.assertEqual(str(eqn.source_info.name_stack), 'foo/bar')
|
|
|
|
def test_multiple_name_stack(self):
|
|
|
|
def f(x):
|
|
with jax.named_scope('foo'):
|
|
y = x + 1
|
|
with jax.named_scope('bar'):
|
|
with jax.named_scope('baz'):
|
|
return y + 1
|
|
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
|
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'bar/baz')
|
|
|
|
def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('bar')
|
|
def _f(x):
|
|
return [x + 1]
|
|
return core.call(lu.wrap_init(
|
|
_f,
|
|
debug_info=api_util.debug_info("test", _f, (0,), {})), x)[0]
|
|
|
|
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
|
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
|
|
|
hlo_text = _get_hlo(f)(2)
|
|
self.assertIn('foo/jit(core_call)/bar', hlo_text)
|
|
|
|
def test_jit_jaxpr_should_not_store_outer_name_stack(self):
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.jit
|
|
@jax.named_scope('bar')
|
|
def _f(x):
|
|
return x + 1
|
|
return _f(x)
|
|
|
|
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
|
jaxpr_param = 'jaxpr'
|
|
self.assertEqual(
|
|
str(jaxpr.eqns[0].params[jaxpr_param].eqns[0].source_info.name_stack),
|
|
'bar')
|
|
|
|
hlo_text = _get_hlo(f)(2)
|
|
self.assertIn('foo/jit(_f)/bar', hlo_text)
|
|
|
|
def test_pmap_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
|
|
@jax.named_scope('foo')
|
|
@jax.pmap
|
|
def f(x):
|
|
with jax.named_scope('bar'):
|
|
return x + 1
|
|
jaxpr = jax.make_jaxpr(f)(jnp.ones(1)).jaxpr
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
|
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
|
|
|
|
|
class NameStackTransformationTest(jtu.JaxTestCase):
|
|
|
|
def test_vmap_should_transform_name_stack(self):
|
|
@jax.vmap
|
|
def f(x):
|
|
with jax.named_scope('foo'):
|
|
return x + 1
|
|
jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
|
|
|
|
def test_vmap_should_transform_inner_name_stacks(self):
|
|
@jax.named_scope('foo')
|
|
@jax.vmap
|
|
def f(x):
|
|
with jax.named_scope('bar'):
|
|
with jax.named_scope('baz'):
|
|
return x + 1
|
|
jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/vmap(bar)/baz')
|
|
|
|
def test_vmap_should_apply_to_call_jaxpr(self):
|
|
@jax.named_scope('foo')
|
|
@jax.vmap
|
|
def f(x):
|
|
@jax.jit
|
|
@jax.named_scope('bar')
|
|
def _f(x):
|
|
return x + 1
|
|
return _f(x)
|
|
|
|
jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
|
|
jaxpr_param = 'jaxpr'
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
|
self.assertEqual(
|
|
str(jaxpr.eqns[0].params[jaxpr_param].eqns[0].source_info.name_stack),
|
|
'bar')
|
|
|
|
hlo_text = _get_hlo(f)(jnp.ones(2))
|
|
self.assertIn('foo/vmap(jit(_f))/bar', hlo_text)
|
|
|
|
def test_jvp_should_transform_stacks(self):
|
|
def f(x):
|
|
with jax.named_scope('bar'):
|
|
with jax.named_scope('baz'):
|
|
return jnp.square(x)
|
|
g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
|
|
jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack),
|
|
'foo/jvp(bar)/baz')
|
|
|
|
def test_jvp_should_apply_to_call_jaxpr(self):
|
|
@jax.jit
|
|
def f(x):
|
|
with jax.named_scope('bar'):
|
|
with jax.named_scope('baz'):
|
|
return jnp.square(x)
|
|
g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
|
|
jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
|
|
jaxpr_param = 'jaxpr'
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
|
self.assertEqual(
|
|
str(jaxpr.eqns[0].params[jaxpr_param].eqns[0].source_info.name_stack),
|
|
'bar/baz')
|
|
|
|
hlo_text = _get_hlo(g)(1., 1.)
|
|
self.assertIn('foo/jvp(jit(f))/bar/baz/mul', hlo_text)
|
|
|
|
def test_grad_should_add_jvp_and_transpose_to_name_stack(self):
|
|
@jax.value_and_grad
|
|
def f(x):
|
|
with jax.named_scope('foo'):
|
|
return 2 * jnp.sin(x)
|
|
jaxpr = jax.make_jaxpr(f)(1.).jaxpr
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
|
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)')
|
|
self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack), 'jvp(foo)')
|
|
self.assertEqual(str(jaxpr.eqns[4].source_info.name_stack),
|
|
'transpose(jvp(foo))')
|
|
|
|
hlo_text = _get_hlo(f)(1.)
|
|
self.assertIn('jvp(foo)/sin', hlo_text)
|
|
self.assertIn('jvp(foo)/cos', hlo_text)
|
|
self.assertIn('transpose(jvp(foo))/mul', hlo_text)
|
|
|
|
def test_grad_should_add_jvp_and_transpose_to_call_jaxpr(self):
|
|
@jax.value_and_grad
|
|
@jax.named_scope('foo')
|
|
@jax.jit
|
|
def f(x):
|
|
with jax.named_scope('bar'):
|
|
# return jnp.sin(x)
|
|
return jax.lax.sin(x)
|
|
jaxpr = jax.make_jaxpr(f)(1.).jaxpr
|
|
jaxpr_param = 'jaxpr'
|
|
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
|
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'transpose(jvp(foo))')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params[jaxpr_param].eqns[0].source_info.name_stack), 'bar')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params[jaxpr_param].eqns[1].source_info.name_stack), 'bar')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[1].params[jaxpr_param].eqns[0].source_info.name_stack), 'bar')
|
|
|
|
hlo_text = _get_hlo(f)(1.)
|
|
self.assertIn('jvp(foo)/jit(f)/bar/sin', hlo_text)
|
|
self.assertIn('jvp(foo)/jit(f)/bar/cos', hlo_text)
|
|
self.assertIn('transpose(jvp(foo))/jit(f)/bar/mul', hlo_text)
|
|
|
|
def test_nested_jit_stack(self):
|
|
|
|
@jax.value_and_grad
|
|
@jax.jit
|
|
def f(x):
|
|
@jax.jit
|
|
def g(y):
|
|
return jnp.sin(y)
|
|
return g(x)
|
|
|
|
hlo_text = _get_hlo(f)(2.)
|
|
self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text)
|
|
self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text)
|
|
self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text)
|
|
|
|
def test_nested_pjit_stack(self):
|
|
@jax.value_and_grad
|
|
@pjit
|
|
def f(x):
|
|
@pjit
|
|
def g(y):
|
|
return jnp.sin(y)
|
|
return g(x)
|
|
|
|
hlo_text = _get_hlo(f)(2.)
|
|
self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text)
|
|
self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text)
|
|
self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text)
|
|
|
|
def test_remat_appears_in_hlo(self):
|
|
@ad_checkpoint.remat
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
|
|
hlo_text = _get_hlo(f)(2.)
|
|
hlo_text_grad = _get_hlo(jax.grad(f))(2.)
|
|
self.assertNotIn('rematted_computation', hlo_text)
|
|
self.assertNotIn('remat', hlo_text)
|
|
self.assertIn('checkpoint', hlo_text)
|
|
self.assertIn('rematted_computation', hlo_text_grad)
|
|
|
|
|
|
class NameStackControlFlowTest(jtu.JaxTestCase):
|
|
|
|
def test_while_loop_body_should_not_have_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('bar')
|
|
def body(x):
|
|
return x + 1
|
|
@jax.named_scope('bar_cond')
|
|
def cond(x):
|
|
return x < 5
|
|
return lax.while_loop(cond, body, x)
|
|
jaxpr = jax.make_jaxpr(f)(0)
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
|
|
'bar')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
|
|
'bar_cond')
|
|
|
|
hlo_text = _get_hlo(f)(1.)
|
|
self.assertIn('foo/while/body/bar', hlo_text)
|
|
self.assertIn('foo/while/cond/bar_cond', hlo_text)
|
|
|
|
def test_vmap_of_while_loop_should_transform_name_stack(self):
|
|
|
|
@jax.vmap
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('bar')
|
|
def body(x):
|
|
return x + 1
|
|
@jax.named_scope('bar_cond')
|
|
def cond(x):
|
|
return x < 5
|
|
return lax.while_loop(cond, body, x)
|
|
jaxpr = jax.make_jaxpr(f)(jnp.arange(2))
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
|
|
'bar')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
|
|
'bar_cond')
|
|
|
|
hlo_text = _get_hlo(f)(jnp.arange(2.))
|
|
self.assertIn('vmap(foo)/while/body/bar/add', hlo_text)
|
|
self.assertIn('vmap(foo)/while/cond/bar_cond/lt', hlo_text)
|
|
|
|
def test_jvp_of_while_loop_transforms_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('bar')
|
|
def body(x):
|
|
return x + 1.
|
|
@jax.named_scope('bar_cond')
|
|
def cond(x):
|
|
return x < 5.
|
|
return lax.while_loop(cond, body, x)
|
|
g = lambda x, t: jax.jvp(f, (x,), (t,))
|
|
jaxpr = jax.make_jaxpr(g)(1., 1.)
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
|
|
'bar')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
|
|
'bar_cond')
|
|
|
|
hlo_text = _get_hlo(g)(1., 1.)
|
|
self.assertIn('jvp(foo)/while/body/bar/add', hlo_text)
|
|
self.assertIn('jvp(foo)/while/cond/bar_cond/lt', hlo_text)
|
|
|
|
def test_vmap_of_jvp_of_while_loop_transforms_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('bar')
|
|
def body(x):
|
|
return x + 1.
|
|
@jax.named_scope('bar_cond')
|
|
def cond(x):
|
|
return x < 5.
|
|
return lax.while_loop(cond, body, x)
|
|
g = jax.vmap(lambda x, t: jax.jvp(f, (x,), (t,)))
|
|
jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2))
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
|
|
'bar')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
|
|
'bar_cond')
|
|
|
|
hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2))
|
|
self.assertIn('vmap(jvp(foo))/while/body/bar/add', hlo_text)
|
|
self.assertIn('vmap(jvp(foo))/while/body_pred/bar_cond', hlo_text)
|
|
|
|
|
|
def test_cond_body_should_not_have_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x, y):
|
|
@jax.named_scope('true')
|
|
def true_fn(x):
|
|
return x + 1
|
|
@jax.named_scope('false')
|
|
def false_fn(x):
|
|
return x - 1
|
|
return lax.cond(y, true_fn, false_fn, x)
|
|
jaxpr = jax.make_jaxpr(f)(0, True)
|
|
for eqn in jaxpr.eqns:
|
|
self.assertEqual(str(eqn.source_info.name_stack), 'foo')
|
|
if eqn.primitive is lax.cond_p:
|
|
self.assertEqual(str(
|
|
eqn.params['branches'][0].eqns[0].source_info.name_stack),
|
|
'false')
|
|
self.assertEqual(str(
|
|
eqn.params['branches'][1].eqns[0].source_info.name_stack),
|
|
'true')
|
|
|
|
hlo_text = _get_hlo(f)(1, True)
|
|
self.assertIn('foo/cond/branch_0_fun/false/sub', hlo_text)
|
|
self.assertIn('foo/cond/branch_1_fun/true/add', hlo_text)
|
|
|
|
def test_vmap_of_cond_should_transform_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
def f(x, y):
|
|
@jax.named_scope('true')
|
|
def true_fn(x):
|
|
return x + 1
|
|
@jax.named_scope('false')
|
|
def false_fn(x):
|
|
return x - 1
|
|
return lax.cond(y, true_fn, false_fn, x)
|
|
jaxpr = jax.make_jaxpr(f)(jnp.arange(2), True)
|
|
for eqn in jaxpr.eqns:
|
|
self.assertIn('foo', str(eqn.source_info.name_stack))
|
|
if eqn.primitive is lax.cond_p:
|
|
self.assertEqual(str(
|
|
eqn.params['branches'][0].eqns[0].source_info.name_stack),
|
|
'false')
|
|
self.assertEqual(str(
|
|
eqn.params['branches'][1].eqns[0].source_info.name_stack),
|
|
'true')
|
|
|
|
hlo_text = _get_hlo(f)(jnp.arange(2.), True)
|
|
self.assertIn('foo/vmap(cond)/branch_0_fun/false/sub', hlo_text)
|
|
self.assertIn('foo/vmap(cond)/branch_1_fun/true/add', hlo_text)
|
|
|
|
def test_jvp_of_cond_transforms_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x, y):
|
|
@jax.named_scope('true')
|
|
def true_fn(x):
|
|
return x + 1
|
|
@jax.named_scope('false')
|
|
def false_fn(x):
|
|
return x - 1
|
|
return lax.cond(y, true_fn, false_fn, x)
|
|
f_ = lambda x: jax.jit(f)(x, True)
|
|
g = lambda x, t: jax.jvp(f_, (x,), (t,))
|
|
jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2))
|
|
jaxpr_param = 'jaxpr'
|
|
call_jaxpr = jaxpr.jaxpr.eqns[0].params[jaxpr_param]
|
|
self.assertEqual(str(call_jaxpr.eqns[1].source_info.name_stack), 'foo')
|
|
self.assertEqual(str(
|
|
call_jaxpr.eqns[1].params['branches'][0].eqns[0].source_info.name_stack),
|
|
'false')
|
|
self.assertEqual(str(
|
|
call_jaxpr.eqns[1].params['branches'][1].eqns[0].source_info.name_stack),
|
|
'true')
|
|
|
|
hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2))
|
|
self.assertIn('jvp(jit(f))/foo/cond/branch_0_fun/false/sub', hlo_text)
|
|
self.assertIn('jvp(jit(f))/foo/cond/branch_1_fun/true/add', hlo_text)
|
|
|
|
def test_vmap_of_jvp_of_cond_transforms_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x, y):
|
|
@jax.named_scope('true')
|
|
def true_fn(x):
|
|
return x + 1
|
|
@jax.named_scope('false')
|
|
def false_fn(x):
|
|
return x - 1
|
|
return lax.cond(y, true_fn, false_fn, x)
|
|
f_ = lambda x: jax.jit(f)(x, True)
|
|
g = jax.vmap(lambda x, t: jax.jvp(f_, (x,), (t,)))
|
|
jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2))
|
|
jaxpr_param = 'jaxpr'
|
|
|
|
call_jaxpr = jaxpr.jaxpr.eqns[0].params[jaxpr_param]
|
|
self.assertEqual(str(
|
|
call_jaxpr.eqns[1].params['branches'][0].eqns[0].source_info.name_stack),
|
|
'false')
|
|
self.assertEqual(str(
|
|
call_jaxpr.eqns[1].params['branches'][1].eqns[0].source_info.name_stack),
|
|
'true')
|
|
|
|
hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2))
|
|
self.assertIn(
|
|
'vmap(jvp(jit(f)))/foo/cond/branch_0_fun/false/sub"',
|
|
hlo_text)
|
|
self.assertIn(
|
|
'vmap(jvp(jit(f)))/foo/cond/branch_1_fun/true/add"',
|
|
hlo_text)
|
|
|
|
def test_grad_of_cond_transforms_name_stack(self):
|
|
|
|
@jax.value_and_grad
|
|
@jax.named_scope('foo')
|
|
def f(x, y):
|
|
@jax.named_scope('true')
|
|
def true_fn(x):
|
|
return x * x * 2.
|
|
|
|
@jax.named_scope('false')
|
|
def false_fn(x):
|
|
return x / jnp.square(x)
|
|
|
|
return lax.cond(y, true_fn, false_fn, x)
|
|
|
|
jaxpr = jax.make_jaxpr(f)(1., True)
|
|
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)')
|
|
self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack),
|
|
'transpose(jvp(foo))')
|
|
|
|
hlo_text = _get_hlo(f)(1., True)
|
|
self.assertIn(
|
|
'jvp(foo)/cond/branch_0_fun/false/div',
|
|
hlo_text)
|
|
self.assertIn(
|
|
'jvp(foo)/cond/branch_1_fun/true/mul',
|
|
hlo_text)
|
|
self.assertIn(
|
|
'transpose(jvp(foo))/cond/branch_0_fun/false/div',
|
|
hlo_text)
|
|
self.assertIn(
|
|
'transpose(jvp(foo))/cond/branch_1_fun/true/mul',
|
|
hlo_text)
|
|
|
|
def test_vmap_of_grad_of_cond_transforms_name_stack(self):
|
|
|
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
@jax.value_and_grad
|
|
@jax.named_scope('foo')
|
|
def f(x, y):
|
|
@jax.named_scope('true')
|
|
def true_fn(x):
|
|
return x * x * 2.
|
|
@jax.named_scope('false')
|
|
def false_fn(x):
|
|
return x / x / 2.
|
|
return lax.cond(y, true_fn, false_fn, x)
|
|
jaxpr = jax.make_jaxpr(f)(jnp.arange(2.), True)
|
|
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'vmap(jvp(foo))')
|
|
self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack),
|
|
'vmap(transpose(jvp(foo)))')
|
|
|
|
hlo_text = _get_hlo(f)(jnp.arange(2.), True)
|
|
self.assertIn(
|
|
'vmap(jvp(foo))/cond/branch_0_fun/false/div',
|
|
hlo_text)
|
|
self.assertIn(
|
|
'vmap(jvp(foo))/cond/branch_1_fun/true/mul',
|
|
hlo_text)
|
|
self.assertIn(
|
|
'vmap(transpose(jvp(foo)))/cond/branch_0_fun/false/div',
|
|
hlo_text)
|
|
self.assertIn(
|
|
'vmap(transpose(jvp(foo)))/cond/branch_1_fun/true/mul',
|
|
hlo_text)
|
|
|
|
def test_scan_body_should_not_have_name_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('scan_body')
|
|
def body(carry, x):
|
|
return carry + x, carry + x
|
|
return lax.scan(body, x, jnp.arange(5, dtype='float32'))
|
|
jaxpr = jax.make_jaxpr(f)(jnp.float32(1))
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
|
'scan_body')
|
|
|
|
hlo_text = _get_hlo(f)(1.)
|
|
self.assertIn('foo/while/body/scan_body', hlo_text)
|
|
|
|
def test_vmap_of_scan_should_transform_stack(self):
|
|
|
|
@jax.vmap
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('scan_body')
|
|
def body(carry, x):
|
|
return carry + x, carry + x
|
|
return lax.scan(body, x, jnp.arange(8.))
|
|
jaxpr = jax.make_jaxpr(f)(jnp.arange(2.))
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
|
'scan_body')
|
|
|
|
hlo_text = _get_hlo(f)(jnp.arange(2.))
|
|
self.assertIn('vmap(foo)/while/body/scan_body/add', hlo_text)
|
|
|
|
def test_jvp_of_scan_should_transform_stack(self):
|
|
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('scan_body')
|
|
def body(carry, x):
|
|
return carry + x, carry + x
|
|
return lax.scan(body, x, jnp.arange(8, dtype='float32'))
|
|
g = lambda x, t: jax.jvp(f, (x,), (t,))
|
|
jaxpr = jax.make_jaxpr(g)(jnp.float32(1), jnp.float32(1))
|
|
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
|
'scan_body')
|
|
|
|
hlo_text = _get_hlo(g)(1., 1.)
|
|
self.assertIn('jvp(foo)/while/body/scan_body/add', hlo_text)
|
|
|
|
def test_grad_of_scan_should_transform_stack(self):
|
|
|
|
@jax.value_and_grad
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('scan_body')
|
|
def body(carry, x):
|
|
return 2 * carry * x, carry + x
|
|
return lax.scan(body, x, jnp.arange(8., dtype='float32'))[0]
|
|
jaxpr = jax.make_jaxpr(f)(jnp.float32(2))
|
|
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)')
|
|
self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack),
|
|
'transpose(jvp(foo))')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
|
'scan_body')
|
|
|
|
hlo_text = _get_hlo(f)(1.)
|
|
self.assertIn('jvp(foo)/while/body/scan_body/mul', hlo_text)
|
|
self.assertIn('transpose(jvp(foo))/while/body/scan_body/mul', hlo_text)
|
|
|
|
def test_vmap_of_grad_of_scan_should_transform_stack(self):
|
|
|
|
@jax.vmap
|
|
@jax.value_and_grad
|
|
@jax.named_scope('foo')
|
|
def f(x):
|
|
@jax.named_scope('scan_body')
|
|
def body(carry, x):
|
|
return carry * x, carry + x
|
|
return lax.scan(body, x, jnp.arange(8.))[0]
|
|
jaxpr = jax.make_jaxpr(f)(jnp.arange(2.))
|
|
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'vmap(jvp(foo))')
|
|
self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack),
|
|
'vmap(transpose(jvp(foo)))')
|
|
self.assertEqual(str(
|
|
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
|
'scan_body')
|
|
|
|
hlo_text = _get_hlo(f)(jnp.arange(2.))
|
|
self.assertIn('vmap(jvp(foo))/while/body/scan_body/mul', hlo_text)
|
|
self.assertIn('vmap(transpose(jvp(foo)))/while/body/scan_body/mul', hlo_text)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|