diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 7b13e98f811..d9a70209e48 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -406,6 +406,10 @@ class InputContext(object): (global_batch_size, self._num_replicas_in_sync)) return global_batch_size // self._num_replicas_in_sync + def __str__(self): + return "tf.distribute.InputContext(input pipeline id {}, total: {})".format( + self.input_pipeline_id, self.num_input_pipelines) + # ------------------------------------------------------------------------------ # Base classes for all distribution strategies. diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py index d8b4902bd85..1d171bc5cd5 100644 --- a/tensorflow/python/distribute/distribute_lib_test.py +++ b/tensorflow/python/distribute/distribute_lib_test.py @@ -563,6 +563,18 @@ class InputContextTest(test.TestCase): with self.assertRaises(ValueError): input_context.get_per_replica_batch_size(13) + def testStr(self): + input_context = distribute_lib.InputContext( + num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=42) + self.assertEqual( + "tf.distribute.InputContext(input pipeline id 0, total: 1)", + str(input_context)) + input_context = distribute_lib.InputContext( + num_input_pipelines=3, input_pipeline_id=1, num_replicas_in_sync=42) + self.assertEqual( + "tf.distribute.InputContext(input pipeline id 1, total: 3)", + str(input_context)) + if __name__ == "__main__": test.main()