Benchmarks for GDA. Also move create_global_mesh to test_utils since it was replicated in a lot of places.

PiperOrigin-RevId: 421142813
This commit is contained in:
Yash Katariya 2022-01-11 15:42:31 -08:00 committed by jax authors
parent f235edbf5c
commit fbb8b9f8c6
4 changed files with 124 additions and 39 deletions

View File

@ -0,0 +1,95 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Microbenchmarks for JAX `api` functions."""
from functools import partial
import google_benchmark
import jax
from jax._src import test_util as jtu
from jax._src.util import prod
from jax.experimental import global_device_array as gda
import numpy as np
mesh_shapes_axes = [
((256, 8), ["x", "y"]),
((256, 8), [None]),
((256, 8), ["x"]),
((256, 8), ["y"]),
((256, 8), [("x", "y")]),
((128, 8), ["x", "y"]),
((4, 2), ["x", "y"]),
]
def gda_construction_callback(mesh_axes, state):
# Keep the mesh containing 8 local devices as using >8 local devices is
# unrealistic. Since `from_callback` measures `device_put` time as well, it
# dominates when local devices are for example 2048 (local devices will never
# be 2048).
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (2048, 2048)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
while state:
gda.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
def gda_construction_raw(mesh_shape, mesh_axes, state):
# `device_put` time is not measured in this benchmark. All the devices here
# are local.
global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))
global_input_shape = (2048, 2048)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
global_indices = gda.get_shard_indices(global_input_shape, global_mesh,
mesh_axes)
dbs = [
jax.device_put(global_input_data[global_indices[device]], device)
for device in global_mesh.local_devices
]
while state:
gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
def indices_replica_id_calc(mesh_shape, mesh_axes, state):
global_input_shape = (2048, 2048)
global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))
while state:
gda.get_shard_indices_replica_ids(global_input_shape, global_mesh, mesh_axes)
benchmarks = []
for mesh_shape, axes in mesh_shapes_axes:
benchmarks.extend([
google_benchmark.register(
partial(gda_construction_callback, axes),
name=f"gda_construction_callback_(4, 2)_{axes}"),
google_benchmark.register(
partial(gda_construction_raw, mesh_shape, axes),
name=f"gda_construction_raw_{mesh_shape}_{axes}"),
google_benchmark.register(
partial(indices_replica_id_calc, mesh_shape, axes),
name=f"indices_replica_id_calc_{mesh_shape}_{axes}"),
])
if __name__ == "__main__":
google_benchmark.main()

View File

@ -41,7 +41,7 @@ from jax._src.lib import xla_bridge
from jax._src import dispatch
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.experimental.maps import mesh
from jax.experimental.maps import mesh, Mesh
FLAGS = flags.FLAGS
@ -1129,6 +1129,16 @@ def restore_spmd_lowering_flag():
if old_spmd_lowering_flag is None: return
config.update('experimental_xmap_spmd_lowering', old_spmd_lowering_flag)
def create_global_mesh(mesh_shape, axis_names):
size = prod(mesh_shape)
if len(api.devices()) < size:
raise unittest.SkipTest(f"Test requires {size} global devices.")
devices = sorted(api.devices(), key=lambda d: d.id)
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
global_mesh = Mesh(mesh_devices, axis_names)
return global_mesh
class _cached_property:
null = object()

View File

@ -13,7 +13,6 @@
# limitations under the License.
"""Tests for GlobalDeviceArray."""
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
@ -24,22 +23,12 @@ from jax._src import test_util as jtu
from jax._src.util import prod, safe_zip
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import Mesh
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.config import config
config.parse_flags_with_absl()
def create_global_mesh(mesh_shape, axis_names):
size = prod(mesh_shape)
if len(jax.devices()) < size:
raise unittest.SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
global_mesh = Mesh(mesh_devices, axis_names)
return global_mesh
class GDATest(jtu.JaxTestCase):
@parameterized.named_parameters(
@ -76,7 +65,7 @@ class GDATest(jtu.JaxTestCase):
)
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids, expected_is_fully_replicated):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
@ -126,7 +115,7 @@ class GDATest(jtu.JaxTestCase):
)
def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
global_input_shape = (8, 4, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
@ -160,7 +149,7 @@ class GDATest(jtu.JaxTestCase):
)
def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = create_global_mesh((8,), ('x'))
global_mesh = jtu.create_global_mesh((8,), ('x'))
global_input_shape = (16,)
global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
def cb(index):
@ -188,7 +177,7 @@ class GDATest(jtu.JaxTestCase):
)
def test_gda_subset_devices(self, mesh_axes, expected_index,
expected_shard_shape, expected_replica_ids):
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
@ -213,7 +202,7 @@ class GDATest(jtu.JaxTestCase):
self.assertArraysEqual(g.data, l.data)
def test_gda_batched_callback(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [('x', 'y')]
global_input_data = np.arange(
@ -233,7 +222,7 @@ class GDATest(jtu.JaxTestCase):
expected_second_shard_value)
def test_gda_batched_callback_with_devices(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x']
global_input_data = np.arange(
@ -259,7 +248,7 @@ class GDATest(jtu.JaxTestCase):
expected_second_shard_value)
def test_gda_str_repr(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [('x', 'y')]
global_input_data = np.arange(

View File

@ -30,7 +30,7 @@ from jax.errors import JAXTypeError
from jax import lax
# TODO(skye): do we still wanna call this PartitionSpec?
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap, mesh, Mesh
from jax.experimental.maps import xmap, mesh
from jax.experimental import global_device_array
import jax.experimental.pjit as pjit_lib
from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
@ -72,15 +72,6 @@ def check_1d_2d_mesh(f, set_mesh):
))(jtu.with_mesh_from_kwargs(f) if set_mesh else f)
def create_global_mesh(mesh_shape, axis_names):
size = prod(mesh_shape)
if len(jax.devices()) < size:
raise unittest.SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
global_mesh = Mesh(mesh_devices, axis_names)
return global_mesh
# TODO(skye): make the buffer donation utils part of JaxTestCase
class PJitTest(jtu.BufferDonationTestCase):
@ -610,7 +601,7 @@ class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_single_output(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
@ -645,7 +636,7 @@ class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_multi_input_multi_output(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
@ -718,7 +709,7 @@ class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_mixed_inputs(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
@ -783,7 +774,7 @@ class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 2), ('y', 2)])
def test_pjit_gda_mesh_mismatch(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
global_input_data = np.arange(
@ -804,7 +795,7 @@ class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_wrong_resource_for_gda_input(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x']
global_input_data = np.arange(
@ -830,7 +821,7 @@ class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_caching(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
@ -858,7 +849,7 @@ class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_partition_spec_mismatch_semantically_equivalent(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [None]
global_input_data = np.arange(
@ -882,7 +873,7 @@ class GDAPjitTest(jtu.JaxTestCase):
f(output_gda)
def test_from_gda_duplicates(self):
global_mesh = create_global_mesh((1, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
input_gda = create_gda(global_input_shape, global_mesh, mesh_axes)
@ -896,7 +887,7 @@ class GDAPjitTest(jtu.JaxTestCase):
input_gda)
def test_no_recompilation_due_to_in_axis_resources(self):
global_mesh = create_global_mesh((1, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None,)
input_gda = create_gda(global_input_shape, global_mesh, mesh_axes)
@ -1178,7 +1169,7 @@ class UtilTest(jtu.JaxTestCase):
self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out)
def test_get_input_metadata_fully_replicated(self):
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_in_aval1 = jax.core.ShapedArray((4, 4), jnp.int32)
global_in_aval2 = jax.core.ShapedArray((4, 4, 4), jnp.int32)
global_in_aval3 = jax.core.ShapedArray((), jnp.int32)