Added type annotations and removed unused imports (#2472)

* Added type annotations and removed unused imports

* Adjusted type hints for pytype
This commit is contained in:
George Necula 2020-03-21 13:54:30 +01:00 committed by GitHub
parent 9331fc5b42
commit 428377afb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 61 additions and 84 deletions

View File

@ -54,7 +54,7 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
for _ in range(warmup):
f()
times = []
times: List[float] = []
count = 0
while (count < iters if iters is not None
else sum(times) < target_total_secs):
@ -64,13 +64,13 @@ def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
times.append(end - start)
count += 1
times = onp.array(times)
times_arr = onp.array(times)
print("---------Benchmark results for %s---------" % (name or f.__name__))
print("mean=%f std=%f %%std=%f total=%f" %
(times.mean(), times.std(), _pstd(times), times.sum()))
(times_arr.mean(), times_arr.std(), _pstd(times_arr), times_arr.sum()))
print("#iters=%d #warmup=%d" % (count, warmup))
print()
return times
return times_arr
def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict],

View File

@ -1,9 +1,7 @@
Building from source
====================
First, obtain the JAX source code.
.. code-block:: shell
First, obtain the JAX source code::
git clone https://github.com/google/jax
cd jax
@ -20,9 +18,7 @@ Installing ``jaxlib`` with pip
..............................
If you're only modifying Python portions of JAX, we recommend installing
``jaxlib`` from a prebuilt wheel using pip:
.. code-block:: shell
``jaxlib`` from a prebuilt wheel using pip::
pip install jaxlib
@ -40,9 +36,7 @@ To build ``jaxlib`` from source, you must also install some prerequisites:
* Cython
* six (required for during the jaxlib build only, not required at install time)
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with:
.. code-block:: shell
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with::
sudo apt-get install g++ python python3-dev python3-numpy python3-scipy cython3 python3-six
@ -50,16 +44,12 @@ On Ubuntu 18.04 or Debian you can install the necessary prerequisites with:
If you are building on a Mac, make sure XCode and the XCode command line tools
are installed.
You can also install the necessary Python dependencies using ``pip``:
You can also install the necessary Python dependencies using ``pip``::
.. code-block:: shell
pip install numpy scipy cython six
pip install numpy scipy cython six
To build ``jaxlib`` with CUDA support, you can run
.. code-block:: shell
To build ``jaxlib`` with CUDA support, you can run::
python build/build.py --enable_cuda
pip install -e build # installs jaxlib (includes XLA)
@ -70,9 +60,7 @@ specify the paths to CUDA and CUDNN, which you must have installed. Here
``python`` should be the name of your Python 3 interpreter; on some systems, you
may need to use ``python3`` instead.
To build ``jaxlib`` without CUDA GPU support (CPU only), drop the ``--enable_cuda``:
.. code-block:: shell
To build ``jaxlib`` without CUDA GPU support (CPU only), drop the ``--enable_cuda``::
python build/build.py
pip install -e build # installs jaxlib (includes XLA)
@ -80,9 +68,7 @@ To build ``jaxlib`` without CUDA GPU support (CPU only), drop the ``--enable_cud
Installing ``jax``
------------------
Once ``jaxlib`` has been installed, you can install ``jax`` by running
.. code-block:: shell
Once ``jaxlib`` has been installed, you can install ``jax`` by running::
pip install -e . # installs jax
@ -97,33 +83,25 @@ Running the tests
To run all the JAX tests, we recommend using ``pytest-xdist``, which can run tests in
parallel. First, install ``pytest-xdist`` and ``pytest-benchmark`` by running
``pip install pytest-xdist pytest-benchmark``.
Then, from the repository root directory run
.. code-block:: shell
Then, from the repository root directory run::
pytest -n auto tests
JAX generates test cases combinatorially, and you can control the number of
cases that are generated and checked for each test (default is 10). The automated tests
currently use 25:
.. code-block:: shell
currently use 25::
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
The automated tests also run the tests with default 64-bit floats and ints:
.. code-block:: shell
The automated tests also run the tests with default 64-bit floats and ints::
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
You can run a more specific set of tests using
`pytest <https://docs.pytest.org/en/latest/usage.html#specifying-tests-selecting-tests>`_'s
built-in selection mechanisms, or alternatively you can run a specific test
file directly to see more detailed information about the cases being run:
.. code-block:: shell
file directly to see more detailed information about the cases being run::
python tests/lax_numpy_test.py --num_generated_cases=5
@ -135,9 +113,7 @@ The Colab notebooks are tested for errors as part of the documentation build.
Update documentation
====================
To rebuild the documentation, install several packages:
.. code-block:: shell
To rebuild the documentation, install several packages::
pip install -r docs/requirements.txt
@ -148,9 +124,7 @@ I have used successfully on the Mac: ``conda install -c conda-forge pandoc``.
If you do not want to install ``pandoc`` then you should regenerate the documentation
without the notebooks.
You run at top-level one of the following commands:
.. code-block:: shell
You run at top-level one of the following commands::
sphinx-build -b html docs docs/build/html # with the notebooks
sphinx-build -b html -D nbsphinx_execute=never docs docs/build/html # without the notebooks
@ -195,9 +169,7 @@ branch. That branch is also built automatically, and you can
see the generated documentation `here <https://jax.readthedocs.io/en/test-docs/>`_.
For a local test, I was able to do it in a fresh directory by replaying the commands
I saw in the Readthedocs logs:
.. code-block:: shell
I saw in the Readthedocs logs::
mkvirtualenv jax-docs # A new virtualenv
mkdir jax-docs # A new directory

View File

@ -47,6 +47,7 @@ from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
from .util import (unzip2, curry, partial, safe_map, safe_zip,
WrapHashably, Hashable, prod, split_list, extend_name_stack, wrap_name)
from .lib import xla_bridge as xb
# Unused imports to be exported
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped

View File

@ -21,6 +21,7 @@ from functools import total_ordering
import itertools as it
from weakref import ref
import threading
from typing import Dict, Generator, Iterator, Sequence, Type
import types
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set
@ -28,8 +29,9 @@ import numpy as onp
from . import dtypes
from . import linear_util as lu
from .util import safe_zip, safe_map, partial, curry, prod, partialmethod
from .pprint_util import pp, vcat, hcat, pp_kv_pairs
from .pprint_util import pp, vcat, hcat, pp_kv_pairs, PrettyPrint
# TODO(dougalm): the trace cache breaks the leak detector. Consisder solving.
check_leaks = False
@ -64,7 +66,7 @@ class Jaxpr(object):
__repr__ = __str__
def subjaxprs(jaxpr):
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
Does not descend recursively into the found subjaxprs.
"""
@ -77,12 +79,10 @@ def subjaxprs(jaxpr):
class TypedJaxpr(object):
def __init__(self, jaxpr, literals, in_avals, out_avals):
assert type(jaxpr) is Jaxpr
def __init__(self, jaxpr: Jaxpr, literals: Sequence,
in_avals: Sequence['AbstractValue'], out_avals: Sequence['AbstractValue']):
assert len(literals) == len(jaxpr.constvars)
assert len(in_avals) == len(jaxpr.invars)
assert all(isinstance(aval, AbstractValue) for aval in in_avals)
assert all(isinstance(aval, AbstractValue) for aval in out_avals)
if not skip_checks:
in_avals_raised = [raise_to_shaped(v) for v in in_avals]
@ -106,7 +106,7 @@ class TypedJaxpr(object):
__repr__ = __str__
@curry
def jaxpr_as_fun(typed_jaxpr, *args):
def jaxpr_as_fun(typed_jaxpr: TypedJaxpr, *args):
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, *args)
@ -518,7 +518,7 @@ def cur_sublevel():
@contextmanager
def new_master(trace_type, bottom=False):
def new_master(trace_type: Type[Trace], bottom=False) -> Generator[MasterTrace, None, None]:
level = trace_state.trace_stack.next_level(bottom)
master = MasterTrace(level, trace_type)
trace_state.trace_stack.push(master, bottom)
@ -914,7 +914,7 @@ call_p.def_impl(call_impl)
# ------------------- Jaxpr printed representation -------------------
def check_jaxpr(jaxpr):
def check_jaxpr(jaxpr: Jaxpr):
"""Checks well-formedness of a jaxpr.
Specifically it checks that all variabled used are previously defined.
@ -922,16 +922,16 @@ def check_jaxpr(jaxpr):
def context():
return "\njaxpr:\n{}\n".format(jaxpr)
def read_env(env, v):
def read_env(env: Set[Var], v: Var):
if type(v) is not Literal and v not in env:
raise Exception("Variable '{}' not defined".format(v) + context())
def write_env(env, v):
def write_env(env: Set[Var], v: Var):
if v in env:
raise Exception("Variable {} already bound".format(v) + context())
env.add(v)
env = set()
env: Set[Var] = set()
read = partial(read_env, env)
write = partial(write_env, env)
@ -952,22 +952,23 @@ def check_jaxpr(jaxpr):
map(read, jaxpr.outvars)
def pp_vars(vs):
def pp_vars(vs) -> str:
return ' '.join(map(str, vs))
def pp_eqn_compact(primitive_name, params):
def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint:
filtered_params = {k: v for k, v in params.items()
if not isinstance(v, (Jaxpr, TypedJaxpr))}
return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items()))
def pp_eqn(eqn):
def pp_eqn(eqn: JaxprEqn) -> PrettyPrint:
lhs = pp_vars(eqn.outvars)
pp_subexpr = pp('')
return (pp('{} = '.format(lhs)) >>
pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items()))
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr
def pp_jaxpr(jaxpr):
def pp_jaxpr(jaxpr) -> PrettyPrint:
pp_outvars = str(tuple(jaxpr.outvars))
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
pp_vars(jaxpr.invars))) +

View File

@ -18,7 +18,7 @@ from .interpreters import xla
from .lib import xla_client
from .lib import xla_bridge
def to_dlpack(x):
def to_dlpack(x: xla.DeviceArray):
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
The DLPack shares memory with `x`.
@ -48,4 +48,4 @@ def from_dlpack(dlpack, backend=None):
xla_shape = buf.shape()
assert not xla_shape.is_tuple()
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
return xla.DeviceArray(aval, buf.device(), lazy.array(aval.shape), buf)
return xla.DeviceArray(aval, buf.device(), lazy.array(aval.shape), buf) # pytype: disable=attribute-error

View File

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import numpy as onp
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
from .. import core
from .. import dtypes

View File

@ -16,7 +16,6 @@
from functools import partial
from typing import Callable, Dict
from .. import core
from .. import linear_util as lu
from ..core import Trace, Tracer, Primitive, new_master

View File

@ -17,7 +17,7 @@ import itertools as it
from collections import namedtuple
import contextlib
import threading
from typing import Callable, Dict, Set
from typing import Callable, Dict, Sequence, Set
from weakref import ref
import numpy as onp
@ -331,7 +331,7 @@ class PartialVal(tuple):
valid_pv_types = (AbstractValue, type(None))
def merge_pvals(val, pval):
def merge_pvals(val, pval: PartialVal):
pv, const = pval
if isinstance(pv, AbstractValue):
return val
@ -348,7 +348,9 @@ def partial_val_aval(pv, const):
else:
raise TypeError(pv)
def trace_to_jaxpr(fun: lu.WrappedFun, pvals, instantiate=False, stage_out_calls=False, bottom=False):
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate=False, stage_out_calls=False, bottom=False):
"""Traces a function, given abstract inputs, to a jaxpr."""
trace_type = StagingJaxprTrace if stage_out_calls else JaxprTrace
with new_master(trace_type, bottom=bottom) as master:

View File

@ -399,7 +399,7 @@ xla.canonicalize_dtype_handlers[ChunkedDeviceArray] = identity
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
def xla_pmap_impl(fun, *args, backend, axis_name, axis_size, global_axis_size,
def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size, global_axis_size,
devices, name, mapped_invars=None):
abstract_args = map(xla.abstractify, args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
@ -602,7 +602,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None):
devices = xb.get_backend(backend).get_default_device_assignment(nrep)
assert nrep == len(devices)
aval = xla.abstractify(val)
aval = xla.abstractify(val) # type: ShapedArray
aval = ShapedArray((axis_size,) + aval.shape, aval.dtype)
device_buffers = [xla.device_put(val, d) for d in devices]
return ShardedDeviceArray(aval, device_buffers)

View File

@ -39,6 +39,7 @@ from ..lib import xla_client as xc
from . import partial_eval as pe
from . import ad
from . import masking
from typing import Callable
FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_debug_nans',
@ -58,7 +59,8 @@ def _make_abstract_unit(_): return xc.Shape.array_shape(onp.dtype('bool'), ())
def _device_put_unit(_, device):
return xc.Buffer.from_pyval(onp.zeros((), dtype=onp.dtype('bool')), device,
backend=xb.get_device_backend(device))
def _make_array_shape(a):
return xc.Shape.array_shape(a.dtype, a.shape)
### handlers
@ -72,8 +74,9 @@ def aval_to_xla_shape(aval):
) from err
xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = {}
xla_shape_handlers[core.AbstractUnit] = _make_abstract_unit
xla_shape_handlers[ShapedArray] = lambda a: xc.Shape.array_shape(a.dtype, a.shape)
xla_shape_handlers[ConcreteArray] = lambda a: xc.Shape.array_shape(a.dtype, a.shape)
xla_shape_handlers[ShapedArray] = _make_array_shape
xla_shape_handlers[ConcreteArray] = _make_array_shape
def aval_to_result_handler(device, aval):
try:
@ -131,7 +134,7 @@ def _canonicalize_python_scalar_dtype(typ, x):
for _t in dtypes.python_scalar_dtypes.keys():
canonicalize_dtype_handlers[_t] = partial(_canonicalize_python_scalar_dtype, _t)
def abstractify(x):
def abstractify(x) -> core.AbstractValue:
typ = type(x)
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
@ -908,7 +911,7 @@ def _copy_device_array_to_device(x, device):
backend=xb.get_device_backend(device))
return DeviceArray(x.aval, device, x._lazy_expr, moved_buf)
def _force(x):
def _force(x: DeviceArray) -> DeviceArray:
if lazy.is_trivial(x._lazy_expr):
return x
else:
@ -923,8 +926,8 @@ def _force(x):
return force_fun(x)
@cache()
def _lazy_force_computation(sticky, aval, device, lexpr
) -> Callable[[DeviceValue], DeviceArray]:
def _lazy_force_computation(sticky, aval, device, lexpr) -> Callable[[DeviceArray], DeviceArray]:
c = xb.make_computation_builder("lazy_force")
if lazy.is_constant(lexpr):
param = None

View File

@ -675,7 +675,7 @@ class APITest(jtu.JaxTestCase):
ans = api.grad(foo, (0, 1))(3., 4.)
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
def test_defvjp_all(self):
def test_defvjp_all_custom_transforms(self):
@api.custom_transforms
def foo(x):
return np.sin(x)

View File

@ -248,13 +248,13 @@ class FftTest(jtu.JaxTestCase):
self.assertRaisesRegex(
ValueError,
"jax.np.fft.{} only supports 2 axes. "
"Got axes = \\[0\\].".format(name, name),
"Got axes = \\[0\\].".format(name),
lambda: func(rng([2, 3], dtype=onp.float64), axes=[0])
)
self.assertRaisesRegex(
ValueError,
"jax.np.fft.{} only supports 2 axes. "
"Got axes = \\(0, 1, 2\\).".format(name, name),
"Got axes = \\(0, 1, 2\\).".format(name),
lambda: func(rng([2, 3, 3], dtype=onp.float64), axes=(0, 1, 2))
)
self.assertRaises(

View File

@ -433,7 +433,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
(isinstance(axis, tuple) and len(axis) == 1)
else [None, 'fro', 1, 2, -1, -2, np.inf, -np.inf, 'nuc'])
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
for rng_factory in [jtu.rand_default])) # type: ignore
def testNorm(self, shape, dtype, ord, axis, keepdims, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)