6627 Commits

Author SHA1 Message Date
jax authors
ce651172e9 Merge pull request #18154 from jakevdp:keyarray
PiperOrigin-RevId: 575268045
2023-10-20 11:09:13 -07:00
jax authors
2bd5ffe464 Merge pull request #18168 from superbobry:no-config-import
PiperOrigin-RevId: 575259071
2023-10-20 10:49:07 -07:00
Jake VanderPlas
8f82f2e66f [typing] regularize types of jax.random API 2023-10-20 10:33:20 -07:00
jax authors
a01e47fcef Merge pull request #18202 from gnecula:dim_var_attrs
PiperOrigin-RevId: 575075716
2023-10-19 19:55:43 -07:00
George Necula
8d5a8583ad [export] Add jax.global_constant MLIR attributes for dimension variable arguments
In presence of shape polymorphism and multi-platorm lowering
we pass the global values for the dimension variables and the
platform index to all inner functions. At the moment, prior to
compilation we run a shape refinement pass to infer which of
the arguments of a function carry such global values.
This inference can yield false positives, e.g., when a
user-defined function is called with a constant int32 as the first argument.

With this change we do not need to infer anymore the arguments
that carry global constants. This is in preparation for a more
reliable implementation of shape refinement.
2023-10-20 04:27:05 +02:00
Jake VanderPlas
53c4de477e [random] deprecate jax.random.default_prng_impl() 2023-10-19 13:59:01 -07:00
jax authors
741b71fe85 Merge pull request #18093 from mattjj:shmap-res-optimization
PiperOrigin-RevId: 574928569
2023-10-19 10:46:08 -07:00
George Necula
82a2793fc9 [export] Improve the calling of multi-platform exported module
Previously we declared the lowering rule for call_exported to be
platform specific. This was correct, but in the case when the
caller function is lowered itself for multiple platforms this results
in multiple copies of the inner called Exported. Now instead we
make the call_exported rule be platform independent and make it
compute the platform index for the called module based on the
platform index in the caller module. This results in a single
copy of the HLO for the called module in the output.
2023-10-19 17:40:46 +02:00
Matthew Johnson
0944010186 output res forwarding optimization for shard_map and jit 2023-10-18 23:56:26 -07:00
jax authors
dfcbfc3915 Merge pull request #18161 from jakevdp:prng-private-impl
PiperOrigin-RevId: 574679979
2023-10-18 18:57:02 -07:00
Rebecca Chen
e144f71c33 Fix or ignore some pytype errors.
PiperOrigin-RevId: 574633158
2023-10-18 16:21:59 -07:00
jax authors
3778265e2e Merge pull request #18126 from niqodea:wrapcauchy
PiperOrigin-RevId: 574572631
2023-10-18 13:18:20 -07:00
jax authors
9435a0ad14 Merge pull request #18138 from mattjj:shmap-axis-env-fix
PiperOrigin-RevId: 574540561
2023-10-18 11:35:02 -07:00
Jake VanderPlas
0da4be5e2a [random] make PRNG impl attributes private 2023-10-18 11:10:47 -07:00
Nicola De Angeli
890b762a3e feat: add wrapcauchy logpdf and pdf 2023-10-18 13:47:10 +02:00
Sergei Lebedev
1079304259 MAINT Do not import the config object in JAX internals
The longer term goal here is to move away from having the config object as
part of the public API and migrate towards module-level functions instead.

Note that we can preserve the dynamic attribute lookup behavior of the
config object via a module-level `__getattr__`
2023-10-18 10:55:13 +01:00
jax authors
59047470ef Merge pull request #18156 from jakevdp:seed_with_impl
PiperOrigin-RevId: 574327521
2023-10-17 19:10:45 -07:00
jax authors
74983770cb Add GetOpSharding to XLA/PjRt utils.
PiperOrigin-RevId: 574287268
2023-10-17 15:46:52 -07:00
Jake VanderPlas
6da4750c3b [random] remove internal uses of deprecated prng.seed_with_impl() 2023-10-17 13:18:08 -07:00
jax authors
d03bbc0d0f random_lax_test: Bump shards number for CPU config.
PiperOrigin-RevId: 574239793
2023-10-17 12:59:12 -07:00
jax authors
2be6019f1c Rollback to fix internal breakage
Reverts 7d203aebfa6206affde207c884b50172e203d177

PiperOrigin-RevId: 574101804
2023-10-17 04:24:15 -07:00
Jieying Luo
497f8091de Re-enable the test as the GPU plugin profiler fixes are in.
Reverts 4cd4f3f3b380755ee4738aa73fd2a834e6a61fd7

PiperOrigin-RevId: 574020446
2023-10-16 21:56:46 -07:00
jax authors
7d203aebfa Merge pull request #18105 from jakevdp:keyarray
PiperOrigin-RevId: 573995089
2023-10-16 19:22:41 -07:00
Peter Hawkins
89b5449882 [XLA:GPU] Fix bug in all-to-all for complex data types.
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.

Fixes https://github.com/google/jax/issues/18122

PiperOrigin-RevId: 573955671
2023-10-16 16:02:22 -07:00
jax authors
5919c1f33c Merge pull request #18104 from gnecula:multi_jax2tf
PiperOrigin-RevId: 573951693
2023-10-16 15:46:23 -07:00
Tao Wang
93fbf623a4 Fix testProfilerGetFDOProfile.
PiperOrigin-RevId: 573936890
2023-10-16 14:52:44 -07:00
Matthew Johnson
3bfe1d27bc [shard_map] fix axis env extension bug
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2023-10-16 12:36:51 -07:00
Chris Jones
dcc92e3c5d [pallas] dot fixes.
- Check that operands are 2D.
- Set `preferred_element_type`.
- Fix dot output type on GPU.

PiperOrigin-RevId: 573895904
2023-10-16 12:35:43 -07:00
George Necula
b65c1b293b [jax2tf] First step to enable multi-platform native lowering
Enable experiments with jax2tf native serialization for
multiple platforms. This feature is not yet fully functional
but we need this change to enable further testing.

Cleanup some of the places that are specific to single-platform
serialization, e.g., `lowering_platform`, and generalize
them to multiple platforms (`lowering_platforms`).
2023-10-16 07:01:23 -07:00
jax authors
8d700df5a1 Merge pull request #18111 from superbobry:fix-flags-underscore
PiperOrigin-RevId: 573466372
2023-10-14 07:19:37 -07:00
Sergei Lebedev
f9087ab0c6 MAINT Drop underscore from the name of externally-referenced state objects 2023-10-13 21:30:13 +01:00
Jieying Luo
4cd4f3f3b3 Disable pgle_test.py for GPU plugin.
PiperOrigin-RevId: 573304221
2023-10-13 13:25:11 -07:00
Tao Wang
c568110cc8 Set up an API to top trace and fdo profile in memory.
PiperOrigin-RevId: 573276173
2023-10-13 11:34:25 -07:00
Jake VanderPlas
a2623f2888 [random] Avoid references to PRNGKeyArray type
See https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html
2023-10-13 11:10:05 -07:00
Jake VanderPlas
2c64a0ac2c typing: add some type assertions to typing_test 2023-10-13 08:30:08 -07:00
Chris Jones
2bc2e173cb [pallas:gpu] Fix swap Triton lowering.
PiperOrigin-RevId: 573141426
2023-10-13 01:48:20 -07:00
Chris Jones
0da5828d03 [pallas] Simplify Slice.from_slice code and add check for Slice.size.
Slice behaviour for negative and out-of-range start/stop values now matches standard Python behaviour.

PiperOrigin-RevId: 573141218
2023-10-13 01:36:28 -07:00
Yash Katariya
ef20526a76 Return PositionalSharding if input's rank is >= 3 or a NamedSharding if a mesh is available via the context from inspect_array_sharding. Never return GSPMDSharding from inspect_array_sharding.
PiperOrigin-RevId: 573048344
2023-10-12 16:55:12 -07:00
jax authors
65cfe1a5a3 Instrument metrics for the new JAX compilation cache key generation algorithm.
Metrics:
1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation.
2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation.
PiperOrigin-RevId: 573019115
2023-10-12 14:56:02 -07:00
jax authors
4031c60434 Merge pull request #18079 from superbobry:state-objects
PiperOrigin-RevId: 572937114
2023-10-12 10:00:23 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
George Necula
65e86e357f [export] Fix the call_exported in presence of shardings.
Previously, when we call_exported of an Exported module
with shardings, we invoke the right HLO but the enclosing
JAX computation does not know about the shardings of the
called module. This results in errors when invoking the
calling module.

We change call_exported lowering rules to add sharding
constraints for the inputs and the outputs and we add
a check that we call the exported module on the same
number of devices as at export time.
2023-10-12 06:19:20 -07:00
Chris Jones
b6f7441e73 [pallas] Add Triton lowering rule for custom_jvp_call_p.
PiperOrigin-RevId: 572852113
2023-10-12 03:51:32 -07:00
jax authors
b1e9628448 Testing triton integration 2023-09-14
PiperOrigin-RevId: 572808737
2023-10-12 00:35:36 -07:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Yash Katariya
fd09b35645 Optimize make_array_from_callback for fully replicated shardings by going via batched_device_put
Before:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%
```

After:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%
```

PiperOrigin-RevId: 572429822
2023-10-10 19:02:04 -07:00
Sergei Lebedev
5d9c39f4b0 MAINT Use a generator expression with all() and any()
There is no reason to allocate a list only for the purpose of iteration.
2023-10-10 22:33:03 +01:00
Jieying Luo
b81a3e1fd7 Remove calling configure_library_path during jax import and get libtpu path from libtpu_module.get_library_path().
PiperOrigin-RevId: 572306461
2023-10-10 10:59:37 -07:00
Jieying Luo
269d7ce5c1 Remove take_ownership support in DLPack.
When take_ownership is true, the original buffer is marked as deleted and enforced that JAX won't attempt to read or write the buffer. This provides better error checking but at the cost of one more C++ API and two more C APIs. The same semantic can be achieved by not using take_ownership and being careful. Therefore we decided to remove take_ownership support in DLPack.

PiperOrigin-RevId: 572278488
2023-10-10 09:43:02 -07:00
jax authors
4c306381c9 Merge pull request #18023 from jakevdp:make-jaxpr-name
PiperOrigin-RevId: 572094635
2023-10-09 18:22:26 -07:00