Use RetrieveBuiltinData and RetrieveCustomInitialData instead of directly accessing builtin_data and custom_initial_data.

PiperOrigin-RevId: 314183617
Change-Id: I8494bf3fe28fc8e3eeba21611beccc48354eac96
This commit is contained in:
Robert David 2020-06-01 12:44:39 -07:00 committed by TensorFlower Gardener
parent 6f22fa9376
commit 21896de7bc
1 changed files with 49 additions and 83 deletions

View File

@ -422,11 +422,8 @@ class AddOperationParser : public TFLiteOperationParser {
AddAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
node->operation.attributes = std::move(attr);
const auto* tf_options =
static_cast<const TfLiteAddParams*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
const TfLiteAddParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
node);
}
@ -499,11 +496,8 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
break;
}
}
const auto* tf_options = static_cast<const TfLiteConcatenationParams*>(
tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
const TfLiteConcatenationParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation,
graph, node));
node->operation.attributes = attr;
@ -601,11 +595,8 @@ class Conv2DOperationParser : public TFLiteOperationParser {
}
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
const auto* tf_options =
static_cast<const TfLiteConvParams*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
const TfLiteConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
attr.dilations = HW(tf_options->dilation_height_factor,
tf_options->dilation_width_factor);
@ -639,17 +630,19 @@ class Convolution2DTransposeBiasParser : public TFLiteOperationParser {
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* params = static_cast<const TfLiteTransposeConvParams*>(
tflite_node->custom_initial_data);
const TfLiteTransposeConvParams* tf_options;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
ConvolutionTransposedAttributes attr;
attr.stride =
params ? HW(params->stride_height, params->stride_width) : HW(1, 1);
attr.stride = status.ok()
? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(params->padding, graph->FindInputs(node->id)[0]->tensor.shape,
&attr);
UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
graph->FindInputs(node->id)[0]->tensor.shape, &attr);
node->operation.attributes = std::move(attr);
return absl::OkStatus();
@ -874,17 +867,15 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
TfLiteFusedActivation activation = kTfLiteActNone;
switch (operation_type_) {
case OperationType::SUB: {
const auto* tf_options =
static_cast<const TfLiteSubParams*>(tflite_node->builtin_data);
if (tf_options != nullptr) {
const TfLiteSubParams* tf_options;
if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
activation = tf_options->activation;
}
break;
}
case OperationType::DIV: {
const auto* tf_options =
static_cast<const TfLiteDivParams*>(tflite_node->builtin_data);
if (tf_options != nullptr) {
const TfLiteDivParams* tf_options;
if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
activation = tf_options->activation;
}
break;
@ -1002,8 +993,8 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode();
RETURN_IF_ERROR(reader->AddInput(node, 0));
const auto* tf_options = static_cast<const TfLiteFullyConnectedParams*>(
tflite_node->builtin_data);
const TfLiteFullyConnectedParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->weights_format !=
kTfLiteFullyConnectedWeightsFormatDefault) {
return absl::UnimplementedError(
@ -1112,12 +1103,9 @@ class LSTMOperationParser : public TFLiteOperationParser {
return absl::InvalidArgumentError("LSTM should have 4 output tensors");
}
const auto* params =
static_cast<const TfLiteLSTMParams*>(tflite_node->builtin_data);
if (!params) {
return absl::InternalError("Missing tflite params");
}
RETURN_IF_ERROR(CheckParameters(params));
const TfLiteLSTMParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckParameters(tf_options));
Node* concat_node = graph->NewNode();
concat_node->operation.type = ToString(OperationType::CONCAT);
@ -1251,11 +1239,8 @@ class MulOperationParser : public TFLiteOperationParser {
constant_dims, graph, reader));
}
const auto* tf_options =
static_cast<const TfLiteMulParams*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing TfLiteMulParams");
}
const TfLiteMulParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return MaybeFuseActivationToTheSingleOutput(tf_options->activation, graph,
node);
}
@ -1338,8 +1323,8 @@ class PadOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
if (mirror_pad_) {
auto* tf_options = static_cast<const TfLiteMirrorPaddingParams*>(
tflite_node->builtin_data);
const TfLiteMirrorPaddingParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->mode !=
TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) {
return absl::InvalidArgumentError(
@ -1444,14 +1429,9 @@ class Pooling2DOperationParser : public TFLiteOperationParser {
// is MaxPoolingWithArgmax2D. There is no way to read
// tflite_node->builtin_code, so, simply check whether custom data is
// available.
auto* tf_options =
static_cast<const TfLitePoolParams*>(tflite_node->custom_initial_data);
if (!tf_options) {
tf_options =
static_cast<const TfLitePoolParams*>(tflite_node->builtin_data);
}
if (!tf_options) {
return absl::InternalError("Missing tflite params");
const TfLitePoolParams* tf_options;
if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
}
std::vector<uint32_t> max_tensor_id{0};
@ -1637,10 +1617,8 @@ class Resize2DOperationParser : public TFLiteOperationParser {
template <class T>
absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
bool* align_corners) {
const auto* tf_options = static_cast<const T*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
const T* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
*align_corners = tf_options->align_corners;
return absl::OkStatus();
}
@ -1648,12 +1626,8 @@ class Resize2DOperationParser : public TFLiteOperationParser {
absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
bool* half_pixel_centers) {
if (sampling_type_ == SamplingType::BILINEAR) {
const auto* tf_options =
static_cast<TfLiteResizeBilinearParams*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError(
"Missing tflite params for ResizeBilinear op");
}
const TfLiteResizeBilinearParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->align_corners && tf_options->half_pixel_centers) {
return absl::InternalError(
"If half_pixel_centers is True, align_corners must be False.");
@ -1809,11 +1783,8 @@ class SoftmaxOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options =
static_cast<const TfLiteSoftmaxParams*>(tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
const TfLiteSoftmaxParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->beta != 1) {
// there is multiply by scalar operation fused in softmax. Make a layer
// out of it before softmax.
@ -1857,8 +1828,8 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser {
node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options =
static_cast<const TfLiteSpaceToDepthParams*>(tflite_node->builtin_data);
const TfLiteSpaceToDepthParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
SpaceToDepthAttributes attr;
attr.block_size = tf_options->block_size;
node->operation.attributes = attr;
@ -1898,14 +1869,12 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
const auto* tf_options =
static_cast<const TfLiteStridedSliceParams*>(tflite_node->builtin_data);
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
const TfLiteStridedSliceParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
SliceAttributes attr;
if (read_without_batch) {
RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
@ -2074,11 +2043,9 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
RETURN_IF_ERROR(reader->AddOutputs(node));
const auto* tf_options = static_cast<const TfLiteTransposeConvParams*>(
tflite_node->builtin_data);
if (!tf_options) {
return absl::InternalError("Missing tflite options.");
}
const TfLiteTransposeConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
ConvolutionTransposedAttributes attr;
attr.stride = tf_options
? HW(tf_options->stride_height, tf_options->stride_width)
@ -2157,11 +2124,10 @@ class Unpooling2DOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(reader->AddOutputs(node));
auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
MaxUnpooling2DAttributes attr;
const auto* tf_options =
static_cast<const TfLitePoolParams*>(tflite_node->custom_initial_data);
if (!tf_options) {
return absl::InternalError("Missing tflite params");
}
const TfLitePoolParams* tf_options;
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
UpdatePadding(tf_options->padding, input_shape, &attr);