STT-tensorflow/tensorflow/python/util/tf_should_use_test.py
Gaurav Jain f618ab4955 Move away from deprecated asserts
- assertEquals -> assertEqual
- assertRaisesRegexp -> assertRegexpMatches
- assertRegexpMatches -> assertRegex

PiperOrigin-RevId: 319118081
Change-Id: Ieb457128522920ab55d6b69a7f244ab798a7d689
2020-06-30 16:10:22 -07:00

144 lines
4.9 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for tf_should_use."""
# pylint: disable=unused-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import gc
import sys
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import tf_should_use
@contextlib.contextmanager
def reroute_error():
"""Temporarily reroute errors written to tf_logging.error into `captured`."""
with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error:
yield error
class TfShouldUseTest(test.TestCase):
def testAddShouldUseWarningWhenNotUsed(self):
c = constant_op.constant(0, name='blah0')
def in_this_function():
h = tf_should_use._add_should_use_warning(c, warn_in_eager=True)
del h
with reroute_error() as error:
in_this_function()
msg = '\n'.join(error.call_args[0])
self.assertIn('Object was never used', msg)
if not context.executing_eagerly():
self.assertIn('blah0:0', msg)
self.assertIn('in_this_function', msg)
self.assertFalse(gc.garbage)
def testAddShouldUseExceptionInEagerAndFunction(self):
def in_this_function():
c = constant_op.constant(0, name='blah0')
h = tf_should_use._add_should_use_warning(
c, warn_in_eager=True, error_in_function=True)
del h
if context.executing_eagerly():
with reroute_error() as error:
in_this_function()
msg = '\n'.join(error.call_args[0])
self.assertIn('Object was never used', msg)
self.assertIn('in_this_function', msg)
self.assertFalse(gc.garbage)
tf_fn_in_this_function = def_function.function(in_this_function)
with self.assertRaisesRegex(RuntimeError,
r'Object was never used.*blah0:0'):
tf_fn_in_this_function()
self.assertFalse(gc.garbage)
def _testAddShouldUseWarningWhenUsed(self, fn, name):
c = constant_op.constant(0, name=name)
with reroute_error() as error:
h = tf_should_use._add_should_use_warning(c, warn_in_eager=True)
fn(h)
del h
error.assert_not_called()
def testAddShouldUseWarningWhenUsedWithAdd(self):
def add(h):
_ = h + 1
self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
gc.collect()
self.assertFalse(gc.garbage)
def testAddShouldUseWarningWhenUsedWithGetShape(self):
def get_shape(h):
_ = h.shape
self._testAddShouldUseWarningWhenUsed(get_shape, name='blah_get_name')
gc.collect()
self.assertFalse(gc.garbage)
def testShouldUseResult(self):
@tf_should_use.should_use_result(warn_in_eager=True)
def return_const(value):
return constant_op.constant(value, name='blah2')
with reroute_error() as error:
return_const(0.0)
msg = '\n'.join(error.call_args[0])
self.assertIn('Object was never used', msg)
if not context.executing_eagerly():
self.assertIn('blah2:0', msg)
self.assertIn('return_const', msg)
gc.collect()
self.assertFalse(gc.garbage)
def testShouldUseResultWhenNotReallyUsed(self):
@tf_should_use.should_use_result(warn_in_eager=True)
def return_const(value):
return constant_op.constant(value, name='blah3')
with reroute_error() as error:
with self.cached_session():
return_const(0.0)
# Creating another op and executing it does not mark the
# unused op as being "used".
v = constant_op.constant(1.0, name='meh')
self.evaluate(v)
msg = '\n'.join(error.call_args[0])
self.assertIn('Object was never used', msg)
if not context.executing_eagerly():
self.assertIn('blah3:0', msg)
self.assertIn('return_const', msg)
gc.collect()
self.assertFalse(gc.garbage)
# Tests that mark_used is available in the API.
def testMarkUsed(self):
@tf_should_use.should_use_result(warn_in_eager=True)
def return_const(value):
return constant_op.constant(value, name='blah3')
with self.cached_session():
return_const(0.0).mark_used()
if __name__ == '__main__':
test.main()