From 328a4aafecfa3fb9dde0516e684db43bc15a1a03 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 11 Sep 2019 23:59:58 -0700 Subject: [PATCH] Remove contrib dependence of basic_session_run_hooks_test by moving Fake summary writer near summary writer. PiperOrigin-RevId: 268622421 --- tensorflow/python/BUILD | 20 ++- .../summary/writer/fake_summary_writer.py | 143 ++++++++++++++++++ .../training/basic_session_run_hooks_test.py | 2 +- 3 files changed, 162 insertions(+), 3 deletions(-) create mode 100644 tensorflow/python/summary/writer/fake_summary_writer.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f1991366f09..7ae73b30e9e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5907,6 +5907,7 @@ tf_py_test( ":client", ":client_testlib", ":control_flow_ops", + ":fake_summary_writer", ":framework", ":framework_for_generated_wrappers", ":nn_grad", @@ -5916,7 +5917,6 @@ tf_py_test( ":training", ":variable_scope", ":variables", - "//tensorflow/contrib/testing:testing_py", "//tensorflow/core:protos_all_py", ], tags = [ @@ -6086,7 +6086,10 @@ py_library( name = "summary", srcs = glob( ["summary/**/*.py"], - exclude = ["**/*test*"], + exclude = [ + "**/fake*", + "**/*test*", + ], ), srcs_version = "PY2AND3", visibility = ["//visibility:public"], @@ -6110,6 +6113,19 @@ py_library( ], ) +py_library( + name = "fake_summary_writer", + testonly = 1, + srcs = ["summary/writer/fake_summary_writer.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":framework_test_lib", + ":protos_all_py", + ":summary", + ], +) + py_tests( name = "summary_tests", size = "small", diff --git a/tensorflow/python/summary/writer/fake_summary_writer.py b/tensorflow/python/summary/writer/fake_summary_writer.py new file mode 100644 index 00000000000..eac34afc4ad --- /dev/null +++ b/tensorflow/python/summary/writer/fake_summary_writer.py @@ -0,0 +1,143 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fake summary writer for unit tests.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.framework import test_util +from tensorflow.python.summary.writer import writer +from tensorflow.python.summary.writer import writer_cache + + +# TODO(ptucker): Replace with mock framework. +class FakeSummaryWriter(object): + """Fake summary writer.""" + + _replaced_summary_writer = None + + @classmethod + def install(cls): + if cls._replaced_summary_writer: + raise ValueError('FakeSummaryWriter already installed.') + cls._replaced_summary_writer = writer.FileWriter + writer.FileWriter = FakeSummaryWriter + writer_cache.FileWriter = FakeSummaryWriter + + @classmethod + def uninstall(cls): + if not cls._replaced_summary_writer: + raise ValueError('FakeSummaryWriter not installed.') + writer.FileWriter = cls._replaced_summary_writer + writer_cache.FileWriter = cls._replaced_summary_writer + cls._replaced_summary_writer = None + + def __init__(self, logdir, graph=None): + self._logdir = logdir + self._graph = graph + self._summaries = {} + self._added_graphs = [] + self._added_meta_graphs = [] + self._added_session_logs = [] + self._added_run_metadata = {} + + @property + def summaries(self): + return self._summaries + + def assert_summaries(self, + test_case, + expected_logdir=None, + expected_graph=None, + expected_summaries=None, + expected_added_graphs=None, + expected_added_meta_graphs=None, + expected_session_logs=None): + """Assert expected items have been added to summary writer.""" + if expected_logdir is not None: + test_case.assertEqual(expected_logdir, self._logdir) + if expected_graph is not None: + test_case.assertTrue(expected_graph is self._graph) + expected_summaries = expected_summaries or {} + for step in expected_summaries: + test_case.assertTrue( + step in self._summaries, + msg='Missing step %s from %s.' % (step, self._summaries.keys())) + actual_simple_values = {} + for step_summary in self._summaries[step]: + for v in step_summary.value: + # Ignore global_step/sec since it's written by Supervisor in a + # separate thread, so it's non-deterministic how many get written. + if 'global_step/sec' != v.tag: + actual_simple_values[v.tag] = v.simple_value + test_case.assertEqual(expected_summaries[step], actual_simple_values) + if expected_added_graphs is not None: + test_case.assertEqual(expected_added_graphs, self._added_graphs) + if expected_added_meta_graphs is not None: + test_case.assertEqual(len(expected_added_meta_graphs), + len(self._added_meta_graphs)) + for expected, actual in zip(expected_added_meta_graphs, + self._added_meta_graphs): + test_util.assert_meta_graph_protos_equal(test_case, expected, actual) + if expected_session_logs is not None: + test_case.assertEqual(expected_session_logs, self._added_session_logs) + + def add_summary(self, summ, current_global_step): + """Add summary.""" + if isinstance(summ, bytes): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summ) + summ = summary_proto + if current_global_step in self._summaries: + step_summaries = self._summaries[current_global_step] + else: + step_summaries = [] + self._summaries[current_global_step] = step_summaries + step_summaries.append(summ) + + # NOTE: Ignore global_step since its value is non-deterministic. + def add_graph(self, graph, global_step=None, graph_def=None): + """Add graph.""" + if (global_step is not None) and (global_step < 0): + raise ValueError('Invalid global_step %s.' % global_step) + if graph_def is not None: + raise ValueError('Unexpected graph_def %s.' % graph_def) + self._added_graphs.append(graph) + + def add_meta_graph(self, meta_graph_def, global_step=None): + """Add metagraph.""" + if (global_step is not None) and (global_step < 0): + raise ValueError('Invalid global_step %s.' % global_step) + self._added_meta_graphs.append(meta_graph_def) + + # NOTE: Ignore global_step since its value is non-deterministic. + def add_session_log(self, session_log, global_step=None): + # pylint: disable=unused-argument + self._added_session_logs.append(session_log) + + def add_run_metadata(self, run_metadata, tag, global_step=None): + if (global_step is not None) and (global_step < 0): + raise ValueError('Invalid global_step %s.' % global_step) + self._added_run_metadata[tag] = run_metadata + + def flush(self): + pass + + def reopen(self): + pass + + def close(self): + pass diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index f25085435fb..3e1ccfed0dc 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -24,7 +24,6 @@ import shutil import tempfile import time -from tensorflow.contrib.testing.python.framework import fake_summary_writer from tensorflow.python.client import session as session_lib from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -43,6 +42,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging from tensorflow.python.summary import summary as summary_lib +from tensorflow.python.summary.writer import fake_summary_writer from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import checkpoint_utils