Internal change

PiperOrigin-RevId: 340297114
Change-Id: Ib67ea44b245f8fd66cd17c7b68f26260391793a9
This commit is contained in:
Pavithra Vijay 2020-11-02 12:55:50 -08:00 committed by TensorFlower Gardener
parent 2ae6b87fe5
commit f8ba2a8d9b
5 changed files with 166 additions and 212 deletions

View File

@ -33,12 +33,15 @@ import time
import numpy as np import numpy as np
import six import six
from tensorflow.core.framework import summary_pb2
from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.distribute import distributed_file_utils from tensorflow.python.keras.distribute import distributed_file_utils
@ -1920,6 +1923,51 @@ class LearningRateScheduler(Callback):
logs['lr'] = K.get_value(self.model.optimizer.lr) logs['lr'] = K.get_value(self.model.optimizer.lr)
def keras_model_summary(name, data, step=None):
"""Writes a Keras model as JSON to as a Summary.
Writing the Keras model configuration allows the TensorBoard graph plugin to
render a conceptual graph, as opposed to graph of ops. In case the model fails
to serialize as JSON, it ignores and returns False.
Args:
name: A name for this summary. The summary tag used for TensorBoard will be
this name prefixed by any active name scopes.
data: A Keras Model to write.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.
Returns:
True on success, or False if no summary was written because no default
summary writer was available.
Raises:
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
"""
summary_metadata = summary_pb2.SummaryMetadata()
# Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for
# the rationale.
summary_metadata.plugin_data.plugin_name = 'graph_keras_model'
# version number = 1
summary_metadata.plugin_data.content = b'1'
try:
json_string = data.to_json()
except Exception as exc: # pylint: disable=broad-except
# An exception should not break a model code.
logging.warn('Model failed to serialize as JSON. Ignoring... %s', exc)
return False
with summary_ops_v2.summary_scope(name, 'graph_keras_model',
[data, step]) as (tag, _):
with ops.device('cpu:0'):
tensor = constant_op.constant(json_string, dtype=dtypes.string)
return summary_ops_v2.write(
tag=tag, tensor=tensor, step=step, metadata=summary_metadata)
@keras_export('keras.callbacks.TensorBoard', v1=[]) @keras_export('keras.callbacks.TensorBoard', v1=[])
class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
# pylint: disable=line-too-long # pylint: disable=line-too-long
@ -2164,7 +2212,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
self.model._is_graph_network or # pylint: disable=protected-access self.model._is_graph_network or # pylint: disable=protected-access
self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access
if summary_writable: if summary_writable:
summary_ops_v2.keras_model('keras', self.model, step=0) keras_model_summary('keras', self.model, step=0)
def _configure_embeddings(self): def _configure_embeddings(self):
"""Configure the Projector for embeddings.""" """Configure the Projector for embeddings."""

View File

@ -33,20 +33,25 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.core.framework import summary_pb2 from tensorflow.core.framework import summary_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
from tensorflow.python.keras.utils import np_utils from tensorflow.python.keras.utils import np_utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import save_options as save_options_lib from tensorflow.python.saved_model import save_options as save_options_lib
@ -2617,5 +2622,117 @@ class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):
ckpt_file_path) ckpt_file_path)
class SummaryOpsTest(test.TestCase):
def tearDown(self):
super(SummaryOpsTest, self).tearDown()
summary_ops_v2.trace_off()
def keras_model(self, *args, **kwargs):
logdir = self.get_temp_dir()
writer = summary_ops_v2.create_file_writer_v2(logdir)
with writer.as_default():
keras.callbacks.keras_model_summary(*args, **kwargs)
writer.close()
events = events_from_logdir(logdir)
# The first event contains no summary values. The written content goes to
# the second event.
return events[1]
@testing_utils.run_v2_only
def testKerasModel(self):
model = keras.Sequential(
[Dense(10, input_shape=(100,)),
Activation('relu', name='my_relu')])
event = self.keras_model(name='my_name', data=model, step=1)
first_val = event.summary.value[0]
self.assertEqual(model.to_json(), first_val.tensor.string_val[0].decode())
@testing_utils.run_v2_only
def testKerasModel_usesDefaultStep(self):
model = keras.Sequential(
[Dense(10, input_shape=(100,)),
Activation('relu', name='my_relu')])
try:
summary_ops_v2.set_step(42)
event = self.keras_model(name='my_name', data=model)
self.assertEqual(42, event.step)
finally:
# Reset to default state for other tests.
summary_ops_v2.set_step(None)
@testing_utils.run_v2_only
def testKerasModel_subclass(self):
class SimpleSubclass(keras.Model):
def __init__(self):
super(SimpleSubclass, self).__init__(name='subclass')
self.dense = Dense(10, input_shape=(100,))
self.activation = Activation('relu', name='my_relu')
def call(self, inputs):
x = self.dense(inputs)
return self.activation(x)
model = SimpleSubclass()
with test.mock.patch.object(logging, 'warn') as mock_log:
self.assertFalse(
keras.callbacks.keras_model_summary(
name='my_name', data=model, step=1))
self.assertRegex(
str(mock_log.call_args), 'Model failed to serialize as JSON.')
@testing_utils.run_v2_only
def testKerasModel_otherExceptions(self):
model = keras.Sequential()
with test.mock.patch.object(model, 'to_json') as mock_to_json:
with test.mock.patch.object(logging, 'warn') as mock_log:
mock_to_json.side_effect = Exception('oops')
self.assertFalse(
keras.callbacks.keras_model_summary(
name='my_name', data=model, step=1))
self.assertRegex(
str(mock_log.call_args),
'Model failed to serialize as JSON. Ignoring')
def events_from_file(filepath):
"""Returns all events in a single event file.
Args:
filepath: Path to the event file.
Returns:
A list of all tf.Event protos in the event file.
"""
result = []
raw_dataset = readers.TFRecordDatasetV2([filepath])
for raw_record in raw_dataset.take(10):
event = event_pb2.Event()
event.ParseFromString(raw_record.numpy())
result.append(event)
return result
def events_from_logdir(logdir):
"""Returns all events in the single eventfile in logdir.
Args:
logdir: The directory in which the single event file is sought.
Returns:
A list of all tf.Event protos from the single event file.
Raises:
AssertionError: If logdir does not contain exactly one file.
"""
assert gfile.Exists(logdir)
files = gfile.ListDirectory(logdir)
assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
return events_from_file(os.path.join(logdir, files[0]))
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -279,23 +279,6 @@ tf_py_test(
], ],
) )
cuda_py_test(
name = "summary_ops_test",
size = "small",
srcs = ["summary_ops_test.py"],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python/keras:testing_utils",
"//tensorflow/python/keras/engine",
"//tensorflow/python/keras/layers:core",
],
)
tf_py_test( tf_py_test(
name = "saved_model_test", name = "saved_model_test",
size = "small", size = "small",

View File

@ -1,147 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for V2 summary ops from summary_ops_v2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.core.util import event_pb2
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine.sequential import Sequential
from tensorflow.python.keras.engine.training import Model
from tensorflow.python.keras.layers.core import Activation
from tensorflow.python.keras.layers.core import Dense
from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
class SummaryOpsTest(test.TestCase):
def tearDown(self):
super(SummaryOpsTest, self).tearDown()
summary_ops.trace_off()
def keras_model(self, *args, **kwargs):
logdir = self.get_temp_dir()
writer = summary_ops.create_file_writer_v2(logdir)
with writer.as_default():
summary_ops.keras_model(*args, **kwargs)
writer.close()
events = events_from_logdir(logdir)
# The first event contains no summary values. The written content goes to
# the second event.
return events[1]
@testing_utils.run_v2_only
def testKerasModel(self):
model = Sequential(
[Dense(10, input_shape=(100,)),
Activation('relu', name='my_relu')])
event = self.keras_model(name='my_name', data=model, step=1)
first_val = event.summary.value[0]
self.assertEqual(model.to_json(), first_val.tensor.string_val[0].decode())
@testing_utils.run_v2_only
def testKerasModel_usesDefaultStep(self):
model = Sequential(
[Dense(10, input_shape=(100,)),
Activation('relu', name='my_relu')])
try:
summary_ops.set_step(42)
event = self.keras_model(name='my_name', data=model)
self.assertEqual(42, event.step)
finally:
# Reset to default state for other tests.
summary_ops.set_step(None)
@testing_utils.run_v2_only
def testKerasModel_subclass(self):
class SimpleSubclass(Model):
def __init__(self):
super(SimpleSubclass, self).__init__(name='subclass')
self.dense = Dense(10, input_shape=(100,))
self.activation = Activation('relu', name='my_relu')
def call(self, inputs):
x = self.dense(inputs)
return self.activation(x)
model = SimpleSubclass()
with test.mock.patch.object(logging, 'warn') as mock_log:
self.assertFalse(
summary_ops.keras_model(name='my_name', data=model, step=1))
self.assertRegex(
str(mock_log.call_args), 'Model failed to serialize as JSON.')
@testing_utils.run_v2_only
def testKerasModel_otherExceptions(self):
model = Sequential()
with test.mock.patch.object(model, 'to_json') as mock_to_json:
with test.mock.patch.object(logging, 'warn') as mock_log:
mock_to_json.side_effect = Exception('oops')
self.assertFalse(
summary_ops.keras_model(name='my_name', data=model, step=1))
self.assertRegex(
str(mock_log.call_args),
'Model failed to serialize as JSON. Ignoring... oops')
def events_from_file(filepath):
"""Returns all events in a single event file.
Args:
filepath: Path to the event file.
Returns:
A list of all tf.Event protos in the event file.
"""
records = list(tf_record.tf_record_iterator(filepath))
result = []
for r in records:
event = event_pb2.Event()
event.ParseFromString(r)
result.append(event)
return result
def events_from_logdir(logdir):
"""Returns all events in the single eventfile in logdir.
Args:
logdir: The directory in which the single event file is sought.
Returns:
A list of all tf.Event protos from the single event file.
Raises:
AssertionError: If logdir does not contain exactly one file.
"""
assert gfile.Exists(logdir)
files = gfile.ListDirectory(logdir)
assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
return events_from_file(os.path.join(logdir, files[0]))
if __name__ == '__main__':
test.main()

View File

@ -1202,53 +1202,6 @@ def run_metadata_graphs(name, data, step=None):
metadata=summary_metadata) metadata=summary_metadata)
def keras_model(name, data, step=None):
"""Writes a Keras model as JSON to as a Summary.
Writing the Keras model configuration allows the TensorBoard graph plugin to
render a conceptual graph, as opposed to graph of ops. In case the model fails
to serialize as JSON, it ignores and returns False.
Args:
name: A name for this summary. The summary tag used for TensorBoard will be
this name prefixed by any active name scopes.
data: A Keras Model to write.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.
Returns:
True on success, or False if no summary was written because no default
summary writer was available.
Raises:
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
"""
summary_metadata = summary_pb2.SummaryMetadata()
# Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for
# the rationale.
summary_metadata.plugin_data.plugin_name = "graph_keras_model"
# version number = 1
summary_metadata.plugin_data.content = b"1"
try:
json_string = data.to_json()
except Exception as exc: # pylint: disable=broad-except
# An exception should not break a model code.
logging.warn("Model failed to serialize as JSON. Ignoring... %s" % exc)
return False
with summary_scope(name, "graph_keras_model", [data, step]) as (tag, _):
with ops.device("cpu:0"):
tensor = constant_op.constant(json_string, dtype=dtypes.string)
return write(
tag=tag,
tensor=tensor,
step=step,
metadata=summary_metadata)
_TraceContext = collections.namedtuple("TraceContext", ("graph", "profiler")) _TraceContext = collections.namedtuple("TraceContext", ("graph", "profiler"))
_current_trace_context_lock = threading.Lock() _current_trace_context_lock = threading.Lock()
_current_trace_context = None _current_trace_context = None