[XLA] Enable truncated normal for double.

Fix a problem in testTruncatedNormalIsInRange that causes the test not actually
run.

Add testTruncatedNormalIsNotConstant for double.

PiperOrigin-RevId: 257417015
This commit is contained in:
Bixia Zheng 2019-07-10 09:14:31 -07:00 committed by TensorFlower Gardener
parent 7c7c449f17
commit d7fbbc0023
2 changed files with 5 additions and 3 deletions

View File

@ -116,12 +116,14 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
self._testRngIsNotConstant(rng, dtypes.float32)
# TODO(b/34339814): make this test work with 16 bit float types.
for dtype in self._random_types() & {np.float32, np.float64}:
self._testRngIsNotConstant(rng, dtype)
def testTruncatedNormalIsInRange(self):
count = 10000000
# TODO(b/34339814): make this test work with 16 bit float types.
for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
for dtype in self._random_types() & {np.float32, np.float64}:
with self.session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)

View File

@ -293,7 +293,7 @@ class TruncatedNormalOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("TruncatedNormal")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype", DT_FLOAT),
.TypeConstraint("dtype", {DT_FLOAT, DT_DOUBLE}),
TruncatedNormalOp);
} // namespace