mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fixed some type errors under pyright
These are mostly due to relience on submodule import side-effects, which AFAIU are unchecked by both pytype and mypy.
This commit is contained in:
parent
a8a55e0f2e
commit
51eb0d27c7
@ -58,7 +58,7 @@ map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
|
||||
DType = Any
|
||||
Shape = jax._src.core.Shape
|
||||
Shape = core.Shape
|
||||
# The values of input and output sharding from the lowering.
|
||||
LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue]
|
||||
HloSharding = xla_client.HloSharding
|
||||
|
@ -21,7 +21,6 @@ import itertools as it
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -389,6 +388,9 @@ class JVPTrace(Trace):
|
||||
|
||||
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
# Local import to prevent an import cycle.
|
||||
from jax._src.lax import lax
|
||||
|
||||
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
fwd_in = [(core.full_lower(p), type(t) is not Zero)
|
||||
for p, t in zip(primals_in, tangents_in)]
|
||||
@ -402,7 +404,7 @@ class JVPTrace(Trace):
|
||||
tangents_out = custom_lin_p.bind(
|
||||
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
||||
tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out)
|
||||
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
|
||||
tangents_out = map(recast_to_float0, primals_out, tangents_out)
|
||||
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
||||
|
||||
|
@ -36,7 +36,8 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.primitives import uninitialized_value
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.util import (
|
||||
safe_map,
|
||||
@ -111,13 +112,15 @@ def _pad_values_to_block_dimension(value,
|
||||
)
|
||||
if padded_shape != value.shape:
|
||||
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
|
||||
pad_value = uninitialized_value(shape=(), dtype=value.dtype)
|
||||
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
|
||||
value = jnp.pad(value, pad_width, constant_values=pad_value)
|
||||
return value
|
||||
|
||||
def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
|
||||
scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals)
|
||||
return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals)
|
||||
return tuple(
|
||||
primitives.uninitialized_value(a.shape, a.dtype) for a in scratch_avals
|
||||
)
|
||||
|
||||
def _initialize_output_vals(
|
||||
block_mappings_output: Iterable[BlockMapping],
|
||||
@ -128,7 +131,8 @@ def _initialize_output_vals(
|
||||
if i in oi_map:
|
||||
output_vals.append(input_args[oi_map[i]])
|
||||
else:
|
||||
output_vals.append(uninitialized_value(bm.array_shape_dtype.shape,
|
||||
output_vals.append(primitives.uninitialized_value(
|
||||
bm.array_shape_dtype.shape,
|
||||
bm.array_shape_dtype.dtype))
|
||||
return output_vals
|
||||
|
||||
@ -212,7 +216,7 @@ def _pallas_call_impl_interpret(
|
||||
if padding is not None and any(p != (0, 0) for p in padding):
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError("Padding with aliasing not supported.")
|
||||
pad_value = uninitialized_value(shape=(), dtype=x.dtype)
|
||||
pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype)
|
||||
x = lax.pad(x, pad_value, [(*p, 0) for p in padding])
|
||||
carry.append(x)
|
||||
|
||||
@ -872,9 +876,9 @@ def _pallas_call_batching_rule(
|
||||
val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim]
|
||||
|
||||
def when_wrapped_kernel(lengths_ref, *args, **kwargs):
|
||||
b_idx = jax.experimental.pallas.program_id(stacked_axis)
|
||||
b_idx = primitives.program_id(stacked_axis)
|
||||
i_idx = (
|
||||
jax.experimental.pallas.program_id(ragged_axis_dim)
|
||||
primitives.program_id(ragged_axis_dim)
|
||||
* val_at_ragged_dim
|
||||
)
|
||||
b_len = lengths_ref[b_idx]
|
||||
@ -883,7 +887,7 @@ def _pallas_call_batching_rule(
|
||||
# b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
|
||||
# checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
|
||||
|
||||
@jax.experimental.pallas.when(i_idx < b_len)
|
||||
@pallas_utils.when(i_idx < b_len)
|
||||
def f():
|
||||
# Important! This allows us to trace the inner kernel with the correct
|
||||
# grid to preserve user program_id semantics. Ex: program_id(0) will
|
||||
@ -893,7 +897,7 @@ def _pallas_call_batching_rule(
|
||||
|
||||
if debug_zero_fill_counterfactual:
|
||||
|
||||
@jax.experimental.pallas.when(i_idx >= b_len)
|
||||
@pallas_utils.when(i_idx >= b_len)
|
||||
def g():
|
||||
for arg_ref in args:
|
||||
arg_ref[...] = jnp.zeros_like(arg_ref)
|
||||
|
@ -707,7 +707,7 @@ debug_print_p = jax_core.Primitive("debug_print")
|
||||
debug_print_p.multiple_results = True
|
||||
|
||||
|
||||
def debug_print(fmt: str, *args: jax.ArrayLike):
|
||||
def debug_print(fmt: str, *args: jax.typing.ArrayLike):
|
||||
"""Prints scalar values from inside a Pallas kernel.
|
||||
|
||||
Args:
|
||||
@ -732,7 +732,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike):
|
||||
|
||||
|
||||
def check_debug_print_format(
|
||||
fmt: str, *args: jax.ArrayLike
|
||||
fmt: str, *args: jax.typing.ArrayLike
|
||||
):
|
||||
n_placeholders = 0
|
||||
for _, field, spec, conversion in string.Formatter().parse(fmt):
|
||||
|
@ -27,8 +27,9 @@ import jax
|
||||
import jax.numpy.fft
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src.api_util import _ensure_index_tuple
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.api_util import _ensure_index_tuple
|
||||
from jax._src.lax.lax import PrecisionLike
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import (
|
||||
@ -655,8 +656,7 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
|
||||
f"Unknown boundary option '{boundary}', "
|
||||
f"must be one of: {list(boundary_funcs.keys())}")
|
||||
|
||||
axis = jax.core.concrete_or_error(operator.index, axis,
|
||||
"axis of windowed-FFT")
|
||||
axis = core.concrete_or_error(operator.index, axis, "axis of windowed-FFT")
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
|
||||
if y is None:
|
||||
@ -686,8 +686,8 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
|
||||
noverlap_int: int = 0
|
||||
|
||||
if nperseg is not None: # if specified by user
|
||||
nperseg_int = jax.core.concrete_or_error(int, nperseg,
|
||||
"nperseg of windowed-FFT")
|
||||
nperseg_int = core.concrete_or_error(
|
||||
int, nperseg, "nperseg of windowed-FFT")
|
||||
if nperseg_int < 1:
|
||||
raise ValueError('nperseg must be a positive integer')
|
||||
# parse window; if array like, then set nperseg = win.shape
|
||||
@ -698,14 +698,13 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
|
||||
if noverlap is None:
|
||||
noverlap_int = nperseg_int // 2
|
||||
else:
|
||||
noverlap_int = jax.core.concrete_or_error(int, noverlap,
|
||||
"noverlap of windowed-FFT")
|
||||
noverlap_int = core.concrete_or_error(
|
||||
int, noverlap, "noverlap of windowed-FFT")
|
||||
|
||||
if nfft is None:
|
||||
nfft_int = nperseg_int
|
||||
else:
|
||||
nfft_int = jax.core.concrete_or_error(int, nfft,
|
||||
"nfft of windowed-FFT")
|
||||
nfft_int = core.concrete_or_error(int, nfft, "nfft of windowed-FFT")
|
||||
|
||||
# Special cases for size == 0
|
||||
if y is None:
|
||||
@ -1015,8 +1014,8 @@ def _overlap_and_add(x: Array, step_size: int) -> Array:
|
||||
An array with `(..., output_size)`-shape containing overlapped signal.
|
||||
"""
|
||||
check_arraylike("_overlap_and_add", x)
|
||||
step_size = jax.core.concrete_or_error(int, step_size,
|
||||
"step_size for overlap_and_add")
|
||||
step_size = core.concrete_or_error(
|
||||
int, step_size, "step_size for overlap_and_add")
|
||||
if x.ndim < 2:
|
||||
raise ValueError('Input must have (..., frames, frame_length) shape.')
|
||||
|
||||
@ -1114,7 +1113,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided
|
||||
else Zxx.shape[freq_axis])
|
||||
|
||||
nperseg_int = jax.core.concrete_or_error(int, nperseg or n_default,
|
||||
nperseg_int = core.concrete_or_error(int, nperseg or n_default,
|
||||
"nperseg: segment length of STFT")
|
||||
if nperseg_int < 1:
|
||||
raise ValueError('nperseg must be a positive integer')
|
||||
@ -1125,13 +1124,13 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
if input_onesided and nperseg_int == n_default + 1:
|
||||
nfft_int += 1 # Odd nperseg, no FFT padding
|
||||
else:
|
||||
nfft_int = jax.core.concrete_or_error(int, nfft, "nfft of STFT")
|
||||
nfft_int = core.concrete_or_error(int, nfft, "nfft of STFT")
|
||||
if nfft_int < nperseg_int:
|
||||
raise ValueError(
|
||||
f'FFT length ({nfft_int}) must be longer than nperseg ({nperseg_int}).')
|
||||
|
||||
noverlap_int = jax.core.concrete_or_error(int, noverlap or nperseg_int // 2,
|
||||
"noverlap of STFT")
|
||||
noverlap_int = core.concrete_or_error(
|
||||
int, noverlap or nperseg_int // 2, "noverlap of STFT")
|
||||
if noverlap_int >= nperseg_int:
|
||||
raise ValueError('noverlap must be less than nperseg.')
|
||||
nstep = nperseg_int - noverlap_int
|
||||
|
Loading…
x
Reference in New Issue
Block a user