make constant handlers follow type mro

fixes #6129
This commit is contained in:
Matthew Johnson 2021-03-18 18:05:22 -07:00
parent d1a8ad076b
commit ee4ec867ec
2 changed files with 23 additions and 6 deletions

View File

@ -29,11 +29,13 @@ from absl import logging
logging._warn_preinit_stderr = 0
from ..config import flags
from jax._src import util
from jax._src import util, traceback_util
from .. import dtypes
import numpy as np
import threading
traceback_util.register_exclusion(__file__)
try:
from . import tpu_client
except ImportError:
@ -335,11 +337,12 @@ def constant(builder, py_val, canonicalize_types=True):
Returns:
A representation of the constant, either a ComputationDataHandle or None
"""
py_type = type(py_val)
if py_type in _constant_handlers:
return _constant_handlers[py_type](builder, py_val, canonicalize_types)
else:
raise TypeError("No constant handler for type: {}".format(py_type))
for t in type(py_val).mro():
handler = _constant_handlers.get(t)
if handler: return handler(builder, py_val, canonicalize_types)
if hasattr(py_val, '__jax_array__'):
return constant(builder, py_val.__jax_array__(), canonicalize_types)
raise TypeError("No constant handler for type: {}".format(type(py_val)))
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see

View File

@ -16,6 +16,7 @@
import collections
from contextlib import contextmanager
import copy
import enum
from functools import partial
import re
import unittest
@ -2359,6 +2360,19 @@ class APITest(jtu.JaxTestCase):
for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]:
self.assertEqual(f(x), f(a))
def test_constant_handler_mro(self):
# https://github.com/google/jax/issues/6129
class Foo(enum.IntEnum):
bar = 1
@api.pmap
def f(_):
return Foo.bar
ans = f(jnp.arange(1)) # doesn't crash
expected = jnp.arange(1) + 1
self.assertAllClose(ans, expected)
class RematTest(jtu.JaxTestCase):