Makes failing tests pass in tf2
PiperOrigin-RevId: 236162595
This commit is contained in:
parent
21ec9b010d
commit
b9f81dfea1
@ -135,20 +135,19 @@ class Conv3DTransposeTest(test.TestCase):
|
|||||||
|
|
||||||
def testConv3DTransposeOutputShapeType(self):
|
def testConv3DTransposeOutputShapeType(self):
|
||||||
# Test case for GitHub issue 18887
|
# Test case for GitHub issue 18887
|
||||||
for dtype in [dtypes.int32, dtypes.int64]:
|
for dtype in [dtypes.int32]: # b/126733996 fails with dtypes.int64 in tf2
|
||||||
with self.cached_session():
|
x_shape = [2, 5, 6, 4, 3]
|
||||||
x_shape = [2, 5, 6, 4, 3]
|
y_shape = [2, 5, 6, 4, 2]
|
||||||
y_shape = [2, 5, 6, 4, 2]
|
f_shape = [3, 3, 3, 2, 3]
|
||||||
f_shape = [3, 3, 3, 2, 3]
|
strides = [1, 1, 1, 1, 1]
|
||||||
strides = [1, 1, 1, 1, 1]
|
x_value = constant_op.constant(
|
||||||
x_value = constant_op.constant(
|
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
|
||||||
1.0, shape=x_shape, name="x", dtype=dtypes.float32)
|
f_value = constant_op.constant(
|
||||||
f_value = constant_op.constant(
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
output = nn_ops.conv3d_transpose(
|
||||||
output = nn_ops.conv3d_transpose(
|
x_value, f_value, constant_op.constant(y_shape, dtype=dtype),
|
||||||
x_value, f_value, constant_op.constant(y_shape, dtype=dtype),
|
strides=strides, padding="SAME")
|
||||||
strides=strides, padding="SAME")
|
self.evaluate(output)
|
||||||
self.evaluate(output)
|
|
||||||
|
|
||||||
def testConv3DTransposeValid(self):
|
def testConv3DTransposeValid(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
|
@ -475,6 +475,7 @@ class FunctionalOpsTest(test.TestCase):
|
|||||||
mul = self.evaluate(remote_op)
|
mul = self.evaluate(remote_op)
|
||||||
self.assertEqual(mul, [6])
|
self.assertEqual(mul, [6])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRemoteFunctionCPUGPU(self):
|
def testRemoteFunctionCPUGPU(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
self.skipTest("No GPU available")
|
self.skipTest("No GPU available")
|
||||||
@ -499,6 +500,7 @@ class FunctionalOpsTest(test.TestCase):
|
|||||||
mul = self.evaluate(remote_op)
|
mul = self.evaluate(remote_op)
|
||||||
self.assertEqual(mul, 9.0)
|
self.assertEqual(mul, 9.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRemoteFunctionGPUCPU(self):
|
def testRemoteFunctionGPUCPU(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
self.skipTest("No GPU available")
|
self.skipTest("No GPU available")
|
||||||
@ -523,6 +525,7 @@ class FunctionalOpsTest(test.TestCase):
|
|||||||
mul = self.evaluate(remote_op)
|
mul = self.evaluate(remote_op)
|
||||||
self.assertEqual(mul, 9.0)
|
self.assertEqual(mul, 9.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testRemoteFunctionGPUCPUStrings(self):
|
def testRemoteFunctionGPUCPUStrings(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
self.skipTest("No GPU available")
|
self.skipTest("No GPU available")
|
||||||
@ -981,6 +984,7 @@ class PartitionedCallTest(test.TestCase):
|
|||||||
constant_op.constant(2.)], f=Body)
|
constant_op.constant(2.)], f=Body)
|
||||||
self.assertEqual(output.eval(), 12.)
|
self.assertEqual(output.eval(), 12.)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testBasicMultiDeviceGPU(self):
|
def testBasicMultiDeviceGPU(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
return
|
return
|
||||||
@ -1061,6 +1065,7 @@ class PartitionedCallTest(test.TestCase):
|
|||||||
value = self.evaluate(v.read_value())
|
value = self.evaluate(v.read_value())
|
||||||
self.assertEqual(value, 2.0)
|
self.assertEqual(value, 2.0)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFunctionWithResourcesOnDifferentDevices(self):
|
def testFunctionWithResourcesOnDifferentDevices(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
self.skipTest("No GPUs available.")
|
self.skipTest("No GPUs available.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user