From b232d0944071041f739d2015430edbc2ff7c0612 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 29 Jul 2021 09:51:41 -0400 Subject: [PATCH] Enable flake8 checks for spaces around operators. --- jax/_src/image/scale.py | 2 +- jax/_src/lax/lax.py | 2 +- jax/_src/numpy/lax_numpy.py | 10 +++++----- jax/_src/scipy/linalg.py | 14 +++++++------- jax/experimental/jet.py | 6 +++--- setup.cfg | 2 +- tests/api_test.py | 10 +++++----- tests/linalg_test.py | 4 ++-- tests/random_test.py | 8 ++++---- 9 files changed, 29 insertions(+), 29 deletions(-) diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index b63fd98fc..c1a753bd9 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -35,7 +35,7 @@ def _fill_keys_cubic_kernel(x): # IEEE Transactions on Acoustics, Speech, and Signal Processing, # 29(6):1153–1160, 1981. out = ((1.5 * x - 2.5) * x) * x + 1. - out = jnp.where(x >= 1., ((-0.5* x + 2.5) * x - 4.) * x + 2., out) + out = jnp.where(x >= 1., ((-0.5 * x + 2.5) * x - 4.) * x + 2., out) return jnp.where(x >= 2., 0., out) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index fc0f82f25..17ef566a3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1924,7 +1924,7 @@ def collapse(operand: Array, start_dimension: int, def slice_in_dim(operand: Array, start_index: Optional[int], limit_index: Optional[int], - stride: int = 1, axis: int = 0)-> Array: + stride: int = 1, axis: int = 0) -> Array: """Convenience wrapper around slice applying to only one dimension.""" start_indices = [0] * operand.ndim limit_indices = list(operand.shape) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 43d038c4f..35868b432 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3345,8 +3345,8 @@ def meshgrid(*args, **kwargs): def _make_1d_grid_from_slice(s: slice, op_name: str): - start =core.concrete_or_error(None, s.start, - f"slice start of jnp.{op_name}") or 0 + start = core.concrete_or_error(None, s.start, + f"slice start of jnp.{op_name}") or 0 stop = core.concrete_or_error(None, s.stop, f"slice stop of jnp.{op_name}") step = core.concrete_or_error(None, s.step, @@ -3907,7 +3907,7 @@ def diag(v, k=0): else: raise ValueError("diag input must be 1d or 2d") -_SCALAR_VALUE_DOC="""\ +_SCALAR_VALUE_DOC = """\ This differs from np.diagflat for some scalar values of v, jax always returns a two-dimensional array, whereas numpy may return a scalar depending on the type of v. @@ -3929,7 +3929,7 @@ def diagflat(v, k=0): res = res.reshape(adj_length,adj_length) return res -_POLY_DOC="""\ +_POLY_DOC = """\ This differs from np.poly when an integer array is given. np.poly returns a result with dtype float64 in this case. jax returns a result with an inexact type, but not necessarily @@ -4032,7 +4032,7 @@ def trim_zeros(filt, trim='fb'): end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] -_LEADING_ZEROS_DOC="""\ +_LEADING_ZEROS_DOC = """\ Setting trim_leading_zeros=True makes the output match that of numpy. But prevents the function from being able to be used in compiled code. """ diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 65518564d..6d57b274c 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -288,10 +288,10 @@ def _calc_P_Q(A): n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings U13, V13 = _pade13(A) - conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001, - 9.504178996162932e-001, 2.097847961257068e+000]) - U = jnp.select((A_L1= n: jtu.check_jvp(jnp.linalg.qr, partial(jvp, jnp.linalg.qr), (a,), atol=3e-3) diff --git a/tests/random_test.py b/tests/random_test.py index 7b031f6fb..78bc03705 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -910,8 +910,8 @@ class LaxRandomTest(jtu.JaxTestCase): # 1/2 CDF_one_maxwell((x - loc) / scale)) # + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale))) def double_sided_maxwell_cdf(x, loc, scale): - pos = scipy.stats.maxwell().cdf((x - loc)/ scale) - neg = (1 - scipy.stats.maxwell().cdf((-x + loc)/ scale)) + pos = scipy.stats.maxwell().cdf((x - loc) / scale) + neg = (1 - scipy.stats.maxwell().cdf((-x + loc) / scale)) return (pos + neg) / 2 for samples in [uncompiled_samples, compiled_samples]: @@ -939,9 +939,9 @@ class LaxRandomTest(jtu.JaxTestCase): assert len(counts) == 2 self.assertAllClose( - counts[0]/ num_samples, 0.5, rtol=1e-02, atol=1e-02) + counts[0] / num_samples, 0.5, rtol=1e-02, atol=1e-02) self.assertAllClose( - counts[1]/ num_samples, 0.5, rtol=1e-02, atol=1e-02) + counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02) def testChoiceShapeIsNotSequenceError(self): key = random.PRNGKey(0)