diff --git a/tensorflow/python/tpu/tensor_tracer_flags.py b/tensorflow/python/tpu/tensor_tracer_flags.py index a8d0ceafd16..923985eabc3 100644 --- a/tensorflow/python/tpu/tensor_tracer_flags.py +++ b/tensorflow/python/tpu/tensor_tracer_flags.py @@ -141,6 +141,13 @@ class TTParameters(object): self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE) self.flush_summaries_with_outside_compile = self.is_flag_on( FLAG_FLUSH_SUMMARY) + self._check_flag_errors() + + def _check_flag_errors(self): + if self.trace_mode in (TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY): + if not self.trace_dir: + raise ValueError('trace_dir must be explicitly provided in ' + 'TENSOR_TRACER_FLAGS when summary mode is used.') def _get_report_filepath(self): """Sets the path of the output report file."""