Merge pull request #5717 from google:dynamic-shapes2

PiperOrigin-RevId: 357851603
This commit is contained in:
jax authors 2021-02-16 18:45:13 -08:00
commit babf249705
6 changed files with 1632 additions and 7 deletions

View File

@ -131,3 +131,10 @@ pytype_library(
":jax",
],
)
pytype_library(
name = "djax",
srcs = ["experimental/djax.py"],
srcs_version = "PY3",
deps = [":jax"],
)

View File

@ -860,6 +860,9 @@ class AbstractValue:
def update(self, **kwargs):
raise NotImplementedError("must override")
def str_short(self):
raise NotImplementedError("must override")
class Bot(AbstractValue): pass
bot = Bot()
@ -1570,12 +1573,15 @@ def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
Ignores weak_type and named_shape, other than to check that an axis name isn't
used with different sizes.
"""
aval1 = raise_to_shaped(aval1, weak_type=False)
aval2 = raise_to_shaped(aval2, weak_type=False)
if aval1 == aval2: return True
# unequal avals may still represent the same type, because type is represented
# by avals at the shaped level, and because weak type tags and (for now) named
# shape components aren't considered part of the type
if isinstance(aval1, ShapedArray) and isinstance(aval2, ShapedArray):
# check for named shape conflicts
# a bonus check for whether any named axes have inconsistent sizes
join_named_shapes(aval1.named_shape, aval2.named_shape)
return aval1.strip_named_shape() == aval2.strip_named_shape()
return (raise_to_shaped(aval1, weak_type=False).strip_named_shape() ==
raise_to_shaped(aval2, weak_type=False).strip_named_shape())
class JaxprTypeError(TypeError): pass

1430
jax/experimental/djax.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -293,7 +293,8 @@ class JVPTrace(Trace):
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
nz_tangents = [type(t) is not Zero for t in tangents]
params = dict(params, name=wrap_name(params['name'], 'jvp'))
if 'name' in params:
params = dict(params, name=wrap_name(params['name'], 'jvp'))
f_jvp = jvp_subtrace(f, self.main)
if isinstance(call_primitive, core.MapPrimitive):
in_axes = params['in_axes']

View File

@ -943,11 +943,11 @@ class DynamicJaxprTracer(core.Tracer):
raise core.escaped_tracer_error(self, None)
class JaxprStackFrame:
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
__slots__ = ['gensym', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
'tracers', 'eqns', 'invars']
def __init__(self):
self.newvar = core.gensym()
self.gensym = core.gensym()
self.tracer_to_var = {}
self.constid_to_var = {}
self.constvar_to_val = {}
@ -964,6 +964,9 @@ class JaxprStackFrame:
out_avals = [t.aval for t in out_tracers]
return jaxpr, out_avals, constvals
def newvar(self, aval):
return self.gensym(aval)
def find_progenitors(self, tracer):
var = self.tracer_to_var.get(id(tracer))
if not var:

178
tests/djax_test.py Normal file
View File

@ -0,0 +1,178 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import skipIf
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax import test_util as jtu
from jax.util import safe_map, safe_zip
from jax.experimental import djax
from jax.experimental.djax import (
bbarray, ones_like, sin, add, iota, nonzero, reduce_sum, broadcast)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
jax.config.parse_flags_with_absl()
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
class DJaxTests(jtu.JaxTestCase):
def test_identity_typechecks(self):
def f(x):
return x
x = jnp.array([0, 1])
jaxpr, _, _ = djax.make_djaxpr(f, x)
djax.typecheck_jaxpr(jaxpr)
def test_sin_typechecks(self):
def f(x):
return sin(x)
x = bbarray((5,), jnp.arange(3.))
jaxpr, _, _ = djax.make_djaxpr(f, x)
djax.typecheck_jaxpr(jaxpr)
def test_sin_and_add_typechecks(self):
def f(x):
y = sin(x)
z = sin(y)
return add(y, z)
x = bbarray((5,), jnp.arange(3.))
jaxpr, _, _ = djax.make_djaxpr(f, x)
djax.typecheck_jaxpr(jaxpr)
def test_iota_typechecks(self):
def f():
return iota(3)
jaxpr, _, _ = djax.make_djaxpr(f)
djax.typecheck_jaxpr(jaxpr)
def test_nonzero_typechecks(self):
def f(x):
return nonzero(x)
x = jnp.array([1, 0, -2, 0, 3, 0])
jaxpr, _, _ = djax.make_djaxpr(f, x)
djax.typecheck_jaxpr(jaxpr)
def test_sum_of_nonzero_typechecks(self):
def f(x):
return reduce_sum(nonzero(x), tuple(range(len(x.shape))))
x = jnp.array([1, 0, -2, 0, 3, 0])
jaxpr, _, _ = djax.make_djaxpr(f, x)
djax.typecheck_jaxpr(jaxpr)
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
@skipIf(jax.config.x64_enabled, "only 32bit for now")
class DJaxXLATests(jtu.JaxTestCase):
def test_reduce_sum_of_nonzero(self):
@djax.djit
def f(x):
nonzero_idx = nonzero(x)
return reduce_sum(nonzero_idx)
x = jnp.array([0, 1, 0, 1, 0, 1])
ans = f(x)
expected = np.sum(np.nonzero(x)[0])
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nonzero(self):
@djax.djit
def f(x):
return nonzero(x)
x = jnp.array([0, 1, 0, 1, 0, 1])
ans = f(x)
expected, = np.nonzero(x)
self.assertAllClose(np.array(ans), expected, check_dtypes=False)
def test_iota(self):
@djax.djit
def f(i):
return iota(i)
ans = f(djax.BoundedInt(3, 5))
expected = np.arange(3)
self.assertAllClose(np.array(ans), expected, check_dtypes=False)
def test_broadcast(self):
@djax.djit
def f(x, n):
y = nonzero(x)
return broadcast(y, n)
x = np.arange(3)
n = djax.BoundedInt(4, 5)
ans = f(x, n)
expected = np.broadcast_to(np.nonzero(x)[0], (4, 2))
self.assertAllClose(np.array(ans), expected, check_dtypes=False)
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
@skipIf(jax.config.x64_enabled, "only 32bit for now")
class DJaxADTests(jtu.JaxTestCase):
def test_jvp(self):
@djax.djit
def f(x):
y = sin(x)
return reduce_sum(y, axes=(0,))
x = bbarray((5,), jnp.arange(2.))
z, z_dot = jax.jvp(f, (x,), (ones_like(x),))
def g(x):
return jnp.sin(x).sum()
expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))
self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
def test_linearize(self):
@djax.djit
def f(x):
y = sin(x)
return reduce_sum(y, axes=(0,))
x = bbarray((5,), jnp.arange(2.))
with jax.core.skipping_checks(): # TODO implement dxla_call abs eval rule
z, f_lin = jax.linearize(f, x)
z_dot = f_lin(ones_like(x))
def g(x):
return jnp.sin(x).sum()
expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))
self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
@skipIf(jax.config.x64_enabled, "only 32bit for now")
class DJaxBatchingTests(jtu.JaxTestCase):
def test_nonzero(self):
@djax.djit
def f(x):
return nonzero(x)
xs = jnp.array([[0, 1, 0, 1, 0, 1],
[1, 1, 1, 1, 0, 1]])
jax.vmap(f)(xs) # doesn't crash
# TODO check value
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())