Propagate nan-value error through ErrorReporter.

PiperOrigin-RevId: 289700382
Change-Id: I0a3719f0cf713268db22b3e8dcbdf9ae143ed280
This commit is contained in:
Jian Li 2020-01-14 11:58:56 -08:00 committed by TensorFlower Gardener
parent cff8012de1
commit a250518f0d
7 changed files with 55 additions and 36 deletions

View File

@ -18,6 +18,7 @@ cc_library(
":calibration_logger", ":calibration_logger",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:lstm_shared", "//tensorflow/lite/kernels:lstm_shared",
"//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels:op_macros",
@ -120,8 +121,10 @@ cc_library(
hdrs = ["calibration_logger.h"], hdrs = ["calibration_logger.h"],
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite:minimal_logging", "//tensorflow/lite:minimal_logging",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
], ],
) )

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
@ -64,7 +65,8 @@ inline void LstmStepWithAuxInput(
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr, Logger* logger, float* output_ptr, Logger* logger,
const std::vector<int>& intemediate_tensor_indexes) { const std::vector<int>& intemediate_tensor_indexes,
ErrorReporter* error_reporter) {
// Since we have already checked that weights are all there or none, we can // Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition. // check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_cifg = (input_to_input_weights_ptr == nullptr);
@ -158,7 +160,7 @@ inline void LstmStepWithAuxInput(
} }
if (is_layer_norm_lstm) { if (is_layer_norm_lstm) {
logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch, logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
n_cell * n_batch); n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization( tensor_utils::MeanStddevNormalization(
input_gate_scratch, input_gate_scratch, n_cell, n_batch); input_gate_scratch, input_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct( tensor_utils::VectorBatchVectorCwiseProduct(
@ -179,7 +181,7 @@ inline void LstmStepWithAuxInput(
} }
if (is_layer_norm_lstm) { if (is_layer_norm_lstm) {
logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch, logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch,
n_cell * n_batch); n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(forget_gate_scratch, tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch); forget_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct( tensor_utils::VectorBatchVectorCwiseProduct(
@ -196,7 +198,7 @@ inline void LstmStepWithAuxInput(
n_batch * n_cell, cell_state_ptr); n_batch * n_cell, cell_state_ptr);
if (is_layer_norm_lstm) { if (is_layer_norm_lstm) {
logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch, logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch,
n_cell * n_batch); n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch); n_batch);
tensor_utils::VectorBatchVectorCwiseProduct( tensor_utils::VectorBatchVectorCwiseProduct(
@ -229,7 +231,7 @@ inline void LstmStepWithAuxInput(
} }
if (is_layer_norm_lstm) { if (is_layer_norm_lstm) {
logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch, logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch,
n_cell * n_batch); n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(output_gate_scratch, tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch); output_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct( tensor_utils::VectorBatchVectorCwiseProduct(
@ -246,7 +248,7 @@ inline void LstmStepWithAuxInput(
n_batch * n_cell, output_gate_scratch); n_batch * n_cell, output_gate_scratch);
logger->LogTensorValue(intemediate_tensor_indexes[4], output_gate_scratch, logger->LogTensorValue(intemediate_tensor_indexes[4], output_gate_scratch,
n_cell * n_batch); n_cell * n_batch, error_reporter);
const bool use_projection_weight = (projection_weights_ptr != nullptr); const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr); const bool use_projection_bias = (projection_bias_ptr != nullptr);
@ -317,7 +319,8 @@ TfLiteStatus EvalFloat(
int output_offset, TfLiteTensor* scratch_buffer, int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* activation_state, TfLiteTensor* cell_state, TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output, Logger* logger, TfLiteTensor* output, Logger* logger,
const std::vector<int>& intemediate_tensor_indexes) { const std::vector<int>& intemediate_tensor_indexes,
ErrorReporter* error_reporter) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
int max_time, n_batch; int max_time, n_batch;
if (input->dims->size == 3) { if (input->dims->size == 3) {
@ -404,7 +407,7 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(activation_state), GetTensorData<float>(activation_state),
GetTensorData<float>(cell_state), input_gate_scratch, GetTensorData<float>(cell_state), input_gate_scratch,
forget_gate_scratch, cell_scratch, output_gate_scratch, forget_gate_scratch, cell_scratch, output_gate_scratch,
output_ptr_time, logger, intemediate_tensor_indexes); output_ptr_time, logger, intemediate_tensor_indexes, error_reporter);
} }
} else { } else {
for (int b = 0; b < n_batch; b++) { for (int b = 0; b < n_batch; b++) {
@ -465,7 +468,7 @@ TfLiteStatus EvalFloat(
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr, activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr, forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
output_ptr, logger, intemediate_tensor_indexes); output_ptr, logger, intemediate_tensor_indexes, error_reporter);
} }
} }
} }
@ -489,8 +492,8 @@ struct OpData {
// Resize the output, state tensors based on the sizes of the input tensors. // Resize the output, state tensors based on the sizes of the input tensors.
// Allocate a temporary scratch tensor. Also check that the sizes of the input // Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other. // tensors match each other.
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
Logger* logger) { ErrorReporter* error_reporter) {
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data); const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
const TfLiteTensor* input = const TfLiteTensor* input =
@ -585,7 +588,7 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
projection_bias, params, /*forward_sequence=*/true, projection_bias, params, /*forward_sequence=*/true,
/*time_major=*/true, /*time_major=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state, /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output, logger, intemediate_tensor_indexes); output, logger, intemediate_tensor_indexes, error_reporter);
} }
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: case kTfLiteInt8:
@ -598,8 +601,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
} // namespace } // namespace
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node, TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
Logger* logger) { Logger* logger,
return lstm_eval(context, node, logger); ErrorReporter* error_reporter) {
return lstm_eval(context, node, logger, error_reporter);
} }
} // namespace builtin } // namespace builtin

View File

@ -24,7 +24,7 @@ namespace calibration {
namespace builtin { namespace builtin {
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node, TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
Logger* logger); Logger* logger, ErrorReporter* error_reporter);
} // namespace builtin } // namespace builtin
} // namespace calibration } // namespace calibration

View File

@ -23,17 +23,17 @@ namespace tflite {
namespace optimize { namespace optimize {
namespace calibration { namespace calibration {
TfLiteStatus MinMax::Update(const float* values, size_t tensor_size) { TfLiteStatus MinMax::Update(const float* values, size_t tensor_size,
ErrorReporter* error_reporter) {
if (tensor_size <= 0) return kTfLiteOk; if (tensor_size <= 0) return kTfLiteOk;
// TODO(shashishekhar): Make it possible to use weighted/moving average. // TODO(shashishekhar): Make it possible to use weighted/moving average.
for (size_t i = 0; i < tensor_size; ++i) { for (size_t i = 0; i < tensor_size; ++i) {
if (std::isnan(values[i])) { if (std::isnan(values[i])) {
// TODO(suharshs): Propagate ErrorReporter here. error_reporter->Report(
TFLITE_LOG(tflite::TFLITE_LOG_ERROR, "Model resulted in Nan value during calibration. Please "
"Model resulted in Nan value during calibration. Please " "make sure model results in all real-values during "
"make sure model results in all real-values during " "inference with provided dataset.");
"inference with provided dataset.");
return kTfLiteError; return kTfLiteError;
} }
} }

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
namespace tflite { namespace tflite {
namespace optimize { namespace optimize {
@ -26,7 +27,8 @@ namespace calibration {
class MinMax { class MinMax {
public: public:
TfLiteStatus Update(const float* values, size_t tensor_size); TfLiteStatus Update(const float* values, size_t tensor_size,
ErrorReporter* error_reporter);
bool HasValues() const { return has_values_; } bool HasValues() const { return has_values_; }
@ -48,9 +50,10 @@ class Logger {
public: public:
// Log the value for tensor at |tensor_index| which has |tensor_values| // Log the value for tensor at |tensor_index| which has |tensor_values|
TfLiteStatus LogTensorValue(int tensor_index, const float* tensor_values, TfLiteStatus LogTensorValue(int tensor_index, const float* tensor_values,
size_t tensor_size) { size_t tensor_size,
return tensor_id_to_stats_map_[tensor_index].Update(tensor_values, ErrorReporter* error_reporter) {
tensor_size); return tensor_id_to_stats_map_[tensor_index].Update(
tensor_values, tensor_size, error_reporter);
} }
// Returns a map from tensor_index -> observed min max values. // Returns a map from tensor_index -> observed min max values.

View File

@ -57,9 +57,11 @@ class Calibrator {
public: public:
Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>& Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
node_ptr_opinfo_map, node_ptr_opinfo_map,
std::unique_ptr<LoggingOpResolver> logging_op_resolver) std::unique_ptr<LoggingOpResolver> logging_op_resolver,
ErrorReporter* error_reporter)
: node_ptr_opinfo_map_(node_ptr_opinfo_map), : node_ptr_opinfo_map_(node_ptr_opinfo_map),
logging_op_resolver_(std::move(logging_op_resolver)) { logging_op_resolver_(std::move(logging_op_resolver)),
error_reporter_(error_reporter) {
logger_ = absl::make_unique<Logger>(); logger_ = absl::make_unique<Logger>();
} }
@ -69,6 +71,9 @@ class Calibrator {
// Gets the instance of logger associated with the current context. // Gets the instance of logger associated with the current context.
Logger* GetLogger() const { return logger_.get(); } Logger* GetLogger() const { return logger_.get(); }
// Gets the error reporter.
ErrorReporter* GetErrorReporter() const { return error_reporter_; }
// Gets the operator information about the given TfLiteNode. // Gets the operator information about the given TfLiteNode.
const OperatorInfo& GetOpInfo(const TfLiteNode* node) const { const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
return node_ptr_opinfo_map_.at(node); return node_ptr_opinfo_map_.at(node);
@ -79,6 +84,7 @@ class Calibrator {
std::unique_ptr<LoggingOpResolver> logging_op_resolver_; std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
const std::unordered_map<int, OperatorInfo> index_opinfo_; const std::unordered_map<int, OperatorInfo> index_opinfo_;
std::unique_ptr<Logger> logger_; std::unique_ptr<Logger> logger_;
ErrorReporter* error_reporter_;
}; };
KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const { KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
@ -147,7 +153,7 @@ class GlobalCalibratorRegistry {
return kTfLiteError; return kTfLiteError;
} }
auto calibrator = absl::make_unique<Calibrator>( auto calibrator = absl::make_unique<Calibrator>(
node_to_opinfo, std::move(logging_op_resolver)); node_to_opinfo, std::move(logging_op_resolver), reporter);
calibrator_registry_[context] = std::move(calibrator); calibrator_registry_[context] = std::move(calibrator);
*calibrator_ptr = calibrator_registry_.at(context).get(); *calibrator_ptr = calibrator_registry_.at(context).get();
return kTfLiteOk; return kTfLiteOk;
@ -189,18 +195,20 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
auto kernel_invoke = calibrator->GetKernelInvoke(node); auto kernel_invoke = calibrator->GetKernelInvoke(node);
auto logger = calibrator->GetLogger(); auto logger = calibrator->GetLogger();
auto op_info = calibrator->GetOpInfo(node); auto op_info = calibrator->GetOpInfo(node);
auto error_reporter = calibrator->GetErrorReporter();
for (int i : op_info.loggable_inputs) { for (int i : op_info.loggable_inputs) {
auto tensor = context->tensors[i]; auto tensor = context->tensors[i];
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float))); i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
} }
auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node); auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node);
TfLiteStatus status; TfLiteStatus status;
if (kernel_invoke_intermediate == nullptr) { if (kernel_invoke_intermediate == nullptr) {
status = kernel_invoke(context, node); status = kernel_invoke(context, node);
} else { } else {
status = kernel_invoke_intermediate(context, node, calibrator->GetLogger()); status = kernel_invoke_intermediate(context, node, calibrator->GetLogger(),
error_reporter);
} }
// TODO(shashishekhar): An intermediate tensor in graph will get logged twice // TODO(shashishekhar): An intermediate tensor in graph will get logged twice
@ -212,14 +220,14 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
// cell. // cell.
for (int i : op_info.loggable_inputs) { for (int i : op_info.loggable_inputs) {
auto tensor = context->tensors[i]; auto tensor = context->tensors[i];
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float))); i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
} }
for (int i : op_info.loggable_outputs) { for (int i : op_info.loggable_outputs) {
auto tensor = context->tensors[i]; auto tensor = context->tensors[i];
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float))); i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
} }
return status; return status;

View File

@ -24,7 +24,8 @@ namespace calibration {
typedef TfLiteStatus (*logging_kernel_func_ptr)(TfLiteContext* context, typedef TfLiteStatus (*logging_kernel_func_ptr)(TfLiteContext* context,
TfLiteNode* node, TfLiteNode* node,
Logger* logger); Logger* logger,
ErrorReporter* error_reporter);
} // namespace calibration } // namespace calibration
} // namespace optimize } // namespace optimize