mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Ensured that JAX type checks under pytype on Python 3.12
Some errors uncovered by pytype look genuine and need to be revisited in the in the future. PiperOrigin-RevId: 704268742
This commit is contained in:
parent
5a1c4c5783
commit
1ac6b762dd
@ -1120,7 +1120,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
bufs.append(buf)
|
||||
break
|
||||
else:
|
||||
bufs.append(buf)
|
||||
bufs.append(candidates_list[-1])
|
||||
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
|
||||
|
||||
|
||||
|
@ -1992,7 +1992,8 @@ def compute_dim_vars_from_arg_shapes(
|
||||
generate the code for computing the dimension variables. It also generates
|
||||
the shape assertions.
|
||||
|
||||
Returns: the values of the dimension variables, in the order determined by
|
||||
Returns:
|
||||
The values of the dimension variables, in the order determined by
|
||||
`all_dim_vars(args_avals)`.
|
||||
"""
|
||||
dim_vars = all_dim_vars(args_avals)
|
||||
@ -2006,8 +2007,7 @@ def compute_dim_vars_from_arg_shapes(
|
||||
}
|
||||
synthetic_eval = ShapeEvaluator(synthetic_env)
|
||||
shape_constraints.shape_assertions(synthetic_eval)
|
||||
dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars]
|
||||
return tuple(dim_values)
|
||||
return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars)
|
||||
|
||||
def _solve_dim_equations(
|
||||
eqns: list[_DimEquation],
|
||||
@ -2141,7 +2141,8 @@ def _solve_dim_equations(
|
||||
eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)]
|
||||
if not eqns:
|
||||
add_explicit_symbolic_constraints(shape_env)
|
||||
return shape_env, shape_constraints # SUCCESS
|
||||
# SUCCESS
|
||||
return shape_env, shape_constraints # pytype: disable=bad-return-type
|
||||
elif len(eqns) >= nr_eqns:
|
||||
break
|
||||
|
||||
|
@ -1699,6 +1699,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
|
||||
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
|
||||
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
|
||||
# The below custom call achieves the sharding like above example.
|
||||
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
|
||||
if config.use_shardy_partitioner.value:
|
||||
physical_ndim = core.physical_aval(aval).ndim
|
||||
s = sharding_impls.SdyArraySharding(
|
||||
|
@ -177,9 +177,12 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
if const is None:
|
||||
aval = pval.get_aval()
|
||||
if type(aval) is DShapedArray:
|
||||
# TODO(dougalm): Fix the type error and remove the pytype pragmas.
|
||||
# pytype: disable=attribute-error
|
||||
shape = [self.new_instantiated_const(d)
|
||||
if isinstance(d, Tracer) and d._trace.level < self.level else d
|
||||
for d in aval.shape]
|
||||
# pytype: enable=attribute-error
|
||||
aval = aval.update(shape=tuple(shape))
|
||||
return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding())
|
||||
else:
|
||||
@ -1776,6 +1779,9 @@ def _inline_literals(
|
||||
newvars: dict[Var, Var] = {}
|
||||
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
|
||||
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
|
||||
lit_or_var = (
|
||||
lambda a: a if isinstance(a, Literal) else (lit(a) or var(a))
|
||||
)
|
||||
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))
|
||||
|
||||
def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
|
||||
@ -1794,10 +1800,10 @@ def _inline_literals(
|
||||
new_invars = [var(v) for v in jaxpr.invars]
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
invars = [lit(x) or var(x) for x in eqn.invars]
|
||||
invars = [lit_or_var(x) for x in eqn.invars]
|
||||
outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars]
|
||||
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
|
||||
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
|
||||
new_outvars = [lit_or_var(v) for v in jaxpr.outvars]
|
||||
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
|
||||
new_eqns)
|
||||
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
|
||||
|
@ -513,7 +513,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
# jaxpr for each branch.
|
||||
branches_known_ : list[core.ClosedJaxpr] = []
|
||||
branches_staged_: list[core.ClosedJaxpr] = []
|
||||
branch_res_avals: list[core.AbstractValue] = []
|
||||
branch_res_avals: list[list[core.AbstractValue]] = []
|
||||
for jaxpr in branches:
|
||||
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
|
@ -1651,7 +1651,7 @@ pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
|
||||
|
||||
def _add_reshapes(which, jaxpr_known, jaxpr_staged):
|
||||
# add singleton axes to residuals which are from jaxpr_known and are scalars
|
||||
which_ = [w and not v.aval.shape
|
||||
which_ = [w and not v.aval.shape # pytype: disable=attribute-error
|
||||
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
|
||||
if not any(which_): return jaxpr_known, jaxpr_staged
|
||||
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
|
||||
|
Loading…
x
Reference in New Issue
Block a user