diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 727d8e40eff..d27557133ee 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -33,12 +33,15 @@ import time import numpy as np import six +from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.distribute import distributed_file_utils @@ -1920,6 +1923,51 @@ class LearningRateScheduler(Callback): logs['lr'] = K.get_value(self.model.optimizer.lr) +def keras_model_summary(name, data, step=None): + """Writes a Keras model as JSON to as a Summary. + + Writing the Keras model configuration allows the TensorBoard graph plugin to + render a conceptual graph, as opposed to graph of ops. In case the model fails + to serialize as JSON, it ignores and returns False. + + Args: + name: A name for this summary. The summary tag used for TensorBoard will be + this name prefixed by any active name scopes. + data: A Keras Model to write. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + + Returns: + True on success, or False if no summary was written because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + summary_metadata = summary_pb2.SummaryMetadata() + # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for + # the rationale. + summary_metadata.plugin_data.plugin_name = 'graph_keras_model' + # version number = 1 + summary_metadata.plugin_data.content = b'1' + + try: + json_string = data.to_json() + except Exception as exc: # pylint: disable=broad-except + # An exception should not break a model code. + logging.warn('Model failed to serialize as JSON. Ignoring... %s', exc) + return False + + with summary_ops_v2.summary_scope(name, 'graph_keras_model', + [data, step]) as (tag, _): + with ops.device('cpu:0'): + tensor = constant_op.constant(json_string, dtype=dtypes.string) + return summary_ops_v2.write( + tag=tag, tensor=tensor, step=step, metadata=summary_metadata) + + @keras_export('keras.callbacks.TensorBoard', v1=[]) class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): # pylint: disable=line-too-long @@ -2164,7 +2212,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): self.model._is_graph_network or # pylint: disable=protected-access self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access if summary_writable: - summary_ops_v2.keras_model('keras', self.model, step=0) + keras_model_summary('keras', self.model, step=0) def _configure_embeddings(self): """Configure the Projector for embeddings.""" diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 5cb33e73622..538f981d509 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -33,20 +33,25 @@ from absl.testing import parameterized import numpy as np from tensorflow.core.framework import summary_pb2 +from tensorflow.core.util import event_pb2 from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.layers import Activation +from tensorflow.python.keras.layers import Dense from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.keras.utils import np_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 +from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import save_options as save_options_lib @@ -2617,5 +2622,117 @@ class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase): ckpt_file_path) +class SummaryOpsTest(test.TestCase): + + def tearDown(self): + super(SummaryOpsTest, self).tearDown() + summary_ops_v2.trace_off() + + def keras_model(self, *args, **kwargs): + logdir = self.get_temp_dir() + writer = summary_ops_v2.create_file_writer_v2(logdir) + with writer.as_default(): + keras.callbacks.keras_model_summary(*args, **kwargs) + writer.close() + events = events_from_logdir(logdir) + # The first event contains no summary values. The written content goes to + # the second event. + return events[1] + + @testing_utils.run_v2_only + def testKerasModel(self): + model = keras.Sequential( + [Dense(10, input_shape=(100,)), + Activation('relu', name='my_relu')]) + event = self.keras_model(name='my_name', data=model, step=1) + first_val = event.summary.value[0] + self.assertEqual(model.to_json(), first_val.tensor.string_val[0].decode()) + + @testing_utils.run_v2_only + def testKerasModel_usesDefaultStep(self): + model = keras.Sequential( + [Dense(10, input_shape=(100,)), + Activation('relu', name='my_relu')]) + try: + summary_ops_v2.set_step(42) + event = self.keras_model(name='my_name', data=model) + self.assertEqual(42, event.step) + finally: + # Reset to default state for other tests. + summary_ops_v2.set_step(None) + + @testing_utils.run_v2_only + def testKerasModel_subclass(self): + + class SimpleSubclass(keras.Model): + + def __init__(self): + super(SimpleSubclass, self).__init__(name='subclass') + self.dense = Dense(10, input_shape=(100,)) + self.activation = Activation('relu', name='my_relu') + + def call(self, inputs): + x = self.dense(inputs) + return self.activation(x) + + model = SimpleSubclass() + with test.mock.patch.object(logging, 'warn') as mock_log: + self.assertFalse( + keras.callbacks.keras_model_summary( + name='my_name', data=model, step=1)) + self.assertRegex( + str(mock_log.call_args), 'Model failed to serialize as JSON.') + + @testing_utils.run_v2_only + def testKerasModel_otherExceptions(self): + model = keras.Sequential() + + with test.mock.patch.object(model, 'to_json') as mock_to_json: + with test.mock.patch.object(logging, 'warn') as mock_log: + mock_to_json.side_effect = Exception('oops') + self.assertFalse( + keras.callbacks.keras_model_summary( + name='my_name', data=model, step=1)) + self.assertRegex( + str(mock_log.call_args), + 'Model failed to serialize as JSON. Ignoring') + + +def events_from_file(filepath): + """Returns all events in a single event file. + + Args: + filepath: Path to the event file. + + Returns: + A list of all tf.Event protos in the event file. + """ + result = [] + raw_dataset = readers.TFRecordDatasetV2([filepath]) + for raw_record in raw_dataset.take(10): + event = event_pb2.Event() + event.ParseFromString(raw_record.numpy()) + result.append(event) + return result + + +def events_from_logdir(logdir): + """Returns all events in the single eventfile in logdir. + + Args: + logdir: The directory in which the single event file is sought. + + Returns: + A list of all tf.Event protos from the single event file. + + Raises: + AssertionError: If logdir does not contain exactly one file. + """ + assert gfile.Exists(logdir) + files = gfile.ListDirectory(logdir) + assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files + return events_from_file(os.path.join(logdir, files[0])) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD index d9f1710e24c..e7db6c45e04 100644 --- a/tensorflow/python/keras/tests/BUILD +++ b/tensorflow/python/keras/tests/BUILD @@ -279,23 +279,6 @@ tf_py_test( ], ) -cuda_py_test( - name = "summary_ops_test", - size = "small", - srcs = ["summary_ops_test.py"], - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:lib", - "//tensorflow/python:platform", - "//tensorflow/python:summary_ops_v2", - "//tensorflow/python/keras:testing_utils", - "//tensorflow/python/keras/engine", - "//tensorflow/python/keras/layers:core", - ], -) - tf_py_test( name = "saved_model_test", size = "small", diff --git a/tensorflow/python/keras/tests/summary_ops_test.py b/tensorflow/python/keras/tests/summary_ops_test.py deleted file mode 100644 index e0e91c0d759..00000000000 --- a/tensorflow/python/keras/tests/summary_ops_test.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for V2 summary ops from summary_ops_v2.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from tensorflow.core.util import event_pb2 -from tensorflow.python.keras import testing_utils -from tensorflow.python.keras.engine.sequential import Sequential -from tensorflow.python.keras.engine.training import Model -from tensorflow.python.keras.layers.core import Activation -from tensorflow.python.keras.layers.core import Dense -from tensorflow.python.lib.io import tf_record -from tensorflow.python.ops import summary_ops_v2 as summary_ops -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging - - -class SummaryOpsTest(test.TestCase): - - def tearDown(self): - super(SummaryOpsTest, self).tearDown() - summary_ops.trace_off() - - def keras_model(self, *args, **kwargs): - logdir = self.get_temp_dir() - writer = summary_ops.create_file_writer_v2(logdir) - with writer.as_default(): - summary_ops.keras_model(*args, **kwargs) - writer.close() - events = events_from_logdir(logdir) - # The first event contains no summary values. The written content goes to - # the second event. - return events[1] - - @testing_utils.run_v2_only - def testKerasModel(self): - model = Sequential( - [Dense(10, input_shape=(100,)), - Activation('relu', name='my_relu')]) - event = self.keras_model(name='my_name', data=model, step=1) - first_val = event.summary.value[0] - self.assertEqual(model.to_json(), first_val.tensor.string_val[0].decode()) - - @testing_utils.run_v2_only - def testKerasModel_usesDefaultStep(self): - model = Sequential( - [Dense(10, input_shape=(100,)), - Activation('relu', name='my_relu')]) - try: - summary_ops.set_step(42) - event = self.keras_model(name='my_name', data=model) - self.assertEqual(42, event.step) - finally: - # Reset to default state for other tests. - summary_ops.set_step(None) - - @testing_utils.run_v2_only - def testKerasModel_subclass(self): - - class SimpleSubclass(Model): - - def __init__(self): - super(SimpleSubclass, self).__init__(name='subclass') - self.dense = Dense(10, input_shape=(100,)) - self.activation = Activation('relu', name='my_relu') - - def call(self, inputs): - x = self.dense(inputs) - return self.activation(x) - - model = SimpleSubclass() - with test.mock.patch.object(logging, 'warn') as mock_log: - self.assertFalse( - summary_ops.keras_model(name='my_name', data=model, step=1)) - self.assertRegex( - str(mock_log.call_args), 'Model failed to serialize as JSON.') - - @testing_utils.run_v2_only - def testKerasModel_otherExceptions(self): - model = Sequential() - - with test.mock.patch.object(model, 'to_json') as mock_to_json: - with test.mock.patch.object(logging, 'warn') as mock_log: - mock_to_json.side_effect = Exception('oops') - self.assertFalse( - summary_ops.keras_model(name='my_name', data=model, step=1)) - self.assertRegex( - str(mock_log.call_args), - 'Model failed to serialize as JSON. Ignoring... oops') - - -def events_from_file(filepath): - """Returns all events in a single event file. - - Args: - filepath: Path to the event file. - - Returns: - A list of all tf.Event protos in the event file. - """ - records = list(tf_record.tf_record_iterator(filepath)) - result = [] - for r in records: - event = event_pb2.Event() - event.ParseFromString(r) - result.append(event) - return result - - -def events_from_logdir(logdir): - """Returns all events in the single eventfile in logdir. - - Args: - logdir: The directory in which the single event file is sought. - - Returns: - A list of all tf.Event protos from the single event file. - - Raises: - AssertionError: If logdir does not contain exactly one file. - """ - assert gfile.Exists(logdir) - files = gfile.ListDirectory(logdir) - assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files - return events_from_file(os.path.join(logdir, files[0])) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index db9227c97cb..7b020d2b9f0 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -1202,53 +1202,6 @@ def run_metadata_graphs(name, data, step=None): metadata=summary_metadata) -def keras_model(name, data, step=None): - """Writes a Keras model as JSON to as a Summary. - - Writing the Keras model configuration allows the TensorBoard graph plugin to - render a conceptual graph, as opposed to graph of ops. In case the model fails - to serialize as JSON, it ignores and returns False. - - Args: - name: A name for this summary. The summary tag used for TensorBoard will be - this name prefixed by any active name scopes. - data: A Keras Model to write. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - - Returns: - True on success, or False if no summary was written because no default - summary writer was available. - - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - summary_metadata = summary_pb2.SummaryMetadata() - # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for - # the rationale. - summary_metadata.plugin_data.plugin_name = "graph_keras_model" - # version number = 1 - summary_metadata.plugin_data.content = b"1" - - try: - json_string = data.to_json() - except Exception as exc: # pylint: disable=broad-except - # An exception should not break a model code. - logging.warn("Model failed to serialize as JSON. Ignoring... %s" % exc) - return False - - with summary_scope(name, "graph_keras_model", [data, step]) as (tag, _): - with ops.device("cpu:0"): - tensor = constant_op.constant(json_string, dtype=dtypes.string) - return write( - tag=tag, - tensor=tensor, - step=step, - metadata=summary_metadata) - - _TraceContext = collections.namedtuple("TraceContext", ("graph", "profiler")) _current_trace_context_lock = threading.Lock() _current_trace_context = None