Ensured that JAX type checks under pytype on Python 3.12

Some errors uncovered by pytype look genuine and need to be revisited in
the in the future.

PiperOrigin-RevId: 704268742
This commit is contained in:
Sergei Lebedev 2024-12-09 06:52:25 -08:00 committed by jax authors
parent 5a1c4c5783
commit 1ac6b762dd
6 changed files with 17 additions and 9 deletions

View File

@ -1120,7 +1120,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
bufs.append(buf)
break
else:
bufs.append(buf)
bufs.append(candidates_list[-1])
return pxla.batched_device_put(x.aval, sharding, bufs, devices)

View File

@ -1992,7 +1992,8 @@ def compute_dim_vars_from_arg_shapes(
generate the code for computing the dimension variables. It also generates
the shape assertions.
Returns: the values of the dimension variables, in the order determined by
Returns:
The values of the dimension variables, in the order determined by
`all_dim_vars(args_avals)`.
"""
dim_vars = all_dim_vars(args_avals)
@ -2006,8 +2007,7 @@ def compute_dim_vars_from_arg_shapes(
}
synthetic_eval = ShapeEvaluator(synthetic_env)
shape_constraints.shape_assertions(synthetic_eval)
dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars]
return tuple(dim_values)
return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars)
def _solve_dim_equations(
eqns: list[_DimEquation],
@ -2141,7 +2141,8 @@ def _solve_dim_equations(
eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)]
if not eqns:
add_explicit_symbolic_constraints(shape_env)
return shape_env, shape_constraints # SUCCESS
# SUCCESS
return shape_env, shape_constraints # pytype: disable=bad-return-type
elif len(eqns) >= nr_eqns:
break

View File

@ -1699,6 +1699,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
# The below custom call achieves the sharding like above example.
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
if config.use_shardy_partitioner.value:
physical_ndim = core.physical_aval(aval).ndim
s = sharding_impls.SdyArraySharding(

View File

@ -177,9 +177,12 @@ class JaxprTrace(Trace['JaxprTracer']):
if const is None:
aval = pval.get_aval()
if type(aval) is DShapedArray:
# TODO(dougalm): Fix the type error and remove the pytype pragmas.
# pytype: disable=attribute-error
shape = [self.new_instantiated_const(d)
if isinstance(d, Tracer) and d._trace.level < self.level else d
for d in aval.shape]
# pytype: enable=attribute-error
aval = aval.update(shape=tuple(shape))
return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding())
else:
@ -1776,6 +1779,9 @@ def _inline_literals(
newvars: dict[Var, Var] = {}
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
lit_or_var = (
lambda a: a if isinstance(a, Literal) else (lit(a) or var(a))
)
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))
def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
@ -1794,10 +1800,10 @@ def _inline_literals(
new_invars = [var(v) for v in jaxpr.invars]
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(x) or var(x) for x in eqn.invars]
invars = [lit_or_var(x) for x in eqn.invars]
outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars]
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_outvars = [lit_or_var(v) for v in jaxpr.outvars]
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
new_eqns)
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,

View File

@ -513,7 +513,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
# jaxpr for each branch.
branches_known_ : list[core.ClosedJaxpr] = []
branches_staged_: list[core.ClosedJaxpr] = []
branch_res_avals: list[core.AbstractValue] = []
branch_res_avals: list[list[core.AbstractValue]] = []
for jaxpr in branches:
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(

View File

@ -1651,7 +1651,7 @@ pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
def _add_reshapes(which, jaxpr_known, jaxpr_staged):
# add singleton axes to residuals which are from jaxpr_known and are scalars
which_ = [w and not v.aval.shape
which_ = [w and not v.aval.shape # pytype: disable=attribute-error
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
if not any(which_): return jaxpr_known, jaxpr_staged
assert not jaxpr_known.constvars and not jaxpr_staged.constvars