Switching contrib.summaries API to be context-manager-centric
PiperOrigin-RevId: 173129793
This commit is contained in:
		
							parent
							
								
									03b02ffc9e
								
							
						
					
					
						commit
						4ec6f2b07c
					
				| @ -27,6 +27,7 @@ from tensorflow.python.framework import ops | ||||
| from tensorflow.python.layers import utils | ||||
| from tensorflow.python.ops import summary_op_util | ||||
| from tensorflow.python.training import training_util | ||||
| from tensorflow.python.util import tf_contextlib | ||||
| 
 | ||||
| # Name for a collection which is expected to have at most a single boolean | ||||
| # Tensor. If this tensor is True the summary ops will record summaries. | ||||
| @ -46,22 +47,50 @@ def should_record_summaries(): | ||||
| 
 | ||||
| 
 | ||||
| # TODO(apassos) consider how to handle local step here. | ||||
| @tf_contextlib.contextmanager | ||||
| def record_summaries_every_n_global_steps(n): | ||||
|   """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" | ||||
|   collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) | ||||
|   old = collection_ref[:] | ||||
|   collection_ref[:] = [training_util.get_global_step() % n == 0] | ||||
|   yield | ||||
|   collection_ref[:] = old | ||||
| 
 | ||||
| 
 | ||||
| @tf_contextlib.contextmanager | ||||
| def always_record_summaries(): | ||||
|   """Sets the should_record_summaries Tensor to always true.""" | ||||
|   collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) | ||||
|   old = collection_ref[:] | ||||
|   collection_ref[:] = [True] | ||||
|   yield | ||||
|   collection_ref[:] = old | ||||
| 
 | ||||
| 
 | ||||
| @tf_contextlib.contextmanager | ||||
| def never_record_summaries(): | ||||
|   """Sets the should_record_summaries Tensor to always false.""" | ||||
|   collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) | ||||
|   old = collection_ref[:] | ||||
|   collection_ref[:] = [False] | ||||
|   yield | ||||
|   collection_ref[:] = old | ||||
| 
 | ||||
| 
 | ||||
| class SummaryWriter(object): | ||||
| 
 | ||||
|   def __init__(self, resource): | ||||
|     self._resource = resource | ||||
| 
 | ||||
|   def set_as_default(self): | ||||
|     context.context().summary_writer_resource = self._resource | ||||
| 
 | ||||
|   @tf_contextlib.contextmanager | ||||
|   def as_default(self): | ||||
|     old = context.context().summary_writer_resource | ||||
|     context.context().summary_writer_resource = self._resource | ||||
|     yield | ||||
|     context.context().summary_writer_resource = old | ||||
| 
 | ||||
| 
 | ||||
| def create_summary_file_writer(logdir, | ||||
| @ -77,9 +106,11 @@ def create_summary_file_writer(logdir, | ||||
|   if filename_suffix is None: | ||||
|     filename_suffix = constant_op.constant("") | ||||
|   resource = gen_summary_ops.summary_writer(shared_name=name) | ||||
|   # TODO(apassos) ensure the initialization op runs when in graph mode; consider | ||||
|   # calling session.run here. | ||||
|   gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, | ||||
|                                              flush_secs, filename_suffix) | ||||
|   context.context().summary_writer_resource = resource | ||||
|   return SummaryWriter(resource) | ||||
| 
 | ||||
| 
 | ||||
| def _nothing(): | ||||
|  | ||||
| @ -41,60 +41,65 @@ class TargetTest(test_util.TensorFlowTestCase): | ||||
| 
 | ||||
|   def testShouldRecordSummary(self): | ||||
|     self.assertFalse(summary_ops.should_record_summaries()) | ||||
|     summary_ops.always_record_summaries() | ||||
|     self.assertTrue(summary_ops.should_record_summaries()) | ||||
|     with summary_ops.always_record_summaries(): | ||||
|       self.assertTrue(summary_ops.should_record_summaries()) | ||||
| 
 | ||||
|   def testSummaryOps(self): | ||||
|     training_util.get_or_create_global_step() | ||||
|     logdir = tempfile.mkdtemp() | ||||
|     summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') | ||||
|     summary_ops.always_record_summaries() | ||||
|     summary_ops.generic('tensor', 1, '') | ||||
|     summary_ops.scalar('scalar', 2.0) | ||||
|     summary_ops.histogram('histogram', [1.0]) | ||||
|     summary_ops.image('image', [[[[1.0]]]]) | ||||
|     summary_ops.audio('audio', [[1.0]], 1.0, 1) | ||||
|     # The working condition of the ops is tested in the C++ test so we just | ||||
|     # test here that we're calling them correctly. | ||||
|     self.assertTrue(gfile.Exists(logdir)) | ||||
|     with summary_ops.create_summary_file_writer( | ||||
|         logdir, max_queue=0, | ||||
|         name='t0').as_default(), summary_ops.always_record_summaries(): | ||||
|       summary_ops.generic('tensor', 1, '') | ||||
|       summary_ops.scalar('scalar', 2.0) | ||||
|       summary_ops.histogram('histogram', [1.0]) | ||||
|       summary_ops.image('image', [[[[1.0]]]]) | ||||
|       summary_ops.audio('audio', [[1.0]], 1.0, 1) | ||||
|       # The working condition of the ops is tested in the C++ test so we just | ||||
|       # test here that we're calling them correctly. | ||||
|       self.assertTrue(gfile.Exists(logdir)) | ||||
| 
 | ||||
|   def testDefunSummarys(self): | ||||
|     training_util.get_or_create_global_step() | ||||
|     logdir = tempfile.mkdtemp() | ||||
|     summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1') | ||||
|     summary_ops.always_record_summaries() | ||||
|     with summary_ops.create_summary_file_writer( | ||||
|         logdir, max_queue=0, | ||||
|         name='t1').as_default(), summary_ops.always_record_summaries(): | ||||
| 
 | ||||
|     @function.defun | ||||
|     def write(): | ||||
|       summary_ops.scalar('scalar', 2.0) | ||||
|       @function.defun | ||||
|       def write(): | ||||
|         summary_ops.scalar('scalar', 2.0) | ||||
| 
 | ||||
|     write() | ||||
|       write() | ||||
| 
 | ||||
|     self.assertTrue(gfile.Exists(logdir)) | ||||
|     files = gfile.ListDirectory(logdir) | ||||
|     self.assertEqual(len(files), 1) | ||||
|     records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) | ||||
|     self.assertEqual(len(records), 2) | ||||
|     event = event_pb2.Event() | ||||
|     event.ParseFromString(records[1]) | ||||
|     self.assertEqual(event.summary.value[0].simple_value, 2.0) | ||||
|       self.assertTrue(gfile.Exists(logdir)) | ||||
|       files = gfile.ListDirectory(logdir) | ||||
|       self.assertEqual(len(files), 1) | ||||
|       records = list( | ||||
|           tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) | ||||
|       self.assertEqual(len(records), 2) | ||||
|       event = event_pb2.Event() | ||||
|       event.ParseFromString(records[1]) | ||||
|       self.assertEqual(event.summary.value[0].simple_value, 2.0) | ||||
| 
 | ||||
|   def testSummaryName(self): | ||||
|     training_util.get_or_create_global_step() | ||||
|     logdir = tempfile.mkdtemp() | ||||
|     summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2') | ||||
|     summary_ops.always_record_summaries() | ||||
|     with summary_ops.create_summary_file_writer( | ||||
|         logdir, max_queue=0, | ||||
|         name='t2').as_default(), summary_ops.always_record_summaries(): | ||||
| 
 | ||||
|     summary_ops.scalar('scalar', 2.0) | ||||
|       summary_ops.scalar('scalar', 2.0) | ||||
| 
 | ||||
|     self.assertTrue(gfile.Exists(logdir)) | ||||
|     files = gfile.ListDirectory(logdir) | ||||
|     self.assertEqual(len(files), 1) | ||||
|     records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) | ||||
|     self.assertEqual(len(records), 2) | ||||
|     event = event_pb2.Event() | ||||
|     event.ParseFromString(records[1]) | ||||
|     self.assertEqual(event.summary.value[0].tag, 'scalar') | ||||
|       self.assertTrue(gfile.Exists(logdir)) | ||||
|       files = gfile.ListDirectory(logdir) | ||||
|       self.assertEqual(len(files), 1) | ||||
|       records = list( | ||||
|           tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) | ||||
|       self.assertEqual(len(records), 2) | ||||
|       event = event_pb2.Event() | ||||
|       event.ParseFromString(records[1]) | ||||
|       self.assertEqual(event.summary.value[0].tag, 'scalar') | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|  | ||||
| @ -58,6 +58,7 @@ class _EagerContext(threading.local): | ||||
|     self.mode = _default_mode | ||||
|     self.scope_name = "" | ||||
|     self.recording_summaries = False | ||||
|     self.summary_writer_resource = None | ||||
|     self.scalar_cache = {} | ||||
| 
 | ||||
| 
 | ||||
| @ -86,7 +87,6 @@ class Context(object): | ||||
|     self._eager_context = _EagerContext() | ||||
|     self._context_handle = None | ||||
|     self._context_devices = None | ||||
|     self._summary_writer_resource = None | ||||
|     self._post_execution_callbacks = [] | ||||
|     self._config = config | ||||
|     self._seed = None | ||||
| @ -213,12 +213,12 @@ class Context(object): | ||||
|   @property | ||||
|   def summary_writer_resource(self): | ||||
|     """Returns summary writer resource.""" | ||||
|     return self._summary_writer_resource | ||||
|     return self._eager_context.summary_writer_resource | ||||
| 
 | ||||
|   @summary_writer_resource.setter | ||||
|   def summary_writer_resource(self, resource): | ||||
|     """Sets summary writer resource.""" | ||||
|     self._summary_writer_resource = resource | ||||
|     self._eager_context.summary_writer_resource = resource | ||||
| 
 | ||||
|   @property | ||||
|   def device_name(self): | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user