mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
MAINT Drop underscore from the name of externally-referenced state objects
This commit is contained in:
parent
16061e6302
commit
f9087ab0c6
@ -670,7 +670,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
name=name),
|
||||
source_info_util.new_source_info(), resource_env, {})
|
||||
jaxpr = plan.subst_axes_with_resources(jaxpr)
|
||||
use_spmd_lowering = _SPMD_LOWERING.value
|
||||
use_spmd_lowering = SPMD_LOWERING.value
|
||||
ensure_fixed_sharding = _ENSURE_FIXED_SHARDING.value
|
||||
if use_spmd_lowering and ensure_fixed_sharding:
|
||||
jaxpr = _fix_inferred_spmd_sharding(jaxpr, resource_env)
|
||||
@ -686,7 +686,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
||||
mesh = resource_env.physical_mesh
|
||||
tiling_method: pxla.TilingMethod
|
||||
if _SPMD_LOWERING_MANUAL.value:
|
||||
if SPMD_LOWERING_MANUAL.value:
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
tiling_method = pxla.TileManual(manual_mesh_axes)
|
||||
else:
|
||||
@ -1284,7 +1284,7 @@ batching.BatchTrace.post_process_xmap = _batch_trace_post_process_xmap
|
||||
|
||||
def _xmap_lowering_rule(ctx, *args, **kwargs):
|
||||
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
|
||||
if _SPMD_LOWERING_MANUAL.value:
|
||||
if SPMD_LOWERING_MANUAL.value:
|
||||
return _xmap_lowering_rule_spmd_manual(ctx, *args, **kwargs)
|
||||
else:
|
||||
return _xmap_lowering_rule_spmd(ctx, *args, **kwargs)
|
||||
@ -1839,13 +1839,13 @@ def _clear_compilation_cache(_):
|
||||
|
||||
def _ensure_spmd_and(f):
|
||||
def update(v):
|
||||
if v and not _SPMD_LOWERING.value:
|
||||
if v and not SPMD_LOWERING.value:
|
||||
raise RuntimeError("This flag requires enabling the experimental_xmap_spmd_lowering flag")
|
||||
return f(v)
|
||||
return update
|
||||
|
||||
|
||||
_SPMD_LOWERING = config.define_bool_state(
|
||||
SPMD_LOWERING = config.define_bool_state(
|
||||
name="experimental_xmap_spmd_lowering",
|
||||
default=False,
|
||||
help=("When set, multi-device xmap computations will be compiled through "
|
||||
@ -1853,7 +1853,7 @@ _SPMD_LOWERING = config.define_bool_state(
|
||||
"Not supported on CPU!"),
|
||||
update_global_hook=_clear_compilation_cache,
|
||||
update_thread_local_hook=_thread_local_flag_unsupported)
|
||||
_SPMD_LOWERING_MANUAL = config.define_bool_state(
|
||||
SPMD_LOWERING_MANUAL = config.define_bool_state(
|
||||
name="experimental_xmap_spmd_lowering_manual",
|
||||
default=False,
|
||||
help=("When set, multi-device xmap computations will be compiled using "
|
||||
|
@ -67,7 +67,7 @@ _TEST_DUT = config.DEFINE_string(
|
||||
'Describes the device under test in case special consideration is required.'
|
||||
)
|
||||
|
||||
_NUM_GENERATED_CASES = config.DEFINE_integer(
|
||||
NUM_GENERATED_CASES = config.DEFINE_integer(
|
||||
'jax_num_generated_cases',
|
||||
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
|
||||
help='Number of generated cases to test')
|
||||
@ -762,7 +762,7 @@ def assert_dot_preferred_element_type(expected, fun, *args, **kwargs):
|
||||
|
||||
def cases_from_gens(*gens):
|
||||
sizes = [1, 3, 10]
|
||||
cases_per_size = int(_NUM_GENERATED_CASES.value / len(sizes)) + 1
|
||||
cases_per_size = int(NUM_GENERATED_CASES.value / len(sizes)) + 1
|
||||
for size in sizes:
|
||||
for i in range(cases_per_size):
|
||||
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
|
||||
@ -775,7 +775,7 @@ def named_cases_from_sampler(gen):
|
||||
if not isinstance(x, (list, tuple)):
|
||||
x = list(x)
|
||||
return [x[rng.randint(len(x))]]
|
||||
while (len(seen) < _NUM_GENERATED_CASES.value and
|
||||
while (len(seen) < NUM_GENERATED_CASES.value and
|
||||
retries < _MAX_CASES_SAMPLING_RETRIES.value):
|
||||
retries += 1
|
||||
cases = list(gen(choose_one))
|
||||
@ -804,7 +804,7 @@ def sample_product_testcases(*args, **kw):
|
||||
kw = [(k, list(v)) for k, v in kw.items()]
|
||||
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
|
||||
testcases = []
|
||||
for i in _choice(n, min(n, _NUM_GENERATED_CASES.value)):
|
||||
for i in _choice(n, min(n, NUM_GENERATED_CASES.value)):
|
||||
testcase = {}
|
||||
for a in args:
|
||||
testcase.update(a[i % len(a)])
|
||||
|
@ -77,7 +77,7 @@ def setUpModule():
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
global prev_spmd_lowering_flag
|
||||
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
|
||||
prev_spmd_lowering_flag = maps.SPMD_LOWERING.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
|
||||
|
||||
|
@ -259,7 +259,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.xmap_spmd_lowering_enabled = maps._SPMD_LOWERING.value
|
||||
self.xmap_spmd_lowering_enabled = maps.SPMD_LOWERING.value
|
||||
jax.config.update("experimental_xmap_spmd_lowering", True)
|
||||
|
||||
def tearDown(self):
|
||||
|
@ -78,7 +78,7 @@ def setUpModule():
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
global prev_spmd_lowering_flag
|
||||
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
|
||||
prev_spmd_lowering_flag = maps.SPMD_LOWERING.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
|
||||
def tearDownModule():
|
||||
|
@ -632,8 +632,8 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
|
||||
raise unittest.SkipTest('Manual partitioning needed for pure_callback')
|
||||
|
||||
spmd_lowering = maps._SPMD_LOWERING.value
|
||||
spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
|
||||
spmd_lowering = maps.SPMD_LOWERING.value
|
||||
spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
config.update('experimental_xmap_spmd_lowering_manual', True)
|
||||
try:
|
||||
|
@ -1382,7 +1382,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
return jtu.create_global_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
|
||||
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
||||
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
@ -1391,7 +1391,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
|
||||
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
||||
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
@ -1401,7 +1401,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
(name + f'_check_rep={check_rep}', *params, check_rep)
|
||||
for (name, *params) in sample(jtu._NUM_GENERATED_CASES.value, sample_shmap)
|
||||
for (name, *params) in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
|
||||
for check_rep in [True, False]
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
@ -1414,7 +1414,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
|
||||
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
mesh = self.make_mesh(mesh)
|
||||
@ -1433,7 +1433,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(jtu._NUM_GENERATED_CASES.value,
|
||||
sample(jtu.NUM_GENERATED_CASES.value,
|
||||
partial(sample_shmap_batched, 5)))
|
||||
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
@ -1456,7 +1456,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(jtu._NUM_GENERATED_CASES.value,
|
||||
sample(jtu.NUM_GENERATED_CASES.value,
|
||||
partial(sample_shmap_batched, 5)))
|
||||
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
mesh = self.make_mesh(mesh)
|
||||
|
@ -831,7 +831,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
|
||||
@hp.given(get_vmap_params())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
max_examples=jtu.NUM_GENERATED_CASES.value)
|
||||
def test_get_vmap(self, get_vmap_param: GetVmapParams):
|
||||
|
||||
indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims
|
||||
@ -870,7 +870,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
|
||||
@hp.given(set_vmap_params())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
max_examples=jtu.NUM_GENERATED_CASES.value)
|
||||
def test_set_vmap(self, set_vmap_param: SetVmapParams):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Scatter is nondeterministic on GPU")
|
||||
@ -915,7 +915,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
|
||||
@hp.given(set_vmap_params())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
max_examples=jtu.NUM_GENERATED_CASES.value)
|
||||
def test_addupdate_vmap(self, set_vmap_param: SetVmapParams):
|
||||
|
||||
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
|
||||
@ -1538,7 +1538,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
@jax.legacy_prng_key('allow')
|
||||
@hp.given(hps.data())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
max_examples=jtu.NUM_GENERATED_CASES.value)
|
||||
def test_jvp(self, data):
|
||||
|
||||
spec = data.draw(func_spec())
|
||||
@ -1563,7 +1563,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
@jax.legacy_prng_key('allow')
|
||||
@hp.given(hps.data())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
max_examples=jtu.NUM_GENERATED_CASES.value)
|
||||
def test_linearize(self, data):
|
||||
|
||||
spec = data.draw(func_spec())
|
||||
@ -1589,7 +1589,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
@jax.legacy_prng_key('allow')
|
||||
@hp.given(hps.data())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
max_examples=jtu.NUM_GENERATED_CASES.value)
|
||||
def test_vjp(self, data):
|
||||
|
||||
spec = data.draw(func_spec())
|
||||
|
@ -246,7 +246,7 @@ class XMapTestCase(jtu.BufferDonationTestCase):
|
||||
class SPMDTestMixin:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.spmd_lowering = maps._SPMD_LOWERING.value
|
||||
self.spmd_lowering = maps.SPMD_LOWERING.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
|
||||
def tearDown(self):
|
||||
@ -258,8 +258,8 @@ class ManualSPMDTestMixin:
|
||||
if not hasattr(xla_client.OpSharding.Type, "MANUAL"):
|
||||
raise SkipTest
|
||||
super().setUp()
|
||||
self.spmd_lowering = maps._SPMD_LOWERING.value
|
||||
self.spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
|
||||
self.spmd_lowering = maps.SPMD_LOWERING.value
|
||||
self.spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
config.update('experimental_xmap_spmd_lowering_manual', True)
|
||||
|
||||
@ -436,7 +436,7 @@ class XMapTest(XMapTestCase):
|
||||
m_size = math.prod([2] + [2] * (len(mesh) - 2))
|
||||
self.assertListEqual(y_op_sharding.tile_assignment_dimensions(),
|
||||
[2, 1, 1, m_size])
|
||||
if maps._SPMD_LOWERING.value:
|
||||
if maps.SPMD_LOWERING.value:
|
||||
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
|
||||
# Make sure that there are non-partial sharding specs in the HLO
|
||||
if xla_extension_version >= 180:
|
||||
@ -749,7 +749,7 @@ class XMapTest(XMapTestCase):
|
||||
axis_resources={'i': 'x'})
|
||||
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
||||
hlo = f.lower(x).as_text(dialect='stablehlo')
|
||||
if maps._SPMD_LOWERING.value:
|
||||
if maps.SPMD_LOWERING.value:
|
||||
self.assertIn("mhlo.num_partitions = 2", hlo)
|
||||
self.assertIn("mhlo.num_replicas = 1", hlo)
|
||||
else:
|
||||
@ -1204,7 +1204,7 @@ class NewPrimitiveTest(XMapTestCase):
|
||||
|
||||
@jtu.with_and_without_mesh
|
||||
def testGather(self, mesh, axis_resources):
|
||||
if axis_resources and not maps._SPMD_LOWERING.value:
|
||||
if axis_resources and not maps.SPMD_LOWERING.value:
|
||||
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
|
||||
x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
|
||||
y = jnp.arange(35).reshape((5, 7)) % 3
|
||||
|
Loading…
x
Reference in New Issue
Block a user