Addressed reviewer's comments.
Change-Id: If8022418adcc6b6a93354625476f32155dd53d36
This commit is contained in:
parent
829277a571
commit
165c8c5dbd
@ -120,8 +120,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename integer_type>
|
template <typename integer_type>
|
||||||
TfLiteStatus EvalSignedInt(TfLiteContext* context, const PadContext& op_context,
|
TfLiteStatus EvalInt(TfLiteContext* context, const PadContext& op_context,
|
||||||
const tflite::PadParams& op_params) {
|
const tflite::PadParams& op_params) {
|
||||||
integer_type pad_value;
|
integer_type pad_value;
|
||||||
if (op_context.constant_values == nullptr) {
|
if (op_context.constant_values == nullptr) {
|
||||||
// Quantized Pad requires that 0 is represented in the quantized
|
// Quantized Pad requires that 0 is represented in the quantized
|
||||||
@ -211,43 +211,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case kTfLiteUInt8: {
|
case kTfLiteUInt8: {
|
||||||
uint8_t pad_value;
|
EvalInt<uint8_t>(context, op_context, op_params);
|
||||||
if (op_context.constant_values == nullptr) {
|
|
||||||
// Quantized Pad requires that 0 is represented in the quantized
|
|
||||||
// range.
|
|
||||||
TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
|
|
||||||
std::numeric_limits<uint8_t>::min());
|
|
||||||
TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
|
|
||||||
std::numeric_limits<uint8_t>::max());
|
|
||||||
pad_value = static_cast<uint8_t>(op_context.output->params.zero_point);
|
|
||||||
} else {
|
|
||||||
// Quantized Pad requires that 'constant_values' is represented in the
|
|
||||||
// same quantized range as the input and output tensors.
|
|
||||||
TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
|
|
||||||
op_context.constant_values->params.zero_point);
|
|
||||||
TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
|
|
||||||
op_context.constant_values->params.scale);
|
|
||||||
pad_value = *GetTensorData<uint8_t>(op_context.constant_values);
|
|
||||||
}
|
|
||||||
if (kernel_type == kReference) {
|
|
||||||
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
|
|
||||||
TF_LITE_PAD(reference_ops, PadImageStyle, uint8_t, pad_value);
|
|
||||||
} else {
|
|
||||||
TF_LITE_PAD(reference_ops, Pad, uint8_t, pad_value);
|
|
||||||
}
|
|
||||||
} else if (kernel_type == kGenericOptimized) {
|
|
||||||
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
|
|
||||||
TF_LITE_PAD(optimized_ops, PadImageStyle, uint8_t, pad_value);
|
|
||||||
} else {
|
|
||||||
TF_LITE_PAD(optimized_ops, Pad, uint8_t, pad_value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} break;
|
} break;
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
EvalSignedInt<int8_t>(context, op_context, op_params);
|
EvalInt<int8_t>(context, op_context, op_params);
|
||||||
} break;
|
} break;
|
||||||
case kTfLiteInt16: {
|
case kTfLiteInt16: {
|
||||||
EvalSignedInt<int16_t>(context, op_context, op_params);
|
EvalInt<int16_t>(context, op_context, op_params);
|
||||||
} break;
|
} break;
|
||||||
case kTfLiteInt32: {
|
case kTfLiteInt32: {
|
||||||
int32_t pad_value =
|
int32_t pad_value =
|
||||||
|
Loading…
Reference in New Issue
Block a user