diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py index dfcaafdc9eb..ea9740a22e1 100644 --- a/tensorflow/python/autograph/converters/logical_expressions.py +++ b/tensorflow/python/autograph/converters/logical_expressions.py @@ -38,29 +38,29 @@ from tensorflow.python.autograph.pyct import templates SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' +OP_MAPPING = { + gast.And: 'ag__.and_', + gast.Eq: 'ag__.eq', + gast.NotEq: 'ag__.not_eq', + gast.Lt: 'ag__.lt', + gast.LtE: 'ag__.lt_e', + gast.Gt: 'ag__.gt', + gast.GtE: 'ag__.gt_e', + gast.Is: 'ag__.is_', + gast.IsNot: 'ag__.is_not', + gast.In: 'ag__.in_', + gast.Not: 'ag__.not_', + gast.NotIn: 'ag__.not_in', + gast.Or: 'ag__.or_', + gast.UAdd: 'ag__.u_add', + gast.USub: 'ag__.u_sub', + gast.Invert: 'ag__.invert', +} + + class LogicalExpressionTransformer(converter.Base): """Converts logical expressions to corresponding TF calls.""" - def __init__(self, ctx): - super(LogicalExpressionTransformer, self).__init__(ctx) - # TODO(mdan): For completeness and consistency, overload everything. - self.op_mapping = { - gast.And: 'ag__.and_', - gast.Eq: 'ag__.eq', - gast.NotEq: 'ag__.not_eq', - gast.Lt: 'ag__.lt', - gast.LtE: 'ag__.lt_e', - gast.Gt: 'ag__.gt', - gast.GtE: 'ag__.gt_e', - gast.Is: 'ag__.is_', - gast.IsNot: 'ag__.is_not', - gast.In: 'ag__.in_', - gast.Not: 'ag__.not_', - gast.NotIn: 'ag__.not_in', - gast.Or: 'ag__.or_', - gast.USub: 'ag__.u_sub', - } - def _expect_simple_symbol(self, operand): if isinstance(operand, gast.Name): return @@ -74,11 +74,11 @@ class LogicalExpressionTransformer(converter.Base): def _has_matching_func(self, operator): op_type = type(operator) - return op_type in self.op_mapping + return op_type in OP_MAPPING def _matching_func(self, operator): op_type = type(operator) - return self.op_mapping[op_type] + return OP_MAPPING[op_type] def _as_function(self, func_name, args, args_as_lambda=False): if args_as_lambda: diff --git a/tensorflow/python/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py index 687412750e0..67ccd1fb479 100644 --- a/tensorflow/python/autograph/converters/logical_expressions_test.py +++ b/tensorflow/python/autograph/converters/logical_expressions_test.py @@ -77,6 +77,13 @@ class LogicalExpressionTest(converter_testing.TestCase): with self.converted(test_fn, logical_expressions, {}) as result: self.assertTrue(result.test_fn('a', ('a',))) + def test_unary_ops(self): + def test_fn(a): + return ~a, -a, +a + + with self.converted(test_fn, logical_expressions, {}) as result: + self.assertEqual(result.test_fn(1), (-2, -1, 1)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py index 7a580fe3247..35f8028c295 100644 --- a/tensorflow/python/autograph/operators/__init__.py +++ b/tensorflow/python/autograph/operators/__init__.py @@ -52,6 +52,7 @@ from tensorflow.python.autograph.operators.logical import eq from tensorflow.python.autograph.operators.logical import gt from tensorflow.python.autograph.operators.logical import gt_e from tensorflow.python.autograph.operators.logical import in_ +from tensorflow.python.autograph.operators.logical import invert from tensorflow.python.autograph.operators.logical import is_ from tensorflow.python.autograph.operators.logical import is_not from tensorflow.python.autograph.operators.logical import lt @@ -60,6 +61,7 @@ from tensorflow.python.autograph.operators.logical import not_ from tensorflow.python.autograph.operators.logical import not_eq from tensorflow.python.autograph.operators.logical import not_in from tensorflow.python.autograph.operators.logical import or_ +from tensorflow.python.autograph.operators.logical import u_add from tensorflow.python.autograph.operators.logical import u_sub from tensorflow.python.autograph.operators.py_builtins import float_ from tensorflow.python.autograph.operators.py_builtins import int_ diff --git a/tensorflow/python/autograph/operators/logical.py b/tensorflow/python/autograph/operators/logical.py index 569db5b91bd..dadb0daf1ae 100644 --- a/tensorflow/python/autograph/operators/logical.py +++ b/tensorflow/python/autograph/operators/logical.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import operator + from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_math_ops @@ -35,7 +37,7 @@ def and_(a, b): a_val = a() if tensor_util.is_tensor(a_val): return _tf_lazy_and(a_val, b) - return _py_lazy_and(a_val, b) + return a_val and b() def _tf_lazy_and(cond, b): @@ -44,17 +46,12 @@ def _tf_lazy_and(cond, b): return control_flow_ops.cond(cond, b, lambda: cond) -def _py_lazy_and(cond, b): - """Lazy-eval equivalent of "and" in Python.""" - return cond and b() - - def or_(a, b): """Functional form of "or". Uses lazy evaluation semantics.""" a_val = a() if tensor_util.is_tensor(a_val): return _tf_lazy_or(a_val, b) - return _py_lazy_or(a_val, b) + return a_val or b() def _tf_lazy_or(cond, b): @@ -63,16 +60,11 @@ def _tf_lazy_or(cond, b): return control_flow_ops.cond(cond, lambda: cond, b) -def _py_lazy_or(cond, b): - """Lazy-eval equivalent of "or" in Python.""" - return cond or b() - - def eq(a, b): """Functional form of "equal".""" if tensor_util.is_tensor(a) or tensor_util.is_tensor(b): return _tf_equal(a, b) - return _py_equal(a, b) + return a == b def _tf_equal(a, b): @@ -80,11 +72,6 @@ def _tf_equal(a, b): return gen_math_ops.equal(a, b) -def _py_equal(a, b): - """Overload of "equal" that falls back to Python's default implementation.""" - return a == b - - def not_eq(a, b): """Functional form of "not-equal".""" return not_(eq(a, b)) @@ -92,25 +79,8 @@ def not_eq(a, b): # Default implementation for the remainings. - -def gt(a, b): - """Functional form of "less-than".""" - return a > b - - -def gt_e(a, b): - """Functional form of "less-than".""" - return a >= b - - -def is_(a, b): - """Functional form of "less-than".""" - return a is b - - -def is_not(a, b): - """Functional form of "less-than".""" - return a is not b +is_ = operator.is_ +is_not = operator.is_not def in_(a, b): @@ -119,21 +89,16 @@ def in_(a, b): return a in b -def lt(a, b): - """Functional form of "less-than".""" - return a < b - - -def lt_e(a, b): - """Functional form of "less-than".""" - return a <= b - - def not_in(a, b): """Functional form of "less-than".""" return a not in b +gt = operator.gt +gt_e = operator.ge +lt = operator.lt +lt_e = operator.le -def u_sub(a): - """Functional form of "unary-sub".""" - return -a + +u_add = operator.pos +u_sub = operator.neg +invert = operator.invert