Partial update of tf.keras to the Keras 2.2.0 API.
Changes included are: - Embedding visualization is added to TensorBoard callback (from older Keras API.) - Fix: learning phase info being left out in multi-input models (from older Keras API.) - Fix: Tensorboard callback only supports logging Embeddings layer weights - Fix: Tensorboard callback with layer with multiple outputs PiperOrigin-RevId: 204659796
This commit is contained in:
parent
fe7d1d9447
commit
e6ce9ea5a1
@ -31,11 +31,16 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
|
from tensorflow.python.keras.engine.training_utils import standardize_input_data
|
||||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.summary import summary as tf_summary
|
from tensorflow.python.summary import summary as tf_summary
|
||||||
|
from tensorflow.python.training import saver
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -697,7 +702,9 @@ class TensorBoard(Callback):
|
|||||||
write_images: whether to write model weights to visualize as
|
write_images: whether to write model weights to visualize as
|
||||||
image in TensorBoard.
|
image in TensorBoard.
|
||||||
embeddings_freq: frequency (in epochs) at which selected embedding
|
embeddings_freq: frequency (in epochs) at which selected embedding
|
||||||
layers will be saved.
|
layers will be saved. If set to 0, embeddings won't be computed.
|
||||||
|
Data to be visualized in TensorBoard's Embedding tab must be passed
|
||||||
|
as `embeddings_data`.
|
||||||
embeddings_layer_names: a list of names of layers to keep eye on. If
|
embeddings_layer_names: a list of names of layers to keep eye on. If
|
||||||
None or empty list all the embedding layer will be watched.
|
None or empty list all the embedding layer will be watched.
|
||||||
embeddings_metadata: a dictionary which maps layer name to a file name
|
embeddings_metadata: a dictionary which maps layer name to a file name
|
||||||
@ -705,6 +712,10 @@ class TensorBoard(Callback):
|
|||||||
[details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
|
[details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
|
||||||
about metadata files format. In case if the same metadata file is
|
about metadata files format. In case if the same metadata file is
|
||||||
used for all embedding layers, string can be passed.
|
used for all embedding layers, string can be passed.
|
||||||
|
embeddings_data: data to be embedded at layers specified in
|
||||||
|
`embeddings_layer_names`. Numpy array (if the model has a single
|
||||||
|
input) or list of Numpy arrays (if the model has multiple inputs).
|
||||||
|
Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
@ -715,7 +726,11 @@ class TensorBoard(Callback):
|
|||||||
batch_size=32,
|
batch_size=32,
|
||||||
write_graph=True,
|
write_graph=True,
|
||||||
write_grads=False,
|
write_grads=False,
|
||||||
write_images=False):
|
write_images=False,
|
||||||
|
embeddings_freq=0,
|
||||||
|
embeddings_layer_names=None,
|
||||||
|
embeddings_metadata=None,
|
||||||
|
embeddings_data=None):
|
||||||
super(TensorBoard, self).__init__()
|
super(TensorBoard, self).__init__()
|
||||||
self.log_dir = log_dir
|
self.log_dir = log_dir
|
||||||
self.histogram_freq = histogram_freq
|
self.histogram_freq = histogram_freq
|
||||||
@ -727,6 +742,10 @@ class TensorBoard(Callback):
|
|||||||
self._current_batch = 0
|
self._current_batch = 0
|
||||||
# abstracted writer class to be able to stub for testing
|
# abstracted writer class to be able to stub for testing
|
||||||
self._writer_class = tf_summary.FileWriter
|
self._writer_class = tf_summary.FileWriter
|
||||||
|
self.embeddings_freq = embeddings_freq
|
||||||
|
self.embeddings_layer_names = embeddings_layer_names
|
||||||
|
self.embeddings_metadata = embeddings_metadata
|
||||||
|
self.embeddings_data = embeddings_data
|
||||||
|
|
||||||
def set_model(self, model):
|
def set_model(self, model):
|
||||||
"""Sets Keras model and creates summary ops."""
|
"""Sets Keras model and creates summary ops."""
|
||||||
@ -778,6 +797,10 @@ class TensorBoard(Callback):
|
|||||||
tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
|
tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
|
||||||
|
|
||||||
if hasattr(layer, 'output'):
|
if hasattr(layer, 'output'):
|
||||||
|
if isinstance(layer.output, list):
|
||||||
|
for i, output in enumerate(layer.output):
|
||||||
|
tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
|
||||||
|
else:
|
||||||
tf_summary.histogram('{}_out'.format(layer.name), layer.output)
|
tf_summary.histogram('{}_out'.format(layer.name), layer.output)
|
||||||
self.merged = tf_summary.merge_all()
|
self.merged = tf_summary.merge_all()
|
||||||
|
|
||||||
@ -786,6 +809,74 @@ class TensorBoard(Callback):
|
|||||||
else:
|
else:
|
||||||
self.writer = self._writer_class(self.log_dir)
|
self.writer = self._writer_class(self.log_dir)
|
||||||
|
|
||||||
|
# If both embedding_freq and embeddings_data are available, we will
|
||||||
|
# visualize embeddings.
|
||||||
|
if self.embeddings_freq and self.embeddings_data is not None:
|
||||||
|
self.embeddings_data = standardize_input_data(self.embeddings_data,
|
||||||
|
model.input_names)
|
||||||
|
|
||||||
|
# If embedding_layer_names are not provided, get all of the embedding
|
||||||
|
# layers from the model.
|
||||||
|
embeddings_layer_names = self.embeddings_layer_names
|
||||||
|
if not embeddings_layer_names:
|
||||||
|
embeddings_layer_names = [
|
||||||
|
layer.name
|
||||||
|
for layer in self.model.layers
|
||||||
|
if type(layer).__name__ == 'Embedding'
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assign_embeddings = []
|
||||||
|
embeddings_vars = {}
|
||||||
|
|
||||||
|
self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
|
||||||
|
self.step = step = array_ops.placeholder(dtypes.int32)
|
||||||
|
|
||||||
|
for layer in self.model.layers:
|
||||||
|
if layer.name in embeddings_layer_names:
|
||||||
|
embedding_input = self.model.get_layer(layer.name).output
|
||||||
|
embedding_size = np.prod(embedding_input.shape[1:])
|
||||||
|
embedding_input = array_ops.reshape(embedding_input,
|
||||||
|
(step, int(embedding_size)))
|
||||||
|
shape = (self.embeddings_data[0].shape[0], int(embedding_size))
|
||||||
|
embedding = variables.Variable(
|
||||||
|
array_ops.zeros(shape), name=layer.name + '_embedding')
|
||||||
|
embeddings_vars[layer.name] = embedding
|
||||||
|
batch = state_ops.assign(embedding[batch_id:batch_id + step],
|
||||||
|
embedding_input)
|
||||||
|
self.assign_embeddings.append(batch)
|
||||||
|
|
||||||
|
self.saver = saver.Saver(list(embeddings_vars.values()))
|
||||||
|
|
||||||
|
# Create embeddings_metadata dictionary
|
||||||
|
if isinstance(self.embeddings_metadata, str):
|
||||||
|
embeddings_metadata = {
|
||||||
|
layer_name: self.embeddings_metadata
|
||||||
|
for layer_name in embeddings_vars.keys()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# If embedding_metadata is already a dictionary
|
||||||
|
embeddings_metadata = self.embeddings_metadata
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tensorboard.plugins import projector
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Failed to import TensorBoard. Please make sure that '
|
||||||
|
'TensorBoard integration is complete."')
|
||||||
|
|
||||||
|
# TODO(psv): Add integration tests to test embedding visualization
|
||||||
|
# with TensorBoard callback. We are unable to write a unit test for this
|
||||||
|
# because TensorBoard dependency assumes TensorFlow package is installed.
|
||||||
|
config = projector.ProjectorConfig()
|
||||||
|
for layer_name, tensor in embeddings_vars.items():
|
||||||
|
embedding = config.embeddings.add()
|
||||||
|
embedding.tensor_name = tensor.name
|
||||||
|
|
||||||
|
if (embeddings_metadata is not None and
|
||||||
|
layer_name in embeddings_metadata):
|
||||||
|
embedding.metadata_path = embeddings_metadata[layer_name]
|
||||||
|
|
||||||
|
projector.visualize_embeddings(self.writer, config)
|
||||||
|
|
||||||
def _fetch_callback(self, summary):
|
def _fetch_callback(self, summary):
|
||||||
self.writer.add_summary(
|
self.writer.add_summary(
|
||||||
summary,
|
summary,
|
||||||
@ -833,6 +924,46 @@ class TensorBoard(Callback):
|
|||||||
if self.merged in self.model.test_function.fetch_callbacks:
|
if self.merged in self.model.test_function.fetch_callbacks:
|
||||||
self.model.test_function.fetch_callbacks.pop(self.merged)
|
self.model.test_function.fetch_callbacks.pop(self.merged)
|
||||||
|
|
||||||
|
if self.embeddings_data is None and self.embeddings_freq:
|
||||||
|
raise ValueError('To visualize embeddings, embeddings_data must '
|
||||||
|
'be provided.')
|
||||||
|
|
||||||
|
if self.embeddings_freq and self.embeddings_data is not None:
|
||||||
|
if epoch % self.embeddings_freq == 0:
|
||||||
|
# We need a second forward-pass here because we're passing
|
||||||
|
# the `embeddings_data` explicitly. This design allows to pass
|
||||||
|
# arbitrary data as `embeddings_data` and results from the fact
|
||||||
|
# that we need to know the size of the `tf.Variable`s which
|
||||||
|
# hold the embeddings in `set_model`. At this point, however,
|
||||||
|
# the `validation_data` is not yet set.
|
||||||
|
|
||||||
|
embeddings_data = self.embeddings_data
|
||||||
|
n_samples = embeddings_data[0].shape[0]
|
||||||
|
i = 0
|
||||||
|
while i < n_samples:
|
||||||
|
step = min(self.batch_size, n_samples - i)
|
||||||
|
batch = slice(i, i + step)
|
||||||
|
|
||||||
|
if isinstance(self.model.input, list):
|
||||||
|
feed_dict = {
|
||||||
|
model_input: embeddings_data[idx][batch]
|
||||||
|
for idx, model_input in enumerate(self.model.input)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
feed_dict = {self.model.input: embeddings_data[0][batch]}
|
||||||
|
|
||||||
|
feed_dict.update({self.batch_id: i, self.step: step})
|
||||||
|
|
||||||
|
if self.model.uses_learning_phase:
|
||||||
|
feed_dict[K.learning_phase()] = False
|
||||||
|
|
||||||
|
self.sess.run(self.assign_embeddings, feed_dict=feed_dict)
|
||||||
|
self.saver.save(self.sess,
|
||||||
|
os.path.join(self.log_dir, 'keras_embedding.ckpt'),
|
||||||
|
epoch)
|
||||||
|
|
||||||
|
i += self.batch_size
|
||||||
|
|
||||||
for name, value in logs.items():
|
for name, value in logs.items():
|
||||||
if name in ['batch', 'size']:
|
if name in ['batch', 'size']:
|
||||||
continue
|
continue
|
||||||
|
@ -5,7 +5,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\'], "
|
argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\', \'embeddings_freq\', \'embeddings_layer_names\', \'embeddings_metadata\', \'embeddings_data\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\', \'0\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "on_batch_begin"
|
name: "on_batch_begin"
|
||||||
|
Loading…
Reference in New Issue
Block a user