diff --git a/tensorflow/python/kernel_tests/save_restore_ops_test.py b/tensorflow/python/kernel_tests/save_restore_ops_test.py index fecc9a3800f..19fd22bd87a 100644 --- a/tensorflow/python/kernel_tests/save_restore_ops_test.py +++ b/tensorflow/python/kernel_tests/save_restore_ops_test.py @@ -23,6 +23,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import io_ops @@ -55,24 +56,24 @@ class ShardedFileOpsTest(test.TestCase): class ShapeInferenceTest(test.TestCase): - @test_util.run_deprecated_v1 def testRestoreV2WithSliceInput(self): - op = io_ops.restore_v2("model", ["var1", "var2"], ["", "3 4 0,1:-"], - [dtypes.float32, dtypes.float32]) - self.assertEqual(2, len(op)) - self.assertFalse(op[0].get_shape().is_fully_defined()) - self.assertEqual([1, 4], op[1].get_shape()) + with ops.Graph().as_default(): + op = io_ops.restore_v2("model", ["var1", "var2"], ["", "3 4 0,1:-"], + [dtypes.float32, dtypes.float32]) + self.assertEqual(2, len(op)) + self.assertFalse(op[0].get_shape().is_fully_defined()) + self.assertEqual([1, 4], op[1].get_shape()) - @test_util.run_deprecated_v1 def testRestoreV2NumSlicesNotMatch(self): - with self.assertRaises(ValueError): - io_ops.restore_v2("model", ["var1", "var2", "var3"], ["", "3 4 0,1:-"], - [dtypes.float32, dtypes.float32]) + with ops.Graph().as_default(): + with self.assertRaises(ValueError): + io_ops.restore_v2("model", ["var1", "var2", "var3"], ["", "3 4 0,1:-"], + [dtypes.float32, dtypes.float32]) - @test_util.run_deprecated_v1 def testRestoreSlice(self): - op = gen_io_ops.restore_slice("model", "var", "3 4 0,1:-", dtypes.float32) - self.assertEqual([1, 4], op.get_shape()) + with ops.Graph().as_default(): + op = gen_io_ops.restore_slice("model", "var", "3 4 0,1:-", dtypes.float32) + self.assertEqual([1, 4], op.get_shape()) if __name__ == "__main__":