Remove pxla.OutputType enum class now that the only output can be jax.Array

PiperOrigin-RevId: 517985356
This commit is contained in:
Yash Katariya 2023-03-20 09:09:15 -07:00 committed by jax authors
parent 021fadfcbc
commit 58fed7001a
4 changed files with 12 additions and 44 deletions

View File

@ -714,9 +714,9 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
return xc.array_result_handler(
global_aval, out_sharding, committed=committed, _skip_checks=True
)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
pxla.global_result_handlers[core.AbstractToken] = lambda *_: lambda *_: core.token
# Only used for Arrays that come out of pmap.
@ -729,5 +729,5 @@ def _array_local_result_handler(aval, sharding, indices):
return xc.array_result_handler(
aval, sharding, committed=True, _skip_checks=True
)
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler
pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler

View File

@ -70,7 +70,7 @@ from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.config import flags
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -566,12 +566,6 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
return PartitionSpec(*partitions)
class OutputType(enum.Enum):
Array = 0
GlobalDeviceArray = 1
ShardedDeviceArray = 2
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding: sharding_impls.XLACompatibleSharding,
@ -591,26 +585,14 @@ def local_aval_to_result_handler(
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
"""
output_type = OutputType.Array
try:
return local_result_handlers[(type(aval), output_type)](aval, sharding, indices)
return local_result_handlers[(type(aval))](aval, sharding, indices)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
PxlaResultHandler = Callable[..., Callable[[Sequence[xb.xla_client.Buffer]], Any]]
local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {}
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)
else:
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)
local_result_handlers[(ShapedArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
def global_aval_to_result_handler(
@ -633,15 +615,14 @@ def global_aval_to_result_handler(
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
"""
output_type = OutputType.Array
try:
return global_result_handlers[(type(aval), output_type)](
return global_result_handlers[type(aval)](
aval, out_sharding, committed, is_out_sharding_from_xla)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
global_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {}
global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
### lazy device-memory persistence and result handling

View File

@ -335,12 +335,7 @@ class KeyTyRules:
def local_sharded_result_handler(aval, sharding, indices):
phys_aval, = KeyTyRules.physical_avals(aval)
key_shape = aval.dtype.impl.key_shape
# TODO(yashkatariya,frostig): remove this conditional and inline it when
# the transient config ever settles
output_type = pxla.OutputType.Array
phys_handler_maker = pxla.local_result_handlers[
(core.ShapedArray, output_type)]
phys_handler_maker = pxla.local_result_handlers[core.ShapedArray]
# set up a grounded sharding (with a grounded sharding spec)
if isinstance(sharding, (PmapSharding, NamedSharding)):
@ -366,13 +361,7 @@ class KeyTyRules:
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
phys_aval, = KeyTyRules.physical_avals(aval)
# TODO(yashkatariya,frostig): remove this conditional and inline it when
# the transient config ever settles
output_type = pxla.OutputType.Array
phys_handler_maker = pxla.global_result_handlers[
(core.ShapedArray, output_type)]
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_sharding = make_key_array_phys_sharding(
aval, out_sharding, is_out_sharding_from_xla)

View File

@ -32,7 +32,6 @@ from jax._src.interpreters.pxla import (
NoSharding as NoSharding,
OpShardingType as OpShardingType,
OrderedDictType as OrderedDictType,
OutputType as OutputType,
ParallelCallableInfo as ParallelCallableInfo,
PartitionInfo as PartitionInfo,
PartitionsOrReplicated as PartitionsOrReplicated,
@ -94,7 +93,6 @@ from jax._src.interpreters.pxla import (
reconcile_num_partitions as reconcile_num_partitions,
replicate as replicate,
resource_typecheck as resource_typecheck,
sda_array_result_handler as sda_array_result_handler,
shard_arg_handlers as shard_arg_handlers,
shard_args as shard_args,
shard_arg as shard_arg,