Separate out parse functionality into helper functions.
Ops in this change: * Con2D * DepthwiseConv2D * Reshape PiperOrigin-RevId: 315424381 Change-Id: If2fb9187785eabd31b9d6588322cb70345650539
This commit is contained in:
parent
6ddc7f8d99
commit
47b4145e68
@ -62,6 +62,17 @@ class SafeBuiltinDataAllocator {
|
||||
BuiltinDataAllocator* allocator_;
|
||||
};
|
||||
|
||||
// All the Parse functions take some pointers as params and this function has
|
||||
// the common DCHECKs to catch if any of those are nullptr.
|
||||
void CheckParsePointerParams(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data) {
|
||||
TFLITE_DCHECK(op != nullptr);
|
||||
TFLITE_DCHECK(error_reporter != nullptr);
|
||||
TFLITE_DCHECK(allocator != nullptr);
|
||||
TFLITE_DCHECK(builtin_data != nullptr);
|
||||
}
|
||||
|
||||
// Copies the contents from the flatbuffer int vector `flatbuffer` into the
|
||||
// int array `buffer`. `flat_vector` and `buffer` represent the same
|
||||
// configuration operation for a given operation.
|
||||
@ -109,6 +120,17 @@ TfLiteFusedActivation ConvertActivation(ActivationFunctionType activation) {
|
||||
return kTfLiteActNone;
|
||||
}
|
||||
|
||||
// Converts the flatbuffer padding enum to what is used at runtime.
|
||||
TfLitePadding ConvertPadding(Padding padding) {
|
||||
switch (padding) {
|
||||
case Padding_SAME:
|
||||
return kTfLitePaddingSame;
|
||||
case Padding_VALID:
|
||||
return kTfLitePaddingValid;
|
||||
}
|
||||
return kTfLitePaddingUnknown;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
||||
@ -155,6 +177,74 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLiteConvParams,
|
||||
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLiteConvParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
|
||||
const Conv2DOptions* schema_params = op->builtin_options_as_Conv2DOptions();
|
||||
|
||||
if (schema_params != nullptr) {
|
||||
params->padding = ConvertPadding(schema_params->padding());
|
||||
params->stride_width = schema_params->stride_w();
|
||||
params->stride_height = schema_params->stride_h();
|
||||
params->activation =
|
||||
ConvertActivation(schema_params->fused_activation_function());
|
||||
|
||||
params->dilation_width_factor = schema_params->dilation_w_factor();
|
||||
params->dilation_height_factor = schema_params->dilation_h_factor();
|
||||
} else {
|
||||
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||
// reasonable defaults in the params struct. We are not doing so until we
|
||||
// better undertand the ramifications of changing the legacy behavior.
|
||||
}
|
||||
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
|
||||
std::unique_ptr<TfLiteDepthwiseConvParams,
|
||||
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLiteDepthwiseConvParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
|
||||
const DepthwiseConv2DOptions* schema_params =
|
||||
op->builtin_options_as_DepthwiseConv2DOptions();
|
||||
|
||||
if (schema_params != nullptr) {
|
||||
params->padding = ConvertPadding(schema_params->padding());
|
||||
params->stride_width = schema_params->stride_w();
|
||||
params->stride_height = schema_params->stride_h();
|
||||
params->depth_multiplier = schema_params->depth_multiplier();
|
||||
params->activation =
|
||||
ConvertActivation(schema_params->fused_activation_function());
|
||||
|
||||
params->dilation_width_factor = schema_params->dilation_w_factor();
|
||||
params->dilation_height_factor = schema_params->dilation_h_factor();
|
||||
} else {
|
||||
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||
// reasonable defaults in the params struct. We are not doing so until we
|
||||
// better undertand the ramifications of changing the legacy behavior.
|
||||
}
|
||||
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
@ -167,10 +257,7 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data) {
|
||||
TFLITE_DCHECK(op != nullptr);
|
||||
TFLITE_DCHECK(error_reporter != nullptr);
|
||||
TFLITE_DCHECK(allocator != nullptr);
|
||||
TFLITE_DCHECK(builtin_data != nullptr);
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
|
||||
@ -212,6 +299,47 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseReshape(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
|
||||
std::unique_ptr<TfLiteReshapeParams,
|
||||
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLiteReshapeParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
|
||||
const ReshapeOptions* schema_params = op->builtin_options_as_ReshapeOptions();
|
||||
|
||||
if (schema_params != nullptr) {
|
||||
const flatbuffers::Vector<int32_t>* new_shape = schema_params->new_shape();
|
||||
// TODO(b/147203660): We need to figure out when dynamic reshape
|
||||
// (new_shape is a tensor) happens, why the option is not a nullptr.
|
||||
// But nonethless, we should only copy when new_shape is not a nullptr.
|
||||
if (new_shape != nullptr) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
|
||||
params->shape, error_reporter, "reshape"));
|
||||
params->num_dimensions = new_shape->size();
|
||||
} else {
|
||||
// TODO(b/157480169) TODO(b/147203660): We should either return
|
||||
// kTfLiteError or fill in some reasonable defaults in the params struct.
|
||||
// We are not doing so until we better undertand the ramifications of
|
||||
// changing the legacy behavior.
|
||||
}
|
||||
} else {
|
||||
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||
// reasonable defaults in the params struct. We are not doing so until we
|
||||
// better undertand the ramifications of changing the legacy behavior.
|
||||
}
|
||||
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
@ -224,10 +352,7 @@ TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data) {
|
||||
TFLITE_DCHECK(op != nullptr);
|
||||
TFLITE_DCHECK(error_reporter != nullptr);
|
||||
TFLITE_DCHECK(allocator != nullptr);
|
||||
TFLITE_DCHECK(builtin_data != nullptr);
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLiteSoftmaxParams,
|
||||
@ -252,10 +377,7 @@ TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator,
|
||||
TfLiteStatus ParseSvdf(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
TFLITE_DCHECK(op != nullptr);
|
||||
TFLITE_DCHECK(error_reporter != nullptr);
|
||||
TFLITE_DCHECK(allocator != nullptr);
|
||||
TFLITE_DCHECK(builtin_data != nullptr);
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLiteSVDFParams,
|
||||
@ -283,15 +405,6 @@ TfLiteStatus ParseSvdf(const Operator* op, BuiltinOperator,
|
||||
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
auto parse_padding = [](Padding padding) {
|
||||
switch (padding) {
|
||||
case Padding_SAME:
|
||||
return kTfLitePaddingSame;
|
||||
case Padding_VALID:
|
||||
return kTfLitePaddingValid;
|
||||
}
|
||||
return kTfLitePaddingUnknown;
|
||||
};
|
||||
auto parseLSHProjectionType = [](LSHProjectionType type) {
|
||||
switch (type) {
|
||||
case LSHProjectionType_SPARSE:
|
||||
@ -317,6 +430,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
*builtin_data = nullptr;
|
||||
switch (op_type) {
|
||||
case BuiltinOperator_CONV_2D: {
|
||||
return ParseConv2D(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_DEPTHWISE_CONV_2D: {
|
||||
return ParseDepthwiseConv2D(op, op_type, error_reporter, allocator,
|
||||
builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_DEQUANTIZE: {
|
||||
return ParseDequantize(op, op_type, error_reporter, allocator,
|
||||
builtin_data);
|
||||
@ -332,6 +454,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_RESHAPE: {
|
||||
return ParseReshape(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_SOFTMAX: {
|
||||
return ParseSoftmax(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
@ -340,22 +466,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
return ParseSvdf(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_CONV_2D: {
|
||||
auto params = safe_allocator.Allocate<TfLiteConvParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
|
||||
params->padding = parse_padding(conv_params->padding());
|
||||
params->stride_width = conv_params->stride_w();
|
||||
params->stride_height = conv_params->stride_h();
|
||||
params->activation =
|
||||
ConvertActivation(conv_params->fused_activation_function());
|
||||
|
||||
params->dilation_width_factor = conv_params->dilation_w_factor();
|
||||
params->dilation_height_factor = conv_params->dilation_h_factor();
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_CAST: {
|
||||
auto params = safe_allocator.Allocate<TfLiteCastParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
@ -386,7 +496,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
auto params = safe_allocator.Allocate<TfLitePoolParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
|
||||
params->padding = parse_padding(pool_params->padding());
|
||||
params->padding = ConvertPadding(pool_params->padding());
|
||||
params->stride_width = pool_params->stride_w();
|
||||
params->stride_height = pool_params->stride_h();
|
||||
params->filter_width = pool_params->filter_width();
|
||||
@ -397,24 +507,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_DEPTHWISE_CONV_2D: {
|
||||
auto params = safe_allocator.Allocate<TfLiteDepthwiseConvParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* conv_params =
|
||||
op->builtin_options_as_DepthwiseConv2DOptions()) {
|
||||
params->padding = parse_padding(conv_params->padding());
|
||||
params->stride_width = conv_params->stride_w();
|
||||
params->stride_height = conv_params->stride_h();
|
||||
params->depth_multiplier = conv_params->depth_multiplier();
|
||||
params->activation =
|
||||
ConvertActivation(conv_params->fused_activation_function());
|
||||
|
||||
params->dilation_width_factor = conv_params->dilation_w_factor();
|
||||
params->dilation_height_factor = conv_params->dilation_h_factor();
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
|
||||
auto params = safe_allocator.Allocate<TfLiteSequenceRNNParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
@ -644,24 +736,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_RESHAPE: {
|
||||
auto params = safe_allocator.Allocate<TfLiteReshapeParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
|
||||
auto* new_shape = schema_params->new_shape();
|
||||
// TODO(b/147203660): We need to figure out when dynamic reshape
|
||||
// (new_shape is a tensor) happens, why the option is not a nullptr.
|
||||
// But nonethless, we should only copy when new_shape is not a nullptr.
|
||||
if (new_shape) {
|
||||
TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray(
|
||||
sizeof(params->shape), new_shape, params->shape, error_reporter,
|
||||
"reshape"));
|
||||
params->num_dimensions = new_shape->size();
|
||||
}
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_SKIP_GRAM: {
|
||||
auto params = safe_allocator.Allocate<TfLiteSkipGramParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
@ -791,7 +865,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* transpose_conv_params =
|
||||
op->builtin_options_as_TransposeConvOptions()) {
|
||||
params->padding = parse_padding(transpose_conv_params->padding());
|
||||
params->padding = ConvertPadding(transpose_conv_params->padding());
|
||||
params->stride_width = transpose_conv_params->stride_w();
|
||||
params->stride_height = transpose_conv_params->stride_h();
|
||||
}
|
||||
|
@ -75,6 +75,15 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
||||
// removed once we are no longer using ParseOpData for the OpResolver
|
||||
// implementation in micro.
|
||||
|
||||
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseDequantize(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
@ -90,6 +99,10 @@ TfLiteStatus ParseQuantize(const Operator* op, BuiltinOperator op_type,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseReshape(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
Loading…
x
Reference in New Issue
Block a user