From bcbb1db93f67f5cb6888bbaa0a96676d5d8e5c65 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Sat, 18 May 2019 09:57:35 -0700 Subject: [PATCH] Reactivate logical operators: not, and, or, which aren't supported by Python's operator overloading. Move the equality operator under a new experimental feature which will be eventually deprecated. PiperOrigin-RevId: 248873569 --- .../converters/logical_expressions.py | 144 +++++++----------- tensorflow/python/autograph/core/converter.py | 7 +- .../python/autograph/impl/conversion.py | 3 +- .../python/autograph/operators/__init__.py | 11 -- .../python/autograph/operators/logical.py | 36 +---- ...flow.autograph.experimental.-feature.pbtxt | 4 +- ...flow.autograph.experimental.-feature.pbtxt | 4 +- 7 files changed, 68 insertions(+), 141 deletions(-) diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py index ea9740a22e1..6f2c0ca029b 100644 --- a/tensorflow/python/autograph/converters/logical_expressions.py +++ b/tensorflow/python/autograph/converters/logical_expressions.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Converter for logical expressions. - -e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF. -""" +"""Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`.""" from __future__ import absolute_import from __future__ import division @@ -24,7 +21,6 @@ from __future__ import print_function import gast from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import templates @@ -38,128 +34,104 @@ from tensorflow.python.autograph.pyct import templates SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' -OP_MAPPING = { +LOGICAL_OPERATORS = { gast.And: 'ag__.and_', + gast.Not: 'ag__.not_', + gast.Or: 'ag__.or_', +} + +EQUALITY_OPERATORS = { gast.Eq: 'ag__.eq', gast.NotEq: 'ag__.not_eq', - gast.Lt: 'ag__.lt', - gast.LtE: 'ag__.lt_e', - gast.Gt: 'ag__.gt', - gast.GtE: 'ag__.gt_e', - gast.Is: 'ag__.is_', - gast.IsNot: 'ag__.is_not', - gast.In: 'ag__.in_', - gast.Not: 'ag__.not_', - gast.NotIn: 'ag__.not_in', - gast.Or: 'ag__.or_', - gast.UAdd: 'ag__.u_add', - gast.USub: 'ag__.u_sub', - gast.Invert: 'ag__.invert', } class LogicalExpressionTransformer(converter.Base): """Converts logical expressions to corresponding TF calls.""" - def _expect_simple_symbol(self, operand): - if isinstance(operand, gast.Name): - return - if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND): - return - raise NotImplementedError( - 'only simple local variables are supported in logical and compound ' - 'comparison expressions; for example, we support "a or b" but not ' - '"a.x or b"; for a workaround, assign the expression to a local ' - 'variable and use that instead, for example "tmp = a.x", "tmp or b"') - - def _has_matching_func(self, operator): + def _overload_of(self, operator): op_type = type(operator) - return op_type in OP_MAPPING + if op_type in LOGICAL_OPERATORS: + return LOGICAL_OPERATORS[op_type] + if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS): + if op_type in EQUALITY_OPERATORS: + return EQUALITY_OPERATORS[op_type] + return None - def _matching_func(self, operator): - op_type = type(operator) - return OP_MAPPING[op_type] + def _as_lambda(self, expr): + return templates.replace_as_expression('lambda: expr', expr=expr) - def _as_function(self, func_name, args, args_as_lambda=False): - if args_as_lambda: - args_as_lambda = [] - for arg in args: - template = """ - lambda: arg - """ - args_as_lambda.append( - templates.replace_as_expression(template, arg=arg)) - args = args_as_lambda + def _as_binary_function(self, func_name, arg1, arg2): + return templates.replace_as_expression( + 'func_name(arg1, arg2)', + func_name=parser.parse_expression(func_name), + arg1=arg1, + arg2=arg2) - if not args: - template = """ - func_name() - """ - replacement = templates.replace_as_expression( - template, func_name=parser.parse_expression(func_name)) - elif len(args) == 1: - template = """ - func_name(arg) - """ - replacement = templates.replace_as_expression( - template, func_name=parser.parse_expression(func_name), arg=args[0]) - elif len(args) == 2: - template = """ - func_name(arg1, arg2) - """ - replacement = templates.replace_as_expression( - template, - func_name=parser.parse_expression(func_name), - arg1=args[0], - arg2=args[1]) - else: - raise NotImplementedError('{} arguments for {}'.format( - len(args), func_name)) + def _as_binary_operation(self, op, arg1, arg2): + template = templates.replace_as_expression( + 'arg1 is arg2', + arg1=arg1, + arg2=arg2) + template.ops[0] = op + return template - anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) - return replacement + def _as_unary_function(self, func_name, arg): + return templates.replace_as_expression( + 'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg) def visit_Compare(self, node): node = self.generic_visit(node) + if (not self.ctx.program.options.uses( + converter.Feature.EQUALITY_OPERATORS)): + return node + ops_and_comps = list(zip(node.ops, node.comparators)) left = node.left - op_tree = None # Repeated comparisons are converted to conjunctions: # a < b < c -> a < b and b < c + op_tree = None while ops_and_comps: op, right = ops_and_comps.pop(0) - binary_comparison = self._as_function( - self._matching_func(op), (left, right)) - if isinstance(left, gast.Name) and isinstance(right, gast.Name): - anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) - if op_tree: - self._expect_simple_symbol(right) - op_tree = self._as_function( - 'ag__.and_', (op_tree, binary_comparison), args_as_lambda=True) + overload = self._overload_of(op) + if overload is not None: + binary_comparison = self._as_binary_function(overload, left, right) + else: + binary_comparison = self._as_binary_operation(op, left, right) + if op_tree is not None: + op_tree = self._as_binary_function('ag__.and_', + self._as_lambda(op_tree), + self._as_lambda(binary_comparison)) else: op_tree = binary_comparison left = right + assert op_tree is not None return op_tree def visit_UnaryOp(self, node): node = self.generic_visit(node) - return self._as_function(self._matching_func(node.op), (node.operand,)) + + overload = self._overload_of(node.op) + if overload is None: + return node + + return self._as_unary_function(overload, node.operand) def visit_BoolOp(self, node): node = self.generic_visit(node) node_values = node.values right = node.values.pop() - self._expect_simple_symbol(right) while node_values: left = node_values.pop() - self._expect_simple_symbol(left) - right = self._as_function( - self._matching_func(node.op), (left, right), args_as_lambda=True) + right = self._as_binary_function( + self._overload_of(node.op), self._as_lambda(left), + self._as_lambda(right)) return right def transform(node, ctx): - return LogicalExpressionTransformer(ctx).visit(node) + transformer = LogicalExpressionTransformer(ctx) + return transformer.visit(node) diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index 061c4cfaaf3..5b265cf965c 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -96,9 +96,10 @@ class Feature(enum.Enum): ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert. BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to their TF counterparts. + EQUALITY_OPERATORS: Whether to convert the comparison operators, like + equality. This is soon to be deprecated as support is being added to the + Tensor class. LISTS: Convert list idioms, like initializers, slices, append, etc. - LOGICAL_EXPRESSIONS: Convert data-dependent logical expressions applied to - Tensors to their TF counterparts. NAME_SCOPES: Insert name scopes that name ops according to context, like the function they were defined in. """ @@ -108,8 +109,8 @@ class Feature(enum.Enum): AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS' ASSERT_STATEMENTS = 'ASSERT_STATEMENTS' BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS' + EQUALITY_OPERATORS = 'EQUALITY_OPERATORS' LISTS = 'LISTS' - LOGICAL_EXPRESSIONS = 'LOGICAL_EXPRESSIONS' NAME_SCOPES = 'NAME_SCOPES' @classmethod diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index f39da7c8376..f53e1862290 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -677,8 +677,7 @@ def node_to_graph(node, context): node = converter.apply_(node, context, call_trees) node = converter.apply_(node, context, control_flow) node = converter.apply_(node, context, conditional_expressions) - if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS): - node = converter.apply_(node, context, logical_expressions) + node = converter.apply_(node, context, logical_expressions) if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS): node = converter.apply_(node, context, side_effect_guards) # TODO(mdan): If function scopes ever does more, the toggle will need moving. diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py index bbc684eaf2b..17f5ff39a8e 100644 --- a/tensorflow/python/autograph/operators/__init__.py +++ b/tensorflow/python/autograph/operators/__init__.py @@ -49,20 +49,9 @@ from tensorflow.python.autograph.operators.data_structures import new_list from tensorflow.python.autograph.operators.exceptions import assert_stmt from tensorflow.python.autograph.operators.logical import and_ from tensorflow.python.autograph.operators.logical import eq -from tensorflow.python.autograph.operators.logical import gt -from tensorflow.python.autograph.operators.logical import gt_e -from tensorflow.python.autograph.operators.logical import in_ -from tensorflow.python.autograph.operators.logical import invert -from tensorflow.python.autograph.operators.logical import is_ -from tensorflow.python.autograph.operators.logical import is_not -from tensorflow.python.autograph.operators.logical import lt -from tensorflow.python.autograph.operators.logical import lt_e from tensorflow.python.autograph.operators.logical import not_ from tensorflow.python.autograph.operators.logical import not_eq -from tensorflow.python.autograph.operators.logical import not_in from tensorflow.python.autograph.operators.logical import or_ -from tensorflow.python.autograph.operators.logical import u_add -from tensorflow.python.autograph.operators.logical import u_sub from tensorflow.python.autograph.operators.py_builtins import float_ from tensorflow.python.autograph.operators.py_builtins import int_ from tensorflow.python.autograph.operators.py_builtins import len_ diff --git a/tensorflow/python/autograph/operators/logical.py b/tensorflow/python/autograph/operators/logical.py index cafb0583e8f..81457580fbc 100644 --- a/tensorflow/python/autograph/operators/logical.py +++ b/tensorflow/python/autograph/operators/logical.py @@ -12,24 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Logical operators, including comparison and bool operators.""" +"""Logical boolean operators: not, and, or.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import operator - from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_math_ops -# Note: the implementations in this file are split into very small-grained -# functions in preparation for the factoring out the more generic pyct library. -# At that time, the py_* and tf_* functions will reside in different libraries. - - def not_(a): """Functional form of "not".""" if tensor_util.is_tensor(a): @@ -105,30 +98,3 @@ def _py_equal(a, b): def not_eq(a, b): """Functional form of "not-equal".""" return not_(eq(a, b)) - - -# Default implementation for the rest. - -is_ = operator.is_ -is_not = operator.is_not - - -def in_(a, b): - """Functional form of "in".""" - # TODO(mdan): in and not_in should probably be convertible for some types. - return a in b - - -def not_in(a, b): - """Functional form of "not-in".""" - return a not in b - -gt = operator.gt -gt_e = operator.ge -lt = operator.lt -lt_e = operator.le - - -u_add = operator.pos -u_sub = operator.neg -invert = operator.invert diff --git a/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.-feature.pbtxt index d283fb8e14f..5d17918107c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.-feature.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.-feature.pbtxt @@ -18,11 +18,11 @@ tf_class { mtype: "" } member { - name: "LISTS" + name: "EQUALITY_OPERATORS" mtype: "" } member { - name: "LOGICAL_EXPRESSIONS" + name: "LISTS" mtype: "" } member { diff --git a/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.-feature.pbtxt index d283fb8e14f..5d17918107c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.-feature.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.-feature.pbtxt @@ -18,11 +18,11 @@ tf_class { mtype: "" } member { - name: "LISTS" + name: "EQUALITY_OPERATORS" mtype: "" } member { - name: "LOGICAL_EXPRESSIONS" + name: "LISTS" mtype: "" } member {