15481 Commits

Author SHA1 Message Date
Jake VanderPlas
7610013ebf Improve error for tolist() and tobytes() on tracer objects 2023-03-17 09:42:49 -07:00
Ivy Zheng
08c83369be Add an optional flatten_func argument to custom node registration even when flatten_with_keys is given, for better perf for those in need.
Fixes #14844

PiperOrigin-RevId: 517308676
2023-03-16 21:35:10 -07:00
jax authors
d9598215b8 Merge pull request #15048 from mattjj:7155
PiperOrigin-RevId: 517299263
2023-03-16 20:23:30 -07:00
George Necula
ae6ad8ac3c [jax2tf] Add tests for approx_top_k.
These are important because they use a TPU custom call.

PiperOrigin-RevId: 517291210
2023-03-16 19:37:27 -07:00
Matthew Johnson
00dc1f8e6c add test for #7155, fixes #7155 2023-03-16 19:01:27 -07:00
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
jax authors
d5f8fd3d85 Merge pull request #15008 from jakevdp:axis-validation
PiperOrigin-RevId: 517154932
2023-03-16 10:14:40 -07:00
Yash Katariya
f9468d3879 Remove the helper jit functions from api.py
PiperOrigin-RevId: 517152277
2023-03-16 10:08:00 -07:00
Jake Vanderplas
56267f08dd Copybara import of the project:
--
371c5a45ea08c8e92136761149d0016077a58652 by Jake VanderPlas <jakevdp@google.com>:

pytree doc: add discussion of children vs aux_data

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15007 from jakevdp:pytree-doc 371c5a45ea08c8e92136761149d0016077a58652
PiperOrigin-RevId: 517149897
2023-03-16 10:00:24 -07:00
jax authors
e275a9aa5c Merge pull request #15031 from jakevdp:einsum-signature
PiperOrigin-RevId: 517149586
2023-03-16 09:53:18 -07:00
Jake VanderPlas
1be502eeee jax.numpy reductions: validate axis for scalar input 2023-03-16 09:52:26 -07:00
Jake VanderPlas
90804fab56 jnp.einsum: make signature match documentation 2023-03-16 09:21:42 -07:00
jax authors
da19bf18a2 Merge pull request #15030 from jakevdp:nan-test
PiperOrigin-RevId: 517140475
2023-03-16 09:16:46 -07:00
Jake VanderPlas
f6bedb13f7 Add regression test for #4780 2023-03-16 09:05:23 -07:00
George Necula
860630b367 [jax2tf] Add the first version of a custom call backwards compatibility test
See the back_compat_test.py module docstring for some details.

PiperOrigin-RevId: 517031436
2023-03-15 23:10:32 -07:00
Yash Katariya
181355335c Remove references to jax.config.jax_jit_pjit_api_merge, which is always True at head.
PiperOrigin-RevId: 516998437
2023-03-15 20:07:20 -07:00
Yash Katariya
6a0c8069dc Remove the check for if not isinstance(old_token, array.ArrayImpl) since py_executable always return jax.Arrays
PiperOrigin-RevId: 516974728
2023-03-15 17:30:21 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Parker Schuh
e2cce94a3d Avoid extra construction of ShapedArray in array __getitem__.
PiperOrigin-RevId: 516957331
2023-03-15 16:15:15 -07:00
jax authors
dfee197b97 jnp.einsum is parametrizable with dot_general.
PiperOrigin-RevId: 516941177
2023-03-15 15:09:00 -07:00
Parker Schuh
ee70b9612c Avoid extra construction of ShapedArray in array __getitem__.
PiperOrigin-RevId: 516916687
2023-03-15 13:38:36 -07:00
jax authors
09713d8d70 Merge pull request #14986 from mattjj:get-dtype-from-aval
PiperOrigin-RevId: 516910295
2023-03-15 13:16:09 -07:00
Yash Katariya
634035abd7 Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now
PiperOrigin-RevId: 516905931
2023-03-15 13:00:00 -07:00
Matthew Johnson
54b889ca7f [dynamic-shapes] don't require buf objects have dtype attribute
Fixes iree-org/iree-jax#57

An alternative fix would've been just to add the dtype attribute to IreeBuffer.
But it seems better not to make demands on the underlying runtime objects when
we don't need to.

I had to run the test with:

`JAX_PLATFORM_NAME=iree JAX_ARRAY=0 JAX_JIT_PJIT_API_MERGE=0 python tests/dynamic_api_test.py DynamicShapeTest.test_iree_buffer_doesnt_need_dtype_attribute`
2023-03-15 12:53:43 -07:00
Kevin Gleason
6f52388ecc Add support for StableHLO Serialized Portable Artifacts in JAX2TF.
PiperOrigin-RevId: 516885716
2023-03-15 11:42:44 -07:00
Yash Katariya
88584290aa Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
PiperOrigin-RevId: 516881635
2023-03-15 11:34:57 -07:00
Peter Hawkins
01dcd3a3fc Relax argument type annotation for lax.dynamic_slice.
PiperOrigin-RevId: 516881433
2023-03-15 11:28:22 -07:00
jax authors
46682bc0e1 Merge pull request #15006 from jakevdp:colab-tpu-install
PiperOrigin-RevId: 516877100
2023-03-15 11:14:33 -07:00
George Necula
1a9f49963c [jax2tf] Rename experimental_native_lowering to native_serialization
We refer to the feature as serialization rather than just lowering,
because the former is both more widely understood and is actually
more accurate because jax2tf will both lower to StableHLO and then
serialize to StableHLO with compatibility guarantees.

This is part of launching the new version of jax2tf with native
serialization.

For now we keep also the parameter `experimental_native_lowering` and
the flag `jax2tf_default_experimental_native_lowering`, until we transition
projects using these flags to the new ones (separate change).

PiperOrigin-RevId: 516864636
2023-03-15 10:31:25 -07:00
Jake VanderPlas
e3444a8d42 README: improve Colab TPU installation discussion 2023-03-15 08:54:23 -07:00
Peter Hawkins
28e4038933 Mark jax.numpy.DeviceArray as deprecated. Use jax.Array instead.
PiperOrigin-RevId: 516835920
2023-03-15 08:50:00 -07:00
Parker Schuh
9990ed2e64 Implement copy_to_host_async and _value with a single call to
device_replica_id_map and device_indices_map.

PiperOrigin-RevId: 516835021
2023-03-15 08:42:32 -07:00
Adam Paszke
1301968248 Optimize canonicalize_shape
I was looking at some profiles and noticed canonicalize_shape showing up as a noticeable
overhead in certain cases. Which makes sense, given that we carefully check all possible
cases before trying to consider integers as plausible elements (which are the most popular
_by far_). And this function is pretty hot, because it gets called any time we create a new
`ShapedArray`.

I wrote a small benchmark that repeatedly calls canonicalize_shape on a 4-sized tuple of
integers.

Before:
7.62µs ± 8%

After:
1.42µs ± 2%

So a pretty easy 5x improvement overall. And in more real cases, when resharding an array
onto 8 TPUs, 50% of the time was spent on creating shapes for avals of device buffers.

PiperOrigin-RevId: 516795311
2023-03-15 05:10:09 -07:00
jax authors
d978dcf7c0 Merge pull request #14975 from gnecula:native_doc
PiperOrigin-RevId: 516791903
2023-03-15 04:49:44 -07:00
George Necula
c9ca394b33 [jax2tf] Documentation for the native serialization mode 2023-03-15 08:37:29 +02:00
Yash Katariya
3c15093ff4 batched_device_put was fixed to correctly use the x64 flag so there is no need to canonicalize dtype anymore.
PiperOrigin-RevId: 516736011
2023-03-14 23:17:27 -07:00
Parker Schuh
48702171bf Add benchmarks for np.array, device_put, and _arrays.
PiperOrigin-RevId: 516692492
2023-03-14 19:06:06 -07:00
Yash Katariya
ca6564a8b1 Delete jax_jit.device_put since it is not used anywhere except for 1 test. Replace it with batched_device_put
PiperOrigin-RevId: 516691869
2023-03-14 18:59:52 -07:00
jax authors
e627b88f6e Merge pull request #14928 from jakevdp:arange-validation
PiperOrigin-RevId: 516682079
2023-03-14 17:55:03 -07:00
Jake VanderPlas
b308312986 jnp.arange: better validation of inputs 2023-03-14 16:41:58 -07:00
jax authors
f96b59f03d Merge pull request #14987 from jakevdp:bcoo-div
PiperOrigin-RevId: 516655538
2023-03-14 15:51:57 -07:00
Peter Hawkins
a0121d9b9b Improve pytype inference for Sharding type.
* Define use_cpp_class and use_cpp_method decorators as no-ops for type checking.
* Remove the use of abc.ABC when defining the Sharding type. This triggers a pytype bug: the easiest fix seems to be to skip the use of the ABC.
* Write use_cpp_class decorator differently on ArrayImpl to work around pytype bug.
* Fix a few new type errors.

PiperOrigin-RevId: 516631428
2023-03-14 14:20:17 -07:00
jax authors
f27f0579ff Merge pull request #14952 from jakevdp:index-helper-kwargs
PiperOrigin-RevId: 516605517
2023-03-14 12:47:28 -07:00
Jake VanderPlas
74242f06d9 [sparse] add BCOO lowering for div
We had avoiding this previously because dividing by zero is
a densifying operation, but we already support mul which has
similar issues if the operand contains infinities.
2023-03-14 11:58:43 -07:00
Peter Hawkins
ed8ddfb3f7 Add device_buffers property to jax.Array type.
This should be considered deprecated, but it exists and users are using it.

PiperOrigin-RevId: 516585684
2023-03-14 11:36:30 -07:00
Parker Schuh
f389781c0c Use batched_device_put for pxla.shard_sharded_device_array_slow_path.
PiperOrigin-RevId: 516577931
2023-03-14 11:17:14 -07:00
Parker Schuh
c1ae3336d6 Hide jit-of-pmap warning.
PiperOrigin-RevId: 516577489
2023-03-14 11:10:33 -07:00
Yash Katariya
a01b1d2b75 Make pxla.replicate go via batched_device_put rather than pxla.device_put.
PiperOrigin-RevId: 516571175
2023-03-14 10:49:41 -07:00
Yash Katariya
50c7378f83 Fix the usage of device_put_handlers since that is deprecated. Use batched_device_put instead
PiperOrigin-RevId: 516563782
2023-03-14 10:26:10 -07:00
Yash Katariya
b97fb56e95 If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too.
PiperOrigin-RevId: 516561791
2023-03-14 10:19:37 -07:00