Add all unary operations to list of autograph-supported ops.
PiperOrigin-RevId: 227681912
This commit is contained in:
parent
eb54349cb4
commit
31cf2b97f0
tensorflow/python/autograph
@ -38,29 +38,29 @@ from tensorflow.python.autograph.pyct import templates
|
|||||||
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
|
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
|
||||||
|
|
||||||
|
|
||||||
|
OP_MAPPING = {
|
||||||
|
gast.And: 'ag__.and_',
|
||||||
|
gast.Eq: 'ag__.eq',
|
||||||
|
gast.NotEq: 'ag__.not_eq',
|
||||||
|
gast.Lt: 'ag__.lt',
|
||||||
|
gast.LtE: 'ag__.lt_e',
|
||||||
|
gast.Gt: 'ag__.gt',
|
||||||
|
gast.GtE: 'ag__.gt_e',
|
||||||
|
gast.Is: 'ag__.is_',
|
||||||
|
gast.IsNot: 'ag__.is_not',
|
||||||
|
gast.In: 'ag__.in_',
|
||||||
|
gast.Not: 'ag__.not_',
|
||||||
|
gast.NotIn: 'ag__.not_in',
|
||||||
|
gast.Or: 'ag__.or_',
|
||||||
|
gast.UAdd: 'ag__.u_add',
|
||||||
|
gast.USub: 'ag__.u_sub',
|
||||||
|
gast.Invert: 'ag__.invert',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LogicalExpressionTransformer(converter.Base):
|
class LogicalExpressionTransformer(converter.Base):
|
||||||
"""Converts logical expressions to corresponding TF calls."""
|
"""Converts logical expressions to corresponding TF calls."""
|
||||||
|
|
||||||
def __init__(self, ctx):
|
|
||||||
super(LogicalExpressionTransformer, self).__init__(ctx)
|
|
||||||
# TODO(mdan): For completeness and consistency, overload everything.
|
|
||||||
self.op_mapping = {
|
|
||||||
gast.And: 'ag__.and_',
|
|
||||||
gast.Eq: 'ag__.eq',
|
|
||||||
gast.NotEq: 'ag__.not_eq',
|
|
||||||
gast.Lt: 'ag__.lt',
|
|
||||||
gast.LtE: 'ag__.lt_e',
|
|
||||||
gast.Gt: 'ag__.gt',
|
|
||||||
gast.GtE: 'ag__.gt_e',
|
|
||||||
gast.Is: 'ag__.is_',
|
|
||||||
gast.IsNot: 'ag__.is_not',
|
|
||||||
gast.In: 'ag__.in_',
|
|
||||||
gast.Not: 'ag__.not_',
|
|
||||||
gast.NotIn: 'ag__.not_in',
|
|
||||||
gast.Or: 'ag__.or_',
|
|
||||||
gast.USub: 'ag__.u_sub',
|
|
||||||
}
|
|
||||||
|
|
||||||
def _expect_simple_symbol(self, operand):
|
def _expect_simple_symbol(self, operand):
|
||||||
if isinstance(operand, gast.Name):
|
if isinstance(operand, gast.Name):
|
||||||
return
|
return
|
||||||
@ -74,11 +74,11 @@ class LogicalExpressionTransformer(converter.Base):
|
|||||||
|
|
||||||
def _has_matching_func(self, operator):
|
def _has_matching_func(self, operator):
|
||||||
op_type = type(operator)
|
op_type = type(operator)
|
||||||
return op_type in self.op_mapping
|
return op_type in OP_MAPPING
|
||||||
|
|
||||||
def _matching_func(self, operator):
|
def _matching_func(self, operator):
|
||||||
op_type = type(operator)
|
op_type = type(operator)
|
||||||
return self.op_mapping[op_type]
|
return OP_MAPPING[op_type]
|
||||||
|
|
||||||
def _as_function(self, func_name, args, args_as_lambda=False):
|
def _as_function(self, func_name, args, args_as_lambda=False):
|
||||||
if args_as_lambda:
|
if args_as_lambda:
|
||||||
|
@ -77,6 +77,13 @@ class LogicalExpressionTest(converter_testing.TestCase):
|
|||||||
with self.converted(test_fn, logical_expressions, {}) as result:
|
with self.converted(test_fn, logical_expressions, {}) as result:
|
||||||
self.assertTrue(result.test_fn('a', ('a',)))
|
self.assertTrue(result.test_fn('a', ('a',)))
|
||||||
|
|
||||||
|
def test_unary_ops(self):
|
||||||
|
def test_fn(a):
|
||||||
|
return ~a, -a, +a
|
||||||
|
|
||||||
|
with self.converted(test_fn, logical_expressions, {}) as result:
|
||||||
|
self.assertEqual(result.test_fn(1), (-2, -1, 1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -52,6 +52,7 @@ from tensorflow.python.autograph.operators.logical import eq
|
|||||||
from tensorflow.python.autograph.operators.logical import gt
|
from tensorflow.python.autograph.operators.logical import gt
|
||||||
from tensorflow.python.autograph.operators.logical import gt_e
|
from tensorflow.python.autograph.operators.logical import gt_e
|
||||||
from tensorflow.python.autograph.operators.logical import in_
|
from tensorflow.python.autograph.operators.logical import in_
|
||||||
|
from tensorflow.python.autograph.operators.logical import invert
|
||||||
from tensorflow.python.autograph.operators.logical import is_
|
from tensorflow.python.autograph.operators.logical import is_
|
||||||
from tensorflow.python.autograph.operators.logical import is_not
|
from tensorflow.python.autograph.operators.logical import is_not
|
||||||
from tensorflow.python.autograph.operators.logical import lt
|
from tensorflow.python.autograph.operators.logical import lt
|
||||||
@ -60,6 +61,7 @@ from tensorflow.python.autograph.operators.logical import not_
|
|||||||
from tensorflow.python.autograph.operators.logical import not_eq
|
from tensorflow.python.autograph.operators.logical import not_eq
|
||||||
from tensorflow.python.autograph.operators.logical import not_in
|
from tensorflow.python.autograph.operators.logical import not_in
|
||||||
from tensorflow.python.autograph.operators.logical import or_
|
from tensorflow.python.autograph.operators.logical import or_
|
||||||
|
from tensorflow.python.autograph.operators.logical import u_add
|
||||||
from tensorflow.python.autograph.operators.logical import u_sub
|
from tensorflow.python.autograph.operators.logical import u_sub
|
||||||
from tensorflow.python.autograph.operators.py_builtins import float_
|
from tensorflow.python.autograph.operators.py_builtins import float_
|
||||||
from tensorflow.python.autograph.operators.py_builtins import int_
|
from tensorflow.python.autograph.operators.py_builtins import int_
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import operator
|
||||||
|
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
@ -35,7 +37,7 @@ def and_(a, b):
|
|||||||
a_val = a()
|
a_val = a()
|
||||||
if tensor_util.is_tensor(a_val):
|
if tensor_util.is_tensor(a_val):
|
||||||
return _tf_lazy_and(a_val, b)
|
return _tf_lazy_and(a_val, b)
|
||||||
return _py_lazy_and(a_val, b)
|
return a_val and b()
|
||||||
|
|
||||||
|
|
||||||
def _tf_lazy_and(cond, b):
|
def _tf_lazy_and(cond, b):
|
||||||
@ -44,17 +46,12 @@ def _tf_lazy_and(cond, b):
|
|||||||
return control_flow_ops.cond(cond, b, lambda: cond)
|
return control_flow_ops.cond(cond, b, lambda: cond)
|
||||||
|
|
||||||
|
|
||||||
def _py_lazy_and(cond, b):
|
|
||||||
"""Lazy-eval equivalent of "and" in Python."""
|
|
||||||
return cond and b()
|
|
||||||
|
|
||||||
|
|
||||||
def or_(a, b):
|
def or_(a, b):
|
||||||
"""Functional form of "or". Uses lazy evaluation semantics."""
|
"""Functional form of "or". Uses lazy evaluation semantics."""
|
||||||
a_val = a()
|
a_val = a()
|
||||||
if tensor_util.is_tensor(a_val):
|
if tensor_util.is_tensor(a_val):
|
||||||
return _tf_lazy_or(a_val, b)
|
return _tf_lazy_or(a_val, b)
|
||||||
return _py_lazy_or(a_val, b)
|
return a_val or b()
|
||||||
|
|
||||||
|
|
||||||
def _tf_lazy_or(cond, b):
|
def _tf_lazy_or(cond, b):
|
||||||
@ -63,16 +60,11 @@ def _tf_lazy_or(cond, b):
|
|||||||
return control_flow_ops.cond(cond, lambda: cond, b)
|
return control_flow_ops.cond(cond, lambda: cond, b)
|
||||||
|
|
||||||
|
|
||||||
def _py_lazy_or(cond, b):
|
|
||||||
"""Lazy-eval equivalent of "or" in Python."""
|
|
||||||
return cond or b()
|
|
||||||
|
|
||||||
|
|
||||||
def eq(a, b):
|
def eq(a, b):
|
||||||
"""Functional form of "equal"."""
|
"""Functional form of "equal"."""
|
||||||
if tensor_util.is_tensor(a) or tensor_util.is_tensor(b):
|
if tensor_util.is_tensor(a) or tensor_util.is_tensor(b):
|
||||||
return _tf_equal(a, b)
|
return _tf_equal(a, b)
|
||||||
return _py_equal(a, b)
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
def _tf_equal(a, b):
|
def _tf_equal(a, b):
|
||||||
@ -80,11 +72,6 @@ def _tf_equal(a, b):
|
|||||||
return gen_math_ops.equal(a, b)
|
return gen_math_ops.equal(a, b)
|
||||||
|
|
||||||
|
|
||||||
def _py_equal(a, b):
|
|
||||||
"""Overload of "equal" that falls back to Python's default implementation."""
|
|
||||||
return a == b
|
|
||||||
|
|
||||||
|
|
||||||
def not_eq(a, b):
|
def not_eq(a, b):
|
||||||
"""Functional form of "not-equal"."""
|
"""Functional form of "not-equal"."""
|
||||||
return not_(eq(a, b))
|
return not_(eq(a, b))
|
||||||
@ -92,25 +79,8 @@ def not_eq(a, b):
|
|||||||
|
|
||||||
# Default implementation for the remainings.
|
# Default implementation for the remainings.
|
||||||
|
|
||||||
|
is_ = operator.is_
|
||||||
def gt(a, b):
|
is_not = operator.is_not
|
||||||
"""Functional form of "less-than"."""
|
|
||||||
return a > b
|
|
||||||
|
|
||||||
|
|
||||||
def gt_e(a, b):
|
|
||||||
"""Functional form of "less-than"."""
|
|
||||||
return a >= b
|
|
||||||
|
|
||||||
|
|
||||||
def is_(a, b):
|
|
||||||
"""Functional form of "less-than"."""
|
|
||||||
return a is b
|
|
||||||
|
|
||||||
|
|
||||||
def is_not(a, b):
|
|
||||||
"""Functional form of "less-than"."""
|
|
||||||
return a is not b
|
|
||||||
|
|
||||||
|
|
||||||
def in_(a, b):
|
def in_(a, b):
|
||||||
@ -119,21 +89,16 @@ def in_(a, b):
|
|||||||
return a in b
|
return a in b
|
||||||
|
|
||||||
|
|
||||||
def lt(a, b):
|
|
||||||
"""Functional form of "less-than"."""
|
|
||||||
return a < b
|
|
||||||
|
|
||||||
|
|
||||||
def lt_e(a, b):
|
|
||||||
"""Functional form of "less-than"."""
|
|
||||||
return a <= b
|
|
||||||
|
|
||||||
|
|
||||||
def not_in(a, b):
|
def not_in(a, b):
|
||||||
"""Functional form of "less-than"."""
|
"""Functional form of "less-than"."""
|
||||||
return a not in b
|
return a not in b
|
||||||
|
|
||||||
|
gt = operator.gt
|
||||||
|
gt_e = operator.ge
|
||||||
|
lt = operator.lt
|
||||||
|
lt_e = operator.le
|
||||||
|
|
||||||
def u_sub(a):
|
|
||||||
"""Functional form of "unary-sub"."""
|
u_add = operator.pos
|
||||||
return -a
|
u_sub = operator.neg
|
||||||
|
invert = operator.invert
|
||||||
|
Loading…
Reference in New Issue
Block a user