Fixed some type errors under pyright

These are mostly due to relience on submodule import side-effects, which
AFAIU are unchecked by both pytype and mypy.
This commit is contained in:
Sergei Lebedev 2024-09-04 22:17:19 +01:00
parent a8a55e0f2e
commit 51eb0d27c7
5 changed files with 35 additions and 30 deletions

View File

@ -58,7 +58,7 @@ map = util.safe_map
zip = util.safe_zip
DType = Any
Shape = jax._src.core.Shape
Shape = core.Shape
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue]
HloSharding = xla_client.HloSharding

View File

@ -21,7 +21,6 @@ import itertools as it
from functools import partial
from typing import Any
import jax
from jax._src import config
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
@ -389,6 +388,9 @@ class JVPTrace(Trace):
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees,
symbolic_zeros):
# Local import to prevent an import cycle.
from jax._src.lax import lax
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
fwd_in = [(core.full_lower(p), type(t) is not Zero)
for p, t in zip(primals_in, tangents_in)]
@ -402,7 +404,7 @@ class JVPTrace(Trace):
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)

View File

@ -36,7 +36,8 @@ from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.pallas.primitives import uninitialized_value
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.state import discharge as state_discharge
from jax._src.util import (
safe_map,
@ -111,13 +112,15 @@ def _pad_values_to_block_dimension(value,
)
if padded_shape != value.shape:
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
pad_value = uninitialized_value(shape=(), dtype=value.dtype)
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
value = jnp.pad(value, pad_width, constant_values=pad_value)
return value
def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals)
return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals)
return tuple(
primitives.uninitialized_value(a.shape, a.dtype) for a in scratch_avals
)
def _initialize_output_vals(
block_mappings_output: Iterable[BlockMapping],
@ -128,7 +131,8 @@ def _initialize_output_vals(
if i in oi_map:
output_vals.append(input_args[oi_map[i]])
else:
output_vals.append(uninitialized_value(bm.array_shape_dtype.shape,
output_vals.append(primitives.uninitialized_value(
bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype))
return output_vals
@ -212,7 +216,7 @@ def _pallas_call_impl_interpret(
if padding is not None and any(p != (0, 0) for p in padding):
if input_output_aliases:
raise NotImplementedError("Padding with aliasing not supported.")
pad_value = uninitialized_value(shape=(), dtype=x.dtype)
pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype)
x = lax.pad(x, pad_value, [(*p, 0) for p in padding])
carry.append(x)
@ -872,9 +876,9 @@ def _pallas_call_batching_rule(
val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim]
def when_wrapped_kernel(lengths_ref, *args, **kwargs):
b_idx = jax.experimental.pallas.program_id(stacked_axis)
b_idx = primitives.program_id(stacked_axis)
i_idx = (
jax.experimental.pallas.program_id(ragged_axis_dim)
primitives.program_id(ragged_axis_dim)
* val_at_ragged_dim
)
b_len = lengths_ref[b_idx]
@ -883,7 +887,7 @@ def _pallas_call_batching_rule(
# b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
# checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
@jax.experimental.pallas.when(i_idx < b_len)
@pallas_utils.when(i_idx < b_len)
def f():
# Important! This allows us to trace the inner kernel with the correct
# grid to preserve user program_id semantics. Ex: program_id(0) will
@ -893,7 +897,7 @@ def _pallas_call_batching_rule(
if debug_zero_fill_counterfactual:
@jax.experimental.pallas.when(i_idx >= b_len)
@pallas_utils.when(i_idx >= b_len)
def g():
for arg_ref in args:
arg_ref[...] = jnp.zeros_like(arg_ref)

View File

@ -707,7 +707,7 @@ debug_print_p = jax_core.Primitive("debug_print")
debug_print_p.multiple_results = True
def debug_print(fmt: str, *args: jax.ArrayLike):
def debug_print(fmt: str, *args: jax.typing.ArrayLike):
"""Prints scalar values from inside a Pallas kernel.
Args:
@ -732,7 +732,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike):
def check_debug_print_format(
fmt: str, *args: jax.ArrayLike
fmt: str, *args: jax.typing.ArrayLike
):
n_placeholders = 0
for _, field, spec, conversion in string.Formatter().parse(fmt):

View File

@ -27,8 +27,9 @@ import jax
import jax.numpy.fft
import jax.numpy as jnp
from jax import lax
from jax._src.api_util import _ensure_index_tuple
from jax._src import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.lax.lax import PrecisionLike
from jax._src.numpy import linalg
from jax._src.numpy.util import (
@ -655,8 +656,7 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
f"Unknown boundary option '{boundary}', "
f"must be one of: {list(boundary_funcs.keys())}")
axis = jax.core.concrete_or_error(operator.index, axis,
"axis of windowed-FFT")
axis = core.concrete_or_error(operator.index, axis, "axis of windowed-FFT")
axis = canonicalize_axis(axis, x.ndim)
if y is None:
@ -686,8 +686,8 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
noverlap_int: int = 0
if nperseg is not None: # if specified by user
nperseg_int = jax.core.concrete_or_error(int, nperseg,
"nperseg of windowed-FFT")
nperseg_int = core.concrete_or_error(
int, nperseg, "nperseg of windowed-FFT")
if nperseg_int < 1:
raise ValueError('nperseg must be a positive integer')
# parse window; if array like, then set nperseg = win.shape
@ -698,14 +698,13 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
if noverlap is None:
noverlap_int = nperseg_int // 2
else:
noverlap_int = jax.core.concrete_or_error(int, noverlap,
"noverlap of windowed-FFT")
noverlap_int = core.concrete_or_error(
int, noverlap, "noverlap of windowed-FFT")
if nfft is None:
nfft_int = nperseg_int
else:
nfft_int = jax.core.concrete_or_error(int, nfft,
"nfft of windowed-FFT")
nfft_int = core.concrete_or_error(int, nfft, "nfft of windowed-FFT")
# Special cases for size == 0
if y is None:
@ -1015,8 +1014,8 @@ def _overlap_and_add(x: Array, step_size: int) -> Array:
An array with `(..., output_size)`-shape containing overlapped signal.
"""
check_arraylike("_overlap_and_add", x)
step_size = jax.core.concrete_or_error(int, step_size,
"step_size for overlap_and_add")
step_size = core.concrete_or_error(
int, step_size, "step_size for overlap_and_add")
if x.ndim < 2:
raise ValueError('Input must have (..., frames, frame_length) shape.')
@ -1114,7 +1113,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided
else Zxx.shape[freq_axis])
nperseg_int = jax.core.concrete_or_error(int, nperseg or n_default,
nperseg_int = core.concrete_or_error(int, nperseg or n_default,
"nperseg: segment length of STFT")
if nperseg_int < 1:
raise ValueError('nperseg must be a positive integer')
@ -1125,13 +1124,13 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
if input_onesided and nperseg_int == n_default + 1:
nfft_int += 1 # Odd nperseg, no FFT padding
else:
nfft_int = jax.core.concrete_or_error(int, nfft, "nfft of STFT")
nfft_int = core.concrete_or_error(int, nfft, "nfft of STFT")
if nfft_int < nperseg_int:
raise ValueError(
f'FFT length ({nfft_int}) must be longer than nperseg ({nperseg_int}).')
noverlap_int = jax.core.concrete_or_error(int, noverlap or nperseg_int // 2,
"noverlap of STFT")
noverlap_int = core.concrete_or_error(
int, noverlap or nperseg_int // 2, "noverlap of STFT")
if noverlap_int >= nperseg_int:
raise ValueError('noverlap must be less than nperseg.')
nstep = nperseg_int - noverlap_int