mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #16231 from jakevdp:product
PiperOrigin-RevId: 537454290
This commit is contained in:
commit
5639e194be
@ -110,7 +110,7 @@ def sharding_to_proto(sharding: SpatialSharding):
|
||||
else:
|
||||
proto.type = xc.OpSharding.Type.OTHER
|
||||
proto.tile_assignment_dimensions = list(sharding) # type: ignore
|
||||
proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore
|
||||
proto.tile_assignment_devices = list(range(np.prod(sharding))) # type: ignore
|
||||
return proto
|
||||
|
||||
def tuple_sharding_proto(elems):
|
||||
|
2
jax/_src/third_party/numpy/linalg.py
vendored
2
jax/_src/third_party/numpy/linalg.py
vendored
@ -7,7 +7,7 @@ from jax._src.numpy.util import check_arraylike, _wraps
|
||||
|
||||
def _isEmpty2d(arr):
|
||||
# check size first for efficiency
|
||||
return arr.size == 0 and np.product(arr.shape[-2:]) == 0
|
||||
return arr.size == 0 and np.prod(arr.shape[-2:]) == 0
|
||||
|
||||
|
||||
def _assertNoEmpty2d(*arrays):
|
||||
|
@ -128,7 +128,7 @@ def _create_device_mesh_for_nd_torus(
|
||||
# 4x8 or a single axis. If XLA 2D collectives support non-square plane
|
||||
# soon, we can continue to preferentially map to 2D plane in general,
|
||||
# otherwise, we should treat non-square 2D plane and 1D submesh equally.
|
||||
if np.product(c_axes) == logical_axis_size:
|
||||
if np.prod(c_axes) == logical_axis_size:
|
||||
assignment[logical_axis_index] = c_indices
|
||||
# Zero the assigned physical axes.
|
||||
assignable_physical_mesh = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user