Split TPUEmbeddingConfiguration logging across multiple statements.

This avoids log truncation for larger models.

PiperOrigin-RevId: 338551150
Change-Id: I693cb9771e493f7070dfd93f9e56f5b417b30b66
This commit is contained in:
Revan Sopher 2020-10-22 14:36:40 -07:00 committed by TensorFlower Gardener
parent 5bb73bc2aa
commit 73ec60de15
3 changed files with 50 additions and 2 deletions

View File

@ -372,8 +372,9 @@ class TPUEmbedding(tracking.AutoTrackable):
self._config_proto = self._create_config_proto() self._config_proto = self._create_config_proto()
logging.info("Initializing TPU Embedding engine with config: %s", logging.info("Initializing TPU Embedding engine.")
self._config_proto) tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto)
@def_function.function @def_function.function
def load_config(): def load_config():
tpu.initialize_system_for_tpu_embedding(self._config_proto) tpu.initialize_system_for_tpu_embedding(self._config_proto)

View File

@ -23,9 +23,12 @@ import abc
import math import math
import typing import typing
from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
from absl import logging
import six import six
from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
from tensorflow.python.distribute import sharded_variable from tensorflow.python.distribute import sharded_variable
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops_v2 from tensorflow.python.ops import init_ops_v2
@ -731,3 +734,18 @@ class FeatureConfig(object):
max_sequence_length=self.max_sequence_length, max_sequence_length=self.max_sequence_length,
name=self.name) name=self.name)
) )
def log_tpu_embedding_configuration(
config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None:
"""Logs a TPUEmbeddingConfiguration proto across multiple statements.
Args:
config: TPUEmbeddingConfiguration proto to log. Necessary because
logging.info has a maximum length to each log statement, which
particularly large configs can exceed.
"""
logging.info("Beginning log of TPUEmbeddingConfiguration.")
for line in str(config).splitlines():
logging.info(line)
logging.info("Done with log of TPUEmbeddingConfiguration.")

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_embedding_v2_utils from tensorflow.python.tpu import tpu_embedding_v2_utils
@ -88,6 +89,34 @@ class ConfigTest(test.TestCase):
) )
class TPUEmbeddingConfigurationTest(test.TestCase):
def test_no_truncate(self):
truncate_length = 14937 # Experimentally maximum string length loggable.
config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
for i in range(500):
td = config.table_descriptor.add()
td.name = 'table_{}'.format(i)
td.vocabulary_size = i
config.num_hosts = 2
config.num_tensor_cores = 4
config.batch_size_per_tensor_core = 128
self.assertGreater(
len(str(config)), truncate_length,
'Test sanity check: generated config should be of truncating length.')
with self.assertLogs() as logs:
tpu_embedding_v2_utils.log_tpu_embedding_configuration(config)
self.assertIn('table_499', ''.join(logs.output))
for line in logs.output:
self.assertLess(
len(line), truncate_length,
'Logging function lines should not be of truncating length.')
if __name__ == '__main__': if __name__ == '__main__':
v2_compat.enable_v2_behavior() v2_compat.enable_v2_behavior()
test.main() test.main()