Make calibrator throw error if values during inference are Nan.

PiperOrigin-RevId: 240590732
This commit is contained in:
Suharsh Sivakumar 2019-03-27 10:29:17 -07:00 committed by TensorFlower Gardener
parent 225727b257
commit 047abbc428
3 changed files with 36 additions and 29 deletions

View File

@ -89,6 +89,7 @@ cc_library(
name = "calibration_logger",
hdrs = ["calibration_logger.h"],
deps = [
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite/c:c_api_internal",
],
)

View File

@ -12,38 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
#include <algorithm>
#include <limits>
#include <unordered_map>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/c_api_internal.h"
namespace tflite {
namespace optimize {
namespace calibration {
class MinMax {
public:
void Update(const float* values, size_t tensor_size) {
// TODO(shashishekhar): Really slow implementation, optimize
if (tensor_size <= 0) return;
TfLiteStatus Update(const float* values, size_t tensor_size) {
if (tensor_size <= 0) return kTfLiteOk;
if (!has_values_) {
min_ = max_ = values[0];
has_values_ = true;
return;
}
// We are only logging absolute min/max here.
// TODO(shashishekhar): Make it possible to use weighted/moving average.
for (size_t i = 0; i < tensor_size; i++) {
float val = values[i];
if (min_ > val) {
min_ = val;
} else if (max_ < val) {
max_ = val;
for (size_t i = 0; i < tensor_size; ++i) {
if (std::isnan(values[i])) {
// TODO(suharshs): Propagate ErrorReporter here.
LOG(ERROR) << "Model resulted in Nan value during calibration. Please "
"make sure model results in all real-values during "
"inference with provided dataset.";
return kTfLiteError;
}
}
// We are only logging absolute min/max here.
const auto minmax = std::minmax_element(values, values + tensor_size);
min_ = std::min<float>(min_, *minmax.first);
max_ = std::max<float>(max_, *minmax.second);
if (!has_values_) has_values_ = true;
return kTfLiteOk;
}
bool HasValues() const { return has_values_; }
@ -56,17 +60,19 @@ class MinMax {
}
private:
bool has_values_;
float min_, max_;
bool has_values_ = false;
float min_ = std::numeric_limits<float>::max();
float max_ = std::numeric_limits<float>::min();
};
// Captures min max values for tensors.
class Logger {
public:
// Log the value for tensor at |tensor_index| which has |tensor_values|
void LogTensorValue(int tensor_index, const float* tensor_values,
TfLiteStatus LogTensorValue(int tensor_index, const float* tensor_values,
size_t tensor_size) {
tensor_id_to_stats_map_[tensor_index].Update(tensor_values, tensor_size);
return tensor_id_to_stats_map_[tensor_index].Update(tensor_values,
tensor_size);
}
// Returns a map from tensor_index -> observed min max values.
@ -82,4 +88,4 @@ class Logger {
} // namespace optimize
} // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_

View File

@ -171,7 +171,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
for (int i : op_info.loggable_inputs) {
auto tensor = context->tensors[i];
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
TF_LITE_ENSURE_STATUS(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
}
auto status = kernel_invoke(context, node);
@ -182,7 +183,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
for (int i : op_info.loggable_outputs) {
auto tensor = context->tensors[i];
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
TF_LITE_ENSURE_STATUS(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
}
return status;
@ -218,9 +220,7 @@ TfLiteStatus GetNodeOpInfoMapAndContext(
const std::unordered_map<int, OperatorInfo>& node_to_opinfo,
tflite::Interpreter* const interpreter,
std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
const TfLiteContext** context
) {
const TfLiteContext** context) {
NodeInfoDelegateObserver delegate_observer(node_to_opinfo,
node_ptr_opinfo_map);
NodeInfoDelegateParams delegate_params;