jax.typing: recommend instance check in Python 3.10 or newer

This commit is contained in:
Jake VanderPlas 2023-03-27 10:01:28 -07:00
parent 40fb646e35
commit ed9fa1342b

View File

@ -41,7 +41,10 @@ For example, your function might look like this::
from jax.typing import ArrayLike
def my_function(x: ArrayLike) -> Array:
# Runtime type validation:
# Runtime type validation, Python 3.10 or newer:
if not isinstance(x, ArrayLike):
raise TypeError(f"Expected arraylike input; got {x}")
# Runtime type validation, any Python version:
if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
raise TypeError(f"Expected arraylike input; got {x}")