mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Finish 0.4.35 release by removing dead code
PiperOrigin-RevId: 689396609
This commit is contained in:
parent
644f881a51
commit
6c8e56f43f
@ -1719,51 +1719,43 @@ def transfer_guard(new_val: str) -> Iterator[None]:
|
||||
yield
|
||||
|
||||
|
||||
if lib.xla_extension_version < 293:
|
||||
def _update_garbage_collection_guard(state, key, val):
|
||||
"""Applies the transfer guard level within guard_lib."""
|
||||
if val is None:
|
||||
setattr(state, key, None)
|
||||
elif val == 'allow':
|
||||
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW)
|
||||
elif val == 'log':
|
||||
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG)
|
||||
elif val == 'fatal':
|
||||
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL)
|
||||
else:
|
||||
assert False, f'Invalid garbage collection guard level {val}'
|
||||
|
||||
def array_garbage_collection_guard(_val):
|
||||
raise NotImplementedError(
|
||||
'jaxlib version is too low for garbage collection guard'
|
||||
)
|
||||
|
||||
else:
|
||||
def _update_garbage_collection_guard(state, key, val):
|
||||
"""Applies the transfer guard level within guard_lib."""
|
||||
if val is None:
|
||||
setattr(state, key, None)
|
||||
elif val == 'allow':
|
||||
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW)
|
||||
elif val == 'log':
|
||||
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG)
|
||||
elif val == 'fatal':
|
||||
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL)
|
||||
else:
|
||||
assert False, f'Invalid garbage collection guard level {val}'
|
||||
|
||||
array_garbage_collection_guard = optional_enum_state(
|
||||
name='jax_array_garbage_collection_guard',
|
||||
enum_values=['allow', 'log', 'fatal'],
|
||||
# The default is applied by guard_lib.
|
||||
default=None,
|
||||
help=(
|
||||
'Select garbage collection guard level for "jax.Array" objects.\nThis'
|
||||
' option can be used to control what happens when a "jax.Array"'
|
||||
' object is garbage collected. It is desirable for "jax.Array"'
|
||||
' objects to be freed by Python reference couting rather than garbage'
|
||||
' collection in order to avoid device memory being held by the arrays'
|
||||
' until garbage collection occurs.\n\nValid values are:\n * "allow":'
|
||||
' do not log garbage collection of "jax.Array" objects.\n * "log":'
|
||||
' log an error when a "jax.Array" is garbage collected.\n * "fatal":'
|
||||
' fatal error if a "jax.Array" is garbage collected.\nDefault is'
|
||||
' "allow".'
|
||||
),
|
||||
update_global_hook=lambda val: _update_garbage_collection_guard(
|
||||
guard_lib.global_state(), 'garbage_collect_array', val
|
||||
),
|
||||
update_thread_local_hook=lambda val: _update_garbage_collection_guard(
|
||||
guard_lib.thread_local_state(), 'garbage_collect_array', val
|
||||
),
|
||||
)
|
||||
array_garbage_collection_guard = optional_enum_state(
|
||||
name='jax_array_garbage_collection_guard',
|
||||
enum_values=['allow', 'log', 'fatal'],
|
||||
# The default is applied by guard_lib.
|
||||
default=None,
|
||||
help=(
|
||||
'Select garbage collection guard level for "jax.Array" objects.\nThis'
|
||||
' option can be used to control what happens when a "jax.Array"'
|
||||
' object is garbage collected. It is desirable for "jax.Array"'
|
||||
' objects to be freed by Python reference couting rather than garbage'
|
||||
' collection in order to avoid device memory being held by the arrays'
|
||||
' until garbage collection occurs.\n\nValid values are:\n * "allow":'
|
||||
' do not log garbage collection of "jax.Array" objects.\n * "log":'
|
||||
' log an error when a "jax.Array" is garbage collected.\n * "fatal":'
|
||||
' fatal error if a "jax.Array" is garbage collected.\nDefault is'
|
||||
' "allow".'
|
||||
),
|
||||
update_global_hook=lambda val: _update_garbage_collection_guard(
|
||||
guard_lib.global_state(), 'garbage_collect_array', val
|
||||
),
|
||||
update_thread_local_hook=lambda val: _update_garbage_collection_guard(
|
||||
guard_lib.thread_local_state(), 'garbage_collect_array', val
|
||||
),
|
||||
)
|
||||
|
||||
def _update_debug_log_modules(module_names_str: str | None):
|
||||
logging_config.disable_all_debug_logging()
|
||||
|
@ -48,13 +48,12 @@ from jax._src.interpreters import pxla
|
||||
from jax._src import lib
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.monitoring import record_event_duration_secs
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, NamedSharding,
|
||||
GSPMDSharding, TransferToMemoryKind, is_single_device_sharding)
|
||||
SingleDeviceSharding, NamedSharding, TransferToMemoryKind,
|
||||
is_single_device_sharding)
|
||||
from jax._src.layout import Layout, DeviceLocalLayout
|
||||
|
||||
|
||||
@ -361,50 +360,21 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
|
||||
f"platform {inp_plat} and target sharding's device set "
|
||||
f"ids: {target_ids} on platform {target_plat}")
|
||||
|
||||
if xla_extension_version >= 292:
|
||||
if inp_sharding.is_fully_replicated:
|
||||
permute_order = None
|
||||
else:
|
||||
permute_order = np.vectorize(target_sharding._device_assignment.index,
|
||||
otypes=[int])(inp_sharding._device_assignment)
|
||||
new_mesh = Mesh(
|
||||
target_sharding.mesh.devices.reshape(inp_sharding.mesh.axis_sizes),
|
||||
inp_sharding.mesh.axis_names)
|
||||
new_s = NamedSharding(
|
||||
new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind,
|
||||
_logical_device_ids=(None if permute_order is None else
|
||||
tuple(permute_order.tolist())))
|
||||
new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays)
|
||||
return api.jit(_identity_fn, out_shardings=target_sharding,
|
||||
donate_argnums=donate_argnums)(new_x)
|
||||
if inp_sharding.is_fully_replicated:
|
||||
permute_order = None
|
||||
else:
|
||||
old_hlo_sharding = inp_sharding._to_xla_hlo_sharding(x.ndim)
|
||||
if old_hlo_sharding.is_replicated():
|
||||
new_hlo_sharding = old_hlo_sharding
|
||||
else:
|
||||
permute_order = np.vectorize(target_sharding._device_assignment.index,
|
||||
permute_order = np.vectorize(target_sharding._device_assignment.index,
|
||||
otypes=[int])(inp_sharding._device_assignment)
|
||||
# Unfortunately need to fallback to V1 sharding here.
|
||||
new_op_sharding = old_hlo_sharding.to_proto()
|
||||
new_op_sharding.iota_reshape_dims = []
|
||||
new_op_sharding.iota_transpose_perm = []
|
||||
new_op_sharding.tile_assignment_devices = np.take(
|
||||
permute_order, old_hlo_sharding.tile_assignment_devices()
|
||||
)
|
||||
new_hlo_sharding = xc.HloSharding.from_proto(new_op_sharding)
|
||||
assert (list(np.take(inp_sharding._device_assignment,
|
||||
old_hlo_sharding.tile_assignment_devices()))
|
||||
== list(np.take(target_sharding._device_assignment,
|
||||
new_op_sharding.tile_assignment_devices)))
|
||||
|
||||
new_x = array.make_array_from_single_device_arrays(
|
||||
x.shape,
|
||||
GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding,
|
||||
memory_kind=target_sharding.memory_kind),
|
||||
x._arrays,
|
||||
)
|
||||
return api.jit(_identity_fn, out_shardings=target_sharding,
|
||||
donate_argnums=donate_argnums)(new_x)
|
||||
new_mesh = Mesh(
|
||||
target_sharding.mesh.devices.reshape(inp_sharding.mesh.axis_sizes),
|
||||
inp_sharding.mesh.axis_names)
|
||||
new_s = NamedSharding(
|
||||
new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind,
|
||||
_logical_device_ids=(None if permute_order is None else
|
||||
tuple(permute_order.tolist())))
|
||||
new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays)
|
||||
return api.jit(_identity_fn, out_shardings=target_sharding,
|
||||
donate_argnums=donate_argnums)(new_x)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -37,7 +37,6 @@ from jax._src.op_shardings import (
|
||||
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_extension_version
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -242,8 +241,6 @@ class NamedSharding(sharding.Sharding):
|
||||
_parsed_pspec: ParsedPartitionSpec
|
||||
_manual_axes: frozenset[MeshAxisName]
|
||||
_logical_device_ids: tuple[int, ...] | None
|
||||
if xla_extension_version < 292:
|
||||
_logical_device_ids = None
|
||||
|
||||
@use_cpp_method()
|
||||
def __init__(
|
||||
@ -308,15 +305,10 @@ class NamedSharding(sharding.Sharding):
|
||||
cls, mesh, parsed_pspec, *, memory_kind=None, _manual_axes=frozenset(),
|
||||
_logical_device_ids=None,
|
||||
):
|
||||
if xla_extension_version >= 292:
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec,
|
||||
_manual_axes=_manual_axes,
|
||||
_logical_device_ids=_logical_device_ids)
|
||||
else:
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec,
|
||||
_manual_axes=_manual_axes)
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec,
|
||||
_manual_axes=_manual_axes,
|
||||
_logical_device_ids=_logical_device_ids)
|
||||
|
||||
@property
|
||||
def num_devices(self) -> int:
|
||||
|
@ -20,7 +20,6 @@ from unittest import mock
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax._src.test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -46,9 +45,6 @@ def _create_array_cycle():
|
||||
class GarbageCollectionGuardTest(jtu.JaxTestCase):
|
||||
|
||||
def test_gced_array_is_not_logged_by_default(self):
|
||||
if xla_extension_version < 293:
|
||||
self.skipTest("Requires xla_extension_version >= 293")
|
||||
|
||||
# Create a reference cycle of two jax.Arrays.
|
||||
_create_array_cycle()
|
||||
|
||||
@ -66,9 +62,6 @@ class GarbageCollectionGuardTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_gced_array_is_logged(self):
|
||||
if xla_extension_version < 293:
|
||||
self.skipTest("Requires xla_extension_version >= 293")
|
||||
|
||||
# Use mock_stderr to be able to inspect stderr.
|
||||
mock_stderr = io.StringIO()
|
||||
|
||||
|
@ -26,7 +26,6 @@ from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.layout import DeviceLocalLayout as DLL, Layout
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import config
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.numpy as jnp
|
||||
@ -655,8 +654,6 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices('tpu')
|
||||
def test_ragged_copy_on_host(self):
|
||||
if xla_extension_version < 290:
|
||||
self.skipTest('Requires xla_extension_version >= 290')
|
||||
mesh = jtu.create_mesh((2,), ('x'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P(('x')))
|
||||
cpu_sharding = sharding.with_memory_kind('pinned_host')
|
||||
|
@ -38,7 +38,6 @@ from jax import dtypes
|
||||
from jax import stages
|
||||
from jax import lax
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.lax import with_sharding_constraint
|
||||
from jax._src import prng
|
||||
from jax.sharding import PartitionSpec as P, Mesh
|
||||
@ -5604,9 +5603,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertTrue(hs4.is_tiled())
|
||||
|
||||
def test_hlo_sharding_with_device_ordering(self):
|
||||
if xla_extension_version < 291:
|
||||
self.skipTest('Requires xla_extension_version >= 291')
|
||||
|
||||
hs1 = xc.HloSharding.subgroup_with_device_ordering(
|
||||
np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.int64),
|
||||
subgroup_types=[xc.OpSharding.Type.REPLICATED],
|
||||
@ -5718,7 +5714,6 @@ class ShardyTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text())
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292")
|
||||
def test_lowering_with_sharding_constraint(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
arr = np.arange(16).reshape(4, 2, 2)
|
||||
@ -5744,7 +5739,6 @@ class ShardyTest(jtu.JaxTestCase):
|
||||
self.assertIn('<@mesh, [{"x"}, {?}, {"y"}]>', lowered_str)
|
||||
|
||||
# TODO(bartchr): run on CPU once Shardy is added to the XLA CPU pipeline.
|
||||
@unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292")
|
||||
@jtu.skip_on_devices('cpu')
|
||||
def test_compile_with_inferred_out_sharding(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
@ -37,7 +37,6 @@ from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
@ -2642,7 +2641,6 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
@unittest.skipIf(sdy is None, "shardy is not enabled")
|
||||
class SdyIntegrationTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292")
|
||||
# Verify we can lower to a `ManualComputationOp`.
|
||||
def test_shardy_collective_permute(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
Loading…
x
Reference in New Issue
Block a user