Partial update of tf.keras to the Keras 2.2.0 API.

Changes included are:
- Embedding visualization is added to TensorBoard callback (from older Keras API.)
- Fix: learning phase info being left out in multi-input models (from older Keras API.)
- Fix: Tensorboard callback only supports logging Embeddings layer weights
- Fix: Tensorboard callback with layer with multiple outputs

PiperOrigin-RevId: 204659796
This commit is contained in:
Pavithra Vijay 2018-07-15 11:28:49 -07:00 committed by TensorFlower Gardener
parent fe7d1d9447
commit e6ce9ea5a1
2 changed files with 135 additions and 4 deletions

View File

@ -31,11 +31,16 @@ import time
import numpy as np
import six
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.training_utils import standardize_input_data
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@ -697,7 +702,9 @@ class TensorBoard(Callback):
write_images: whether to write model weights to visualize as
image in TensorBoard.
embeddings_freq: frequency (in epochs) at which selected embedding
layers will be saved.
layers will be saved. If set to 0, embeddings won't be computed.
Data to be visualized in TensorBoard's Embedding tab must be passed
as `embeddings_data`.
embeddings_layer_names: a list of names of layers to keep eye on. If
None or empty list all the embedding layer will be watched.
embeddings_metadata: a dictionary which maps layer name to a file name
@ -705,6 +712,10 @@ class TensorBoard(Callback):
[details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
about metadata files format. In case if the same metadata file is
used for all embedding layers, string can be passed.
embeddings_data: data to be embedded at layers specified in
`embeddings_layer_names`. Numpy array (if the model has a single
input) or list of Numpy arrays (if the model has multiple inputs).
Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding)
"""
# pylint: enable=line-too-long
@ -715,7 +726,11 @@ class TensorBoard(Callback):
batch_size=32,
write_graph=True,
write_grads=False,
write_images=False):
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
embeddings_data=None):
super(TensorBoard, self).__init__()
self.log_dir = log_dir
self.histogram_freq = histogram_freq
@ -727,6 +742,10 @@ class TensorBoard(Callback):
self._current_batch = 0
# abstracted writer class to be able to stub for testing
self._writer_class = tf_summary.FileWriter
self.embeddings_freq = embeddings_freq
self.embeddings_layer_names = embeddings_layer_names
self.embeddings_metadata = embeddings_metadata
self.embeddings_data = embeddings_data
def set_model(self, model):
"""Sets Keras model and creates summary ops."""
@ -778,7 +797,11 @@ class TensorBoard(Callback):
tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
if hasattr(layer, 'output'):
tf_summary.histogram('{}_out'.format(layer.name), layer.output)
if isinstance(layer.output, list):
for i, output in enumerate(layer.output):
tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
else:
tf_summary.histogram('{}_out'.format(layer.name), layer.output)
self.merged = tf_summary.merge_all()
if self.write_graph:
@ -786,6 +809,74 @@ class TensorBoard(Callback):
else:
self.writer = self._writer_class(self.log_dir)
# If both embedding_freq and embeddings_data are available, we will
# visualize embeddings.
if self.embeddings_freq and self.embeddings_data is not None:
self.embeddings_data = standardize_input_data(self.embeddings_data,
model.input_names)
# If embedding_layer_names are not provided, get all of the embedding
# layers from the model.
embeddings_layer_names = self.embeddings_layer_names
if not embeddings_layer_names:
embeddings_layer_names = [
layer.name
for layer in self.model.layers
if type(layer).__name__ == 'Embedding'
]
self.assign_embeddings = []
embeddings_vars = {}
self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
self.step = step = array_ops.placeholder(dtypes.int32)
for layer in self.model.layers:
if layer.name in embeddings_layer_names:
embedding_input = self.model.get_layer(layer.name).output
embedding_size = np.prod(embedding_input.shape[1:])
embedding_input = array_ops.reshape(embedding_input,
(step, int(embedding_size)))
shape = (self.embeddings_data[0].shape[0], int(embedding_size))
embedding = variables.Variable(
array_ops.zeros(shape), name=layer.name + '_embedding')
embeddings_vars[layer.name] = embedding
batch = state_ops.assign(embedding[batch_id:batch_id + step],
embedding_input)
self.assign_embeddings.append(batch)
self.saver = saver.Saver(list(embeddings_vars.values()))
# Create embeddings_metadata dictionary
if isinstance(self.embeddings_metadata, str):
embeddings_metadata = {
layer_name: self.embeddings_metadata
for layer_name in embeddings_vars.keys()
}
else:
# If embedding_metadata is already a dictionary
embeddings_metadata = self.embeddings_metadata
try:
from tensorboard.plugins import projector
except ImportError:
raise ImportError('Failed to import TensorBoard. Please make sure that '
'TensorBoard integration is complete."')
# TODO(psv): Add integration tests to test embedding visualization
# with TensorBoard callback. We are unable to write a unit test for this
# because TensorBoard dependency assumes TensorFlow package is installed.
config = projector.ProjectorConfig()
for layer_name, tensor in embeddings_vars.items():
embedding = config.embeddings.add()
embedding.tensor_name = tensor.name
if (embeddings_metadata is not None and
layer_name in embeddings_metadata):
embedding.metadata_path = embeddings_metadata[layer_name]
projector.visualize_embeddings(self.writer, config)
def _fetch_callback(self, summary):
self.writer.add_summary(
summary,
@ -833,6 +924,46 @@ class TensorBoard(Callback):
if self.merged in self.model.test_function.fetch_callbacks:
self.model.test_function.fetch_callbacks.pop(self.merged)
if self.embeddings_data is None and self.embeddings_freq:
raise ValueError('To visualize embeddings, embeddings_data must '
'be provided.')
if self.embeddings_freq and self.embeddings_data is not None:
if epoch % self.embeddings_freq == 0:
# We need a second forward-pass here because we're passing
# the `embeddings_data` explicitly. This design allows to pass
# arbitrary data as `embeddings_data` and results from the fact
# that we need to know the size of the `tf.Variable`s which
# hold the embeddings in `set_model`. At this point, however,
# the `validation_data` is not yet set.
embeddings_data = self.embeddings_data
n_samples = embeddings_data[0].shape[0]
i = 0
while i < n_samples:
step = min(self.batch_size, n_samples - i)
batch = slice(i, i + step)
if isinstance(self.model.input, list):
feed_dict = {
model_input: embeddings_data[idx][batch]
for idx, model_input in enumerate(self.model.input)
}
else:
feed_dict = {self.model.input: embeddings_data[0][batch]}
feed_dict.update({self.batch_id: i, self.step: step})
if self.model.uses_learning_phase:
feed_dict[K.learning_phase()] = False
self.sess.run(self.assign_embeddings, feed_dict=feed_dict)
self.saver.save(self.sess,
os.path.join(self.log_dir, 'keras_embedding.ckpt'),
epoch)
i += self.batch_size
for name, value in logs.items():
if name in ['batch', 'size']:
continue

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\'], "
argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\', \'embeddings_freq\', \'embeddings_layer_names\', \'embeddings_metadata\', \'embeddings_data\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\', \'0\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "on_batch_begin"