Split TPUEmbeddingConfiguration logging across multiple statements.
This avoids log truncation for larger models. PiperOrigin-RevId: 338551150 Change-Id: I693cb9771e493f7070dfd93f9e56f5b417b30b66
This commit is contained in:
parent
5bb73bc2aa
commit
73ec60de15
@ -372,8 +372,9 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||||||
|
|
||||||
self._config_proto = self._create_config_proto()
|
self._config_proto = self._create_config_proto()
|
||||||
|
|
||||||
logging.info("Initializing TPU Embedding engine with config: %s",
|
logging.info("Initializing TPU Embedding engine.")
|
||||||
self._config_proto)
|
tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def load_config():
|
def load_config():
|
||||||
tpu.initialize_system_for_tpu_embedding(self._config_proto)
|
tpu.initialize_system_for_tpu_embedding(self._config_proto)
|
||||||
|
@ -23,9 +23,12 @@ import abc
|
|||||||
import math
|
import math
|
||||||
import typing
|
import typing
|
||||||
from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
|
from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
|
||||||
|
|
||||||
|
from absl import logging
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
|
from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
|
||||||
|
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
|
||||||
from tensorflow.python.distribute import sharded_variable
|
from tensorflow.python.distribute import sharded_variable
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import init_ops_v2
|
from tensorflow.python.ops import init_ops_v2
|
||||||
@ -731,3 +734,18 @@ class FeatureConfig(object):
|
|||||||
max_sequence_length=self.max_sequence_length,
|
max_sequence_length=self.max_sequence_length,
|
||||||
name=self.name)
|
name=self.name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_tpu_embedding_configuration(
|
||||||
|
config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None:
|
||||||
|
"""Logs a TPUEmbeddingConfiguration proto across multiple statements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: TPUEmbeddingConfiguration proto to log. Necessary because
|
||||||
|
logging.info has a maximum length to each log statement, which
|
||||||
|
particularly large configs can exceed.
|
||||||
|
"""
|
||||||
|
logging.info("Beginning log of TPUEmbeddingConfiguration.")
|
||||||
|
for line in str(config).splitlines():
|
||||||
|
logging.info(line)
|
||||||
|
logging.info("Done with log of TPUEmbeddingConfiguration.")
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
|
||||||
from tensorflow.python.compat import v2_compat
|
from tensorflow.python.compat import v2_compat
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.tpu import tpu_embedding_v2_utils
|
from tensorflow.python.tpu import tpu_embedding_v2_utils
|
||||||
@ -88,6 +89,34 @@ class ConfigTest(test.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TPUEmbeddingConfigurationTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_no_truncate(self):
|
||||||
|
truncate_length = 14937 # Experimentally maximum string length loggable.
|
||||||
|
|
||||||
|
config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
|
||||||
|
for i in range(500):
|
||||||
|
td = config.table_descriptor.add()
|
||||||
|
td.name = 'table_{}'.format(i)
|
||||||
|
td.vocabulary_size = i
|
||||||
|
config.num_hosts = 2
|
||||||
|
config.num_tensor_cores = 4
|
||||||
|
config.batch_size_per_tensor_core = 128
|
||||||
|
|
||||||
|
self.assertGreater(
|
||||||
|
len(str(config)), truncate_length,
|
||||||
|
'Test sanity check: generated config should be of truncating length.')
|
||||||
|
|
||||||
|
with self.assertLogs() as logs:
|
||||||
|
tpu_embedding_v2_utils.log_tpu_embedding_configuration(config)
|
||||||
|
|
||||||
|
self.assertIn('table_499', ''.join(logs.output))
|
||||||
|
for line in logs.output:
|
||||||
|
self.assertLess(
|
||||||
|
len(line), truncate_length,
|
||||||
|
'Logging function lines should not be of truncating length.')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user