6382 Commits

Author SHA1 Message Date
Qiao Zhang
d4adf0095f Add default jvp and transpose rule for jax.lax.reduce_precision.
PiperOrigin-RevId: 564536160
2023-09-11 16:35:44 -07:00
jax authors
23e4f0b471 Hash serialized topology description for new cache key generation.
The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.

Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.

Testing: revised unit test.
PiperOrigin-RevId: 564461564
2023-09-11 12:08:26 -07:00
Yash Katariya
a36598b2a7 Set the jax_enable_memories flag to True.
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.

If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.

PiperOrigin-RevId: 564457393
2023-09-11 11:55:09 -07:00
Jake VanderPlas
bc91f2d182 Add more extensive tests for version strings 2023-09-08 11:33:49 -07:00
Jake VanderPlas
9e9ea5a90c jax.dtypes: fix dtypes.safe_to_cast() when output dtype is weak 2023-09-08 09:04:05 -07:00
Adam Paszke
832820e8ff CI changes
PiperOrigin-RevId: 563754696
2023-09-08 08:08:05 -07:00
Adam Paszke
592eb44d7e Fix Pallas tests that broke after recent changes
PiperOrigin-RevId: 563750775
2023-09-08 07:50:20 -07:00
Yash Katariya
b8eccb13f0 Remove the date check from jaxlib and jax version checks since it causes problem when jaxlib runs ahead of jax in CI (depending on timezones).
PiperOrigin-RevId: 563614108
2023-09-07 19:52:35 -07:00
Sharad Vikram
cb114f247a [Pallas] Refactor memory space handling
PiperOrigin-RevId: 563586933
2023-09-07 17:08:57 -07:00
jax authors
311dc9cfde Add truncated normal initializer to jax.nn
PiperOrigin-RevId: 563576354
2023-09-07 16:23:42 -07:00
Parker Schuh
bda9292523 Propagate ad.Zeros to the scan body function for jax.lax.scan for the output 'ys'.
Example of what this fixes:

```
def grad_fn(x):
  def scan_body(x, params):
    return x, x.sum()

  pred, state = jax.lax.scan(scan_body, x, None, length=2)
  return pred.sum(), state
x = np.zeros((5, 10), dtype=np.float32)
loss_grad_fn = jax.value_and_grad(grad_fn, has_aux=True)
print(jax.make_jaxpr(loss_grad_fn)(x))
```
PiperOrigin-RevId: 563544684
2023-09-07 14:36:53 -07:00
Jake VanderPlas
8412781127 Internal: add dtypes.safe_to_cast utility & use to generate indexing warning 2023-09-07 12:18:14 -07:00
jax authors
f64235acc8 Merge pull request #17453 from jakevdp:fix-version-string
PiperOrigin-RevId: 563466394
2023-09-07 10:06:32 -07:00
Jake VanderPlas
6f3f0d5e57 build: write appropriate version strings to build artifacts 2023-09-07 08:45:48 -07:00
Peter Hawkins
9b447aa3ec Relax test tolerance to fix BCSR sparse matmul test failure on P100 GPU.
PiperOrigin-RevId: 563441383
2023-09-07 08:37:31 -07:00
Peter Hawkins
429422dfea Reverts 5fcd9265b1e20c41d684659af3d52c41f25ae2f3
PiperOrigin-RevId: 563426073
2023-09-07 07:35:44 -07:00
jax authors
26d90beca2 Merge pull request #17314 from jon-chuang:jon-chuang/fix-causal-upper-bound
PiperOrigin-RevId: 563228924
2023-09-06 15:13:16 -07:00
Adam Paszke
bb8d5a0121 Rewrite simple slicing to the static slicing primitive whenever possible
This makes it a lot easier to handle within Pallas and Mosaic.

PiperOrigin-RevId: 563128943
2023-09-06 09:43:00 -07:00
George Necula
660a015652 [export] Move jax_export and shape_poly out of jax2tf.
Those modules have been developed initially for jax2tf
but they do not depend on TF anymore. They are used for JAX
native serialization. We move them under
jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF.

We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later.

PiperOrigin-RevId: 562988740
2023-09-05 22:15:59 -07:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
Yash Katariya
80606cd28d Make is_fully_addressable an abstract method and implement it on each concrete Sharding.
Also, don't cache methods. Pull them out into a free function and cache that function.

PiperOrigin-RevId: 562939188
2023-09-05 17:28:22 -07:00
George Necula
f27816af30 [callback] Enable 64-bit types and add tests.
This takes advantage of a recent fix in XLA:TPU to enable
64-bit host transfers.

PiperOrigin-RevId: 562890507
2023-09-05 14:23:28 -07:00
jax authors
7224c24521 Merge pull request #17406 from jakevdp:namedtuple-mul
PiperOrigin-RevId: 562843729
2023-09-05 11:40:18 -07:00
Jake VanderPlas
7d29ed6bdd Lower jax.numpy matmul functions to mixed-precision dot_general 2023-09-05 08:37:51 -07:00
Jon Chuang
7d27c319b2 fix causal attention upper bound 2023-09-04 11:22:21 +08:00
George Necula
01c068eabd [callback] Some test cleanup.
Removes callback testing function and uses io_callback
and pure_callback instead. This allows us to remove
some tests from the PureCallbackTest class.

Renames IoPythonCallbackTest -> IoCallbackTest and PurePythonCallbackTest -> PureCallbackTest.

PiperOrigin-RevId: 562285255
2023-09-02 21:51:07 -07:00
Jake VanderPlas
72de4cd0b0 Bug: support NamedTuples in deferring binary ops 2023-09-01 13:16:20 -07:00
George Necula
efaea8ed32 [callback] Enable device_index support in terms of callback sharding support.
This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561856023
2023-08-31 22:31:35 -07:00
George Necula
e0a6230214 [host_callback] Delete unused code paths.
This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561851494
2023-08-31 22:08:23 -07:00
Matthew Johnson
70b58bbd30 rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
2023-08-31 17:31:21 -07:00
jax authors
c38f67043c Hash serialized CompileOptions for new cache key generation.
The original cache key generation hashes individual fields of
CompileOptions, ExecutableBuildOptions, and DebugOptions. This
is not future proof: when a field is added to any of these
structures, the corresponding hash needs to be added to the
cache key generation. The new cache key generation algorithm
hashes the serialized representation of CompileOptions.

Some DebugOptions do not affect the compilation result;
exclude them from the computation. If additional fields are
identified, they can be added; such additions will reduce
unnecessary cache misses.

Testing: revised unit test.
PiperOrigin-RevId: 561803875
2023-08-31 17:21:57 -07:00
jax authors
80f6151110 Instrument metrics to track cache hit rate of original JAX compilation cache.
Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of  the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.

Created a context manager to register/unregister event listeners.

PiperOrigin-RevId: 561771262
2023-08-31 15:05:23 -07:00
Matthew Johnson
8a04dfd830 rolling back shard_map transposition change to fix a bug
Reverts 437d7be73534403f39fbee9d6391be1c532933a1

PiperOrigin-RevId: 561730581
2023-08-31 12:39:56 -07:00
Jake VanderPlas
f0309b49c9 jax.random: warn on unsupported dtypes 2023-08-31 10:56:05 -07:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00
jax authors
437d7be735 Merge pull request #17368 from mattjj:shmap-transpose
PiperOrigin-RevId: 561476082
2023-08-30 16:14:13 -07:00
Jake VanderPlas
4b89d03147 Deprecate the contents of jax.prng 2023-08-30 15:13:32 -07:00
Matthew Johnson
fdd252f6ca [shard-map] add rewrite for efficient transposition 2023-08-30 15:08:11 -07:00
Yash Katariya
e785f89470 Raise a good error message when mesh is not provided to jax.jit when using spmd_axis_name parameter of jax.vmap
PiperOrigin-RevId: 561217612
2023-08-29 20:58:57 -07:00
Skye Wanderman-Milne
ecee8f9116 [JAX] Implement importing external dlpack-aware Python arrays.
See https://dmlc.github.io/dlpack/latest/python_spec.html.

This is the import path. The export path was implemented in
0b3cbfe4bc.

This allows for creating jax.Arrays from external GPU arrays
asynchronously.

PiperOrigin-RevId: 561172624
2023-08-29 16:39:31 -07:00
Peter Hawkins
e369445596 Remove tests for jax.numpy.in1d, which is deprecated.
PiperOrigin-RevId: 561161024
2023-08-29 15:52:34 -07:00
Tao Wang
5a578cba19 Add segment_ids support to pallas flash attention on GPU.
PiperOrigin-RevId: 561130379
2023-08-29 13:59:18 -07:00
Yash Katariya
6072d5993e Any devices passed to jax.sharding.Mesh are required to be hashable.
This is true for mock devices or user specific devices and jax.devices() too.

Fix the tests so that the mock devices are hashable.

PiperOrigin-RevId: 561103167
2023-08-29 12:20:54 -07:00
Peter Hawkins
d0a6813ea2 Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call().
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.

Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.

This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).

Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.

PiperOrigin-RevId: 561042402
2023-08-29 08:50:07 -07:00
Yash Katariya
a37e2159b3 Don't drop out of C++ fast path if mesh pointers are not equal.
This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches.

PiperOrigin-RevId: 560828993
2023-08-28 15:04:05 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Jake VanderPlas
cb7c7ad942 jnp.ufunc: add fast paths for add/prod reductions 2023-08-28 08:30:23 -07:00
jax authors
871c9f4d76 Merge pull request #17307 from froystig:wrap-key
PiperOrigin-RevId: 560536131
2023-08-27 12:58:50 -07:00
Roy Frostig
a69f134cde add jax.extend.random.wrap_key_data 2023-08-26 11:39:25 -07:00
jax authors
9aaacc5f3a Merge pull request #17275 from jakevdp:ufunc-reduce-where
PiperOrigin-RevId: 560251564
2023-08-25 19:18:32 -07:00