Make calibrator throw error if values during inference are Nan.
PiperOrigin-RevId: 240590732
This commit is contained in:
parent
225727b257
commit
047abbc428
@ -89,6 +89,7 @@ cc_library(
|
||||
name = "calibration_logger",
|
||||
hdrs = ["calibration_logger.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
],
|
||||
)
|
||||
|
@ -12,38 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
namespace calibration {
|
||||
|
||||
class MinMax {
|
||||
public:
|
||||
void Update(const float* values, size_t tensor_size) {
|
||||
// TODO(shashishekhar): Really slow implementation, optimize
|
||||
if (tensor_size <= 0) return;
|
||||
TfLiteStatus Update(const float* values, size_t tensor_size) {
|
||||
if (tensor_size <= 0) return kTfLiteOk;
|
||||
|
||||
if (!has_values_) {
|
||||
min_ = max_ = values[0];
|
||||
has_values_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// We are only logging absolute min/max here.
|
||||
// TODO(shashishekhar): Make it possible to use weighted/moving average.
|
||||
for (size_t i = 0; i < tensor_size; i++) {
|
||||
float val = values[i];
|
||||
if (min_ > val) {
|
||||
min_ = val;
|
||||
} else if (max_ < val) {
|
||||
max_ = val;
|
||||
for (size_t i = 0; i < tensor_size; ++i) {
|
||||
if (std::isnan(values[i])) {
|
||||
// TODO(suharshs): Propagate ErrorReporter here.
|
||||
LOG(ERROR) << "Model resulted in Nan value during calibration. Please "
|
||||
"make sure model results in all real-values during "
|
||||
"inference with provided dataset.";
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
// We are only logging absolute min/max here.
|
||||
const auto minmax = std::minmax_element(values, values + tensor_size);
|
||||
min_ = std::min<float>(min_, *minmax.first);
|
||||
max_ = std::max<float>(max_, *minmax.second);
|
||||
|
||||
if (!has_values_) has_values_ = true;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
bool HasValues() const { return has_values_; }
|
||||
@ -56,17 +60,19 @@ class MinMax {
|
||||
}
|
||||
|
||||
private:
|
||||
bool has_values_;
|
||||
float min_, max_;
|
||||
bool has_values_ = false;
|
||||
float min_ = std::numeric_limits<float>::max();
|
||||
float max_ = std::numeric_limits<float>::min();
|
||||
};
|
||||
|
||||
// Captures min max values for tensors.
|
||||
class Logger {
|
||||
public:
|
||||
// Log the value for tensor at |tensor_index| which has |tensor_values|
|
||||
void LogTensorValue(int tensor_index, const float* tensor_values,
|
||||
size_t tensor_size) {
|
||||
tensor_id_to_stats_map_[tensor_index].Update(tensor_values, tensor_size);
|
||||
TfLiteStatus LogTensorValue(int tensor_index, const float* tensor_values,
|
||||
size_t tensor_size) {
|
||||
return tensor_id_to_stats_map_[tensor_index].Update(tensor_values,
|
||||
tensor_size);
|
||||
}
|
||||
|
||||
// Returns a map from tensor_index -> observed min max values.
|
||||
@ -82,4 +88,4 @@ class Logger {
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_LOGGER_H_
|
||||
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CALIBRATION_LOGGER_H_
|
||||
|
@ -171,7 +171,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
for (int i : op_info.loggable_inputs) {
|
||||
auto tensor = context->tensors[i];
|
||||
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
|
||||
}
|
||||
|
||||
auto status = kernel_invoke(context, node);
|
||||
@ -182,7 +183,8 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
for (int i : op_info.loggable_outputs) {
|
||||
auto tensor = context->tensors[i];
|
||||
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
|
||||
}
|
||||
|
||||
return status;
|
||||
@ -218,9 +220,7 @@ TfLiteStatus GetNodeOpInfoMapAndContext(
|
||||
const std::unordered_map<int, OperatorInfo>& node_to_opinfo,
|
||||
tflite::Interpreter* const interpreter,
|
||||
std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
|
||||
const TfLiteContext** context
|
||||
|
||||
) {
|
||||
const TfLiteContext** context) {
|
||||
NodeInfoDelegateObserver delegate_observer(node_to_opinfo,
|
||||
node_ptr_opinfo_map);
|
||||
NodeInfoDelegateParams delegate_params;
|
||||
|
Loading…
Reference in New Issue
Block a user