mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[x64] Make autodiff respect weak types
This commit is contained in:
parent
28b3c46b9b
commit
496e400c71
@ -37,8 +37,9 @@ def make_shaped_array(x):
|
||||
return ShapedArray(np.shape(x), dtype)
|
||||
|
||||
def zeros_like_array(x):
|
||||
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
|
||||
aval = ShapedArray(np.shape(x), dtype)
|
||||
dtype, weak_type = dtypes._lattice_result_type(x)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
|
||||
return ad_util.zeros_like_aval(aval)
|
||||
|
||||
array_types = {np.ndarray, np.bool_,
|
||||
@ -55,7 +56,8 @@ for t in array_types:
|
||||
core.literalable_types.update(array_types)
|
||||
|
||||
def _zeros_like_python_scalar(t, x):
|
||||
return np.array(0, dtypes.python_scalar_dtypes[t])
|
||||
aval = core.ShapedArray((), dtypes.python_scalar_dtypes[t], weak_type=True)
|
||||
return ad_util.zeros_like_aval(aval)
|
||||
|
||||
def _make_concrete_python_scalar(t, x):
|
||||
return ConcreteArray(
|
||||
|
@ -170,7 +170,7 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
|
||||
... return z
|
||||
...
|
||||
>>> jax.value_and_grad(g)(2.0)
|
||||
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32))
|
||||
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32, weak_type=True))
|
||||
|
||||
Here, the same value is produced whether or not the :func:`jax.checkpoint`
|
||||
decorator is present. When the decorator is not present, the values
|
||||
|
@ -1024,9 +1024,8 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
ans, vjp_py, aux = _vjp(
|
||||
f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
|
||||
_check_scalar(ans)
|
||||
dtype = _dtype(ans)
|
||||
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
|
||||
g = vjp_py(np.ones((), dtype=dtype))
|
||||
g = vjp_py(jax.lax._one(ans))
|
||||
g = g[0] if isinstance(argnums, int) else g
|
||||
if not has_aux:
|
||||
return ans, g
|
||||
@ -1277,26 +1276,28 @@ def _std_basis(pytree):
|
||||
|
||||
def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr):
|
||||
return _unravel_array_into_pytree(
|
||||
input_pytree, -1, _dtype(output_pytree_leaf), arr)
|
||||
input_pytree, -1, output_pytree_leaf, arr)
|
||||
|
||||
def _jacrev_unravel(output_pytree, input_pytree_leaf, arr):
|
||||
return _unravel_array_into_pytree(
|
||||
output_pytree, 0, _dtype(input_pytree_leaf), arr)
|
||||
output_pytree, 0, input_pytree_leaf, arr)
|
||||
|
||||
def _possible_downcast(x, dtype):
|
||||
def _possible_downcast(x, example):
|
||||
if (dtypes.issubdtype(x.dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(dtype, np.complexfloating)):
|
||||
not dtypes.issubdtype(_dtype(example), np.complexfloating)):
|
||||
x = x.real
|
||||
return x.astype(dtype)
|
||||
dtype = None if example is None else _dtype(example)
|
||||
weak_type = None if example is None else dtypes.is_weakly_typed(example)
|
||||
return jax._src.lax.lax._convert_element_type(x, dtype, weak_type)
|
||||
|
||||
def _unravel_array_into_pytree(pytree, axis, cast_to_type, arr):
|
||||
def _unravel_array_into_pytree(pytree, axis, example, arr):
|
||||
"""Unravel an array into a PyTree with a given structure.
|
||||
Args:
|
||||
pytree: The pytree that provides the structure.
|
||||
axis: The parameter axis is either -1, 0, or 1. It controls the
|
||||
resulting shapes.
|
||||
cast_to_type: Cast the components to the given dtype, or else use the
|
||||
pytree leaf type if cast_to_type is None.
|
||||
example: If specified, cast the components to the matching dtype/weak_type,
|
||||
or else use the pytree leaf type if example is None.
|
||||
arr: The array to be unraveled.
|
||||
"""
|
||||
leaves, treedef = tree_flatten(pytree)
|
||||
@ -1304,8 +1305,7 @@ def _unravel_array_into_pytree(pytree, axis, cast_to_type, arr):
|
||||
shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves]
|
||||
parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis)
|
||||
reshaped_parts = [
|
||||
_possible_downcast(np.reshape(x, shape),
|
||||
_dtype(leaf) if cast_to_type is None else cast_to_type)
|
||||
_possible_downcast(np.reshape(x, shape), leaf if example is None else example)
|
||||
for x, shape, leaf in zip(parts, shapes, leaves)]
|
||||
return tree_unflatten(treedef, reshaped_parts)
|
||||
|
||||
@ -2994,7 +2994,7 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
|
||||
... return z
|
||||
...
|
||||
>>> jax.value_and_grad(g)(2.0)
|
||||
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32))
|
||||
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32, weak_type=True))
|
||||
|
||||
Here, the same value is produced whether or not the :func:`jax.checkpoint`
|
||||
decorator is present. When the decorator is not present, the values
|
||||
|
@ -1520,11 +1520,10 @@ def _device_put_raw(x, weak_type=None):
|
||||
|
||||
def zeros_like_shaped_array(aval):
|
||||
assert isinstance(aval, ShapedArray)
|
||||
scalar_zero = np.array(0).astype(aval.dtype)
|
||||
if scalar_zero.dtype != aval.dtype:
|
||||
# For numpy 1.17.5 we get here for float0. We use an alternate construction.
|
||||
assert aval.dtype == dtypes.float0
|
||||
if aval.dtype == dtypes.float0:
|
||||
scalar_zero = np.zeros((), dtype=aval.dtype)
|
||||
else:
|
||||
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
|
||||
return broadcast(scalar_zero, aval.shape)
|
||||
|
||||
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
||||
@ -1586,13 +1585,13 @@ def stop_gradient(x):
|
||||
For example:
|
||||
|
||||
>>> jax.grad(lambda x: x**2)(3.)
|
||||
DeviceArray(6., dtype=float32)
|
||||
DeviceArray(6., dtype=float32, weak_type=True)
|
||||
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
|
||||
DeviceArray(0., dtype=float32)
|
||||
DeviceArray(0., dtype=float32, weak_type=True)
|
||||
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
|
||||
DeviceArray(2., dtype=float32)
|
||||
DeviceArray(2., dtype=float32, weak_type=True)
|
||||
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
|
||||
DeviceArray(0., dtype=float32)
|
||||
DeviceArray(0., dtype=float32, weak_type=True)
|
||||
"""
|
||||
def stop(x):
|
||||
if (dtypes.issubdtype(_dtype(x), np.floating) or
|
||||
|
@ -1177,9 +1177,9 @@ class ConcreteArray(ShapedArray):
|
||||
__slots__ = ['val']
|
||||
array_abstraction_level = 0
|
||||
|
||||
def __init__(self, val, weak_type=False):
|
||||
def __init__(self, val, weak_type=None):
|
||||
super().__init__(np.shape(val), np.result_type(val),
|
||||
weak_type=weak_type)
|
||||
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
|
||||
# Note: canonicalized self.dtype doesn't necessarily match self.val
|
||||
self.val = val
|
||||
assert self.dtype != np.dtype('O'), val
|
||||
|
@ -40,8 +40,8 @@ import concurrent.futures
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
|
||||
from jax import core, dtypes, lax
|
||||
from jax._src import api
|
||||
from jax import core, lax
|
||||
from jax._src import api, dtypes
|
||||
from jax.core import Primitive
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import ad
|
||||
@ -847,6 +847,13 @@ class APITest(jtu.JaxTestCase):
|
||||
assert g(2.0) == 4.0
|
||||
assert len(side) == 1
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{transform.__name__}", "transform": transform}
|
||||
for transform in [grad, jacfwd, jacrev])
|
||||
def test_ad_weak_types(self, transform):
|
||||
out = transform(lambda x: x)(1.0)
|
||||
self.assertTrue(dtypes.is_weakly_typed(out))
|
||||
|
||||
def test_bad_input(self):
|
||||
def f(x):
|
||||
return x
|
||||
@ -2325,25 +2332,31 @@ class APITest(jtu.JaxTestCase):
|
||||
if not hasattr(self, "assertLogs"):
|
||||
raise unittest.SkipTest("test requires assertLogs (python 3)")
|
||||
|
||||
lax.add(1, 2) # make sure some initial warnings are already printed
|
||||
# make sure some initial warnings & cached operations already happen.
|
||||
api.grad(api.jit(lambda x: x))(1.0)
|
||||
|
||||
sin = api.jit(jnp.sin)
|
||||
@api.jit
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
prev_level = logging.get_verbosity()
|
||||
try:
|
||||
logging.set_verbosity('DEBUG')
|
||||
with self.assertLogs(level=logging.DEBUG) as l:
|
||||
ans1 = api.grad(sin)(2.)
|
||||
ans2 = api.grad(sin)(3.)
|
||||
ans1 = api.grad(f)(2.)
|
||||
ans2 = api.grad(f)(3.)
|
||||
finally:
|
||||
logging.set_verbosity(prev_level)
|
||||
self.assertLen(l.output, 2)
|
||||
|
||||
self.assertLen(l.output, 2) # one for fwd, one for bwd
|
||||
self.assertAllClose(ans1, np.cos(2.), check_dtypes=False)
|
||||
self.assertAllClose(ans2, np.cos(3.), check_dtypes=False)
|
||||
|
||||
def test_grad_of_jit_compilation_caching2(self):
|
||||
# Like the above test, but instead of logging use our compile counters.
|
||||
|
||||
# make sure some initial convert element type operations are pre-cached.
|
||||
api.grad(api.jit(lambda x: x))(1.0)
|
||||
|
||||
@api.jit
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user