Add int64 shape support on GPU for stateless random ops. (#13908)
* Add int64 shape support on GPU for stateless random ops. This fix adds int64 shape support on GPU for stateless random ops `StatelessRandomUniform`, `StatelessRandomNormal`, `StatelessTruncatedNormal`. The int64 shape for stateless random ops is already supported on CPU with int32/int64 processed properly through `MakeShape`. However, on GPU a type constraint `.TypeConstraint<int32>("T")` has been improperly added. Such a type constraint actually prevents an int64 shape type to run on GPU. (As a comparision, no type constraint on CPU). This fix removes the type constraint and allows int64 shape to be run on GPU. This fix also adds test cases for int64 shape support on stateless random ops. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for int64 shape support for stateless random ops. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add int32 to shape types tested.
This commit is contained in:
parent
0d437c3beb
commit
ac0004e711
tensorflow
contrib/stateless/python/kernel_tests
core/kernels
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.contrib import stateless
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -79,6 +80,21 @@ class StatelessOpsTest(test.TestCase):
|
||||
for s1, v1 in values:
|
||||
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
||||
|
||||
def testShapeType(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
for shape_dtype in [dtypes.int32, dtypes.int64]:
|
||||
seed_t = array_ops.placeholder(dtypes.int64, shape=[2])
|
||||
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
|
||||
for stateless_op, _ in CASES:
|
||||
for shape in (), (3,), (2, 5):
|
||||
pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype),
|
||||
seed=seed_t)
|
||||
values = [(seed, pure.eval(feed_dict={seed_t: seed}))
|
||||
for seed in seeds]
|
||||
for s0, v0 in values:
|
||||
for s1, v1 in values:
|
||||
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -137,7 +137,6 @@ TF_CALL_double(REGISTER);
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("seed") \
|
||||
.TypeConstraint<int32>("T") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomOp<GPUDevice, random::UniformDistribution< \
|
||||
random::PhiloxRandom, TYPE> >); \
|
||||
@ -146,7 +145,6 @@ TF_CALL_double(REGISTER);
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("seed") \
|
||||
.TypeConstraint<int32>("T") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomOp<GPUDevice, random::NormalDistribution< \
|
||||
random::PhiloxRandom, TYPE> >); \
|
||||
@ -155,7 +153,6 @@ TF_CALL_double(REGISTER);
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("seed") \
|
||||
.TypeConstraint<int32>("T") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomOp< \
|
||||
GPUDevice, \
|
||||
|
Loading…
Reference in New Issue
Block a user