Propagate nan-value error through ErrorReporter.

PiperOrigin-RevId: 289700382
Change-Id: I0a3719f0cf713268db22b3e8dcbdf9ae143ed280
This commit is contained in:
Jian Li 2020-01-14 11:58:56 -08:00 committed by TensorFlower Gardener
parent cff8012de1
commit a250518f0d
7 changed files with 55 additions and 36 deletions

View File

@ -18,6 +18,7 @@ cc_library(
":calibration_logger",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:lstm_shared",
"//tensorflow/lite/kernels:op_macros",
@ -120,8 +121,10 @@ cc_library(
hdrs = ["calibration_logger.h"],
copts = tflite_copts(),
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
@ -64,7 +65,8 @@ inline void LstmStepWithAuxInput(
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr, Logger* logger,
const std::vector<int>& intemediate_tensor_indexes) {
const std::vector<int>& intemediate_tensor_indexes,
ErrorReporter* error_reporter) {
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
@ -158,7 +160,7 @@ inline void LstmStepWithAuxInput(
}
if (is_layer_norm_lstm) {
logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
n_cell * n_batch);
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -179,7 +181,7 @@ inline void LstmStepWithAuxInput(
}
if (is_layer_norm_lstm) {
logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch,
n_cell * n_batch);
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -196,7 +198,7 @@ inline void LstmStepWithAuxInput(
n_batch * n_cell, cell_state_ptr);
if (is_layer_norm_lstm) {
logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch,
n_cell * n_batch);
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -229,7 +231,7 @@ inline void LstmStepWithAuxInput(
}
if (is_layer_norm_lstm) {
logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch,
n_cell * n_batch);
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -246,7 +248,7 @@ inline void LstmStepWithAuxInput(
n_batch * n_cell, output_gate_scratch);
logger->LogTensorValue(intemediate_tensor_indexes[4], output_gate_scratch,
n_cell * n_batch);
n_cell * n_batch, error_reporter);
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
@ -317,7 +319,8 @@ TfLiteStatus EvalFloat(
int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output, Logger* logger,
const std::vector<int>& intemediate_tensor_indexes) {
const std::vector<int>& intemediate_tensor_indexes,
ErrorReporter* error_reporter) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
int max_time, n_batch;
if (input->dims->size == 3) {
@ -404,7 +407,7 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(activation_state),
GetTensorData<float>(cell_state), input_gate_scratch,
forget_gate_scratch, cell_scratch, output_gate_scratch,
output_ptr_time, logger, intemediate_tensor_indexes);
output_ptr_time, logger, intemediate_tensor_indexes, error_reporter);
}
} else {
for (int b = 0; b < n_batch; b++) {
@ -465,7 +468,7 @@ TfLiteStatus EvalFloat(
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
output_ptr, logger, intemediate_tensor_indexes);
output_ptr, logger, intemediate_tensor_indexes, error_reporter);
}
}
}
@ -489,8 +492,8 @@ struct OpData {
// Resize the output, state tensors based on the sizes of the input tensors.
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
Logger* logger) {
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
ErrorReporter* error_reporter) {
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
const TfLiteTensor* input =
@ -585,7 +588,7 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
projection_bias, params, /*forward_sequence=*/true,
/*time_major=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output, logger, intemediate_tensor_indexes);
output, logger, intemediate_tensor_indexes, error_reporter);
}
case kTfLiteUInt8:
case kTfLiteInt8:
@ -598,8 +601,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
} // namespace
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
Logger* logger) {
return lstm_eval(context, node, logger);
Logger* logger,
ErrorReporter* error_reporter) {
return lstm_eval(context, node, logger, error_reporter);
}
} // namespace builtin

View File

@ -24,7 +24,7 @@ namespace calibration {
namespace builtin {
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
Logger* logger);
Logger* logger, ErrorReporter* error_reporter);
} // namespace builtin
} // namespace calibration

View File

@ -23,17 +23,17 @@ namespace tflite {
namespace optimize {
namespace calibration {
TfLiteStatus MinMax::Update(const float* values, size_t tensor_size) {
TfLiteStatus MinMax::Update(const float* values, size_t tensor_size,
ErrorReporter* error_reporter) {
if (tensor_size <= 0) return kTfLiteOk;
// TODO(shashishekhar): Make it possible to use weighted/moving average.
for (size_t i = 0; i < tensor_size; ++i) {
if (std::isnan(values[i])) {
// TODO(suharshs): Propagate ErrorReporter here.
TFLITE_LOG(tflite::TFLITE_LOG_ERROR,
"Model resulted in Nan value during calibration. Please "
"make sure model results in all real-values during "
"inference with provided dataset.");
error_reporter->Report(
"Model resulted in Nan value during calibration. Please "
"make sure model results in all real-values during "
"inference with provided dataset.");
return kTfLiteError;
}
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
namespace tflite {
namespace optimize {
@ -26,7 +27,8 @@ namespace calibration {
class MinMax {
public:
TfLiteStatus Update(const float* values, size_t tensor_size);
TfLiteStatus Update(const float* values, size_t tensor_size,
ErrorReporter* error_reporter);
bool HasValues() const { return has_values_; }
@ -48,9 +50,10 @@ class Logger {
public:
// Log the value for tensor at |tensor_index| which has |tensor_values|
TfLiteStatus LogTensorValue(int tensor_index, const float* tensor_values,
size_t tensor_size) {
return tensor_id_to_stats_map_[tensor_index].Update(tensor_values,
tensor_size);
size_t tensor_size,
ErrorReporter* error_reporter) {
return tensor_id_to_stats_map_[tensor_index].Update(
tensor_values, tensor_size, error_reporter);
}
// Returns a map from tensor_index -> observed min max values.

View File

@ -57,9 +57,11 @@ class Calibrator {
public:
Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
node_ptr_opinfo_map,
std::unique_ptr<LoggingOpResolver> logging_op_resolver)
std::unique_ptr<LoggingOpResolver> logging_op_resolver,
ErrorReporter* error_reporter)
: node_ptr_opinfo_map_(node_ptr_opinfo_map),
logging_op_resolver_(std::move(logging_op_resolver)) {
logging_op_resolver_(std::move(logging_op_resolver)),
error_reporter_(error_reporter) {
logger_ = absl::make_unique<Logger>();
}
@ -69,6 +71,9 @@ class Calibrator {
// Gets the instance of logger associated with the current context.
Logger* GetLogger() const { return logger_.get(); }
// Gets the error reporter.
ErrorReporter* GetErrorReporter() const { return error_reporter_; }
// Gets the operator information about the given TfLiteNode.
const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
return node_ptr_opinfo_map_.at(node);
@ -79,6 +84,7 @@ class Calibrator {
std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
const std::unordered_map<int, OperatorInfo> index_opinfo_;
std::unique_ptr<Logger> logger_;
ErrorReporter* error_reporter_;
};
KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
@ -147,7 +153,7 @@ class GlobalCalibratorRegistry {
return kTfLiteError;
}
auto calibrator = absl::make_unique<Calibrator>(
node_to_opinfo, std::move(logging_op_resolver));
node_to_opinfo, std::move(logging_op_resolver), reporter);
calibrator_registry_[context] = std::move(calibrator);
*calibrator_ptr = calibrator_registry_.at(context).get();
return kTfLiteOk;
@ -189,18 +195,20 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
auto kernel_invoke = calibrator->GetKernelInvoke(node);
auto logger = calibrator->GetLogger();
auto op_info = calibrator->GetOpInfo(node);
auto error_reporter = calibrator->GetErrorReporter();
for (int i : op_info.loggable_inputs) {
auto tensor = context->tensors[i];
TF_LITE_ENSURE_STATUS(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
}
auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node);
TfLiteStatus status;
if (kernel_invoke_intermediate == nullptr) {
status = kernel_invoke(context, node);
} else {
status = kernel_invoke_intermediate(context, node, calibrator->GetLogger());
status = kernel_invoke_intermediate(context, node, calibrator->GetLogger(),
error_reporter);
}
// TODO(shashishekhar): An intermediate tensor in graph will get logged twice
@ -212,14 +220,14 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
// cell.
for (int i : op_info.loggable_inputs) {
auto tensor = context->tensors[i];
TF_LITE_ENSURE_STATUS(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
}
for (int i : op_info.loggable_outputs) {
auto tensor = context->tensors[i];
TF_LITE_ENSURE_STATUS(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float)));
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
}
return status;

View File

@ -24,7 +24,8 @@ namespace calibration {
typedef TfLiteStatus (*logging_kernel_func_ptr)(TfLiteContext* context,
TfLiteNode* node,
Logger* logger);
Logger* logger,
ErrorReporter* error_reporter);
} // namespace calibration
} // namespace optimize