mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #17669 from andportnoy:aportnoy/xmap-test-use-float32
PiperOrigin-RevId: 566732138
This commit is contained in:
commit
33d862fb93
@ -553,6 +553,7 @@ class XMapTest(XMapTestCase):
|
||||
"vmap_as_xmap": vmap_as_xmap}
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(VmapOfXmapCases))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testNestedMap(self,
|
||||
xmap_in_axes, xmap_out_axes,
|
||||
vmap_in_axes, vmap_out_axes, vmap_result_axis,
|
||||
@ -1341,6 +1342,7 @@ class PDotTests(XMapTestCase):
|
||||
self.assertAllClose(z, jnp.dot(x, y))
|
||||
|
||||
@jtu.with_mesh([('r1', 2)])
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testPdotBatching(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
|
Loading…
x
Reference in New Issue
Block a user