remove some trailing whitespace (#3287)

This commit is contained in:
Matthew Johnson 2020-06-02 17:37:20 -07:00 committed by GitHub
parent ea4277b030
commit c42a7f7890
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 35 additions and 35 deletions

View File

@ -18,7 +18,7 @@ This script contains a JAX implementation of Differentially Private Stochastic
Gradient Descent (https://arxiv.org/abs/1607.00133). DPSGD requires clipping
the per-example parameter gradients, which is non-trivial to implement
efficiently for convolutional neural networks. The JAX XLA compiler shines in
this setting by optimizing the minibatch-vectorized computation for
this setting by optimizing the minibatch-vectorized computation for
convolutional architectures. Train time takes a few seconds per epoch on a
commodity GPU.

View File

@ -342,11 +342,11 @@ positional arguments and parameters:
* nr_untapped: how many positional arguments (from the tail) should not be
passed to the tap function.
* arg_treedef: the treedef of the tapped positional arguments.
* transforms: a tuple of the transformations that have been applied. Each
* transforms: a tuple of the transformations that have been applied. Each
element of the tuple is itself a tuple with the first element the name
of the transform. The remaining elements depend on the transform. For
example, for `batch`, the parameters are the dimensions that have been
batched, and for `mask` the logical shapes. These are unpacked by
of the transform. The remaining elements depend on the transform. For
example, for `batch`, the parameters are the dimensions that have been
batched, and for `mask` the logical shapes. These are unpacked by
_ConsumerCallable before passing to the user function.
* the remaining parameters are passed to the tap function.

View File

@ -23,9 +23,9 @@ What is a gufunc?
("gufuncs") are one of my favorite abstractions from NumPy. They generalize
NumPy's `broadcasting rules
<https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html>`_ to
handle non-scalar operations. When a gufuncs is applied to arrays, there are:
handle non-scalar operations. When a gufuncs is applied to arrays, there are:
* "core dimensions" over which an operation is defined.
* "core dimensions" over which an operation is defined.
* "broadcast dimensions" over which operations can be automatically vectorized.
A string `signature <https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html#details-of-signature>`_
@ -199,7 +199,7 @@ def _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims):
return [broadcast_shape + tuple(dim_sizes[dim] for dim in core_dims)
for core_dims in list_of_core_dims]
# adapted from np.vectorize (again authored by shoyer@)
def broadcast_with_core_dims(args, input_core_dims, output_core_dims):
if len(args) != len(input_core_dims):
@ -245,7 +245,7 @@ def vectorize(signature):
"""Vectorize a function using JAX.
Turns an arbitrary function into a numpy style "gufunc". Once
you specify the behavior of the core axis, the rest will be
you specify the behavior of the core axis, the rest will be
broadcast naturally.
Args:
@ -258,7 +258,7 @@ def vectorize(signature):
which axis should be treated as the core one.
"""
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):

View File

@ -78,7 +78,7 @@ kaiming_normal = he_normal = partial(variance_scaling, 2.0, "fan_in", "truncated
def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
"""
Construct an initializer for uniformly distributed orthogonal matrices.
If the shape is not square, the matrices will have orthonormal rows or columns
depending on which side is smaller.
"""
@ -100,7 +100,7 @@ def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
"""
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
The shape must be 3D, 4D or 5D.
"""

View File

@ -170,14 +170,14 @@ def _cofactor_solve(a, b):
If a is rank n-1, then the lower right corner of u will be zero and the
triangular_solve will fail.
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
Then y_{n} =
Then y_{n}
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
x_{n} * prod_{i=1...n-1}(u_{ii})
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
we can avoid the triangular_solve failing.
To correctly compute the rest of y_{i} for i != n, we simply multiply
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
For the second case, a check is done on the matrix to see if `solve`
returns NaN or Inf, and gives a matrix of zeros as a result, as the
gradient of the determinant of a matrix with rank less than n-1 is 0.

View File

@ -54,13 +54,13 @@ def _nonzero_range(arr):
@_wraps(np.roots, lax_description="""\
If the input polynomial coefficients of length n do not start with zero,
the polynomial is of degree n - 1 leading to n - 1 roots.
the polynomial is of degree n - 1 leading to n - 1 roots.
If the coefficients do have leading zeros, the polynomial they define
has a smaller degree and the number of roots (and thus the output shape)
has a smaller degree and the number of roots (and thus the output shape)
is value dependent.
The general implementation can therefore not be transformed with jit.
If the coefficients are guaranteed to have no leading zeros, use the
If the coefficients are guaranteed to have no leading zeros, use the
keyword argument `strip_zeros=False` to get a jit-compatible variant::
>>> roots_unsafe = jax.jit(functools.partial(jnp.roots, strip_zeros=False))

View File

@ -52,7 +52,7 @@ def _parse_gufunc_signature(
'not a valid gufunc signature: {}'.format(signature))
args, retvals = ([tuple(re.findall(_DIMENSION_NAME, arg))
for arg in re.findall(_ARGUMENT, arg_list)]
for arg_list in signature.split('->'))
for arg_list in signature.split('->'))
return args, retvals

View File

@ -26,7 +26,7 @@ Context:
Among other requirements, the JAX PRNG aims to:
(a) ensure reproducibility,
(b) parallelize well, both in terms of vectorization (generating array values)
and multi-replica, multi-core computation. In particular it should not use
and multi-replica, multi-core computation. In particular it should not use
sequencing constraints between random function calls.
The approach is based on:

View File

@ -23,7 +23,7 @@ from .. import api
from ..numpy import lax_numpy as jnp
from ..numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
_promote_args_inexact)
from ..numpy._util import _wraps
from ..numpy._util import _wraps
@_wraps(osp_special.gammaln)

View File

@ -89,7 +89,7 @@ def tensorsolve(a, b, axes=None):
allaxes.insert(an, k)
a = a.transpose(allaxes)
Q = a.shape[-(an - b.ndim):]
prod = 1
@ -98,10 +98,10 @@ def tensorsolve(a, b, axes=None):
a = a.reshape(-1, prod)
b = b.ravel()
res = jnp.asarray(la.solve(a, b))
res = res.reshape(Q)
return res

View File

@ -71,7 +71,7 @@ class DtypesTest(jtu.JaxTestCase):
@parameterized.named_parameters(
{"testcase_name": "_swap={}_jit={}".format(swap, jit),
"swap": swap, "jit": jit}
"swap": swap, "jit": jit}
for swap in [False, True] for jit in [False, True])
@jtu.skip_on_devices("tpu") # F16 not supported on TPU
def testBinaryPromotion(self, swap, jit):

View File

@ -774,7 +774,7 @@ transforms: ({'name': 'jvp'},) what: y * 3
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
testing_stream.reset()
with hcb.outfeed_receiver():
res_grad = grad_func(jnp.float32(5.))
@ -843,7 +843,7 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
with hcb.outfeed_receiver():
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let
let
in (12.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
# Just making the Jaxpr invokes the id_print twiceonce
assertMultiLineStrippedEqual(self, """
@ -1100,7 +1100,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
in (d, e, h) }
linear=(False, False, False, False, False, False)
true_jaxpr={ lambda ; d g_ a b c h.
let
let
in (a, d, h) } ] c d e 1 2 b h
in (f, g, i) }""", func, [y, 5])
@ -1176,7 +1176,7 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
in (w, t, u, x) }
body_nconsts=2
cond_jaxpr={ lambda ; j k l m.
let
let
in (j,) }
cond_nconsts=0 ] b c h a 1 i
in (d, 5, g) }""", func, [ct_body])

View File

@ -2027,7 +2027,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
{"testcase_name":
f"_arg{i}_ndmin={ndmin}_dtype={np.dtype(dtype) if dtype else None}",
"arg": arg, "ndmin": ndmin, "dtype": dtype}
for i, (arg, dtypes) in enumerate([

View File

@ -103,7 +103,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
def testDetOfSingularMatrix(self):
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
self.assertAllClose(np.float32(0), jsp.linalg.det(x))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
@ -174,10 +174,10 @@ class NumpyLinalgTest(jtu.JaxTestCase):
result = jnp.linalg.tensorsolve(*args_maker())
self.assertEqual(result.shape, Q)
self._CheckAgainstNumpy(np.linalg.tensorsolve,
self._CheckAgainstNumpy(np.linalg.tensorsolve,
jnp.linalg.tensorsolve, args_maker,
tol={np.float32: 1e-2, np.float64: 1e-3})
self._CompileAndCheck(jnp.linalg.tensorsolve,
self._CompileAndCheck(jnp.linalg.tensorsolve,
args_maker,
rtol={np.float64: 1e-13})

View File

@ -88,7 +88,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
def testEluValue(self):
val = nn.elu(1e4)
self.assertAllClose(val, 1e4, check_dtypes=False)
def testGluValue(self):
val = nn.glu(jnp.array([1.0, 0.0]))
self.assertAllClose(val, jnp.array([0.5]))

View File

@ -134,12 +134,12 @@ class VectorizeTest(jtu.JaxTestCase):
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (3,))
self.assertAllClose(jnp.mean(X, axis=1), b)
b, a = center(X, axis=0)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (4,))
self.assertAllClose(jnp.mean(X, axis=0), b)
if __name__ == "__main__":
absltest.main()