Use RetrieveBuiltinData and RetrieveCustomInitialData instead of directly accessing builtin_data and custom_initial_data.
PiperOrigin-RevId: 314183617 Change-Id: I8494bf3fe28fc8e3eeba21611beccc48354eac96
This commit is contained in:
parent
6f22fa9376
commit
21896de7bc
|
@ -422,11 +422,8 @@ class AddOperationParser : public TFLiteOperationParser {
|
|||
AddAttributes attr;
|
||||
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
|
||||
node->operation.attributes = std::move(attr);
|
||||
const auto* tf_options =
|
||||
static_cast<const TfLiteAddParams*>(tflite_node->builtin_data);
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing tflite params");
|
||||
}
|
||||
const TfLiteAddParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
|
||||
node);
|
||||
}
|
||||
|
@ -499,11 +496,8 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
|
|||
break;
|
||||
}
|
||||
}
|
||||
const auto* tf_options = static_cast<const TfLiteConcatenationParams*>(
|
||||
tflite_node->builtin_data);
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing tflite params");
|
||||
}
|
||||
const TfLiteConcatenationParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
|
||||
graph, node));
|
||||
node->operation.attributes = attr;
|
||||
|
@ -601,11 +595,8 @@ class Conv2DOperationParser : public TFLiteOperationParser {
|
|||
}
|
||||
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
|
||||
|
||||
const auto* tf_options =
|
||||
static_cast<const TfLiteConvParams*>(tflite_node->builtin_data);
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing tflite params");
|
||||
}
|
||||
const TfLiteConvParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
|
||||
attr.dilations = HW(tf_options->dilation_height_factor,
|
||||
tf_options->dilation_width_factor);
|
||||
|
@ -639,17 +630,19 @@ class Convolution2DTransposeBiasParser : public TFLiteOperationParser {
|
|||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
const auto* params = static_cast<const TfLiteTransposeConvParams*>(
|
||||
tflite_node->custom_initial_data);
|
||||
const TfLiteTransposeConvParams* tf_options;
|
||||
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
|
||||
|
||||
ConvolutionTransposedAttributes attr;
|
||||
attr.stride =
|
||||
params ? HW(params->stride_height, params->stride_width) : HW(1, 1);
|
||||
attr.stride = status.ok()
|
||||
? HW(tf_options->stride_height, tf_options->stride_width)
|
||||
: HW(1, 1);
|
||||
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
|
||||
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
|
||||
|
||||
UpdatePadding(params->padding, graph->FindInputs(node->id)[0]->tensor.shape,
|
||||
&attr);
|
||||
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
|
||||
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
|
||||
|
||||
node->operation.attributes = std::move(attr);
|
||||
return absl::OkStatus();
|
||||
|
@ -874,17 +867,15 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
|
|||
TfLiteFusedActivation activation = kTfLiteActNone;
|
||||
switch (operation_type_) {
|
||||
case OperationType::SUB: {
|
||||
const auto* tf_options =
|
||||
static_cast<const TfLiteSubParams*>(tflite_node->builtin_data);
|
||||
if (tf_options != nullptr) {
|
||||
const TfLiteSubParams* tf_options;
|
||||
if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
|
||||
activation = tf_options->activation;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OperationType::DIV: {
|
||||
const auto* tf_options =
|
||||
static_cast<const TfLiteDivParams*>(tflite_node->builtin_data);
|
||||
if (tf_options != nullptr) {
|
||||
const TfLiteDivParams* tf_options;
|
||||
if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
|
||||
activation = tf_options->activation;
|
||||
}
|
||||
break;
|
||||
|
@ -1002,8 +993,8 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
|
|||
Node* node = graph->NewNode();
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
|
||||
const auto* tf_options = static_cast<const TfLiteFullyConnectedParams*>(
|
||||
tflite_node->builtin_data);
|
||||
const TfLiteFullyConnectedParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
if (tf_options->weights_format !=
|
||||
kTfLiteFullyConnectedWeightsFormatDefault) {
|
||||
return absl::UnimplementedError(
|
||||
|
@ -1112,12 +1103,9 @@ class LSTMOperationParser : public TFLiteOperationParser {
|
|||
return absl::InvalidArgumentError("LSTM should have 4 output tensors");
|
||||
}
|
||||
|
||||
const auto* params =
|
||||
static_cast<const TfLiteLSTMParams*>(tflite_node->builtin_data);
|
||||
if (!params) {
|
||||
return absl::InternalError("Missing tflite params");
|
||||
}
|
||||
RETURN_IF_ERROR(CheckParameters(params));
|
||||
const TfLiteLSTMParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
RETURN_IF_ERROR(CheckParameters(tf_options));
|
||||
|
||||
Node* concat_node = graph->NewNode();
|
||||
concat_node->operation.type = ToString(OperationType::CONCAT);
|
||||
|
@ -1251,11 +1239,8 @@ class MulOperationParser : public TFLiteOperationParser {
|
|||
constant_dims, graph, reader));
|
||||
}
|
||||
|
||||
const auto* tf_options =
|
||||
static_cast<const TfLiteMulParams*>(tflite_node->builtin_data);
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing TfLiteMulParams");
|
||||
}
|
||||
const TfLiteMulParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
|
||||
node);
|
||||
}
|
||||
|
@ -1338,8 +1323,8 @@ class PadOperationParser : public TFLiteOperationParser {
|
|||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
if (mirror_pad_) {
|
||||
auto* tf_options = static_cast<const TfLiteMirrorPaddingParams*>(
|
||||
tflite_node->builtin_data);
|
||||
const TfLiteMirrorPaddingParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
if (tf_options->mode !=
|
||||
TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) {
|
||||
return absl::InvalidArgumentError(
|
||||
|
@ -1444,14 +1429,9 @@ class Pooling2DOperationParser : public TFLiteOperationParser {
|
|||
// is MaxPoolingWithArgmax2D. There is no way to read
|
||||
// tflite_node->builtin_code, so, simply check whether custom data is
|
||||
// available.
|
||||
auto* tf_options =
|
||||
static_cast<const TfLitePoolParams*>(tflite_node->custom_initial_data);
|
||||
if (!tf_options) {
|
||||
tf_options =
|
||||
static_cast<const TfLitePoolParams*>(tflite_node->builtin_data);
|
||||
}
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing tflite params");
|
||||
const TfLitePoolParams* tf_options;
|
||||
if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
}
|
||||
|
||||
std::vector<uint32_t> max_tensor_id{0};
|
||||
|
@ -1637,10 +1617,8 @@ class Resize2DOperationParser : public TFLiteOperationParser {
|
|||
template <class T>
|
||||
absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
|
||||
bool* align_corners) {
|
||||
const auto* tf_options = static_cast<const T*>(tflite_node->builtin_data);
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing tflite params");
|
||||
}
|
||||
const T* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
*align_corners = tf_options->align_corners;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -1648,12 +1626,8 @@ class Resize2DOperationParser : public TFLiteOperationParser {
|
|||
absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
|
||||
bool* half_pixel_centers) {
|
||||
if (sampling_type_ == SamplingType::BILINEAR) {
|
||||
const auto* tf_options =
|
||||
static_cast<TfLiteResizeBilinearParams*>(tflite_node->builtin_data);
|
||||
if (!tf_options) {
|
||||
return absl::InternalError(
|
||||
"Missing tflite params for ResizeBilinear op");
|
||||
}
|
||||
const TfLiteResizeBilinearParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
if (tf_options->align_corners && tf_options->half_pixel_centers) {
|
||||
return absl::InternalError(
|
||||
"If half_pixel_centers is True, align_corners must be False.");
|
||||
|
@ -1809,11 +1783,8 @@ class SoftmaxOperationParser : public TFLiteOperationParser {
|
|||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
const auto* tf_options =
|
||||
static_cast<const TfLiteSoftmaxParams*>(tflite_node->builtin_data);
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing tflite params");
|
||||
}
|
||||
const TfLiteSoftmaxParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
if (tf_options->beta != 1) {
|
||||
// there is multiply by scalar operation fused in softmax. Make a layer
|
||||
// out of it before softmax.
|
||||
|
@ -1857,8 +1828,8 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser {
|
|||
node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
const auto* tf_options =
|
||||
static_cast<const TfLiteSpaceToDepthParams*>(tflite_node->builtin_data);
|
||||
const TfLiteSpaceToDepthParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
SpaceToDepthAttributes attr;
|
||||
attr.block_size = tf_options->block_size;
|
||||
node->operation.attributes = attr;
|
||||
|
@ -1898,14 +1869,12 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
|||
"Slicing is supported for 3 or 4 dimensional tensors only.");
|
||||
}
|
||||
|
||||
const auto* tf_options =
|
||||
static_cast<const TfLiteStridedSliceParams*>(tflite_node->builtin_data);
|
||||
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing tflite params");
|
||||
}
|
||||
const TfLiteStridedSliceParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
|
||||
|
||||
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
|
||||
|
||||
SliceAttributes attr;
|
||||
if (read_without_batch) {
|
||||
RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
|
||||
|
@ -2074,11 +2043,9 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
|
|||
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
const auto* tf_options = static_cast<const TfLiteTransposeConvParams*>(
|
||||
tflite_node->builtin_data);
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing tflite options.");
|
||||
}
|
||||
const TfLiteTransposeConvParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
|
||||
ConvolutionTransposedAttributes attr;
|
||||
attr.stride = tf_options
|
||||
? HW(tf_options->stride_height, tf_options->stride_width)
|
||||
|
@ -2157,11 +2124,10 @@ class Unpooling2DOperationParser : public TFLiteOperationParser {
|
|||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
|
||||
MaxUnpooling2DAttributes attr;
|
||||
const auto* tf_options =
|
||||
static_cast<const TfLitePoolParams*>(tflite_node->custom_initial_data);
|
||||
if (!tf_options) {
|
||||
return absl::InternalError("Missing tflite params");
|
||||
}
|
||||
|
||||
const TfLitePoolParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
||||
|
||||
attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
|
||||
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
|
||||
UpdatePadding(tf_options->padding, input_shape, &attr);
|
||||
|
|
Loading…
Reference in New Issue