Make calibrator throw error if values during inference are Nan.
PiperOrigin-RevId: 240590732
This commit is contained in:
parent
225727b257
commit
047abbc428
@ -89,6 +89,7 @@ cc_library(
|
|||||||
name = "calibration_logger",
|
name = "calibration_logger",
|
||||||
hdrs = ["calibration_logger.h"],
|
hdrs = ["calibration_logger.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/core:tflite_portable_logging",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -12,38 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_
|
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
|
||||||
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_
|
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace optimize {
|
namespace optimize {
|
||||||
namespace calibration {
|
namespace calibration {
|
||||||
|
|
||||||
class MinMax {
|
class MinMax {
|
||||||
public:
|
public:
|
||||||
void Update(const float* values, size_t tensor_size) {
|
TfLiteStatus Update(const float* values, size_t tensor_size) {
|
||||||
// TODO(shashishekhar): Really slow implementation, optimize
|
if (tensor_size <= 0) return kTfLiteOk;
|
||||||
if (tensor_size <= 0) return;
|
|
||||||
|
|
||||||
if (!has_values_) {
|
|
||||||
min_ = max_ = values[0];
|
|
||||||
has_values_ = true;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We are only logging absolute min/max here.
|
|
||||||
// TODO(shashishekhar): Make it possible to use weighted/moving average.
|
// TODO(shashishekhar): Make it possible to use weighted/moving average.
|
||||||
for (size_t i = 0; i < tensor_size; i++) {
|
for (size_t i = 0; i < tensor_size; ++i) {
|
||||||
float val = values[i];
|
if (std::isnan(values[i])) {
|
||||||
if (min_ > val) {
|
// TODO(suharshs): Propagate ErrorReporter here.
|
||||||
min_ = val;
|
LOG(ERROR) << "Model resulted in Nan value during calibration. Please "
|
||||||
} else if (max_ < val) {
|
"make sure model results in all real-values during "
|
||||||
max_ = val;
|
"inference with provided dataset.";
|
||||||
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// We are only logging absolute min/max here.
|
||||||
|
const auto minmax = std::minmax_element(values, values + tensor_size);
|
||||||
|
min_ = std::min<float>(min_, *minmax.first);
|
||||||
|
max_ = std::max<float>(max_, *minmax.second);
|
||||||
|
|
||||||
|
if (!has_values_) has_values_ = true;
|
||||||
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HasValues() const { return has_values_; }
|
bool HasValues() const { return has_values_; }
|
||||||
@ -56,17 +60,19 @@ class MinMax {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool has_values_;
|
bool has_values_ = false;
|
||||||
float min_, max_;
|
float min_ = std::numeric_limits<float>::max();
|
||||||
|
float max_ = std::numeric_limits<float>::min();
|
||||||
};
|
};
|
||||||
|
|
||||||
// Captures min max values for tensors.
|
// Captures min max values for tensors.
|
||||||
class Logger {
|
class Logger {
|
||||||
public:
|
public:
|
||||||
// Log the value for tensor at |tensor_index| which has |tensor_values|
|
// Log the value for tensor at |tensor_index| which has |tensor_values|
|
||||||
void LogTensorValue(int tensor_index, const float* tensor_values,
|
TfLiteStatus LogTensorValue(int tensor_index, const float* tensor_values,
|
||||||
size_t tensor_size) {
|
size_t tensor_size) {
|
||||||
tensor_id_to_stats_map_[tensor_index].Update(tensor_values, tensor_size);
|
return tensor_id_to_stats_map_[tensor_index].Update(tensor_values,
|
||||||
|
tensor_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a map from tensor_index -> observed min max values.
|
// Returns a map from tensor_index -> observed min max values.
|
||||||
@ -82,4 +88,4 @@ class Logger {
|
|||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_
|
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
|
||||||
|
@ -171,7 +171,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
for (int i : op_info.loggable_inputs) {
|
for (int i : op_info.loggable_inputs) {
|
||||||
auto tensor = context->tensors[i];
|
auto tensor = context->tensors[i];
|
||||||
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
|
TF_LITE_ENSURE_STATUS(
|
||||||
|
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto status = kernel_invoke(context, node);
|
auto status = kernel_invoke(context, node);
|
||||||
@ -182,7 +183,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
for (int i : op_info.loggable_outputs) {
|
for (int i : op_info.loggable_outputs) {
|
||||||
auto tensor = context->tensors[i];
|
auto tensor = context->tensors[i];
|
||||||
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
|
TF_LITE_ENSURE_STATUS(
|
||||||
|
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
@ -218,9 +220,7 @@ TfLiteStatus GetNodeOpInfoMapAndContext(
|
|||||||
const std::unordered_map<int, OperatorInfo>& node_to_opinfo,
|
const std::unordered_map<int, OperatorInfo>& node_to_opinfo,
|
||||||
tflite::Interpreter* const interpreter,
|
tflite::Interpreter* const interpreter,
|
||||||
std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
|
std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
|
||||||
const TfLiteContext** context
|
const TfLiteContext** context) {
|
||||||
|
|
||||||
) {
|
|
||||||
NodeInfoDelegateObserver delegate_observer(node_to_opinfo,
|
NodeInfoDelegateObserver delegate_observer(node_to_opinfo,
|
||||||
node_ptr_opinfo_map);
|
node_ptr_opinfo_map);
|
||||||
NodeInfoDelegateParams delegate_params;
|
NodeInfoDelegateParams delegate_params;
|
||||||
|
Loading…
Reference in New Issue
Block a user