Add a test to demonstrate how to force XLA to choose
a different sharding.
Also it is possible to return the wrong
shape from a partition function. We should error in this case.
PiperOrigin-RevId: 525606690
--
1f0eaa0059321f0b9301012d3bae7921056b5c9d by Jake VanderPlas <jakevdp@google.com>:
Test: fix TPU tolerance for Beta test
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15674 from jakevdp:beta-tpu-test 1f0eaa0059321f0b9301012d3bae7921056b5c9d
PiperOrigin-RevId: 525568586
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.
There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.
This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.
PiperOrigin-RevId: 525534901
- Add gamma fcn api in scipy.special
- Add tests for this purpose
- Add function to the docs
Currently, there is no implementation of the gamma function in jax
but there is one in scipy.special. This breaks some higher level
jit-compilation like in the blackjax backend for pymc. This commit
adds the missing gamma function.
Resolves: #15409
Formatting ops/attributes into str could be expensive. Instead, this uses a proper MLIR API to access `StringAttr` without printers.
PiperOrigin-RevId: 524999042
Start off with two functions: one for retrieving the attached topology, and the other for producing a mesh from the topology (modeling how `mesh_utils` might be adapted).
Use as:
```
topo = jax.topologies.get_attached_topology() // Discovers local devices.
mesh = jax.topologies.make_mesh(topo, mesh_shape, axis_names) # see mesh_utils.create_device_mesh.
```
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 524909149
Use a Protocol instead of an abstract base class for the CacheInterface since it allows us to use one fewer file.
No functional change intended.
PiperOrigin-RevId: 524855263