Test identity gradient in presence of control flow.

PiperOrigin-RevId: 342259927
Change-Id: I890b84bfa1d8483f104bfd6523955564c8974082
This commit is contained in:
Dan Moldovan 2020-11-13 07:39:11 -08:00 committed by TensorFlower Gardener
parent f21a1be8e4
commit 5c34e9158b
3 changed files with 24 additions and 6 deletions

View File

@ -789,8 +789,6 @@ class TFRGen(transformer.CodeGenerator):
# The out symbols are just a Tuple of names
for out in node.args[5].elts[:nouts]:
val, ty = self.symbol_table.lookup(out.value)
if ty != TFRTypes.AG_UNDEFINED_VAL:
raise ValueError('if stmt out symbol is not defined.')
out_symbols.append(out.value)
return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols,
node)
@ -980,10 +978,8 @@ class TFRGen(transformer.CodeGenerator):
if ret_ssa_values:
self.emit(ret_str + ' = ')
# add ssa values to the symbol table
out_types = []
for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR)
out_types.append(str(TFRTypes.TENSOR))
self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types)))
@ -1001,6 +997,10 @@ class TFRGen(transformer.CodeGenerator):
self.visit_block(get_state.body)
self.symbol_table.exit_scope()
# add ssa values to the symbol table
for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR)
self._emit_with_loc('\n}', node)
return list(zip(ret_ssa_values, out_types))

View File

@ -196,8 +196,8 @@ class ControlFlowTransformer(converter.Base):
# it.
input_only = basic_scope_vars & live_in - live_out
# Place the outputs first.
scope_vars = sorted(scope_vars, key=lambda v: v in input_only)
# Place the outputs first, then sort lexicographically.
scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v))
nouts = len(scope_vars) - len(input_only)
return scope_vars, undefined, nouts

View File

@ -107,6 +107,12 @@ class QN(object):
def has_attr(self):
return self._has_attr
@property
def attr(self):
if not self._has_attr:
raise ValueError('Cannot get attr of non-attribute "%s".' % self)
return self.qn[1]
@property
def parent(self):
if self._parent is None:
@ -160,6 +166,18 @@ class QN(object):
self.has_subscript() == other.has_subscript() and
self.has_attr() == other.has_attr())
def __lt__(self, other):
if isinstance(other, QN):
return self.qn < other.qn
else:
return str(self) < str(other)
def __gt__(self, other):
if isinstance(other, QN):
return self.qn > other.qn
else:
return str(self) > str(other)
def __str__(self):
root = self.qn[0]
if self.has_subscript():