Avoid stack overflow when JITting a function that uses copy.copy or copy.deepcopy. (#1834)

This commit is contained in:
George van den Driessche 2019-12-11 02:48:51 +00:00 committed by Peter Hawkins
parent 3167b3ddcd
commit a73106b37c
2 changed files with 17 additions and 0 deletions

View File

@ -372,6 +372,12 @@ class Tracer(object):
def __repr__(self):
return 'Traced<{}>with<{}>'.format(self.aval, self.trace)
def __copy__(self):
return self
def __deepcopy__(self, unused_memo):
return self
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals

View File

@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
import collections
import copy
from functools import partial
import unittest
import warnings
@ -1282,6 +1283,16 @@ class APITest(jtu.JaxTestCase):
python_should_be_executing = False
api.jit(f)(3)
def test_jit_shallow_copy(self):
def f(x):
return copy.copy(x)
api.jit(f)(1)
def test_jit_deep_copy(self):
def f(x):
return copy.deepcopy(x)
api.jit(f)(1)
def test_pmap_global_cache(self):
def f(x):
assert python_should_be_executing