mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
39 lines
700 B
Python
39 lines
700 B
Python
# JAX is Autograd and XLA
|
|
py_library(
|
|
name = "libjax",
|
|
srcs = glob(
|
|
[
|
|
"*.py",
|
|
"lib/*.py",
|
|
"interpreters/*.py",
|
|
"numpy/*.py",
|
|
"scipy/*.py",
|
|
],
|
|
exclude = [
|
|
"*_test.py",
|
|
"**/*_test.py",
|
|
],
|
|
),
|
|
deps = [
|
|
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
|
|
],
|
|
)
|
|
|
|
py_library(
|
|
name = "stax",
|
|
srcs = ["experimental/stax.py"],
|
|
deps = [":libjax"],
|
|
)
|
|
|
|
py_library(
|
|
name = "minmax",
|
|
srcs = ["experimental/minmax.py"],
|
|
deps = [":libjax"],
|
|
)
|
|
|
|
py_library(
|
|
name = "lapax",
|
|
srcs = ["experimental/lapax.py"],
|
|
deps = [":libjax"],
|
|
)
|