[XLA] Enable truncated normal for double.
Fix a problem in testTruncatedNormalIsInRange that causes the test not actually run. Add testTruncatedNormalIsNotConstant for double. PiperOrigin-RevId: 257417015
This commit is contained in:
parent
7c7c449f17
commit
d7fbbc0023
@ -116,12 +116,14 @@ class RandomOpsTest(xla_test.XLATestCase):
|
|||||||
def rng(dtype):
|
def rng(dtype):
|
||||||
return random_ops.truncated_normal(shape=[2], dtype=dtype)
|
return random_ops.truncated_normal(shape=[2], dtype=dtype)
|
||||||
|
|
||||||
self._testRngIsNotConstant(rng, dtypes.float32)
|
# TODO(b/34339814): make this test work with 16 bit float types.
|
||||||
|
for dtype in self._random_types() & {np.float32, np.float64}:
|
||||||
|
self._testRngIsNotConstant(rng, dtype)
|
||||||
|
|
||||||
def testTruncatedNormalIsInRange(self):
|
def testTruncatedNormalIsInRange(self):
|
||||||
count = 10000000
|
count = 10000000
|
||||||
# TODO(b/34339814): make this test work with 16 bit float types.
|
# TODO(b/34339814): make this test work with 16 bit float types.
|
||||||
for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
|
for dtype in self._random_types() & {np.float32, np.float64}:
|
||||||
with self.session() as sess:
|
with self.session() as sess:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
||||||
|
@ -293,7 +293,7 @@ class TruncatedNormalOp : public XlaOpKernel {
|
|||||||
|
|
||||||
REGISTER_XLA_OP(Name("TruncatedNormal")
|
REGISTER_XLA_OP(Name("TruncatedNormal")
|
||||||
.CompileTimeConstantInput("shape")
|
.CompileTimeConstantInput("shape")
|
||||||
.TypeConstraint("dtype", DT_FLOAT),
|
.TypeConstraint("dtype", {DT_FLOAT, DT_DOUBLE}),
|
||||||
TruncatedNormalOp);
|
TruncatedNormalOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user