Changes convert_stack's return type from list to tuple to make the value hashable.
PiperOrigin-RevId: 259392206
This commit is contained in:
parent
b0cd40d7c7
commit
8614aaf955
@ -1462,14 +1462,14 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
|
||||
|
||||
# Lookup should work with node name input.
|
||||
traceback = dump.node_traceback("traceback/w")
|
||||
self.assertIsInstance(traceback, list)
|
||||
self.assertIsInstance(traceback, tuple)
|
||||
self.assertGreater(len(traceback), 0)
|
||||
for trace in traceback:
|
||||
self.assertIsInstance(trace, tuple)
|
||||
|
||||
# Lookup should also work with tensor name input.
|
||||
traceback = dump.node_traceback("traceback/w:0")
|
||||
self.assertIsInstance(traceback, list)
|
||||
self.assertIsInstance(traceback, tuple)
|
||||
self.assertGreater(len(traceback), 0)
|
||||
for trace in traceback:
|
||||
self.assertIsInstance(trace, tuple)
|
||||
|
@ -49,6 +49,7 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.keras.layers import core as keras_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
@ -1203,6 +1204,24 @@ class MultiWorkerMirroredStrategyTestWithChief(
|
||||
self._test_summary_for_replica_zero_only(strategy)
|
||||
|
||||
|
||||
class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_one_gpu,
|
||||
],
|
||||
mode=["graph"]))
|
||||
def testMirroredVariableAsStopGradient(self, distribution):
|
||||
with distribution.scope():
|
||||
inp = constant_op.constant(1.0)
|
||||
x = variables.Variable(1.0)
|
||||
y = inp*x
|
||||
grads = gradients.gradients(x, y, stop_gradients=x)
|
||||
self.assertIsNone(grads[0])
|
||||
|
||||
|
||||
def _replica_id():
|
||||
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
|
||||
if not isinstance(replica_id, ops.Tensor):
|
||||
|
@ -199,21 +199,22 @@ def convert_stack(stack, include_func_start_lineno=False):
|
||||
included as the 5th entry in return tuples.
|
||||
|
||||
Returns:
|
||||
A list of n 4-tuples or 5-tuples
|
||||
A tuple of n 4-tuples or 5-tuples
|
||||
(filename, lineno, name, code, [optional: func_start_lineno]), where the
|
||||
code tuple element is calculated from the corresponding elements of the
|
||||
input tuple.
|
||||
"""
|
||||
ret = []
|
||||
for (filename, lineno, name, frame_globals, func_start_lineno) in stack:
|
||||
linecache.checkcache(filename)
|
||||
line = linecache.getline(filename, lineno, frame_globals)
|
||||
if line:
|
||||
line = line.strip()
|
||||
else:
|
||||
line = None
|
||||
if include_func_start_lineno:
|
||||
ret.append((filename, lineno, name, line, func_start_lineno))
|
||||
else:
|
||||
ret.append((filename, lineno, name, line))
|
||||
return ret
|
||||
def _tuple_generator(): # pylint: disable=missing-docstring
|
||||
for (filename, lineno, name, frame_globals, func_start_lineno) in stack:
|
||||
linecache.checkcache(filename)
|
||||
line = linecache.getline(filename, lineno, frame_globals)
|
||||
if line:
|
||||
line = line.strip()
|
||||
else:
|
||||
line = None
|
||||
if include_func_start_lineno:
|
||||
yield (filename, lineno, name, line, func_start_lineno)
|
||||
else:
|
||||
yield (filename, lineno, name, line)
|
||||
|
||||
return tuple(_tuple_generator())
|
||||
|
Loading…
Reference in New Issue
Block a user