Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.

PiperOrigin-RevId: 508114498
This commit is contained in:
Yash Katariya 2023-02-08 10:16:42 -08:00 committed by jax authors
parent 55c2b6dad6
commit 7b1128fdc4
4 changed files with 7 additions and 6 deletions

View File

@ -852,8 +852,8 @@ def disable_jit(disable: bool = True):
... print("Value of y is", y)
... return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)>
>>> print(f(jax.numpy.array([1, 2, 3]))) # doctest:+ELLIPSIS
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace...>
[5 7 9]
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,

View File

@ -354,8 +354,8 @@ class Partial(functools.partial):
>>> print_zero = Partial(print, 0)
>>> print_zero()
0
>>> call_func(print_zero)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
>>> call_func(print_zero) # doctest:+ELLIPSIS
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace...>
"""
def __new__(klass, func, *args, **kw):
# In Python 3.10+, if func is itself a functools.partial instance,

View File

@ -99,7 +99,7 @@ class JaxprStatsTest(jtu.JaxTestCase):
# comes from contextlib.
return jax.named_call(jnp.cos, name='test')(x)
hist = jaxpr_util.source_locations(make_jaxpr(f)(1.).jaxpr)
hist = jaxpr_util.source_locations(make_jaxpr(f)(jnp.arange(8.)).jaxpr)
for filename in hist.keys():
self.assertIn(os.path.basename(__file__), filename)

View File

@ -81,7 +81,8 @@ class MetadataTest(jtu.JaxTestCase):
def test_source_file_prefix_removal(self):
def make_hlo():
return jax.xla_computation(jnp.sin)(1.).get_hlo_module().to_string()
return jax.xla_computation(jnp.sin)(
jnp.arange(8.)).get_hlo_module().to_string()
# Sanity check
self.assertIn("/tests/metadata_test.py", make_hlo())