Disable graph rewrites in tests that assert a given graph structure (round 2). This will make it possible to optimize these graphs in the future without having to update the test with every improvement.

PiperOrigin-RevId: 160997959
This commit is contained in:
Benoit Steiner 2017-07-05 13:32:08 -07:00 committed by TensorFlower Gardener
parent d13fd82289
commit 70804d820b
4 changed files with 46 additions and 15 deletions

View File

@ -25,6 +25,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.cli import analyzer_cli
from tensorflow.python.debug.cli import cli_shared
@ -43,6 +44,13 @@ from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
def no_rewrite_session_config():
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
def line_number_above():
return tf_inspect.stack()[1][2] - 1
@ -506,7 +514,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
cls._curr_file_path = os.path.abspath(
tf_inspect.getfile(tf_inspect.currentframe()))
cls._sess = session.Session()
cls._sess = session.Session(config=no_rewrite_session_config())
with cls._sess as sess:
u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
v_init_val = np.array([[2.0], [-1.0]])
@ -1382,7 +1390,7 @@ class AnalyzerCLIPrintLargeTensorTest(test_util.TensorFlowTestCase):
def setUpClass(cls):
cls._dump_root = tempfile.mkdtemp()
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
# 2400 elements should exceed the default threshold (2000).
x = constant_op.constant(np.zeros([300, 8]), name="large_tensors/x")
@ -1459,7 +1467,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
else:
cls._main_device = "/job:localhost/replica:0/task:0/cpu:0"
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
x_init_val = np.array([5.0, 3.0])
x_init = constant_op.constant(x_init_val, shape=[2])
x = variables.Variable(x_init, name="control_deps/x")
@ -1799,7 +1807,7 @@ class AnalyzerCLIWhileLoopTest(test_util.TensorFlowTestCase):
def setUpClass(cls):
cls._dump_root = tempfile.mkdtemp()
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
loop_var = constant_op.constant(0, name="while_loop_test/loop_var")
cond = lambda loop_var: math_ops.less(loop_var, 10)
body = lambda loop_var: math_ops.add(loop_var, 1)

View File

@ -22,6 +22,7 @@ import shutil
import tempfile
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import debug_utils
@ -35,6 +36,12 @@ from tensorflow.python.platform import googletest
class SessionDebugTest(session_debug_testlib.SessionDebugTestBase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
def _debug_urls(self, run_number=None):
return ["file://%s" % self._debug_dump_dir(run_number=run_number)]
@ -47,7 +54,7 @@ class SessionDebugTest(session_debug_testlib.SessionDebugTestBase):
def testAllowsDifferentWatchesOnDifferentRuns(self):
"""Test watching different tensors on different runs of the same graph."""
with session.Session() as sess:
with session.Session(config=self._no_rewrite_session_config()) as sess:
u_init_val = [[5.0, 3.0], [-1.0, 0.0]]
v_init_val = [[2.0], [-1.0]]

View File

@ -29,6 +29,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
@ -53,6 +54,13 @@ from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
def no_rewrite_session_config():
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
class _RNNCellForTest(rnn_cell_impl.RNNCell):
"""RNN cell for testing."""
@ -160,7 +168,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
validate=validate)
def _generate_dump_from_simple_addition_graph(self):
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
v_init_val = np.array([[2.0], [-1.0]])
@ -304,7 +312,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
results.dump.node_op_type("foo_bar")
def testDumpStringTensorsWorks(self):
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
str1_init_val = np.array(b"abc")
str2_init_val = np.array(b"def")
@ -419,7 +427,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertEqual(s_init_val, sess.run(s))
def testDebugWhileLoopGeneratesMultipleDumps(self):
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
num_iter = 10
# "u" is the Variable being updated in the loop.
@ -659,7 +667,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertEqual(x_name, first_bad_datum[0].node_name)
def _session_run_for_graph_structure_lookup(self):
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
u_name = "testDumpGraphStructureLookup/u"
v_name = "testDumpGraphStructureLookup/v"
w_name = "testDumpGraphStructureLookup/w"
@ -798,7 +806,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertTrue(dump.loaded_partition_graphs())
def testGraphPathFindingOnControlEdgesWorks(self):
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
v1 = variables.Variable(1.0, name="v1")
v2 = variables.Variable(2.0, name="v2")
v3 = variables.Variable(3.0, name="v3")
@ -814,7 +822,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertIsNone(dump.find_some_path("v1", "c", include_control=False))
def testGraphPathFindingReverseRefEdgeWorks(self):
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
v = variables.Variable(10.0, name="v")
delta = variables.Variable(1.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
@ -1164,7 +1172,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertAllClose(np.array([[-3.0, 0.0]]), x_dumps[0].get_tensor())
def testDebugNumericSummaryOnInitializedTensorGivesCorrectResult(self):
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
a = variables.Variable(
[
np.nan, np.nan, 0.0, 0.0, 0.0, -1.0, -3.0, 3.0, 7.0, -np.inf,
@ -1252,7 +1260,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertIn("m:0:DebugNumericSummary", dump.debug_watch_keys("m"))
def testDebugNumericSummaryInvalidAttributesStringAreCaught(self):
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
a = variables.Variable(10.0, name="a")
b = variables.Variable(0.0, name="b")
c = variables.Variable(0.0, name="c")
@ -1300,7 +1308,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
sess.run(y, options=run_options, run_metadata=run_metadata)
def testDebugNumericSummaryMuteOnHealthyMutesOnlyHealthyTensorDumps(self):
with session.Session() as sess:
with session.Session(config=no_rewrite_session_config()) as sess:
a = variables.Variable(10.0, name="a")
b = variables.Variable(0.0, name="b")
c = variables.Variable(0.0, name="c")

View File

@ -24,6 +24,8 @@ import threading
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.wrappers import framework
@ -139,6 +141,12 @@ class TestDebugWrapperSessionBadAction(framework.BaseDebugWrapperSession):
class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
def setUp(self):
self._observer = {
"sess_init_count": 0,
@ -153,7 +161,7 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
self._dump_root = tempfile.mkdtemp()
self._sess = session.Session()
self._sess = session.Session(config=self._no_rewrite_session_config())
self._a_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
self._b_init_val = np.array([[2.0], [-1.0]])