From c24f20968bddb57eaa70948d581594f0d02ae51e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 31 Jul 2024 11:52:46 -0700 Subject: [PATCH] [array api] use jnp.astype directly --- jax/experimental/array_api/__init__.py | 5 +-- .../array_api/_data_type_functions.py | 37 ------------------- 2 files changed, 1 insertion(+), 41 deletions(-) delete mode 100644 jax/experimental/array_api/_data_type_functions.py diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 0d917b6ed..c3bc83112 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -58,6 +58,7 @@ from jax.numpy import ( asarray as asarray, asin as asin, asinh as asinh, + astype as astype, atan as atan, atan2 as atan2, atanh as atanh, @@ -190,7 +191,3 @@ from jax.numpy import ( zeros as zeros, zeros_like as zeros_like, ) - -from jax.experimental.array_api._data_type_functions import ( - astype as astype, -) diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py deleted file mode 100644 index 3ff95befc..000000000 --- a/jax/experimental/array_api/_data_type_functions.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import jax.numpy as jnp - -from jax._src.lib import xla_client as xc -from jax._src.sharding import Sharding -from jax._src import dtypes as _dtypes - - -# TODO(micky774): Remove when jax.numpy.astype is deprecation is completed -def astype(x, dtype, /, *, copy: bool = True, device: xc.Device | Sharding | None = None): - src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x) - if ( - src_dtype is not None - and _dtypes.isdtype(src_dtype, "complex floating") - and _dtypes.isdtype(dtype, ("integral", "real floating")) - ): - raise ValueError( - "Casting from complex to non-complex dtypes is not permitted. Please " - "first use jnp.real or jnp.imag to take the real/imaginary component of " - "your input." - ) - return jnp.astype(x, dtype, copy=copy, device=device)