Reorder class definition for easier discovery.

PiperOrigin-RevId: 334860738
Change-Id: Ibf9f7e581ea2d41ec3546fe391601025ac1ec239
This commit is contained in:
Juhyun Lee 2020-10-01 11:35:38 -07:00 committed by TensorFlower Gardener
parent 1ac96303da
commit 99908062d4

View File

@ -505,45 +505,6 @@ class Conv2DOperationParser : public TFLiteOperationParser {
}
};
// Custom op version of TRANSPOSE_CONV.
class Convolution2DTransposeBiasParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
const TfLiteTransposeConvParams* tf_options;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(
CheckStrides(tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const TfLiteTransposeConvParams* tf_options;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
ConvolutionTransposedAttributes attr;
attr.stride = status.ok()
? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
};
class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@ -2171,7 +2132,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
};
// Builtin op version of TRANSPOSE_CONV.
class TransposeConvOperationParser : public TFLiteOperationParser {
class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
@ -2215,6 +2176,45 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
}
};
// Custom op version of TRANSPOSE_CONV.
class TransposeConvCustomOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
const TfLiteTransposeConvParams* tf_options;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(
CheckStrides(tf_options->stride_height, tf_options->stride_width));
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const TfLiteTransposeConvParams* tf_options;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
ConvolutionTransposedAttributes attr;
attr.stride = status.ok()
? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
node->operation.attributes = std::move(attr);
return absl::OkStatus();
}
};
class TransposeOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@ -2755,12 +2755,12 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
case kTfLiteBuiltinTranspose:
return std::make_unique<TransposeOperationParser>();
case kTfLiteBuiltinTransposeConv:
return std::make_unique<TransposeConvOperationParser>();
return std::make_unique<TransposeConvBuiltinOperationParser>();
case kTfLiteBuiltinCustom:
const absl::string_view custom_name = registration->custom_name;
if (custom_name == "Convolution2DTransposeBias") {
return std::make_unique<Convolution2DTransposeBiasParser>();
return std::make_unique<TransposeConvCustomOperationParser>();
}
if (custom_name == "MaxPoolingWithArgmax2D") {
return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);