Reactivate logical operators: not, and, or, which aren't supported by Python's operator overloading. Move the equality operator under a new experimental feature which will be eventually deprecated.

PiperOrigin-RevId: 248873569
This commit is contained in:
Dan Moldovan 2019-05-18 09:57:35 -07:00 committed by TensorFlower Gardener
parent ec76369176
commit bcbb1db93f
7 changed files with 68 additions and 141 deletions

View File

@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converter for logical expressions.
e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF.
"""
"""Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`."""
from __future__ import absolute_import
from __future__ import division
@ -24,7 +21,6 @@ from __future__ import print_function
import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import templates
@ -38,128 +34,104 @@ from tensorflow.python.autograph.pyct import templates
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
OP_MAPPING = {
LOGICAL_OPERATORS = {
gast.And: 'ag__.and_',
gast.Not: 'ag__.not_',
gast.Or: 'ag__.or_',
}
EQUALITY_OPERATORS = {
gast.Eq: 'ag__.eq',
gast.NotEq: 'ag__.not_eq',
gast.Lt: 'ag__.lt',
gast.LtE: 'ag__.lt_e',
gast.Gt: 'ag__.gt',
gast.GtE: 'ag__.gt_e',
gast.Is: 'ag__.is_',
gast.IsNot: 'ag__.is_not',
gast.In: 'ag__.in_',
gast.Not: 'ag__.not_',
gast.NotIn: 'ag__.not_in',
gast.Or: 'ag__.or_',
gast.UAdd: 'ag__.u_add',
gast.USub: 'ag__.u_sub',
gast.Invert: 'ag__.invert',
}
class LogicalExpressionTransformer(converter.Base):
"""Converts logical expressions to corresponding TF calls."""
def _expect_simple_symbol(self, operand):
if isinstance(operand, gast.Name):
return
if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND):
return
raise NotImplementedError(
'only simple local variables are supported in logical and compound '
'comparison expressions; for example, we support "a or b" but not '
'"a.x or b"; for a workaround, assign the expression to a local '
'variable and use that instead, for example "tmp = a.x", "tmp or b"')
def _has_matching_func(self, operator):
def _overload_of(self, operator):
op_type = type(operator)
return op_type in OP_MAPPING
if op_type in LOGICAL_OPERATORS:
return LOGICAL_OPERATORS[op_type]
if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS):
if op_type in EQUALITY_OPERATORS:
return EQUALITY_OPERATORS[op_type]
return None
def _matching_func(self, operator):
op_type = type(operator)
return OP_MAPPING[op_type]
def _as_lambda(self, expr):
return templates.replace_as_expression('lambda: expr', expr=expr)
def _as_function(self, func_name, args, args_as_lambda=False):
if args_as_lambda:
args_as_lambda = []
for arg in args:
template = """
lambda: arg
"""
args_as_lambda.append(
templates.replace_as_expression(template, arg=arg))
args = args_as_lambda
def _as_binary_function(self, func_name, arg1, arg2):
return templates.replace_as_expression(
'func_name(arg1, arg2)',
func_name=parser.parse_expression(func_name),
arg1=arg1,
arg2=arg2)
if not args:
template = """
func_name()
"""
replacement = templates.replace_as_expression(
template, func_name=parser.parse_expression(func_name))
elif len(args) == 1:
template = """
func_name(arg)
"""
replacement = templates.replace_as_expression(
template, func_name=parser.parse_expression(func_name), arg=args[0])
elif len(args) == 2:
template = """
func_name(arg1, arg2)
"""
replacement = templates.replace_as_expression(
template,
func_name=parser.parse_expression(func_name),
arg1=args[0],
arg2=args[1])
else:
raise NotImplementedError('{} arguments for {}'.format(
len(args), func_name))
def _as_binary_operation(self, op, arg1, arg2):
template = templates.replace_as_expression(
'arg1 is arg2',
arg1=arg1,
arg2=arg2)
template.ops[0] = op
return template
anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
return replacement
def _as_unary_function(self, func_name, arg):
return templates.replace_as_expression(
'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg)
def visit_Compare(self, node):
node = self.generic_visit(node)
if (not self.ctx.program.options.uses(
converter.Feature.EQUALITY_OPERATORS)):
return node
ops_and_comps = list(zip(node.ops, node.comparators))
left = node.left
op_tree = None
# Repeated comparisons are converted to conjunctions:
# a < b < c -> a < b and b < c
op_tree = None
while ops_and_comps:
op, right = ops_and_comps.pop(0)
binary_comparison = self._as_function(
self._matching_func(op), (left, right))
if isinstance(left, gast.Name) and isinstance(right, gast.Name):
anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
if op_tree:
self._expect_simple_symbol(right)
op_tree = self._as_function(
'ag__.and_', (op_tree, binary_comparison), args_as_lambda=True)
overload = self._overload_of(op)
if overload is not None:
binary_comparison = self._as_binary_function(overload, left, right)
else:
binary_comparison = self._as_binary_operation(op, left, right)
if op_tree is not None:
op_tree = self._as_binary_function('ag__.and_',
self._as_lambda(op_tree),
self._as_lambda(binary_comparison))
else:
op_tree = binary_comparison
left = right
assert op_tree is not None
return op_tree
def visit_UnaryOp(self, node):
node = self.generic_visit(node)
return self._as_function(self._matching_func(node.op), (node.operand,))
overload = self._overload_of(node.op)
if overload is None:
return node
return self._as_unary_function(overload, node.operand)
def visit_BoolOp(self, node):
node = self.generic_visit(node)
node_values = node.values
right = node.values.pop()
self._expect_simple_symbol(right)
while node_values:
left = node_values.pop()
self._expect_simple_symbol(left)
right = self._as_function(
self._matching_func(node.op), (left, right), args_as_lambda=True)
right = self._as_binary_function(
self._overload_of(node.op), self._as_lambda(left),
self._as_lambda(right))
return right
def transform(node, ctx):
return LogicalExpressionTransformer(ctx).visit(node)
transformer = LogicalExpressionTransformer(ctx)
return transformer.visit(node)

View File

@ -96,9 +96,10 @@ class Feature(enum.Enum):
ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert.
BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to
their TF counterparts.
EQUALITY_OPERATORS: Whether to convert the comparison operators, like
equality. This is soon to be deprecated as support is being added to the
Tensor class.
LISTS: Convert list idioms, like initializers, slices, append, etc.
LOGICAL_EXPRESSIONS: Convert data-dependent logical expressions applied to
Tensors to their TF counterparts.
NAME_SCOPES: Insert name scopes that name ops according to context, like the
function they were defined in.
"""
@ -108,8 +109,8 @@ class Feature(enum.Enum):
AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS'
ASSERT_STATEMENTS = 'ASSERT_STATEMENTS'
BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS'
EQUALITY_OPERATORS = 'EQUALITY_OPERATORS'
LISTS = 'LISTS'
LOGICAL_EXPRESSIONS = 'LOGICAL_EXPRESSIONS'
NAME_SCOPES = 'NAME_SCOPES'
@classmethod

View File

@ -677,8 +677,7 @@ def node_to_graph(node, context):
node = converter.apply_(node, context, call_trees)
node = converter.apply_(node, context, control_flow)
node = converter.apply_(node, context, conditional_expressions)
if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS):
node = converter.apply_(node, context, logical_expressions)
node = converter.apply_(node, context, logical_expressions)
if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
node = converter.apply_(node, context, side_effect_guards)
# TODO(mdan): If function scopes ever does more, the toggle will need moving.

View File

@ -49,20 +49,9 @@ from tensorflow.python.autograph.operators.data_structures import new_list
from tensorflow.python.autograph.operators.exceptions import assert_stmt
from tensorflow.python.autograph.operators.logical import and_
from tensorflow.python.autograph.operators.logical import eq
from tensorflow.python.autograph.operators.logical import gt
from tensorflow.python.autograph.operators.logical import gt_e
from tensorflow.python.autograph.operators.logical import in_
from tensorflow.python.autograph.operators.logical import invert
from tensorflow.python.autograph.operators.logical import is_
from tensorflow.python.autograph.operators.logical import is_not
from tensorflow.python.autograph.operators.logical import lt
from tensorflow.python.autograph.operators.logical import lt_e
from tensorflow.python.autograph.operators.logical import not_
from tensorflow.python.autograph.operators.logical import not_eq
from tensorflow.python.autograph.operators.logical import not_in
from tensorflow.python.autograph.operators.logical import or_
from tensorflow.python.autograph.operators.logical import u_add
from tensorflow.python.autograph.operators.logical import u_sub
from tensorflow.python.autograph.operators.py_builtins import float_
from tensorflow.python.autograph.operators.py_builtins import int_
from tensorflow.python.autograph.operators.py_builtins import len_

View File

@ -12,24 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logical operators, including comparison and bool operators."""
"""Logical boolean operators: not, and, or."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import operator
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
# Note: the implementations in this file are split into very small-grained
# functions in preparation for the factoring out the more generic pyct library.
# At that time, the py_* and tf_* functions will reside in different libraries.
def not_(a):
"""Functional form of "not"."""
if tensor_util.is_tensor(a):
@ -105,30 +98,3 @@ def _py_equal(a, b):
def not_eq(a, b):
"""Functional form of "not-equal"."""
return not_(eq(a, b))
# Default implementation for the rest.
is_ = operator.is_
is_not = operator.is_not
def in_(a, b):
"""Functional form of "in"."""
# TODO(mdan): in and not_in should probably be convertible for some types.
return a in b
def not_in(a, b):
"""Functional form of "not-in"."""
return a not in b
gt = operator.gt
gt_e = operator.ge
lt = operator.lt
lt_e = operator.le
u_add = operator.pos
u_sub = operator.neg
invert = operator.invert

View File

@ -18,11 +18,11 @@ tf_class {
mtype: "<enum \'Feature\'>"
}
member {
name: "LISTS"
name: "EQUALITY_OPERATORS"
mtype: "<enum \'Feature\'>"
}
member {
name: "LOGICAL_EXPRESSIONS"
name: "LISTS"
mtype: "<enum \'Feature\'>"
}
member {

View File

@ -18,11 +18,11 @@ tf_class {
mtype: "<enum \'Feature\'>"
}
member {
name: "LISTS"
name: "EQUALITY_OPERATORS"
mtype: "<enum \'Feature\'>"
}
member {
name: "LOGICAL_EXPRESSIONS"
name: "LISTS"
mtype: "<enum \'Feature\'>"
}
member {