1307 Commits

Author SHA1 Message Date
Yash Katariya
fbc05ee5ac Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
Peter Hawkins
d9b0f3cd6f Recommend --local_test_jobs in bazel test command line on GPU. 2023-03-29 09:28:53 -04:00
Skye Wanderman-Milne
473d1c3685 Turn on PJRT C API by default.
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)

To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
2023-03-28 15:28:13 -07:00
jax authors
af4d4943a7 Merge pull request #8633 from shawwn:2021-11-19/autodidax-fix-jaxpr-subcomp-return-type
PiperOrigin-RevId: 519745476
2023-03-27 09:52:20 -07:00
Ravin Kumar
8c2549519b
Update user_guides.rst
Fix minor typo
2023-03-26 17:21:35 -07:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
jax authors
383cf41848 Merge pull request #14937 from b0nce:fix-stats
PiperOrigin-RevId: 518888600
2023-03-23 10:01:54 -07:00
jax authors
00e6c73b68 Merge pull request #15114 from JiaYaobo:add_wald_random
PiperOrigin-RevId: 518592428
2023-03-22 09:37:44 -07:00
jiayaobo
f7a14d65d2 add wald random generator
add wald to random.py
2023-03-22 11:06:59 +08:00
Jake VanderPlas
4a9ed3eaa8 Document ShapeDtypeStruct 2023-03-21 13:53:20 -07:00
Misha
83b3f5b759 Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions. 2023-03-21 00:16:13 +01:00
jax authors
fa9d9ae05f Merge pull request #14900 from JiaYaobo:add_rayleigh_random
PiperOrigin-RevId: 518015562
2023-03-20 10:51:49 -07:00
Mark Sandler
bab1098866 Fixes broken examples, and (invalid) comment for PartitionSpec
PiperOrigin-RevId: 517531823
2023-03-17 16:09:45 -07:00
jax authors
c25ea3f0f2 Merge pull request #15064 from jakevdp:sharp-bits-indexing
PiperOrigin-RevId: 517498861
2023-03-17 13:50:14 -07:00
Jake VanderPlas
c7c9cb652e Sharp bits: refer to ndarray.at in out-of-bound indexing discussion 2023-03-17 13:29:05 -07:00
Jake VanderPlas
912d646076 DOC: remove jax 0.4.1 banner from index page 2023-03-17 13:17:47 -07:00
Jake Vanderplas
56267f08dd Copybara import of the project:
--
371c5a45ea08c8e92136761149d0016077a58652 by Jake VanderPlas <jakevdp@google.com>:

pytree doc: add discussion of children vs aux_data

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15007 from jakevdp:pytree-doc 371c5a45ea08c8e92136761149d0016077a58652
PiperOrigin-RevId: 517149897
2023-03-16 10:00:24 -07:00
Yash Katariya
634035abd7 Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now
PiperOrigin-RevId: 516905931
2023-03-15 13:00:00 -07:00
jiayaobo
05c47033b2 add rayleigh distribution to random.py
add rayleigh distribution to random.py

add rayleigh distribution to random.py

add rayleigh to random.py
2023-03-14 09:55:54 +08:00
jax authors
42ef649e65 Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Matthew Johnson
3eb9c7a6e7 update docs to remove stale reference to laziness optimization 2023-03-13 10:17:46 -07:00
jax authors
d9395959ca Merge pull request #14837 from nouiz:doc_gpu_custom
PiperOrigin-RevId: 515444284
2023-03-09 14:36:30 -08:00
Frederic Bastien
97fc9b4f23 Add reference to the C code where I was looking for them. Also add some high-level description of what is needed. 2023-03-09 10:26:23 -08:00
Matthew Johnson
1f67351f56 [shard_map] make debug_print work with shard_map, eager and jit 2023-03-08 20:38:03 -08:00
Parker Schuh
d62fc88fb1 Roll back #14792
Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs).

PiperOrigin-RevId: 514897538
2023-03-07 18:31:19 -08:00
Misha
feb9ab33af Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic. 2023-03-07 07:56:47 +01:00
Jake VanderPlas
55d9c06267 DOC: update sphinx & sphinx-autodoc-typehints 2023-03-01 11:03:20 -08:00
jax authors
b348fce4dd Merge pull request #14736 from jakevdp:fix-rtd
PiperOrigin-RevId: 513280156
2023-03-01 10:48:54 -08:00
Peter Hawkins
71f120beed Add "Open in Kaggle" buttons to Jupyter notebooks. 2023-03-01 13:15:42 -05:00
Jake VanderPlas
f1b0b6ac65 DOC: fix readthedocs for sphinx-book-theme=1.0 2023-03-01 10:11:31 -08:00
jax authors
fa1ea37704 Merge pull request #14658 from JiaYaobo:chisq_and_f_dist
PiperOrigin-RevId: 513220241
2023-03-01 06:35:34 -08:00
jiayaobo
fdf8ac18d6 add random.chisquare and random.f
add chi2 and F random variables methods

add chi2 and F random variables methods

fix F rv shape broadcasting

fix shape broadcasting
2023-03-01 15:03:50 +08:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
bcf378f6b4 Merge pull request #14701 from jakevdp:doc-devicearray
PiperOrigin-RevId: 512684443
2023-02-27 11:33:07 -08:00
jax authors
f0d816f899 Merge pull request #14673 from nouiz:gpu_doc
PiperOrigin-RevId: 512669380
2023-02-27 10:49:52 -08:00
Jake VanderPlas
b09b4ba51f DOC: fix jax.numpy.Array discussion 2023-02-27 10:45:06 -08:00
Frederic Bastien
ec817974aa Add a new link instead of a TODO. 2023-02-24 13:54:16 -08:00
Frederic Bastien
86191077ff Small fix as the module name changed. 2023-02-24 12:37:56 -08:00
Matthew Johnson
c22da81d5d fixes from reviewers 2023-02-23 15:06:55 -08:00
Matthew Johnson
141996ec11 add remat tutorial docs 2023-02-23 14:37:52 -08:00
Jake VanderPlas
de673ce297 DOC: improve usage recommendation in jax.typing 2023-02-21 04:58:21 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Jake VanderPlas
47ec553c40 DOC: add alternative for pytree initialization 2023-02-17 06:04:33 -08:00
jax authors
bb04686a98 Merge pull request #14503 from sharadmv:fstring-docs
PiperOrigin-RevId: 510181634
2023-02-16 10:28:53 -08:00
jax authors
4eeca92e54 Merge pull request #14482 from gijskoning:patch-1
PiperOrigin-RevId: 510176609
2023-02-16 10:10:34 -08:00
Peter Hawkins
0af9fff5ca Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 510027595
2023-02-15 21:03:03 -08:00
Sharad Vikram
6f1714e57a Add some info in the docs about using jax.debug.print with f-strings 2023-02-15 15:16:37 -08:00
Gijs Koning
71e0d92920
Small update to jax profiling docs
Creating the trace object with context manager also requires a logdir.
2023-02-15 11:33:19 +01:00