Reactivate logical operators: not, and, or, which aren't supported by Python's operator overloading. Move the equality operator under a new experimental feature which will be eventually deprecated.
PiperOrigin-RevId: 248873569
This commit is contained in:
parent
ec76369176
commit
bcbb1db93f
@ -12,10 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Converter for logical expressions.
|
"""Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`."""
|
||||||
|
|
||||||
e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -24,7 +21,6 @@ from __future__ import print_function
|
|||||||
import gast
|
import gast
|
||||||
|
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.pyct import anno
|
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.autograph.pyct import templates
|
from tensorflow.python.autograph.pyct import templates
|
||||||
|
|
||||||
@ -38,128 +34,104 @@ from tensorflow.python.autograph.pyct import templates
|
|||||||
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
|
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
|
||||||
|
|
||||||
|
|
||||||
OP_MAPPING = {
|
LOGICAL_OPERATORS = {
|
||||||
gast.And: 'ag__.and_',
|
gast.And: 'ag__.and_',
|
||||||
|
gast.Not: 'ag__.not_',
|
||||||
|
gast.Or: 'ag__.or_',
|
||||||
|
}
|
||||||
|
|
||||||
|
EQUALITY_OPERATORS = {
|
||||||
gast.Eq: 'ag__.eq',
|
gast.Eq: 'ag__.eq',
|
||||||
gast.NotEq: 'ag__.not_eq',
|
gast.NotEq: 'ag__.not_eq',
|
||||||
gast.Lt: 'ag__.lt',
|
|
||||||
gast.LtE: 'ag__.lt_e',
|
|
||||||
gast.Gt: 'ag__.gt',
|
|
||||||
gast.GtE: 'ag__.gt_e',
|
|
||||||
gast.Is: 'ag__.is_',
|
|
||||||
gast.IsNot: 'ag__.is_not',
|
|
||||||
gast.In: 'ag__.in_',
|
|
||||||
gast.Not: 'ag__.not_',
|
|
||||||
gast.NotIn: 'ag__.not_in',
|
|
||||||
gast.Or: 'ag__.or_',
|
|
||||||
gast.UAdd: 'ag__.u_add',
|
|
||||||
gast.USub: 'ag__.u_sub',
|
|
||||||
gast.Invert: 'ag__.invert',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LogicalExpressionTransformer(converter.Base):
|
class LogicalExpressionTransformer(converter.Base):
|
||||||
"""Converts logical expressions to corresponding TF calls."""
|
"""Converts logical expressions to corresponding TF calls."""
|
||||||
|
|
||||||
def _expect_simple_symbol(self, operand):
|
def _overload_of(self, operator):
|
||||||
if isinstance(operand, gast.Name):
|
|
||||||
return
|
|
||||||
if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND):
|
|
||||||
return
|
|
||||||
raise NotImplementedError(
|
|
||||||
'only simple local variables are supported in logical and compound '
|
|
||||||
'comparison expressions; for example, we support "a or b" but not '
|
|
||||||
'"a.x or b"; for a workaround, assign the expression to a local '
|
|
||||||
'variable and use that instead, for example "tmp = a.x", "tmp or b"')
|
|
||||||
|
|
||||||
def _has_matching_func(self, operator):
|
|
||||||
op_type = type(operator)
|
op_type = type(operator)
|
||||||
return op_type in OP_MAPPING
|
if op_type in LOGICAL_OPERATORS:
|
||||||
|
return LOGICAL_OPERATORS[op_type]
|
||||||
|
if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS):
|
||||||
|
if op_type in EQUALITY_OPERATORS:
|
||||||
|
return EQUALITY_OPERATORS[op_type]
|
||||||
|
return None
|
||||||
|
|
||||||
def _matching_func(self, operator):
|
def _as_lambda(self, expr):
|
||||||
op_type = type(operator)
|
return templates.replace_as_expression('lambda: expr', expr=expr)
|
||||||
return OP_MAPPING[op_type]
|
|
||||||
|
|
||||||
def _as_function(self, func_name, args, args_as_lambda=False):
|
def _as_binary_function(self, func_name, arg1, arg2):
|
||||||
if args_as_lambda:
|
return templates.replace_as_expression(
|
||||||
args_as_lambda = []
|
'func_name(arg1, arg2)',
|
||||||
for arg in args:
|
func_name=parser.parse_expression(func_name),
|
||||||
template = """
|
arg1=arg1,
|
||||||
lambda: arg
|
arg2=arg2)
|
||||||
"""
|
|
||||||
args_as_lambda.append(
|
|
||||||
templates.replace_as_expression(template, arg=arg))
|
|
||||||
args = args_as_lambda
|
|
||||||
|
|
||||||
if not args:
|
def _as_binary_operation(self, op, arg1, arg2):
|
||||||
template = """
|
template = templates.replace_as_expression(
|
||||||
func_name()
|
'arg1 is arg2',
|
||||||
"""
|
arg1=arg1,
|
||||||
replacement = templates.replace_as_expression(
|
arg2=arg2)
|
||||||
template, func_name=parser.parse_expression(func_name))
|
template.ops[0] = op
|
||||||
elif len(args) == 1:
|
return template
|
||||||
template = """
|
|
||||||
func_name(arg)
|
|
||||||
"""
|
|
||||||
replacement = templates.replace_as_expression(
|
|
||||||
template, func_name=parser.parse_expression(func_name), arg=args[0])
|
|
||||||
elif len(args) == 2:
|
|
||||||
template = """
|
|
||||||
func_name(arg1, arg2)
|
|
||||||
"""
|
|
||||||
replacement = templates.replace_as_expression(
|
|
||||||
template,
|
|
||||||
func_name=parser.parse_expression(func_name),
|
|
||||||
arg1=args[0],
|
|
||||||
arg2=args[1])
|
|
||||||
else:
|
|
||||||
raise NotImplementedError('{} arguments for {}'.format(
|
|
||||||
len(args), func_name))
|
|
||||||
|
|
||||||
anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
|
def _as_unary_function(self, func_name, arg):
|
||||||
return replacement
|
return templates.replace_as_expression(
|
||||||
|
'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg)
|
||||||
|
|
||||||
def visit_Compare(self, node):
|
def visit_Compare(self, node):
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
|
|
||||||
|
if (not self.ctx.program.options.uses(
|
||||||
|
converter.Feature.EQUALITY_OPERATORS)):
|
||||||
|
return node
|
||||||
|
|
||||||
ops_and_comps = list(zip(node.ops, node.comparators))
|
ops_and_comps = list(zip(node.ops, node.comparators))
|
||||||
left = node.left
|
left = node.left
|
||||||
op_tree = None
|
|
||||||
|
|
||||||
# Repeated comparisons are converted to conjunctions:
|
# Repeated comparisons are converted to conjunctions:
|
||||||
# a < b < c -> a < b and b < c
|
# a < b < c -> a < b and b < c
|
||||||
|
op_tree = None
|
||||||
while ops_and_comps:
|
while ops_and_comps:
|
||||||
op, right = ops_and_comps.pop(0)
|
op, right = ops_and_comps.pop(0)
|
||||||
binary_comparison = self._as_function(
|
overload = self._overload_of(op)
|
||||||
self._matching_func(op), (left, right))
|
if overload is not None:
|
||||||
if isinstance(left, gast.Name) and isinstance(right, gast.Name):
|
binary_comparison = self._as_binary_function(overload, left, right)
|
||||||
anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
|
else:
|
||||||
if op_tree:
|
binary_comparison = self._as_binary_operation(op, left, right)
|
||||||
self._expect_simple_symbol(right)
|
if op_tree is not None:
|
||||||
op_tree = self._as_function(
|
op_tree = self._as_binary_function('ag__.and_',
|
||||||
'ag__.and_', (op_tree, binary_comparison), args_as_lambda=True)
|
self._as_lambda(op_tree),
|
||||||
|
self._as_lambda(binary_comparison))
|
||||||
else:
|
else:
|
||||||
op_tree = binary_comparison
|
op_tree = binary_comparison
|
||||||
left = right
|
left = right
|
||||||
|
|
||||||
assert op_tree is not None
|
assert op_tree is not None
|
||||||
return op_tree
|
return op_tree
|
||||||
|
|
||||||
def visit_UnaryOp(self, node):
|
def visit_UnaryOp(self, node):
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
return self._as_function(self._matching_func(node.op), (node.operand,))
|
|
||||||
|
overload = self._overload_of(node.op)
|
||||||
|
if overload is None:
|
||||||
|
return node
|
||||||
|
|
||||||
|
return self._as_unary_function(overload, node.operand)
|
||||||
|
|
||||||
def visit_BoolOp(self, node):
|
def visit_BoolOp(self, node):
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
node_values = node.values
|
node_values = node.values
|
||||||
right = node.values.pop()
|
right = node.values.pop()
|
||||||
self._expect_simple_symbol(right)
|
|
||||||
while node_values:
|
while node_values:
|
||||||
left = node_values.pop()
|
left = node_values.pop()
|
||||||
self._expect_simple_symbol(left)
|
right = self._as_binary_function(
|
||||||
right = self._as_function(
|
self._overload_of(node.op), self._as_lambda(left),
|
||||||
self._matching_func(node.op), (left, right), args_as_lambda=True)
|
self._as_lambda(right))
|
||||||
return right
|
return right
|
||||||
|
|
||||||
|
|
||||||
def transform(node, ctx):
|
def transform(node, ctx):
|
||||||
return LogicalExpressionTransformer(ctx).visit(node)
|
transformer = LogicalExpressionTransformer(ctx)
|
||||||
|
return transformer.visit(node)
|
||||||
|
@ -96,9 +96,10 @@ class Feature(enum.Enum):
|
|||||||
ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert.
|
ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert.
|
||||||
BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to
|
BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to
|
||||||
their TF counterparts.
|
their TF counterparts.
|
||||||
|
EQUALITY_OPERATORS: Whether to convert the comparison operators, like
|
||||||
|
equality. This is soon to be deprecated as support is being added to the
|
||||||
|
Tensor class.
|
||||||
LISTS: Convert list idioms, like initializers, slices, append, etc.
|
LISTS: Convert list idioms, like initializers, slices, append, etc.
|
||||||
LOGICAL_EXPRESSIONS: Convert data-dependent logical expressions applied to
|
|
||||||
Tensors to their TF counterparts.
|
|
||||||
NAME_SCOPES: Insert name scopes that name ops according to context, like the
|
NAME_SCOPES: Insert name scopes that name ops according to context, like the
|
||||||
function they were defined in.
|
function they were defined in.
|
||||||
"""
|
"""
|
||||||
@ -108,8 +109,8 @@ class Feature(enum.Enum):
|
|||||||
AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS'
|
AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS'
|
||||||
ASSERT_STATEMENTS = 'ASSERT_STATEMENTS'
|
ASSERT_STATEMENTS = 'ASSERT_STATEMENTS'
|
||||||
BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS'
|
BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS'
|
||||||
|
EQUALITY_OPERATORS = 'EQUALITY_OPERATORS'
|
||||||
LISTS = 'LISTS'
|
LISTS = 'LISTS'
|
||||||
LOGICAL_EXPRESSIONS = 'LOGICAL_EXPRESSIONS'
|
|
||||||
NAME_SCOPES = 'NAME_SCOPES'
|
NAME_SCOPES = 'NAME_SCOPES'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -677,8 +677,7 @@ def node_to_graph(node, context):
|
|||||||
node = converter.apply_(node, context, call_trees)
|
node = converter.apply_(node, context, call_trees)
|
||||||
node = converter.apply_(node, context, control_flow)
|
node = converter.apply_(node, context, control_flow)
|
||||||
node = converter.apply_(node, context, conditional_expressions)
|
node = converter.apply_(node, context, conditional_expressions)
|
||||||
if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS):
|
node = converter.apply_(node, context, logical_expressions)
|
||||||
node = converter.apply_(node, context, logical_expressions)
|
|
||||||
if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
|
if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
|
||||||
node = converter.apply_(node, context, side_effect_guards)
|
node = converter.apply_(node, context, side_effect_guards)
|
||||||
# TODO(mdan): If function scopes ever does more, the toggle will need moving.
|
# TODO(mdan): If function scopes ever does more, the toggle will need moving.
|
||||||
|
@ -49,20 +49,9 @@ from tensorflow.python.autograph.operators.data_structures import new_list
|
|||||||
from tensorflow.python.autograph.operators.exceptions import assert_stmt
|
from tensorflow.python.autograph.operators.exceptions import assert_stmt
|
||||||
from tensorflow.python.autograph.operators.logical import and_
|
from tensorflow.python.autograph.operators.logical import and_
|
||||||
from tensorflow.python.autograph.operators.logical import eq
|
from tensorflow.python.autograph.operators.logical import eq
|
||||||
from tensorflow.python.autograph.operators.logical import gt
|
|
||||||
from tensorflow.python.autograph.operators.logical import gt_e
|
|
||||||
from tensorflow.python.autograph.operators.logical import in_
|
|
||||||
from tensorflow.python.autograph.operators.logical import invert
|
|
||||||
from tensorflow.python.autograph.operators.logical import is_
|
|
||||||
from tensorflow.python.autograph.operators.logical import is_not
|
|
||||||
from tensorflow.python.autograph.operators.logical import lt
|
|
||||||
from tensorflow.python.autograph.operators.logical import lt_e
|
|
||||||
from tensorflow.python.autograph.operators.logical import not_
|
from tensorflow.python.autograph.operators.logical import not_
|
||||||
from tensorflow.python.autograph.operators.logical import not_eq
|
from tensorflow.python.autograph.operators.logical import not_eq
|
||||||
from tensorflow.python.autograph.operators.logical import not_in
|
|
||||||
from tensorflow.python.autograph.operators.logical import or_
|
from tensorflow.python.autograph.operators.logical import or_
|
||||||
from tensorflow.python.autograph.operators.logical import u_add
|
|
||||||
from tensorflow.python.autograph.operators.logical import u_sub
|
|
||||||
from tensorflow.python.autograph.operators.py_builtins import float_
|
from tensorflow.python.autograph.operators.py_builtins import float_
|
||||||
from tensorflow.python.autograph.operators.py_builtins import int_
|
from tensorflow.python.autograph.operators.py_builtins import int_
|
||||||
from tensorflow.python.autograph.operators.py_builtins import len_
|
from tensorflow.python.autograph.operators.py_builtins import len_
|
||||||
|
@ -12,24 +12,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Logical operators, including comparison and bool operators."""
|
"""Logical boolean operators: not, and, or."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import operator
|
|
||||||
|
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
|
|
||||||
|
|
||||||
# Note: the implementations in this file are split into very small-grained
|
|
||||||
# functions in preparation for the factoring out the more generic pyct library.
|
|
||||||
# At that time, the py_* and tf_* functions will reside in different libraries.
|
|
||||||
|
|
||||||
|
|
||||||
def not_(a):
|
def not_(a):
|
||||||
"""Functional form of "not"."""
|
"""Functional form of "not"."""
|
||||||
if tensor_util.is_tensor(a):
|
if tensor_util.is_tensor(a):
|
||||||
@ -105,30 +98,3 @@ def _py_equal(a, b):
|
|||||||
def not_eq(a, b):
|
def not_eq(a, b):
|
||||||
"""Functional form of "not-equal"."""
|
"""Functional form of "not-equal"."""
|
||||||
return not_(eq(a, b))
|
return not_(eq(a, b))
|
||||||
|
|
||||||
|
|
||||||
# Default implementation for the rest.
|
|
||||||
|
|
||||||
is_ = operator.is_
|
|
||||||
is_not = operator.is_not
|
|
||||||
|
|
||||||
|
|
||||||
def in_(a, b):
|
|
||||||
"""Functional form of "in"."""
|
|
||||||
# TODO(mdan): in and not_in should probably be convertible for some types.
|
|
||||||
return a in b
|
|
||||||
|
|
||||||
|
|
||||||
def not_in(a, b):
|
|
||||||
"""Functional form of "not-in"."""
|
|
||||||
return a not in b
|
|
||||||
|
|
||||||
gt = operator.gt
|
|
||||||
gt_e = operator.ge
|
|
||||||
lt = operator.lt
|
|
||||||
lt_e = operator.le
|
|
||||||
|
|
||||||
|
|
||||||
u_add = operator.pos
|
|
||||||
u_sub = operator.neg
|
|
||||||
invert = operator.invert
|
|
||||||
|
@ -18,11 +18,11 @@ tf_class {
|
|||||||
mtype: "<enum \'Feature\'>"
|
mtype: "<enum \'Feature\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
name: "LISTS"
|
name: "EQUALITY_OPERATORS"
|
||||||
mtype: "<enum \'Feature\'>"
|
mtype: "<enum \'Feature\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
name: "LOGICAL_EXPRESSIONS"
|
name: "LISTS"
|
||||||
mtype: "<enum \'Feature\'>"
|
mtype: "<enum \'Feature\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
|
@ -18,11 +18,11 @@ tf_class {
|
|||||||
mtype: "<enum \'Feature\'>"
|
mtype: "<enum \'Feature\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
name: "LISTS"
|
name: "EQUALITY_OPERATORS"
|
||||||
mtype: "<enum \'Feature\'>"
|
mtype: "<enum \'Feature\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
name: "LOGICAL_EXPRESSIONS"
|
name: "LISTS"
|
||||||
mtype: "<enum \'Feature\'>"
|
mtype: "<enum \'Feature\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
|
Loading…
Reference in New Issue
Block a user