Changes convert_stack's return type from list to tuple to make the value hashable.

PiperOrigin-RevId: 259392206
This commit is contained in:
A. Unique TensorFlower 2019-07-22 13:17:28 -07:00 committed by TensorFlower Gardener
parent b0cd40d7c7
commit 8614aaf955
3 changed files with 36 additions and 16 deletions

View File

@ -1462,14 +1462,14 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
# Lookup should work with node name input. # Lookup should work with node name input.
traceback = dump.node_traceback("traceback/w") traceback = dump.node_traceback("traceback/w")
self.assertIsInstance(traceback, list) self.assertIsInstance(traceback, tuple)
self.assertGreater(len(traceback), 0) self.assertGreater(len(traceback), 0)
for trace in traceback: for trace in traceback:
self.assertIsInstance(trace, tuple) self.assertIsInstance(trace, tuple)
# Lookup should also work with tensor name input. # Lookup should also work with tensor name input.
traceback = dump.node_traceback("traceback/w:0") traceback = dump.node_traceback("traceback/w:0")
self.assertIsInstance(traceback, list) self.assertIsInstance(traceback, tuple)
self.assertGreater(len(traceback), 0) self.assertGreater(len(traceback), 0)
for trace in traceback: for trace in traceback:
self.assertIsInstance(trace, tuple) self.assertIsInstance(trace, tuple)

View File

@ -49,6 +49,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -1203,6 +1204,24 @@ class MultiWorkerMirroredStrategyTestWithChief(
self._test_summary_for_replica_zero_only(strategy) self._test_summary_for_replica_zero_only(strategy)
class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
],
mode=["graph"]))
def testMirroredVariableAsStopGradient(self, distribution):
with distribution.scope():
inp = constant_op.constant(1.0)
x = variables.Variable(1.0)
y = inp*x
grads = gradients.gradients(x, y, stop_gradients=x)
self.assertIsNone(grads[0])
def _replica_id(): def _replica_id():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if not isinstance(replica_id, ops.Tensor): if not isinstance(replica_id, ops.Tensor):

View File

@ -199,12 +199,12 @@ def convert_stack(stack, include_func_start_lineno=False):
included as the 5th entry in return tuples. included as the 5th entry in return tuples.
Returns: Returns:
A list of n 4-tuples or 5-tuples A tuple of n 4-tuples or 5-tuples
(filename, lineno, name, code, [optional: func_start_lineno]), where the (filename, lineno, name, code, [optional: func_start_lineno]), where the
code tuple element is calculated from the corresponding elements of the code tuple element is calculated from the corresponding elements of the
input tuple. input tuple.
""" """
ret = [] def _tuple_generator(): # pylint: disable=missing-docstring
for (filename, lineno, name, frame_globals, func_start_lineno) in stack: for (filename, lineno, name, frame_globals, func_start_lineno) in stack:
linecache.checkcache(filename) linecache.checkcache(filename)
line = linecache.getline(filename, lineno, frame_globals) line = linecache.getline(filename, lineno, frame_globals)
@ -213,7 +213,8 @@ def convert_stack(stack, include_func_start_lineno=False):
else: else:
line = None line = None
if include_func_start_lineno: if include_func_start_lineno:
ret.append((filename, lineno, name, line, func_start_lineno)) yield (filename, lineno, name, line, func_start_lineno)
else: else:
ret.append((filename, lineno, name, line)) yield (filename, lineno, name, line)
return ret
return tuple(_tuple_generator())