From 08563842b9815382e890a00041c20416dfd7a1f6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 7 Feb 2025 09:56:52 -0800 Subject: [PATCH] DOC: make clear that printoptions are NumPy aliases --- jax/_src/numpy/lax_numpy.py | 37 +++++++++++++++++++++++++++++++++---- tests/lax_numpy_test.py | 3 ++- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index e3b17ac6d..b7749ece5 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -111,11 +111,40 @@ euler_gamma = np.euler_gamma inf = np.inf nan = np.nan -# NumPy utility functions +# Wrappers for NumPy printoptions -get_printoptions = np.get_printoptions -printoptions = np.printoptions -set_printoptions = np.set_printoptions +def get_printoptions(): + """Alias of :func:`numpy.get_printoptions`. + + JAX arrays are printed via NumPy, so NumPy's `printoptions` + configurations will apply to printed JAX arrays. + + See the :func:`numpy.set_printoptions` documentation for details + on the available options and their meanings. + """ + return np.get_printoptions() + +def printoptions(*args, **kwargs): + """Alias of :func:`numpy.printoptions`. + + JAX arrays are printed via NumPy, so NumPy's `printoptions` + configurations will apply to printed JAX arrays. + + See the :func:`numpy.set_printoptions` documentation for details + on the available options and their meanings. + """ + return np.printoptions(*args, **kwargs) + +def set_printoptions(*args, **kwargs): + """Alias of :func:`numpy.set_printoptions`. + + JAX arrays are printed via NumPy, so NumPy's `printoptions` + configurations will apply to printed JAX arrays. + + See the :func:`numpy.set_printoptions` documentation for details + on the available options and their meanings. + """ + return np.set_printoptions(*args, **kwargs) @export def iscomplexobj(x: Any) -> bool: diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 00d8f8dc0..1d6cfad8a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6567,7 +6567,8 @@ class NumpyDocTests(jtu.JaxTestCase): aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', 'amax', 'amin', 'around', 'bitwise_invert', 'bitwise_left_shift', 'bitwise_not','bitwise_right_shift', 'conj', 'degrees', 'divide', - 'mod', 'pow', 'radians', 'round_'] + 'get_printoptions', 'mod', 'pow', 'printoptions', 'radians', 'round_', + 'set_printoptions'] skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split'] for name in dir(jnp):