Use RetrieveBuiltinData and RetrieveCustomInitialData instead of directly accessing builtin_data and custom_initial_data.

PiperOrigin-RevId: 314183617
Change-Id: I8494bf3fe28fc8e3eeba21611beccc48354eac96
This commit is contained in:
Robert David 2020-06-01 12:44:39 -07:00 committed by TensorFlower Gardener
parent 6f22fa9376
commit 21896de7bc
1 changed files with 49 additions and 83 deletions

View File

@ -422,11 +422,8 @@ class AddOperationParser : public TFLiteOperationParser {
AddAttributes attr; AddAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
node->operation.attributes = std::move(attr); node->operation.attributes = std::move(attr);
const auto* tf_options = const TfLiteAddParams* tf_options;
static_cast<const TfLiteAddParams*>(tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
node); node);
} }
@ -499,11 +496,8 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
break; break;
} }
} }
const auto* tf_options = static_cast<const TfLiteConcatenationParams*>( const TfLiteConcatenationParams* tf_options;
tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, node)); graph, node));
node->operation.attributes = attr; node->operation.attributes = attr;
@ -601,11 +595,8 @@ class Conv2DOperationParser : public TFLiteOperationParser {
} }
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
const auto* tf_options = const TfLiteConvParams* tf_options;
static_cast<const TfLiteConvParams*>(tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
attr.dilations = HW(tf_options->dilation_height_factor, attr.dilations = HW(tf_options->dilation_height_factor,
tf_options->dilation_width_factor); tf_options->dilation_width_factor);
@ -639,17 +630,19 @@ class Convolution2DTransposeBiasParser : public TFLiteOperationParser {
RETURN_IF_ERROR(reader->AddInput(node, 0)); RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node)); RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* params = static_cast<const TfLiteTransposeConvParams*>( const TfLiteTransposeConvParams* tf_options;
tflite_node->custom_initial_data); auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
ConvolutionTransposedAttributes attr; ConvolutionTransposedAttributes attr;
attr.stride = attr.stride = status.ok()
params ? HW(params->stride_height, params->stride_width) : HW(1, 1); ? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(params->padding, graph->FindInputs(node->id)[0]->tensor.shape, UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
&attr); graph->FindInputs(node->id)[0]->tensor.shape, &attr);
node->operation.attributes = std::move(attr); node->operation.attributes = std::move(attr);
return absl::OkStatus(); return absl::OkStatus();
@ -874,17 +867,15 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
TfLiteFusedActivation activation = kTfLiteActNone; TfLiteFusedActivation activation = kTfLiteActNone;
switch (operation_type_) { switch (operation_type_) {
case OperationType::SUB: { case OperationType::SUB: {
const auto* tf_options = const TfLiteSubParams* tf_options;
static_cast<const TfLiteSubParams*>(tflite_node->builtin_data); if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
if (tf_options != nullptr) {
activation = tf_options->activation; activation = tf_options->activation;
} }
break; break;
} }
case OperationType::DIV: { case OperationType::DIV: {
const auto* tf_options = const TfLiteDivParams* tf_options;
static_cast<const TfLiteDivParams*>(tflite_node->builtin_data); if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
if (tf_options != nullptr) {
activation = tf_options->activation; activation = tf_options->activation;
} }
break; break;
@ -1002,8 +993,8 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode(); Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0)); RETURN_IF_ERROR(reader->AddInput(node, 0));
const auto* tf_options = static_cast<const TfLiteFullyConnectedParams*>( const TfLiteFullyConnectedParams* tf_options;
tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->weights_format != if (tf_options->weights_format !=
kTfLiteFullyConnectedWeightsFormatDefault) { kTfLiteFullyConnectedWeightsFormatDefault) {
return absl::UnimplementedError( return absl::UnimplementedError(
@ -1112,12 +1103,9 @@ class LSTMOperationParser : public TFLiteOperationParser {
return absl::InvalidArgumentError("LSTM should have 4 output tensors"); return absl::InvalidArgumentError("LSTM should have 4 output tensors");
} }
const auto* params = const TfLiteLSTMParams* tf_options;
static_cast<const TfLiteLSTMParams*>(tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (!params) { RETURN_IF_ERROR(CheckParameters(tf_options));
return absl::InternalError("Missing tflite params");
}
RETURN_IF_ERROR(CheckParameters(params));
Node* concat_node = graph->NewNode(); Node* concat_node = graph->NewNode();
concat_node->operation.type = ToString(OperationType::CONCAT); concat_node->operation.type = ToString(OperationType::CONCAT);
@ -1251,11 +1239,8 @@ class MulOperationParser : public TFLiteOperationParser {
constant_dims, graph, reader)); constant_dims, graph, reader));
} }
const auto* tf_options = const TfLiteMulParams* tf_options;
static_cast<const TfLiteMulParams*>(tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (!tf_options) {
return absl::InternalError("Missing TfLiteMulParams");
}
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph, return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
node); node);
} }
@ -1338,8 +1323,8 @@ class PadOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node, const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final { const TfLiteRegistration* registration) final {
if (mirror_pad_) { if (mirror_pad_) {
auto* tf_options = static_cast<const TfLiteMirrorPaddingParams*>( const TfLiteMirrorPaddingParams* tf_options;
tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->mode != if (tf_options->mode !=
TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) { TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
@ -1444,14 +1429,9 @@ class Pooling2DOperationParser : public TFLiteOperationParser {
// is MaxPoolingWithArgmax2D. There is no way to read // is MaxPoolingWithArgmax2D. There is no way to read
// tflite_node->builtin_code, so, simply check whether custom data is // tflite_node->builtin_code, so, simply check whether custom data is
// available. // available.
auto* tf_options = const TfLitePoolParams* tf_options;
static_cast<const TfLitePoolParams*>(tflite_node->custom_initial_data); if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
if (!tf_options) { RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
tf_options =
static_cast<const TfLitePoolParams*>(tflite_node->builtin_data);
}
if (!tf_options) {
return absl::InternalError("Missing tflite params");
} }
std::vector<uint32_t> max_tensor_id{0}; std::vector<uint32_t> max_tensor_id{0};
@ -1637,10 +1617,8 @@ class Resize2DOperationParser : public TFLiteOperationParser {
template <class T> template <class T>
absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node, absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
bool* align_corners) { bool* align_corners) {
const auto* tf_options = static_cast<const T*>(tflite_node->builtin_data); const T* tf_options;
if (!tf_options) { RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return absl::InternalError("Missing tflite params");
}
*align_corners = tf_options->align_corners; *align_corners = tf_options->align_corners;
return absl::OkStatus(); return absl::OkStatus();
} }
@ -1648,12 +1626,8 @@ class Resize2DOperationParser : public TFLiteOperationParser {
absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node, absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
bool* half_pixel_centers) { bool* half_pixel_centers) {
if (sampling_type_ == SamplingType::BILINEAR) { if (sampling_type_ == SamplingType::BILINEAR) {
const auto* tf_options = const TfLiteResizeBilinearParams* tf_options;
static_cast<TfLiteResizeBilinearParams*>(tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (!tf_options) {
return absl::InternalError(
"Missing tflite params for ResizeBilinear op");
}
if (tf_options->align_corners && tf_options->half_pixel_centers) { if (tf_options->align_corners && tf_options->half_pixel_centers) {
return absl::InternalError( return absl::InternalError(
"If half_pixel_centers is True, align_corners must be False."); "If half_pixel_centers is True, align_corners must be False.");
@ -1809,11 +1783,8 @@ class SoftmaxOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(reader->AddInput(node, 0)); RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node)); RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options = const TfLiteSoftmaxParams* tf_options;
static_cast<const TfLiteSoftmaxParams*>(tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
if (tf_options->beta != 1) { if (tf_options->beta != 1) {
// there is multiply by scalar operation fused in softmax. Make a layer // there is multiply by scalar operation fused in softmax. Make a layer
// out of it before softmax. // out of it before softmax.
@ -1857,8 +1828,8 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser {
node->operation.type = ToString(OperationType::SPACE_TO_DEPTH); node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
RETURN_IF_ERROR(reader->AddInput(node, 0)); RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node)); RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options = const TfLiteSpaceToDepthParams* tf_options;
static_cast<const TfLiteSpaceToDepthParams*>(tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
SpaceToDepthAttributes attr; SpaceToDepthAttributes attr;
attr.block_size = tf_options->block_size; attr.block_size = tf_options->block_size;
node->operation.attributes = attr; node->operation.attributes = attr;
@ -1898,14 +1869,12 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
"Slicing is supported for 3 or 4 dimensional tensors only."); "Slicing is supported for 3 or 4 dimensional tensors only.");
} }
const auto* tf_options = const TfLiteStridedSliceParams* tf_options;
static_cast<const TfLiteStridedSliceParams*>(tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
SliceAttributes attr; SliceAttributes attr;
if (read_without_batch) { if (read_without_batch) {
RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options, RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
@ -2074,11 +2043,9 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
RETURN_IF_ERROR(reader->AddOutputs(node)); RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options = static_cast<const TfLiteTransposeConvParams*>( const TfLiteTransposeConvParams* tf_options;
tflite_node->builtin_data); RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (!tf_options) {
return absl::InternalError("Missing tflite options.");
}
ConvolutionTransposedAttributes attr; ConvolutionTransposedAttributes attr;
attr.stride = tf_options attr.stride = tf_options
? HW(tf_options->stride_height, tf_options->stride_width) ? HW(tf_options->stride_height, tf_options->stride_width)
@ -2157,11 +2124,10 @@ class Unpooling2DOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(reader->AddOutputs(node)); RETURN_IF_ERROR(reader->AddOutputs(node));
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
MaxUnpooling2DAttributes attr; MaxUnpooling2DAttributes attr;
const auto* tf_options =
static_cast<const TfLitePoolParams*>(tflite_node->custom_initial_data); const TfLitePoolParams* tf_options;
if (!tf_options) { RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
return absl::InternalError("Missing tflite params");
}
attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width); attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
UpdatePadding(tf_options->padding, input_shape, &attr); UpdatePadding(tf_options->padding, input_shape, &attr);