Merge pull request #16231 from jakevdp:product

PiperOrigin-RevId: 537454290
This commit is contained in:
jax authors 2023-06-02 18:13:56 -07:00
commit 5639e194be
3 changed files with 3 additions and 3 deletions

View File

@ -110,7 +110,7 @@ def sharding_to_proto(sharding: SpatialSharding):
else:
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding) # type: ignore
proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore
proto.tile_assignment_devices = list(range(np.prod(sharding))) # type: ignore
return proto
def tuple_sharding_proto(elems):

View File

@ -7,7 +7,7 @@ from jax._src.numpy.util import check_arraylike, _wraps
def _isEmpty2d(arr):
# check size first for efficiency
return arr.size == 0 and np.product(arr.shape[-2:]) == 0
return arr.size == 0 and np.prod(arr.shape[-2:]) == 0
def _assertNoEmpty2d(*arrays):

View File

@ -128,7 +128,7 @@ def _create_device_mesh_for_nd_torus(
# 4x8 or a single axis. If XLA 2D collectives support non-square plane
# soon, we can continue to preferentially map to 2D plane in general,
# otherwise, we should treat non-square 2D plane and 1D submesh equally.
if np.product(c_axes) == logical_axis_size:
if np.prod(c_axes) == logical_axis_size:
assignment[logical_axis_index] = c_indices
# Zero the assigned physical axes.
assignable_physical_mesh = [