diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py index bd09613adc2..52afb197134 100644 --- a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py +++ b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py @@ -789,8 +789,6 @@ class TFRGen(transformer.CodeGenerator): # The out symbols are just a Tuple of names for out in node.args[5].elts[:nouts]: val, ty = self.symbol_table.lookup(out.value) - if ty != TFRTypes.AG_UNDEFINED_VAL: - raise ValueError('if stmt out symbol is not defined.') out_symbols.append(out.value) return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols, node) @@ -980,10 +978,8 @@ class TFRGen(transformer.CodeGenerator): if ret_ssa_values: self.emit(ret_str + ' = ') - # add ssa values to the symbol table out_types = [] for symbol, ssa_value in zip(out_symbols, ret_ssa_values): - self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR) out_types.append(str(TFRTypes.TENSOR)) self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types))) @@ -1001,6 +997,10 @@ class TFRGen(transformer.CodeGenerator): self.visit_block(get_state.body) self.symbol_table.exit_scope() + # add ssa values to the symbol table + for symbol, ssa_value in zip(out_symbols, ret_ssa_values): + self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR) + self._emit_with_loc('\n}', node) return list(zip(ret_ssa_values, out_types)) diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index 7b30b5723be..3c62d88061a 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -196,8 +196,8 @@ class ControlFlowTransformer(converter.Base): # it. input_only = basic_scope_vars & live_in - live_out - # Place the outputs first. - scope_vars = sorted(scope_vars, key=lambda v: v in input_only) + # Place the outputs first, then sort lexicographically. + scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v)) nouts = len(scope_vars) - len(input_only) return scope_vars, undefined, nouts diff --git a/tensorflow/python/autograph/pyct/qual_names.py b/tensorflow/python/autograph/pyct/qual_names.py index d9491691567..d747230d51a 100644 --- a/tensorflow/python/autograph/pyct/qual_names.py +++ b/tensorflow/python/autograph/pyct/qual_names.py @@ -107,6 +107,12 @@ class QN(object): def has_attr(self): return self._has_attr + @property + def attr(self): + if not self._has_attr: + raise ValueError('Cannot get attr of non-attribute "%s".' % self) + return self.qn[1] + @property def parent(self): if self._parent is None: @@ -160,6 +166,18 @@ class QN(object): self.has_subscript() == other.has_subscript() and self.has_attr() == other.has_attr()) + def __lt__(self, other): + if isinstance(other, QN): + return self.qn < other.qn + else: + return str(self) < str(other) + + def __gt__(self, other): + if isinstance(other, QN): + return self.qn > other.qn + else: + return str(self) > str(other) + def __str__(self): root = self.qn[0] if self.has_subscript():