Add all unary operations to list of autograph-supported ops.

PiperOrigin-RevId: 227681912
This commit is contained in:
A. Unique TensorFlower 2019-01-03 06:51:42 -08:00 committed by TensorFlower Gardener
parent eb54349cb4
commit 31cf2b97f0
4 changed files with 46 additions and 72 deletions

View File

@ -38,29 +38,29 @@ from tensorflow.python.autograph.pyct import templates
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
OP_MAPPING = {
gast.And: 'ag__.and_',
gast.Eq: 'ag__.eq',
gast.NotEq: 'ag__.not_eq',
gast.Lt: 'ag__.lt',
gast.LtE: 'ag__.lt_e',
gast.Gt: 'ag__.gt',
gast.GtE: 'ag__.gt_e',
gast.Is: 'ag__.is_',
gast.IsNot: 'ag__.is_not',
gast.In: 'ag__.in_',
gast.Not: 'ag__.not_',
gast.NotIn: 'ag__.not_in',
gast.Or: 'ag__.or_',
gast.UAdd: 'ag__.u_add',
gast.USub: 'ag__.u_sub',
gast.Invert: 'ag__.invert',
}
class LogicalExpressionTransformer(converter.Base): class LogicalExpressionTransformer(converter.Base):
"""Converts logical expressions to corresponding TF calls.""" """Converts logical expressions to corresponding TF calls."""
def __init__(self, ctx):
super(LogicalExpressionTransformer, self).__init__(ctx)
# TODO(mdan): For completeness and consistency, overload everything.
self.op_mapping = {
gast.And: 'ag__.and_',
gast.Eq: 'ag__.eq',
gast.NotEq: 'ag__.not_eq',
gast.Lt: 'ag__.lt',
gast.LtE: 'ag__.lt_e',
gast.Gt: 'ag__.gt',
gast.GtE: 'ag__.gt_e',
gast.Is: 'ag__.is_',
gast.IsNot: 'ag__.is_not',
gast.In: 'ag__.in_',
gast.Not: 'ag__.not_',
gast.NotIn: 'ag__.not_in',
gast.Or: 'ag__.or_',
gast.USub: 'ag__.u_sub',
}
def _expect_simple_symbol(self, operand): def _expect_simple_symbol(self, operand):
if isinstance(operand, gast.Name): if isinstance(operand, gast.Name):
return return
@ -74,11 +74,11 @@ class LogicalExpressionTransformer(converter.Base):
def _has_matching_func(self, operator): def _has_matching_func(self, operator):
op_type = type(operator) op_type = type(operator)
return op_type in self.op_mapping return op_type in OP_MAPPING
def _matching_func(self, operator): def _matching_func(self, operator):
op_type = type(operator) op_type = type(operator)
return self.op_mapping[op_type] return OP_MAPPING[op_type]
def _as_function(self, func_name, args, args_as_lambda=False): def _as_function(self, func_name, args, args_as_lambda=False):
if args_as_lambda: if args_as_lambda:

View File

@ -77,6 +77,13 @@ class LogicalExpressionTest(converter_testing.TestCase):
with self.converted(test_fn, logical_expressions, {}) as result: with self.converted(test_fn, logical_expressions, {}) as result:
self.assertTrue(result.test_fn('a', ('a',))) self.assertTrue(result.test_fn('a', ('a',)))
def test_unary_ops(self):
def test_fn(a):
return ~a, -a, +a
with self.converted(test_fn, logical_expressions, {}) as result:
self.assertEqual(result.test_fn(1), (-2, -1, 1))
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -52,6 +52,7 @@ from tensorflow.python.autograph.operators.logical import eq
from tensorflow.python.autograph.operators.logical import gt from tensorflow.python.autograph.operators.logical import gt
from tensorflow.python.autograph.operators.logical import gt_e from tensorflow.python.autograph.operators.logical import gt_e
from tensorflow.python.autograph.operators.logical import in_ from tensorflow.python.autograph.operators.logical import in_
from tensorflow.python.autograph.operators.logical import invert
from tensorflow.python.autograph.operators.logical import is_ from tensorflow.python.autograph.operators.logical import is_
from tensorflow.python.autograph.operators.logical import is_not from tensorflow.python.autograph.operators.logical import is_not
from tensorflow.python.autograph.operators.logical import lt from tensorflow.python.autograph.operators.logical import lt
@ -60,6 +61,7 @@ from tensorflow.python.autograph.operators.logical import not_
from tensorflow.python.autograph.operators.logical import not_eq from tensorflow.python.autograph.operators.logical import not_eq
from tensorflow.python.autograph.operators.logical import not_in from tensorflow.python.autograph.operators.logical import not_in
from tensorflow.python.autograph.operators.logical import or_ from tensorflow.python.autograph.operators.logical import or_
from tensorflow.python.autograph.operators.logical import u_add
from tensorflow.python.autograph.operators.logical import u_sub from tensorflow.python.autograph.operators.logical import u_sub
from tensorflow.python.autograph.operators.py_builtins import float_ from tensorflow.python.autograph.operators.py_builtins import float_
from tensorflow.python.autograph.operators.py_builtins import int_ from tensorflow.python.autograph.operators.py_builtins import int_

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import operator
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_math_ops
@ -35,7 +37,7 @@ def and_(a, b):
a_val = a() a_val = a()
if tensor_util.is_tensor(a_val): if tensor_util.is_tensor(a_val):
return _tf_lazy_and(a_val, b) return _tf_lazy_and(a_val, b)
return _py_lazy_and(a_val, b) return a_val and b()
def _tf_lazy_and(cond, b): def _tf_lazy_and(cond, b):
@ -44,17 +46,12 @@ def _tf_lazy_and(cond, b):
return control_flow_ops.cond(cond, b, lambda: cond) return control_flow_ops.cond(cond, b, lambda: cond)
def _py_lazy_and(cond, b):
"""Lazy-eval equivalent of "and" in Python."""
return cond and b()
def or_(a, b): def or_(a, b):
"""Functional form of "or". Uses lazy evaluation semantics.""" """Functional form of "or". Uses lazy evaluation semantics."""
a_val = a() a_val = a()
if tensor_util.is_tensor(a_val): if tensor_util.is_tensor(a_val):
return _tf_lazy_or(a_val, b) return _tf_lazy_or(a_val, b)
return _py_lazy_or(a_val, b) return a_val or b()
def _tf_lazy_or(cond, b): def _tf_lazy_or(cond, b):
@ -63,16 +60,11 @@ def _tf_lazy_or(cond, b):
return control_flow_ops.cond(cond, lambda: cond, b) return control_flow_ops.cond(cond, lambda: cond, b)
def _py_lazy_or(cond, b):
"""Lazy-eval equivalent of "or" in Python."""
return cond or b()
def eq(a, b): def eq(a, b):
"""Functional form of "equal".""" """Functional form of "equal"."""
if tensor_util.is_tensor(a) or tensor_util.is_tensor(b): if tensor_util.is_tensor(a) or tensor_util.is_tensor(b):
return _tf_equal(a, b) return _tf_equal(a, b)
return _py_equal(a, b) return a == b
def _tf_equal(a, b): def _tf_equal(a, b):
@ -80,11 +72,6 @@ def _tf_equal(a, b):
return gen_math_ops.equal(a, b) return gen_math_ops.equal(a, b)
def _py_equal(a, b):
"""Overload of "equal" that falls back to Python's default implementation."""
return a == b
def not_eq(a, b): def not_eq(a, b):
"""Functional form of "not-equal".""" """Functional form of "not-equal"."""
return not_(eq(a, b)) return not_(eq(a, b))
@ -92,25 +79,8 @@ def not_eq(a, b):
# Default implementation for the remainings. # Default implementation for the remainings.
is_ = operator.is_
def gt(a, b): is_not = operator.is_not
"""Functional form of "less-than"."""
return a > b
def gt_e(a, b):
"""Functional form of "less-than"."""
return a >= b
def is_(a, b):
"""Functional form of "less-than"."""
return a is b
def is_not(a, b):
"""Functional form of "less-than"."""
return a is not b
def in_(a, b): def in_(a, b):
@ -119,21 +89,16 @@ def in_(a, b):
return a in b return a in b
def lt(a, b):
"""Functional form of "less-than"."""
return a < b
def lt_e(a, b):
"""Functional form of "less-than"."""
return a <= b
def not_in(a, b): def not_in(a, b):
"""Functional form of "less-than".""" """Functional form of "less-than"."""
return a not in b return a not in b
gt = operator.gt
gt_e = operator.ge
lt = operator.lt
lt_e = operator.le
def u_sub(a):
"""Functional form of "unary-sub".""" u_add = operator.pos
return -a u_sub = operator.neg
invert = operator.invert