[x64] Make autodiff respect weak types

This commit is contained in:
Jake VanderPlas 2021-11-23 15:04:08 -08:00
parent 28b3c46b9b
commit 496e400c71
6 changed files with 49 additions and 35 deletions

View File

@ -37,8 +37,9 @@ def make_shaped_array(x):
return ShapedArray(np.shape(x), dtype)
def zeros_like_array(x):
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
aval = ShapedArray(np.shape(x), dtype)
dtype, weak_type = dtypes._lattice_result_type(x)
dtype = dtypes.canonicalize_dtype(dtype)
aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
return ad_util.zeros_like_aval(aval)
array_types = {np.ndarray, np.bool_,
@ -55,7 +56,8 @@ for t in array_types:
core.literalable_types.update(array_types)
def _zeros_like_python_scalar(t, x):
return np.array(0, dtypes.python_scalar_dtypes[t])
aval = core.ShapedArray((), dtypes.python_scalar_dtypes[t], weak_type=True)
return ad_util.zeros_like_aval(aval)
def _make_concrete_python_scalar(t, x):
return ConcreteArray(

View File

@ -170,7 +170,7 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
... return z
...
>>> jax.value_and_grad(g)(2.0)
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32))
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32, weak_type=True))
Here, the same value is produced whether or not the :func:`jax.checkpoint`
decorator is present. When the decorator is not present, the values

View File

@ -1024,9 +1024,8 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
ans, vjp_py, aux = _vjp(
f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
_check_scalar(ans)
dtype = _dtype(ans)
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
g = vjp_py(np.ones((), dtype=dtype))
g = vjp_py(jax.lax._one(ans))
g = g[0] if isinstance(argnums, int) else g
if not has_aux:
return ans, g
@ -1277,26 +1276,28 @@ def _std_basis(pytree):
def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr):
return _unravel_array_into_pytree(
input_pytree, -1, _dtype(output_pytree_leaf), arr)
input_pytree, -1, output_pytree_leaf, arr)
def _jacrev_unravel(output_pytree, input_pytree_leaf, arr):
return _unravel_array_into_pytree(
output_pytree, 0, _dtype(input_pytree_leaf), arr)
output_pytree, 0, input_pytree_leaf, arr)
def _possible_downcast(x, dtype):
def _possible_downcast(x, example):
if (dtypes.issubdtype(x.dtype, np.complexfloating) and
not dtypes.issubdtype(dtype, np.complexfloating)):
not dtypes.issubdtype(_dtype(example), np.complexfloating)):
x = x.real
return x.astype(dtype)
dtype = None if example is None else _dtype(example)
weak_type = None if example is None else dtypes.is_weakly_typed(example)
return jax._src.lax.lax._convert_element_type(x, dtype, weak_type)
def _unravel_array_into_pytree(pytree, axis, cast_to_type, arr):
def _unravel_array_into_pytree(pytree, axis, example, arr):
"""Unravel an array into a PyTree with a given structure.
Args:
pytree: The pytree that provides the structure.
axis: The parameter axis is either -1, 0, or 1. It controls the
resulting shapes.
cast_to_type: Cast the components to the given dtype, or else use the
pytree leaf type if cast_to_type is None.
example: If specified, cast the components to the matching dtype/weak_type,
or else use the pytree leaf type if example is None.
arr: The array to be unraveled.
"""
leaves, treedef = tree_flatten(pytree)
@ -1304,8 +1305,7 @@ def _unravel_array_into_pytree(pytree, axis, cast_to_type, arr):
shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves]
parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis)
reshaped_parts = [
_possible_downcast(np.reshape(x, shape),
_dtype(leaf) if cast_to_type is None else cast_to_type)
_possible_downcast(np.reshape(x, shape), leaf if example is None else example)
for x, shape, leaf in zip(parts, shapes, leaves)]
return tree_unflatten(treedef, reshaped_parts)
@ -2994,7 +2994,7 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
... return z
...
>>> jax.value_and_grad(g)(2.0)
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32))
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32, weak_type=True))
Here, the same value is produced whether or not the :func:`jax.checkpoint`
decorator is present. When the decorator is not present, the values

View File

@ -1520,11 +1520,10 @@ def _device_put_raw(x, weak_type=None):
def zeros_like_shaped_array(aval):
assert isinstance(aval, ShapedArray)
scalar_zero = np.array(0).astype(aval.dtype)
if scalar_zero.dtype != aval.dtype:
# For numpy 1.17.5 we get here for float0. We use an alternate construction.
assert aval.dtype == dtypes.float0
if aval.dtype == dtypes.float0:
scalar_zero = np.zeros((), dtype=aval.dtype)
else:
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
return broadcast(scalar_zero, aval.shape)
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
@ -1586,13 +1585,13 @@ def stop_gradient(x):
For example:
>>> jax.grad(lambda x: x**2)(3.)
DeviceArray(6., dtype=float32)
DeviceArray(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
DeviceArray(0., dtype=float32)
DeviceArray(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
DeviceArray(2., dtype=float32)
DeviceArray(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
DeviceArray(0., dtype=float32)
DeviceArray(0., dtype=float32, weak_type=True)
"""
def stop(x):
if (dtypes.issubdtype(_dtype(x), np.floating) or

View File

@ -1177,9 +1177,9 @@ class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
def __init__(self, val, weak_type=False):
def __init__(self, val, weak_type=None):
super().__init__(np.shape(val), np.result_type(val),
weak_type=weak_type)
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
# Note: canonicalized self.dtype doesn't necessarily match self.val
self.val = val
assert self.dtype != np.dtype('O'), val

View File

@ -40,8 +40,8 @@ import concurrent.futures
import jax
import jax.numpy as jnp
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
from jax import core, dtypes, lax
from jax._src import api
from jax import core, lax
from jax._src import api, dtypes
from jax.core import Primitive
from jax.errors import UnexpectedTracerError
from jax.interpreters import ad
@ -847,6 +847,13 @@ class APITest(jtu.JaxTestCase):
assert g(2.0) == 4.0
assert len(side) == 1
@parameterized.named_parameters(
{"testcase_name": f"_{transform.__name__}", "transform": transform}
for transform in [grad, jacfwd, jacrev])
def test_ad_weak_types(self, transform):
out = transform(lambda x: x)(1.0)
self.assertTrue(dtypes.is_weakly_typed(out))
def test_bad_input(self):
def f(x):
return x
@ -2325,25 +2332,31 @@ class APITest(jtu.JaxTestCase):
if not hasattr(self, "assertLogs"):
raise unittest.SkipTest("test requires assertLogs (python 3)")
lax.add(1, 2) # make sure some initial warnings are already printed
# make sure some initial warnings & cached operations already happen.
api.grad(api.jit(lambda x: x))(1.0)
sin = api.jit(jnp.sin)
@api.jit
def f(x):
return jnp.sin(x)
prev_level = logging.get_verbosity()
try:
logging.set_verbosity('DEBUG')
with self.assertLogs(level=logging.DEBUG) as l:
ans1 = api.grad(sin)(2.)
ans2 = api.grad(sin)(3.)
ans1 = api.grad(f)(2.)
ans2 = api.grad(f)(3.)
finally:
logging.set_verbosity(prev_level)
self.assertLen(l.output, 2)
self.assertLen(l.output, 2) # one for fwd, one for bwd
self.assertAllClose(ans1, np.cos(2.), check_dtypes=False)
self.assertAllClose(ans2, np.cos(3.), check_dtypes=False)
def test_grad_of_jit_compilation_caching2(self):
# Like the above test, but instead of logging use our compile counters.
# make sure some initial convert element type operations are pre-cached.
api.grad(api.jit(lambda x: x))(1.0)
@api.jit
def f(x):
return jnp.sin(x)