From 21896de7bc18e5ac0fb876c86e05adda0669e547 Mon Sep 17 00:00:00 2001 From: Robert David Date: Mon, 1 Jun 2020 12:44:39 -0700 Subject: [PATCH] Use RetrieveBuiltinData and RetrieveCustomInitialData instead of directly accessing builtin_data and custom_initial_data. PiperOrigin-RevId: 314183617 Change-Id: I8494bf3fe28fc8e3eeba21611beccc48354eac96 --- .../delegates/gpu/common/model_builder.cc | 132 +++++++----------- 1 file changed, 49 insertions(+), 83 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 29c819f7800..4501ec0f0e0 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -422,11 +422,8 @@ class AddOperationParser : public TFLiteOperationParser { AddAttributes attr; RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); node->operation.attributes = std::move(attr); - const auto* tf_options = - static_cast(tflite_node->builtin_data); - if (!tf_options) { - return absl::InternalError("Missing tflite params"); - } + const TfLiteAddParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node); } @@ -499,11 +496,8 @@ class ConcatenationOperationParser : public TFLiteOperationParser { break; } } - const auto* tf_options = static_cast( - tflite_node->builtin_data); - if (!tf_options) { - return absl::InternalError("Missing tflite params"); - } + const TfLiteConcatenationParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node)); node->operation.attributes = attr; @@ -601,11 +595,8 @@ class Conv2DOperationParser : public TFLiteOperationParser { } reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional - const auto* tf_options = - static_cast(tflite_node->builtin_data); - if (!tf_options) { - return absl::InternalError("Missing tflite params"); - } + const TfLiteConvParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); attr.dilations = HW(tf_options->dilation_height_factor, tf_options->dilation_width_factor); @@ -639,17 +630,19 @@ class Convolution2DTransposeBiasParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddInput(node, 0)); RETURN_IF_ERROR(reader->AddOutputs(node)); - const auto* params = static_cast( - tflite_node->custom_initial_data); + const TfLiteTransposeConvParams* tf_options; + auto status = RetrieveCustomInitialData(tflite_node, &tf_options); + ConvolutionTransposedAttributes attr; - attr.stride = - params ? HW(params->stride_height, params->stride_width) : HW(1, 1); + attr.stride = status.ok() + ? HW(tf_options->stride_height, tf_options->stride_width) + : HW(1, 1); RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional - UpdatePadding(params->padding, graph->FindInputs(node->id)[0]->tensor.shape, - &attr); + UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown, + graph->FindInputs(node->id)[0]->tensor.shape, &attr); node->operation.attributes = std::move(attr); return absl::OkStatus(); @@ -874,17 +867,15 @@ class ElementwiseOperationParser : public TFLiteOperationParser { TfLiteFusedActivation activation = kTfLiteActNone; switch (operation_type_) { case OperationType::SUB: { - const auto* tf_options = - static_cast(tflite_node->builtin_data); - if (tf_options != nullptr) { + const TfLiteSubParams* tf_options; + if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) { activation = tf_options->activation; } break; } case OperationType::DIV: { - const auto* tf_options = - static_cast(tflite_node->builtin_data); - if (tf_options != nullptr) { + const TfLiteDivParams* tf_options; + if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) { activation = tf_options->activation; } break; @@ -1002,8 +993,8 @@ class FullyConnectedOperationParser : public TFLiteOperationParser { Node* node = graph->NewNode(); RETURN_IF_ERROR(reader->AddInput(node, 0)); - const auto* tf_options = static_cast( - tflite_node->builtin_data); + const TfLiteFullyConnectedParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) { return absl::UnimplementedError( @@ -1112,12 +1103,9 @@ class LSTMOperationParser : public TFLiteOperationParser { return absl::InvalidArgumentError("LSTM should have 4 output tensors"); } - const auto* params = - static_cast(tflite_node->builtin_data); - if (!params) { - return absl::InternalError("Missing tflite params"); - } - RETURN_IF_ERROR(CheckParameters(params)); + const TfLiteLSTMParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + RETURN_IF_ERROR(CheckParameters(tf_options)); Node* concat_node = graph->NewNode(); concat_node->operation.type = ToString(OperationType::CONCAT); @@ -1251,11 +1239,8 @@ class MulOperationParser : public TFLiteOperationParser { constant_dims, graph, reader)); } - const auto* tf_options = - static_cast(tflite_node->builtin_data); - if (!tf_options) { - return absl::InternalError("Missing TfLiteMulParams"); - } + const TfLiteMulParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, node); } @@ -1338,8 +1323,8 @@ class PadOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { if (mirror_pad_) { - auto* tf_options = static_cast( - tflite_node->builtin_data); + const TfLiteMirrorPaddingParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->mode != TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) { return absl::InvalidArgumentError( @@ -1444,14 +1429,9 @@ class Pooling2DOperationParser : public TFLiteOperationParser { // is MaxPoolingWithArgmax2D. There is no way to read // tflite_node->builtin_code, so, simply check whether custom data is // available. - auto* tf_options = - static_cast(tflite_node->custom_initial_data); - if (!tf_options) { - tf_options = - static_cast(tflite_node->builtin_data); - } - if (!tf_options) { - return absl::InternalError("Missing tflite params"); + const TfLitePoolParams* tf_options; + if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) { + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); } std::vector max_tensor_id{0}; @@ -1637,10 +1617,8 @@ class Resize2DOperationParser : public TFLiteOperationParser { template absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node, bool* align_corners) { - const auto* tf_options = static_cast(tflite_node->builtin_data); - if (!tf_options) { - return absl::InternalError("Missing tflite params"); - } + const T* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); *align_corners = tf_options->align_corners; return absl::OkStatus(); } @@ -1648,12 +1626,8 @@ class Resize2DOperationParser : public TFLiteOperationParser { absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node, bool* half_pixel_centers) { if (sampling_type_ == SamplingType::BILINEAR) { - const auto* tf_options = - static_cast(tflite_node->builtin_data); - if (!tf_options) { - return absl::InternalError( - "Missing tflite params for ResizeBilinear op"); - } + const TfLiteResizeBilinearParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->align_corners && tf_options->half_pixel_centers) { return absl::InternalError( "If half_pixel_centers is True, align_corners must be False."); @@ -1809,11 +1783,8 @@ class SoftmaxOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddInput(node, 0)); RETURN_IF_ERROR(reader->AddOutputs(node)); - const auto* tf_options = - static_cast(tflite_node->builtin_data); - if (!tf_options) { - return absl::InternalError("Missing tflite params"); - } + const TfLiteSoftmaxParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->beta != 1) { // there is multiply by scalar operation fused in softmax. Make a layer // out of it before softmax. @@ -1857,8 +1828,8 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser { node->operation.type = ToString(OperationType::SPACE_TO_DEPTH); RETURN_IF_ERROR(reader->AddInput(node, 0)); RETURN_IF_ERROR(reader->AddOutputs(node)); - const auto* tf_options = - static_cast(tflite_node->builtin_data); + const TfLiteSpaceToDepthParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); SpaceToDepthAttributes attr; attr.block_size = tf_options->block_size; node->operation.attributes = attr; @@ -1898,14 +1869,12 @@ class StridedSliceOperationParser : public TFLiteOperationParser { "Slicing is supported for 3 or 4 dimensional tensors only."); } - const auto* tf_options = - static_cast(tflite_node->builtin_data); - auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; - if (!tf_options) { - return absl::InternalError("Missing tflite params"); - } + const TfLiteStridedSliceParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); + auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; + SliceAttributes attr; if (read_without_batch) { RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options, @@ -2074,11 +2043,9 @@ class TransposeConvOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); RETURN_IF_ERROR(reader->AddOutputs(node)); - const auto* tf_options = static_cast( - tflite_node->builtin_data); - if (!tf_options) { - return absl::InternalError("Missing tflite options."); - } + const TfLiteTransposeConvParams* tf_options; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + ConvolutionTransposedAttributes attr; attr.stride = tf_options ? HW(tf_options->stride_height, tf_options->stride_width) @@ -2157,11 +2124,10 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddOutputs(node)); auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; MaxUnpooling2DAttributes attr; - const auto* tf_options = - static_cast(tflite_node->custom_initial_data); - if (!tf_options) { - return absl::InternalError("Missing tflite params"); - } + + const TfLitePoolParams* tf_options; + RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); + attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width); attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); UpdatePadding(tf_options->padding, input_shape, &attr);