mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21070 from shuhand0:rel0.0.7
PiperOrigin-RevId: 631218770
This commit is contained in:
commit
eee2783e85
5
.github/workflows/metal_plugin_ci.yml
vendored
5
.github/workflows/metal_plugin_ci.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
jaxlib-version: ["plugin_latest"]
|
||||
jaxlib-version: ["pypi_latest", "nightly"]
|
||||
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
@ -32,13 +32,14 @@ jobs:
|
||||
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
|
||||
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
|
||||
pip install -U pip numpy wheel
|
||||
pip install jax-metal absl-py pytest
|
||||
pip install absl-py pytest
|
||||
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
|
||||
pip install --pre jaxlib \
|
||||
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
fi;
|
||||
cd jax
|
||||
pip install .
|
||||
pip install jax-metal
|
||||
- name: Run test
|
||||
run: |
|
||||
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
|
||||
|
@ -302,13 +302,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
||||
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
|
||||
def testNonzero(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
|
||||
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, fill_value=fill_value)
|
||||
@ -370,20 +368,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
|
||||
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
|
||||
def testArgWhere(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
|
||||
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
|
||||
|
||||
# JIT compilation requires specifying a size statically. Full test of this
|
||||
# behavior is in testNonzeroSize().
|
||||
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, fill_value=fill_value)
|
||||
@ -2055,7 +2049,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(np_input, expected_np_input_after_call)
|
||||
self.assertAllClose(jnp_input, expected_jnp_input_after_call)
|
||||
|
||||
@unittest.skip("Jax-metal fail to convert 1D convolution op.")
|
||||
@jtu.sample_product(
|
||||
mode=['full', 'same', 'valid'],
|
||||
op=['convolve', 'correlate'],
|
||||
@ -2077,7 +2070,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@unittest.skip("Jax-metal fail to convert 1D convolution op.")
|
||||
@jtu.sample_product(
|
||||
mode=['full', 'same', 'valid'],
|
||||
op=['convolve', 'correlate'],
|
||||
@ -4431,15 +4423,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes, dtype=all_dtypes,
|
||||
shape=nonzerodim_shapes, dtype=all_dtypes,
|
||||
)
|
||||
def testWhereOneArgument(self, shape, dtype):
|
||||
rng = jtu.rand_some_zero(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="Calling nonzero on 0d arrays.*"):
|
||||
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
|
||||
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
|
||||
|
||||
# JIT compilation requires specifying a size statically. Full test of
|
||||
# this behavior is in testNonzeroSize().
|
||||
@ -5724,7 +5713,7 @@ class ReportedIssuesTests(jtu.JaxTestCase):
|
||||
#loc = loc(unknown)
|
||||
module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<3x2x3xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x2xi32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<3x2xf32> {
|
||||
%0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 2], start_index_map = [0, 2], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = dense<[1, 2, 1]> : tensor<3xi64>} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2)
|
||||
%0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 2], start_index_map = [0, 2], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 2, 1>} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2)
|
||||
return %0 : tensor<3x2xf32> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
|
Loading…
x
Reference in New Issue
Block a user