Peter Hawkins
be1cf46a49
Split sharding_impls into its own Bazel target.
...
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Peter Hawkins
b4402185db
Move PartitionSpec into its own file (jax/_src/partition_spec.py).
...
No functional changes intended.
A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.
PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Peter Hawkins
dfe95dcb4e
Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
...
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Peter Hawkins
452f3c55e3
Rename jax._src.sharding_utils to jax._src.op_shardings.
...
Move some more op_sharding related helpers to that module.
PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
Yash Katariya
03d5aaad96
Switch the implementation of sharded_aval to a simpler one.
...
Create sharding_utils.py to move utilities from pxla.py to sharding_utils.py to break cyclic deps.
PiperOrigin-RevId: 522209346
2023-04-05 18:32:00 -07:00
jax authors
3c1f3abba2
Merge pull request #15149 from sharadmv:runstate
...
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
Peter Hawkins
31eeaed913
Split mlir.py and xla.py into separate Bazel targets.
...
PiperOrigin-RevId: 520737811
2023-03-30 14:06:16 -07:00
Peter Hawkins
47177e1417
Split more targets out the main JAX Bazel target.
...
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Peter Hawkins
3135fbcd7f
[JAX] Delete _DeviceArray and DeviceArray.
...
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00
jax authors
a964ae7fac
Internal Code Change
...
PiperOrigin-RevId: 520341781
2023-03-29 08:23:56 -07:00
Peter Hawkins
c2d6fcc0e6
Split core.py and several files in an SCC with it into a separate Bazel build target.
...
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
Peter Hawkins
88c2898e36
Use pytype_strict_library() in Bazel build rules.
...
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Peter Hawkins
f461c4ef0c
Move jax._src.typing into a separate Bazel target.
...
PiperOrigin-RevId: 518899136
2023-03-23 10:30:08 -07:00
Yash Katariya
634035abd7
Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now
...
PiperOrigin-RevId: 516905931
2023-03-15 13:00:00 -07:00
Yash Katariya
88584290aa
Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
...
PiperOrigin-RevId: 516881635
2023-03-15 11:34:57 -07:00
Peter Hawkins
e4b154b660
Split basearray into separate Bazel module.
...
Move the definition of ArrayLike into basearray to avoid a cyclic dependency between array.py and basearray.
PiperOrigin-RevId: 516264828
2023-03-13 11:14:41 -07:00
Peter Hawkins
1925aa1109
Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
...
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Peter Hawkins
86f3a82014
Split _src/mesh into a separate Bazel target.
...
PiperOrigin-RevId: 516218798
2023-03-13 08:31:55 -07:00
Peter Hawkins
a32a7ff903
Move _src/tree_util.py into a separate Bazel target.
...
Fix a type error in api.py revealed by the split.
PiperOrigin-RevId: 515745227
2023-03-10 14:51:52 -08:00
Peter Hawkins
a722ec08f9
No changes.
...
PiperOrigin-RevId: 515737909
2023-03-10 14:23:27 -08:00
Peter Hawkins
0935a7cb31
Split _src files custom_api_util, deprecations, effects and environment_info into separate Bazel targets.
...
PiperOrigin-RevId: 515708165
2023-03-10 12:26:05 -08:00
Peter Hawkins
0420192d29
Split _src/profiler into a separate BUILD target.
...
Clean up some stale excludes as well.
PiperOrigin-RevId: 515694871
2023-03-10 11:38:46 -08:00
Peter Hawkins
cca3961cde
[JAX] Split _src/xla_bridge.py into a separate Bazel target.
...
Include _src/distributed.py and _src/clusters/*.py in the same target because they are in a strongly-connected component.
[XLA:Python] Set type of ArrayImpl to Any, since the JAX change now allows pytype to see that some values are ArrayImpls but ArrayImpls are not instances of jax.Array to Pytype.
Fix type of buffer_from_pyval.
PiperOrigin-RevId: 515687258
2023-03-10 11:12:02 -08:00
pizzud
04def0b6ab
lazy_loader_module: Move to new internal_test_util directory.
...
Now we no longer need to mess with sys.path in lazy_loader_test.
PiperOrigin-RevId: 515674188
2023-03-10 10:29:33 -08:00
Peter Hawkins
d58be3d4df
Split source_info_util into its own Bazel target.
...
PiperOrigin-RevId: 515646269
2023-03-10 08:41:06 -08:00
Peter Hawkins
7bfd89a89c
Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets.
...
PiperOrigin-RevId: 515420193
2023-03-09 13:11:04 -08:00
Peter Hawkins
7fd1e2ff47
Split _src/traceback_util.py into its own Bazel target.
...
Improve its type annotations.
PiperOrigin-RevId: 515376365
2023-03-09 10:33:47 -08:00
Peter Hawkins
9912a8eb56
Split _src/pretty_printer.py into its own Bazel target.
...
PiperOrigin-RevId: 515348089
2023-03-09 08:51:30 -08:00
Peter Hawkins
08789fd967
Exclude "util.py" and "config.py" from the main JAX bazel target.
...
This completes the process of splitting these targets out of :jax.
PiperOrigin-RevId: 515340312
2023-03-09 08:17:03 -08:00
Peter Hawkins
0e05a7987f
Split some submodules out of //jax under Bazel.
...
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py
PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
pizzud
22cbf95e07
lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
...
Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.
PiperOrigin-RevId: 514745247
2023-03-07 08:49:42 -08:00
Peter Hawkins
0bb75afaa6
Remove global_device_array from shared jax bazel library.
...
Require Bazel users to depend explicitly on :global_device_array. Change in preparation for removing global device arrays.
PiperOrigin-RevId: 511273814
2023-02-21 12:27:44 -08:00
Peter Hawkins
f7734fd6a4
Limit visibility of Bazel target jax:global_device_array.
...
PiperOrigin-RevId: 510521459
2023-02-17 14:30:05 -08:00
pizzud
631e4ed7e0
lax_test: Create a separate module for lax-specific test utils in a new package.
...
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.
The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.
Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.
PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Peter Hawkins
c368562529
Add keep_dep tag to :global_device_array build target to hint that it should be kept.
...
PiperOrigin-RevId: 510241400
2023-02-16 14:15:21 -08:00
Peter Hawkins
43b615c0a0
Move global_device_array into its own BUILD target.
...
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
jax authors
b8d6efe22f
Merge pull request #14273 from mattjj:shard-map
...
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973
shard_map (shmap) prototype and JEP
...
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Peter Hawkins
c90a85403b
Merge pull request #14248 from jakevdp:dead-code
...
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Sharad Vikram
c9a57e1b44
Delete jax.experimental.callback
...
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Qiao Zhang
4d1c4bc761
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
...
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf
Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
...
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
...
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Jake VanderPlas
66262901f0
[sparse] improve testing framework
2022-11-16 09:58:06 -08:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Jake VanderPlas
265b39d23f
Add pytype_srcs to main jax BUILD rule
...
PiperOrigin-RevId: 476989241
2022-09-26 14:18:13 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
79406757d0
Remove deprecated jax.experimental.optimizers
...
The new location is jax.example_libraries.optimizers
2022-08-09 08:50:59 -07:00
Peter Hawkins
b865111996
Refactor BUILD files to avoid individually naming Python dependencies.
...
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.
Fix build failure with dangling matplotlib reference.
PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Jake VanderPlas
91dbcbf525
Remove deprecated jax.experimental.stax
...
The new location is jax.example_libraries.stax
2022-08-02 16:50:06 -07:00