Add contrib/testing.
Change: 115578243
This commit is contained in:
parent
9ccc4b6afe
commit
d1aed6505a
@ -14,6 +14,7 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
],
|
||||
)
|
||||
|
@ -22,4 +22,5 @@ from __future__ import print_function
|
||||
# Add projects here, they will show up under tf.contrib.
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
from tensorflow.contrib import testing
|
||||
from tensorflow.contrib import util
|
||||
|
29
tensorflow/contrib/testing/BUILD
Normal file
29
tensorflow/contrib/testing/BUILD
Normal file
@ -0,0 +1,29 @@
|
||||
# Description:
|
||||
# contains parts of TensorFlow that are experimental or unstable and which are not supported.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
py_library(
|
||||
name = "testing_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/framework/test_util.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
22
tensorflow/contrib/testing/__init__.py
Normal file
22
tensorflow/contrib/testing/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Testing utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.testing.python.framework.test_util import *
|
118
tensorflow/contrib/testing/python/framework/test_util.py
Normal file
118
tensorflow/contrib/testing/python/framework/test_util.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""Test utilities."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import glob
|
||||
import os
|
||||
import numpy as np
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python.platform import logging
|
||||
from tensorflow.python.training import summary_io
|
||||
|
||||
|
||||
def assert_summary(expected_tags, expected_simple_values, summary_proto):
|
||||
"""Asserts summary contains the specified tags and values.
|
||||
|
||||
Args:
|
||||
expected_tags: All tags in summary.
|
||||
expected_simple_values: Simply values for some tags.
|
||||
summary_proto: Summary to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: if expectations are not met.
|
||||
"""
|
||||
actual_tags = set()
|
||||
for value in summary_proto.value:
|
||||
actual_tags.add(value.tag)
|
||||
if value.tag in expected_simple_values:
|
||||
expected = expected_simple_values[value.tag]
|
||||
actual = value.simple_value
|
||||
np.testing.assert_almost_equal(
|
||||
actual, expected, decimal=2, err_msg=value.tag)
|
||||
expected_tags = set(expected_tags)
|
||||
if expected_tags != actual_tags:
|
||||
raise ValueError('Expected tags %s, got %s.' % (expected_tags, actual_tags))
|
||||
|
||||
|
||||
def to_summary_proto(summary_str):
|
||||
"""Create summary based on latest stats.
|
||||
|
||||
Args:
|
||||
summary_str: Serialized summary.
|
||||
Returns:
|
||||
summary_pb2.Summary.
|
||||
Raises:
|
||||
ValueError: if tensor is not a valid summary tensor.
|
||||
"""
|
||||
summary = summary_pb2.Summary()
|
||||
summary.ParseFromString(summary_str)
|
||||
return summary
|
||||
|
||||
|
||||
# TODO(ptucker): Move to a non-test package?
|
||||
def latest_event_file(base_dir):
|
||||
"""Find latest event file in `base_dir`.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory in which TF event flies are stored.
|
||||
Returns:
|
||||
File path, or `None` if none exists.
|
||||
"""
|
||||
file_paths = glob.glob(os.path.join(base_dir, 'events.*'))
|
||||
return sorted(file_paths)[-1] if file_paths else None
|
||||
|
||||
|
||||
def latest_events(base_dir):
|
||||
"""Parse events from latest event file in base_dir.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory in which TF event flies are stored.
|
||||
Returns:
|
||||
Iterable of event protos.
|
||||
Raises:
|
||||
ValueError: if no event files exist under base_dir.
|
||||
"""
|
||||
file_path = latest_event_file(base_dir)
|
||||
return summary_io.summary_iterator(file_path) if file_path else []
|
||||
|
||||
|
||||
def latest_summaries(base_dir):
|
||||
"""Parse summary events from latest event file in base_dir.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory in which TF event flies are stored.
|
||||
Returns:
|
||||
List of event protos.
|
||||
Raises:
|
||||
ValueError: if no event files exist under base_dir.
|
||||
"""
|
||||
return [e for e in latest_events(base_dir) if e.HasField('summary')]
|
||||
|
||||
|
||||
def simple_values_from_events(events, tags):
|
||||
"""Parse summaries from events with simple_value.
|
||||
|
||||
Args:
|
||||
events: List of tensorflow.Event protos.
|
||||
tags: List of string event tags corresponding to simple_value summaries.
|
||||
Returns:
|
||||
dict of tag:value.
|
||||
Raises:
|
||||
ValueError: if a summary with a specified tag does not contain simple_value.
|
||||
"""
|
||||
step_by_tag = {}
|
||||
value_by_tag = {}
|
||||
for e in events:
|
||||
if e.HasField('summary'):
|
||||
for v in e.summary.value:
|
||||
tag = v.tag
|
||||
if tag in tags:
|
||||
if not v.HasField('simple_value'):
|
||||
raise ValueError('Summary for %s is not a simple_value.' % tag)
|
||||
# The events are mostly sorted in step order, but we explicitly check
|
||||
# just in case.
|
||||
if tag not in step_by_tag or e.step > step_by_tag[tag]:
|
||||
step_by_tag[tag] = e.step
|
||||
value_by_tag[tag] = v.simple_value
|
||||
return value_by_tag
|
||||
|
Loading…
x
Reference in New Issue
Block a user