Reactivate logical operators: not, and, or, which aren't supported by Python's operator overloading. Move the equality operator under a new experimental feature which will be eventually deprecated.

PiperOrigin-RevId: 248873569
This commit is contained in:
Dan Moldovan 2019-05-18 09:57:35 -07:00 committed by TensorFlower Gardener
parent ec76369176
commit bcbb1db93f
7 changed files with 68 additions and 141 deletions

View File

@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Converter for logical expressions. """Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`."""
e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF.
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -24,7 +21,6 @@ from __future__ import print_function
import gast import gast
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import templates from tensorflow.python.autograph.pyct import templates
@ -38,128 +34,104 @@ from tensorflow.python.autograph.pyct import templates
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
OP_MAPPING = { LOGICAL_OPERATORS = {
gast.And: 'ag__.and_', gast.And: 'ag__.and_',
gast.Not: 'ag__.not_',
gast.Or: 'ag__.or_',
}
EQUALITY_OPERATORS = {
gast.Eq: 'ag__.eq', gast.Eq: 'ag__.eq',
gast.NotEq: 'ag__.not_eq', gast.NotEq: 'ag__.not_eq',
gast.Lt: 'ag__.lt',
gast.LtE: 'ag__.lt_e',
gast.Gt: 'ag__.gt',
gast.GtE: 'ag__.gt_e',
gast.Is: 'ag__.is_',
gast.IsNot: 'ag__.is_not',
gast.In: 'ag__.in_',
gast.Not: 'ag__.not_',
gast.NotIn: 'ag__.not_in',
gast.Or: 'ag__.or_',
gast.UAdd: 'ag__.u_add',
gast.USub: 'ag__.u_sub',
gast.Invert: 'ag__.invert',
} }
class LogicalExpressionTransformer(converter.Base): class LogicalExpressionTransformer(converter.Base):
"""Converts logical expressions to corresponding TF calls.""" """Converts logical expressions to corresponding TF calls."""
def _expect_simple_symbol(self, operand): def _overload_of(self, operator):
if isinstance(operand, gast.Name):
return
if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND):
return
raise NotImplementedError(
'only simple local variables are supported in logical and compound '
'comparison expressions; for example, we support "a or b" but not '
'"a.x or b"; for a workaround, assign the expression to a local '
'variable and use that instead, for example "tmp = a.x", "tmp or b"')
def _has_matching_func(self, operator):
op_type = type(operator) op_type = type(operator)
return op_type in OP_MAPPING if op_type in LOGICAL_OPERATORS:
return LOGICAL_OPERATORS[op_type]
if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS):
if op_type in EQUALITY_OPERATORS:
return EQUALITY_OPERATORS[op_type]
return None
def _matching_func(self, operator): def _as_lambda(self, expr):
op_type = type(operator) return templates.replace_as_expression('lambda: expr', expr=expr)
return OP_MAPPING[op_type]
def _as_function(self, func_name, args, args_as_lambda=False): def _as_binary_function(self, func_name, arg1, arg2):
if args_as_lambda: return templates.replace_as_expression(
args_as_lambda = [] 'func_name(arg1, arg2)',
for arg in args: func_name=parser.parse_expression(func_name),
template = """ arg1=arg1,
lambda: arg arg2=arg2)
"""
args_as_lambda.append(
templates.replace_as_expression(template, arg=arg))
args = args_as_lambda
if not args: def _as_binary_operation(self, op, arg1, arg2):
template = """ template = templates.replace_as_expression(
func_name() 'arg1 is arg2',
""" arg1=arg1,
replacement = templates.replace_as_expression( arg2=arg2)
template, func_name=parser.parse_expression(func_name)) template.ops[0] = op
elif len(args) == 1: return template
template = """
func_name(arg)
"""
replacement = templates.replace_as_expression(
template, func_name=parser.parse_expression(func_name), arg=args[0])
elif len(args) == 2:
template = """
func_name(arg1, arg2)
"""
replacement = templates.replace_as_expression(
template,
func_name=parser.parse_expression(func_name),
arg1=args[0],
arg2=args[1])
else:
raise NotImplementedError('{} arguments for {}'.format(
len(args), func_name))
anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) def _as_unary_function(self, func_name, arg):
return replacement return templates.replace_as_expression(
'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg)
def visit_Compare(self, node): def visit_Compare(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
if (not self.ctx.program.options.uses(
converter.Feature.EQUALITY_OPERATORS)):
return node
ops_and_comps = list(zip(node.ops, node.comparators)) ops_and_comps = list(zip(node.ops, node.comparators))
left = node.left left = node.left
op_tree = None
# Repeated comparisons are converted to conjunctions: # Repeated comparisons are converted to conjunctions:
# a < b < c -> a < b and b < c # a < b < c -> a < b and b < c
op_tree = None
while ops_and_comps: while ops_and_comps:
op, right = ops_and_comps.pop(0) op, right = ops_and_comps.pop(0)
binary_comparison = self._as_function( overload = self._overload_of(op)
self._matching_func(op), (left, right)) if overload is not None:
if isinstance(left, gast.Name) and isinstance(right, gast.Name): binary_comparison = self._as_binary_function(overload, left, right)
anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) else:
if op_tree: binary_comparison = self._as_binary_operation(op, left, right)
self._expect_simple_symbol(right) if op_tree is not None:
op_tree = self._as_function( op_tree = self._as_binary_function('ag__.and_',
'ag__.and_', (op_tree, binary_comparison), args_as_lambda=True) self._as_lambda(op_tree),
self._as_lambda(binary_comparison))
else: else:
op_tree = binary_comparison op_tree = binary_comparison
left = right left = right
assert op_tree is not None assert op_tree is not None
return op_tree return op_tree
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
return self._as_function(self._matching_func(node.op), (node.operand,))
overload = self._overload_of(node.op)
if overload is None:
return node
return self._as_unary_function(overload, node.operand)
def visit_BoolOp(self, node): def visit_BoolOp(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
node_values = node.values node_values = node.values
right = node.values.pop() right = node.values.pop()
self._expect_simple_symbol(right)
while node_values: while node_values:
left = node_values.pop() left = node_values.pop()
self._expect_simple_symbol(left) right = self._as_binary_function(
right = self._as_function( self._overload_of(node.op), self._as_lambda(left),
self._matching_func(node.op), (left, right), args_as_lambda=True) self._as_lambda(right))
return right return right
def transform(node, ctx): def transform(node, ctx):
return LogicalExpressionTransformer(ctx).visit(node) transformer = LogicalExpressionTransformer(ctx)
return transformer.visit(node)

View File

@ -96,9 +96,10 @@ class Feature(enum.Enum):
ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert. ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert.
BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to
their TF counterparts. their TF counterparts.
EQUALITY_OPERATORS: Whether to convert the comparison operators, like
equality. This is soon to be deprecated as support is being added to the
Tensor class.
LISTS: Convert list idioms, like initializers, slices, append, etc. LISTS: Convert list idioms, like initializers, slices, append, etc.
LOGICAL_EXPRESSIONS: Convert data-dependent logical expressions applied to
Tensors to their TF counterparts.
NAME_SCOPES: Insert name scopes that name ops according to context, like the NAME_SCOPES: Insert name scopes that name ops according to context, like the
function they were defined in. function they were defined in.
""" """
@ -108,8 +109,8 @@ class Feature(enum.Enum):
AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS' AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS'
ASSERT_STATEMENTS = 'ASSERT_STATEMENTS' ASSERT_STATEMENTS = 'ASSERT_STATEMENTS'
BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS' BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS'
EQUALITY_OPERATORS = 'EQUALITY_OPERATORS'
LISTS = 'LISTS' LISTS = 'LISTS'
LOGICAL_EXPRESSIONS = 'LOGICAL_EXPRESSIONS'
NAME_SCOPES = 'NAME_SCOPES' NAME_SCOPES = 'NAME_SCOPES'
@classmethod @classmethod

View File

@ -677,8 +677,7 @@ def node_to_graph(node, context):
node = converter.apply_(node, context, call_trees) node = converter.apply_(node, context, call_trees)
node = converter.apply_(node, context, control_flow) node = converter.apply_(node, context, control_flow)
node = converter.apply_(node, context, conditional_expressions) node = converter.apply_(node, context, conditional_expressions)
if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS): node = converter.apply_(node, context, logical_expressions)
node = converter.apply_(node, context, logical_expressions)
if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS): if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
node = converter.apply_(node, context, side_effect_guards) node = converter.apply_(node, context, side_effect_guards)
# TODO(mdan): If function scopes ever does more, the toggle will need moving. # TODO(mdan): If function scopes ever does more, the toggle will need moving.

View File

@ -49,20 +49,9 @@ from tensorflow.python.autograph.operators.data_structures import new_list
from tensorflow.python.autograph.operators.exceptions import assert_stmt from tensorflow.python.autograph.operators.exceptions import assert_stmt
from tensorflow.python.autograph.operators.logical import and_ from tensorflow.python.autograph.operators.logical import and_
from tensorflow.python.autograph.operators.logical import eq from tensorflow.python.autograph.operators.logical import eq
from tensorflow.python.autograph.operators.logical import gt
from tensorflow.python.autograph.operators.logical import gt_e
from tensorflow.python.autograph.operators.logical import in_
from tensorflow.python.autograph.operators.logical import invert
from tensorflow.python.autograph.operators.logical import is_
from tensorflow.python.autograph.operators.logical import is_not
from tensorflow.python.autograph.operators.logical import lt
from tensorflow.python.autograph.operators.logical import lt_e
from tensorflow.python.autograph.operators.logical import not_ from tensorflow.python.autograph.operators.logical import not_
from tensorflow.python.autograph.operators.logical import not_eq from tensorflow.python.autograph.operators.logical import not_eq
from tensorflow.python.autograph.operators.logical import not_in
from tensorflow.python.autograph.operators.logical import or_ from tensorflow.python.autograph.operators.logical import or_
from tensorflow.python.autograph.operators.logical import u_add
from tensorflow.python.autograph.operators.logical import u_sub
from tensorflow.python.autograph.operators.py_builtins import float_ from tensorflow.python.autograph.operators.py_builtins import float_
from tensorflow.python.autograph.operators.py_builtins import int_ from tensorflow.python.autograph.operators.py_builtins import int_
from tensorflow.python.autograph.operators.py_builtins import len_ from tensorflow.python.autograph.operators.py_builtins import len_

View File

@ -12,24 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Logical operators, including comparison and bool operators.""" """Logical boolean operators: not, and, or."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import operator
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_math_ops
# Note: the implementations in this file are split into very small-grained
# functions in preparation for the factoring out the more generic pyct library.
# At that time, the py_* and tf_* functions will reside in different libraries.
def not_(a): def not_(a):
"""Functional form of "not".""" """Functional form of "not"."""
if tensor_util.is_tensor(a): if tensor_util.is_tensor(a):
@ -105,30 +98,3 @@ def _py_equal(a, b):
def not_eq(a, b): def not_eq(a, b):
"""Functional form of "not-equal".""" """Functional form of "not-equal"."""
return not_(eq(a, b)) return not_(eq(a, b))
# Default implementation for the rest.
is_ = operator.is_
is_not = operator.is_not
def in_(a, b):
"""Functional form of "in"."""
# TODO(mdan): in and not_in should probably be convertible for some types.
return a in b
def not_in(a, b):
"""Functional form of "not-in"."""
return a not in b
gt = operator.gt
gt_e = operator.ge
lt = operator.lt
lt_e = operator.le
u_add = operator.pos
u_sub = operator.neg
invert = operator.invert

View File

@ -18,11 +18,11 @@ tf_class {
mtype: "<enum \'Feature\'>" mtype: "<enum \'Feature\'>"
} }
member { member {
name: "LISTS" name: "EQUALITY_OPERATORS"
mtype: "<enum \'Feature\'>" mtype: "<enum \'Feature\'>"
} }
member { member {
name: "LOGICAL_EXPRESSIONS" name: "LISTS"
mtype: "<enum \'Feature\'>" mtype: "<enum \'Feature\'>"
} }
member { member {

View File

@ -18,11 +18,11 @@ tf_class {
mtype: "<enum \'Feature\'>" mtype: "<enum \'Feature\'>"
} }
member { member {
name: "LISTS" name: "EQUALITY_OPERATORS"
mtype: "<enum \'Feature\'>" mtype: "<enum \'Feature\'>"
} }
member { member {
name: "LOGICAL_EXPRESSIONS" name: "LISTS"
mtype: "<enum \'Feature\'>" mtype: "<enum \'Feature\'>"
} }
member { member {