Enable flake8 checks for spaces around operators.

This commit is contained in:
Peter Hawkins 2021-07-29 09:51:41 -04:00
parent 66cbb2225b
commit b232d09440
9 changed files with 29 additions and 29 deletions

View File

@ -35,7 +35,7 @@ def _fill_keys_cubic_kernel(x):
# IEEE Transactions on Acoustics, Speech, and Signal Processing,
# 29(6):11531160, 1981.
out = ((1.5 * x - 2.5) * x) * x + 1.
out = jnp.where(x >= 1., ((-0.5* x + 2.5) * x - 4.) * x + 2., out)
out = jnp.where(x >= 1., ((-0.5 * x + 2.5) * x - 4.) * x + 2., out)
return jnp.where(x >= 2., 0., out)

View File

@ -1924,7 +1924,7 @@ def collapse(operand: Array, start_dimension: int,
def slice_in_dim(operand: Array, start_index: Optional[int],
limit_index: Optional[int],
stride: int = 1, axis: int = 0)-> Array:
stride: int = 1, axis: int = 0) -> Array:
"""Convenience wrapper around slice applying to only one dimension."""
start_indices = [0] * operand.ndim
limit_indices = list(operand.shape)

View File

@ -3345,8 +3345,8 @@ def meshgrid(*args, **kwargs):
def _make_1d_grid_from_slice(s: slice, op_name: str):
start =core.concrete_or_error(None, s.start,
f"slice start of jnp.{op_name}") or 0
start = core.concrete_or_error(None, s.start,
f"slice start of jnp.{op_name}") or 0
stop = core.concrete_or_error(None, s.stop,
f"slice stop of jnp.{op_name}")
step = core.concrete_or_error(None, s.step,
@ -3907,7 +3907,7 @@ def diag(v, k=0):
else:
raise ValueError("diag input must be 1d or 2d")
_SCALAR_VALUE_DOC="""\
_SCALAR_VALUE_DOC = """\
This differs from np.diagflat for some scalar values of v,
jax always returns a two-dimensional array, whereas numpy may
return a scalar depending on the type of v.
@ -3929,7 +3929,7 @@ def diagflat(v, k=0):
res = res.reshape(adj_length,adj_length)
return res
_POLY_DOC="""\
_POLY_DOC = """\
This differs from np.poly when an integer array is given.
np.poly returns a result with dtype float64 in this case.
jax returns a result with an inexact type, but not necessarily
@ -4032,7 +4032,7 @@ def trim_zeros(filt, trim='fb'):
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt[start:len(filt) - end]
_LEADING_ZEROS_DOC="""\
_LEADING_ZEROS_DOC = """\
Setting trim_leading_zeros=True makes the output match that of numpy.
But prevents the function from being able to be used in compiled code.
"""

View File

@ -288,10 +288,10 @@ def _calc_P_Q(A):
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
A = A / 2**n_squarings
U13, V13 = _pade13(A)
conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
9.504178996162932e-001, 2.097847961257068e+000])
U = jnp.select((A_L1<conds), (U3, U5, U7, U9), U13)
V = jnp.select((A_L1<conds), (V3, V5, V7, V9), V13)
conds = jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
9.504178996162932e-001, 2.097847961257068e+000])
U = jnp.select((A_L1 < conds), (U3, U5, U7, U9), U13)
V = jnp.select((A_L1 < conds), (V3, V5, V7, V9), V13)
elif A.dtype == 'float32' or A.dtype == 'complex64':
U3,V3 = _pade3(A)
U5,V5 = _pade5(A)
@ -299,9 +299,9 @@ def _calc_P_Q(A):
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
A = A / 2**n_squarings
U7,V7 = _pade7(A)
conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
U = jnp.select((A_L1<conds), (U3, U5), U7)
V = jnp.select((A_L1<conds), (V3, V5), V7)
conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
U = jnp.select((A_L1 < conds), (U3, U5), U7)
V = jnp.select((A_L1 < conds), (V3, V5), V7)
else:
raise TypeError("A.dtype={} is not supported.".format(A.dtype))
P = U + V # p_m(A) : numerator

View File

@ -364,7 +364,7 @@ def _exp_taylor(primals_in, series_in):
u = [x] + series
v = [lax.exp(x)] + [None] * len(series)
for k in range(1,len(v)):
v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
v[k] = fact(k-1) * sum([_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1)])
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.exp_p] = _exp_taylor
@ -377,7 +377,7 @@ def _pow_taylor(primals_in, series_in):
u = [x] + series
v = [u_ ** r_] + [None] * len(series)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
v[k] = fact(k-1) * sum([_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1)])
primal_out, *series_out = v
return primal_out, series_out
@ -412,7 +412,7 @@ def _expit_taylor(primals_in, series_in):
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum([_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)])
e[k] = (1 - v[0]) * v[k] - fact(k) * sum([_scale2(k, j)* v[j] * v[k-j] for j in range(1, k+1)])
e[k] = (1 - v[0]) * v[k] - fact(k) * sum([_scale2(k, j) * v[j] * v[k-j] for j in range(1, k+1)])
primal_out, *series_out = v
return primal_out, series_out

View File

@ -6,7 +6,7 @@ ignore =
E121 # line continuations
W503, W504 # line breaks around binary operators
max-complexity = 18
select = B,C,F,W,T4,B9
select = B,C,F,W,T4,B9,E225,E227,E228
exclude =
.git,
build,

View File

@ -1222,11 +1222,11 @@ class APITest(jtu.JaxTestCase):
return jnp.sum(jnp.cos(jnp.abs(z)))
ans = grad(f)(zs)
expected = np.array([ 0. +0.j,
-0.80430663+0.40215331j,
-0.70368982+0.35184491j,
0.1886467 -0.09432335j,
0.86873727-0.43436864j])
expected = np.array([ 0. + 0.j,
-0.80430663 + 0.40215331j,
-0.70368982 + 0.35184491j,
0.1886467 - 0.09432335j,
0.86873727 - 0.43436864j])
self.assertAllClose(ans, expected, check_dtypes=False,
atol=jtu.default_gradient_tolerance,
rtol=jtu.default_gradient_tolerance)

View File

@ -440,7 +440,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
new_w, new_v = f(new_a)
new_a = (new_a + np.conj(new_a.T)) / 2
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
RTOL=1e-2
RTOL = 1e-2
assert np.max(
np.abs((np.diag(np.dot(np.conj((v+dv).T), np.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
@ -676,7 +676,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# Check that q is close to unitary.
self.assertTrue(np.all(
norm(np.eye(k) -np.matmul(np.conj(T(lq)), lq)) < 5))
norm(np.eye(k) - np.matmul(np.conj(T(lq)), lq)) < 5))
if not full_matrices and m >= n:
jtu.check_jvp(jnp.linalg.qr, partial(jvp, jnp.linalg.qr), (a,), atol=3e-3)

View File

@ -910,8 +910,8 @@ class LaxRandomTest(jtu.JaxTestCase):
# 1/2 CDF_one_maxwell((x - loc) / scale))
# + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale)))
def double_sided_maxwell_cdf(x, loc, scale):
pos = scipy.stats.maxwell().cdf((x - loc)/ scale)
neg = (1 - scipy.stats.maxwell().cdf((-x + loc)/ scale))
pos = scipy.stats.maxwell().cdf((x - loc) / scale)
neg = (1 - scipy.stats.maxwell().cdf((-x + loc) / scale))
return (pos + neg) / 2
for samples in [uncompiled_samples, compiled_samples]:
@ -939,9 +939,9 @@ class LaxRandomTest(jtu.JaxTestCase):
assert len(counts) == 2
self.assertAllClose(
counts[0]/ num_samples, 0.5, rtol=1e-02, atol=1e-02)
counts[0] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
self.assertAllClose(
counts[1]/ num_samples, 0.5, rtol=1e-02, atol=1e-02)
counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
def testChoiceShapeIsNotSequenceError(self):
key = random.PRNGKey(0)