Add all unary operations to list of autograph-supported ops.

PiperOrigin-RevId: 227681912
This commit is contained in:
A. Unique TensorFlower 2019-01-03 06:51:42 -08:00 committed by TensorFlower Gardener
parent eb54349cb4
commit 31cf2b97f0
4 changed files with 46 additions and 72 deletions

View File

@ -38,29 +38,29 @@ from tensorflow.python.autograph.pyct import templates
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
OP_MAPPING = {
gast.And: 'ag__.and_',
gast.Eq: 'ag__.eq',
gast.NotEq: 'ag__.not_eq',
gast.Lt: 'ag__.lt',
gast.LtE: 'ag__.lt_e',
gast.Gt: 'ag__.gt',
gast.GtE: 'ag__.gt_e',
gast.Is: 'ag__.is_',
gast.IsNot: 'ag__.is_not',
gast.In: 'ag__.in_',
gast.Not: 'ag__.not_',
gast.NotIn: 'ag__.not_in',
gast.Or: 'ag__.or_',
gast.UAdd: 'ag__.u_add',
gast.USub: 'ag__.u_sub',
gast.Invert: 'ag__.invert',
}
class LogicalExpressionTransformer(converter.Base):
"""Converts logical expressions to corresponding TF calls."""
def __init__(self, ctx):
super(LogicalExpressionTransformer, self).__init__(ctx)
# TODO(mdan): For completeness and consistency, overload everything.
self.op_mapping = {
gast.And: 'ag__.and_',
gast.Eq: 'ag__.eq',
gast.NotEq: 'ag__.not_eq',
gast.Lt: 'ag__.lt',
gast.LtE: 'ag__.lt_e',
gast.Gt: 'ag__.gt',
gast.GtE: 'ag__.gt_e',
gast.Is: 'ag__.is_',
gast.IsNot: 'ag__.is_not',
gast.In: 'ag__.in_',
gast.Not: 'ag__.not_',
gast.NotIn: 'ag__.not_in',
gast.Or: 'ag__.or_',
gast.USub: 'ag__.u_sub',
}
def _expect_simple_symbol(self, operand):
if isinstance(operand, gast.Name):
return
@ -74,11 +74,11 @@ class LogicalExpressionTransformer(converter.Base):
def _has_matching_func(self, operator):
op_type = type(operator)
return op_type in self.op_mapping
return op_type in OP_MAPPING
def _matching_func(self, operator):
op_type = type(operator)
return self.op_mapping[op_type]
return OP_MAPPING[op_type]
def _as_function(self, func_name, args, args_as_lambda=False):
if args_as_lambda:

View File

@ -77,6 +77,13 @@ class LogicalExpressionTest(converter_testing.TestCase):
with self.converted(test_fn, logical_expressions, {}) as result:
self.assertTrue(result.test_fn('a', ('a',)))
def test_unary_ops(self):
def test_fn(a):
return ~a, -a, +a
with self.converted(test_fn, logical_expressions, {}) as result:
self.assertEqual(result.test_fn(1), (-2, -1, 1))
if __name__ == '__main__':
test.main()

View File

@ -52,6 +52,7 @@ from tensorflow.python.autograph.operators.logical import eq
from tensorflow.python.autograph.operators.logical import gt
from tensorflow.python.autograph.operators.logical import gt_e
from tensorflow.python.autograph.operators.logical import in_
from tensorflow.python.autograph.operators.logical import invert
from tensorflow.python.autograph.operators.logical import is_
from tensorflow.python.autograph.operators.logical import is_not
from tensorflow.python.autograph.operators.logical import lt
@ -60,6 +61,7 @@ from tensorflow.python.autograph.operators.logical import not_
from tensorflow.python.autograph.operators.logical import not_eq
from tensorflow.python.autograph.operators.logical import not_in
from tensorflow.python.autograph.operators.logical import or_
from tensorflow.python.autograph.operators.logical import u_add
from tensorflow.python.autograph.operators.logical import u_sub
from tensorflow.python.autograph.operators.py_builtins import float_
from tensorflow.python.autograph.operators.py_builtins import int_

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import operator
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
@ -35,7 +37,7 @@ def and_(a, b):
a_val = a()
if tensor_util.is_tensor(a_val):
return _tf_lazy_and(a_val, b)
return _py_lazy_and(a_val, b)
return a_val and b()
def _tf_lazy_and(cond, b):
@ -44,17 +46,12 @@ def _tf_lazy_and(cond, b):
return control_flow_ops.cond(cond, b, lambda: cond)
def _py_lazy_and(cond, b):
"""Lazy-eval equivalent of "and" in Python."""
return cond and b()
def or_(a, b):
"""Functional form of "or". Uses lazy evaluation semantics."""
a_val = a()
if tensor_util.is_tensor(a_val):
return _tf_lazy_or(a_val, b)
return _py_lazy_or(a_val, b)
return a_val or b()
def _tf_lazy_or(cond, b):
@ -63,16 +60,11 @@ def _tf_lazy_or(cond, b):
return control_flow_ops.cond(cond, lambda: cond, b)
def _py_lazy_or(cond, b):
"""Lazy-eval equivalent of "or" in Python."""
return cond or b()
def eq(a, b):
"""Functional form of "equal"."""
if tensor_util.is_tensor(a) or tensor_util.is_tensor(b):
return _tf_equal(a, b)
return _py_equal(a, b)
return a == b
def _tf_equal(a, b):
@ -80,11 +72,6 @@ def _tf_equal(a, b):
return gen_math_ops.equal(a, b)
def _py_equal(a, b):
"""Overload of "equal" that falls back to Python's default implementation."""
return a == b
def not_eq(a, b):
"""Functional form of "not-equal"."""
return not_(eq(a, b))
@ -92,25 +79,8 @@ def not_eq(a, b):
# Default implementation for the remainings.
def gt(a, b):
"""Functional form of "less-than"."""
return a > b
def gt_e(a, b):
"""Functional form of "less-than"."""
return a >= b
def is_(a, b):
"""Functional form of "less-than"."""
return a is b
def is_not(a, b):
"""Functional form of "less-than"."""
return a is not b
is_ = operator.is_
is_not = operator.is_not
def in_(a, b):
@ -119,21 +89,16 @@ def in_(a, b):
return a in b
def lt(a, b):
"""Functional form of "less-than"."""
return a < b
def lt_e(a, b):
"""Functional form of "less-than"."""
return a <= b
def not_in(a, b):
"""Functional form of "less-than"."""
return a not in b
gt = operator.gt
gt_e = operator.ge
lt = operator.lt
lt_e = operator.le
def u_sub(a):
"""Functional form of "unary-sub"."""
return -a
u_add = operator.pos
u_sub = operator.neg
invert = operator.invert