Separate out parse functionality into helper functions.

Ops in this change:
 * L2Normalization
 * Less
 * LessEqual
 * Log
 * LogicalAnd
 * LogicalNot
 * LogicalOr
 * Logistic
 * Maximum

PiperOrigin-RevId: 318595850
Change-Id: I17605d841170ae9cbd4e92d44432327ab40b401b
This commit is contained in:
Advait Jain 2020-06-26 22:14:38 -07:00 committed by TensorFlower Gardener
parent 064994341a
commit ff28621f73
3 changed files with 190 additions and 47 deletions

View File

@ -461,6 +461,97 @@ TfLiteStatus ParseGreaterEqual(const Operator*, BuiltinOperator, ErrorReporter*,
return kTfLiteOk;
}
TfLiteStatus ParseL2Normalization(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteL2NormParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteL2NormParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const L2NormOptions* schema_params = op->builtin_options_as_L2NormOptions();
if (schema_params != nullptr) {
params->activation =
ConvertActivation(schema_params->fused_activation_function());
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseLess(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseLessEqual(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseLog(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseLogicalAnd(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseLogicalNot(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseLogicalOr(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseLogistic(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseMaximum(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParsePool(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
@ -684,14 +775,56 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data);
}
case BuiltinOperator_MAX_POOL_2D: {
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
case BuiltinOperator_L2_NORMALIZATION: {
return ParseL2Normalization(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_L2_POOL_2D: {
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_LESS: {
return ParseLess(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_LESS_EQUAL: {
return ParseLessEqual(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_LOG: {
return ParseLog(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_LOGICAL_AND: {
return ParseLogicalAnd(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_LOGICAL_NOT: {
return ParseLogicalNot(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_LOGICAL_OR: {
return ParseLogicalOr(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_LOGISTIC: {
return ParseLogistic(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_MAXIMUM: {
return ParseMaximum(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_MAX_POOL_2D: {
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_QUANTIZE: {
return ParseQuantize(op, op_type, error_reporter, allocator,
builtin_data);
@ -820,16 +953,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_L2_NORMALIZATION: {
auto params = safe_allocator.Allocate<TfLiteL2NormParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* schema_params = op->builtin_options_as_L2NormOptions()) {
params->activation =
ConvertActivation(schema_params->fused_activation_function());
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
auto params = safe_allocator.Allocate<TfLiteLocalResponseNormParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
@ -1214,14 +1337,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_EXP:
case BuiltinOperator_EXPAND_DIMS:
case BuiltinOperator_HARD_SWISH:
case BuiltinOperator_LESS:
case BuiltinOperator_LESS_EQUAL:
case BuiltinOperator_LOG:
case BuiltinOperator_LOGISTIC:
case BuiltinOperator_LOG_SOFTMAX:
case BuiltinOperator_MATRIX_DIAG:
case BuiltinOperator_MATRIX_SET_DIAG:
case BuiltinOperator_MAXIMUM:
case BuiltinOperator_MINIMUM:
case BuiltinOperator_NEG:
case BuiltinOperator_NOT_EQUAL:
@ -1244,9 +1362,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_TRANSPOSE:
case BuiltinOperator_POW:
case BuiltinOperator_LOGICAL_OR:
case BuiltinOperator_LOGICAL_AND:
case BuiltinOperator_LOGICAL_NOT:
case BuiltinOperator_FLOOR_DIV:
case BuiltinOperator_SQUARE:
case BuiltinOperator_ZEROS_LIKE:

View File

@ -140,6 +140,48 @@ TfLiteStatus ParseGreaterEqual(const Operator* op, BuiltinOperator op_type,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseL2Normalization(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLess(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseLessEqual(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLog(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseLogicalAnd(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLogicalNot(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLogicalOr(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseLogistic(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseMaximum(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParsePool(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);

View File

@ -201,67 +201,53 @@ class MicroMutableOpResolver : public MicroOpResolver {
}
TfLiteStatus AddL2Normalization() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_L2_NORMALIZATION,
*tflite::ops::micro::Register_L2_NORMALIZATION(),
ParseOpData);
ParseL2Normalization);
}
TfLiteStatus AddLess() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LESS,
*tflite::ops::micro::Register_LESS(), ParseOpData);
*tflite::ops::micro::Register_LESS(), ParseLess);
}
TfLiteStatus AddLessEqual() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LESS_EQUAL,
*tflite::ops::micro::Register_LESS_EQUAL(), ParseOpData);
*tflite::ops::micro::Register_LESS_EQUAL(),
ParseLessEqual);
}
TfLiteStatus AddLog() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LOG, *tflite::ops::micro::Register_LOG(),
ParseOpData);
ParseLog);
}
TfLiteStatus AddLogicalAnd() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LOGICAL_AND,
*tflite::ops::micro::Register_LOGICAL_AND(), ParseOpData);
*tflite::ops::micro::Register_LOGICAL_AND(),
ParseLogicalAnd);
}
TfLiteStatus AddLogicalNot() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LOGICAL_NOT,
*tflite::ops::micro::Register_LOGICAL_NOT(), ParseOpData);
*tflite::ops::micro::Register_LOGICAL_NOT(),
ParseLogicalNot);
}
TfLiteStatus AddLogicalOr() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LOGICAL_OR,
*tflite::ops::micro::Register_LOGICAL_OR(), ParseOpData);
*tflite::ops::micro::Register_LOGICAL_OR(),
ParseLogicalOr);
}
TfLiteStatus AddLogistic() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_LOGISTIC,
*tflite::ops::micro::Register_LOGISTIC(), ParseOpData);
*tflite::ops::micro::Register_LOGISTIC(), ParseLogistic);
}
TfLiteStatus AddMaximum() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_MAXIMUM,
*tflite::ops::micro::Register_MAXIMUM(), ParseOpData);
*tflite::ops::micro::Register_MAXIMUM(), ParseMaximum);
}
TfLiteStatus AddMaxPool2D() {