diff --git a/benchmarks/gda_benchmark.py b/benchmarks/gda_benchmark.py new file mode 100644 index 000000000..e0705b8e0 --- /dev/null +++ b/benchmarks/gda_benchmark.py @@ -0,0 +1,95 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Microbenchmarks for JAX `api` functions.""" + +from functools import partial + +import google_benchmark +import jax +from jax._src import test_util as jtu +from jax._src.util import prod +from jax.experimental import global_device_array as gda +import numpy as np + +mesh_shapes_axes = [ + ((256, 8), ["x", "y"]), + ((256, 8), [None]), + ((256, 8), ["x"]), + ((256, 8), ["y"]), + ((256, 8), [("x", "y")]), + ((128, 8), ["x", "y"]), + ((4, 2), ["x", "y"]), +] + + +def gda_construction_callback(mesh_axes, state): + # Keep the mesh containing 8 local devices as using >8 local devices is + # unrealistic. Since `from_callback` measures `device_put` time as well, it + # dominates when local devices are for example 2048 (local devices will never + # be 2048). + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_input_shape = (2048, 2048) + global_input_data = np.arange( + prod(global_input_shape)).reshape(global_input_shape) + def cb(index): + return global_input_data[index] + + while state: + gda.GlobalDeviceArray.from_callback( + global_input_shape, global_mesh, mesh_axes, cb) + + +def gda_construction_raw(mesh_shape, mesh_axes, state): + # `device_put` time is not measured in this benchmark. All the devices here + # are local. + global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y")) + global_input_shape = (2048, 2048) + global_input_data = np.arange( + prod(global_input_shape)).reshape(global_input_shape) + global_indices = gda.get_shard_indices(global_input_shape, global_mesh, + mesh_axes) + dbs = [ + jax.device_put(global_input_data[global_indices[device]], device) + for device in global_mesh.local_devices + ] + + while state: + gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs) + + +def indices_replica_id_calc(mesh_shape, mesh_axes, state): + global_input_shape = (2048, 2048) + global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y")) + + while state: + gda.get_shard_indices_replica_ids(global_input_shape, global_mesh, mesh_axes) + + +benchmarks = [] +for mesh_shape, axes in mesh_shapes_axes: + benchmarks.extend([ + google_benchmark.register( + partial(gda_construction_callback, axes), + name=f"gda_construction_callback_(4, 2)_{axes}"), + google_benchmark.register( + partial(gda_construction_raw, mesh_shape, axes), + name=f"gda_construction_raw_{mesh_shape}_{axes}"), + google_benchmark.register( + partial(indices_replica_id_calc, mesh_shape, axes), + name=f"indices_replica_id_calc_{mesh_shape}_{axes}"), + ]) + + +if __name__ == "__main__": + google_benchmark.main() diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 30c096ba4..bf61920d2 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -41,7 +41,7 @@ from jax._src.lib import xla_bridge from jax._src import dispatch from jax.interpreters import mlir from jax.interpreters import xla -from jax.experimental.maps import mesh +from jax.experimental.maps import mesh, Mesh FLAGS = flags.FLAGS @@ -1129,6 +1129,16 @@ def restore_spmd_lowering_flag(): if old_spmd_lowering_flag is None: return config.update('experimental_xmap_spmd_lowering', old_spmd_lowering_flag) +def create_global_mesh(mesh_shape, axis_names): + size = prod(mesh_shape) + if len(api.devices()) < size: + raise unittest.SkipTest(f"Test requires {size} global devices.") + devices = sorted(api.devices(), key=lambda d: d.id) + mesh_devices = np.array(devices[:size]).reshape(mesh_shape) + global_mesh = Mesh(mesh_devices, axis_names) + return global_mesh + + class _cached_property: null = object() diff --git a/tests/global_device_array_test.py b/tests/global_device_array_test.py index 32ecf2e8b..91aa9457b 100644 --- a/tests/global_device_array_test.py +++ b/tests/global_device_array_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for GlobalDeviceArray.""" -import unittest from absl.testing import absltest from absl.testing import parameterized import numpy as np @@ -24,22 +23,12 @@ from jax._src import test_util as jtu from jax._src.util import prod, safe_zip from jax.experimental import PartitionSpec as P -from jax.experimental.maps import Mesh from jax.experimental.global_device_array import GlobalDeviceArray from jax.config import config config.parse_flags_with_absl() -def create_global_mesh(mesh_shape, axis_names): - size = prod(mesh_shape) - if len(jax.devices()) < size: - raise unittest.SkipTest(f"Test requires {size} local devices") - mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape) - global_mesh = Mesh(mesh_devices, axis_names) - return global_mesh - - class GDATest(jtu.JaxTestCase): @parameterized.named_parameters( @@ -76,7 +65,7 @@ class GDATest(jtu.JaxTestCase): ) def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids, expected_is_fully_replicated): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) @@ -126,7 +115,7 @@ class GDATest(jtu.JaxTestCase): ) def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): - global_mesh = create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) global_input_shape = (8, 4, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) @@ -160,7 +149,7 @@ class GDATest(jtu.JaxTestCase): ) def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): - global_mesh = create_global_mesh((8,), ('x')) + global_mesh = jtu.create_global_mesh((8,), ('x')) global_input_shape = (16,) global_input_data = np.arange(prod(global_input_shape)).reshape(-1) def cb(index): @@ -188,7 +177,7 @@ class GDATest(jtu.JaxTestCase): ) def test_gda_subset_devices(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): - global_mesh = create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) @@ -213,7 +202,7 @@ class GDATest(jtu.JaxTestCase): self.assertArraysEqual(g.data, l.data) def test_gda_batched_callback(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = [('x', 'y')] global_input_data = np.arange( @@ -233,7 +222,7 @@ class GDATest(jtu.JaxTestCase): expected_second_shard_value) def test_gda_batched_callback_with_devices(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x'] global_input_data = np.arange( @@ -259,7 +248,7 @@ class GDATest(jtu.JaxTestCase): expected_second_shard_value) def test_gda_str_repr(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = [('x', 'y')] global_input_data = np.arange( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index f87e5bcc5..9a0beec0d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -30,7 +30,7 @@ from jax.errors import JAXTypeError from jax import lax # TODO(skye): do we still wanna call this PartitionSpec? from jax.experimental import PartitionSpec as P -from jax.experimental.maps import xmap, mesh, Mesh +from jax.experimental.maps import xmap, mesh from jax.experimental import global_device_array import jax.experimental.pjit as pjit_lib from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint, @@ -72,15 +72,6 @@ def check_1d_2d_mesh(f, set_mesh): ))(jtu.with_mesh_from_kwargs(f) if set_mesh else f) -def create_global_mesh(mesh_shape, axis_names): - size = prod(mesh_shape) - if len(jax.devices()) < size: - raise unittest.SkipTest(f"Test requires {size} local devices") - mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape) - global_mesh = Mesh(mesh_devices, axis_names) - return global_mesh - - # TODO(skye): make the buffer donation utils part of JaxTestCase class PJitTest(jtu.BufferDonationTestCase): @@ -610,7 +601,7 @@ class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 4), ('y', 2)]) def test_pjit_gda_single_output(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x', 'y') input_data = np.arange( @@ -645,7 +636,7 @@ class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 4), ('y', 2)]) def test_pjit_gda_multi_input_multi_output(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) @@ -718,7 +709,7 @@ class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 4), ('y', 2)]) def test_pjit_gda_mixed_inputs(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x', 'y') input_data = np.arange( @@ -783,7 +774,7 @@ class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 2), ('y', 2)]) def test_pjit_gda_mesh_mismatch(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x', 'y'] global_input_data = np.arange( @@ -804,7 +795,7 @@ class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 4), ('y', 2)]) def test_pjit_gda_wrong_resource_for_gda_input(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x'] global_input_data = np.arange( @@ -830,7 +821,7 @@ class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 4), ('y', 2)]) def test_pjit_gda_caching(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) mesh_axes = P('x', 'y') input_data = np.arange( @@ -858,7 +849,7 @@ class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 4), ('y', 2)]) def test_partition_spec_mismatch_semantically_equivalent(self): - global_mesh = create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = [None] global_input_data = np.arange( @@ -882,7 +873,7 @@ class GDAPjitTest(jtu.JaxTestCase): f(output_gda) def test_from_gda_duplicates(self): - global_mesh = create_global_mesh((1, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x', 'y'] input_gda = create_gda(global_input_shape, global_mesh, mesh_axes) @@ -896,7 +887,7 @@ class GDAPjitTest(jtu.JaxTestCase): input_gda) def test_no_recompilation_due_to_in_axis_resources(self): - global_mesh = create_global_mesh((1, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(None,) input_gda = create_gda(global_input_shape, global_mesh, mesh_axes) @@ -1178,7 +1169,7 @@ class UtilTest(jtu.JaxTestCase): self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out) def test_get_input_metadata_fully_replicated(self): - global_mesh = create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_in_aval1 = jax.core.ShapedArray((4, 4), jnp.int32) global_in_aval2 = jax.core.ShapedArray((4, 4, 4), jnp.int32) global_in_aval3 = jax.core.ShapedArray((), jnp.int32)