4044 Commits

Author SHA1 Message Date
Matthew Johnson
39c2f8b051 fixup from 5415306: remove extraneous lines
also add test
2022-03-11 15:19:10 -08:00
jax authors
cf1161ff8b Merge pull request #9826 from froystig:lax-cleanup2
PiperOrigin-RevId: 433827272
2022-03-10 12:48:34 -08:00
Roy Frostig
64572795b7 remove _select_and_{gather,scatter}_add from public jax.lax module 2022-03-10 10:43:42 -08:00
jax authors
3c49cb5450 Use sharded shape to compute aliasing.
PiperOrigin-RevId: 433762389
2022-03-10 08:41:57 -08:00
Yin Li
c5d4aba2a9 Fix fft dtype for norm='ortho' 2022-03-10 10:39:52 -05:00
Lena Martens
76e1021cad Checkify: Fix empty enabled_errors case wrt user checks.
What looked like a quick win (short-cut on empty enabled_error, don't trace)
was actually a quick bug (user_checks throw error in eager).
2022-03-10 13:33:29 +00:00
Roy Frostig
8f93629e87 remove _convert_element_type from public jax.lax module 2022-03-09 18:46:38 -08:00
Joan Puigcerver
caf094d06b Support gamma distribution with PRNGKeys other than threefry2x32.
PiperOrigin-RevId: 433614014
2022-03-09 17:06:02 -08:00
Jean-Baptiste Lespiau
8a85544537 Add the input avals to Lowered and Compiled.
PiperOrigin-RevId: 433505462
2022-03-09 09:59:45 -08:00
jax authors
beb9900c81 Merge pull request #9807 from jakevdp:coo-sorted-cols
PiperOrigin-RevId: 433397135
2022-03-08 22:46:32 -08:00
Yash Katariya
6e63a90728 Add JAX support for pjit on CPU
PiperOrigin-RevId: 433354147
2022-03-08 17:43:46 -08:00
Jake VanderPlas
3679e0c714 [sparse] track sorted columns for COO GPU lowerings 2022-03-08 17:07:02 -08:00
jax authors
537e35b0fa Merge pull request #9805 from froystig:lax-cleanup
PiperOrigin-RevId: 433347321
2022-03-08 17:05:47 -08:00
Roy Frostig
0cae3160f5 remove _delta from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
6f519576f6 remove _reduce_sum from public jax.lax module 2022-03-08 16:34:26 -08:00
Skye Wanderman-Milne
bcee442390 Improve TPU v2 and v3 mesh_utils.create_device_mesh logic.
* Fixes a bug when a non-3D mesh was requested
* Adds new logic when requesting a single-host mesh
* Extends logic to v2 as well as v3
2022-03-08 22:47:10 +00:00
Jake VanderPlas
43c3bfd324 [sparse]: COO: check for sorted rows before cusparse lowering 2022-03-08 09:21:09 -08:00
jax authors
fdb74ea42a Merge pull request #9785 from froystig:lax-const
PiperOrigin-RevId: 433071851
2022-03-07 16:40:29 -08:00
jax authors
3e93fe01ef Merge pull request #9787 from jakevdp:sparsify-refactor
PiperOrigin-RevId: 433035626
2022-03-07 14:09:53 -08:00
Jake VanderPlas
8c6e001e45 [sparse] refactor internal implementation of sparsify transform 2022-03-07 12:48:03 -08:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
Jake VanderPlas
424536dcf4 [sparse] change call signature of coo primitive wrappers 2022-03-07 11:26:43 -08:00
Yash Katariya
99a103723c Make mesh_axes on GDA strict by only allowing PartitionSpecs to be consistent with pjit.
PiperOrigin-RevId: 432957496
2022-03-07 08:59:23 -08:00
Jean-Baptiste Lespiau
17f11e05e0 Add accessors on Compiled returning the args and kwargs PyTreeDef working for all transforms.
This also documents the fact that `in_tree` content varies, based on the transform.

PiperOrigin-RevId: 432895923
2022-03-07 02:36:42 -08:00
Roy Frostig
947b7b88e1 re-implement custom_transpose without upfront staging.
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-04 16:50:51 -08:00
jax authors
2a3f936ffa Merge pull request #9576 from nicholasjng:broadcast-validation
PiperOrigin-RevId: 432531230
2022-03-04 14:21:17 -08:00
Nicholas Junge
56546d3e73 Validate lax.broadcast_shape inputs before control flow execution
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.

In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
2022-03-04 19:27:52 +01:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
jax authors
cf9a900d78 Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib
PiperOrigin-RevId: 432236852
2022-03-03 11:11:02 -08:00
Yash Katariya
72cc567c05 Use the new mesh property instead of the private _global_mesh attribute.
PiperOrigin-RevId: 431815802
2022-03-01 17:43:40 -08:00
jax authors
d9f82f7b9b [JAX] Move experimental.ann.approx_*_k into lax.
Updated docs, tests and the example code snippets.

PiperOrigin-RevId: 431781401
2022-03-01 14:46:33 -08:00
Matthew Johnson
4075b81e00 add regression test for #9731 2022-03-01 13:10:28 -08:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Jake VanderPlas
ed2550999f implement jnp.copy 2022-03-01 11:56:36 -08:00
jax authors
c7508d1f2d Merge pull request #9721 from jakevdp:poisson-nan
PiperOrigin-RevId: 431505317
2022-02-28 12:59:08 -08:00
Yash Katariya
d0cc3395e8 Add block_until_ready method to GDA
PiperOrigin-RevId: 431504594
2022-02-28 12:54:23 -08:00
Jake VanderPlas
2c2773a5f1 jax.random.poisson: fix corner cases 2022-02-28 12:10:47 -08:00
Peter Hawkins
c339330bc1 [XLA:CPU] Relax test tolerances for tests using XLA:CPU.
An upcoming change to XLA:CPU will disable reassociation on floating point operators by default which is an unsound fast math optimization. This change is being made to fix numerical errors in softmax computations caused by reassocation. After that change, we will enable reassociation only in reduction operators where it is very important for performance and the XLA operator contract allows that.

Since this change alters the order of operations, it may cause small numerical changes leading to test failures. This change relaxes test tolerances to make tests pass.

PiperOrigin-RevId: 431453240
2022-02-28 09:26:54 -08:00
Roy Frostig
d636e74626 make xla_executable a property, consistent across executable types
Also test IR and executable-related methods of `Lowered` and
`Compiled`.
2022-02-25 19:05:44 -08:00
Samuel Ainsworth
bf59b7d872 Relax tolerances slightly for MKL.
Fix https://github.com/google/jax/issues/9705.
2022-02-25 22:02:55 +00:00
Jake VanderPlas
1b01865b89 BUG: return numpy arrays for jnp.load() with unsupported dtypes 2022-02-25 09:27:42 -08:00
Du Phan
e28ec78d7a Make negative dimension return consistent results for image.scale_and_translate 2022-02-24 22:34:00 -05:00
jax authors
8372b98c48 [JAX] Move ann.ann_recall back to tests.
The function is simple enough for users to implement their own on the host.

PiperOrigin-RevId: 430696789
2022-02-24 07:23:17 -08:00
Yash Katariya
e2834d89e1 Fix the gpu tests that were failing with Future warning
PiperOrigin-RevId: 430532523
2022-02-23 13:58:28 -08:00
jax authors
c041694538 Merge pull request #8395 from sharadmv:name-stack-mechanism
PiperOrigin-RevId: 430506453
2022-02-23 12:00:20 -08:00
Yash Katariya
687a7630ee Deprecate maps.mesh and replace it with maps.Mesh.
PiperOrigin-RevId: 430489855
2022-02-23 10:47:06 -08:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
lenamartens
45d3ddda31 Fix tests and handle cond consts. 2022-02-23 16:11:09 +00:00
Matthew Johnson
4b1d0a466b fixing scan and other control flow
Co-authored-by: Lena Martens <lenamartens@google.com>
2022-02-23 15:39:22 +00:00
jax authors
97b1bd3b65 Merge pull request #9636 from LenaMartens:changelist/429277776
PiperOrigin-RevId: 430268738
2022-02-22 12:21:44 -08:00