22644 Commits

Author SHA1 Message Date
Jake VanderPlas
09fd345de9 pre-commit: update hooks & pin using hashes 2024-08-27 15:23:13 -07:00
jax authors
45dcfde293 Merge pull request #23279 from jakevdp:ruff-061
PiperOrigin-RevId: 668158604
2024-08-27 15:16:33 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
jax authors
88a2008829 Merge pull request #22972 from mgoldfarb-nvidia:mgoldfarb-nvidia/pgo_nsys_converter_update
PiperOrigin-RevId: 668147156
2024-08-27 14:51:22 -07:00
jax authors
db9e44fe56 Update XLA dependency to use revision
4170b9b1a9.

PiperOrigin-RevId: 668146966
2024-08-27 14:47:29 -07:00
jax authors
538b0bac5e Merge pull request #23271 from oliverdutton-iso:fast_det_jvp
PiperOrigin-RevId: 668144793
2024-08-27 14:42:05 -07:00
Yash Katariya
afff0e09aa Improve the error message to specify shapes too
PiperOrigin-RevId: 668117141
2024-08-27 13:30:55 -07:00
jax authors
6ee41369ee Merge pull request #23233 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 668061985
2024-08-27 11:16:11 -07:00
jax authors
140955dce0 Merge pull request #23224 from ROCm:ci_build_script
PiperOrigin-RevId: 668061975
2024-08-27 11:12:05 -07:00
jax authors
ac9e9e016b Merge pull request #23274 from ayaka14732:mypy-error
PiperOrigin-RevId: 668010899
2024-08-27 09:06:25 -07:00
Ayaka
859eacb5a1 Fix mypy error 2024-08-27 16:57:53 +01:00
rajasekharporeddy
2f3d428e78 Improved docs for jnp.fix and jnp.trunc 2024-08-27 19:40:28 +05:30
Oliver Dutton
087f569759 Fast jvp for 2x3 and 2x2 determinants
Speed up jvp's for 3x3 and 2x2 determinants

The current det implementation custom_jvp is all encompassing, so while there's fast functions for the 2 and 3d cases they still go via the slow general jvp. PR localises the custom_jvp to the generic case.

This general case is ~10x slower on GPU (A100) and ~250x slower on TPU (v2).

```python
import jax
from jax import numpy as jnp
from jax import random

def det_3x3(a: jax.Array) -> jax.Array:
  return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] +
          a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] +
          a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] -
          a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] -
          a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] -
          a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2])

key = random.key(42)
x = random.normal(key, (int(1e5), 3, 3))

general_grad = jax.grad(lambda x: jnp.linalg.det(x).sum())
direct_3by3_grad = jax.vmap(jax.grad(det_3x3))
general_grad, direct_3by3_grad = (jax.jit(f) for f in (general_grad, direct_3by3_grad))

_ = jax.block_until_ready(general_grad(x))
_ = jax.block_until_ready(direct_3by3_grad(x))

%timeit _ = jax.block_until_ready(general_grad(x))
%timeit _ = jax.block_until_ready(direct_3by3_grad(x))
2024-08-27 14:55:54 +01:00
jax authors
be6e1549c3 Merge pull request #23207 from carlosgmartin:docs-jacobian-tall-wide
PiperOrigin-RevId: 667797293
2024-08-26 19:15:49 -07:00
Parker Schuh
d63df39744 Support donating arrays with non-default layouts by setting up XLA donation
directly instead of defining aliasing for arrays with potentially incompatible
layouts. We only fallback to xla dontation for exactly those arrays which
have input and output layouts explicitly set to conflicting layouts.

PiperOrigin-RevId: 667770224
2024-08-26 17:19:33 -07:00
Peter Hawkins
45b871950e Fix a number of minor problems in the ROCM build.
Change in preparation for adding more presubmits for AMD ROCM.

PiperOrigin-RevId: 667766343
2024-08-26 17:04:01 -07:00
Justin Fu
9027101737 Update usages of mosaic compiler params with TPUCompilerParams.
PiperOrigin-RevId: 667762992
2024-08-26 16:51:43 -07:00
carlosgmartin
f812d0f28b Clarify meaning of tall and wide Jacobian matrices in autodiff docs. 2024-08-26 16:00:36 -07:00
jax authors
57c0d59d04 Update XLA dependency to use revision
baf026d13b.

PiperOrigin-RevId: 667720484
2024-08-26 14:39:56 -07:00
jax authors
c556803bdf Merge pull request #23240 from jakevdp:broadcast-docs
PiperOrigin-RevId: 667710935
2024-08-26 14:12:03 -07:00
Bryan Massoth
b38f985b01 Add a callout that LibTPU now supports profiling of SparseCore for TPUv5p chips which will be viewable in Tensorboard Profiler's TraceViewer tool.
PiperOrigin-RevId: 667708094
2024-08-26 14:04:43 -07:00
jax authors
c33ce85784 Small fix for the jax trace dumping path
PiperOrigin-RevId: 667639334
2024-08-26 10:51:08 -07:00
Jake VanderPlas
416f79bb5c DOC: update docstrings for broadcast-related functions 2024-08-26 10:48:29 -07:00
jax authors
be13d4055e Merge pull request #23236 from jakevdp:fix-mean-norm
PiperOrigin-RevId: 667634188
2024-08-26 10:38:01 -07:00
Michael Goldfarb
d2b1ebd0aa Update pgo_nsys_converter.py to use the NVTX kern sum report when available. 2024-08-26 17:27:23 +00:00
jax authors
9e689e4e01 Merge pull request #23229 from froystig:scan-unroll-err
PiperOrigin-RevId: 667617328
2024-08-26 09:52:08 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Jake VanderPlas
4b1c9f483c jnp.mean: fix normalizer for large arrays 2024-08-26 09:04:15 -07:00
jax authors
550607a45d Merge pull request #23197 from jakevdp:quantile-docs
PiperOrigin-RevId: 667602295
2024-08-26 09:02:31 -07:00
Roy Frostig
b3e3115391 improve scan error message on non-concrete unroll argument 2024-08-24 23:09:12 -07:00
jax authors
e3e0860184 Merge pull request #23228 from froystig:scanlen
PiperOrigin-RevId: 667251807
2024-08-24 22:51:45 -07:00
Roy Frostig
a9b41e9fe7 improve scan error message on non-concrete length argument
Specifically, make it speak concretely about the `length` argument.
2024-08-24 22:30:33 -07:00
jax authors
e9143623e0 Update XLA dependency to use revision
4bfb5c82a4.

PiperOrigin-RevId: 667177243
2024-08-24 14:34:48 -07:00
Ruturaj4
9ce8de5fb0 [ROCm] add build file. 2024-08-23 18:11:48 -05:00
Justin Fu
7253b9ac8b [Pallas] Fix pallas interpret mode DMA test failures.
PiperOrigin-RevId: 666953373
2024-08-23 16:07:53 -07:00
jax authors
0c505b79b4 Merge pull request #23222 from mattjj:rafi
PiperOrigin-RevId: 666950244
2024-08-23 15:57:28 -07:00
jax authors
a2a351f88b Fix pallas int4->int8 conversion
PiperOrigin-RevId: 666939965
2024-08-23 15:19:02 -07:00
jax authors
6a5ca0bb52 Update XLA dependency to use revision
9738684ff8.

PiperOrigin-RevId: 666937202
2024-08-23 15:10:36 -07:00
Colin Gaffney
276c87eba0 Add a more helpful error message in create_hybrid_device_mesh for missing attribute process_index or `slice_index.
PiperOrigin-RevId: 666928476
2024-08-23 14:42:48 -07:00
Matthew Johnson
670a648b7b add experimental jax.no_tracing context manager 2024-08-23 21:21:55 +00:00
Jake VanderPlas
9090b8a4f9 Better docs for jnp quantile & percentile 2024-08-23 13:38:20 -07:00
jax authors
c6c701e6a7 Merge pull request #23196 from jakevdp:register-deprecations
PiperOrigin-RevId: 666900363
2024-08-23 13:16:09 -07:00
jax authors
20d13abfa0 Update XLA dependency to use revision
b0d313b58e.

PiperOrigin-RevId: 666868666
2024-08-23 11:38:33 -07:00
jax authors
279977c61d Refactor hermetic CUDA flags and update --config=cuda to add CUDA dependencies both for bazel build and bazel test phases.
Add `--@local_config_cuda//cuda:override_include_cuda_libs` to override settings for TF wheel.

Forbid building TF wheel with `--@local_config_cuda//cuda:include_cuda_libs=true`

PiperOrigin-RevId: 666848518
2024-08-23 10:44:32 -07:00
Adam Paszke
be59f6ec47 [Mosaic GPU] Support tiled stores of arrays with fewer columns than swizzling
PiperOrigin-RevId: 666798285
2024-08-23 08:06:25 -07:00
Bart Chrzaszcz
71b7e78916 Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.
Tests fixed include:

- `test_globally_sharded_key_array_8x4_multi_device`
  - Issue was in `replicate_trailing_dims` where an `xc.OpSharding` was always created. Fixed by creating an equivalent SDY sharding.
- `test_aot_out_info`
  - Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in `lower_jaxpr_to_module` when adding the sdy MeshOp (there won't be any propagation)
- `test_concurrent_pjit`
  - In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
- `test_globally_sharded_key_array_result_8x4_single_device`
  - This tests adds a WSC when no `mesh_shape_tuple` exists (`"sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>`), so we should create a mesh named `mesh` with a single device id in case it doesn't exist.
- `testLowerCostAnalysis`
  - This calls into `mlir_module_to_xla_computation` which calls its own MLIR parsing function in `//third_party/tensorflow/compiler/xla/python/mlir.cc`. Needed to register the SDY dialect in it.
- `testShardingConstraintWithArray`
  - This calls `.compiler_ir(dialect="hlo")` which calls `PyMlirModuleToXlaComputation` which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.

PiperOrigin-RevId: 666777167
2024-08-23 06:51:13 -07:00
Adam Paszke
f54e220430 [Mosaic GPU] Add support for short n dimension in WGMMA
PiperOrigin-RevId: 666766079
2024-08-23 06:08:37 -07:00
Adam Paszke
c76787571b [Mosaic GPU] Expose wait_parity on collective barrier
PiperOrigin-RevId: 666761011
2024-08-23 05:49:06 -07:00
Paweł Paruzel
c430b0c5e3 Activate QR Factorization to XLA's FFI
PiperOrigin-RevId: 666722604
2024-08-23 03:21:43 -07:00
Justin Fu
07767e81a0 [Pallas] Add support for casting to/from unsigned integer types.
PiperOrigin-RevId: 666663406
2024-08-22 23:57:17 -07:00