Matthew Johnson
39c2f8b051
fixup from 5415306: remove extraneous lines
...
also add test
2022-03-11 15:19:10 -08:00
jax authors
cf1161ff8b
Merge pull request #9826 from froystig:lax-cleanup2
...
PiperOrigin-RevId: 433827272
2022-03-10 12:48:34 -08:00
Roy Frostig
64572795b7
remove _select_and_{gather,scatter}_add
from public jax.lax
module
2022-03-10 10:43:42 -08:00
jax authors
3c49cb5450
Use sharded shape to compute aliasing.
...
PiperOrigin-RevId: 433762389
2022-03-10 08:41:57 -08:00
Yin Li
c5d4aba2a9
Fix fft dtype for norm='ortho'
2022-03-10 10:39:52 -05:00
Lena Martens
76e1021cad
Checkify: Fix empty enabled_errors case wrt user checks.
...
What looked like a quick win (short-cut on empty enabled_error, don't trace)
was actually a quick bug (user_checks throw error in eager).
2022-03-10 13:33:29 +00:00
Roy Frostig
8f93629e87
remove _convert_element_type
from public jax.lax
module
2022-03-09 18:46:38 -08:00
Joan Puigcerver
caf094d06b
Support gamma distribution with PRNGKeys other than threefry2x32.
...
PiperOrigin-RevId: 433614014
2022-03-09 17:06:02 -08:00
Jean-Baptiste Lespiau
8a85544537
Add the input avals to Lowered and Compiled.
...
PiperOrigin-RevId: 433505462
2022-03-09 09:59:45 -08:00
jax authors
beb9900c81
Merge pull request #9807 from jakevdp:coo-sorted-cols
...
PiperOrigin-RevId: 433397135
2022-03-08 22:46:32 -08:00
Yash Katariya
6e63a90728
Add JAX support for pjit on CPU
...
PiperOrigin-RevId: 433354147
2022-03-08 17:43:46 -08:00
Jake VanderPlas
3679e0c714
[sparse] track sorted columns for COO GPU lowerings
2022-03-08 17:07:02 -08:00
jax authors
537e35b0fa
Merge pull request #9805 from froystig:lax-cleanup
...
PiperOrigin-RevId: 433347321
2022-03-08 17:05:47 -08:00
Roy Frostig
0cae3160f5
remove _delta
from public jax.lax
module
2022-03-08 16:34:26 -08:00
Roy Frostig
6f519576f6
remove _reduce_sum
from public jax.lax
module
2022-03-08 16:34:26 -08:00
Skye Wanderman-Milne
bcee442390
Improve TPU v2 and v3 mesh_utils.create_device_mesh logic.
...
* Fixes a bug when a non-3D mesh was requested
* Adds new logic when requesting a single-host mesh
* Extends logic to v2 as well as v3
2022-03-08 22:47:10 +00:00
Jake VanderPlas
43c3bfd324
[sparse]: COO: check for sorted rows before cusparse lowering
2022-03-08 09:21:09 -08:00
jax authors
fdb74ea42a
Merge pull request #9785 from froystig:lax-const
...
PiperOrigin-RevId: 433071851
2022-03-07 16:40:29 -08:00
jax authors
3e93fe01ef
Merge pull request #9787 from jakevdp:sparsify-refactor
...
PiperOrigin-RevId: 433035626
2022-03-07 14:09:53 -08:00
Jake VanderPlas
8c6e001e45
[sparse] refactor internal implementation of sparsify transform
2022-03-07 12:48:03 -08:00
Roy Frostig
f7731bf959
remove _const
from public jax.lax
module
...
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
Jake VanderPlas
424536dcf4
[sparse] change call signature of coo primitive wrappers
2022-03-07 11:26:43 -08:00
Yash Katariya
99a103723c
Make mesh_axes
on GDA strict by only allowing PartitionSpecs to be consistent with pjit.
...
PiperOrigin-RevId: 432957496
2022-03-07 08:59:23 -08:00
Jean-Baptiste Lespiau
17f11e05e0
Add accessors on Compiled returning the args and kwargs PyTreeDef
working for all transforms.
...
This also documents the fact that `in_tree` content varies, based on the transform.
PiperOrigin-RevId: 432895923
2022-03-07 02:36:42 -08:00
Roy Frostig
947b7b88e1
re-implement custom_transpose
without upfront staging.
...
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-04 16:50:51 -08:00
jax authors
2a3f936ffa
Merge pull request #9576 from nicholasjng:broadcast-validation
...
PiperOrigin-RevId: 432531230
2022-03-04 14:21:17 -08:00
Nicholas Junge
56546d3e73
Validate lax.broadcast_shape
inputs before control flow execution
...
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.
In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
2022-03-04 19:27:52 +01:00
Peter Hawkins
c978df5550
Increase minimum jaxlib version to 0.3.0.
2022-03-04 10:33:03 -05:00
jax authors
cf9a900d78
Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib
...
PiperOrigin-RevId: 432236852
2022-03-03 11:11:02 -08:00
Yash Katariya
72cc567c05
Use the new mesh
property instead of the private _global_mesh
attribute.
...
PiperOrigin-RevId: 431815802
2022-03-01 17:43:40 -08:00
jax authors
d9f82f7b9b
[JAX] Move experimental.ann.approx_*_k
into lax
.
...
Updated docs, tests and the example code snippets.
PiperOrigin-RevId: 431781401
2022-03-01 14:46:33 -08:00
Matthew Johnson
4075b81e00
add regression test for #9731
2022-03-01 13:10:28 -08:00
Reza Rahimi
a0d9d81f92
Update JAX to use new math libraries in ROCm-5.0.
2022-03-01 20:02:15 +00:00
Jake VanderPlas
ed2550999f
implement jnp.copy
2022-03-01 11:56:36 -08:00
jax authors
c7508d1f2d
Merge pull request #9721 from jakevdp:poisson-nan
...
PiperOrigin-RevId: 431505317
2022-02-28 12:59:08 -08:00
Yash Katariya
d0cc3395e8
Add block_until_ready
method to GDA
...
PiperOrigin-RevId: 431504594
2022-02-28 12:54:23 -08:00
Jake VanderPlas
2c2773a5f1
jax.random.poisson: fix corner cases
2022-02-28 12:10:47 -08:00
Peter Hawkins
c339330bc1
[XLA:CPU] Relax test tolerances for tests using XLA:CPU.
...
An upcoming change to XLA:CPU will disable reassociation on floating point operators by default which is an unsound fast math optimization. This change is being made to fix numerical errors in softmax computations caused by reassocation. After that change, we will enable reassociation only in reduction operators where it is very important for performance and the XLA operator contract allows that.
Since this change alters the order of operations, it may cause small numerical changes leading to test failures. This change relaxes test tolerances to make tests pass.
PiperOrigin-RevId: 431453240
2022-02-28 09:26:54 -08:00
Roy Frostig
d636e74626
make xla_executable
a property, consistent across executable types
...
Also test IR and executable-related methods of `Lowered` and
`Compiled`.
2022-02-25 19:05:44 -08:00
Samuel Ainsworth
bf59b7d872
Relax tolerances slightly for MKL.
...
Fix https://github.com/google/jax/issues/9705 .
2022-02-25 22:02:55 +00:00
Jake VanderPlas
1b01865b89
BUG: return numpy arrays for jnp.load() with unsupported dtypes
2022-02-25 09:27:42 -08:00
Du Phan
e28ec78d7a
Make negative dimension return consistent results for image.scale_and_translate
2022-02-24 22:34:00 -05:00
jax authors
8372b98c48
[JAX] Move ann.ann_recall back to tests.
...
The function is simple enough for users to implement their own on the host.
PiperOrigin-RevId: 430696789
2022-02-24 07:23:17 -08:00
Yash Katariya
e2834d89e1
Fix the gpu tests that were failing with Future warning
...
PiperOrigin-RevId: 430532523
2022-02-23 13:58:28 -08:00
jax authors
c041694538
Merge pull request #8395 from sharadmv:name-stack-mechanism
...
PiperOrigin-RevId: 430506453
2022-02-23 12:00:20 -08:00
Yash Katariya
687a7630ee
Deprecate maps.mesh
and replace it with maps.Mesh
.
...
PiperOrigin-RevId: 430489855
2022-02-23 10:47:06 -08:00
Sharad Vikram
1b79caa6bd
Add separate mechanism for threading name stacks to the lowering
2022-02-23 09:59:09 -08:00
lenamartens
45d3ddda31
Fix tests and handle cond consts.
2022-02-23 16:11:09 +00:00
Matthew Johnson
4b1d0a466b
fixing scan and other control flow
...
Co-authored-by: Lena Martens <lenamartens@google.com>
2022-02-23 15:39:22 +00:00
jax authors
97b1bd3b65
Merge pull request #9636 from LenaMartens:changelist/429277776
...
PiperOrigin-RevId: 430268738
2022-02-22 12:21:44 -08:00