Add __str__ method to tf.distribute.InputContext.
PiperOrigin-RevId: 292228607 Change-Id: I8752cbc5feec2de60d25663c744e07453950ad71
This commit is contained in:
parent
5f24934536
commit
abb256c1f8
@ -406,6 +406,10 @@ class InputContext(object):
|
||||
(global_batch_size, self._num_replicas_in_sync))
|
||||
return global_batch_size // self._num_replicas_in_sync
|
||||
|
||||
def __str__(self):
|
||||
return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
|
||||
self.input_pipeline_id, self.num_input_pipelines)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Base classes for all distribution strategies.
|
||||
|
@ -563,6 +563,18 @@ class InputContextTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
input_context.get_per_replica_batch_size(13)
|
||||
|
||||
def testStr(self):
|
||||
input_context = distribute_lib.InputContext(
|
||||
num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=42)
|
||||
self.assertEqual(
|
||||
"tf.distribute.InputContext(input pipeline id 0, total: 1)",
|
||||
str(input_context))
|
||||
input_context = distribute_lib.InputContext(
|
||||
num_input_pipelines=3, input_pipeline_id=1, num_replicas_in_sync=42)
|
||||
self.assertEqual(
|
||||
"tf.distribute.InputContext(input pipeline id 1, total: 3)",
|
||||
str(input_context))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user