diff --git a/jax/LICENSE b/LICENSE similarity index 100% rename from jax/LICENSE rename to LICENSE diff --git a/jax/LICENSE_SHORT b/LICENSE_SHORT similarity index 100% rename from jax/LICENSE_SHORT rename to LICENSE_SHORT diff --git a/examples/BUILD b/examples/BUILD new file mode 100644 index 000000000..227c1bce7 --- /dev/null +++ b/examples/BUILD @@ -0,0 +1,30 @@ +py_binary( + name = "interactive", + srcs = ["interactive.py"], + deps = ["//jax:libjax"], +) + +py_library( + name = "datasets", + srcs = ["datasets.py"], +) + +py_binary( + name = "mnist_classifier", + srcs = ["mnist_classifier.py"], + deps = [ + ":datasets", + "//jax:libjax", + ], +) + +py_binary( + name = "mnist_vae", + srcs = ["mnist_vae.py"], + deps = [ + ":datasets", + ":minmax", + ":stax", + "//jax:libjax", + ], +) diff --git a/jax/BUILD b/jax/BUILD index 0b3567534..c5155e6e7 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1,14 +1,4 @@ # JAX is Autograd and XLA -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2 - -exports_files(["LICENSE"]) - -load(":build_defs.bzl", "jax_test") - -package_group(name = "jax") - py_library( name = "libjax", srcs = glob( @@ -29,128 +19,20 @@ py_library( ], ) -jax_test( - name = "core_test", - srcs = ["tests/core_test.py"], - shard_count = { - "cpu": 5, - }, -) - -jax_test( - name = "lax_test", - srcs = ["tests/lax_test.py"], - shard_count = { - "cpu": 40, - "gpu": 20, - }, -) - -jax_test( - name = "lax_numpy_test", - srcs = ["tests/lax_numpy_test.py"], - shard_count = { - "cpu": 40, - "gpu": 20, - }, -) - -jax_test( - name = "lax_numpy_indexing_test", - srcs = ["tests/lax_numpy_indexing_test.py"], - shard_count = { - "cpu": 10, - "gpu": 2, - }, -) - -jax_test( - name = "lax_scipy_test", - srcs = ["tests/lax_scipy_test.py"], - shard_count = { - "cpu": 10, - "gpu": 2, - }, -) - -jax_test( - name = "random_test", - srcs = ["tests/random_test.py"], -) - -jax_test( - name = "api_test", - srcs = ["tests/api_test.py"], -) - -jax_test( - name = "batching_test", - srcs = ["tests/batching_test.py"], -) - -py_binary( - name = "interactive", - srcs = ["examples/interactive.py"], - deps = [":libjax"], -) - py_library( name = "stax", srcs = ["experimental/stax.py"], deps = [":libjax"], ) -jax_test( - name = "stax_test", - srcs = ["tests/stax_test.py"], - deps = [":stax"], -) - py_library( name = "minmax", srcs = ["experimental/minmax.py"], deps = [":libjax"], ) -jax_test( - name = "minmax_test", - srcs = ["tests/minmax_test.py"], - deps = [":minmax"], -) - py_library( name = "lapax", srcs = ["experimental/lapax.py"], deps = [":libjax"], ) - -jax_test( - name = "lapax_test", - srcs = ["tests/lapax_test.py"], - deps = [":lapax"], -) - -py_library( - name = "datasets", - srcs = ["examples/datasets.py"], -) - -py_binary( - name = "mnist_classifier", - srcs = ["examples/mnist_classifier.py"], - deps = [ - ":datasets", - ":libjax", - ], -) - -py_binary( - name = "mnist_vae", - srcs = ["examples/mnist_vae.py"], - deps = [ - ":datasets", - ":libjax", - ":minmax", - ":stax", - ], -) diff --git a/tests/BUILD b/tests/BUILD new file mode 100644 index 000000000..3d4a13e3b --- /dev/null +++ b/tests/BUILD @@ -0,0 +1,78 @@ +load(":build_defs.bzl", "jax_test") + +jax_test( + name = "core_test", + srcs = ["tests/core_test.py"], + shard_count = { + "cpu": 5, + }, +) + +jax_test( + name = "lax_test", + srcs = ["tests/lax_test.py"], + shard_count = { + "cpu": 40, + "gpu": 20, + }, +) + +jax_test( + name = "lax_numpy_test", + srcs = ["tests/lax_numpy_test.py"], + shard_count = { + "cpu": 40, + "gpu": 20, + }, +) + +jax_test( + name = "lax_numpy_indexing_test", + srcs = ["tests/lax_numpy_indexing_test.py"], + shard_count = { + "cpu": 10, + "gpu": 2, + }, +) + +jax_test( + name = "lax_scipy_test", + srcs = ["tests/lax_scipy_test.py"], + shard_count = { + "cpu": 10, + "gpu": 2, + }, +) + +jax_test( + name = "random_test", + srcs = ["tests/random_test.py"], +) + +jax_test( + name = "api_test", + srcs = ["tests/api_test.py"], +) + +jax_test( + name = "batching_test", + srcs = ["tests/batching_test.py"], +) + +jax_test( + name = "stax_test", + srcs = ["tests/stax_test.py"], + deps = [":stax"], +) + +jax_test( + name = "minmax_test", + srcs = ["tests/minmax_test.py"], + deps = [":minmax"], +) + +jax_test( + name = "lapax_test", + srcs = ["tests/lapax_test.py"], + deps = [":lapax"], +) diff --git a/jax/build_defs.bzl b/tests/build_defs.bzl similarity index 100% rename from jax/build_defs.bzl rename to tests/build_defs.bzl