mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Added type annotations and removed unused imports (#2472)
* Added type annotations and removed unused imports * Adjusted type hints for pytype
This commit is contained in:
parent
9331fc5b42
commit
428377afb3
@ -54,7 +54,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
|
||||
for _ in range(warmup):
|
||||
f()
|
||||
|
||||
times = []
|
||||
times: List[float] = []
|
||||
count = 0
|
||||
while (count < iters if iters is not None
|
||||
else sum(times) < target_total_secs):
|
||||
@ -64,13 +64,13 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
|
||||
times.append(end - start)
|
||||
count += 1
|
||||
|
||||
times = onp.array(times)
|
||||
times_arr = onp.array(times)
|
||||
print("---------Benchmark results for %s---------" % (name or f.__name__))
|
||||
print("mean=%f std=%f %%std=%f total=%f" %
|
||||
(times.mean(), times.std(), _pstd(times), times.sum()))
|
||||
(times_arr.mean(), times_arr.std(), _pstd(times_arr), times_arr.sum()))
|
||||
print("#iters=%d #warmup=%d" % (count, warmup))
|
||||
print()
|
||||
return times
|
||||
return times_arr
|
||||
|
||||
|
||||
def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict],
|
||||
|
@ -1,9 +1,7 @@
|
||||
Building from source
|
||||
====================
|
||||
|
||||
First, obtain the JAX source code.
|
||||
|
||||
.. code-block:: shell
|
||||
First, obtain the JAX source code::
|
||||
|
||||
git clone https://github.com/google/jax
|
||||
cd jax
|
||||
@ -20,9 +18,7 @@ Installing ``jaxlib`` with pip
|
||||
..............................
|
||||
|
||||
If you're only modifying Python portions of JAX, we recommend installing
|
||||
``jaxlib`` from a prebuilt wheel using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
``jaxlib`` from a prebuilt wheel using pip::
|
||||
|
||||
pip install jaxlib
|
||||
|
||||
@ -40,9 +36,7 @@ To build ``jaxlib`` from source, you must also install some prerequisites:
|
||||
* Cython
|
||||
* six (required for during the jaxlib build only, not required at install time)
|
||||
|
||||
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with:
|
||||
|
||||
.. code-block:: shell
|
||||
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with::
|
||||
|
||||
sudo apt-get install g++ python python3-dev python3-numpy python3-scipy cython3 python3-six
|
||||
|
||||
@ -50,16 +44,12 @@ On Ubuntu 18.04 or Debian you can install the necessary prerequisites with:
|
||||
If you are building on a Mac, make sure XCode and the XCode command line tools
|
||||
are installed.
|
||||
|
||||
You can also install the necessary Python dependencies using ``pip``:
|
||||
You can also install the necessary Python dependencies using ``pip``::
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install numpy scipy cython six
|
||||
pip install numpy scipy cython six
|
||||
|
||||
|
||||
To build ``jaxlib`` with CUDA support, you can run
|
||||
|
||||
.. code-block:: shell
|
||||
To build ``jaxlib`` with CUDA support, you can run::
|
||||
|
||||
python build/build.py --enable_cuda
|
||||
pip install -e build # installs jaxlib (includes XLA)
|
||||
@ -70,9 +60,7 @@ specify the paths to CUDA and CUDNN, which you must have installed. Here
|
||||
``python`` should be the name of your Python 3 interpreter; on some systems, you
|
||||
may need to use ``python3`` instead.
|
||||
|
||||
To build ``jaxlib`` without CUDA GPU support (CPU only), drop the ``--enable_cuda``:
|
||||
|
||||
.. code-block:: shell
|
||||
To build ``jaxlib`` without CUDA GPU support (CPU only), drop the ``--enable_cuda``::
|
||||
|
||||
python build/build.py
|
||||
pip install -e build # installs jaxlib (includes XLA)
|
||||
@ -80,9 +68,7 @@ To build ``jaxlib`` without CUDA GPU support (CPU only), drop the ``--enable_cud
|
||||
Installing ``jax``
|
||||
------------------
|
||||
|
||||
Once ``jaxlib`` has been installed, you can install ``jax`` by running
|
||||
|
||||
.. code-block:: shell
|
||||
Once ``jaxlib`` has been installed, you can install ``jax`` by running::
|
||||
|
||||
pip install -e . # installs jax
|
||||
|
||||
@ -97,33 +83,25 @@ Running the tests
|
||||
To run all the JAX tests, we recommend using ``pytest-xdist``, which can run tests in
|
||||
parallel. First, install ``pytest-xdist`` and ``pytest-benchmark`` by running
|
||||
``pip install pytest-xdist pytest-benchmark``.
|
||||
Then, from the repository root directory run
|
||||
|
||||
.. code-block:: shell
|
||||
Then, from the repository root directory run::
|
||||
|
||||
pytest -n auto tests
|
||||
|
||||
|
||||
JAX generates test cases combinatorially, and you can control the number of
|
||||
cases that are generated and checked for each test (default is 10). The automated tests
|
||||
currently use 25:
|
||||
|
||||
.. code-block:: shell
|
||||
currently use 25::
|
||||
|
||||
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
|
||||
|
||||
The automated tests also run the tests with default 64-bit floats and ints:
|
||||
|
||||
.. code-block:: shell
|
||||
The automated tests also run the tests with default 64-bit floats and ints::
|
||||
|
||||
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
|
||||
|
||||
You can run a more specific set of tests using
|
||||
`pytest <https://docs.pytest.org/en/latest/usage.html#specifying-tests-selecting-tests>`_'s
|
||||
built-in selection mechanisms, or alternatively you can run a specific test
|
||||
file directly to see more detailed information about the cases being run:
|
||||
|
||||
.. code-block:: shell
|
||||
file directly to see more detailed information about the cases being run::
|
||||
|
||||
python tests/lax_numpy_test.py --num_generated_cases=5
|
||||
|
||||
@ -135,9 +113,7 @@ The Colab notebooks are tested for errors as part of the documentation build.
|
||||
Update documentation
|
||||
====================
|
||||
|
||||
To rebuild the documentation, install several packages:
|
||||
|
||||
.. code-block:: shell
|
||||
To rebuild the documentation, install several packages::
|
||||
|
||||
pip install -r docs/requirements.txt
|
||||
|
||||
@ -148,9 +124,7 @@ I have used successfully on the Mac: ``conda install -c conda-forge pandoc``.
|
||||
If you do not want to install ``pandoc`` then you should regenerate the documentation
|
||||
without the notebooks.
|
||||
|
||||
You run at top-level one of the following commands:
|
||||
|
||||
.. code-block:: shell
|
||||
You run at top-level one of the following commands::
|
||||
|
||||
sphinx-build -b html docs docs/build/html # with the notebooks
|
||||
sphinx-build -b html -D nbsphinx_execute=never docs docs/build/html # without the notebooks
|
||||
@ -195,9 +169,7 @@ branch. That branch is also built automatically, and you can
|
||||
see the generated documentation `here <https://jax.readthedocs.io/en/test-docs/>`_.
|
||||
|
||||
For a local test, I was able to do it in a fresh directory by replaying the commands
|
||||
I saw in the Readthedocs logs:
|
||||
|
||||
.. code-block:: shell
|
||||
I saw in the Readthedocs logs::
|
||||
|
||||
mkvirtualenv jax-docs # A new virtualenv
|
||||
mkdir jax-docs # A new directory
|
||||
|
@ -47,6 +47,7 @@ from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
|
||||
from .util import (unzip2, curry, partial, safe_map, safe_zip,
|
||||
WrapHashably, Hashable, prod, split_list, extend_name_stack, wrap_name)
|
||||
from .lib import xla_bridge as xb
|
||||
# Unused imports to be exported
|
||||
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
|
||||
host_id, host_ids, host_count)
|
||||
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
|
33
jax/core.py
33
jax/core.py
@ -21,6 +21,7 @@ from functools import total_ordering
|
||||
import itertools as it
|
||||
from weakref import ref
|
||||
import threading
|
||||
from typing import Dict, Generator, Iterator, Sequence, Type
|
||||
import types
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set
|
||||
|
||||
@ -28,8 +29,9 @@ import numpy as onp
|
||||
|
||||
from . import dtypes
|
||||
from . import linear_util as lu
|
||||
|
||||
from .util import safe_zip, safe_map, partial, curry, prod, partialmethod
|
||||
from .pprint_util import pp, vcat, hcat, pp_kv_pairs
|
||||
from .pprint_util import pp, vcat, hcat, pp_kv_pairs, PrettyPrint
|
||||
|
||||
# TODO(dougalm): the trace cache breaks the leak detector. Consisder solving.
|
||||
check_leaks = False
|
||||
@ -64,7 +66,7 @@ class Jaxpr(object):
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def subjaxprs(jaxpr):
|
||||
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
|
||||
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
|
||||
Does not descend recursively into the found subjaxprs.
|
||||
"""
|
||||
@ -77,12 +79,10 @@ def subjaxprs(jaxpr):
|
||||
|
||||
|
||||
class TypedJaxpr(object):
|
||||
def __init__(self, jaxpr, literals, in_avals, out_avals):
|
||||
assert type(jaxpr) is Jaxpr
|
||||
def __init__(self, jaxpr: Jaxpr, literals: Sequence,
|
||||
in_avals: Sequence['AbstractValue'], out_avals: Sequence['AbstractValue']):
|
||||
assert len(literals) == len(jaxpr.constvars)
|
||||
assert len(in_avals) == len(jaxpr.invars)
|
||||
assert all(isinstance(aval, AbstractValue) for aval in in_avals)
|
||||
assert all(isinstance(aval, AbstractValue) for aval in out_avals)
|
||||
|
||||
if not skip_checks:
|
||||
in_avals_raised = [raise_to_shaped(v) for v in in_avals]
|
||||
@ -106,7 +106,7 @@ class TypedJaxpr(object):
|
||||
__repr__ = __str__
|
||||
|
||||
@curry
|
||||
def jaxpr_as_fun(typed_jaxpr, *args):
|
||||
def jaxpr_as_fun(typed_jaxpr: TypedJaxpr, *args):
|
||||
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, *args)
|
||||
|
||||
|
||||
@ -518,7 +518,7 @@ def cur_sublevel():
|
||||
|
||||
|
||||
@contextmanager
|
||||
def new_master(trace_type, bottom=False):
|
||||
def new_master(trace_type: Type[Trace], bottom=False) -> Generator[MasterTrace, None, None]:
|
||||
level = trace_state.trace_stack.next_level(bottom)
|
||||
master = MasterTrace(level, trace_type)
|
||||
trace_state.trace_stack.push(master, bottom)
|
||||
@ -914,7 +914,7 @@ call_p.def_impl(call_impl)
|
||||
|
||||
# ------------------- Jaxpr printed representation -------------------
|
||||
|
||||
def check_jaxpr(jaxpr):
|
||||
def check_jaxpr(jaxpr: Jaxpr):
|
||||
"""Checks well-formedness of a jaxpr.
|
||||
|
||||
Specifically it checks that all variabled used are previously defined.
|
||||
@ -922,16 +922,16 @@ def check_jaxpr(jaxpr):
|
||||
def context():
|
||||
return "\njaxpr:\n{}\n".format(jaxpr)
|
||||
|
||||
def read_env(env, v):
|
||||
def read_env(env: Set[Var], v: Var):
|
||||
if type(v) is not Literal and v not in env:
|
||||
raise Exception("Variable '{}' not defined".format(v) + context())
|
||||
|
||||
def write_env(env, v):
|
||||
def write_env(env: Set[Var], v: Var):
|
||||
if v in env:
|
||||
raise Exception("Variable {} already bound".format(v) + context())
|
||||
env.add(v)
|
||||
|
||||
env = set()
|
||||
env: Set[Var] = set()
|
||||
read = partial(read_env, env)
|
||||
write = partial(write_env, env)
|
||||
|
||||
@ -952,22 +952,23 @@ def check_jaxpr(jaxpr):
|
||||
map(read, jaxpr.outvars)
|
||||
|
||||
|
||||
def pp_vars(vs):
|
||||
def pp_vars(vs) -> str:
|
||||
return ' '.join(map(str, vs))
|
||||
|
||||
def pp_eqn_compact(primitive_name, params):
|
||||
def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint:
|
||||
filtered_params = {k: v for k, v in params.items()
|
||||
if not isinstance(v, (Jaxpr, TypedJaxpr))}
|
||||
return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items()))
|
||||
|
||||
def pp_eqn(eqn):
|
||||
def pp_eqn(eqn: JaxprEqn) -> PrettyPrint:
|
||||
lhs = pp_vars(eqn.outvars)
|
||||
pp_subexpr = pp('')
|
||||
return (pp('{} = '.format(lhs)) >>
|
||||
pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items()))
|
||||
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr
|
||||
|
||||
def pp_jaxpr(jaxpr):
|
||||
|
||||
def pp_jaxpr(jaxpr) -> PrettyPrint:
|
||||
pp_outvars = str(tuple(jaxpr.outvars))
|
||||
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
|
||||
pp_vars(jaxpr.invars))) +
|
||||
|
@ -18,7 +18,7 @@ from .interpreters import xla
|
||||
from .lib import xla_client
|
||||
from .lib import xla_bridge
|
||||
|
||||
def to_dlpack(x):
|
||||
def to_dlpack(x: xla.DeviceArray):
|
||||
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
|
||||
|
||||
The DLPack shares memory with `x`.
|
||||
@ -48,4 +48,4 @@ def from_dlpack(dlpack, backend=None):
|
||||
xla_shape = buf.shape()
|
||||
assert not xla_shape.is_tuple()
|
||||
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
||||
return xla.DeviceArray(aval, buf.device(), lazy.array(aval.shape), buf)
|
||||
return xla.DeviceArray(aval, buf.device(), lazy.array(aval.shape), buf) # pytype: disable=attribute-error
|
||||
|
@ -12,9 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as onp
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from .. import core
|
||||
from .. import dtypes
|
||||
|
@ -16,7 +16,6 @@
|
||||
from functools import partial
|
||||
from typing import Callable, Dict
|
||||
|
||||
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from ..core import Trace, Tracer, Primitive, new_master
|
||||
|
@ -17,7 +17,7 @@ import itertools as it
|
||||
from collections import namedtuple
|
||||
import contextlib
|
||||
import threading
|
||||
from typing import Callable, Dict, Set
|
||||
from typing import Callable, Dict, Sequence, Set
|
||||
from weakref import ref
|
||||
|
||||
import numpy as onp
|
||||
@ -331,7 +331,7 @@ class PartialVal(tuple):
|
||||
|
||||
valid_pv_types = (AbstractValue, type(None))
|
||||
|
||||
def merge_pvals(val, pval):
|
||||
def merge_pvals(val, pval: PartialVal):
|
||||
pv, const = pval
|
||||
if isinstance(pv, AbstractValue):
|
||||
return val
|
||||
@ -348,7 +348,9 @@ def partial_val_aval(pv, const):
|
||||
else:
|
||||
raise TypeError(pv)
|
||||
|
||||
def trace_to_jaxpr(fun: lu.WrappedFun, pvals, instantiate=False, stage_out_calls=False, bottom=False):
|
||||
|
||||
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
instantiate=False, stage_out_calls=False, bottom=False):
|
||||
"""Traces a function, given abstract inputs, to a jaxpr."""
|
||||
trace_type = StagingJaxprTrace if stage_out_calls else JaxprTrace
|
||||
with new_master(trace_type, bottom=bottom) as master:
|
||||
|
@ -399,7 +399,7 @@ xla.canonicalize_dtype_handlers[ChunkedDeviceArray] = identity
|
||||
|
||||
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
|
||||
|
||||
def xla_pmap_impl(fun, *args, backend, axis_name, axis_size, global_axis_size,
|
||||
def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size, global_axis_size,
|
||||
devices, name, mapped_invars=None):
|
||||
abstract_args = map(xla.abstractify, args)
|
||||
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
|
||||
@ -602,7 +602,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None):
|
||||
devices = xb.get_backend(backend).get_default_device_assignment(nrep)
|
||||
assert nrep == len(devices)
|
||||
|
||||
aval = xla.abstractify(val)
|
||||
aval = xla.abstractify(val) # type: ShapedArray
|
||||
aval = ShapedArray((axis_size,) + aval.shape, aval.dtype)
|
||||
device_buffers = [xla.device_put(val, d) for d in devices]
|
||||
return ShardedDeviceArray(aval, device_buffers)
|
||||
|
@ -39,6 +39,7 @@ from ..lib import xla_client as xc
|
||||
from . import partial_eval as pe
|
||||
from . import ad
|
||||
from . import masking
|
||||
from typing import Callable
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool('jax_debug_nans',
|
||||
@ -58,7 +59,8 @@ def _make_abstract_unit(_): return xc.Shape.array_shape(onp.dtype('bool'), ())
|
||||
def _device_put_unit(_, device):
|
||||
return xc.Buffer.from_pyval(onp.zeros((), dtype=onp.dtype('bool')), device,
|
||||
backend=xb.get_device_backend(device))
|
||||
|
||||
def _make_array_shape(a):
|
||||
return xc.Shape.array_shape(a.dtype, a.shape)
|
||||
|
||||
### handlers
|
||||
|
||||
@ -72,8 +74,9 @@ def aval_to_xla_shape(aval):
|
||||
) from err
|
||||
xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = {}
|
||||
xla_shape_handlers[core.AbstractUnit] = _make_abstract_unit
|
||||
xla_shape_handlers[ShapedArray] = lambda a: xc.Shape.array_shape(a.dtype, a.shape)
|
||||
xla_shape_handlers[ConcreteArray] = lambda a: xc.Shape.array_shape(a.dtype, a.shape)
|
||||
|
||||
xla_shape_handlers[ShapedArray] = _make_array_shape
|
||||
xla_shape_handlers[ConcreteArray] = _make_array_shape
|
||||
|
||||
def aval_to_result_handler(device, aval):
|
||||
try:
|
||||
@ -131,7 +134,7 @@ def _canonicalize_python_scalar_dtype(typ, x):
|
||||
for _t in dtypes.python_scalar_dtypes.keys():
|
||||
canonicalize_dtype_handlers[_t] = partial(_canonicalize_python_scalar_dtype, _t)
|
||||
|
||||
def abstractify(x):
|
||||
def abstractify(x) -> core.AbstractValue:
|
||||
typ = type(x)
|
||||
aval_fn = pytype_aval_mappings.get(typ)
|
||||
if aval_fn: return aval_fn(x)
|
||||
@ -908,7 +911,7 @@ def _copy_device_array_to_device(x, device):
|
||||
backend=xb.get_device_backend(device))
|
||||
return DeviceArray(x.aval, device, x._lazy_expr, moved_buf)
|
||||
|
||||
def _force(x):
|
||||
def _force(x: DeviceArray) -> DeviceArray:
|
||||
if lazy.is_trivial(x._lazy_expr):
|
||||
return x
|
||||
else:
|
||||
@ -923,8 +926,8 @@ def _force(x):
|
||||
return force_fun(x)
|
||||
|
||||
@cache()
|
||||
def _lazy_force_computation(sticky, aval, device, lexpr
|
||||
) -> Callable[[DeviceValue], DeviceArray]:
|
||||
|
||||
def _lazy_force_computation(sticky, aval, device, lexpr) -> Callable[[DeviceArray], DeviceArray]:
|
||||
c = xb.make_computation_builder("lazy_force")
|
||||
if lazy.is_constant(lexpr):
|
||||
param = None
|
||||
|
@ -675,7 +675,7 @@ class APITest(jtu.JaxTestCase):
|
||||
ans = api.grad(foo, (0, 1))(3., 4.)
|
||||
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
|
||||
|
||||
def test_defvjp_all(self):
|
||||
def test_defvjp_all_custom_transforms(self):
|
||||
@api.custom_transforms
|
||||
def foo(x):
|
||||
return np.sin(x)
|
||||
|
@ -248,13 +248,13 @@ class FftTest(jtu.JaxTestCase):
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jax.np.fft.{} only supports 2 axes. "
|
||||
"Got axes = \\[0\\].".format(name, name),
|
||||
"Got axes = \\[0\\].".format(name),
|
||||
lambda: func(rng([2, 3], dtype=onp.float64), axes=[0])
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jax.np.fft.{} only supports 2 axes. "
|
||||
"Got axes = \\(0, 1, 2\\).".format(name, name),
|
||||
"Got axes = \\(0, 1, 2\\).".format(name),
|
||||
lambda: func(rng([2, 3, 3], dtype=onp.float64), axes=(0, 1, 2))
|
||||
)
|
||||
self.assertRaises(
|
||||
|
@ -433,7 +433,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
(isinstance(axis, tuple) and len(axis) == 1)
|
||||
else [None, 'fro', 1, 2, -1, -2, np.inf, -np.inf, 'nuc'])
|
||||
for dtype in float_types + complex_types
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
for rng_factory in [jtu.rand_default])) # type: ignore
|
||||
def testNorm(self, shape, dtype, ord, axis, keepdims, rng_factory):
|
||||
rng = rng_factory()
|
||||
_skip_if_unsupported_type(dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user