mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
remove cast
This commit is contained in:
parent
befa29b566
commit
0e14075a35
@ -382,14 +382,6 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
|
||||
|
||||
# Staging
|
||||
|
||||
|
||||
def _check_avals_are_shapedarrays(
|
||||
avals: Sequence[core.AbstractValue],
|
||||
) -> Sequence[core.ShapedArray]:
|
||||
assert all(isinstance(a, core.ShapedArray) for a in avals), avals
|
||||
return cast(Sequence[core.ShapedArray], avals)
|
||||
|
||||
|
||||
def _shard_map_staging(
|
||||
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
|
||||
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh,
|
||||
@ -402,7 +394,7 @@ def _shard_map_staging(
|
||||
main = trace.main
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, out_avals_generic, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
|
||||
out_avals_ = _check_avals_are_shapedarrays(out_avals_generic)
|
||||
out_avals_ = map(_check_shapedarray, out_avals_generic)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
if check_rep:
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
@ -426,6 +418,10 @@ def _shard_map_staging(
|
||||
return out_tracers
|
||||
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
|
||||
|
||||
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
return aval
|
||||
|
||||
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
) -> core.AbstractValue:
|
||||
if isinstance(aval, core.ShapedArray):
|
||||
|
Loading…
x
Reference in New Issue
Block a user