[XLA] Enable truncated normal for double.
Fix a problem in testTruncatedNormalIsInRange that causes the test not actually run. Add testTruncatedNormalIsNotConstant for double. PiperOrigin-RevId: 257417015
This commit is contained in:
parent
7c7c449f17
commit
d7fbbc0023
@ -116,12 +116,14 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
def rng(dtype):
|
||||
return random_ops.truncated_normal(shape=[2], dtype=dtype)
|
||||
|
||||
self._testRngIsNotConstant(rng, dtypes.float32)
|
||||
# TODO(b/34339814): make this test work with 16 bit float types.
|
||||
for dtype in self._random_types() & {np.float32, np.float64}:
|
||||
self._testRngIsNotConstant(rng, dtype)
|
||||
|
||||
def testTruncatedNormalIsInRange(self):
|
||||
count = 10000000
|
||||
# TODO(b/34339814): make this test work with 16 bit float types.
|
||||
for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
|
||||
for dtype in self._random_types() & {np.float32, np.float64}:
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
||||
|
@ -293,7 +293,7 @@ class TruncatedNormalOp : public XlaOpKernel {
|
||||
|
||||
REGISTER_XLA_OP(Name("TruncatedNormal")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.TypeConstraint("dtype", DT_FLOAT),
|
||||
.TypeConstraint("dtype", {DT_FLOAT, DT_DOUBLE}),
|
||||
TruncatedNormalOp);
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user