From 51eb0d27c73df9bba44e6f5155913336c58ff270 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 4 Sep 2024 22:17:19 +0100 Subject: [PATCH] Fixed some type errors under pyright These are mostly due to relience on submodule import side-effects, which AFAIU are unchecked by both pytype and mypy. --- jax/_src/export/_export.py | 2 +- jax/_src/interpreters/ad.py | 6 ++++-- jax/_src/pallas/pallas_call.py | 24 ++++++++++++++---------- jax/_src/pallas/primitives.py | 4 ++-- jax/_src/scipy/signal.py | 29 ++++++++++++++--------------- 5 files changed, 35 insertions(+), 30 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index ee1b0dabb..d0159f7a4 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -58,7 +58,7 @@ map = util.safe_map zip = util.safe_zip DType = Any -Shape = jax._src.core.Shape +Shape = core.Shape # The values of input and output sharding from the lowering. LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue] HloSharding = xla_client.HloSharding diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ea9da4574..f1b25cf96 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -21,7 +21,6 @@ import itertools as it from functools import partial from typing import Any -import jax from jax._src import config from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe @@ -389,6 +388,9 @@ class JVPTrace(Trace): def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, symbolic_zeros): + # Local import to prevent an import cycle. + from jax._src.lax import lax + primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) fwd_in = [(core.full_lower(p), type(t) is not Zero) for p, t in zip(primals_in, tangents_in)] @@ -402,7 +404,7 @@ class JVPTrace(Trace): tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) - tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out) + tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d3b0ea8ca..28123f0a0 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -36,7 +36,8 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core -from jax._src.pallas.primitives import uninitialized_value +from jax._src.pallas import primitives +from jax._src.pallas import utils as pallas_utils from jax._src.state import discharge as state_discharge from jax._src.util import ( safe_map, @@ -111,13 +112,15 @@ def _pad_values_to_block_dimension(value, ) if padded_shape != value.shape: pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) - pad_value = uninitialized_value(shape=(), dtype=value.dtype) + pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype) value = jnp.pad(value, pad_width, constant_values=pad_value) return value def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]: scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals) - return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals) + return tuple( + primitives.uninitialized_value(a.shape, a.dtype) for a in scratch_avals + ) def _initialize_output_vals( block_mappings_output: Iterable[BlockMapping], @@ -128,8 +131,9 @@ def _initialize_output_vals( if i in oi_map: output_vals.append(input_args[oi_map[i]]) else: - output_vals.append(uninitialized_value(bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype)) + output_vals.append(primitives.uninitialized_value( + bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype)) return output_vals def _logical_to_interpret_mode_dtype(dtype): @@ -212,7 +216,7 @@ def _pallas_call_impl_interpret( if padding is not None and any(p != (0, 0) for p in padding): if input_output_aliases: raise NotImplementedError("Padding with aliasing not supported.") - pad_value = uninitialized_value(shape=(), dtype=x.dtype) + pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype) x = lax.pad(x, pad_value, [(*p, 0) for p in padding]) carry.append(x) @@ -872,9 +876,9 @@ def _pallas_call_batching_rule( val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim] def when_wrapped_kernel(lengths_ref, *args, **kwargs): - b_idx = jax.experimental.pallas.program_id(stacked_axis) + b_idx = primitives.program_id(stacked_axis) i_idx = ( - jax.experimental.pallas.program_id(ragged_axis_dim) + primitives.program_id(ragged_axis_dim) * val_at_ragged_dim ) b_len = lengths_ref[b_idx] @@ -883,7 +887,7 @@ def _pallas_call_batching_rule( # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") - @jax.experimental.pallas.when(i_idx < b_len) + @pallas_utils.when(i_idx < b_len) def f(): # Important! This allows us to trace the inner kernel with the correct # grid to preserve user program_id semantics. Ex: program_id(0) will @@ -893,7 +897,7 @@ def _pallas_call_batching_rule( if debug_zero_fill_counterfactual: - @jax.experimental.pallas.when(i_idx >= b_len) + @pallas_utils.when(i_idx >= b_len) def g(): for arg_ref in args: arg_ref[...] = jnp.zeros_like(arg_ref) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 53227478c..e41a8cf59 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -707,7 +707,7 @@ debug_print_p = jax_core.Primitive("debug_print") debug_print_p.multiple_results = True -def debug_print(fmt: str, *args: jax.ArrayLike): +def debug_print(fmt: str, *args: jax.typing.ArrayLike): """Prints scalar values from inside a Pallas kernel. Args: @@ -732,7 +732,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike): def check_debug_print_format( - fmt: str, *args: jax.ArrayLike + fmt: str, *args: jax.typing.ArrayLike ): n_placeholders = 0 for _, field, spec, conversion in string.Formatter().parse(fmt): diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index cb3719faf..d950cd2ea 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -27,8 +27,9 @@ import jax import jax.numpy.fft import jax.numpy as jnp from jax import lax -from jax._src.api_util import _ensure_index_tuple +from jax._src import core from jax._src import dtypes +from jax._src.api_util import _ensure_index_tuple from jax._src.lax.lax import PrecisionLike from jax._src.numpy import linalg from jax._src.numpy.util import ( @@ -655,8 +656,7 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, f"Unknown boundary option '{boundary}', " f"must be one of: {list(boundary_funcs.keys())}") - axis = jax.core.concrete_or_error(operator.index, axis, - "axis of windowed-FFT") + axis = core.concrete_or_error(operator.index, axis, "axis of windowed-FFT") axis = canonicalize_axis(axis, x.ndim) if y is None: @@ -686,8 +686,8 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, noverlap_int: int = 0 if nperseg is not None: # if specified by user - nperseg_int = jax.core.concrete_or_error(int, nperseg, - "nperseg of windowed-FFT") + nperseg_int = core.concrete_or_error( + int, nperseg, "nperseg of windowed-FFT") if nperseg_int < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape @@ -698,14 +698,13 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, if noverlap is None: noverlap_int = nperseg_int // 2 else: - noverlap_int = jax.core.concrete_or_error(int, noverlap, - "noverlap of windowed-FFT") + noverlap_int = core.concrete_or_error( + int, noverlap, "noverlap of windowed-FFT") if nfft is None: nfft_int = nperseg_int else: - nfft_int = jax.core.concrete_or_error(int, nfft, - "nfft of windowed-FFT") + nfft_int = core.concrete_or_error(int, nfft, "nfft of windowed-FFT") # Special cases for size == 0 if y is None: @@ -1015,8 +1014,8 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: An array with `(..., output_size)`-shape containing overlapped signal. """ check_arraylike("_overlap_and_add", x) - step_size = jax.core.concrete_or_error(int, step_size, - "step_size for overlap_and_add") + step_size = core.concrete_or_error( + int, step_size, "step_size for overlap_and_add") if x.ndim < 2: raise ValueError('Input must have (..., frames, frame_length) shape.') @@ -1114,7 +1113,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided else Zxx.shape[freq_axis]) - nperseg_int = jax.core.concrete_or_error(int, nperseg or n_default, + nperseg_int = core.concrete_or_error(int, nperseg or n_default, "nperseg: segment length of STFT") if nperseg_int < 1: raise ValueError('nperseg must be a positive integer') @@ -1125,13 +1124,13 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', if input_onesided and nperseg_int == n_default + 1: nfft_int += 1 # Odd nperseg, no FFT padding else: - nfft_int = jax.core.concrete_or_error(int, nfft, "nfft of STFT") + nfft_int = core.concrete_or_error(int, nfft, "nfft of STFT") if nfft_int < nperseg_int: raise ValueError( f'FFT length ({nfft_int}) must be longer than nperseg ({nperseg_int}).') - noverlap_int = jax.core.concrete_or_error(int, noverlap or nperseg_int // 2, - "noverlap of STFT") + noverlap_int = core.concrete_or_error( + int, noverlap or nperseg_int // 2, "noverlap of STFT") if noverlap_int >= nperseg_int: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg_int - noverlap_int