11528 Commits

Author SHA1 Message Date
jax authors
73a973eaa8 Merge pull request #18000 from alhridoy:arange-precision-warning
PiperOrigin-RevId: 575328461
2023-10-20 15:10:12 -07:00
alhridoy
63f7cfe04c Add precision warning and workaround to jnp.arange documentation 2023-10-20 15:34:12 -06:00
jax authors
dde17cd5bc Merge pull request #18180 from carlosgmartin:fill_diagonal
PiperOrigin-RevId: 575317151
2023-10-20 14:20:14 -07:00
jax authors
7fdc06fa18 Merge pull request #17783 from gnecula:export_effects
PiperOrigin-RevId: 575310727
2023-10-20 13:55:22 -07:00
carlosgmartin
3cb504c583 Add jax.numpy.fill_diagonal. 2023-10-20 16:47:46 -04:00
George Necula
70f6a9e725 [export] Add support for exporting functions with effects
In presence of ordered effects JAX lowering produces a main
function that takes token
inputs and returns token outputs. Previously, when exporting
such a module, we would wrap the main function with a function
that does not use tokens on inputs and outputs. With this
change we actually leave the token inputs and outputs and
rely on consumers of the exported function to know how to
invoke a function with tokens.

Due to the fact that PJRT does not support passing tokens
as input and output to the top-level function, JAX native
lowering uses dummy bool[0] arrays in lieu of tokens for
the top-level function, and uses stablehlo tokens for the
inner functions. When we export a function for serialization
we want to use stablehlo tokens even at top-level, to enable
calling that function from a larger JAX computation later.

See more details about the calling convention in the
docstring for `export.export`.

We also fix and test multi-platform lowering in presence
of effects.

This introduces serialization version 9, but does not change the
default serialization version. This means that version 9 will not
be used except in tests that specifically override the
serialization version.
2023-10-20 22:27:27 +02:00
Jieying Luo
84283b19f8 Support older TPU which does not have get_library_path.
This should fix the CI failure with older TPU (the oldest supported TPU should be updated to 20230912 as well).

Tested with:
```
 pip install --pre libtpu-nightly==0.1.dev.20230912 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
 ENABLE_PJRT_COMPATIBILITY=true python -c "import jax; print(jax.devices());"
 ENABLE_PJRT_COMPATIBILITY=true python tests/api_test.py
```

PiperOrigin-RevId: 575268243
2023-10-20 11:20:37 -07:00
jax authors
ce651172e9 Merge pull request #18154 from jakevdp:keyarray
PiperOrigin-RevId: 575268045
2023-10-20 11:09:13 -07:00
jax authors
2bd5ffe464 Merge pull request #18168 from superbobry:no-config-import
PiperOrigin-RevId: 575259071
2023-10-20 10:49:07 -07:00
Jake VanderPlas
8f82f2e66f [typing] regularize types of jax.random API 2023-10-20 10:33:20 -07:00
George Necula
c227ce9262 [export] Drop backwards-compatible lowering_platform in export
Use instead `lowering_platforms`.
2023-10-20 11:23:52 +02:00
jax authors
a01e47fcef Merge pull request #18202 from gnecula:dim_var_attrs
PiperOrigin-RevId: 575075716
2023-10-19 19:55:43 -07:00
George Necula
8d5a8583ad [export] Add jax.global_constant MLIR attributes for dimension variable arguments
In presence of shape polymorphism and multi-platorm lowering
we pass the global values for the dimension variables and the
platform index to all inner functions. At the moment, prior to
compilation we run a shape refinement pass to infer which of
the arguments of a function carry such global values.
This inference can yield false positives, e.g., when a
user-defined function is called with a constant int32 as the first argument.

With this change we do not need to infer anymore the arguments
that carry global constants. This is in preparation for a more
reliable implementation of shape refinement.
2023-10-20 04:27:05 +02:00
jax authors
aba40179bc Merge pull request #18197 from jakevdp:dep-default-impl
PiperOrigin-RevId: 575058994
2023-10-19 18:09:36 -07:00
Jake VanderPlas
53c4de477e [random] deprecate jax.random.default_prng_impl() 2023-10-19 13:59:01 -07:00
Mohammed Anany
b3904c0508 Bumping Triton version
PiperOrigin-RevId: 574986341
2023-10-19 13:38:34 -07:00
Yash Katariya
613369fc22 Finish 0.4.19 jax and jaxlib release
PiperOrigin-RevId: 574983871
2023-10-19 13:27:52 -07:00
jax authors
741b71fe85 Merge pull request #18093 from mattjj:shmap-res-optimization
PiperOrigin-RevId: 574928569
2023-10-19 10:46:08 -07:00
jax authors
ea6803b8ff Merge pull request #18190 from gnecula:multi_call_exported
PiperOrigin-RevId: 574916602
2023-10-19 10:27:38 -07:00
jax authors
93a46902e7 Merge pull request #18189 from gnecula:clean_untile
PiperOrigin-RevId: 574916483
2023-10-19 10:16:51 -07:00
jax authors
8c3b956449 Merge pull request #18191 from jakevdp:numpy-core
PiperOrigin-RevId: 574903227
2023-10-19 09:43:03 -07:00
Matthew Johnson
1ce8313ec3 factor out subs_list and subs_list2 2023-10-19 09:32:44 -07:00
Jake VanderPlas
e7bcfcff4c Avoid numpy.core import for NumPy 2.0 2023-10-19 09:23:11 -07:00
Jake VanderPlas
06306274e5 Fix type checking declaration of jax.random.threefry2x32_p
Followup to https://github.com/google/jax/pull/18176

PiperOrigin-RevId: 574891218
2023-10-19 09:05:50 -07:00
George Necula
c36c428721 Cleanup untile lowering to remove platform dependence.
The workaround for cpu and gpu for booleans is not necessary anymore.
2023-10-19 17:42:18 +02:00
George Necula
82a2793fc9 [export] Improve the calling of multi-platform exported module
Previously we declared the lowering rule for call_exported to be
platform specific. This was correct, but in the case when the
caller function is lowered itself for multiple platforms this results
in multiple copies of the inner called Exported. Now instead we
make the call_exported rule be platform independent and make it
compute the platform index for the called module based on the
platform index in the caller module. This results in a single
copy of the HLO for the called module in the output.
2023-10-19 17:40:46 +02:00
Matthew Johnson
0944010186 output res forwarding optimization for shard_map and jit 2023-10-18 23:56:26 -07:00
jax authors
dfcbfc3915 Merge pull request #18161 from jakevdp:prng-private-impl
PiperOrigin-RevId: 574679979
2023-10-18 18:57:02 -07:00
jax authors
793aee5d22 Merge pull request #18176 from jakevdp:fix-random-deprecation
PiperOrigin-RevId: 574631221
2023-10-18 16:00:49 -07:00
jax authors
fb89c09c21 Merge pull request #18159 from jakevdp:export-dtypes
PiperOrigin-RevId: 574628977
2023-10-18 16:00:27 -07:00
jax authors
6aff74e7ff Merge pull request #18162 from jakevdp:physical-aval
PiperOrigin-RevId: 574627738
2023-10-18 15:49:44 -07:00
Jake VanderPlas
b865827d06 [random] deprecate jax.random.threefry_2x32 & threefry2x32_p 2023-10-18 14:42:49 -07:00
jax authors
3778265e2e Merge pull request #18126 from niqodea:wrapcauchy
PiperOrigin-RevId: 574572631
2023-10-18 13:18:20 -07:00
jax authors
88fe0da6d1 Merge pull request #18078 from ROCmSoftwarePlatform:rocm-jax-triton
PiperOrigin-RevId: 574546618
2023-10-18 11:56:01 -07:00
jax authors
9435a0ad14 Merge pull request #18138 from mattjj:shmap-axis-env-fix
PiperOrigin-RevId: 574540561
2023-10-18 11:35:02 -07:00
Jake VanderPlas
0da4be5e2a [random] make PRNG impl attributes private 2023-10-18 11:10:47 -07:00
George Necula
cf6548070d [XlaCallModule] Drop support for dim_args_spec attribute.
This attribute was used to support shape polymorphism in versions
up to and including version 4. Starting on March 28th 2023 with
JAX version 0.4.6 we stopped using this attribute. We are now
beyond the 6 month backward compatibility version and we drop
support for this attribute.

We also increase the minimum supported serialization version to 5.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

PiperOrigin-RevId: 574450204
2023-10-18 05:54:11 -07:00
Nicola De Angeli
890b762a3e feat: add wrapcauchy logpdf and pdf 2023-10-18 13:47:10 +02:00
Sergei Lebedev
1079304259 MAINT Do not import the config object in JAX internals
The longer term goal here is to move away from having the config object as
part of the public API and migrate towards module-level functions instead.

Note that we can preserve the dynamic attribute lookup behavior of the
config object via a module-level `__getattr__`
2023-10-18 10:55:13 +01:00
Rahul Batra
b4b97cd8e8 [ROCm]: Add jax-triton support for ROCm 2023-10-18 07:09:20 +00:00
Jake VanderPlas
563673576e [random] cleanup internal implementation 2023-10-17 15:47:32 -07:00
Jake VanderPlas
2932c7eb91 Set public module for exported jax.dtypes APIs 2023-10-17 15:07:28 -07:00
Jake VanderPlas
6da4750c3b [random] remove internal uses of deprecated prng.seed_with_impl() 2023-10-17 13:18:08 -07:00
Sharad Vikram
86023f55ea [Pallas TPU] Add DMA descriptor abstraction for constructing but not starting DMAs
PiperOrigin-RevId: 574210634
2023-10-17 11:20:05 -07:00
Chris Jones
c16b893600 [pallas:gpu] Simplify broadcast_to, min, max lowering.
PiperOrigin-RevId: 574204406
2023-10-17 11:00:50 -07:00
Adam Paszke
b84ae9821f Make sure we don't filter stack frames of packages that start with a jax prefix
This ended up accidentally setting up filters for jax_triton. This change additionally
adds an opt-in mechanism for paths, that overrides exclusions. We use this to avoid
treating pallas ops implementations as JAX-internal.

PiperOrigin-RevId: 574167963
2023-10-17 09:06:30 -07:00
jax authors
2be6019f1c Rollback to fix internal breakage
Reverts 7d203aebfa6206affde207c884b50172e203d177

PiperOrigin-RevId: 574101804
2023-10-17 04:24:15 -07:00
jax authors
6540810670 Merge pull request #18117 from jakevdp:ad-deprecations
PiperOrigin-RevId: 573996165
2023-10-16 19:33:59 -07:00
jax authors
7d203aebfa Merge pull request #18105 from jakevdp:keyarray
PiperOrigin-RevId: 573995089
2023-10-16 19:22:41 -07:00
Jieying Luo
43fc423aeb Temporarily set TPU_LIBRARY_PATH in xla_bridge.
This will be removed once tpu_tracer removes its dependency on TPU_LIBRARY_PATH (should be within next two weeks).

PiperOrigin-RevId: 573958529
2023-10-16 16:13:05 -07:00