Test identity gradient in presence of control flow.
PiperOrigin-RevId: 342259927 Change-Id: I890b84bfa1d8483f104bfd6523955564c8974082
This commit is contained in:
parent
f21a1be8e4
commit
5c34e9158b
@ -789,8 +789,6 @@ class TFRGen(transformer.CodeGenerator):
|
||||
# The out symbols are just a Tuple of names
|
||||
for out in node.args[5].elts[:nouts]:
|
||||
val, ty = self.symbol_table.lookup(out.value)
|
||||
if ty != TFRTypes.AG_UNDEFINED_VAL:
|
||||
raise ValueError('if stmt out symbol is not defined.')
|
||||
out_symbols.append(out.value)
|
||||
return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols,
|
||||
node)
|
||||
@ -980,10 +978,8 @@ class TFRGen(transformer.CodeGenerator):
|
||||
if ret_ssa_values:
|
||||
self.emit(ret_str + ' = ')
|
||||
|
||||
# add ssa values to the symbol table
|
||||
out_types = []
|
||||
for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
|
||||
self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR)
|
||||
out_types.append(str(TFRTypes.TENSOR))
|
||||
|
||||
self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types)))
|
||||
@ -1001,6 +997,10 @@ class TFRGen(transformer.CodeGenerator):
|
||||
self.visit_block(get_state.body)
|
||||
self.symbol_table.exit_scope()
|
||||
|
||||
# add ssa values to the symbol table
|
||||
for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
|
||||
self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR)
|
||||
|
||||
self._emit_with_loc('\n}', node)
|
||||
return list(zip(ret_ssa_values, out_types))
|
||||
|
||||
|
@ -196,8 +196,8 @@ class ControlFlowTransformer(converter.Base):
|
||||
# it.
|
||||
input_only = basic_scope_vars & live_in - live_out
|
||||
|
||||
# Place the outputs first.
|
||||
scope_vars = sorted(scope_vars, key=lambda v: v in input_only)
|
||||
# Place the outputs first, then sort lexicographically.
|
||||
scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v))
|
||||
nouts = len(scope_vars) - len(input_only)
|
||||
|
||||
return scope_vars, undefined, nouts
|
||||
|
@ -107,6 +107,12 @@ class QN(object):
|
||||
def has_attr(self):
|
||||
return self._has_attr
|
||||
|
||||
@property
|
||||
def attr(self):
|
||||
if not self._has_attr:
|
||||
raise ValueError('Cannot get attr of non-attribute "%s".' % self)
|
||||
return self.qn[1]
|
||||
|
||||
@property
|
||||
def parent(self):
|
||||
if self._parent is None:
|
||||
@ -160,6 +166,18 @@ class QN(object):
|
||||
self.has_subscript() == other.has_subscript() and
|
||||
self.has_attr() == other.has_attr())
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, QN):
|
||||
return self.qn < other.qn
|
||||
else:
|
||||
return str(self) < str(other)
|
||||
|
||||
def __gt__(self, other):
|
||||
if isinstance(other, QN):
|
||||
return self.qn > other.qn
|
||||
else:
|
||||
return str(self) > str(other)
|
||||
|
||||
def __str__(self):
|
||||
root = self.qn[0]
|
||||
if self.has_subscript():
|
||||
|
Loading…
Reference in New Issue
Block a user