From 047abbc428f16a946abed69654d86bd8cfd69573 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Wed, 27 Mar 2019 10:29:17 -0700 Subject: [PATCH] Make calibrator throw error if values during inference are Nan. PiperOrigin-RevId: 240590732 --- .../lite/tools/optimize/calibration/BUILD | 1 + .../optimize/calibration/calibration_logger.h | 54 ++++++++++--------- .../tools/optimize/calibration/calibrator.cc | 10 ++-- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/tensorflow/lite/tools/optimize/calibration/BUILD b/tensorflow/lite/tools/optimize/calibration/BUILD index c1d2ad2bca8..efd1261f228 100644 --- a/tensorflow/lite/tools/optimize/calibration/BUILD +++ b/tensorflow/lite/tools/optimize/calibration/BUILD @@ -89,6 +89,7 @@ cc_library( name = "calibration_logger", hdrs = ["calibration_logger.h"], deps = [ + "//tensorflow/core:tflite_portable_logging", "//tensorflow/lite/c:c_api_internal", ], ) diff --git a/tensorflow/lite/tools/optimize/calibration/calibration_logger.h b/tensorflow/lite/tools/optimize/calibration/calibration_logger.h index 8fd380423a3..e8b4e30cbbc 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibration_logger.h +++ b/tensorflow/lite/tools/optimize/calibration/calibration_logger.h @@ -12,38 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_ -#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_ +#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_ +#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_ +#include +#include #include +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/c_api_internal.h" namespace tflite { namespace optimize { namespace calibration { + class MinMax { public: - void Update(const float* values, size_t tensor_size) { - // TODO(shashishekhar): Really slow implementation, optimize - if (tensor_size <= 0) return; + TfLiteStatus Update(const float* values, size_t tensor_size) { + if (tensor_size <= 0) return kTfLiteOk; - if (!has_values_) { - min_ = max_ = values[0]; - has_values_ = true; - return; - } - - // We are only logging absolute min/max here. // TODO(shashishekhar): Make it possible to use weighted/moving average. - for (size_t i = 0; i < tensor_size; i++) { - float val = values[i]; - if (min_ > val) { - min_ = val; - } else if (max_ < val) { - max_ = val; + for (size_t i = 0; i < tensor_size; ++i) { + if (std::isnan(values[i])) { + // TODO(suharshs): Propagate ErrorReporter here. + LOG(ERROR) << "Model resulted in Nan value during calibration. Please " + "make sure model results in all real-values during " + "inference with provided dataset."; + return kTfLiteError; } } + // We are only logging absolute min/max here. + const auto minmax = std::minmax_element(values, values + tensor_size); + min_ = std::min(min_, *minmax.first); + max_ = std::max(max_, *minmax.second); + + if (!has_values_) has_values_ = true; + return kTfLiteOk; } bool HasValues() const { return has_values_; } @@ -56,17 +60,19 @@ class MinMax { } private: - bool has_values_; - float min_, max_; + bool has_values_ = false; + float min_ = std::numeric_limits::max(); + float max_ = std::numeric_limits::min(); }; // Captures min max values for tensors. class Logger { public: // Log the value for tensor at |tensor_index| which has |tensor_values| - void LogTensorValue(int tensor_index, const float* tensor_values, - size_t tensor_size) { - tensor_id_to_stats_map_[tensor_index].Update(tensor_values, tensor_size); + TfLiteStatus LogTensorValue(int tensor_index, const float* tensor_values, + size_t tensor_size) { + return tensor_id_to_stats_map_[tensor_index].Update(tensor_values, + tensor_size); } // Returns a map from tensor_index -> observed min max values. @@ -82,4 +88,4 @@ class Logger { } // namespace optimize } // namespace tflite -#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_ +#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_ diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator.cc b/tensorflow/lite/tools/optimize/calibration/calibrator.cc index eead4e590f8..7a9a4943704 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibrator.cc +++ b/tensorflow/lite/tools/optimize/calibration/calibrator.cc @@ -171,7 +171,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) { for (int i : op_info.loggable_inputs) { auto tensor = context->tensors[i]; - logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)); + TF_LITE_ENSURE_STATUS( + logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float))); } auto status = kernel_invoke(context, node); @@ -182,7 +183,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) { for (int i : op_info.loggable_outputs) { auto tensor = context->tensors[i]; - logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)); + TF_LITE_ENSURE_STATUS( + logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float))); } return status; @@ -218,9 +220,7 @@ TfLiteStatus GetNodeOpInfoMapAndContext( const std::unordered_map& node_to_opinfo, tflite::Interpreter* const interpreter, std::unordered_map* node_ptr_opinfo_map, - const TfLiteContext** context - -) { + const TfLiteContext** context) { NodeInfoDelegateObserver delegate_observer(node_to_opinfo, node_ptr_opinfo_map); NodeInfoDelegateParams delegate_params;