mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
remove some trailing whitespace (#3287)
This commit is contained in:
parent
ea4277b030
commit
c42a7f7890
@ -18,7 +18,7 @@ This script contains a JAX implementation of Differentially Private Stochastic
|
||||
Gradient Descent (https://arxiv.org/abs/1607.00133). DPSGD requires clipping
|
||||
the per-example parameter gradients, which is non-trivial to implement
|
||||
efficiently for convolutional neural networks. The JAX XLA compiler shines in
|
||||
this setting by optimizing the minibatch-vectorized computation for
|
||||
this setting by optimizing the minibatch-vectorized computation for
|
||||
convolutional architectures. Train time takes a few seconds per epoch on a
|
||||
commodity GPU.
|
||||
|
||||
|
@ -342,11 +342,11 @@ positional arguments and parameters:
|
||||
* nr_untapped: how many positional arguments (from the tail) should not be
|
||||
passed to the tap function.
|
||||
* arg_treedef: the treedef of the tapped positional arguments.
|
||||
* transforms: a tuple of the transformations that have been applied. Each
|
||||
* transforms: a tuple of the transformations that have been applied. Each
|
||||
element of the tuple is itself a tuple with the first element the name
|
||||
of the transform. The remaining elements depend on the transform. For
|
||||
example, for `batch`, the parameters are the dimensions that have been
|
||||
batched, and for `mask` the logical shapes. These are unpacked by
|
||||
of the transform. The remaining elements depend on the transform. For
|
||||
example, for `batch`, the parameters are the dimensions that have been
|
||||
batched, and for `mask` the logical shapes. These are unpacked by
|
||||
_ConsumerCallable before passing to the user function.
|
||||
|
||||
* the remaining parameters are passed to the tap function.
|
||||
|
@ -23,9 +23,9 @@ What is a gufunc?
|
||||
("gufuncs") are one of my favorite abstractions from NumPy. They generalize
|
||||
NumPy's `broadcasting rules
|
||||
<https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html>`_ to
|
||||
handle non-scalar operations. When a gufuncs is applied to arrays, there are:
|
||||
handle non-scalar operations. When a gufuncs is applied to arrays, there are:
|
||||
|
||||
* "core dimensions" over which an operation is defined.
|
||||
* "core dimensions" over which an operation is defined.
|
||||
* "broadcast dimensions" over which operations can be automatically vectorized.
|
||||
|
||||
A string `signature <https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html#details-of-signature>`_
|
||||
@ -199,7 +199,7 @@ def _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims):
|
||||
return [broadcast_shape + tuple(dim_sizes[dim] for dim in core_dims)
|
||||
for core_dims in list_of_core_dims]
|
||||
|
||||
|
||||
|
||||
# adapted from np.vectorize (again authored by shoyer@)
|
||||
def broadcast_with_core_dims(args, input_core_dims, output_core_dims):
|
||||
if len(args) != len(input_core_dims):
|
||||
@ -245,7 +245,7 @@ def vectorize(signature):
|
||||
"""Vectorize a function using JAX.
|
||||
|
||||
Turns an arbitrary function into a numpy style "gufunc". Once
|
||||
you specify the behavior of the core axis, the rest will be
|
||||
you specify the behavior of the core axis, the rest will be
|
||||
broadcast naturally.
|
||||
|
||||
Args:
|
||||
@ -258,7 +258,7 @@ def vectorize(signature):
|
||||
which axis should be treated as the core one.
|
||||
"""
|
||||
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
|
||||
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
@ -78,7 +78,7 @@ kaiming_normal = he_normal = partial(variance_scaling, 2.0, "fan_in", "truncated
|
||||
def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
|
||||
"""
|
||||
Construct an initializer for uniformly distributed orthogonal matrices.
|
||||
|
||||
|
||||
If the shape is not square, the matrices will have orthonormal rows or columns
|
||||
depending on which side is smaller.
|
||||
"""
|
||||
@ -100,7 +100,7 @@ def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
|
||||
|
||||
def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
|
||||
"""
|
||||
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
|
||||
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
|
||||
|
||||
The shape must be 3D, 4D or 5D.
|
||||
"""
|
||||
|
@ -170,14 +170,14 @@ def _cofactor_solve(a, b):
|
||||
If a is rank n-1, then the lower right corner of u will be zero and the
|
||||
triangular_solve will fail.
|
||||
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
|
||||
Then y_{n} =
|
||||
Then y_{n}
|
||||
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
|
||||
x_{n} * prod_{i=1...n-1}(u_{ii})
|
||||
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
|
||||
we can avoid the triangular_solve failing.
|
||||
To correctly compute the rest of y_{i} for i != n, we simply multiply
|
||||
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
|
||||
|
||||
|
||||
For the second case, a check is done on the matrix to see if `solve`
|
||||
returns NaN or Inf, and gives a matrix of zeros as a result, as the
|
||||
gradient of the determinant of a matrix with rank less than n-1 is 0.
|
||||
|
@ -54,13 +54,13 @@ def _nonzero_range(arr):
|
||||
|
||||
@_wraps(np.roots, lax_description="""\
|
||||
If the input polynomial coefficients of length n do not start with zero,
|
||||
the polynomial is of degree n - 1 leading to n - 1 roots.
|
||||
the polynomial is of degree n - 1 leading to n - 1 roots.
|
||||
If the coefficients do have leading zeros, the polynomial they define
|
||||
has a smaller degree and the number of roots (and thus the output shape)
|
||||
has a smaller degree and the number of roots (and thus the output shape)
|
||||
is value dependent.
|
||||
|
||||
The general implementation can therefore not be transformed with jit.
|
||||
If the coefficients are guaranteed to have no leading zeros, use the
|
||||
If the coefficients are guaranteed to have no leading zeros, use the
|
||||
keyword argument `strip_zeros=False` to get a jit-compatible variant::
|
||||
|
||||
>>> roots_unsafe = jax.jit(functools.partial(jnp.roots, strip_zeros=False))
|
||||
|
@ -52,7 +52,7 @@ def _parse_gufunc_signature(
|
||||
'not a valid gufunc signature: {}'.format(signature))
|
||||
args, retvals = ([tuple(re.findall(_DIMENSION_NAME, arg))
|
||||
for arg in re.findall(_ARGUMENT, arg_list)]
|
||||
for arg_list in signature.split('->'))
|
||||
for arg_list in signature.split('->'))
|
||||
return args, retvals
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ Context:
|
||||
Among other requirements, the JAX PRNG aims to:
|
||||
(a) ensure reproducibility,
|
||||
(b) parallelize well, both in terms of vectorization (generating array values)
|
||||
and multi-replica, multi-core computation. In particular it should not use
|
||||
and multi-replica, multi-core computation. In particular it should not use
|
||||
sequencing constraints between random function calls.
|
||||
|
||||
The approach is based on:
|
||||
|
@ -23,7 +23,7 @@ from .. import api
|
||||
from ..numpy import lax_numpy as jnp
|
||||
from ..numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
|
||||
_promote_args_inexact)
|
||||
from ..numpy._util import _wraps
|
||||
from ..numpy._util import _wraps
|
||||
|
||||
|
||||
@_wraps(osp_special.gammaln)
|
||||
|
6
jax/third_party/numpy/linalg.py
vendored
6
jax/third_party/numpy/linalg.py
vendored
@ -89,7 +89,7 @@ def tensorsolve(a, b, axes=None):
|
||||
allaxes.insert(an, k)
|
||||
|
||||
a = a.transpose(allaxes)
|
||||
|
||||
|
||||
Q = a.shape[-(an - b.ndim):]
|
||||
|
||||
prod = 1
|
||||
@ -98,10 +98,10 @@ def tensorsolve(a, b, axes=None):
|
||||
|
||||
a = a.reshape(-1, prod)
|
||||
b = b.ravel()
|
||||
|
||||
|
||||
res = jnp.asarray(la.solve(a, b))
|
||||
res = res.reshape(Q)
|
||||
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
@ -71,7 +71,7 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_swap={}_jit={}".format(swap, jit),
|
||||
"swap": swap, "jit": jit}
|
||||
"swap": swap, "jit": jit}
|
||||
for swap in [False, True] for jit in [False, True])
|
||||
@jtu.skip_on_devices("tpu") # F16 not supported on TPU
|
||||
def testBinaryPromotion(self, swap, jit):
|
||||
|
@ -774,7 +774,7 @@ transforms: ({'name': 'jvp'},) what: y * 3
|
||||
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
|
||||
2.00""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
|
||||
with hcb.outfeed_receiver():
|
||||
res_grad = grad_func(jnp.float32(5.))
|
||||
|
||||
@ -843,7 +843,7 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
||||
with hcb.outfeed_receiver():
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let
|
||||
let
|
||||
in (12.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
|
||||
# Just making the Jaxpr invokes the id_print twiceonce
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
@ -1100,7 +1100,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
in (d, e, h) }
|
||||
linear=(False, False, False, False, False, False)
|
||||
true_jaxpr={ lambda ; d g_ a b c h.
|
||||
let
|
||||
let
|
||||
in (a, d, h) } ] c d e 1 2 b h
|
||||
in (f, g, i) }""", func, [y, 5])
|
||||
|
||||
@ -1176,7 +1176,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
in (w, t, u, x) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; j k l m.
|
||||
let
|
||||
let
|
||||
in (j,) }
|
||||
cond_nconsts=0 ] b c h a 1 i
|
||||
in (d, 5, g) }""", func, [ct_body])
|
||||
|
@ -2027,7 +2027,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
{"testcase_name":
|
||||
f"_arg{i}_ndmin={ndmin}_dtype={np.dtype(dtype) if dtype else None}",
|
||||
"arg": arg, "ndmin": ndmin, "dtype": dtype}
|
||||
for i, (arg, dtypes) in enumerate([
|
||||
|
@ -103,7 +103,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
def testDetOfSingularMatrix(self):
|
||||
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
|
||||
self.assertAllClose(np.float32(0), jsp.linalg.det(x))
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
@ -174,10 +174,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
result = jnp.linalg.tensorsolve(*args_maker())
|
||||
self.assertEqual(result.shape, Q)
|
||||
|
||||
self._CheckAgainstNumpy(np.linalg.tensorsolve,
|
||||
self._CheckAgainstNumpy(np.linalg.tensorsolve,
|
||||
jnp.linalg.tensorsolve, args_maker,
|
||||
tol={np.float32: 1e-2, np.float64: 1e-3})
|
||||
self._CompileAndCheck(jnp.linalg.tensorsolve,
|
||||
self._CompileAndCheck(jnp.linalg.tensorsolve,
|
||||
args_maker,
|
||||
rtol={np.float64: 1e-13})
|
||||
|
||||
|
@ -88,7 +88,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
def testEluValue(self):
|
||||
val = nn.elu(1e4)
|
||||
self.assertAllClose(val, 1e4, check_dtypes=False)
|
||||
|
||||
|
||||
def testGluValue(self):
|
||||
val = nn.glu(jnp.array([1.0, 0.0]))
|
||||
self.assertAllClose(val, jnp.array([0.5]))
|
||||
|
@ -134,12 +134,12 @@ class VectorizeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(a.shape, (3, 4))
|
||||
self.assertEqual(b.shape, (3,))
|
||||
self.assertAllClose(jnp.mean(X, axis=1), b)
|
||||
|
||||
|
||||
b, a = center(X, axis=0)
|
||||
self.assertEqual(a.shape, (3, 4))
|
||||
self.assertEqual(b.shape, (4,))
|
||||
self.assertAllClose(jnp.mean(X, axis=0), b)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user