From e6ce9ea5a156873c5b927e99d8935e32122538b9 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Sun, 15 Jul 2018 11:28:49 -0700 Subject: [PATCH] Partial update of tf.keras to the Keras 2.2.0 API. Changes included are: - Embedding visualization is added to TensorBoard callback (from older Keras API.) - Fix: learning phase info being left out in multi-input models (from older Keras API.) - Fix: Tensorboard callback only supports logging Embeddings layer weights - Fix: Tensorboard callback with layer with multiple outputs PiperOrigin-RevId: 204659796 --- tensorflow/python/keras/callbacks.py | 137 +++++++++++++++++- ...orflow.keras.callbacks.-tensor-board.pbtxt | 2 +- 2 files changed, 135 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 53d907a2cc7..0857a3279f1 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -31,11 +31,16 @@ import time import numpy as np import six +from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.training_utils import standardize_input_data from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as tf_summary +from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export @@ -697,7 +702,9 @@ class TensorBoard(Callback): write_images: whether to write model weights to visualize as image in TensorBoard. embeddings_freq: frequency (in epochs) at which selected embedding - layers will be saved. + layers will be saved. If set to 0, embeddings won't be computed. + Data to be visualized in TensorBoard's Embedding tab must be passed + as `embeddings_data`. embeddings_layer_names: a list of names of layers to keep eye on. If None or empty list all the embedding layer will be watched. embeddings_metadata: a dictionary which maps layer name to a file name @@ -705,6 +712,10 @@ class TensorBoard(Callback): [details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional) about metadata files format. In case if the same metadata file is used for all embedding layers, string can be passed. + embeddings_data: data to be embedded at layers specified in + `embeddings_layer_names`. Numpy array (if the model has a single + input) or list of Numpy arrays (if the model has multiple inputs). + Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding) """ # pylint: enable=line-too-long @@ -715,7 +726,11 @@ class TensorBoard(Callback): batch_size=32, write_graph=True, write_grads=False, - write_images=False): + write_images=False, + embeddings_freq=0, + embeddings_layer_names=None, + embeddings_metadata=None, + embeddings_data=None): super(TensorBoard, self).__init__() self.log_dir = log_dir self.histogram_freq = histogram_freq @@ -727,6 +742,10 @@ class TensorBoard(Callback): self._current_batch = 0 # abstracted writer class to be able to stub for testing self._writer_class = tf_summary.FileWriter + self.embeddings_freq = embeddings_freq + self.embeddings_layer_names = embeddings_layer_names + self.embeddings_metadata = embeddings_metadata + self.embeddings_data = embeddings_data def set_model(self, model): """Sets Keras model and creates summary ops.""" @@ -778,7 +797,11 @@ class TensorBoard(Callback): tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) if hasattr(layer, 'output'): - tf_summary.histogram('{}_out'.format(layer.name), layer.output) + if isinstance(layer.output, list): + for i, output in enumerate(layer.output): + tf_summary.histogram('{}_out_{}'.format(layer.name, i), output) + else: + tf_summary.histogram('{}_out'.format(layer.name), layer.output) self.merged = tf_summary.merge_all() if self.write_graph: @@ -786,6 +809,74 @@ class TensorBoard(Callback): else: self.writer = self._writer_class(self.log_dir) + # If both embedding_freq and embeddings_data are available, we will + # visualize embeddings. + if self.embeddings_freq and self.embeddings_data is not None: + self.embeddings_data = standardize_input_data(self.embeddings_data, + model.input_names) + + # If embedding_layer_names are not provided, get all of the embedding + # layers from the model. + embeddings_layer_names = self.embeddings_layer_names + if not embeddings_layer_names: + embeddings_layer_names = [ + layer.name + for layer in self.model.layers + if type(layer).__name__ == 'Embedding' + ] + + self.assign_embeddings = [] + embeddings_vars = {} + + self.batch_id = batch_id = array_ops.placeholder(dtypes.int32) + self.step = step = array_ops.placeholder(dtypes.int32) + + for layer in self.model.layers: + if layer.name in embeddings_layer_names: + embedding_input = self.model.get_layer(layer.name).output + embedding_size = np.prod(embedding_input.shape[1:]) + embedding_input = array_ops.reshape(embedding_input, + (step, int(embedding_size))) + shape = (self.embeddings_data[0].shape[0], int(embedding_size)) + embedding = variables.Variable( + array_ops.zeros(shape), name=layer.name + '_embedding') + embeddings_vars[layer.name] = embedding + batch = state_ops.assign(embedding[batch_id:batch_id + step], + embedding_input) + self.assign_embeddings.append(batch) + + self.saver = saver.Saver(list(embeddings_vars.values())) + + # Create embeddings_metadata dictionary + if isinstance(self.embeddings_metadata, str): + embeddings_metadata = { + layer_name: self.embeddings_metadata + for layer_name in embeddings_vars.keys() + } + else: + # If embedding_metadata is already a dictionary + embeddings_metadata = self.embeddings_metadata + + try: + from tensorboard.plugins import projector + except ImportError: + raise ImportError('Failed to import TensorBoard. Please make sure that ' + 'TensorBoard integration is complete."') + + # TODO(psv): Add integration tests to test embedding visualization + # with TensorBoard callback. We are unable to write a unit test for this + # because TensorBoard dependency assumes TensorFlow package is installed. + config = projector.ProjectorConfig() + for layer_name, tensor in embeddings_vars.items(): + embedding = config.embeddings.add() + embedding.tensor_name = tensor.name + + if (embeddings_metadata is not None and + layer_name in embeddings_metadata): + embedding.metadata_path = embeddings_metadata[layer_name] + + projector.visualize_embeddings(self.writer, config) + def _fetch_callback(self, summary): self.writer.add_summary( summary, @@ -833,6 +924,46 @@ class TensorBoard(Callback): if self.merged in self.model.test_function.fetch_callbacks: self.model.test_function.fetch_callbacks.pop(self.merged) + if self.embeddings_data is None and self.embeddings_freq: + raise ValueError('To visualize embeddings, embeddings_data must ' + 'be provided.') + + if self.embeddings_freq and self.embeddings_data is not None: + if epoch % self.embeddings_freq == 0: + # We need a second forward-pass here because we're passing + # the `embeddings_data` explicitly. This design allows to pass + # arbitrary data as `embeddings_data` and results from the fact + # that we need to know the size of the `tf.Variable`s which + # hold the embeddings in `set_model`. At this point, however, + # the `validation_data` is not yet set. + + embeddings_data = self.embeddings_data + n_samples = embeddings_data[0].shape[0] + i = 0 + while i < n_samples: + step = min(self.batch_size, n_samples - i) + batch = slice(i, i + step) + + if isinstance(self.model.input, list): + feed_dict = { + model_input: embeddings_data[idx][batch] + for idx, model_input in enumerate(self.model.input) + } + else: + feed_dict = {self.model.input: embeddings_data[0][batch]} + + feed_dict.update({self.batch_id: i, self.step: step}) + + if self.model.uses_learning_phase: + feed_dict[K.learning_phase()] = False + + self.sess.run(self.assign_embeddings, feed_dict=feed_dict) + self.saver.save(self.sess, + os.path.join(self.log_dir, 'keras_embedding.ckpt'), + epoch) + + i += self.batch_size + for name, value in logs.items(): if name in ['batch', 'size']: continue diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt index 2f52464315d..e58ba18c1c0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\'], " + argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\', \'embeddings_freq\', \'embeddings_layer_names\', \'embeddings_metadata\', \'embeddings_data\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\', \'0\', \'None\', \'None\', \'None\'], " } member_method { name: "on_batch_begin"