From 29d03160e3d4477577bd1925aeb781afd9ada971 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 12 Aug 2022 12:51:09 +0000 Subject: [PATCH] Remove _ prefix from functions in jax._src.dtypes. to_inexact_dtype and to_complex_dtype are used across the JAX code base, so they shouldn't have _ prefixes. --- jax/_src/dtypes.py | 6 +++--- jax/_src/nn/functions.py | 2 +- jax/_src/nn/initializers.py | 4 ++-- jax/_src/numpy/lax_numpy.py | 16 ++++++++-------- jax/_src/numpy/polynomial.py | 4 ++-- jax/_src/numpy/reductions.py | 4 ++-- jax/_src/numpy/ufuncs.py | 2 +- jax/_src/numpy/util.py | 4 ++-- jax/_src/scipy/linalg.py | 2 +- jax/_src/scipy/signal.py | 2 +- tests/fft_test.py | 3 ++- tests/lax_numpy_test.py | 14 +++++++------- tests/polynomial_test.py | 4 ++-- tests/random_test.py | 2 +- tests/scipy_signal_test.py | 8 ++++---- 15 files changed, 39 insertions(+), 38 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index d1c3c528d..d6a98edeb 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -70,14 +70,14 @@ _dtype_to_inexact = { } -def _to_inexact_dtype(dtype): +def to_inexact_dtype(dtype): """Promotes a dtype into an inexact dtype, if it is not already one.""" dtype = np.dtype(dtype) return _dtype_to_inexact.get(dtype, dtype) -def _to_complex_dtype(dtype): - ftype = _to_inexact_dtype(dtype) +def to_complex_dtype(dtype): + ftype = to_inexact_dtype(dtype) if ftype in [np.dtype('float64'), np.dtype('complex128')]: return np.dtype('complex128') return np.dtype('complex64') diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 122b47b59..f8958e9f8 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -254,7 +254,7 @@ def gelu(x: Array, approximate: bool = True) -> Array: """ # Promote to nearest float-like dtype. - x = x.astype(dtypes._to_inexact_dtype(x.dtype)) + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) if approximate: sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 00b3a17e6..f46703edb 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -184,7 +184,7 @@ def _complex_uniform(key: KeyArray, """ key_r, key_theta = random.split(key) real_dtype = np.array(0, dtype).real.dtype - dtype = dtypes._to_complex_dtype(real_dtype) + dtype = dtypes.to_complex_dtype(real_dtype) r = jnp.sqrt(2 * random.uniform(key_r, shape, real_dtype)).astype(dtype) theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) @@ -199,7 +199,7 @@ def _complex_truncated_normal(key: KeyArray, upper: Array, """ key_r, key_theta = random.split(key) real_dtype = np.array(0, dtype).real.dtype - dtype = dtypes._to_complex_dtype(real_dtype) + dtype = dtypes.to_complex_dtype(real_dtype) t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) * random.uniform(key_r, shape, real_dtype).astype(dtype)) r = jnp.sqrt(-jnp.log(1 - t)) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 85f78c561..a93fe4a56 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -413,7 +413,7 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None): raise NotImplementedError("string values for `bins` not implemented.") _check_arraylike("histogram_bin_edges", a, bins) a = ravel(a) - dtype = dtypes._to_inexact_dtype(_dtype(a)) + dtype = dtypes.to_inexact_dtype(_dtype(a)) if _ndim(bins) == 1: return asarray(bins, dtype=dtype) bins = core.concrete_or_error(operator.index, bins, @@ -2178,9 +2178,9 @@ def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, _check_arraylike("linspace", start, stop) if dtype is None: - dtype = dtypes._to_inexact_dtype(result_type(start, stop)) + dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) - computation_dtype = dtypes._to_inexact_dtype(dtype) + computation_dtype = dtypes.to_inexact_dtype(dtype) start = asarray(start, dtype=computation_dtype) stop = asarray(stop, dtype=computation_dtype) @@ -2237,9 +2237,9 @@ def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, """Implementation of logspace differentiable in start and stop args.""" lax_internal._check_user_dtype_supported(dtype, "logspace") if dtype is None: - dtype = dtypes._to_inexact_dtype(result_type(start, stop)) + dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) - computation_dtype = dtypes._to_inexact_dtype(dtype) + computation_dtype = dtypes.to_inexact_dtype(dtype) _check_arraylike("logspace", start, stop) start = asarray(start, dtype=computation_dtype) stop = asarray(stop, dtype=computation_dtype) @@ -2259,9 +2259,9 @@ def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0): """Implementation of geomspace differentiable in start and stop args.""" lax_internal._check_user_dtype_supported(dtype, "geomspace") if dtype is None: - dtype = dtypes._to_inexact_dtype(result_type(start, stop)) + dtype = dtypes.to_inexact_dtype(result_type(start, stop)) dtype = _jnp_dtype(dtype) - computation_dtype = dtypes._to_inexact_dtype(dtype) + computation_dtype = dtypes.to_inexact_dtype(dtype) _check_arraylike("geomspace", start, stop) start = asarray(start, dtype=computation_dtype) stop = asarray(stop, dtype=computation_dtype) @@ -3271,7 +3271,7 @@ def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None): def sort_complex(a): _check_arraylike("sort_complex", a) a = lax.sort(a, dimension=0) - return lax.convert_element_type(a, dtypes._to_complex_dtype(a.dtype)) + return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) @_wraps(np.lexsort) @partial(jit, static_argnames=('axis',)) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 1931c85d7..d1aca3c20 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -33,7 +33,7 @@ import numpy as np def _roots_no_zeros(p): # build companion matrix and find its eigenvalues (the roots) if p.size < 2: - return array([], dtype=dtypes._to_complex_dtype(p.dtype)) + return array([], dtype=dtypes.to_complex_dtype(p.dtype)) A = diag(ones((p.size - 2,), p.dtype), -1) A = A.at[0, :].set(-p[1:] / p[0]) return linalg.eigvals(A) @@ -83,7 +83,7 @@ def roots(p, *, strip_zeros=True): if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p.size < 2: - return array([], dtype=dtypes._to_complex_dtype(p.dtype)) + return array([], dtype=dtypes.to_complex_dtype(p.dtype)) num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0)) if strip_zeros: diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 9b3b7469b..48e55a076 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -277,7 +277,7 @@ def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) if dtype is None: - dtype = dtypes._to_inexact_dtype(dtypes.dtype(a)) + dtype = dtypes.to_inexact_dtype(dtypes.dtype(a)) dtype = dtypes.canonicalize_dtype(dtype) return lax.div( @@ -384,7 +384,7 @@ def _var_promote_types(a_dtype, dtype): computation_dtype = dtype else: if not dtypes.issubdtype(a_dtype, np.inexact): - dtype = dtypes._to_inexact_dtype(a_dtype) + dtype = dtypes.to_inexact_dtype(a_dtype) computation_dtype = dtype else: dtype = _complex_elem_type(a_dtype) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index bd4a1ca45..d71d17e66 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -458,7 +458,7 @@ def ldexp(x1, x2): x1, x2 = _promote_shapes("ldexp", x1, x2) - dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype)) + dtype = dtypes.canonicalize_dtype(dtypes.to_inexact_dtype(x1_dtype)) info = dtypes.finfo(dtype) int_type = _INT_DTYPES[info.bits] diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e6b395bd6..bac102f0d 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -282,7 +282,7 @@ def _promote_dtypes_inexact(*args): Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) - to_dtype_inexact = dtypes._to_inexact_dtype(to_dtype) + to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] @@ -293,7 +293,7 @@ def _promote_dtypes_complex(*args): Promotes arguments to a complex type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) - to_dtype_complex = dtypes._to_complex_dtype(to_dtype) + to_dtype_complex = dtypes.to_complex_dtype(to_dtype) return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type) for x in args] diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index ae9a84b12..44eaf51e8 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -120,7 +120,7 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, @partial(jit, static_argnames=('output',)) def _schur(a, output): if output == "complex": - a = a.astype(dtypes._to_complex_dtype(a.dtype)) + a = a.astype(dtypes.to_complex_dtype(a.dtype)) return lax_linalg.schur(a) @_wraps(scipy.linalg.schur) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index b8013f038..ebf58be9a 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -277,7 +277,7 @@ def _spectral_helper(x, y, except ValueError as err: raise ValueError('x and y cannot be broadcast together.') from err - result_dtype = dtypes._to_complex_dtype(x.dtype) + result_dtype = dtypes.to_complex_dtype(x.dtype) freq_dtype = np.finfo(result_dtype).dtype if nperseg is not None: # if specified by user diff --git a/tests/fft_test.py b/tests/fft_test.py index c6a024c86..590a64b5c 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -109,7 +109,8 @@ class FftTest(jtu.JaxTestCase): def testLaxIrfftDoesNotMutateInputs(self, dtype): if dtype == np.float64 and not config.x64_enabled: raise self.skipTest("float64 requires jax_enable_x64=true") - x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtypes._to_complex_dtype(dtype)) + x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]], + dtype=dtypes.to_complex_dtype(dtype)) y = np.asarray(jnp.fft.irfft2(x)) z = np.asarray(jnp.fft.irfft2(x)) self.assertAllClose(y, z) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a469fd774..8853154d4 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -842,7 +842,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def np_fun(x): x = np.asarray(x) if inexact: - x = x.astype(dtypes._to_inexact_dtype(x.dtype)) + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32) t = out_dtype if out_dtype != jnp.bfloat16 else np.float32 return np_op(x_cast, axis, dtype=t, keepdims=keepdims) @@ -884,7 +884,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def np_fun(x): x = np.asarray(x) if inexact: - x = x.astype(dtypes._to_inexact_dtype(x.dtype)) + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) x_cast = x if not is_bf16_nan_test else x.astype(np.float32) res = np_op(x_cast, axis, keepdims=keepdims) res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) @@ -919,7 +919,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def np_fun(x): x = np.asarray(x) if inexact: - x = x.astype(dtypes._to_inexact_dtype(x.dtype)) + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) x_cast = x if not is_bf16_nan_test else x.astype(np.float32) res = np_op(x_cast, axis, keepdims=keepdims, initial=initial) res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) @@ -955,7 +955,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def np_fun(x): x = np.asarray(x) if inexact: - x = x.astype(dtypes._to_inexact_dtype(x.dtype)) + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) x_cast = x if not is_bf16_nan_test else x.astype(np.float32) res = np_op(x_cast, axis, keepdims=keepdims) res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) @@ -999,7 +999,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def np_fun(x): x = np.asarray(x) if inexact: - x = x.astype(dtypes._to_inexact_dtype(x.dtype)) + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) x_cast = x if not is_bf16_nan_test else x.astype(np.float32) res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where) res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) @@ -1042,7 +1042,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def np_fun(x): x = np.asarray(x) if inexact: - x = x.astype(dtypes._to_inexact_dtype(x.dtype)) + x = x.astype(dtypes.to_inexact_dtype(x.dtype)) x_cast = x if not is_bf16_nan_test else x.astype(np.float32) res = np_op(x_cast, axis, keepdims=keepdims, where=where) res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) @@ -3145,7 +3145,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): @jtu.ignore_warning(category=RuntimeWarning, message="overflow.*") def np_fun(x1, x2): - out_dtype = dtypes._to_inexact_dtype(x1.dtype) + out_dtype = dtypes.to_inexact_dtype(x1.dtype) return np.ldexp(x1.astype(out_dtype), x2) jnp_fun = jnp.ldexp diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index ba0eec230..bf28ca4ff 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -87,7 +87,7 @@ class TestPolynomial(jtu.JaxTestCase): jnp_fun = jnp.roots def np_fun(arg): - return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype)) + return np.roots(arg).astype(dtypes.to_complex_dtype(arg.dtype)) # Note: outputs have no defined order, so we need to use a special comparator. args = args_maker() @@ -116,7 +116,7 @@ class TestPolynomial(jtu.JaxTestCase): jnp_fun = partial(jnp.roots, strip_zeros=False) def np_fun(arg): - roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype)) + roots = np.roots(arg).astype(dtypes.to_complex_dtype(arg.dtype)) if len(roots) < len(arg) - 1: roots = np.pad(roots, (0, len(arg) - len(roots) - 1), constant_values=complex(np.nan, np.nan)) diff --git a/tests/random_test.py b/tests/random_test.py index 90f717a4a..56f968743 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -639,7 +639,7 @@ class LaxRandomTest(jtu.JaxTestCase): def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis): # This is the function API that we test against (note that self.rng().choice differs) np_choice = np.random.default_rng(0).choice - p_dtype = dtypes._to_inexact_dtype(dtype) + p_dtype = dtypes.to_inexact_dtype(dtype) key = self.seed_prng(0) is_range = type(input_range_or_shape) is int diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index a5f0bf892..7eef39a19 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -61,10 +61,10 @@ default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex _TPU_FFT_TOL = 0.15 def _real_dtype(dtype): - return jnp.finfo(dtypes._to_inexact_dtype(dtype)).dtype + return jnp.finfo(dtypes.to_inexact_dtype(dtype)).dtype def _complex_dtype(dtype): - return dtypes._to_complex_dtype(dtype) + return dtypes.to_complex_dtype(dtype) class LaxBackedScipySignalTests(jtu.JaxTestCase): @@ -133,7 +133,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): kwds = dict(axis=axis, type=type, bp=bp) def osp_fun(x): - return osp_signal.detrend(x, **kwds).astype(dtypes._to_inexact_dtype(x.dtype)) + return osp_signal.detrend(x, **kwds).astype(dtypes.to_inexact_dtype(x.dtype)) jsp_fun = partial(jsp_signal.detrend, **kwds) if jtu.device_under_test() == 'tpu': @@ -365,7 +365,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): kwargs['nperseg'] = nperseg else: kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg), - dtype=dtypes._to_complex_dtype(dtype)) + dtype=dtypes.to_complex_dtype(dtype)) if use_noverlap: kwargs['noverlap'] = noverlap