Remove contrib dependence of basic_session_run_hooks_test by moving Fake
summary writer near summary writer. PiperOrigin-RevId: 268622421
This commit is contained in:
parent
bc9877fa53
commit
328a4aafec
@ -5907,6 +5907,7 @@ tf_py_test(
|
||||
":client",
|
||||
":client_testlib",
|
||||
":control_flow_ops",
|
||||
":fake_summary_writer",
|
||||
":framework",
|
||||
":framework_for_generated_wrappers",
|
||||
":nn_grad",
|
||||
@ -5916,7 +5917,6 @@ tf_py_test(
|
||||
":training",
|
||||
":variable_scope",
|
||||
":variables",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
tags = [
|
||||
@ -6086,7 +6086,10 @@ py_library(
|
||||
name = "summary",
|
||||
srcs = glob(
|
||||
["summary/**/*.py"],
|
||||
exclude = ["**/*test*"],
|
||||
exclude = [
|
||||
"**/fake*",
|
||||
"**/*test*",
|
||||
],
|
||||
),
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
@ -6110,6 +6113,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "fake_summary_writer",
|
||||
testonly = 1,
|
||||
srcs = ["summary/writer/fake_summary_writer.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":protos_all_py",
|
||||
":summary",
|
||||
],
|
||||
)
|
||||
|
||||
py_tests(
|
||||
name = "summary_tests",
|
||||
size = "small",
|
||||
|
143
tensorflow/python/summary/writer/fake_summary_writer.py
Normal file
143
tensorflow/python/summary/writer/fake_summary_writer.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Fake summary writer for unit tests."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.summary.writer import writer
|
||||
from tensorflow.python.summary.writer import writer_cache
|
||||
|
||||
|
||||
# TODO(ptucker): Replace with mock framework.
|
||||
class FakeSummaryWriter(object):
|
||||
"""Fake summary writer."""
|
||||
|
||||
_replaced_summary_writer = None
|
||||
|
||||
@classmethod
|
||||
def install(cls):
|
||||
if cls._replaced_summary_writer:
|
||||
raise ValueError('FakeSummaryWriter already installed.')
|
||||
cls._replaced_summary_writer = writer.FileWriter
|
||||
writer.FileWriter = FakeSummaryWriter
|
||||
writer_cache.FileWriter = FakeSummaryWriter
|
||||
|
||||
@classmethod
|
||||
def uninstall(cls):
|
||||
if not cls._replaced_summary_writer:
|
||||
raise ValueError('FakeSummaryWriter not installed.')
|
||||
writer.FileWriter = cls._replaced_summary_writer
|
||||
writer_cache.FileWriter = cls._replaced_summary_writer
|
||||
cls._replaced_summary_writer = None
|
||||
|
||||
def __init__(self, logdir, graph=None):
|
||||
self._logdir = logdir
|
||||
self._graph = graph
|
||||
self._summaries = {}
|
||||
self._added_graphs = []
|
||||
self._added_meta_graphs = []
|
||||
self._added_session_logs = []
|
||||
self._added_run_metadata = {}
|
||||
|
||||
@property
|
||||
def summaries(self):
|
||||
return self._summaries
|
||||
|
||||
def assert_summaries(self,
|
||||
test_case,
|
||||
expected_logdir=None,
|
||||
expected_graph=None,
|
||||
expected_summaries=None,
|
||||
expected_added_graphs=None,
|
||||
expected_added_meta_graphs=None,
|
||||
expected_session_logs=None):
|
||||
"""Assert expected items have been added to summary writer."""
|
||||
if expected_logdir is not None:
|
||||
test_case.assertEqual(expected_logdir, self._logdir)
|
||||
if expected_graph is not None:
|
||||
test_case.assertTrue(expected_graph is self._graph)
|
||||
expected_summaries = expected_summaries or {}
|
||||
for step in expected_summaries:
|
||||
test_case.assertTrue(
|
||||
step in self._summaries,
|
||||
msg='Missing step %s from %s.' % (step, self._summaries.keys()))
|
||||
actual_simple_values = {}
|
||||
for step_summary in self._summaries[step]:
|
||||
for v in step_summary.value:
|
||||
# Ignore global_step/sec since it's written by Supervisor in a
|
||||
# separate thread, so it's non-deterministic how many get written.
|
||||
if 'global_step/sec' != v.tag:
|
||||
actual_simple_values[v.tag] = v.simple_value
|
||||
test_case.assertEqual(expected_summaries[step], actual_simple_values)
|
||||
if expected_added_graphs is not None:
|
||||
test_case.assertEqual(expected_added_graphs, self._added_graphs)
|
||||
if expected_added_meta_graphs is not None:
|
||||
test_case.assertEqual(len(expected_added_meta_graphs),
|
||||
len(self._added_meta_graphs))
|
||||
for expected, actual in zip(expected_added_meta_graphs,
|
||||
self._added_meta_graphs):
|
||||
test_util.assert_meta_graph_protos_equal(test_case, expected, actual)
|
||||
if expected_session_logs is not None:
|
||||
test_case.assertEqual(expected_session_logs, self._added_session_logs)
|
||||
|
||||
def add_summary(self, summ, current_global_step):
|
||||
"""Add summary."""
|
||||
if isinstance(summ, bytes):
|
||||
summary_proto = summary_pb2.Summary()
|
||||
summary_proto.ParseFromString(summ)
|
||||
summ = summary_proto
|
||||
if current_global_step in self._summaries:
|
||||
step_summaries = self._summaries[current_global_step]
|
||||
else:
|
||||
step_summaries = []
|
||||
self._summaries[current_global_step] = step_summaries
|
||||
step_summaries.append(summ)
|
||||
|
||||
# NOTE: Ignore global_step since its value is non-deterministic.
|
||||
def add_graph(self, graph, global_step=None, graph_def=None):
|
||||
"""Add graph."""
|
||||
if (global_step is not None) and (global_step < 0):
|
||||
raise ValueError('Invalid global_step %s.' % global_step)
|
||||
if graph_def is not None:
|
||||
raise ValueError('Unexpected graph_def %s.' % graph_def)
|
||||
self._added_graphs.append(graph)
|
||||
|
||||
def add_meta_graph(self, meta_graph_def, global_step=None):
|
||||
"""Add metagraph."""
|
||||
if (global_step is not None) and (global_step < 0):
|
||||
raise ValueError('Invalid global_step %s.' % global_step)
|
||||
self._added_meta_graphs.append(meta_graph_def)
|
||||
|
||||
# NOTE: Ignore global_step since its value is non-deterministic.
|
||||
def add_session_log(self, session_log, global_step=None):
|
||||
# pylint: disable=unused-argument
|
||||
self._added_session_logs.append(session_log)
|
||||
|
||||
def add_run_metadata(self, run_metadata, tag, global_step=None):
|
||||
if (global_step is not None) and (global_step < 0):
|
||||
raise ValueError('Invalid global_step %s.' % global_step)
|
||||
self._added_run_metadata[tag] = run_metadata
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def reopen(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
@ -24,7 +24,6 @@ import shutil
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from tensorflow.contrib.testing.python.framework import fake_summary_writer
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -43,6 +42,7 @@ from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.summary import summary as summary_lib
|
||||
from tensorflow.python.summary.writer import fake_summary_writer
|
||||
from tensorflow.python.summary.writer import writer_cache
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
|
Loading…
Reference in New Issue
Block a user