Changes convert_stack's return type from list to tuple to make the value hashable.
PiperOrigin-RevId: 259392206
This commit is contained in:
parent
b0cd40d7c7
commit
8614aaf955
@ -1462,14 +1462,14 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
# Lookup should work with node name input.
|
# Lookup should work with node name input.
|
||||||
traceback = dump.node_traceback("traceback/w")
|
traceback = dump.node_traceback("traceback/w")
|
||||||
self.assertIsInstance(traceback, list)
|
self.assertIsInstance(traceback, tuple)
|
||||||
self.assertGreater(len(traceback), 0)
|
self.assertGreater(len(traceback), 0)
|
||||||
for trace in traceback:
|
for trace in traceback:
|
||||||
self.assertIsInstance(trace, tuple)
|
self.assertIsInstance(trace, tuple)
|
||||||
|
|
||||||
# Lookup should also work with tensor name input.
|
# Lookup should also work with tensor name input.
|
||||||
traceback = dump.node_traceback("traceback/w:0")
|
traceback = dump.node_traceback("traceback/w:0")
|
||||||
self.assertIsInstance(traceback, list)
|
self.assertIsInstance(traceback, tuple)
|
||||||
self.assertGreater(len(traceback), 0)
|
self.assertGreater(len(traceback), 0)
|
||||||
for trace in traceback:
|
for trace in traceback:
|
||||||
self.assertIsInstance(trace, tuple)
|
self.assertIsInstance(trace, tuple)
|
||||||
|
@ -49,6 +49,7 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.keras.engine import training as keras_training
|
from tensorflow.python.keras.engine import training as keras_training
|
||||||
from tensorflow.python.keras.layers import core as keras_core
|
from tensorflow.python.keras.layers import core as keras_core
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradients
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -1203,6 +1204,24 @@ class MultiWorkerMirroredStrategyTestWithChief(
|
|||||||
self._test_summary_for_replica_zero_only(strategy)
|
self._test_summary_for_replica_zero_only(strategy)
|
||||||
|
|
||||||
|
|
||||||
|
class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||||
|
strategy_combinations.mirrored_strategy_with_one_gpu,
|
||||||
|
],
|
||||||
|
mode=["graph"]))
|
||||||
|
def testMirroredVariableAsStopGradient(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
|
inp = constant_op.constant(1.0)
|
||||||
|
x = variables.Variable(1.0)
|
||||||
|
y = inp*x
|
||||||
|
grads = gradients.gradients(x, y, stop_gradients=x)
|
||||||
|
self.assertIsNone(grads[0])
|
||||||
|
|
||||||
|
|
||||||
def _replica_id():
|
def _replica_id():
|
||||||
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
|
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
|
||||||
if not isinstance(replica_id, ops.Tensor):
|
if not isinstance(replica_id, ops.Tensor):
|
||||||
|
@ -199,21 +199,22 @@ def convert_stack(stack, include_func_start_lineno=False):
|
|||||||
included as the 5th entry in return tuples.
|
included as the 5th entry in return tuples.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of n 4-tuples or 5-tuples
|
A tuple of n 4-tuples or 5-tuples
|
||||||
(filename, lineno, name, code, [optional: func_start_lineno]), where the
|
(filename, lineno, name, code, [optional: func_start_lineno]), where the
|
||||||
code tuple element is calculated from the corresponding elements of the
|
code tuple element is calculated from the corresponding elements of the
|
||||||
input tuple.
|
input tuple.
|
||||||
"""
|
"""
|
||||||
ret = []
|
def _tuple_generator(): # pylint: disable=missing-docstring
|
||||||
for (filename, lineno, name, frame_globals, func_start_lineno) in stack:
|
for (filename, lineno, name, frame_globals, func_start_lineno) in stack:
|
||||||
linecache.checkcache(filename)
|
linecache.checkcache(filename)
|
||||||
line = linecache.getline(filename, lineno, frame_globals)
|
line = linecache.getline(filename, lineno, frame_globals)
|
||||||
if line:
|
if line:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
else:
|
else:
|
||||||
line = None
|
line = None
|
||||||
if include_func_start_lineno:
|
if include_func_start_lineno:
|
||||||
ret.append((filename, lineno, name, line, func_start_lineno))
|
yield (filename, lineno, name, line, func_start_lineno)
|
||||||
else:
|
else:
|
||||||
ret.append((filename, lineno, name, line))
|
yield (filename, lineno, name, line)
|
||||||
return ret
|
|
||||||
|
return tuple(_tuple_generator())
|
||||||
|
Loading…
Reference in New Issue
Block a user