Split TPUEmbeddingConfiguration logging across multiple statements.

This avoids log truncation for larger models.

PiperOrigin-RevId: 338551150
Change-Id: I693cb9771e493f7070dfd93f9e56f5b417b30b66
This commit is contained in:
Revan Sopher 2020-10-22 14:36:40 -07:00 committed by TensorFlower Gardener
parent 5bb73bc2aa
commit 73ec60de15
3 changed files with 50 additions and 2 deletions

View File

@ -372,8 +372,9 @@ class TPUEmbedding(tracking.AutoTrackable):
self._config_proto = self._create_config_proto()
logging.info("Initializing TPU Embedding engine with config: %s",
self._config_proto)
logging.info("Initializing TPU Embedding engine.")
tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto)
@def_function.function
def load_config():
tpu.initialize_system_for_tpu_embedding(self._config_proto)

View File

@ -23,9 +23,12 @@ import abc
import math
import typing
from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
from absl import logging
import six
from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops_v2
@ -731,3 +734,18 @@ class FeatureConfig(object):
max_sequence_length=self.max_sequence_length,
name=self.name)
)
def log_tpu_embedding_configuration(
config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None:
"""Logs a TPUEmbeddingConfiguration proto across multiple statements.
Args:
config: TPUEmbeddingConfiguration proto to log. Necessary because
logging.info has a maximum length to each log statement, which
particularly large configs can exceed.
"""
logging.info("Beginning log of TPUEmbeddingConfiguration.")
for line in str(config).splitlines():
logging.info(line)
logging.info("Done with log of TPUEmbeddingConfiguration.")

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
from tensorflow.python.compat import v2_compat
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_embedding_v2_utils
@ -88,6 +89,34 @@ class ConfigTest(test.TestCase):
)
class TPUEmbeddingConfigurationTest(test.TestCase):
def test_no_truncate(self):
truncate_length = 14937 # Experimentally maximum string length loggable.
config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
for i in range(500):
td = config.table_descriptor.add()
td.name = 'table_{}'.format(i)
td.vocabulary_size = i
config.num_hosts = 2
config.num_tensor_cores = 4
config.batch_size_per_tensor_core = 128
self.assertGreater(
len(str(config)), truncate_length,
'Test sanity check: generated config should be of truncating length.')
with self.assertLogs() as logs:
tpu_embedding_v2_utils.log_tpu_embedding_configuration(config)
self.assertIn('table_499', ''.join(logs.output))
for line in logs.output:
self.assertLess(
len(line), truncate_length,
'Logging function lines should not be of truncating length.')
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()