mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Enable flake8 checks for spaces around operators.
This commit is contained in:
parent
66cbb2225b
commit
b232d09440
@ -35,7 +35,7 @@ def _fill_keys_cubic_kernel(x):
|
||||
# IEEE Transactions on Acoustics, Speech, and Signal Processing,
|
||||
# 29(6):1153–1160, 1981.
|
||||
out = ((1.5 * x - 2.5) * x) * x + 1.
|
||||
out = jnp.where(x >= 1., ((-0.5* x + 2.5) * x - 4.) * x + 2., out)
|
||||
out = jnp.where(x >= 1., ((-0.5 * x + 2.5) * x - 4.) * x + 2., out)
|
||||
return jnp.where(x >= 2., 0., out)
|
||||
|
||||
|
||||
|
@ -1924,7 +1924,7 @@ def collapse(operand: Array, start_dimension: int,
|
||||
|
||||
def slice_in_dim(operand: Array, start_index: Optional[int],
|
||||
limit_index: Optional[int],
|
||||
stride: int = 1, axis: int = 0)-> Array:
|
||||
stride: int = 1, axis: int = 0) -> Array:
|
||||
"""Convenience wrapper around slice applying to only one dimension."""
|
||||
start_indices = [0] * operand.ndim
|
||||
limit_indices = list(operand.shape)
|
||||
|
@ -3345,8 +3345,8 @@ def meshgrid(*args, **kwargs):
|
||||
|
||||
|
||||
def _make_1d_grid_from_slice(s: slice, op_name: str):
|
||||
start =core.concrete_or_error(None, s.start,
|
||||
f"slice start of jnp.{op_name}") or 0
|
||||
start = core.concrete_or_error(None, s.start,
|
||||
f"slice start of jnp.{op_name}") or 0
|
||||
stop = core.concrete_or_error(None, s.stop,
|
||||
f"slice stop of jnp.{op_name}")
|
||||
step = core.concrete_or_error(None, s.step,
|
||||
@ -3907,7 +3907,7 @@ def diag(v, k=0):
|
||||
else:
|
||||
raise ValueError("diag input must be 1d or 2d")
|
||||
|
||||
_SCALAR_VALUE_DOC="""\
|
||||
_SCALAR_VALUE_DOC = """\
|
||||
This differs from np.diagflat for some scalar values of v,
|
||||
jax always returns a two-dimensional array, whereas numpy may
|
||||
return a scalar depending on the type of v.
|
||||
@ -3929,7 +3929,7 @@ def diagflat(v, k=0):
|
||||
res = res.reshape(adj_length,adj_length)
|
||||
return res
|
||||
|
||||
_POLY_DOC="""\
|
||||
_POLY_DOC = """\
|
||||
This differs from np.poly when an integer array is given.
|
||||
np.poly returns a result with dtype float64 in this case.
|
||||
jax returns a result with an inexact type, but not necessarily
|
||||
@ -4032,7 +4032,7 @@ def trim_zeros(filt, trim='fb'):
|
||||
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
|
||||
return filt[start:len(filt) - end]
|
||||
|
||||
_LEADING_ZEROS_DOC="""\
|
||||
_LEADING_ZEROS_DOC = """\
|
||||
Setting trim_leading_zeros=True makes the output match that of numpy.
|
||||
But prevents the function from being able to be used in compiled code.
|
||||
"""
|
||||
|
@ -288,10 +288,10 @@ def _calc_P_Q(A):
|
||||
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
|
||||
A = A / 2**n_squarings
|
||||
U13, V13 = _pade13(A)
|
||||
conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
|
||||
9.504178996162932e-001, 2.097847961257068e+000])
|
||||
U = jnp.select((A_L1<conds), (U3, U5, U7, U9), U13)
|
||||
V = jnp.select((A_L1<conds), (V3, V5, V7, V9), V13)
|
||||
conds = jnp.array([1.495585217958292e-002, 2.539398330063230e-001,
|
||||
9.504178996162932e-001, 2.097847961257068e+000])
|
||||
U = jnp.select((A_L1 < conds), (U3, U5, U7, U9), U13)
|
||||
V = jnp.select((A_L1 < conds), (V3, V5, V7, V9), V13)
|
||||
elif A.dtype == 'float32' or A.dtype == 'complex64':
|
||||
U3,V3 = _pade3(A)
|
||||
U5,V5 = _pade5(A)
|
||||
@ -299,9 +299,9 @@ def _calc_P_Q(A):
|
||||
n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
|
||||
A = A / 2**n_squarings
|
||||
U7,V7 = _pade7(A)
|
||||
conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
|
||||
U = jnp.select((A_L1<conds), (U3, U5), U7)
|
||||
V = jnp.select((A_L1<conds), (V3, V5), V7)
|
||||
conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000])
|
||||
U = jnp.select((A_L1 < conds), (U3, U5), U7)
|
||||
V = jnp.select((A_L1 < conds), (V3, V5), V7)
|
||||
else:
|
||||
raise TypeError("A.dtype={} is not supported.".format(A.dtype))
|
||||
P = U + V # p_m(A) : numerator
|
||||
|
@ -364,7 +364,7 @@ def _exp_taylor(primals_in, series_in):
|
||||
u = [x] + series
|
||||
v = [lax.exp(x)] + [None] * len(series)
|
||||
for k in range(1,len(v)):
|
||||
v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
|
||||
v[k] = fact(k-1) * sum([_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1)])
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.exp_p] = _exp_taylor
|
||||
@ -377,7 +377,7 @@ def _pow_taylor(primals_in, series_in):
|
||||
u = [x] + series
|
||||
v = [u_ ** r_] + [None] * len(series)
|
||||
for k in range(1, len(v)):
|
||||
v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
|
||||
v[k] = fact(k-1) * sum([_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1)])
|
||||
primal_out, *series_out = v
|
||||
|
||||
return primal_out, series_out
|
||||
@ -412,7 +412,7 @@ def _expit_taylor(primals_in, series_in):
|
||||
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
|
||||
for k in range(1, len(v)):
|
||||
v[k] = fact(k-1) * sum([_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)])
|
||||
e[k] = (1 - v[0]) * v[k] - fact(k) * sum([_scale2(k, j)* v[j] * v[k-j] for j in range(1, k+1)])
|
||||
e[k] = (1 - v[0]) * v[k] - fact(k) * sum([_scale2(k, j) * v[j] * v[k-j] for j in range(1, k+1)])
|
||||
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
|
@ -6,7 +6,7 @@ ignore =
|
||||
E121 # line continuations
|
||||
W503, W504 # line breaks around binary operators
|
||||
max-complexity = 18
|
||||
select = B,C,F,W,T4,B9
|
||||
select = B,C,F,W,T4,B9,E225,E227,E228
|
||||
exclude =
|
||||
.git,
|
||||
build,
|
||||
|
@ -1222,11 +1222,11 @@ class APITest(jtu.JaxTestCase):
|
||||
return jnp.sum(jnp.cos(jnp.abs(z)))
|
||||
|
||||
ans = grad(f)(zs)
|
||||
expected = np.array([ 0. +0.j,
|
||||
-0.80430663+0.40215331j,
|
||||
-0.70368982+0.35184491j,
|
||||
0.1886467 -0.09432335j,
|
||||
0.86873727-0.43436864j])
|
||||
expected = np.array([ 0. + 0.j,
|
||||
-0.80430663 + 0.40215331j,
|
||||
-0.70368982 + 0.35184491j,
|
||||
0.1886467 - 0.09432335j,
|
||||
0.86873727 - 0.43436864j])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False,
|
||||
atol=jtu.default_gradient_tolerance,
|
||||
rtol=jtu.default_gradient_tolerance)
|
||||
|
@ -440,7 +440,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
new_w, new_v = f(new_a)
|
||||
new_a = (new_a + np.conj(new_a.T)) / 2
|
||||
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
|
||||
RTOL=1e-2
|
||||
RTOL = 1e-2
|
||||
assert np.max(
|
||||
np.abs((np.diag(np.dot(np.conj((v+dv).T), np.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
|
||||
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
|
||||
@ -676,7 +676,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
# Check that q is close to unitary.
|
||||
self.assertTrue(np.all(
|
||||
norm(np.eye(k) -np.matmul(np.conj(T(lq)), lq)) < 5))
|
||||
norm(np.eye(k) - np.matmul(np.conj(T(lq)), lq)) < 5))
|
||||
|
||||
if not full_matrices and m >= n:
|
||||
jtu.check_jvp(jnp.linalg.qr, partial(jvp, jnp.linalg.qr), (a,), atol=3e-3)
|
||||
|
@ -910,8 +910,8 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# 1/2 CDF_one_maxwell((x - loc) / scale))
|
||||
# + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale)))
|
||||
def double_sided_maxwell_cdf(x, loc, scale):
|
||||
pos = scipy.stats.maxwell().cdf((x - loc)/ scale)
|
||||
neg = (1 - scipy.stats.maxwell().cdf((-x + loc)/ scale))
|
||||
pos = scipy.stats.maxwell().cdf((x - loc) / scale)
|
||||
neg = (1 - scipy.stats.maxwell().cdf((-x + loc) / scale))
|
||||
return (pos + neg) / 2
|
||||
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
@ -939,9 +939,9 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
assert len(counts) == 2
|
||||
|
||||
self.assertAllClose(
|
||||
counts[0]/ num_samples, 0.5, rtol=1e-02, atol=1e-02)
|
||||
counts[0] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
|
||||
self.assertAllClose(
|
||||
counts[1]/ num_samples, 0.5, rtol=1e-02, atol=1e-02)
|
||||
counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
|
||||
|
||||
def testChoiceShapeIsNotSequenceError(self):
|
||||
key = random.PRNGKey(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user