Changes convert_stack's return type from list to tuple to make the value hashable.

PiperOrigin-RevId: 259392206
This commit is contained in:
A. Unique TensorFlower 2019-07-22 13:17:28 -07:00 committed by TensorFlower Gardener
parent b0cd40d7c7
commit 8614aaf955
3 changed files with 36 additions and 16 deletions

View File

@ -1462,14 +1462,14 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
# Lookup should work with node name input.
traceback = dump.node_traceback("traceback/w")
self.assertIsInstance(traceback, list)
self.assertIsInstance(traceback, tuple)
self.assertGreater(len(traceback), 0)
for trace in traceback:
self.assertIsInstance(trace, tuple)
# Lookup should also work with tensor name input.
traceback = dump.node_traceback("traceback/w:0")
self.assertIsInstance(traceback, list)
self.assertIsInstance(traceback, tuple)
self.assertGreater(len(traceback), 0)
for trace in traceback:
self.assertIsInstance(trace, tuple)

View File

@ -49,6 +49,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@ -1203,6 +1204,24 @@ class MultiWorkerMirroredStrategyTestWithChief(
self._test_summary_for_replica_zero_only(strategy)
class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
],
mode=["graph"]))
def testMirroredVariableAsStopGradient(self, distribution):
with distribution.scope():
inp = constant_op.constant(1.0)
x = variables.Variable(1.0)
y = inp*x
grads = gradients.gradients(x, y, stop_gradients=x)
self.assertIsNone(grads[0])
def _replica_id():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if not isinstance(replica_id, ops.Tensor):

View File

@ -199,12 +199,12 @@ def convert_stack(stack, include_func_start_lineno=False):
included as the 5th entry in return tuples.
Returns:
A list of n 4-tuples or 5-tuples
A tuple of n 4-tuples or 5-tuples
(filename, lineno, name, code, [optional: func_start_lineno]), where the
code tuple element is calculated from the corresponding elements of the
input tuple.
"""
ret = []
def _tuple_generator(): # pylint: disable=missing-docstring
for (filename, lineno, name, frame_globals, func_start_lineno) in stack:
linecache.checkcache(filename)
line = linecache.getline(filename, lineno, frame_globals)
@ -213,7 +213,8 @@ def convert_stack(stack, include_func_start_lineno=False):
else:
line = None
if include_func_start_lineno:
ret.append((filename, lineno, name, line, func_start_lineno))
yield (filename, lineno, name, line, func_start_lineno)
else:
ret.append((filename, lineno, name, line))
return ret
yield (filename, lineno, name, line)
return tuple(_tuple_generator())