Addressed reviewer's comments.

Change-Id: If8022418adcc6b6a93354625476f32155dd53d36
This commit is contained in:
Elena Zhelezina 2020-06-04 11:05:48 +01:00
parent 829277a571
commit 165c8c5dbd

View File

@ -120,8 +120,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <typename integer_type>
TfLiteStatus EvalSignedInt(TfLiteContext* context, const PadContext& op_context,
const tflite::PadParams& op_params) {
TfLiteStatus EvalInt(TfLiteContext* context, const PadContext& op_context,
const tflite::PadParams& op_params) {
integer_type pad_value;
if (op_context.constant_values == nullptr) {
// Quantized Pad requires that 0 is represented in the quantized
@ -211,43 +211,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
} break;
case kTfLiteUInt8: {
uint8_t pad_value;
if (op_context.constant_values == nullptr) {
// Quantized Pad requires that 0 is represented in the quantized
// range.
TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
std::numeric_limits<uint8_t>::min());
TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
std::numeric_limits<uint8_t>::max());
pad_value = static_cast<uint8_t>(op_context.output->params.zero_point);
} else {
// Quantized Pad requires that 'constant_values' is represented in the
// same quantized range as the input and output tensors.
TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
op_context.constant_values->params.zero_point);
TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
op_context.constant_values->params.scale);
pad_value = *GetTensorData<uint8_t>(op_context.constant_values);
}
if (kernel_type == kReference) {
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
TF_LITE_PAD(reference_ops, PadImageStyle, uint8_t, pad_value);
} else {
TF_LITE_PAD(reference_ops, Pad, uint8_t, pad_value);
}
} else if (kernel_type == kGenericOptimized) {
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
TF_LITE_PAD(optimized_ops, PadImageStyle, uint8_t, pad_value);
} else {
TF_LITE_PAD(optimized_ops, Pad, uint8_t, pad_value);
}
}
EvalInt<uint8_t>(context, op_context, op_params);
} break;
case kTfLiteInt8: {
EvalSignedInt<int8_t>(context, op_context, op_params);
EvalInt<int8_t>(context, op_context, op_params);
} break;
case kTfLiteInt16: {
EvalSignedInt<int16_t>(context, op_context, op_params);
EvalInt<int16_t>(context, op_context, op_params);
} break;
case kTfLiteInt32: {
int32_t pad_value =