13454 Commits

Author SHA1 Message Date
Peter Hawkins
5617a02fa4 Remove JAX custom call implementation of batched triangular solve.
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.

PiperOrigin-RevId: 482031830
2022-10-18 15:04:14 -07:00
Yash Katariya
d20b9fa498 Always use .device_buffers for jax.Array because .device_buffer can raise an error if there is more than 1 buffer present in the Array.
PiperOrigin-RevId: 482028624
2022-10-18 14:51:07 -07:00
jax authors
6aafb86758 Merge pull request #12853 from jakevdp:annotate-index-tricks
PiperOrigin-RevId: 482026886
2022-10-18 14:44:57 -07:00
jax authors
a168c2d5b5 Merge pull request #12683 from ylamidon:add-scipy-stats-mode
PiperOrigin-RevId: 482025648
2022-10-18 14:38:44 -07:00
jax authors
9818d6562a Merge pull request #12843 from LenaMartens:while-errors
PiperOrigin-RevId: 482008689
2022-10-18 13:46:14 -07:00
jax authors
8ba7911fea Merge pull request #12851 from jakevdp:annotate-util
PiperOrigin-RevId: 482008659
2022-10-18 13:39:15 -07:00
Yann Lamidon
ccbc3059b0 Add JAX equivalent of scipy.stats.mode 2022-10-18 20:45:02 +01:00
Jake VanderPlas
2fe71d29ea [typing] annotate jax._src.numpy.index_tricks 2022-10-18 12:25:33 -07:00
jax authors
66af016df3 Merge pull request #12844 from yejingxin:main
PiperOrigin-RevId: 481970168
2022-10-18 11:12:20 -07:00
Jingxin Ye
59374c1cd8 skip some tests if runtime is stream_executor
DETAILS:
Run on CloudTPU v2-8 and found some tests in debugging_primitives_test
fail due to stream_executor runtime cannot support host callback.
Since host callback only support TFRT, so that skip all those types if
runtime type is stream_executor.

TESTED:
passed unit test on both TPU v2-8 and CPU.
2022-10-18 17:42:33 +00:00
Kuangyuan Chen
d64da3d407 Roll forward with fix: Remove the original python function fun_ from C++ PjitFunction, as the destroying fun_ may yield the thread in some cases, which causes error during deleting the python object of PjitFunction.
PiperOrigin-RevId: 481950912
2022-10-18 10:05:53 -07:00
Yash Katariya
4af9795668 Add default implementation for calculating devices_indices_map to XLACompatibleSharding by lowering to OpSharding and then using its devices_indices_map.
Why? Because users don't have to write the logic for this once they have written the logic for calculating the op_sharding proto.

PiperOrigin-RevId: 481946515
2022-10-18 09:51:08 -07:00
Jake VanderPlas
d60ceeadd0 [typing] annotate util.unzip2 & util.unzip3 2022-10-18 09:47:49 -07:00
jax authors
9749183f51 Merge pull request #12781 from jakevdp:jax-dtypes-types
PiperOrigin-RevId: 481810269
2022-10-17 20:50:30 -07:00
jax authors
ae8eb6f500 Merge pull request #12839 from jakevdp:update-pre-commit
PiperOrigin-RevId: 481767929
2022-10-17 16:27:53 -07:00
jax authors
b36717e159 Merge pull request #12810 from LenaMartens:less-consts2
PiperOrigin-RevId: 481762484
2022-10-17 16:02:06 -07:00
jax authors
8ff293ab75 Fix xla_bridge_test on TPU
DETAILS:
When run xla_bridge_test on TPU v2-8 it raises the follow error about unknown backend tpu, this change set jax_platforms to be "" to eliminate this error.
```
FAILED tests/xla_bridge_test.py::GetBackendTest::test_backend_init_error - RuntimeError: Unable to initialize backend 'tpu': Unknown backend 'tpu' (set JAX_PLATFORMS='' to automatically choose an available backend)
```

TESTED:
pass unit test on both CPU and TPU
PiperOrigin-RevId: 481758573
2022-10-17 15:44:06 -07:00
lenamartens
c2a00a0526 Disallow checkify-of-vmap-of-while. 2022-10-17 23:01:43 +01:00
jax authors
7a99e5194a Merge pull request #12834 from google:dependabot/github_actions/styfle/cancel-workflow-action-0.11.0
PiperOrigin-RevId: 481734398
2022-10-17 14:12:35 -07:00
jax authors
9961872006 Merge pull request #12838 from dan-zheng:fix-typo
PiperOrigin-RevId: 481734354
2022-10-17 14:05:38 -07:00
Jake VanderPlas
ed7a8bbc10 [typing] annotate jax._src.dtypes 2022-10-17 13:49:26 -07:00
Jake VanderPlas
1ed18fa500 add allow_opaque_dtype to dtypes.canonicalize_dtype utility 2022-10-17 13:47:42 -07:00
Kuangyuan Chen
fd2f590b3b Allow uncommitted single device PyArray in C++ pjit path.
PiperOrigin-RevId: 481711690
2022-10-17 12:35:30 -07:00
Jake VanderPlas
87f1a2bac7 CI: update mypy version in pre-commit config 2022-10-17 11:25:14 -07:00
dependabot[bot]
cef5f20dbb
Bump styfle/cancel-workflow-action from 0.10.1 to 0.11.0
Bumps [styfle/cancel-workflow-action](https://github.com/styfle/cancel-workflow-action) from 0.10.1 to 0.11.0.
- [Release notes](https://github.com/styfle/cancel-workflow-action/releases)
- [Commits](https://github.com/styfle/cancel-workflow-action/compare/0.10.1...0.11.0)

---
updated-dependencies:
- dependency-name: styfle/cancel-workflow-action
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2022-10-17 17:18:44 +00:00
jax authors
504b3c1b25 roll forward with the fix: Make params arg in Compiled.call() position-only so that it does not conflict with the keyword args.
PiperOrigin-RevId: 481666211
2022-10-17 09:50:55 -07:00
Dan Zheng
9b0c4e5b9c Fix typo.
decice -> device
2022-10-14 22:12:08 -07:00
Yash Katariya
4cfa01f1cf Improve the error message when users are trying to create SDAs and pass them into pjit/xmap when jax.Array is enabled. The error message tells them exactly what to do to fix the error.
PiperOrigin-RevId: 481282762
2022-10-14 19:43:32 -07:00
Tianjian Lu
69525cd96d [sparse] Make BCSR vmappable.
PiperOrigin-RevId: 481257762
2022-10-14 16:27:24 -07:00
Yash Katariya
63be0c3815 Guard the new channel_handle feature on mlir_api_version for backwards compatibility
PiperOrigin-RevId: 481246613
2022-10-14 15:28:30 -07:00
jax authors
8bd913c6ec Merge pull request #12813 from jakevdp:clarify-typing
PiperOrigin-RevId: 481239620
2022-10-14 14:54:12 -07:00
Yash Katariya
607ce88d19 jax.Array is a unified type that will subsume JAX's DeviceArray, ShardedDeviceArray and GlobalDeviceArray.
This change replaces uses of `local_shards` and `local_data` with `addressable_shards` and `addressable_data` which are compatible with both `GDA` and `jax.Array`.

PiperOrigin-RevId: 481229606
2022-10-14 14:09:01 -07:00
jax authors
6db8e3c872 Merge pull request #12815 from hawkinsp:py311loc
PiperOrigin-RevId: 481218716
2022-10-14 13:22:19 -07:00
Peter Hawkins
ec5bec6157 Include column information in Python locations under Python 3.11.
https://peps.python.org/pep-0657/ means that we now have richer context information, which we can propagate where we use it, for example to the MHLO location in this example:

```
In [2]: jax.jit(lambda x: x + 2).lower(7).compiler_ir().operation.print(enable_debug_info=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
module @jit__lambda_ {
  func.func public @main(%arg0: tensor<i32> loc(unknown)) -> tensor<i32> {
    %0 = mhlo.constant dense<2> : tensor<i32> loc(#loc0)
    %1 = mhlo.add %arg0, %0 : tensor<i32> loc(#loc1)
    return %1 : tensor<i32> loc(#loc0)
  } loc(#loc0)
} loc(#loc0)
#loc1 = loc("jit(<lambda>)/jit(main)/add"("<ipython-input-2-525e569b8960>":1:18))
```
2022-10-14 19:14:35 +00:00
Jake VanderPlas
8196a6a9f0 [typing] clarify jax._src.typing 2022-10-14 11:52:04 -07:00
Kuangyuan Chen
38a7582923 roll forward with the fix: Make params arg in Compiled.call() position-only so that it does not conflict with the keyword args.
PiperOrigin-RevId: 481181330
2022-10-14 10:42:15 -07:00
Adam Paszke
746dd5ab13 Add support for MANUAL lowering of ppermute
PiperOrigin-RevId: 481157480
2022-10-14 09:02:55 -07:00
lenamartens
3b24e772e0 Checkify: Only create one init_payload.
While debugging some `const` issues, I noticed a huge list of payloads
included as constants. Not sure why I made this a lambda in the first
place, maybe to avoid calling numpy at a module level? I could make this
a cached call instead.
2022-10-14 17:01:12 +01:00
jax authors
c848efa11b Merge pull request #12808 from hawkinsp:py311
PiperOrigin-RevId: 481155690
2022-10-14 08:56:14 -07:00
jax authors
e8ae355c9d Merge pull request #12806 from hawkinsp:abslpy
PiperOrigin-RevId: 481155619
2022-10-14 08:49:42 -07:00
Peter Hawkins
fb72c38e19 Add Python 3.11 as a compatible Python version. 2022-10-14 14:56:07 +00:00
Peter Hawkins
4988b3117d Drop absl-py as a jaxlib dependency.
absl-py is unused.
2022-10-14 13:57:26 +00:00
jax authors
ed17a5b7e8 Merge pull request #12797 from sudhakarsingh27:ignore_user_warning_for_multiprocess_gpu_tests
PiperOrigin-RevId: 481131763
2022-10-14 06:29:51 -07:00
jax authors
1945208d34 Rollback because of failing tests internally.
PiperOrigin-RevId: 481103002
2022-10-14 03:12:42 -07:00
Parker Schuh
361d3fe553 Add an experimental custom_partitioner API which allows
customizing the partitioning rules.

PiperOrigin-RevId: 481032649
2022-10-13 18:37:21 -07:00
Peter Hawkins
5d07cbef2e Fix lax_autodiff_test to avoid scatters into overlapping ranges.
Fixes a flaky test failure on GPU.

PiperOrigin-RevId: 481031833
2022-10-13 18:30:50 -07:00
Sudhakar
cbcd0cdd04 ignore UserWarning 2022-10-13 17:22:15 -07:00
jax authors
c627b47e54 Merge pull request #12769 from nicholasjng:logging-change
PiperOrigin-RevId: 481020481
2022-10-13 17:16:16 -07:00
jax authors
9589e5fca0 Merge pull request #12793 from alonfnt:block_until_ready-fix
PiperOrigin-RevId: 481003292
2022-10-13 15:50:31 -07:00
Kuangyuan Chen
d082ea0d46 Implement a fast path for pjit AOT in C++ for jax.Array inputs.
PiperOrigin-RevId: 480983807
2022-10-13 14:24:05 -07:00