Fix Conv3DBackpropFilterOp with int64 input_sizes on GPU.

Re-enable the test that was failing before.

PiperOrigin-RevId: 236230029
This commit is contained in:
Akshay Modi 2019-02-28 17:34:00 -08:00 committed by TensorFlower Gardener
parent e91d746e0b
commit f86747fe82
3 changed files with 16 additions and 22 deletions

View File

@ -1145,8 +1145,7 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
TensorShape input_shape; TensorShape input_shape;
if (takes_shape_) { if (takes_shape_) {
const Tensor& input_sizes = context->input(0); const Tensor& input_sizes = context->input(0);
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
input_sizes.vec<int32>(), &input_shape));
} else { } else {
input_shape = context->input(0).shape(); input_shape = context->input(0).shape();
} }
@ -1530,8 +1529,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
TensorShape filter_shape; TensorShape filter_shape;
if (takes_shape_) { if (takes_shape_) {
const Tensor& filter_sizes = context->input(1); const Tensor& filter_sizes = context->input(1);
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( OP_REQUIRES_OK(context, MakeShape(filter_sizes, &filter_shape));
filter_sizes.vec<int32>(), &filter_shape));
} else { } else {
filter_shape = context->input(1).shape(); filter_shape = context->input(1).shape();
} }

View File

@ -135,19 +135,20 @@ class Conv3DTransposeTest(test.TestCase):
def testConv3DTransposeOutputShapeType(self): def testConv3DTransposeOutputShapeType(self):
# Test case for GitHub issue 18887 # Test case for GitHub issue 18887
for dtype in [dtypes.int32]: # b/126733996 fails with dtypes.int64 in tf2 for dtype in [dtypes.int32, dtypes.int64]:
x_shape = [2, 5, 6, 4, 3] with self.cached_session():
y_shape = [2, 5, 6, 4, 2] x_shape = [2, 5, 6, 4, 3]
f_shape = [3, 3, 3, 2, 3] y_shape = [2, 5, 6, 4, 2]
strides = [1, 1, 1, 1, 1] f_shape = [3, 3, 3, 2, 3]
x_value = constant_op.constant( strides = [1, 1, 1, 1, 1]
1.0, shape=x_shape, name="x", dtype=dtypes.float32) x_value = constant_op.constant(
f_value = constant_op.constant( 1.0, shape=x_shape, name="x", dtype=dtypes.float32)
1.0, shape=f_shape, name="filter", dtype=dtypes.float32) f_value = constant_op.constant(
output = nn_ops.conv3d_transpose( 1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
x_value, f_value, constant_op.constant(y_shape, dtype=dtype), output = nn_ops.conv3d_transpose(
strides=strides, padding="SAME") x_value, f_value, constant_op.constant(y_shape, dtype=dtype),
self.evaluate(output) strides=strides, padding="SAME")
self.evaluate(output)
def testConv3DTransposeValid(self): def testConv3DTransposeValid(self):
with self.cached_session(): with self.cached_session():

View File

@ -475,7 +475,6 @@ class FunctionalOpsTest(test.TestCase):
mul = self.evaluate(remote_op) mul = self.evaluate(remote_op)
self.assertEqual(mul, [6]) self.assertEqual(mul, [6])
@test_util.run_deprecated_v1
def testRemoteFunctionCPUGPU(self): def testRemoteFunctionCPUGPU(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -500,7 +499,6 @@ class FunctionalOpsTest(test.TestCase):
mul = self.evaluate(remote_op) mul = self.evaluate(remote_op)
self.assertEqual(mul, 9.0) self.assertEqual(mul, 9.0)
@test_util.run_deprecated_v1
def testRemoteFunctionGPUCPU(self): def testRemoteFunctionGPUCPU(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -525,7 +523,6 @@ class FunctionalOpsTest(test.TestCase):
mul = self.evaluate(remote_op) mul = self.evaluate(remote_op)
self.assertEqual(mul, 9.0) self.assertEqual(mul, 9.0)
@test_util.run_deprecated_v1
def testRemoteFunctionGPUCPUStrings(self): def testRemoteFunctionGPUCPUStrings(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPU available") self.skipTest("No GPU available")
@ -984,7 +981,6 @@ class PartitionedCallTest(test.TestCase):
constant_op.constant(2.)], f=Body) constant_op.constant(2.)], f=Body)
self.assertEqual(output.eval(), 12.) self.assertEqual(output.eval(), 12.)
@test_util.run_deprecated_v1
def testBasicMultiDeviceGPU(self): def testBasicMultiDeviceGPU(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
return return
@ -1065,7 +1061,6 @@ class PartitionedCallTest(test.TestCase):
value = self.evaluate(v.read_value()) value = self.evaluate(v.read_value())
self.assertEqual(value, 2.0) self.assertEqual(value, 2.0)
@test_util.run_deprecated_v1
def testFunctionWithResourcesOnDifferentDevices(self): def testFunctionWithResourcesOnDifferentDevices(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPUs available.") self.skipTest("No GPUs available.")