Update naming of embedding in callback for TensorBoard integration.

PiperOrigin-RevId: 314693951
Change-Id: I6a3aba5b98ce6865da8814eda554bac1fb6bcaea
This commit is contained in:
A. Unique TensorFlower 2020-06-04 02:35:03 -07:00 committed by TensorFlower Gardener
parent a8ab56eb94
commit 0ca0c442c0
3 changed files with 14 additions and 7 deletions

View File

@ -2000,7 +2000,10 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
for layer in self.model.layers:
if isinstance(layer, embeddings.Embedding):
embedding = config.embeddings.add()
embedding.tensor_name = layer.name + '/.ATTRIBUTES/VARIABLE_VALUE'
# Embeddings are always the first layer, so this naming should be
# consistent in any keras models checkpoints.
name = 'layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE'
embedding.tensor_name = name
if self.embeddings_metadata is not None:
if isinstance(self.embeddings_metadata, str):

View File

@ -1975,12 +1975,12 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
callbacks=[tb_cbk])
with open(os.path.join(self.logdir, 'projector_config.pbtxt')) as f:
self.assertEqual(
f.readlines(), [
'embeddings {\n',
' tensor_name: "test_embedding/.ATTRIBUTES/VARIABLE_VALUE"\n',
' metadata_path: "metadata.tsv"\n',
'}\n'])
self.assertEqual(f.readlines(), [
'embeddings {\n',
(' tensor_name: '
'"layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE"\n'),
' metadata_path: "metadata.tsv"\n', '}\n'
])
def test_custom_summary(self):
if not context.executing_eagerly():

View File

@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file is a copy of the TensorBoard ProjectorConfig proto.
// Keep this file in sync with the source proto definition at
// https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/projector/projector_config.proto
syntax = "proto3";
package third_party.tensorflow.python.keras.protobuf;