Add int64 shape support on GPU for stateless random ops. (#13908)

* Add int64 shape support on GPU for stateless random ops.

This fix adds int64 shape support on GPU for stateless random ops
`StatelessRandomUniform`, `StatelessRandomNormal`, `StatelessTruncatedNormal`.

The int64 shape for stateless random ops is already supported on CPU
with int32/int64 processed properly through `MakeShape`.

However, on GPU a type constraint `.TypeConstraint<int32>("T")`
has been improperly added. Such a type constraint actually prevents
an int64 shape type to run on GPU. (As a comparision, no type constraint
on CPU).

This fix removes the type constraint and allows int64 shape to be run on GPU.

This fix also adds test cases for int64 shape support on stateless random ops.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test cases for int64 shape support for stateless random ops.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add int32 to shape types tested.
This commit is contained in:
Yong Tang 2017-10-22 22:50:20 -07:00 committed by Vijay Vasudevan
parent 0d437c3beb
commit ac0004e711
2 changed files with 16 additions and 3 deletions

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib import stateless
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
@ -79,6 +80,21 @@ class StatelessOpsTest(test.TestCase):
for s1, v1 in values:
self.assertEqual(s0 == s1, np.all(v0 == v1))
def testShapeType(self):
with self.test_session(use_gpu=True):
for shape_dtype in [dtypes.int32, dtypes.int64]:
seed_t = array_ops.placeholder(dtypes.int64, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
for stateless_op, _ in CASES:
for shape in (), (3,), (2, 5):
pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype),
seed=seed_t)
values = [(seed, pure.eval(feed_dict={seed_t: seed}))
for seed in seeds]
for s0, v0 in values:
for s1, v1 in values:
self.assertEqual(s0 == s1, np.all(v0 == v1))
if __name__ == '__main__':
test.main()

View File

@ -137,7 +137,6 @@ TF_CALL_double(REGISTER);
.Device(DEVICE_GPU) \
.HostMemory("shape") \
.HostMemory("seed") \
.TypeConstraint<int32>("T") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomOp<GPUDevice, random::UniformDistribution< \
random::PhiloxRandom, TYPE> >); \
@ -146,7 +145,6 @@ TF_CALL_double(REGISTER);
.Device(DEVICE_GPU) \
.HostMemory("shape") \
.HostMemory("seed") \
.TypeConstraint<int32>("T") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomOp<GPUDevice, random::NormalDistribution< \
random::PhiloxRandom, TYPE> >); \
@ -155,7 +153,6 @@ TF_CALL_double(REGISTER);
.Device(DEVICE_GPU) \
.HostMemory("shape") \
.HostMemory("seed") \
.TypeConstraint<int32>("T") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomOp< \
GPUDevice, \