Propagate nan-value error through ErrorReporter.
PiperOrigin-RevId: 289700382 Change-Id: I0a3719f0cf713268db22b3e8dcbdf9ae143ed280
This commit is contained in:
parent
cff8012de1
commit
a250518f0d
@ -18,6 +18,7 @@ cc_library(
|
||||
":calibration_logger",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels:lstm_shared",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
@ -120,8 +121,10 @@ cc_library(
|
||||
hdrs = ["calibration_logger.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
@ -64,7 +65,8 @@ inline void LstmStepWithAuxInput(
|
||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* output_ptr, Logger* logger,
|
||||
const std::vector<int>& intemediate_tensor_indexes) {
|
||||
const std::vector<int>& intemediate_tensor_indexes,
|
||||
ErrorReporter* error_reporter) {
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
@ -158,7 +160,7 @@ inline void LstmStepWithAuxInput(
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
|
||||
n_cell * n_batch);
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(
|
||||
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -179,7 +181,7 @@ inline void LstmStepWithAuxInput(
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch,
|
||||
n_cell * n_batch);
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||
forget_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -196,7 +198,7 @@ inline void LstmStepWithAuxInput(
|
||||
n_batch * n_cell, cell_state_ptr);
|
||||
if (is_layer_norm_lstm) {
|
||||
logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch,
|
||||
n_cell * n_batch);
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
||||
n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -229,7 +231,7 @@ inline void LstmStepWithAuxInput(
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch,
|
||||
n_cell * n_batch);
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||
output_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -246,7 +248,7 @@ inline void LstmStepWithAuxInput(
|
||||
n_batch * n_cell, output_gate_scratch);
|
||||
|
||||
logger->LogTensorValue(intemediate_tensor_indexes[4], output_gate_scratch,
|
||||
n_cell * n_batch);
|
||||
n_cell * n_batch, error_reporter);
|
||||
|
||||
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
||||
const bool use_projection_bias = (projection_bias_ptr != nullptr);
|
||||
@ -317,7 +319,8 @@ TfLiteStatus EvalFloat(
|
||||
int output_offset, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output, Logger* logger,
|
||||
const std::vector<int>& intemediate_tensor_indexes) {
|
||||
const std::vector<int>& intemediate_tensor_indexes,
|
||||
ErrorReporter* error_reporter) {
|
||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||
int max_time, n_batch;
|
||||
if (input->dims->size == 3) {
|
||||
@ -404,7 +407,7 @@ TfLiteStatus EvalFloat(
|
||||
GetTensorData<float>(activation_state),
|
||||
GetTensorData<float>(cell_state), input_gate_scratch,
|
||||
forget_gate_scratch, cell_scratch, output_gate_scratch,
|
||||
output_ptr_time, logger, intemediate_tensor_indexes);
|
||||
output_ptr_time, logger, intemediate_tensor_indexes, error_reporter);
|
||||
}
|
||||
} else {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
@ -465,7 +468,7 @@ TfLiteStatus EvalFloat(
|
||||
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
|
||||
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
|
||||
output_ptr, logger, intemediate_tensor_indexes);
|
||||
output_ptr, logger, intemediate_tensor_indexes, error_reporter);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -489,8 +492,8 @@ struct OpData {
|
||||
// Resize the output, state tensors based on the sizes of the input tensors.
|
||||
// Allocate a temporary scratch tensor. Also check that the sizes of the input
|
||||
// tensors match each other.
|
||||
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
|
||||
Logger* logger) {
|
||||
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
ErrorReporter* error_reporter) {
|
||||
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input =
|
||||
@ -585,7 +588,7 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
|
||||
projection_bias, params, /*forward_sequence=*/true,
|
||||
/*time_major=*/true,
|
||||
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
||||
output, logger, intemediate_tensor_indexes);
|
||||
output, logger, intemediate_tensor_indexes, error_reporter);
|
||||
}
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
@ -598,8 +601,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
|
||||
Logger* logger) {
|
||||
return lstm_eval(context, node, logger);
|
||||
Logger* logger,
|
||||
ErrorReporter* error_reporter) {
|
||||
return lstm_eval(context, node, logger, error_reporter);
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
|
@ -24,7 +24,7 @@ namespace calibration {
|
||||
namespace builtin {
|
||||
|
||||
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
|
||||
Logger* logger);
|
||||
Logger* logger, ErrorReporter* error_reporter);
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace calibration
|
||||
|
@ -23,17 +23,17 @@ namespace tflite {
|
||||
namespace optimize {
|
||||
namespace calibration {
|
||||
|
||||
TfLiteStatus MinMax::Update(const float* values, size_t tensor_size) {
|
||||
TfLiteStatus MinMax::Update(const float* values, size_t tensor_size,
|
||||
ErrorReporter* error_reporter) {
|
||||
if (tensor_size <= 0) return kTfLiteOk;
|
||||
|
||||
// TODO(shashishekhar): Make it possible to use weighted/moving average.
|
||||
for (size_t i = 0; i < tensor_size; ++i) {
|
||||
if (std::isnan(values[i])) {
|
||||
// TODO(suharshs): Propagate ErrorReporter here.
|
||||
TFLITE_LOG(tflite::TFLITE_LOG_ERROR,
|
||||
"Model resulted in Nan value during calibration. Please "
|
||||
"make sure model results in all real-values during "
|
||||
"inference with provided dataset.");
|
||||
error_reporter->Report(
|
||||
"Model resulted in Nan value during calibration. Please "
|
||||
"make sure model results in all real-values during "
|
||||
"inference with provided dataset.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
@ -26,7 +27,8 @@ namespace calibration {
|
||||
|
||||
class MinMax {
|
||||
public:
|
||||
TfLiteStatus Update(const float* values, size_t tensor_size);
|
||||
TfLiteStatus Update(const float* values, size_t tensor_size,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
bool HasValues() const { return has_values_; }
|
||||
|
||||
@ -48,9 +50,10 @@ class Logger {
|
||||
public:
|
||||
// Log the value for tensor at |tensor_index| which has |tensor_values|
|
||||
TfLiteStatus LogTensorValue(int tensor_index, const float* tensor_values,
|
||||
size_t tensor_size) {
|
||||
return tensor_id_to_stats_map_[tensor_index].Update(tensor_values,
|
||||
tensor_size);
|
||||
size_t tensor_size,
|
||||
ErrorReporter* error_reporter) {
|
||||
return tensor_id_to_stats_map_[tensor_index].Update(
|
||||
tensor_values, tensor_size, error_reporter);
|
||||
}
|
||||
|
||||
// Returns a map from tensor_index -> observed min max values.
|
||||
|
@ -57,9 +57,11 @@ class Calibrator {
|
||||
public:
|
||||
Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
|
||||
node_ptr_opinfo_map,
|
||||
std::unique_ptr<LoggingOpResolver> logging_op_resolver)
|
||||
std::unique_ptr<LoggingOpResolver> logging_op_resolver,
|
||||
ErrorReporter* error_reporter)
|
||||
: node_ptr_opinfo_map_(node_ptr_opinfo_map),
|
||||
logging_op_resolver_(std::move(logging_op_resolver)) {
|
||||
logging_op_resolver_(std::move(logging_op_resolver)),
|
||||
error_reporter_(error_reporter) {
|
||||
logger_ = absl::make_unique<Logger>();
|
||||
}
|
||||
|
||||
@ -69,6 +71,9 @@ class Calibrator {
|
||||
// Gets the instance of logger associated with the current context.
|
||||
Logger* GetLogger() const { return logger_.get(); }
|
||||
|
||||
// Gets the error reporter.
|
||||
ErrorReporter* GetErrorReporter() const { return error_reporter_; }
|
||||
|
||||
// Gets the operator information about the given TfLiteNode.
|
||||
const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
|
||||
return node_ptr_opinfo_map_.at(node);
|
||||
@ -79,6 +84,7 @@ class Calibrator {
|
||||
std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
|
||||
const std::unordered_map<int, OperatorInfo> index_opinfo_;
|
||||
std::unique_ptr<Logger> logger_;
|
||||
ErrorReporter* error_reporter_;
|
||||
};
|
||||
|
||||
KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
|
||||
@ -147,7 +153,7 @@ class GlobalCalibratorRegistry {
|
||||
return kTfLiteError;
|
||||
}
|
||||
auto calibrator = absl::make_unique<Calibrator>(
|
||||
node_to_opinfo, std::move(logging_op_resolver));
|
||||
node_to_opinfo, std::move(logging_op_resolver), reporter);
|
||||
calibrator_registry_[context] = std::move(calibrator);
|
||||
*calibrator_ptr = calibrator_registry_.at(context).get();
|
||||
return kTfLiteOk;
|
||||
@ -189,18 +195,20 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto kernel_invoke = calibrator->GetKernelInvoke(node);
|
||||
auto logger = calibrator->GetLogger();
|
||||
auto op_info = calibrator->GetOpInfo(node);
|
||||
auto error_reporter = calibrator->GetErrorReporter();
|
||||
|
||||
for (int i : op_info.loggable_inputs) {
|
||||
auto tensor = context->tensors[i];
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
|
||||
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
|
||||
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
|
||||
}
|
||||
auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node);
|
||||
TfLiteStatus status;
|
||||
if (kernel_invoke_intermediate == nullptr) {
|
||||
status = kernel_invoke(context, node);
|
||||
} else {
|
||||
status = kernel_invoke_intermediate(context, node, calibrator->GetLogger());
|
||||
status = kernel_invoke_intermediate(context, node, calibrator->GetLogger(),
|
||||
error_reporter);
|
||||
}
|
||||
|
||||
// TODO(shashishekhar): An intermediate tensor in graph will get logged twice
|
||||
@ -212,14 +220,14 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// cell.
|
||||
for (int i : op_info.loggable_inputs) {
|
||||
auto tensor = context->tensors[i];
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
|
||||
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
|
||||
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
|
||||
}
|
||||
|
||||
for (int i : op_info.loggable_outputs) {
|
||||
auto tensor = context->tensors[i];
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
|
||||
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
|
||||
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
|
||||
}
|
||||
|
||||
return status;
|
||||
|
@ -24,7 +24,8 @@ namespace calibration {
|
||||
|
||||
typedef TfLiteStatus (*logging_kernel_func_ptr)(TfLiteContext* context,
|
||||
TfLiteNode* node,
|
||||
Logger* logger);
|
||||
Logger* logger,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
} // namespace calibration
|
||||
} // namespace optimize
|
||||
|
Loading…
x
Reference in New Issue
Block a user