remove cast

This commit is contained in:
Matthew Johnson 2023-05-09 14:44:05 -07:00
parent befa29b566
commit 0e14075a35

View File

@ -382,14 +382,6 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
# Staging
def _check_avals_are_shapedarrays(
avals: Sequence[core.AbstractValue],
) -> Sequence[core.ShapedArray]:
assert all(isinstance(a, core.ShapedArray) for a in avals), avals
return cast(Sequence[core.ShapedArray], avals)
def _shard_map_staging(
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh,
@ -402,7 +394,7 @@ def _shard_map_staging(
main = trace.main
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, out_avals_generic, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
out_avals_ = _check_avals_are_shapedarrays(out_avals_generic)
out_avals_ = map(_check_shapedarray, out_avals_generic)
_check_names(out_names_thunk(), out_avals_)
if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
@ -426,6 +418,10 @@ def _shard_map_staging(
return out_tracers
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
return aval
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if isinstance(aval, core.ShapedArray):