mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Removed unused _get_memory_space_from_ref
PiperOrigin-RevId: 691342830
This commit is contained in:
parent
e35e7f8e20
commit
908c8a8280
@ -1558,12 +1558,6 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
|
||||
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
|
||||
|
||||
|
||||
def _get_memory_space_from_ref(ref_aval: state.AbstractRef) -> Any:
|
||||
if isinstance(ref_aval, pallas_core.AbstractMemoryRef):
|
||||
return ref_aval.memory_space
|
||||
return pallas_core.MemorySpace.ANY
|
||||
|
||||
|
||||
@state_discharge.register_discharge_rule(pallas_call_p)
|
||||
def _pallas_call_state_discharge_rule(
|
||||
avals_in,
|
||||
|
Loading…
x
Reference in New Issue
Block a user