mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #5717 from google:dynamic-shapes2
PiperOrigin-RevId: 357851603
This commit is contained in:
commit
babf249705
@ -131,3 +131,10 @@ pytype_library(
|
||||
":jax",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "djax",
|
||||
srcs = ["experimental/djax.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
14
jax/core.py
14
jax/core.py
@ -860,6 +860,9 @@ class AbstractValue:
|
||||
def update(self, **kwargs):
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def str_short(self):
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
class Bot(AbstractValue): pass
|
||||
|
||||
bot = Bot()
|
||||
@ -1570,12 +1573,15 @@ def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
|
||||
Ignores weak_type and named_shape, other than to check that an axis name isn't
|
||||
used with different sizes.
|
||||
"""
|
||||
aval1 = raise_to_shaped(aval1, weak_type=False)
|
||||
aval2 = raise_to_shaped(aval2, weak_type=False)
|
||||
if aval1 == aval2: return True
|
||||
# unequal avals may still represent the same type, because type is represented
|
||||
# by avals at the shaped level, and because weak type tags and (for now) named
|
||||
# shape components aren't considered part of the type
|
||||
if isinstance(aval1, ShapedArray) and isinstance(aval2, ShapedArray):
|
||||
# check for named shape conflicts
|
||||
# a bonus check for whether any named axes have inconsistent sizes
|
||||
join_named_shapes(aval1.named_shape, aval2.named_shape)
|
||||
return aval1.strip_named_shape() == aval2.strip_named_shape()
|
||||
return (raise_to_shaped(aval1, weak_type=False).strip_named_shape() ==
|
||||
raise_to_shaped(aval2, weak_type=False).strip_named_shape())
|
||||
|
||||
class JaxprTypeError(TypeError): pass
|
||||
|
||||
|
1430
jax/experimental/djax.py
Normal file
1430
jax/experimental/djax.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -293,7 +293,8 @@ class JVPTrace(Trace):
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
|
||||
nz_tangents = [type(t) is not Zero for t in tangents]
|
||||
params = dict(params, name=wrap_name(params['name'], 'jvp'))
|
||||
if 'name' in params:
|
||||
params = dict(params, name=wrap_name(params['name'], 'jvp'))
|
||||
f_jvp = jvp_subtrace(f, self.main)
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
in_axes = params['in_axes']
|
||||
|
@ -943,11 +943,11 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
raise core.escaped_tracer_error(self, None)
|
||||
|
||||
class JaxprStackFrame:
|
||||
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
|
||||
__slots__ = ['gensym', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
|
||||
'tracers', 'eqns', 'invars']
|
||||
|
||||
def __init__(self):
|
||||
self.newvar = core.gensym()
|
||||
self.gensym = core.gensym()
|
||||
self.tracer_to_var = {}
|
||||
self.constid_to_var = {}
|
||||
self.constvar_to_val = {}
|
||||
@ -964,6 +964,9 @@ class JaxprStackFrame:
|
||||
out_avals = [t.aval for t in out_tracers]
|
||||
return jaxpr, out_avals, constvals
|
||||
|
||||
def newvar(self, aval):
|
||||
return self.gensym(aval)
|
||||
|
||||
def find_progenitors(self, tracer):
|
||||
var = self.tracer_to_var.get(id(tracer))
|
||||
if not var:
|
||||
|
178
tests/djax_test.py
Normal file
178
tests/djax_test.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import skipIf
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.util import safe_map, safe_zip
|
||||
|
||||
from jax.experimental import djax
|
||||
from jax.experimental.djax import (
|
||||
bbarray, ones_like, sin, add, iota, nonzero, reduce_sum, broadcast)
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
|
||||
class DJaxTests(jtu.JaxTestCase):
|
||||
|
||||
def test_identity_typechecks(self):
|
||||
def f(x):
|
||||
return x
|
||||
x = jnp.array([0, 1])
|
||||
jaxpr, _, _ = djax.make_djaxpr(f, x)
|
||||
djax.typecheck_jaxpr(jaxpr)
|
||||
|
||||
def test_sin_typechecks(self):
|
||||
def f(x):
|
||||
return sin(x)
|
||||
x = bbarray((5,), jnp.arange(3.))
|
||||
jaxpr, _, _ = djax.make_djaxpr(f, x)
|
||||
djax.typecheck_jaxpr(jaxpr)
|
||||
|
||||
def test_sin_and_add_typechecks(self):
|
||||
def f(x):
|
||||
y = sin(x)
|
||||
z = sin(y)
|
||||
return add(y, z)
|
||||
x = bbarray((5,), jnp.arange(3.))
|
||||
jaxpr, _, _ = djax.make_djaxpr(f, x)
|
||||
djax.typecheck_jaxpr(jaxpr)
|
||||
|
||||
def test_iota_typechecks(self):
|
||||
def f():
|
||||
return iota(3)
|
||||
jaxpr, _, _ = djax.make_djaxpr(f)
|
||||
djax.typecheck_jaxpr(jaxpr)
|
||||
|
||||
def test_nonzero_typechecks(self):
|
||||
def f(x):
|
||||
return nonzero(x)
|
||||
x = jnp.array([1, 0, -2, 0, 3, 0])
|
||||
jaxpr, _, _ = djax.make_djaxpr(f, x)
|
||||
djax.typecheck_jaxpr(jaxpr)
|
||||
|
||||
def test_sum_of_nonzero_typechecks(self):
|
||||
def f(x):
|
||||
return reduce_sum(nonzero(x), tuple(range(len(x.shape))))
|
||||
x = jnp.array([1, 0, -2, 0, 3, 0])
|
||||
jaxpr, _, _ = djax.make_djaxpr(f, x)
|
||||
djax.typecheck_jaxpr(jaxpr)
|
||||
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
|
||||
@skipIf(jax.config.x64_enabled, "only 32bit for now")
|
||||
class DJaxXLATests(jtu.JaxTestCase):
|
||||
|
||||
def test_reduce_sum_of_nonzero(self):
|
||||
@djax.djit
|
||||
def f(x):
|
||||
nonzero_idx = nonzero(x)
|
||||
return reduce_sum(nonzero_idx)
|
||||
|
||||
x = jnp.array([0, 1, 0, 1, 0, 1])
|
||||
ans = f(x)
|
||||
expected = np.sum(np.nonzero(x)[0])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nonzero(self):
|
||||
@djax.djit
|
||||
def f(x):
|
||||
return nonzero(x)
|
||||
x = jnp.array([0, 1, 0, 1, 0, 1])
|
||||
ans = f(x)
|
||||
expected, = np.nonzero(x)
|
||||
self.assertAllClose(np.array(ans), expected, check_dtypes=False)
|
||||
|
||||
def test_iota(self):
|
||||
@djax.djit
|
||||
def f(i):
|
||||
return iota(i)
|
||||
ans = f(djax.BoundedInt(3, 5))
|
||||
expected = np.arange(3)
|
||||
self.assertAllClose(np.array(ans), expected, check_dtypes=False)
|
||||
|
||||
def test_broadcast(self):
|
||||
@djax.djit
|
||||
def f(x, n):
|
||||
y = nonzero(x)
|
||||
return broadcast(y, n)
|
||||
x = np.arange(3)
|
||||
n = djax.BoundedInt(4, 5)
|
||||
ans = f(x, n)
|
||||
expected = np.broadcast_to(np.nonzero(x)[0], (4, 2))
|
||||
self.assertAllClose(np.array(ans), expected, check_dtypes=False)
|
||||
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
|
||||
@skipIf(jax.config.x64_enabled, "only 32bit for now")
|
||||
class DJaxADTests(jtu.JaxTestCase):
|
||||
|
||||
def test_jvp(self):
|
||||
@djax.djit
|
||||
def f(x):
|
||||
y = sin(x)
|
||||
return reduce_sum(y, axes=(0,))
|
||||
x = bbarray((5,), jnp.arange(2.))
|
||||
z, z_dot = jax.jvp(f, (x,), (ones_like(x),))
|
||||
|
||||
def g(x):
|
||||
return jnp.sin(x).sum()
|
||||
expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))
|
||||
|
||||
self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
|
||||
self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
|
||||
|
||||
def test_linearize(self):
|
||||
@djax.djit
|
||||
def f(x):
|
||||
y = sin(x)
|
||||
return reduce_sum(y, axes=(0,))
|
||||
x = bbarray((5,), jnp.arange(2.))
|
||||
with jax.core.skipping_checks(): # TODO implement dxla_call abs eval rule
|
||||
z, f_lin = jax.linearize(f, x)
|
||||
z_dot = f_lin(ones_like(x))
|
||||
|
||||
def g(x):
|
||||
return jnp.sin(x).sum()
|
||||
expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))
|
||||
|
||||
self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
|
||||
self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
|
||||
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
|
||||
@skipIf(jax.config.x64_enabled, "only 32bit for now")
|
||||
class DJaxBatchingTests(jtu.JaxTestCase):
|
||||
|
||||
def test_nonzero(self):
|
||||
@djax.djit
|
||||
def f(x):
|
||||
return nonzero(x)
|
||||
xs = jnp.array([[0, 1, 0, 1, 0, 1],
|
||||
[1, 1, 1, 1, 0, 1]])
|
||||
jax.vmap(f)(xs) # doesn't crash
|
||||
# TODO check value
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user