4 Commits

Author SHA1 Message Date
Jieying Luo
b403c2a083 [PJRT C API] Add parsing PJRT client create options from json file.
PiperOrigin-RevId: 518418760
2023-03-21 16:57:34 -07:00
Peter Hawkins
cca3961cde [JAX] Split _src/xla_bridge.py into a separate Bazel target.
Include _src/distributed.py and _src/clusters/*.py in the same target because they are in a strongly-connected component.

[XLA:Python] Set type of ArrayImpl to Any, since the JAX change now allows pytype to see that some values are ArrayImpls but ArrayImpls are not instances of jax.Array to Pytype.

Fix type of buffer_from_pyval.

PiperOrigin-RevId: 515687258
2023-03-10 11:12:02 -08:00
Peter Hawkins
ed491b3056 Shorten alias chains for names exported in jax. namespace.
Add some additional type annotations on public APIs.

This allows pytype to do a better job of type inference.

PiperOrigin-RevId: 513255770
2023-03-01 09:19:44 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00