mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
d1a8ad076b
commit
ee4ec867ec
@ -29,11 +29,13 @@ from absl import logging
|
||||
logging._warn_preinit_stderr = 0
|
||||
|
||||
from ..config import flags
|
||||
from jax._src import util
|
||||
from jax._src import util, traceback_util
|
||||
from .. import dtypes
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
try:
|
||||
from . import tpu_client
|
||||
except ImportError:
|
||||
@ -335,11 +337,12 @@ def constant(builder, py_val, canonicalize_types=True):
|
||||
Returns:
|
||||
A representation of the constant, either a ComputationDataHandle or None
|
||||
"""
|
||||
py_type = type(py_val)
|
||||
if py_type in _constant_handlers:
|
||||
return _constant_handlers[py_type](builder, py_val, canonicalize_types)
|
||||
else:
|
||||
raise TypeError("No constant handler for type: {}".format(py_type))
|
||||
for t in type(py_val).mro():
|
||||
handler = _constant_handlers.get(t)
|
||||
if handler: return handler(builder, py_val, canonicalize_types)
|
||||
if hasattr(py_val, '__jax_array__'):
|
||||
return constant(builder, py_val.__jax_array__(), canonicalize_types)
|
||||
raise TypeError("No constant handler for type: {}".format(type(py_val)))
|
||||
|
||||
# HLO instructions optionally can be annotated to say how the output should be
|
||||
# spatially partitioned (represented in XLA as OpSharding protos, see
|
||||
|
@ -16,6 +16,7 @@
|
||||
import collections
|
||||
from contextlib import contextmanager
|
||||
import copy
|
||||
import enum
|
||||
from functools import partial
|
||||
import re
|
||||
import unittest
|
||||
@ -2359,6 +2360,19 @@ class APITest(jtu.JaxTestCase):
|
||||
for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]:
|
||||
self.assertEqual(f(x), f(a))
|
||||
|
||||
def test_constant_handler_mro(self):
|
||||
# https://github.com/google/jax/issues/6129
|
||||
|
||||
class Foo(enum.IntEnum):
|
||||
bar = 1
|
||||
|
||||
@api.pmap
|
||||
def f(_):
|
||||
return Foo.bar
|
||||
|
||||
ans = f(jnp.arange(1)) # doesn't crash
|
||||
expected = jnp.arange(1) + 1
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user