8566 Commits

Author SHA1 Message Date
Jake VanderPlas
05faf0f40d Remove deprecated functionality from jax.test_util
PiperOrigin-RevId: 480360504
2022-10-11 08:16:34 -07:00
Adam Paszke
9c994985a4 Support MANUAL collectives in top-level xmaps
It's a bit of a weird use-case, since MANUAL mode is meant for xmaps that
are nested inside pjits, but it doesn't hurt us to support it.

PiperOrigin-RevId: 480342531
2022-10-11 06:47:30 -07:00
jax authors
bf7869707a Merge pull request #12727 from jakevdp:scipy-linalg-types
PiperOrigin-RevId: 480340339
2022-10-11 06:34:13 -07:00
Rishabh Kabra
e45df46a51 Clarify docs for fori_loop, noting that negative or custom increments are not supported.
PiperOrigin-RevId: 480317277
2022-10-11 04:41:10 -07:00
George Necula
9c879adb73 [jax2tf] Implement jax2tf(pjit) for experimental_native_lowering
This implementation is for the case jax2tf.convert(pjit(f_jax)),
that is, the `pjit` appears at the top-level of the function to
be lowered.
2022-10-11 09:45:07 +02:00
Yash Katariya
ff17d3d9fe Add support for calculating the device_assignment when there are no inputs to jit and pjit.
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).

Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
2022-10-10 22:08:42 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
Yash Katariya
9b3e864731 Add weak_type attribute to Array since it exists on DA (but doesn't exist on SDA).
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Jake VanderPlas
afe74b4710 [typing] add type annotations to jax.scipy.linalg 2022-10-10 16:54:29 -07:00
Kuangyuan Chen
ec5b1c93d7 Turn on cpp pjit py default
PiperOrigin-RevId: 480185387
2022-10-10 15:01:04 -07:00
Yash Katariya
76d8c08317 Fix the type annotation of return type of device_buffer and device_buffers which return ArrayImpl instead of DeviceArray.
PiperOrigin-RevId: 480181798
2022-10-10 14:45:12 -07:00
Yash Katariya
752c3ffcd9 Lift lambda x: x to the top level so that we don't recompile on every invocation of process_allgather.
PiperOrigin-RevId: 480155482
2022-10-10 12:51:17 -07:00
Tianjian Lu
34eb6ce36b [sparse] BCSR fromdense and todense.
PiperOrigin-RevId: 480141918
2022-10-10 11:54:22 -07:00
jax authors
707b07c1e9 Merge pull request #12697 from jakevdp:lax-slicing-types
PiperOrigin-RevId: 480131675
2022-10-10 11:22:35 -07:00
Jake VanderPlas
124021d720 [typing] add annotations to jax.scipy.ndimage 2022-10-10 09:11:13 -07:00
Jake VanderPlas
76e4a1d40f [typing] add annotations to jax.scipy.fft 2022-10-09 05:18:45 -07:00
Jake VanderPlas
ae9f8eeb0c [typing] annotate lax.slicing 2022-10-09 04:20:46 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
jax authors
25c6ef7ff9 Merge pull request #12707 from mattjj:djax-slice-sick4
PiperOrigin-RevId: 479876971
2022-10-09 00:23:39 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Yash Katariya
75b2d05989 Make is_fully_replicated and is_fully_addressble a property rather than a method.
Why?

1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)

2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.

PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
674038ca47 Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d make device_put(prngkeyarray, sharding) for Array
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00
jax authors
e8ba61d82b Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
jax authors
58cd8376ee Merge pull request #12675 from mattjj:device-put2
PiperOrigin-RevId: 479660808
2022-10-07 13:49:57 -07:00
Jake VanderPlas
730f9930aa Update scipy version in jax.scipy.fft 2022-10-07 09:57:39 -07:00
George Necula
fb2141fc3b [jax2tf] Allow the use of DimPolynomial with jnp.array and binary operations
Prior to this the user had to explicitly call core.dimension_as_value whenever
using a potentially polymorphic shape in the computation, e.g., x +
core.dimension_as_value(x.shape[0]). Furthermore, jnp.array(x.shape[0])
would fail.

Now, these operations are allowed implicitly,
and the user can call `jnp.array(x.shape[0])`.

This uses an internal extensibility mechanism called __jax_array__
that is experimental and probably not fully implemented.
2022-10-07 17:58:41 +03:00
George Necula
7c7c94c8dd Expand support for __jax_array__ in jnp.array.
This relates to the long discussion in #4725 and #10065.
2022-10-07 14:25:07 +03:00
jax authors
6c70e4dcaa Merge pull request #12691 from mattjj:issue12688
PiperOrigin-RevId: 479507232
2022-10-07 00:10:18 -07:00
Matthew Johnson
076a7348d0 fix -O / PYTHONOPTIMIZE bug
fixes #12688

I'm not sure how to write test cases for PYTHONOPTIMIZE=1 (without growing our
whole test matrix), so I'm leaving this untested...
2022-10-06 23:15:22 -07:00
Tianjian Lu
7a825362fa [sparse] Bug fix in _validate_bcsr.
PiperOrigin-RevId: 479452053
2022-10-06 17:28:52 -07:00
Matthew Johnson
bcca6fb57a add test, small fixes
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:45:34 -07:00
Jake VanderPlas
9104536a98 [typing] add type annotations to lax.linalg functions 2022-10-06 16:19:00 -07:00
Matthew Johnson
ce95ebad94 make device_put work with Sharding 2nd arg
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:14:15 -07:00
Parker Schuh
f49d3d441d Rename Executable to LoadedExecutable within jax.
PiperOrigin-RevId: 479423951
2022-10-06 15:14:33 -07:00
jax authors
df4f399b33 Merge pull request #12678 from jakevdp:average-axis-tuple
PiperOrigin-RevId: 479414002
2022-10-06 14:38:13 -07:00
jax authors
ba2d8035dc Merge pull request #12674 from jakevdp:svd-type
PiperOrigin-RevId: 479413950
2022-10-06 14:31:33 -07:00
Yash Katariya
d174b3dce3 Take shardings as a parameter to deserialize and run_deserialization instead of mesh and pspecs.
PiperOrigin-RevId: 479346552
2022-10-06 10:20:49 -07:00
Jake VanderPlas
32ef3ba37b jnp.average: support tuple axis 2022-10-06 10:20:46 -07:00
Jake VanderPlas
d94327c9e9 Move promote_like_jnp to jax.test_util 2022-10-06 10:20:26 -07:00
Roy Frostig
1b5e9e45ae add input/output sharding to executable protocol 2022-10-05 17:30:56 -07:00
jax authors
d65145fe33 Merge pull request #12532 from jakevdp:reduction-dtype
PiperOrigin-RevId: 479180617
2022-10-05 17:23:30 -07:00
jax authors
3f08f855c5 Merge pull request #12666 from jakevdp:polynomial-typing
PiperOrigin-RevId: 479180604
2022-10-05 17:17:06 -07:00
jax authors
2525aa69f6 Merge pull request #12642 from RissyRan:platform
PiperOrigin-RevId: 479176113
2022-10-05 16:54:11 -07:00
Matthew Johnson
e8dc6d14e4 improve jit(f).lower(duck_args) and pjit(f).lower(duck_args)
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-05 15:47:59 -07:00
Jake VanderPlas
c1328a67c8 [typing] overloads for jnp.linalg.svd & jnp.linalg.qr 2022-10-05 14:29:00 -07:00
Yash Katariya
9d12e216d5 Add addressable_shards to SDA and DA as a compatibility API to match with jax.Array. This will aid in transition to jax.Array.
PiperOrigin-RevId: 479115126
2022-10-05 12:35:17 -07:00
jax authors
167004d247 Merge pull request #12654 from jakevdp:fft-typing
PiperOrigin-RevId: 479106034
2022-10-05 12:00:10 -07:00
jax authors
3ec5e456b8 Merge pull request #12660 from hawkinsp:testing
PiperOrigin-RevId: 479098519
2022-10-05 11:33:55 -07:00
Jake VanderPlas
6a348f9666 [typing] add types for jax.numpy.polynomial 2022-10-05 11:23:45 -07:00