Reactivate logical operators: not, and, or, which aren't supported by Python's operator overloading. Move the equality operator under a new experimental feature which will be eventually deprecated.
PiperOrigin-RevId: 248873569
This commit is contained in:
parent
ec76369176
commit
bcbb1db93f
@ -12,10 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Converter for logical expressions.
|
||||
|
||||
e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF.
|
||||
"""
|
||||
"""Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -24,7 +21,6 @@ from __future__ import print_function
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
|
||||
@ -38,128 +34,104 @@ from tensorflow.python.autograph.pyct import templates
|
||||
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
|
||||
|
||||
|
||||
OP_MAPPING = {
|
||||
LOGICAL_OPERATORS = {
|
||||
gast.And: 'ag__.and_',
|
||||
gast.Not: 'ag__.not_',
|
||||
gast.Or: 'ag__.or_',
|
||||
}
|
||||
|
||||
EQUALITY_OPERATORS = {
|
||||
gast.Eq: 'ag__.eq',
|
||||
gast.NotEq: 'ag__.not_eq',
|
||||
gast.Lt: 'ag__.lt',
|
||||
gast.LtE: 'ag__.lt_e',
|
||||
gast.Gt: 'ag__.gt',
|
||||
gast.GtE: 'ag__.gt_e',
|
||||
gast.Is: 'ag__.is_',
|
||||
gast.IsNot: 'ag__.is_not',
|
||||
gast.In: 'ag__.in_',
|
||||
gast.Not: 'ag__.not_',
|
||||
gast.NotIn: 'ag__.not_in',
|
||||
gast.Or: 'ag__.or_',
|
||||
gast.UAdd: 'ag__.u_add',
|
||||
gast.USub: 'ag__.u_sub',
|
||||
gast.Invert: 'ag__.invert',
|
||||
}
|
||||
|
||||
|
||||
class LogicalExpressionTransformer(converter.Base):
|
||||
"""Converts logical expressions to corresponding TF calls."""
|
||||
|
||||
def _expect_simple_symbol(self, operand):
|
||||
if isinstance(operand, gast.Name):
|
||||
return
|
||||
if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND):
|
||||
return
|
||||
raise NotImplementedError(
|
||||
'only simple local variables are supported in logical and compound '
|
||||
'comparison expressions; for example, we support "a or b" but not '
|
||||
'"a.x or b"; for a workaround, assign the expression to a local '
|
||||
'variable and use that instead, for example "tmp = a.x", "tmp or b"')
|
||||
|
||||
def _has_matching_func(self, operator):
|
||||
def _overload_of(self, operator):
|
||||
op_type = type(operator)
|
||||
return op_type in OP_MAPPING
|
||||
if op_type in LOGICAL_OPERATORS:
|
||||
return LOGICAL_OPERATORS[op_type]
|
||||
if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS):
|
||||
if op_type in EQUALITY_OPERATORS:
|
||||
return EQUALITY_OPERATORS[op_type]
|
||||
return None
|
||||
|
||||
def _matching_func(self, operator):
|
||||
op_type = type(operator)
|
||||
return OP_MAPPING[op_type]
|
||||
def _as_lambda(self, expr):
|
||||
return templates.replace_as_expression('lambda: expr', expr=expr)
|
||||
|
||||
def _as_function(self, func_name, args, args_as_lambda=False):
|
||||
if args_as_lambda:
|
||||
args_as_lambda = []
|
||||
for arg in args:
|
||||
template = """
|
||||
lambda: arg
|
||||
"""
|
||||
args_as_lambda.append(
|
||||
templates.replace_as_expression(template, arg=arg))
|
||||
args = args_as_lambda
|
||||
def _as_binary_function(self, func_name, arg1, arg2):
|
||||
return templates.replace_as_expression(
|
||||
'func_name(arg1, arg2)',
|
||||
func_name=parser.parse_expression(func_name),
|
||||
arg1=arg1,
|
||||
arg2=arg2)
|
||||
|
||||
if not args:
|
||||
template = """
|
||||
func_name()
|
||||
"""
|
||||
replacement = templates.replace_as_expression(
|
||||
template, func_name=parser.parse_expression(func_name))
|
||||
elif len(args) == 1:
|
||||
template = """
|
||||
func_name(arg)
|
||||
"""
|
||||
replacement = templates.replace_as_expression(
|
||||
template, func_name=parser.parse_expression(func_name), arg=args[0])
|
||||
elif len(args) == 2:
|
||||
template = """
|
||||
func_name(arg1, arg2)
|
||||
"""
|
||||
replacement = templates.replace_as_expression(
|
||||
template,
|
||||
func_name=parser.parse_expression(func_name),
|
||||
arg1=args[0],
|
||||
arg2=args[1])
|
||||
else:
|
||||
raise NotImplementedError('{} arguments for {}'.format(
|
||||
len(args), func_name))
|
||||
def _as_binary_operation(self, op, arg1, arg2):
|
||||
template = templates.replace_as_expression(
|
||||
'arg1 is arg2',
|
||||
arg1=arg1,
|
||||
arg2=arg2)
|
||||
template.ops[0] = op
|
||||
return template
|
||||
|
||||
anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
|
||||
return replacement
|
||||
def _as_unary_function(self, func_name, arg):
|
||||
return templates.replace_as_expression(
|
||||
'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg)
|
||||
|
||||
def visit_Compare(self, node):
|
||||
node = self.generic_visit(node)
|
||||
|
||||
if (not self.ctx.program.options.uses(
|
||||
converter.Feature.EQUALITY_OPERATORS)):
|
||||
return node
|
||||
|
||||
ops_and_comps = list(zip(node.ops, node.comparators))
|
||||
left = node.left
|
||||
op_tree = None
|
||||
|
||||
# Repeated comparisons are converted to conjunctions:
|
||||
# a < b < c -> a < b and b < c
|
||||
op_tree = None
|
||||
while ops_and_comps:
|
||||
op, right = ops_and_comps.pop(0)
|
||||
binary_comparison = self._as_function(
|
||||
self._matching_func(op), (left, right))
|
||||
if isinstance(left, gast.Name) and isinstance(right, gast.Name):
|
||||
anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
|
||||
if op_tree:
|
||||
self._expect_simple_symbol(right)
|
||||
op_tree = self._as_function(
|
||||
'ag__.and_', (op_tree, binary_comparison), args_as_lambda=True)
|
||||
overload = self._overload_of(op)
|
||||
if overload is not None:
|
||||
binary_comparison = self._as_binary_function(overload, left, right)
|
||||
else:
|
||||
binary_comparison = self._as_binary_operation(op, left, right)
|
||||
if op_tree is not None:
|
||||
op_tree = self._as_binary_function('ag__.and_',
|
||||
self._as_lambda(op_tree),
|
||||
self._as_lambda(binary_comparison))
|
||||
else:
|
||||
op_tree = binary_comparison
|
||||
left = right
|
||||
|
||||
assert op_tree is not None
|
||||
return op_tree
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
node = self.generic_visit(node)
|
||||
return self._as_function(self._matching_func(node.op), (node.operand,))
|
||||
|
||||
overload = self._overload_of(node.op)
|
||||
if overload is None:
|
||||
return node
|
||||
|
||||
return self._as_unary_function(overload, node.operand)
|
||||
|
||||
def visit_BoolOp(self, node):
|
||||
node = self.generic_visit(node)
|
||||
node_values = node.values
|
||||
right = node.values.pop()
|
||||
self._expect_simple_symbol(right)
|
||||
while node_values:
|
||||
left = node_values.pop()
|
||||
self._expect_simple_symbol(left)
|
||||
right = self._as_function(
|
||||
self._matching_func(node.op), (left, right), args_as_lambda=True)
|
||||
right = self._as_binary_function(
|
||||
self._overload_of(node.op), self._as_lambda(left),
|
||||
self._as_lambda(right))
|
||||
return right
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
return LogicalExpressionTransformer(ctx).visit(node)
|
||||
transformer = LogicalExpressionTransformer(ctx)
|
||||
return transformer.visit(node)
|
||||
|
@ -96,9 +96,10 @@ class Feature(enum.Enum):
|
||||
ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert.
|
||||
BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to
|
||||
their TF counterparts.
|
||||
EQUALITY_OPERATORS: Whether to convert the comparison operators, like
|
||||
equality. This is soon to be deprecated as support is being added to the
|
||||
Tensor class.
|
||||
LISTS: Convert list idioms, like initializers, slices, append, etc.
|
||||
LOGICAL_EXPRESSIONS: Convert data-dependent logical expressions applied to
|
||||
Tensors to their TF counterparts.
|
||||
NAME_SCOPES: Insert name scopes that name ops according to context, like the
|
||||
function they were defined in.
|
||||
"""
|
||||
@ -108,8 +109,8 @@ class Feature(enum.Enum):
|
||||
AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS'
|
||||
ASSERT_STATEMENTS = 'ASSERT_STATEMENTS'
|
||||
BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS'
|
||||
EQUALITY_OPERATORS = 'EQUALITY_OPERATORS'
|
||||
LISTS = 'LISTS'
|
||||
LOGICAL_EXPRESSIONS = 'LOGICAL_EXPRESSIONS'
|
||||
NAME_SCOPES = 'NAME_SCOPES'
|
||||
|
||||
@classmethod
|
||||
|
@ -677,8 +677,7 @@ def node_to_graph(node, context):
|
||||
node = converter.apply_(node, context, call_trees)
|
||||
node = converter.apply_(node, context, control_flow)
|
||||
node = converter.apply_(node, context, conditional_expressions)
|
||||
if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS):
|
||||
node = converter.apply_(node, context, logical_expressions)
|
||||
node = converter.apply_(node, context, logical_expressions)
|
||||
if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
|
||||
node = converter.apply_(node, context, side_effect_guards)
|
||||
# TODO(mdan): If function scopes ever does more, the toggle will need moving.
|
||||
|
@ -49,20 +49,9 @@ from tensorflow.python.autograph.operators.data_structures import new_list
|
||||
from tensorflow.python.autograph.operators.exceptions import assert_stmt
|
||||
from tensorflow.python.autograph.operators.logical import and_
|
||||
from tensorflow.python.autograph.operators.logical import eq
|
||||
from tensorflow.python.autograph.operators.logical import gt
|
||||
from tensorflow.python.autograph.operators.logical import gt_e
|
||||
from tensorflow.python.autograph.operators.logical import in_
|
||||
from tensorflow.python.autograph.operators.logical import invert
|
||||
from tensorflow.python.autograph.operators.logical import is_
|
||||
from tensorflow.python.autograph.operators.logical import is_not
|
||||
from tensorflow.python.autograph.operators.logical import lt
|
||||
from tensorflow.python.autograph.operators.logical import lt_e
|
||||
from tensorflow.python.autograph.operators.logical import not_
|
||||
from tensorflow.python.autograph.operators.logical import not_eq
|
||||
from tensorflow.python.autograph.operators.logical import not_in
|
||||
from tensorflow.python.autograph.operators.logical import or_
|
||||
from tensorflow.python.autograph.operators.logical import u_add
|
||||
from tensorflow.python.autograph.operators.logical import u_sub
|
||||
from tensorflow.python.autograph.operators.py_builtins import float_
|
||||
from tensorflow.python.autograph.operators.py_builtins import int_
|
||||
from tensorflow.python.autograph.operators.py_builtins import len_
|
||||
|
@ -12,24 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Logical operators, including comparison and bool operators."""
|
||||
"""Logical boolean operators: not, and, or."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import operator
|
||||
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
|
||||
|
||||
# Note: the implementations in this file are split into very small-grained
|
||||
# functions in preparation for the factoring out the more generic pyct library.
|
||||
# At that time, the py_* and tf_* functions will reside in different libraries.
|
||||
|
||||
|
||||
def not_(a):
|
||||
"""Functional form of "not"."""
|
||||
if tensor_util.is_tensor(a):
|
||||
@ -105,30 +98,3 @@ def _py_equal(a, b):
|
||||
def not_eq(a, b):
|
||||
"""Functional form of "not-equal"."""
|
||||
return not_(eq(a, b))
|
||||
|
||||
|
||||
# Default implementation for the rest.
|
||||
|
||||
is_ = operator.is_
|
||||
is_not = operator.is_not
|
||||
|
||||
|
||||
def in_(a, b):
|
||||
"""Functional form of "in"."""
|
||||
# TODO(mdan): in and not_in should probably be convertible for some types.
|
||||
return a in b
|
||||
|
||||
|
||||
def not_in(a, b):
|
||||
"""Functional form of "not-in"."""
|
||||
return a not in b
|
||||
|
||||
gt = operator.gt
|
||||
gt_e = operator.ge
|
||||
lt = operator.lt
|
||||
lt_e = operator.le
|
||||
|
||||
|
||||
u_add = operator.pos
|
||||
u_sub = operator.neg
|
||||
invert = operator.invert
|
||||
|
@ -18,11 +18,11 @@ tf_class {
|
||||
mtype: "<enum \'Feature\'>"
|
||||
}
|
||||
member {
|
||||
name: "LISTS"
|
||||
name: "EQUALITY_OPERATORS"
|
||||
mtype: "<enum \'Feature\'>"
|
||||
}
|
||||
member {
|
||||
name: "LOGICAL_EXPRESSIONS"
|
||||
name: "LISTS"
|
||||
mtype: "<enum \'Feature\'>"
|
||||
}
|
||||
member {
|
||||
|
@ -18,11 +18,11 @@ tf_class {
|
||||
mtype: "<enum \'Feature\'>"
|
||||
}
|
||||
member {
|
||||
name: "LISTS"
|
||||
name: "EQUALITY_OPERATORS"
|
||||
mtype: "<enum \'Feature\'>"
|
||||
}
|
||||
member {
|
||||
name: "LOGICAL_EXPRESSIONS"
|
||||
name: "LISTS"
|
||||
mtype: "<enum \'Feature\'>"
|
||||
}
|
||||
member {
|
||||
|
Loading…
Reference in New Issue
Block a user