Use overloaded operators for the assert statement. This should remove the reliance on importing tensorflow in the generated code.

PiperOrigin-RevId: 216528047
This commit is contained in:
Dan Moldovan 2018-10-10 07:38:42 -07:00 committed by TensorFlower Gardener
parent e851764c24
commit 93226f635c
6 changed files with 200 additions and 9 deletions

View File

@ -33,7 +33,7 @@ class AssertTransformer(converter.Base):
# Note: The lone tf.Assert call will be wrapped with control_dependencies
# by side_effect_guards.
template = """
tf.Assert(test, (msg,))
ag__.assert_stmt(test, lambda: msg)
"""
if node.msg is None:

View File

@ -18,24 +18,30 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gast
from tensorflow.python.autograph.converters import asserts
from tensorflow.python.autograph.converters import side_effect_guards
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.platform import test
class AssertsTest(converter_testing.TestCase):
def test_transform(self):
def test_basic(self):
def test_fn(a):
assert a > 0
assert a, 'test message'
return tf.no_op() # pylint:disable=undefined-variable
node, ctx = self.prepare(test_fn, {})
node = asserts.transform(node, ctx)
self.assertTrue(isinstance(node.body[0].value, gast.Call))
with self.converted(test_fn, (asserts, side_effect_guards), {},
gen_control_flow_ops.no_op) as result:
with self.cached_session() as sess:
op = result.test_fn(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'test message'):
sess.run(op)
if __name__ == '__main__':

View File

@ -22,6 +22,7 @@ py_library(
"__init__.py",
"control_flow.py",
"data_structures.py",
"exceptions.py",
"py_builtins.py",
"slices.py",
],
@ -62,6 +63,16 @@ py_test(
],
)
py_test(
name = "exceptions_test",
srcs = ["exceptions_test.py"],
srcs_version = "PY2AND3",
deps = [
":operators",
"//tensorflow/python:client_testlib",
],
)
py_test(
name = "py_builtins_test",
srcs = ["py_builtins_test.py"],

View File

@ -45,6 +45,7 @@ from tensorflow.python.autograph.operators.data_structures import list_stack
from tensorflow.python.autograph.operators.data_structures import ListPopOpts
from tensorflow.python.autograph.operators.data_structures import ListStackOpts
from tensorflow.python.autograph.operators.data_structures import new_list
from tensorflow.python.autograph.operators.exceptions import assert_stmt
from tensorflow.python.autograph.operators.py_builtins import float_
from tensorflow.python.autograph.operators.py_builtins import int_
from tensorflow.python.autograph.operators.py_builtins import len_

View File

@ -0,0 +1,86 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exception handling statements: assert, etc."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.util import tf_inspect
def assert_stmt(expression1, expression2):
"""Functional form of an assert statement.
This follows the semantics of the Python assert statement, however the
concrete implementations may deviate from it. See the respective
implementation for details.
In general, the assert statement should not be used for control flow.
Furthermore, it is encouraged that the assertion expressions should not have
side effects.
Args:
expression1: Any
expression2: Callable[[], Any], returns the expression to include in the
error message when expression1 evaluates to False. When expression1 is
True, the result of expression2 will not be evaluated, however,
expression2 itself may be evaluated in some implementations.
Returns:
Any, implementation-dependent.
Raises:
ValueError: if any arguments are illegal.
"""
if not callable(expression2):
raise ValueError('{} must be a callable'.format(expression2))
args, _, keywords, _ = tf_inspect.getargspec(expression2)
if args or keywords:
raise ValueError('{} may not have any arguments'.format(expression2))
if tensor_util.is_tensor(expression1):
return _tf_assert_stmt(expression1, expression2)
else:
return _py_assert_stmt(expression1, expression2)
def _tf_assert_stmt(expression1, expression2):
"""Overload of assert_stmt that stages a TF Assert.
This implementation deviates from Python semantics as follows:
(1) the assertion is verified regardless of the state of __debug__
(2) on assertion failure, the graph execution will fail with
tensorflow.errors.ValueError, rather than AssertionError.
Args:
expression1: tensorflow.Tensor, must evaluate to a tf.bool scalar
expression2: Callable[[], Union[tensorflow.Tensor, List[tensorflow.Tensor]]]
Returns:
tensorflow.Operation
"""
expression2_tensors = expression2()
if not isinstance(expression2_tensors, list):
expression2_tensors = [expression2_tensors]
return control_flow_ops.Assert(expression1, expression2_tensors)
def _py_assert_stmt(expression1, expression2):
"""Overload of assert_stmt that executes a Python assert statement."""
assert expression1, expression2()
return None

View File

@ -0,0 +1,87 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for exceptions module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import exceptions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import test
class ExceptionsTest(test.TestCase):
def test_assert_tf_untriggered(self):
with self.cached_session() as sess:
t = exceptions.assert_stmt(
constant_op.constant(True), lambda: constant_op.constant('ignored'))
sess.run(t)
def test_assert_tf_triggered(self):
with self.cached_session() as sess:
t = exceptions.assert_stmt(
constant_op.constant(False),
lambda: constant_op.constant('test message'))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'test message'):
sess.run(t)
def test_assert_tf_multiple_printed_values(self):
two_tensors = [
constant_op.constant('test message'),
constant_op.constant('another message')
]
with self.cached_session() as sess:
t = exceptions.assert_stmt(
constant_op.constant(False), lambda: two_tensors)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'test message.*another message'):
sess.run(t)
def test_assert_python_untriggered(self):
side_effect_trace = []
def expression_with_side_effects():
side_effect_trace.append(object())
return 'test message'
exceptions.assert_stmt(True, expression_with_side_effects)
self.assertListEqual(side_effect_trace, [])
def test_assert_python_triggered(self):
if not __debug__:
# Python assertions only be tested when in debug mode.
return
side_effect_trace = []
tracer = object()
def expression_with_side_effects():
side_effect_trace.append(tracer)
return 'test message'
with self.assertRaisesRegexp(AssertionError, 'test message'):
exceptions.assert_stmt(False, expression_with_side_effects)
self.assertListEqual(side_effect_trace, [tracer])
if __name__ == '__main__':
test.main()