13359 Commits

Author SHA1 Message Date
jax authors
41417eed6f Merge pull request #12752 from skye:workspace
PiperOrigin-RevId: 480405935
jaxlib-v0.3.22 jax-v0.3.22 jax-v0.3.22-rc 0.3.22
2022-10-11 11:12:26 -07:00
Skye Wanderman-Milne
63be2201aa Update WORKSPACE and setup.py for jaxlib 0.3.22 release 2022-10-11 17:55:05 +00:00
jax authors
af98d8b00f Merge pull request #12746 from jakevdp:fix-changelog
PiperOrigin-RevId: 480375907
2022-10-11 09:21:28 -07:00
Jake VanderPlas
05faf0f40d Remove deprecated functionality from jax.test_util
PiperOrigin-RevId: 480360504
2022-10-11 08:16:34 -07:00
Adam Paszke
9c994985a4 Support MANUAL collectives in top-level xmaps
It's a bit of a weird use-case, since MANUAL mode is meant for xmaps that
are nested inside pjits, but it doesn't hurt us to support it.

PiperOrigin-RevId: 480342531
2022-10-11 06:47:30 -07:00
jax authors
bf7869707a Merge pull request #12727 from jakevdp:scipy-linalg-types
PiperOrigin-RevId: 480340339
2022-10-11 06:34:13 -07:00
Jake VanderPlas
943129419c changelog: add missing github commit links 2022-10-11 06:24:38 -07:00
Rishabh Kabra
e45df46a51 Clarify docs for fori_loop, noting that negative or custom increments are not supported.
PiperOrigin-RevId: 480317277
2022-10-11 04:41:10 -07:00
jax authors
eade3683af Merge pull request #12324 from gnecula:tf_pjit
PiperOrigin-RevId: 480310826
2022-10-11 04:01:12 -07:00
George Necula
9c879adb73 [jax2tf] Implement jax2tf(pjit) for experimental_native_lowering
This implementation is for the case jax2tf.convert(pjit(f_jax)),
that is, the `pjit` appears at the top-level of the function to
be lowered.
2022-10-11 09:45:07 +02:00
Yash Katariya
ff17d3d9fe Add support for calculating the device_assignment when there are no inputs to jit and pjit.
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).

Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
2022-10-10 22:08:42 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
Yash Katariya
9b3e864731 Add weak_type attribute to Array since it exists on DA (but doesn't exist on SDA).
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Jake VanderPlas
afe74b4710 [typing] add type annotations to jax.scipy.linalg 2022-10-10 16:54:29 -07:00
Kuangyuan Chen
ec5b1c93d7 Turn on cpp pjit py default
PiperOrigin-RevId: 480185387
2022-10-10 15:01:04 -07:00
Yash Katariya
76d8c08317 Fix the type annotation of return type of device_buffer and device_buffers which return ArrayImpl instead of DeviceArray.
PiperOrigin-RevId: 480181798
2022-10-10 14:45:12 -07:00
Peter Hawkins
a3a2206d49 Fix compilation failure in lapack kernel under msan.
a_size wasn't defined, but it would only be caught under memory sanitizer.

PiperOrigin-RevId: 480176934
2022-10-10 14:24:59 -07:00
Peter Hawkins
2246887f7b Add input-output aliasing annotations for LAPACK calls on CPU.
PiperOrigin-RevId: 480156067
2022-10-10 12:57:29 -07:00
Yash Katariya
752c3ffcd9 Lift lambda x: x to the top level so that we don't recompile on every invocation of process_allgather.
PiperOrigin-RevId: 480155482
2022-10-10 12:51:17 -07:00
jax authors
90e9abe278 Merge pull request #12722 from hawkinsp:tests
PiperOrigin-RevId: 480149233
2022-10-10 12:22:52 -07:00
Peter Hawkins
2ba0396ddb Add changes accidentally omitted from
https://github.com/google/jax/pull/12717
2022-10-10 19:11:58 +00:00
Tianjian Lu
34eb6ce36b [sparse] BCSR fromdense and todense.
PiperOrigin-RevId: 480141918
2022-10-10 11:54:22 -07:00
Peter Hawkins
c657449528 Copybara import of the project:
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:

Migrate more tests from jtu.cases_from_list to jtu.sample_product.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Peter Hawkins
22cd50535b Reapply: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.

PiperOrigin-RevId: 480134780
2022-10-10 11:29:18 -07:00
jax authors
707b07c1e9 Merge pull request #12697 from jakevdp:lax-slicing-types
PiperOrigin-RevId: 480131675
2022-10-10 11:22:35 -07:00
jax authors
ad4dcc4817 Merge pull request #12700 from jakevdp:scipy-types
PiperOrigin-RevId: 480131410
2022-10-10 11:15:46 -07:00
Jake VanderPlas
124021d720 [typing] add annotations to jax.scipy.ndimage 2022-10-10 09:11:13 -07:00
Jake VanderPlas
76e4a1d40f [typing] add annotations to jax.scipy.fft 2022-10-09 05:18:45 -07:00
Jake VanderPlas
ae9f8eeb0c [typing] annotate lax.slicing 2022-10-09 04:20:46 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
jax authors
25c6ef7ff9 Merge pull request #12707 from mattjj:djax-slice-sick4
PiperOrigin-RevId: 479876971
2022-10-09 00:23:39 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Yash Katariya
75b2d05989 Make is_fully_replicated and is_fully_addressble a property rather than a method.
Why?

1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)

2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.

PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
674038ca47 Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d make device_put(prngkeyarray, sharding) for Array
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00
Peter Hawkins
2693afa263 Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
2022-10-07 14:36:24 -07:00
jax authors
e8ba61d82b Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
jax authors
58cd8376ee Merge pull request #12675 from mattjj:device-put2
PiperOrigin-RevId: 479660808
2022-10-07 13:49:57 -07:00
jax authors
33f18fe27b Merge pull request #12703 from ROCmSoftwarePlatform:rocm-ci-update2
PiperOrigin-RevId: 479652964
2022-10-07 13:12:57 -07:00
Jason Furmanek
34f6646050 Add default setting for TENSORFLOW_ROCM_COMMIT 2022-10-07 19:57:53 +00:00
Peter Hawkins
93b839ace4 Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

PiperOrigin-RevId: 479642543
2022-10-07 12:22:04 -07:00
jax authors
0995f7909a Merge pull request #12698 from jakevdp:scipy-fft-todo
PiperOrigin-RevId: 479615134
2022-10-07 10:29:18 -07:00
Jake VanderPlas
730f9930aa Update scipy version in jax.scipy.fft 2022-10-07 09:57:39 -07:00
jax authors
fde444e735 Merge pull request #12695 from gnecula:tf_dimension
PiperOrigin-RevId: 479602588
2022-10-07 09:38:45 -07:00
George Necula
fb2141fc3b [jax2tf] Allow the use of DimPolynomial with jnp.array and binary operations
Prior to this the user had to explicitly call core.dimension_as_value whenever
using a potentially polymorphic shape in the computation, e.g., x +
core.dimension_as_value(x.shape[0]). Furthermore, jnp.array(x.shape[0])
would fail.

Now, these operations are allowed implicitly,
and the user can call `jnp.array(x.shape[0])`.

This uses an internal extensibility mechanism called __jax_array__
that is experimental and probably not fully implemented.
2022-10-07 17:58:41 +03:00
jax authors
363cc124e3 Merge pull request #12197 from ROCmSoftwarePlatform:fixedRocmUnitTestsSkip
PiperOrigin-RevId: 479566021
2022-10-07 06:36:11 -07:00
George Necula
7c7c94c8dd Expand support for __jax_array__ in jnp.array.
This relates to the long discussion in #4725 and #10065.
2022-10-07 14:25:07 +03:00
jax authors
6c70e4dcaa Merge pull request #12691 from mattjj:issue12688
PiperOrigin-RevId: 479507232
2022-10-07 00:10:18 -07:00
Matthew Johnson
076a7348d0 fix -O / PYTHONOPTIMIZE bug
fixes #12688

I'm not sure how to write test cases for PYTHONOPTIMIZE=1 (without growing our
whole test matrix), so I'm leaving this untested...
2022-10-06 23:15:22 -07:00
Tianjian Lu
7a825362fa [sparse] Bug fix in _validate_bcsr.
PiperOrigin-RevId: 479452053
2022-10-06 17:28:52 -07:00