Make calibrator throw error if values during inference are Nan.

PiperOrigin-RevId: 240590732
This commit is contained in:
Suharsh Sivakumar 2019-03-27 10:29:17 -07:00 committed by TensorFlower Gardener
parent 225727b257
commit 047abbc428
3 changed files with 36 additions and 29 deletions

View File

@ -89,6 +89,7 @@ cc_library(
name = "calibration_logger", name = "calibration_logger",
hdrs = ["calibration_logger.h"], hdrs = ["calibration_logger.h"],
deps = [ deps = [
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_internal",
], ],
) )

View File

@ -12,38 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_ #ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_ #define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
#include <algorithm>
#include <limits>
#include <unordered_map> #include <unordered_map>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
namespace tflite { namespace tflite {
namespace optimize { namespace optimize {
namespace calibration { namespace calibration {
class MinMax { class MinMax {
public: public:
void Update(const float* values, size_t tensor_size) { TfLiteStatus Update(const float* values, size_t tensor_size) {
// TODO(shashishekhar): Really slow implementation, optimize if (tensor_size <= 0) return kTfLiteOk;
if (tensor_size <= 0) return;
if (!has_values_) {
min_ = max_ = values[0];
has_values_ = true;
return;
}
// We are only logging absolute min/max here.
// TODO(shashishekhar): Make it possible to use weighted/moving average. // TODO(shashishekhar): Make it possible to use weighted/moving average.
for (size_t i = 0; i < tensor_size; i++) { for (size_t i = 0; i < tensor_size; ++i) {
float val = values[i]; if (std::isnan(values[i])) {
if (min_ > val) { // TODO(suharshs): Propagate ErrorReporter here.
min_ = val; LOG(ERROR) << "Model resulted in Nan value during calibration. Please "
} else if (max_ < val) { "make sure model results in all real-values during "
max_ = val; "inference with provided dataset.";
return kTfLiteError;
} }
} }
// We are only logging absolute min/max here.
const auto minmax = std::minmax_element(values, values + tensor_size);
min_ = std::min<float>(min_, *minmax.first);
max_ = std::max<float>(max_, *minmax.second);
if (!has_values_) has_values_ = true;
return kTfLiteOk;
} }
bool HasValues() const { return has_values_; } bool HasValues() const { return has_values_; }
@ -56,17 +60,19 @@ class MinMax {
} }
private: private:
bool has_values_; bool has_values_ = false;
float min_, max_; float min_ = std::numeric_limits<float>::max();
float max_ = std::numeric_limits<float>::min();
}; };
// Captures min max values for tensors. // Captures min max values for tensors.
class Logger { class Logger {
public: public:
// Log the value for tensor at |tensor_index| which has |tensor_values| // Log the value for tensor at |tensor_index| which has |tensor_values|
void LogTensorValue(int tensor_index, const float* tensor_values, TfLiteStatus LogTensorValue(int tensor_index, const float* tensor_values,
size_t tensor_size) { size_t tensor_size) {
tensor_id_to_stats_map_[tensor_index].Update(tensor_values, tensor_size); return tensor_id_to_stats_map_[tensor_index].Update(tensor_values,
tensor_size);
} }
// Returns a map from tensor_index -> observed min max values. // Returns a map from tensor_index -> observed min max values.
@ -82,4 +88,4 @@ class Logger {
} // namespace optimize } // namespace optimize
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_ #endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_

View File

@ -171,7 +171,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
for (int i : op_info.loggable_inputs) { for (int i : op_info.loggable_inputs) {
auto tensor = context->tensors[i]; auto tensor = context->tensors[i];
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)); TF_LITE_ENSURE_STATUS(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
} }
auto status = kernel_invoke(context, node); auto status = kernel_invoke(context, node);
@ -182,7 +183,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
for (int i : op_info.loggable_outputs) { for (int i : op_info.loggable_outputs) {
auto tensor = context->tensors[i]; auto tensor = context->tensors[i];
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)); TF_LITE_ENSURE_STATUS(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
} }
return status; return status;
@ -218,9 +220,7 @@ TfLiteStatus GetNodeOpInfoMapAndContext(
const std::unordered_map<int, OperatorInfo>& node_to_opinfo, const std::unordered_map<int, OperatorInfo>& node_to_opinfo,
tflite::Interpreter* const interpreter, tflite::Interpreter* const interpreter,
std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map, std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
const TfLiteContext** context const TfLiteContext** context) {
) {
NodeInfoDelegateObserver delegate_observer(node_to_opinfo, NodeInfoDelegateObserver delegate_observer(node_to_opinfo,
node_ptr_opinfo_map); node_ptr_opinfo_map);
NodeInfoDelegateParams delegate_params; NodeInfoDelegateParams delegate_params;