mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Avoid stack overflow when JITting a function that uses copy.copy or copy.deepcopy. (#1834)
This commit is contained in:
parent
3167b3ddcd
commit
a73106b37c
@ -372,6 +372,12 @@ class Tracer(object):
|
||||
def __repr__(self):
|
||||
return 'Traced<{}>with<{}>'.format(self.aval, self.trace)
|
||||
|
||||
def __copy__(self):
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, unused_memo):
|
||||
return self
|
||||
|
||||
|
||||
# these can be used to set up forwarding of properties and instance methods from
|
||||
# Tracer instances to the underlying avals
|
||||
|
@ -17,6 +17,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from functools import partial
|
||||
import unittest
|
||||
import warnings
|
||||
@ -1282,6 +1283,16 @@ class APITest(jtu.JaxTestCase):
|
||||
python_should_be_executing = False
|
||||
api.jit(f)(3)
|
||||
|
||||
def test_jit_shallow_copy(self):
|
||||
def f(x):
|
||||
return copy.copy(x)
|
||||
api.jit(f)(1)
|
||||
|
||||
def test_jit_deep_copy(self):
|
||||
def f(x):
|
||||
return copy.deepcopy(x)
|
||||
api.jit(f)(1)
|
||||
|
||||
def test_pmap_global_cache(self):
|
||||
def f(x):
|
||||
assert python_should_be_executing
|
||||
|
Loading…
x
Reference in New Issue
Block a user