mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Undefined name: from ..core import JaxTuple
[flake8](http://flake8.pycqa.org) testing of https://github.com/google/jax on Python 3.7.1 $ __flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics__ ``` ./jax/interpreters/ad.py:189:20: F821 undefined name 'JaxTuple' return xt, JaxTuple(map(zeros_like_jaxval, xt)) ^ ./jax/interpreters/ad.py:196:16: F821 undefined name 'JaxTuple' return JaxTuple(map(zeros_like_jaxval, yt)), yt ^ 2 F821 undefined name 'JaxTuple' 2 ```
This commit is contained in:
parent
7a313bf622
commit
33292ff962
@ -19,7 +19,7 @@ from __future__ import print_function
|
||||
from . import partial_eval as pe
|
||||
from . import xla
|
||||
from .. import core as core
|
||||
from ..core import Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
|
||||
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
|
||||
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, zero, Zero)
|
||||
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
|
||||
|
Loading…
x
Reference in New Issue
Block a user