diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index 00b295c475a..413c6eb2264 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -372,8 +372,9 @@ class TPUEmbedding(tracking.AutoTrackable): self._config_proto = self._create_config_proto() - logging.info("Initializing TPU Embedding engine with config: %s", - self._config_proto) + logging.info("Initializing TPU Embedding engine.") + tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto) + @def_function.function def load_config(): tpu.initialize_system_for_tpu_embedding(self._config_proto) diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils.py b/tensorflow/python/tpu/tpu_embedding_v2_utils.py index e04f1f0281a..33ff73ed706 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils.py @@ -23,9 +23,12 @@ import abc import math import typing from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union + +from absl import logging import six from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 +from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 from tensorflow.python.distribute import sharded_variable from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops_v2 @@ -731,3 +734,18 @@ class FeatureConfig(object): max_sequence_length=self.max_sequence_length, name=self.name) ) + + +def log_tpu_embedding_configuration( + config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None: + """Logs a TPUEmbeddingConfiguration proto across multiple statements. + + Args: + config: TPUEmbeddingConfiguration proto to log. Necessary because + logging.info has a maximum length to each log statement, which + particularly large configs can exceed. + """ + logging.info("Beginning log of TPUEmbeddingConfiguration.") + for line in str(config).splitlines(): + logging.info(line) + logging.info("Done with log of TPUEmbeddingConfiguration.") diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py index 48797b00009..770ca1fc407 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 from tensorflow.python.compat import v2_compat from tensorflow.python.platform import test from tensorflow.python.tpu import tpu_embedding_v2_utils @@ -88,6 +89,34 @@ class ConfigTest(test.TestCase): ) +class TPUEmbeddingConfigurationTest(test.TestCase): + + def test_no_truncate(self): + truncate_length = 14937 # Experimentally maximum string length loggable. + + config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() + for i in range(500): + td = config.table_descriptor.add() + td.name = 'table_{}'.format(i) + td.vocabulary_size = i + config.num_hosts = 2 + config.num_tensor_cores = 4 + config.batch_size_per_tensor_core = 128 + + self.assertGreater( + len(str(config)), truncate_length, + 'Test sanity check: generated config should be of truncating length.') + + with self.assertLogs() as logs: + tpu_embedding_v2_utils.log_tpu_embedding_configuration(config) + + self.assertIn('table_499', ''.join(logs.output)) + for line in logs.output: + self.assertLess( + len(line), truncate_length, + 'Logging function lines should not be of truncating length.') + + if __name__ == '__main__': v2_compat.enable_v2_behavior() test.main()