Split TPUEmbeddingConfiguration logging across multiple statements.
This avoids log truncation for larger models. PiperOrigin-RevId: 338551150 Change-Id: I693cb9771e493f7070dfd93f9e56f5b417b30b66
This commit is contained in:
parent
5bb73bc2aa
commit
73ec60de15
@ -372,8 +372,9 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
|
||||
self._config_proto = self._create_config_proto()
|
||||
|
||||
logging.info("Initializing TPU Embedding engine with config: %s",
|
||||
self._config_proto)
|
||||
logging.info("Initializing TPU Embedding engine.")
|
||||
tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto)
|
||||
|
||||
@def_function.function
|
||||
def load_config():
|
||||
tpu.initialize_system_for_tpu_embedding(self._config_proto)
|
||||
|
@ -23,9 +23,12 @@ import abc
|
||||
import math
|
||||
import typing
|
||||
from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
|
||||
|
||||
from absl import logging
|
||||
import six
|
||||
|
||||
from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
|
||||
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
@ -731,3 +734,18 @@ class FeatureConfig(object):
|
||||
max_sequence_length=self.max_sequence_length,
|
||||
name=self.name)
|
||||
)
|
||||
|
||||
|
||||
def log_tpu_embedding_configuration(
|
||||
config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None:
|
||||
"""Logs a TPUEmbeddingConfiguration proto across multiple statements.
|
||||
|
||||
Args:
|
||||
config: TPUEmbeddingConfiguration proto to log. Necessary because
|
||||
logging.info has a maximum length to each log statement, which
|
||||
particularly large configs can exceed.
|
||||
"""
|
||||
logging.info("Beginning log of TPUEmbeddingConfiguration.")
|
||||
for line in str(config).splitlines():
|
||||
logging.info(line)
|
||||
logging.info("Done with log of TPUEmbeddingConfiguration.")
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu_embedding_v2_utils
|
||||
@ -88,6 +89,34 @@ class ConfigTest(test.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TPUEmbeddingConfigurationTest(test.TestCase):
|
||||
|
||||
def test_no_truncate(self):
|
||||
truncate_length = 14937 # Experimentally maximum string length loggable.
|
||||
|
||||
config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
|
||||
for i in range(500):
|
||||
td = config.table_descriptor.add()
|
||||
td.name = 'table_{}'.format(i)
|
||||
td.vocabulary_size = i
|
||||
config.num_hosts = 2
|
||||
config.num_tensor_cores = 4
|
||||
config.batch_size_per_tensor_core = 128
|
||||
|
||||
self.assertGreater(
|
||||
len(str(config)), truncate_length,
|
||||
'Test sanity check: generated config should be of truncating length.')
|
||||
|
||||
with self.assertLogs() as logs:
|
||||
tpu_embedding_v2_utils.log_tpu_embedding_configuration(config)
|
||||
|
||||
self.assertIn('table_499', ''.join(logs.output))
|
||||
for line in logs.output:
|
||||
self.assertLess(
|
||||
len(line), truncate_length,
|
||||
'Logging function lines should not be of truncating length.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user