Removed unused _get_memory_space_from_ref

PiperOrigin-RevId: 691342830
This commit is contained in:
Sergei Lebedev 2024-10-30 02:39:00 -07:00 committed by jax authors
parent e35e7f8e20
commit 908c8a8280

View File

@ -1558,12 +1558,6 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
def _get_memory_space_from_ref(ref_aval: state.AbstractRef) -> Any:
if isinstance(ref_aval, pallas_core.AbstractMemoryRef):
return ref_aval.memory_space
return pallas_core.MemorySpace.ANY
@state_discharge.register_discharge_rule(pallas_call_p)
def _pallas_call_state_discharge_rule(
avals_in,