Enable int32 on GPU for tf.tile (#12183)
* Enable int32 on GPU for tf.tile. This fix enabled int32 on GPU for tf.tile, to fix the following error: ``` import tensorflow as tf with tf.device('/gpu:0'): tt = tf.tile(tf.range(4), [3]) with tf.Session() as sess: print(sess.run(tt)) ``` This fix fixes 12169. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Enable int32 for TileGradOp Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
a9e2817c35
commit
92111fdd1a
@ -536,6 +536,12 @@ REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<int32>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<complex64>("T")
|
||||
@ -573,6 +579,12 @@ REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<int32>("T")
|
||||
.TypeConstraint<int32>("Tmultiples")
|
||||
.HostMemory("multiples"),
|
||||
TileGradientOp<GPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<complex64>("T")
|
||||
|
Loading…
Reference in New Issue
Block a user