Use overloaded operators for the assert statement. This should remove the reliance on importing tensorflow in the generated code.
PiperOrigin-RevId: 216528047
This commit is contained in:
parent
e851764c24
commit
93226f635c
@ -33,7 +33,7 @@ class AssertTransformer(converter.Base):
|
||||
# Note: The lone tf.Assert call will be wrapped with control_dependencies
|
||||
# by side_effect_guards.
|
||||
template = """
|
||||
tf.Assert(test, (msg,))
|
||||
ag__.assert_stmt(test, lambda: msg)
|
||||
"""
|
||||
|
||||
if node.msg is None:
|
||||
|
@ -18,24 +18,30 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.converters import asserts
|
||||
from tensorflow.python.autograph.converters import side_effect_guards
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.ops import gen_control_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AssertsTest(converter_testing.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
def test_basic(self):
|
||||
|
||||
def test_fn(a):
|
||||
assert a > 0
|
||||
assert a, 'test message'
|
||||
return tf.no_op() # pylint:disable=undefined-variable
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = asserts.transform(node, ctx)
|
||||
|
||||
self.assertTrue(isinstance(node.body[0].value, gast.Call))
|
||||
with self.converted(test_fn, (asserts, side_effect_guards), {},
|
||||
gen_control_flow_ops.no_op) as result:
|
||||
with self.cached_session() as sess:
|
||||
op = result.test_fn(constant_op.constant(False))
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
'test message'):
|
||||
sess.run(op)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -22,6 +22,7 @@ py_library(
|
||||
"__init__.py",
|
||||
"control_flow.py",
|
||||
"data_structures.py",
|
||||
"exceptions.py",
|
||||
"py_builtins.py",
|
||||
"slices.py",
|
||||
],
|
||||
@ -62,6 +63,16 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "exceptions_test",
|
||||
srcs = ["exceptions_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":operators",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "py_builtins_test",
|
||||
srcs = ["py_builtins_test.py"],
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.autograph.operators.data_structures import list_stack
|
||||
from tensorflow.python.autograph.operators.data_structures import ListPopOpts
|
||||
from tensorflow.python.autograph.operators.data_structures import ListStackOpts
|
||||
from tensorflow.python.autograph.operators.data_structures import new_list
|
||||
from tensorflow.python.autograph.operators.exceptions import assert_stmt
|
||||
from tensorflow.python.autograph.operators.py_builtins import float_
|
||||
from tensorflow.python.autograph.operators.py_builtins import int_
|
||||
from tensorflow.python.autograph.operators.py_builtins import len_
|
||||
|
86
tensorflow/python/autograph/operators/exceptions.py
Normal file
86
tensorflow/python/autograph/operators/exceptions.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Exception handling statements: assert, etc."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def assert_stmt(expression1, expression2):
|
||||
"""Functional form of an assert statement.
|
||||
|
||||
This follows the semantics of the Python assert statement, however the
|
||||
concrete implementations may deviate from it. See the respective
|
||||
implementation for details.
|
||||
|
||||
In general, the assert statement should not be used for control flow.
|
||||
Furthermore, it is encouraged that the assertion expressions should not have
|
||||
side effects.
|
||||
|
||||
Args:
|
||||
expression1: Any
|
||||
expression2: Callable[[], Any], returns the expression to include in the
|
||||
error message when expression1 evaluates to False. When expression1 is
|
||||
True, the result of expression2 will not be evaluated, however,
|
||||
expression2 itself may be evaluated in some implementations.
|
||||
|
||||
Returns:
|
||||
Any, implementation-dependent.
|
||||
|
||||
Raises:
|
||||
ValueError: if any arguments are illegal.
|
||||
"""
|
||||
if not callable(expression2):
|
||||
raise ValueError('{} must be a callable'.format(expression2))
|
||||
args, _, keywords, _ = tf_inspect.getargspec(expression2)
|
||||
if args or keywords:
|
||||
raise ValueError('{} may not have any arguments'.format(expression2))
|
||||
|
||||
if tensor_util.is_tensor(expression1):
|
||||
return _tf_assert_stmt(expression1, expression2)
|
||||
else:
|
||||
return _py_assert_stmt(expression1, expression2)
|
||||
|
||||
|
||||
def _tf_assert_stmt(expression1, expression2):
|
||||
"""Overload of assert_stmt that stages a TF Assert.
|
||||
|
||||
This implementation deviates from Python semantics as follows:
|
||||
(1) the assertion is verified regardless of the state of __debug__
|
||||
(2) on assertion failure, the graph execution will fail with
|
||||
tensorflow.errors.ValueError, rather than AssertionError.
|
||||
|
||||
Args:
|
||||
expression1: tensorflow.Tensor, must evaluate to a tf.bool scalar
|
||||
expression2: Callable[[], Union[tensorflow.Tensor, List[tensorflow.Tensor]]]
|
||||
|
||||
Returns:
|
||||
tensorflow.Operation
|
||||
"""
|
||||
expression2_tensors = expression2()
|
||||
if not isinstance(expression2_tensors, list):
|
||||
expression2_tensors = [expression2_tensors]
|
||||
return control_flow_ops.Assert(expression1, expression2_tensors)
|
||||
|
||||
|
||||
def _py_assert_stmt(expression1, expression2):
|
||||
"""Overload of assert_stmt that executes a Python assert statement."""
|
||||
assert expression1, expression2()
|
||||
return None
|
87
tensorflow/python/autograph/operators/exceptions_test.py
Normal file
87
tensorflow/python/autograph/operators/exceptions_test.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for exceptions module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.operators import exceptions
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ExceptionsTest(test.TestCase):
|
||||
|
||||
def test_assert_tf_untriggered(self):
|
||||
with self.cached_session() as sess:
|
||||
t = exceptions.assert_stmt(
|
||||
constant_op.constant(True), lambda: constant_op.constant('ignored'))
|
||||
sess.run(t)
|
||||
|
||||
def test_assert_tf_triggered(self):
|
||||
with self.cached_session() as sess:
|
||||
t = exceptions.assert_stmt(
|
||||
constant_op.constant(False),
|
||||
lambda: constant_op.constant('test message'))
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
'test message'):
|
||||
sess.run(t)
|
||||
|
||||
def test_assert_tf_multiple_printed_values(self):
|
||||
two_tensors = [
|
||||
constant_op.constant('test message'),
|
||||
constant_op.constant('another message')
|
||||
]
|
||||
with self.cached_session() as sess:
|
||||
t = exceptions.assert_stmt(
|
||||
constant_op.constant(False), lambda: two_tensors)
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
'test message.*another message'):
|
||||
sess.run(t)
|
||||
|
||||
def test_assert_python_untriggered(self):
|
||||
side_effect_trace = []
|
||||
|
||||
def expression_with_side_effects():
|
||||
side_effect_trace.append(object())
|
||||
return 'test message'
|
||||
|
||||
exceptions.assert_stmt(True, expression_with_side_effects)
|
||||
|
||||
self.assertListEqual(side_effect_trace, [])
|
||||
|
||||
def test_assert_python_triggered(self):
|
||||
if not __debug__:
|
||||
# Python assertions only be tested when in debug mode.
|
||||
return
|
||||
|
||||
side_effect_trace = []
|
||||
tracer = object()
|
||||
|
||||
def expression_with_side_effects():
|
||||
side_effect_trace.append(tracer)
|
||||
return 'test message'
|
||||
|
||||
with self.assertRaisesRegexp(AssertionError, 'test message'):
|
||||
exceptions.assert_stmt(False, expression_with_side_effects)
|
||||
self.assertListEqual(side_effect_trace, [tracer])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user