Separate out parse functionality into helper functions.

Ops in this change:
 * Con2D
 * DepthwiseConv2D
 * Reshape

PiperOrigin-RevId: 315424381
Change-Id: If2fb9187785eabd31b9d6588322cb70345650539
This commit is contained in:
Advait Jain 2020-06-08 22:49:07 -07:00 committed by TensorFlower Gardener
parent 6ddc7f8d99
commit 47b4145e68
2 changed files with 162 additions and 75 deletions

View File

@ -62,6 +62,17 @@ class SafeBuiltinDataAllocator {
BuiltinDataAllocator* allocator_;
};
// All the Parse functions take some pointers as params and this function has
// the common DCHECKs to catch if any of those are nullptr.
void CheckParsePointerParams(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
TFLITE_DCHECK(op != nullptr);
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(allocator != nullptr);
TFLITE_DCHECK(builtin_data != nullptr);
}
// Copies the contents from the flatbuffer int vector `flatbuffer` into the
// int array `buffer`. `flat_vector` and `buffer` represent the same
// configuration operation for a given operation.
@ -109,6 +120,17 @@ TfLiteFusedActivation ConvertActivation(ActivationFunctionType activation) {
return kTfLiteActNone;
}
// Converts the flatbuffer padding enum to what is used at runtime.
TfLitePadding ConvertPadding(Padding padding) {
switch (padding) {
case Padding_SAME:
return kTfLitePaddingSame;
case Padding_VALID:
return kTfLitePaddingValid;
}
return kTfLitePaddingUnknown;
}
} // namespace
TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
@ -155,6 +177,74 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
}
}
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteConvParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteConvParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const Conv2DOptions* schema_params = op->builtin_options_as_Conv2DOptions();
if (schema_params != nullptr) {
params->padding = ConvertPadding(schema_params->padding());
params->stride_width = schema_params->stride_w();
params->stride_height = schema_params->stride_h();
params->activation =
ConvertActivation(schema_params->fused_activation_function());
params->dilation_width_factor = schema_params->dilation_w_factor();
params->dilation_height_factor = schema_params->dilation_h_factor();
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteDepthwiseConvParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteDepthwiseConvParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const DepthwiseConv2DOptions* schema_params =
op->builtin_options_as_DepthwiseConv2DOptions();
if (schema_params != nullptr) {
params->padding = ConvertPadding(schema_params->padding());
params->stride_width = schema_params->stride_w();
params->stride_height = schema_params->stride_h();
params->depth_multiplier = schema_params->depth_multiplier();
params->activation =
ConvertActivation(schema_params->fused_activation_function());
params->dilation_width_factor = schema_params->dilation_w_factor();
params->dilation_height_factor = schema_params->dilation_h_factor();
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
@ -167,10 +257,7 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
TFLITE_DCHECK(op != nullptr);
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(allocator != nullptr);
TFLITE_DCHECK(builtin_data != nullptr);
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
@ -212,6 +299,47 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
return kTfLiteOk;
}
TfLiteStatus ParseReshape(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteReshapeParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteReshapeParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const ReshapeOptions* schema_params = op->builtin_options_as_ReshapeOptions();
if (schema_params != nullptr) {
const flatbuffers::Vector<int32_t>* new_shape = schema_params->new_shape();
// TODO(b/147203660): We need to figure out when dynamic reshape
// (new_shape is a tensor) happens, why the option is not a nullptr.
// But nonethless, we should only copy when new_shape is not a nullptr.
if (new_shape != nullptr) {
TF_LITE_ENSURE_STATUS(
FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
params->shape, error_reporter, "reshape"));
params->num_dimensions = new_shape->size();
} else {
// TODO(b/157480169) TODO(b/147203660): We should either return
// kTfLiteError or fill in some reasonable defaults in the params struct.
// We are not doing so until we better undertand the ramifications of
// changing the legacy behavior.
}
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
@ -224,10 +352,7 @@ TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
TFLITE_DCHECK(op != nullptr);
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(allocator != nullptr);
TFLITE_DCHECK(builtin_data != nullptr);
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteSoftmaxParams,
@ -252,10 +377,7 @@ TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator,
TfLiteStatus ParseSvdf(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
TFLITE_DCHECK(op != nullptr);
TFLITE_DCHECK(error_reporter != nullptr);
TFLITE_DCHECK(allocator != nullptr);
TFLITE_DCHECK(builtin_data != nullptr);
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteSVDFParams,
@ -283,15 +405,6 @@ TfLiteStatus ParseSvdf(const Operator* op, BuiltinOperator,
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
auto parse_padding = [](Padding padding) {
switch (padding) {
case Padding_SAME:
return kTfLitePaddingSame;
case Padding_VALID:
return kTfLitePaddingValid;
}
return kTfLitePaddingUnknown;
};
auto parseLSHProjectionType = [](LSHProjectionType type) {
switch (type) {
case LSHProjectionType_SPARSE:
@ -317,6 +430,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
SafeBuiltinDataAllocator safe_allocator(allocator);
*builtin_data = nullptr;
switch (op_type) {
case BuiltinOperator_CONV_2D: {
return ParseConv2D(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_DEPTHWISE_CONV_2D: {
return ParseDepthwiseConv2D(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_DEQUANTIZE: {
return ParseDequantize(op, op_type, error_reporter, allocator,
builtin_data);
@ -332,6 +454,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data);
}
case BuiltinOperator_RESHAPE: {
return ParseReshape(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_SOFTMAX: {
return ParseSoftmax(op, op_type, error_reporter, allocator, builtin_data);
}
@ -340,22 +466,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
return ParseSvdf(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_CONV_2D: {
auto params = safe_allocator.Allocate<TfLiteConvParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
params->padding = parse_padding(conv_params->padding());
params->stride_width = conv_params->stride_w();
params->stride_height = conv_params->stride_h();
params->activation =
ConvertActivation(conv_params->fused_activation_function());
params->dilation_width_factor = conv_params->dilation_w_factor();
params->dilation_height_factor = conv_params->dilation_h_factor();
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_CAST: {
auto params = safe_allocator.Allocate<TfLiteCastParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
@ -386,7 +496,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
auto params = safe_allocator.Allocate<TfLitePoolParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
params->padding = parse_padding(pool_params->padding());
params->padding = ConvertPadding(pool_params->padding());
params->stride_width = pool_params->stride_w();
params->stride_height = pool_params->stride_h();
params->filter_width = pool_params->filter_width();
@ -397,24 +507,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_DEPTHWISE_CONV_2D: {
auto params = safe_allocator.Allocate<TfLiteDepthwiseConvParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* conv_params =
op->builtin_options_as_DepthwiseConv2DOptions()) {
params->padding = parse_padding(conv_params->padding());
params->stride_width = conv_params->stride_w();
params->stride_height = conv_params->stride_h();
params->depth_multiplier = conv_params->depth_multiplier();
params->activation =
ConvertActivation(conv_params->fused_activation_function());
params->dilation_width_factor = conv_params->dilation_w_factor();
params->dilation_height_factor = conv_params->dilation_h_factor();
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
auto params = safe_allocator.Allocate<TfLiteSequenceRNNParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
@ -644,24 +736,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_RESHAPE: {
auto params = safe_allocator.Allocate<TfLiteReshapeParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
auto* new_shape = schema_params->new_shape();
// TODO(b/147203660): We need to figure out when dynamic reshape
// (new_shape is a tensor) happens, why the option is not a nullptr.
// But nonethless, we should only copy when new_shape is not a nullptr.
if (new_shape) {
TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray(
sizeof(params->shape), new_shape, params->shape, error_reporter,
"reshape"));
params->num_dimensions = new_shape->size();
}
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_SKIP_GRAM: {
auto params = safe_allocator.Allocate<TfLiteSkipGramParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
@ -791,7 +865,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* transpose_conv_params =
op->builtin_options_as_TransposeConvOptions()) {
params->padding = parse_padding(transpose_conv_params->padding());
params->padding = ConvertPadding(transpose_conv_params->padding());
params->stride_width = transpose_conv_params->stride_w();
params->stride_height = transpose_conv_params->stride_h();
}

View File

@ -75,6 +75,15 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
// removed once we are no longer using ParseOpData for the OpResolver
// implementation in micro.
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseDequantize(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
@ -90,6 +99,10 @@ TfLiteStatus ParseQuantize(const Operator* op, BuiltinOperator op_type,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseReshape(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseSoftmax(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);