From 58fed7001afa9cefeba9aa05efeb11a7e0298c0b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 20 Mar 2023 09:09:15 -0700 Subject: [PATCH] Remove pxla.OutputType enum class now that the only output can be jax.Array PiperOrigin-RevId: 517985356 --- jax/_src/array.py | 10 +++++----- jax/_src/interpreters/pxla.py | 29 +++++------------------------ jax/_src/prng.py | 15 ++------------- jax/interpreters/pxla.py | 2 -- 4 files changed, 12 insertions(+), 44 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 5226dc2a0..38d8dddce 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -714,9 +714,9 @@ def _array_global_result_handler(global_aval, out_sharding, committed, return xc.array_result_handler( global_aval, out_sharding, committed=committed, _skip_checks=True ) -pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler -pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler -pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token +pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler +pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler +pxla.global_result_handlers[core.AbstractToken] = lambda *_: lambda *_: core.token # Only used for Arrays that come out of pmap. @@ -729,5 +729,5 @@ def _array_local_result_handler(aval, sharding, indices): return xc.array_result_handler( aval, sharding, committed=True, _skip_checks=True ) -pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler -pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler +pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler +pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9ff7314c5..44028b1a5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -70,7 +70,7 @@ from jax._src import xla_bridge as xb from jax._src.abstract_arrays import array_types from jax._src.config import config from jax._src.config import flags -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -566,12 +566,6 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): return PartitionSpec(*partitions) -class OutputType(enum.Enum): - Array = 0 - GlobalDeviceArray = 1 - ShardedDeviceArray = 2 - - def local_aval_to_result_handler( aval: core.AbstractValue, sharding: sharding_impls.XLACompatibleSharding, @@ -591,26 +585,14 @@ def local_aval_to_result_handler( for this output. The function will return an object suitable for returning to the user, e.g. a ShardedDeviceArray. """ - output_type = OutputType.Array try: - return local_result_handlers[(type(aval), output_type)](aval, sharding, indices) + return local_result_handlers[(type(aval))](aval, sharding, indices) except KeyError as err: raise TypeError( f"No pxla_result_handler for type: {type(aval)}") from err PxlaResultHandler = Callable[..., Callable[[Sequence[xb.xla_client.Buffer]], Any]] -local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {} - -def sda_array_result_handler(aval: ShapedArray, sharding, indices): - sharding_spec = _get_sharding_specs([sharding], [aval])[0] - if core.is_opaque_dtype(aval.dtype): - return aval.dtype._rules.local_sharded_result_handler( - aval, sharding, indices) - else: - return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs, - indices) -local_result_handlers[(ShapedArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler -local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler +local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {} def global_aval_to_result_handler( @@ -633,15 +615,14 @@ def global_aval_to_result_handler( for this output. The function will return an object suitable for returning to the user, e.g. a ShardedDeviceArray. """ - output_type = OutputType.Array try: - return global_result_handlers[(type(aval), output_type)]( + return global_result_handlers[type(aval)]( aval, out_sharding, committed, is_out_sharding_from_xla) except KeyError as err: raise TypeError( f"No pxla_result_handler for type: {type(aval)}") from err -global_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {} +global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {} ### lazy device-memory persistence and result handling diff --git a/jax/_src/prng.py b/jax/_src/prng.py index b8bf5bea2..8185ca404 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -335,12 +335,7 @@ class KeyTyRules: def local_sharded_result_handler(aval, sharding, indices): phys_aval, = KeyTyRules.physical_avals(aval) key_shape = aval.dtype.impl.key_shape - - # TODO(yashkatariya,frostig): remove this conditional and inline it when - # the transient config ever settles - output_type = pxla.OutputType.Array - phys_handler_maker = pxla.local_result_handlers[ - (core.ShapedArray, output_type)] + phys_handler_maker = pxla.local_result_handlers[core.ShapedArray] # set up a grounded sharding (with a grounded sharding spec) if isinstance(sharding, (PmapSharding, NamedSharding)): @@ -366,13 +361,7 @@ class KeyTyRules: def global_sharded_result_handler(aval, out_sharding, committed, is_out_sharding_from_xla): phys_aval, = KeyTyRules.physical_avals(aval) - - # TODO(yashkatariya,frostig): remove this conditional and inline it when - # the transient config ever settles - output_type = pxla.OutputType.Array - - phys_handler_maker = pxla.global_result_handlers[ - (core.ShapedArray, output_type)] + phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_sharding = make_key_array_phys_sharding( aval, out_sharding, is_out_sharding_from_xla) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index b9ef7957c..998d81e53 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -32,7 +32,6 @@ from jax._src.interpreters.pxla import ( NoSharding as NoSharding, OpShardingType as OpShardingType, OrderedDictType as OrderedDictType, - OutputType as OutputType, ParallelCallableInfo as ParallelCallableInfo, PartitionInfo as PartitionInfo, PartitionsOrReplicated as PartitionsOrReplicated, @@ -94,7 +93,6 @@ from jax._src.interpreters.pxla import ( reconcile_num_partitions as reconcile_num_partitions, replicate as replicate, resource_typecheck as resource_typecheck, - sda_array_result_handler as sda_array_result_handler, shard_arg_handlers as shard_arg_handlers, shard_args as shard_args, shard_arg as shard_arg,