Peter Hawkins
5617a02fa4
Remove JAX custom call implementation of batched triangular solve.
...
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.
PiperOrigin-RevId: 482031830
2022-10-18 15:04:14 -07:00
Yash Katariya
d20b9fa498
Always use .device_buffers
for jax.Array because .device_buffer
can raise an error if there is more than 1 buffer present in the Array.
...
PiperOrigin-RevId: 482028624
2022-10-18 14:51:07 -07:00
jax authors
6aafb86758
Merge pull request #12853 from jakevdp:annotate-index-tricks
...
PiperOrigin-RevId: 482026886
2022-10-18 14:44:57 -07:00
jax authors
a168c2d5b5
Merge pull request #12683 from ylamidon:add-scipy-stats-mode
...
PiperOrigin-RevId: 482025648
2022-10-18 14:38:44 -07:00
jax authors
9818d6562a
Merge pull request #12843 from LenaMartens:while-errors
...
PiperOrigin-RevId: 482008689
2022-10-18 13:46:14 -07:00
jax authors
8ba7911fea
Merge pull request #12851 from jakevdp:annotate-util
...
PiperOrigin-RevId: 482008659
2022-10-18 13:39:15 -07:00
Yann Lamidon
ccbc3059b0
Add JAX equivalent of scipy.stats.mode
2022-10-18 20:45:02 +01:00
Jake VanderPlas
2fe71d29ea
[typing] annotate jax._src.numpy.index_tricks
2022-10-18 12:25:33 -07:00
jax authors
66af016df3
Merge pull request #12844 from yejingxin:main
...
PiperOrigin-RevId: 481970168
2022-10-18 11:12:20 -07:00
Jingxin Ye
59374c1cd8
skip some tests if runtime is stream_executor
...
DETAILS:
Run on CloudTPU v2-8 and found some tests in debugging_primitives_test
fail due to stream_executor runtime cannot support host callback.
Since host callback only support TFRT, so that skip all those types if
runtime type is stream_executor.
TESTED:
passed unit test on both TPU v2-8 and CPU.
2022-10-18 17:42:33 +00:00
Kuangyuan Chen
d64da3d407
Roll forward with fix: Remove the original python function fun_
from C++ PjitFunction, as the destroying fun_
may yield the thread in some cases, which causes error during deleting the python object of PjitFunction.
...
PiperOrigin-RevId: 481950912
2022-10-18 10:05:53 -07:00
Yash Katariya
4af9795668
Add default implementation for calculating devices_indices_map
to XLACompatibleSharding
by lowering to OpSharding and then using its devices_indices_map
.
...
Why? Because users don't have to write the logic for this once they have written the logic for calculating the op_sharding proto.
PiperOrigin-RevId: 481946515
2022-10-18 09:51:08 -07:00
Jake VanderPlas
d60ceeadd0
[typing] annotate util.unzip2 & util.unzip3
2022-10-18 09:47:49 -07:00
jax authors
9749183f51
Merge pull request #12781 from jakevdp:jax-dtypes-types
...
PiperOrigin-RevId: 481810269
2022-10-17 20:50:30 -07:00
jax authors
ae8eb6f500
Merge pull request #12839 from jakevdp:update-pre-commit
...
PiperOrigin-RevId: 481767929
2022-10-17 16:27:53 -07:00
jax authors
b36717e159
Merge pull request #12810 from LenaMartens:less-consts2
...
PiperOrigin-RevId: 481762484
2022-10-17 16:02:06 -07:00
jax authors
8ff293ab75
Fix xla_bridge_test on TPU
...
DETAILS:
When run xla_bridge_test on TPU v2-8 it raises the follow error about unknown backend tpu, this change set jax_platforms to be "" to eliminate this error.
```
FAILED tests/xla_bridge_test.py::GetBackendTest::test_backend_init_error - RuntimeError: Unable to initialize backend 'tpu': Unknown backend 'tpu' (set JAX_PLATFORMS='' to automatically choose an available backend)
```
TESTED:
pass unit test on both CPU and TPU
PiperOrigin-RevId: 481758573
2022-10-17 15:44:06 -07:00
lenamartens
c2a00a0526
Disallow checkify-of-vmap-of-while.
2022-10-17 23:01:43 +01:00
jax authors
7a99e5194a
Merge pull request #12834 from google:dependabot/github_actions/styfle/cancel-workflow-action-0.11.0
...
PiperOrigin-RevId: 481734398
2022-10-17 14:12:35 -07:00
jax authors
9961872006
Merge pull request #12838 from dan-zheng:fix-typo
...
PiperOrigin-RevId: 481734354
2022-10-17 14:05:38 -07:00
Jake VanderPlas
ed7a8bbc10
[typing] annotate jax._src.dtypes
2022-10-17 13:49:26 -07:00
Jake VanderPlas
1ed18fa500
add allow_opaque_dtype to dtypes.canonicalize_dtype utility
2022-10-17 13:47:42 -07:00
Kuangyuan Chen
fd2f590b3b
Allow uncommitted single device PyArray in C++ pjit path.
...
PiperOrigin-RevId: 481711690
2022-10-17 12:35:30 -07:00
Jake VanderPlas
87f1a2bac7
CI: update mypy version in pre-commit config
2022-10-17 11:25:14 -07:00
dependabot[bot]
cef5f20dbb
Bump styfle/cancel-workflow-action from 0.10.1 to 0.11.0
...
Bumps [styfle/cancel-workflow-action](https://github.com/styfle/cancel-workflow-action ) from 0.10.1 to 0.11.0.
- [Release notes](https://github.com/styfle/cancel-workflow-action/releases )
- [Commits](https://github.com/styfle/cancel-workflow-action/compare/0.10.1...0.11.0 )
---
updated-dependencies:
- dependency-name: styfle/cancel-workflow-action
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
2022-10-17 17:18:44 +00:00
jax authors
504b3c1b25
roll forward with the fix: Make params
arg in Compiled.call() position-only so that it does not conflict with the keyword args.
...
PiperOrigin-RevId: 481666211
2022-10-17 09:50:55 -07:00
Dan Zheng
9b0c4e5b9c
Fix typo.
...
decice -> device
2022-10-14 22:12:08 -07:00
Yash Katariya
4cfa01f1cf
Improve the error message when users are trying to create SDAs and pass them into pjit/xmap when jax.Array is enabled. The error message tells them exactly what to do to fix the error.
...
PiperOrigin-RevId: 481282762
2022-10-14 19:43:32 -07:00
Tianjian Lu
69525cd96d
[sparse] Make BCSR vmappable.
...
PiperOrigin-RevId: 481257762
2022-10-14 16:27:24 -07:00
Yash Katariya
63be0c3815
Guard the new channel_handle feature on mlir_api_version for backwards compatibility
...
PiperOrigin-RevId: 481246613
2022-10-14 15:28:30 -07:00
jax authors
8bd913c6ec
Merge pull request #12813 from jakevdp:clarify-typing
...
PiperOrigin-RevId: 481239620
2022-10-14 14:54:12 -07:00
Yash Katariya
607ce88d19
jax.Array is a unified type that will subsume JAX's DeviceArray, ShardedDeviceArray and GlobalDeviceArray.
...
This change replaces uses of `local_shards` and `local_data` with `addressable_shards` and `addressable_data` which are compatible with both `GDA` and `jax.Array`.
PiperOrigin-RevId: 481229606
2022-10-14 14:09:01 -07:00
jax authors
6db8e3c872
Merge pull request #12815 from hawkinsp:py311loc
...
PiperOrigin-RevId: 481218716
2022-10-14 13:22:19 -07:00
Peter Hawkins
ec5bec6157
Include column information in Python locations under Python 3.11.
...
https://peps.python.org/pep-0657/ means that we now have richer context information, which we can propagate where we use it, for example to the MHLO location in this example:
```
In [2]: jax.jit(lambda x: x + 2).lower(7).compiler_ir().operation.print(enable_debug_info=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
module @jit__lambda_ {
func.func public @main(%arg0: tensor<i32> loc(unknown)) -> tensor<i32> {
%0 = mhlo.constant dense<2> : tensor<i32> loc(#loc0)
%1 = mhlo.add %arg0, %0 : tensor<i32> loc(#loc1)
return %1 : tensor<i32> loc(#loc0)
} loc(#loc0)
} loc(#loc0)
#loc1 = loc("jit(<lambda>)/jit(main)/add"("<ipython-input-2-525e569b8960>":1:18))
```
2022-10-14 19:14:35 +00:00
Jake VanderPlas
8196a6a9f0
[typing] clarify jax._src.typing
2022-10-14 11:52:04 -07:00
Kuangyuan Chen
38a7582923
roll forward with the fix: Make params
arg in Compiled.call() position-only so that it does not conflict with the keyword args.
...
PiperOrigin-RevId: 481181330
2022-10-14 10:42:15 -07:00
Adam Paszke
746dd5ab13
Add support for MANUAL lowering of ppermute
...
PiperOrigin-RevId: 481157480
2022-10-14 09:02:55 -07:00
lenamartens
3b24e772e0
Checkify: Only create one init_payload.
...
While debugging some `const` issues, I noticed a huge list of payloads
included as constants. Not sure why I made this a lambda in the first
place, maybe to avoid calling numpy at a module level? I could make this
a cached call instead.
2022-10-14 17:01:12 +01:00
jax authors
c848efa11b
Merge pull request #12808 from hawkinsp:py311
...
PiperOrigin-RevId: 481155690
2022-10-14 08:56:14 -07:00
jax authors
e8ae355c9d
Merge pull request #12806 from hawkinsp:abslpy
...
PiperOrigin-RevId: 481155619
2022-10-14 08:49:42 -07:00
Peter Hawkins
fb72c38e19
Add Python 3.11 as a compatible Python version.
2022-10-14 14:56:07 +00:00
Peter Hawkins
4988b3117d
Drop absl-py as a jaxlib dependency.
...
absl-py is unused.
2022-10-14 13:57:26 +00:00
jax authors
ed17a5b7e8
Merge pull request #12797 from sudhakarsingh27:ignore_user_warning_for_multiprocess_gpu_tests
...
PiperOrigin-RevId: 481131763
2022-10-14 06:29:51 -07:00
jax authors
1945208d34
Rollback because of failing tests internally.
...
PiperOrigin-RevId: 481103002
2022-10-14 03:12:42 -07:00
Parker Schuh
361d3fe553
Add an experimental custom_partitioner API which allows
...
customizing the partitioning rules.
PiperOrigin-RevId: 481032649
2022-10-13 18:37:21 -07:00
Peter Hawkins
5d07cbef2e
Fix lax_autodiff_test to avoid scatters into overlapping ranges.
...
Fixes a flaky test failure on GPU.
PiperOrigin-RevId: 481031833
2022-10-13 18:30:50 -07:00
Sudhakar
cbcd0cdd04
ignore UserWarning
2022-10-13 17:22:15 -07:00
jax authors
c627b47e54
Merge pull request #12769 from nicholasjng:logging-change
...
PiperOrigin-RevId: 481020481
2022-10-13 17:16:16 -07:00
jax authors
9589e5fca0
Merge pull request #12793 from alonfnt:block_until_ready-fix
...
PiperOrigin-RevId: 481003292
2022-10-13 15:50:31 -07:00
Kuangyuan Chen
d082ea0d46
Implement a fast path for pjit AOT in C++ for jax.Array inputs.
...
PiperOrigin-RevId: 480983807
2022-10-13 14:24:05 -07:00