STT-tensorflow/tensorflow/python/framework/errors_test.py
Yanhua Sun 2800f688ff Add python/lib BUILD and refactor python/BUILD
PiperOrigin-RevId: 346841489
Change-Id: Id8c9e27ccf012c7b9668e7fd56d5d97d3ff09370
2020-12-10 12:42:17 -08:00

152 lines
5.9 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.errors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import pickle
import warnings
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.lib.io import _pywrap_file_io
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class ErrorsTest(test.TestCase):
def _CountReferences(self, typeof):
"""Count number of references to objects of type |typeof|."""
objs = gc.get_objects()
ref_count = 0
for o in objs:
try:
if isinstance(o, typeof):
ref_count += 1
# Certain versions of python keeps a weakref to deleted objects.
except ReferenceError:
pass
return ref_count
def testUniqueClassForEachErrorCode(self):
for error_code, exc_type in [
(errors.CANCELLED, errors_impl.CancelledError),
(errors.UNKNOWN, errors_impl.UnknownError),
(errors.INVALID_ARGUMENT, errors_impl.InvalidArgumentError),
(errors.DEADLINE_EXCEEDED, errors_impl.DeadlineExceededError),
(errors.NOT_FOUND, errors_impl.NotFoundError),
(errors.ALREADY_EXISTS, errors_impl.AlreadyExistsError),
(errors.PERMISSION_DENIED, errors_impl.PermissionDeniedError),
(errors.UNAUTHENTICATED, errors_impl.UnauthenticatedError),
(errors.RESOURCE_EXHAUSTED, errors_impl.ResourceExhaustedError),
(errors.FAILED_PRECONDITION, errors_impl.FailedPreconditionError),
(errors.ABORTED, errors_impl.AbortedError),
(errors.OUT_OF_RANGE, errors_impl.OutOfRangeError),
(errors.UNIMPLEMENTED, errors_impl.UnimplementedError),
(errors.INTERNAL, errors_impl.InternalError),
(errors.UNAVAILABLE, errors_impl.UnavailableError),
(errors.DATA_LOSS, errors_impl.DataLossError),
]:
# pylint: disable=protected-access
self.assertTrue(
isinstance(
errors_impl._make_specific_exception(None, None, None,
error_code), exc_type))
# error_code_from_exception_type and exception_type_from_error_code should
# be consistent with operation result.
self.assertEqual(error_code,
errors_impl.error_code_from_exception_type(exc_type))
# pylint: enable=protected-access
def testKnownErrorClassForEachErrorCodeInProto(self):
for error_code in error_codes_pb2.Code.values():
# pylint: disable=line-too-long
if error_code in (
error_codes_pb2.OK, error_codes_pb2.
DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_
):
continue
# pylint: enable=line-too-long
with warnings.catch_warnings(record=True) as w:
# pylint: disable=protected-access
exc = errors_impl._make_specific_exception(None, None, None, error_code)
# pylint: enable=protected-access
self.assertEqual(0, len(w)) # No warning is raised.
self.assertTrue(isinstance(exc, errors_impl.OpError))
self.assertTrue(errors_impl.OpError in exc.__class__.__bases__)
def testUnknownErrorCodeCausesWarning(self):
with warnings.catch_warnings(record=True) as w:
# pylint: disable=protected-access
exc = errors_impl._make_specific_exception(None, None, None, 37)
# pylint: enable=protected-access
self.assertEqual(1, len(w))
self.assertTrue("Unknown error code: 37" in str(w[0].message))
self.assertTrue(isinstance(exc, errors_impl.OpError))
with warnings.catch_warnings(record=True) as w:
# pylint: disable=protected-access
exc = errors_impl.error_code_from_exception_type("Unknown")
# pylint: enable=protected-access
self.assertEqual(1, len(w))
self.assertTrue("Unknown class exception" in str(w[0].message))
self.assertTrue(isinstance(exc, errors_impl.OpError))
def testStatusDoesNotLeak(self):
try:
_pywrap_file_io.DeleteFile(compat.as_bytes("/DOES_NOT_EXIST/"))
except:
pass
gc.collect()
self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus))
def testPickleable(self):
for error_code in [
errors.CANCELLED,
errors.UNKNOWN,
errors.INVALID_ARGUMENT,
errors.DEADLINE_EXCEEDED,
errors.NOT_FOUND,
errors.ALREADY_EXISTS,
errors.PERMISSION_DENIED,
errors.UNAUTHENTICATED,
errors.RESOURCE_EXHAUSTED,
errors.FAILED_PRECONDITION,
errors.ABORTED,
errors.OUT_OF_RANGE,
errors.UNIMPLEMENTED,
errors.INTERNAL,
errors.UNAVAILABLE,
errors.DATA_LOSS,
]:
# pylint: disable=protected-access
exc = errors_impl._make_specific_exception(None, None, None, error_code)
# pylint: enable=protected-access
unpickled = pickle.loads(pickle.dumps(exc))
self.assertEqual(exc.node_def, unpickled.node_def)
self.assertEqual(exc.op, unpickled.op)
self.assertEqual(exc.message, unpickled.message)
self.assertEqual(exc.error_code, unpickled.error_code)
if __name__ == "__main__":
test.main()