Add all unary operations to list of autograph-supported ops.
PiperOrigin-RevId: 227681912
This commit is contained in:
parent
eb54349cb4
commit
31cf2b97f0
@ -38,29 +38,29 @@ from tensorflow.python.autograph.pyct import templates
|
||||
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
|
||||
|
||||
|
||||
OP_MAPPING = {
|
||||
gast.And: 'ag__.and_',
|
||||
gast.Eq: 'ag__.eq',
|
||||
gast.NotEq: 'ag__.not_eq',
|
||||
gast.Lt: 'ag__.lt',
|
||||
gast.LtE: 'ag__.lt_e',
|
||||
gast.Gt: 'ag__.gt',
|
||||
gast.GtE: 'ag__.gt_e',
|
||||
gast.Is: 'ag__.is_',
|
||||
gast.IsNot: 'ag__.is_not',
|
||||
gast.In: 'ag__.in_',
|
||||
gast.Not: 'ag__.not_',
|
||||
gast.NotIn: 'ag__.not_in',
|
||||
gast.Or: 'ag__.or_',
|
||||
gast.UAdd: 'ag__.u_add',
|
||||
gast.USub: 'ag__.u_sub',
|
||||
gast.Invert: 'ag__.invert',
|
||||
}
|
||||
|
||||
|
||||
class LogicalExpressionTransformer(converter.Base):
|
||||
"""Converts logical expressions to corresponding TF calls."""
|
||||
|
||||
def __init__(self, ctx):
|
||||
super(LogicalExpressionTransformer, self).__init__(ctx)
|
||||
# TODO(mdan): For completeness and consistency, overload everything.
|
||||
self.op_mapping = {
|
||||
gast.And: 'ag__.and_',
|
||||
gast.Eq: 'ag__.eq',
|
||||
gast.NotEq: 'ag__.not_eq',
|
||||
gast.Lt: 'ag__.lt',
|
||||
gast.LtE: 'ag__.lt_e',
|
||||
gast.Gt: 'ag__.gt',
|
||||
gast.GtE: 'ag__.gt_e',
|
||||
gast.Is: 'ag__.is_',
|
||||
gast.IsNot: 'ag__.is_not',
|
||||
gast.In: 'ag__.in_',
|
||||
gast.Not: 'ag__.not_',
|
||||
gast.NotIn: 'ag__.not_in',
|
||||
gast.Or: 'ag__.or_',
|
||||
gast.USub: 'ag__.u_sub',
|
||||
}
|
||||
|
||||
def _expect_simple_symbol(self, operand):
|
||||
if isinstance(operand, gast.Name):
|
||||
return
|
||||
@ -74,11 +74,11 @@ class LogicalExpressionTransformer(converter.Base):
|
||||
|
||||
def _has_matching_func(self, operator):
|
||||
op_type = type(operator)
|
||||
return op_type in self.op_mapping
|
||||
return op_type in OP_MAPPING
|
||||
|
||||
def _matching_func(self, operator):
|
||||
op_type = type(operator)
|
||||
return self.op_mapping[op_type]
|
||||
return OP_MAPPING[op_type]
|
||||
|
||||
def _as_function(self, func_name, args, args_as_lambda=False):
|
||||
if args_as_lambda:
|
||||
|
@ -77,6 +77,13 @@ class LogicalExpressionTest(converter_testing.TestCase):
|
||||
with self.converted(test_fn, logical_expressions, {}) as result:
|
||||
self.assertTrue(result.test_fn('a', ('a',)))
|
||||
|
||||
def test_unary_ops(self):
|
||||
def test_fn(a):
|
||||
return ~a, -a, +a
|
||||
|
||||
with self.converted(test_fn, logical_expressions, {}) as result:
|
||||
self.assertEqual(result.test_fn(1), (-2, -1, 1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -52,6 +52,7 @@ from tensorflow.python.autograph.operators.logical import eq
|
||||
from tensorflow.python.autograph.operators.logical import gt
|
||||
from tensorflow.python.autograph.operators.logical import gt_e
|
||||
from tensorflow.python.autograph.operators.logical import in_
|
||||
from tensorflow.python.autograph.operators.logical import invert
|
||||
from tensorflow.python.autograph.operators.logical import is_
|
||||
from tensorflow.python.autograph.operators.logical import is_not
|
||||
from tensorflow.python.autograph.operators.logical import lt
|
||||
@ -60,6 +61,7 @@ from tensorflow.python.autograph.operators.logical import not_
|
||||
from tensorflow.python.autograph.operators.logical import not_eq
|
||||
from tensorflow.python.autograph.operators.logical import not_in
|
||||
from tensorflow.python.autograph.operators.logical import or_
|
||||
from tensorflow.python.autograph.operators.logical import u_add
|
||||
from tensorflow.python.autograph.operators.logical import u_sub
|
||||
from tensorflow.python.autograph.operators.py_builtins import float_
|
||||
from tensorflow.python.autograph.operators.py_builtins import int_
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import operator
|
||||
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
@ -35,7 +37,7 @@ def and_(a, b):
|
||||
a_val = a()
|
||||
if tensor_util.is_tensor(a_val):
|
||||
return _tf_lazy_and(a_val, b)
|
||||
return _py_lazy_and(a_val, b)
|
||||
return a_val and b()
|
||||
|
||||
|
||||
def _tf_lazy_and(cond, b):
|
||||
@ -44,17 +46,12 @@ def _tf_lazy_and(cond, b):
|
||||
return control_flow_ops.cond(cond, b, lambda: cond)
|
||||
|
||||
|
||||
def _py_lazy_and(cond, b):
|
||||
"""Lazy-eval equivalent of "and" in Python."""
|
||||
return cond and b()
|
||||
|
||||
|
||||
def or_(a, b):
|
||||
"""Functional form of "or". Uses lazy evaluation semantics."""
|
||||
a_val = a()
|
||||
if tensor_util.is_tensor(a_val):
|
||||
return _tf_lazy_or(a_val, b)
|
||||
return _py_lazy_or(a_val, b)
|
||||
return a_val or b()
|
||||
|
||||
|
||||
def _tf_lazy_or(cond, b):
|
||||
@ -63,16 +60,11 @@ def _tf_lazy_or(cond, b):
|
||||
return control_flow_ops.cond(cond, lambda: cond, b)
|
||||
|
||||
|
||||
def _py_lazy_or(cond, b):
|
||||
"""Lazy-eval equivalent of "or" in Python."""
|
||||
return cond or b()
|
||||
|
||||
|
||||
def eq(a, b):
|
||||
"""Functional form of "equal"."""
|
||||
if tensor_util.is_tensor(a) or tensor_util.is_tensor(b):
|
||||
return _tf_equal(a, b)
|
||||
return _py_equal(a, b)
|
||||
return a == b
|
||||
|
||||
|
||||
def _tf_equal(a, b):
|
||||
@ -80,11 +72,6 @@ def _tf_equal(a, b):
|
||||
return gen_math_ops.equal(a, b)
|
||||
|
||||
|
||||
def _py_equal(a, b):
|
||||
"""Overload of "equal" that falls back to Python's default implementation."""
|
||||
return a == b
|
||||
|
||||
|
||||
def not_eq(a, b):
|
||||
"""Functional form of "not-equal"."""
|
||||
return not_(eq(a, b))
|
||||
@ -92,25 +79,8 @@ def not_eq(a, b):
|
||||
|
||||
# Default implementation for the remainings.
|
||||
|
||||
|
||||
def gt(a, b):
|
||||
"""Functional form of "less-than"."""
|
||||
return a > b
|
||||
|
||||
|
||||
def gt_e(a, b):
|
||||
"""Functional form of "less-than"."""
|
||||
return a >= b
|
||||
|
||||
|
||||
def is_(a, b):
|
||||
"""Functional form of "less-than"."""
|
||||
return a is b
|
||||
|
||||
|
||||
def is_not(a, b):
|
||||
"""Functional form of "less-than"."""
|
||||
return a is not b
|
||||
is_ = operator.is_
|
||||
is_not = operator.is_not
|
||||
|
||||
|
||||
def in_(a, b):
|
||||
@ -119,21 +89,16 @@ def in_(a, b):
|
||||
return a in b
|
||||
|
||||
|
||||
def lt(a, b):
|
||||
"""Functional form of "less-than"."""
|
||||
return a < b
|
||||
|
||||
|
||||
def lt_e(a, b):
|
||||
"""Functional form of "less-than"."""
|
||||
return a <= b
|
||||
|
||||
|
||||
def not_in(a, b):
|
||||
"""Functional form of "less-than"."""
|
||||
return a not in b
|
||||
|
||||
gt = operator.gt
|
||||
gt_e = operator.ge
|
||||
lt = operator.lt
|
||||
lt_e = operator.le
|
||||
|
||||
def u_sub(a):
|
||||
"""Functional form of "unary-sub"."""
|
||||
return -a
|
||||
|
||||
u_add = operator.pos
|
||||
u_sub = operator.neg
|
||||
invert = operator.invert
|
||||
|
Loading…
Reference in New Issue
Block a user