mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
PiperOrigin-RevId: 508114498
This commit is contained in:
parent
55c2b6dad6
commit
7b1128fdc4
@ -852,8 +852,8 @@ def disable_jit(disable: bool = True):
|
||||
... print("Value of y is", y)
|
||||
... return y + 3
|
||||
...
|
||||
>>> print(f(jax.numpy.array([1, 2, 3])))
|
||||
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)>
|
||||
>>> print(f(jax.numpy.array([1, 2, 3]))) # doctest:+ELLIPSIS
|
||||
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace...>
|
||||
[5 7 9]
|
||||
|
||||
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,
|
||||
|
@ -354,8 +354,8 @@ class Partial(functools.partial):
|
||||
>>> print_zero = Partial(print, 0)
|
||||
>>> print_zero()
|
||||
0
|
||||
>>> call_func(print_zero)
|
||||
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
|
||||
>>> call_func(print_zero) # doctest:+ELLIPSIS
|
||||
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace...>
|
||||
"""
|
||||
def __new__(klass, func, *args, **kw):
|
||||
# In Python 3.10+, if func is itself a functools.partial instance,
|
||||
|
@ -99,7 +99,7 @@ class JaxprStatsTest(jtu.JaxTestCase):
|
||||
# comes from contextlib.
|
||||
return jax.named_call(jnp.cos, name='test')(x)
|
||||
|
||||
hist = jaxpr_util.source_locations(make_jaxpr(f)(1.).jaxpr)
|
||||
hist = jaxpr_util.source_locations(make_jaxpr(f)(jnp.arange(8.)).jaxpr)
|
||||
for filename in hist.keys():
|
||||
self.assertIn(os.path.basename(__file__), filename)
|
||||
|
||||
|
@ -81,7 +81,8 @@ class MetadataTest(jtu.JaxTestCase):
|
||||
|
||||
def test_source_file_prefix_removal(self):
|
||||
def make_hlo():
|
||||
return jax.xla_computation(jnp.sin)(1.).get_hlo_module().to_string()
|
||||
return jax.xla_computation(jnp.sin)(
|
||||
jnp.arange(8.)).get_hlo_module().to_string()
|
||||
|
||||
# Sanity check
|
||||
self.assertIn("/tests/metadata_test.py", make_hlo())
|
||||
|
Loading…
x
Reference in New Issue
Block a user