110 Commits

Author SHA1 Message Date
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Peter Hawkins
b4402185db Move PartitionSpec into its own file (jax/_src/partition_spec.py).
No functional changes intended.

A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.

PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Peter Hawkins
dfe95dcb4e Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Peter Hawkins
452f3c55e3 Rename jax._src.sharding_utils to jax._src.op_shardings.
Move some more op_sharding related helpers to that module.

PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
Yash Katariya
03d5aaad96 Switch the implementation of sharded_aval to a simpler one.
Create sharding_utils.py to move utilities from pxla.py to sharding_utils.py to break cyclic deps.

PiperOrigin-RevId: 522209346
2023-04-05 18:32:00 -07:00
jax authors
3c1f3abba2 Merge pull request #15149 from sharadmv:runstate
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
Peter Hawkins
31eeaed913 Split mlir.py and xla.py into separate Bazel targets.
PiperOrigin-RevId: 520737811
2023-03-30 14:06:16 -07:00
Peter Hawkins
47177e1417 Split more targets out the main JAX Bazel target.
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Peter Hawkins
3135fbcd7f [JAX] Delete _DeviceArray and DeviceArray.
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00
jax authors
a964ae7fac Internal Code Change
PiperOrigin-RevId: 520341781
2023-03-29 08:23:56 -07:00
Peter Hawkins
c2d6fcc0e6 Split core.py and several files in an SCC with it into a separate Bazel build target.
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
Peter Hawkins
88c2898e36 Use pytype_strict_library() in Bazel build rules.
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Peter Hawkins
f461c4ef0c Move jax._src.typing into a separate Bazel target.
PiperOrigin-RevId: 518899136
2023-03-23 10:30:08 -07:00
Yash Katariya
634035abd7 Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now
PiperOrigin-RevId: 516905931
2023-03-15 13:00:00 -07:00
Yash Katariya
88584290aa Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
PiperOrigin-RevId: 516881635
2023-03-15 11:34:57 -07:00
Peter Hawkins
e4b154b660 Split basearray into separate Bazel module.
Move the definition of ArrayLike into basearray to avoid a cyclic dependency between array.py and basearray.

PiperOrigin-RevId: 516264828
2023-03-13 11:14:41 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Peter Hawkins
86f3a82014 Split _src/mesh into a separate Bazel target.
PiperOrigin-RevId: 516218798
2023-03-13 08:31:55 -07:00
Peter Hawkins
a32a7ff903 Move _src/tree_util.py into a separate Bazel target.
Fix a type error in api.py revealed by the split.

PiperOrigin-RevId: 515745227
2023-03-10 14:51:52 -08:00
Peter Hawkins
a722ec08f9 No changes.
PiperOrigin-RevId: 515737909
2023-03-10 14:23:27 -08:00
Peter Hawkins
0935a7cb31 Split _src files custom_api_util, deprecations, effects and environment_info into separate Bazel targets.
PiperOrigin-RevId: 515708165
2023-03-10 12:26:05 -08:00
Peter Hawkins
0420192d29 Split _src/profiler into a separate BUILD target.
Clean up some stale excludes as well.

PiperOrigin-RevId: 515694871
2023-03-10 11:38:46 -08:00
Peter Hawkins
cca3961cde [JAX] Split _src/xla_bridge.py into a separate Bazel target.
Include _src/distributed.py and _src/clusters/*.py in the same target because they are in a strongly-connected component.

[XLA:Python] Set type of ArrayImpl to Any, since the JAX change now allows pytype to see that some values are ArrayImpls but ArrayImpls are not instances of jax.Array to Pytype.

Fix type of buffer_from_pyval.

PiperOrigin-RevId: 515687258
2023-03-10 11:12:02 -08:00
pizzud
04def0b6ab lazy_loader_module: Move to new internal_test_util directory.
Now we no longer need to mess with sys.path in lazy_loader_test.

PiperOrigin-RevId: 515674188
2023-03-10 10:29:33 -08:00
Peter Hawkins
d58be3d4df Split source_info_util into its own Bazel target.
PiperOrigin-RevId: 515646269
2023-03-10 08:41:06 -08:00
Peter Hawkins
7bfd89a89c Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets.
PiperOrigin-RevId: 515420193
2023-03-09 13:11:04 -08:00
Peter Hawkins
7fd1e2ff47 Split _src/traceback_util.py into its own Bazel target.
Improve its type annotations.

PiperOrigin-RevId: 515376365
2023-03-09 10:33:47 -08:00
Peter Hawkins
9912a8eb56 Split _src/pretty_printer.py into its own Bazel target.
PiperOrigin-RevId: 515348089
2023-03-09 08:51:30 -08:00
Peter Hawkins
08789fd967 Exclude "util.py" and "config.py" from the main JAX bazel target.
This completes the process of splitting these targets out of :jax.

PiperOrigin-RevId: 515340312
2023-03-09 08:17:03 -08:00
Peter Hawkins
0e05a7987f Split some submodules out of //jax under Bazel.
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
pizzud
22cbf95e07 lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.

PiperOrigin-RevId: 514745247
2023-03-07 08:49:42 -08:00
Peter Hawkins
0bb75afaa6 Remove global_device_array from shared jax bazel library.
Require Bazel users to depend explicitly on :global_device_array. Change in preparation for removing global device arrays.

PiperOrigin-RevId: 511273814
2023-02-21 12:27:44 -08:00
Peter Hawkins
f7734fd6a4 Limit visibility of Bazel target jax:global_device_array.
PiperOrigin-RevId: 510521459
2023-02-17 14:30:05 -08:00
pizzud
631e4ed7e0 lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Peter Hawkins
c368562529 Add keep_dep tag to :global_device_array build target to hint that it should be kept.
PiperOrigin-RevId: 510241400
2023-02-16 14:15:21 -08:00
Peter Hawkins
43b615c0a0 Move global_device_array into its own BUILD target.
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
jax authors
b8d6efe22f Merge pull request #14273 from mattjj:shard-map
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Sharad Vikram
c9a57e1b44 Delete jax.experimental.callback
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Jake VanderPlas
66262901f0 [sparse] improve testing framework 2022-11-16 09:58:06 -08:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Jake VanderPlas
265b39d23f Add pytype_srcs to main jax BUILD rule
PiperOrigin-RevId: 476989241
2022-09-26 14:18:13 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
79406757d0 Remove deprecated jax.experimental.optimizers
The new location is jax.example_libraries.optimizers
2022-08-09 08:50:59 -07:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Jake VanderPlas
91dbcbf525 Remove deprecated jax.experimental.stax
The new location is jax.example_libraries.stax
2022-08-02 16:50:06 -07:00