Add __str__ method to tf.distribute.InputContext.

PiperOrigin-RevId: 292228607
Change-Id: I8752cbc5feec2de60d25663c744e07453950ad71
This commit is contained in:
A. Unique TensorFlower 2020-01-29 15:15:56 -08:00 committed by TensorFlower Gardener
parent 5f24934536
commit abb256c1f8
2 changed files with 16 additions and 0 deletions

View File

@ -406,6 +406,10 @@ class InputContext(object):
(global_batch_size, self._num_replicas_in_sync))
return global_batch_size // self._num_replicas_in_sync
def __str__(self):
return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
self.input_pipeline_id, self.num_input_pipelines)
# ------------------------------------------------------------------------------
# Base classes for all distribution strategies.

View File

@ -563,6 +563,18 @@ class InputContextTest(test.TestCase):
with self.assertRaises(ValueError):
input_context.get_per_replica_batch_size(13)
def testStr(self):
input_context = distribute_lib.InputContext(
num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=42)
self.assertEqual(
"tf.distribute.InputContext(input pipeline id 0, total: 1)",
str(input_context))
input_context = distribute_lib.InputContext(
num_input_pipelines=3, input_pipeline_id=1, num_replicas_in_sync=42)
self.assertEqual(
"tf.distribute.InputContext(input pipeline id 1, total: 3)",
str(input_context))
if __name__ == "__main__":
test.main()