mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Remove pxla.OutputType enum class now that the only output can be jax.Array
PiperOrigin-RevId: 517985356
This commit is contained in:
parent
021fadfcbc
commit
58fed7001a
@ -714,9 +714,9 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
|
||||
return xc.array_result_handler(
|
||||
global_aval, out_sharding, committed=committed, _skip_checks=True
|
||||
)
|
||||
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
|
||||
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
|
||||
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
|
||||
pxla.global_result_handlers[core.AbstractToken] = lambda *_: lambda *_: core.token
|
||||
|
||||
|
||||
# Only used for Arrays that come out of pmap.
|
||||
@ -729,5 +729,5 @@ def _array_local_result_handler(aval, sharding, indices):
|
||||
return xc.array_result_handler(
|
||||
aval, sharding, committed=True, _skip_checks=True
|
||||
)
|
||||
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
|
||||
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
|
||||
pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler
|
||||
pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler
|
||||
|
@ -70,7 +70,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
from jax._src.config import flags
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -566,12 +566,6 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
||||
return PartitionSpec(*partitions)
|
||||
|
||||
|
||||
class OutputType(enum.Enum):
|
||||
Array = 0
|
||||
GlobalDeviceArray = 1
|
||||
ShardedDeviceArray = 2
|
||||
|
||||
|
||||
def local_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
sharding: sharding_impls.XLACompatibleSharding,
|
||||
@ -591,26 +585,14 @@ def local_aval_to_result_handler(
|
||||
for this output. The function will return an object suitable for returning
|
||||
to the user, e.g. a ShardedDeviceArray.
|
||||
"""
|
||||
output_type = OutputType.Array
|
||||
try:
|
||||
return local_result_handlers[(type(aval), output_type)](aval, sharding, indices)
|
||||
return local_result_handlers[(type(aval))](aval, sharding, indices)
|
||||
except KeyError as err:
|
||||
raise TypeError(
|
||||
f"No pxla_result_handler for type: {type(aval)}") from err
|
||||
|
||||
PxlaResultHandler = Callable[..., Callable[[Sequence[xb.xla_client.Buffer]], Any]]
|
||||
local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {}
|
||||
|
||||
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
|
||||
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.local_sharded_result_handler(
|
||||
aval, sharding, indices)
|
||||
else:
|
||||
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
||||
indices)
|
||||
local_result_handlers[(ShapedArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
|
||||
local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
|
||||
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
|
||||
|
||||
def global_aval_to_result_handler(
|
||||
@ -633,15 +615,14 @@ def global_aval_to_result_handler(
|
||||
for this output. The function will return an object suitable for returning
|
||||
to the user, e.g. a ShardedDeviceArray.
|
||||
"""
|
||||
output_type = OutputType.Array
|
||||
try:
|
||||
return global_result_handlers[(type(aval), output_type)](
|
||||
return global_result_handlers[type(aval)](
|
||||
aval, out_sharding, committed, is_out_sharding_from_xla)
|
||||
except KeyError as err:
|
||||
raise TypeError(
|
||||
f"No pxla_result_handler for type: {type(aval)}") from err
|
||||
|
||||
global_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {}
|
||||
global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
|
@ -335,12 +335,7 @@ class KeyTyRules:
|
||||
def local_sharded_result_handler(aval, sharding, indices):
|
||||
phys_aval, = KeyTyRules.physical_avals(aval)
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
|
||||
# TODO(yashkatariya,frostig): remove this conditional and inline it when
|
||||
# the transient config ever settles
|
||||
output_type = pxla.OutputType.Array
|
||||
phys_handler_maker = pxla.local_result_handlers[
|
||||
(core.ShapedArray, output_type)]
|
||||
phys_handler_maker = pxla.local_result_handlers[core.ShapedArray]
|
||||
|
||||
# set up a grounded sharding (with a grounded sharding spec)
|
||||
if isinstance(sharding, (PmapSharding, NamedSharding)):
|
||||
@ -366,13 +361,7 @@ class KeyTyRules:
|
||||
def global_sharded_result_handler(aval, out_sharding, committed,
|
||||
is_out_sharding_from_xla):
|
||||
phys_aval, = KeyTyRules.physical_avals(aval)
|
||||
|
||||
# TODO(yashkatariya,frostig): remove this conditional and inline it when
|
||||
# the transient config ever settles
|
||||
output_type = pxla.OutputType.Array
|
||||
|
||||
phys_handler_maker = pxla.global_result_handlers[
|
||||
(core.ShapedArray, output_type)]
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
|
||||
phys_sharding = make_key_array_phys_sharding(
|
||||
aval, out_sharding, is_out_sharding_from_xla)
|
||||
|
@ -32,7 +32,6 @@ from jax._src.interpreters.pxla import (
|
||||
NoSharding as NoSharding,
|
||||
OpShardingType as OpShardingType,
|
||||
OrderedDictType as OrderedDictType,
|
||||
OutputType as OutputType,
|
||||
ParallelCallableInfo as ParallelCallableInfo,
|
||||
PartitionInfo as PartitionInfo,
|
||||
PartitionsOrReplicated as PartitionsOrReplicated,
|
||||
@ -94,7 +93,6 @@ from jax._src.interpreters.pxla import (
|
||||
reconcile_num_partitions as reconcile_num_partitions,
|
||||
replicate as replicate,
|
||||
resource_typecheck as resource_typecheck,
|
||||
sda_array_result_handler as sda_array_result_handler,
|
||||
shard_arg_handlers as shard_arg_handlers,
|
||||
shard_args as shard_args,
|
||||
shard_arg as shard_arg,
|
||||
|
Loading…
x
Reference in New Issue
Block a user