Reorder class definition for easier discovery.
PiperOrigin-RevId: 334860738 Change-Id: Ibf9f7e581ea2d41ec3546fe391601025ac1ec239
This commit is contained in:
parent
1ac96303da
commit
99908062d4
@ -505,45 +505,6 @@ class Conv2DOperationParser : public TFLiteOperationParser {
|
||||
}
|
||||
};
|
||||
|
||||
// Custom op version of TRANSPOSE_CONV.
|
||||
class Convolution2DTransposeBiasParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||
const TfLiteTransposeConvParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
||||
RETURN_IF_ERROR(
|
||||
CheckStrides(tf_options->stride_height, tf_options->stride_width));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader) final {
|
||||
auto* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
const TfLiteTransposeConvParams* tf_options;
|
||||
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
|
||||
|
||||
ConvolutionTransposedAttributes attr;
|
||||
attr.stride = status.ok()
|
||||
? HW(tf_options->stride_height, tf_options->stride_width)
|
||||
: HW(1, 1);
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
|
||||
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
|
||||
|
||||
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
|
||||
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
|
||||
node->operation.attributes = std::move(attr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
@ -2171,7 +2132,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
};
|
||||
|
||||
// Builtin op version of TRANSPOSE_CONV.
|
||||
class TransposeConvOperationParser : public TFLiteOperationParser {
|
||||
class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
@ -2215,6 +2176,45 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
|
||||
}
|
||||
};
|
||||
|
||||
// Custom op version of TRANSPOSE_CONV.
|
||||
class TransposeConvCustomOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||
const TfLiteTransposeConvParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
||||
RETURN_IF_ERROR(
|
||||
CheckStrides(tf_options->stride_height, tf_options->stride_width));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader) final {
|
||||
auto* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
const TfLiteTransposeConvParams* tf_options;
|
||||
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
|
||||
|
||||
ConvolutionTransposedAttributes attr;
|
||||
attr.stride = status.ok()
|
||||
? HW(tf_options->stride_height, tf_options->stride_width)
|
||||
: HW(1, 1);
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
|
||||
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
|
||||
|
||||
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
|
||||
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
|
||||
node->operation.attributes = std::move(attr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
class TransposeOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
@ -2755,12 +2755,12 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
case kTfLiteBuiltinTranspose:
|
||||
return std::make_unique<TransposeOperationParser>();
|
||||
case kTfLiteBuiltinTransposeConv:
|
||||
return std::make_unique<TransposeConvOperationParser>();
|
||||
return std::make_unique<TransposeConvBuiltinOperationParser>();
|
||||
|
||||
case kTfLiteBuiltinCustom:
|
||||
const absl::string_view custom_name = registration->custom_name;
|
||||
if (custom_name == "Convolution2DTransposeBias") {
|
||||
return std::make_unique<Convolution2DTransposeBiasParser>();
|
||||
return std::make_unique<TransposeConvCustomOperationParser>();
|
||||
}
|
||||
if (custom_name == "MaxPoolingWithArgmax2D") {
|
||||
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
|
||||
|
Loading…
Reference in New Issue
Block a user