MAINT Drop underscore from the name of externally-referenced state objects

This commit is contained in:
Sergei Lebedev 2023-10-13 21:27:14 +01:00
parent 16061e6302
commit f9087ab0c6
9 changed files with 33 additions and 33 deletions

View File

@ -670,7 +670,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
name=name),
source_info_util.new_source_info(), resource_env, {})
jaxpr = plan.subst_axes_with_resources(jaxpr)
use_spmd_lowering = _SPMD_LOWERING.value
use_spmd_lowering = SPMD_LOWERING.value
ensure_fixed_sharding = _ENSURE_FIXED_SHARDING.value
if use_spmd_lowering and ensure_fixed_sharding:
jaxpr = _fix_inferred_spmd_sharding(jaxpr, resource_env)
@ -686,7 +686,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
mesh = resource_env.physical_mesh
tiling_method: pxla.TilingMethod
if _SPMD_LOWERING_MANUAL.value:
if SPMD_LOWERING_MANUAL.value:
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
tiling_method = pxla.TileManual(manual_mesh_axes)
else:
@ -1284,7 +1284,7 @@ batching.BatchTrace.post_process_xmap = _batch_trace_post_process_xmap
def _xmap_lowering_rule(ctx, *args, **kwargs):
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
if _SPMD_LOWERING_MANUAL.value:
if SPMD_LOWERING_MANUAL.value:
return _xmap_lowering_rule_spmd_manual(ctx, *args, **kwargs)
else:
return _xmap_lowering_rule_spmd(ctx, *args, **kwargs)
@ -1839,13 +1839,13 @@ def _clear_compilation_cache(_):
def _ensure_spmd_and(f):
def update(v):
if v and not _SPMD_LOWERING.value:
if v and not SPMD_LOWERING.value:
raise RuntimeError("This flag requires enabling the experimental_xmap_spmd_lowering flag")
return f(v)
return update
_SPMD_LOWERING = config.define_bool_state(
SPMD_LOWERING = config.define_bool_state(
name="experimental_xmap_spmd_lowering",
default=False,
help=("When set, multi-device xmap computations will be compiled through "
@ -1853,7 +1853,7 @@ _SPMD_LOWERING = config.define_bool_state(
"Not supported on CPU!"),
update_global_hook=_clear_compilation_cache,
update_thread_local_hook=_thread_local_flag_unsupported)
_SPMD_LOWERING_MANUAL = config.define_bool_state(
SPMD_LOWERING_MANUAL = config.define_bool_state(
name="experimental_xmap_spmd_lowering_manual",
default=False,
help=("When set, multi-device xmap computations will be compiled using "

View File

@ -67,7 +67,7 @@ _TEST_DUT = config.DEFINE_string(
'Describes the device under test in case special consideration is required.'
)
_NUM_GENERATED_CASES = config.DEFINE_integer(
NUM_GENERATED_CASES = config.DEFINE_integer(
'jax_num_generated_cases',
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
help='Number of generated cases to test')
@ -762,7 +762,7 @@ def assert_dot_preferred_element_type(expected, fun, *args, **kwargs):
def cases_from_gens(*gens):
sizes = [1, 3, 10]
cases_per_size = int(_NUM_GENERATED_CASES.value / len(sizes)) + 1
cases_per_size = int(NUM_GENERATED_CASES.value / len(sizes)) + 1
for size in sizes:
for i in range(cases_per_size):
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
@ -775,7 +775,7 @@ def named_cases_from_sampler(gen):
if not isinstance(x, (list, tuple)):
x = list(x)
return [x[rng.randint(len(x))]]
while (len(seen) < _NUM_GENERATED_CASES.value and
while (len(seen) < NUM_GENERATED_CASES.value and
retries < _MAX_CASES_SAMPLING_RETRIES.value):
retries += 1
cases = list(gen(choose_one))
@ -804,7 +804,7 @@ def sample_product_testcases(*args, **kw):
kw = [(k, list(v)) for k, v in kw.items()]
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
testcases = []
for i in _choice(n, min(n, _NUM_GENERATED_CASES.value)):
for i in _choice(n, min(n, NUM_GENERATED_CASES.value)):
testcase = {}
for a in args:
testcase.update(a[i % len(a)])

View File

@ -77,7 +77,7 @@ def setUpModule():
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
global prev_spmd_lowering_flag
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
prev_spmd_lowering_flag = maps.SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)

View File

@ -259,7 +259,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.xmap_spmd_lowering_enabled = maps._SPMD_LOWERING.value
self.xmap_spmd_lowering_enabled = maps.SPMD_LOWERING.value
jax.config.update("experimental_xmap_spmd_lowering", True)
def tearDown(self):

View File

@ -78,7 +78,7 @@ def setUpModule():
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
global prev_spmd_lowering_flag
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
prev_spmd_lowering_flag = maps.SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)
def tearDownModule():

View File

@ -632,8 +632,8 @@ class PureCallbackTest(jtu.JaxTestCase):
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
raise unittest.SkipTest('Manual partitioning needed for pure_callback')
spmd_lowering = maps._SPMD_LOWERING.value
spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
spmd_lowering = maps.SPMD_LOWERING.value
spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
config.update('experimental_xmap_spmd_lowering', True)
config.update('experimental_xmap_spmd_lowering_manual', True)
try:

View File

@ -1382,7 +1382,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
return jtu.create_global_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))
@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
@ -1391,7 +1391,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
self.assertAllClose(expected, out, check_dtypes=False)
@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
@ -1401,7 +1401,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
@parameterized.named_parameters(
(name + f'_check_rep={check_rep}', *params, check_rep)
for (name, *params) in sample(jtu._NUM_GENERATED_CASES.value, sample_shmap)
for (name, *params) in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
for check_rep in [True, False]
)
@jax.default_matmul_precision("float32")
@ -1414,7 +1414,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
mesh = self.make_mesh(mesh)
@ -1433,7 +1433,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value,
sample(jtu.NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
@ -1456,7 +1456,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value,
sample(jtu.NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
mesh = self.make_mesh(mesh)

View File

@ -831,7 +831,7 @@ if CAN_USE_HYPOTHESIS:
@hp.given(get_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_get_vmap(self, get_vmap_param: GetVmapParams):
indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims
@ -870,7 +870,7 @@ if CAN_USE_HYPOTHESIS:
@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_set_vmap(self, set_vmap_param: SetVmapParams):
if jtu.test_device_matches(["gpu"]):
self.skipTest("Scatter is nondeterministic on GPU")
@ -915,7 +915,7 @@ if CAN_USE_HYPOTHESIS:
@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_addupdate_vmap(self, set_vmap_param: SetVmapParams):
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
@ -1538,7 +1538,7 @@ if CAN_USE_HYPOTHESIS:
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_jvp(self, data):
spec = data.draw(func_spec())
@ -1563,7 +1563,7 @@ if CAN_USE_HYPOTHESIS:
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_linearize(self, data):
spec = data.draw(func_spec())
@ -1589,7 +1589,7 @@ if CAN_USE_HYPOTHESIS:
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_vjp(self, data):
spec = data.draw(func_spec())

View File

@ -246,7 +246,7 @@ class XMapTestCase(jtu.BufferDonationTestCase):
class SPMDTestMixin:
def setUp(self):
super().setUp()
self.spmd_lowering = maps._SPMD_LOWERING.value
self.spmd_lowering = maps.SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)
def tearDown(self):
@ -258,8 +258,8 @@ class ManualSPMDTestMixin:
if not hasattr(xla_client.OpSharding.Type, "MANUAL"):
raise SkipTest
super().setUp()
self.spmd_lowering = maps._SPMD_LOWERING.value
self.spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
self.spmd_lowering = maps.SPMD_LOWERING.value
self.spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
config.update('experimental_xmap_spmd_lowering', True)
config.update('experimental_xmap_spmd_lowering_manual', True)
@ -436,7 +436,7 @@ class XMapTest(XMapTestCase):
m_size = math.prod([2] + [2] * (len(mesh) - 2))
self.assertListEqual(y_op_sharding.tile_assignment_dimensions(),
[2, 1, 1, m_size])
if maps._SPMD_LOWERING.value:
if maps.SPMD_LOWERING.value:
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
# Make sure that there are non-partial sharding specs in the HLO
if xla_extension_version >= 180:
@ -749,7 +749,7 @@ class XMapTest(XMapTestCase):
axis_resources={'i': 'x'})
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
hlo = f.lower(x).as_text(dialect='stablehlo')
if maps._SPMD_LOWERING.value:
if maps.SPMD_LOWERING.value:
self.assertIn("mhlo.num_partitions = 2", hlo)
self.assertIn("mhlo.num_replicas = 1", hlo)
else:
@ -1204,7 +1204,7 @@ class NewPrimitiveTest(XMapTestCase):
@jtu.with_and_without_mesh
def testGather(self, mesh, axis_resources):
if axis_resources and not maps._SPMD_LOWERING.value:
if axis_resources and not maps.SPMD_LOWERING.value:
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
y = jnp.arange(35).reshape((5, 7)) % 3