Reorder class definition for easier discovery.
PiperOrigin-RevId: 334860738 Change-Id: Ibf9f7e581ea2d41ec3546fe391601025ac1ec239
This commit is contained in:
parent
1ac96303da
commit
99908062d4
@ -505,45 +505,6 @@ class Conv2DOperationParser : public TFLiteOperationParser {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Custom op version of TRANSPOSE_CONV.
|
|
||||||
class Convolution2DTransposeBiasParser : public TFLiteOperationParser {
|
|
||||||
public:
|
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration) final {
|
|
||||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
|
||||||
const TfLiteTransposeConvParams* tf_options;
|
|
||||||
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
|
||||||
RETURN_IF_ERROR(
|
|
||||||
CheckStrides(tf_options->stride_height, tf_options->stride_width));
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
|
||||||
const TfLiteRegistration* registration,
|
|
||||||
GraphFloat32* graph, ObjectReader* reader) final {
|
|
||||||
auto* node = graph->NewNode();
|
|
||||||
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
|
||||||
|
|
||||||
const TfLiteTransposeConvParams* tf_options;
|
|
||||||
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
|
|
||||||
|
|
||||||
ConvolutionTransposedAttributes attr;
|
|
||||||
attr.stride = status.ok()
|
|
||||||
? HW(tf_options->stride_height, tf_options->stride_width)
|
|
||||||
: HW(1, 1);
|
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
|
|
||||||
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
|
|
||||||
|
|
||||||
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
|
|
||||||
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
|
|
||||||
node->operation.attributes = std::move(attr);
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
|
class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
absl::Status IsSupported(const TfLiteContext* context,
|
||||||
@ -2171,7 +2132,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Builtin op version of TRANSPOSE_CONV.
|
// Builtin op version of TRANSPOSE_CONV.
|
||||||
class TransposeConvOperationParser : public TFLiteOperationParser {
|
class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
absl::Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
@ -2215,6 +2176,45 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Custom op version of TRANSPOSE_CONV.
|
||||||
|
class TransposeConvCustomOperationParser : public TFLiteOperationParser {
|
||||||
|
public:
|
||||||
|
absl::Status IsSupported(const TfLiteContext* context,
|
||||||
|
const TfLiteNode* tflite_node,
|
||||||
|
const TfLiteRegistration* registration) final {
|
||||||
|
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||||
|
const TfLiteTransposeConvParams* tf_options;
|
||||||
|
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
||||||
|
RETURN_IF_ERROR(
|
||||||
|
CheckStrides(tf_options->stride_height, tf_options->stride_width));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||||
|
const TfLiteRegistration* registration,
|
||||||
|
GraphFloat32* graph, ObjectReader* reader) final {
|
||||||
|
auto* node = graph->NewNode();
|
||||||
|
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
|
||||||
|
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||||
|
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||||
|
|
||||||
|
const TfLiteTransposeConvParams* tf_options;
|
||||||
|
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
|
||||||
|
|
||||||
|
ConvolutionTransposedAttributes attr;
|
||||||
|
attr.stride = status.ok()
|
||||||
|
? HW(tf_options->stride_height, tf_options->stride_width)
|
||||||
|
: HW(1, 1);
|
||||||
|
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
|
||||||
|
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
|
||||||
|
|
||||||
|
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
|
||||||
|
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
|
||||||
|
node->operation.attributes = std::move(attr);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class TransposeOperationParser : public TFLiteOperationParser {
|
class TransposeOperationParser : public TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
absl::Status IsSupported(const TfLiteContext* context,
|
absl::Status IsSupported(const TfLiteContext* context,
|
||||||
@ -2755,12 +2755,12 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
|||||||
case kTfLiteBuiltinTranspose:
|
case kTfLiteBuiltinTranspose:
|
||||||
return std::make_unique<TransposeOperationParser>();
|
return std::make_unique<TransposeOperationParser>();
|
||||||
case kTfLiteBuiltinTransposeConv:
|
case kTfLiteBuiltinTransposeConv:
|
||||||
return std::make_unique<TransposeConvOperationParser>();
|
return std::make_unique<TransposeConvBuiltinOperationParser>();
|
||||||
|
|
||||||
case kTfLiteBuiltinCustom:
|
case kTfLiteBuiltinCustom:
|
||||||
const absl::string_view custom_name = registration->custom_name;
|
const absl::string_view custom_name = registration->custom_name;
|
||||||
if (custom_name == "Convolution2DTransposeBias") {
|
if (custom_name == "Convolution2DTransposeBias") {
|
||||||
return std::make_unique<Convolution2DTransposeBiasParser>();
|
return std::make_unique<TransposeConvCustomOperationParser>();
|
||||||
}
|
}
|
||||||
if (custom_name == "MaxPoolingWithArgmax2D") {
|
if (custom_name == "MaxPoolingWithArgmax2D") {
|
||||||
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
|
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
|
||||||
|
Loading…
Reference in New Issue
Block a user