From a7de1a78fe5ecd0e949923fc98652e124531a866 Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Wed, 20 Nov 2019 11:35:55 -0800 Subject: [PATCH] Add int8 transpose conv. PiperOrigin-RevId: 281565221 Change-Id: I984c4e7e4dbb30a872c63778e52eac0bd91fd999 --- tensorflow/lite/kernels/internal/BUILD | 1 + .../reference/integer_ops/transpose_conv.h | 118 ++++++++++++++++++ tensorflow/lite/kernels/register.cc | 4 +- tensorflow/lite/kernels/test_util.h | 7 +- tensorflow/lite/kernels/transpose_conv.cc | 99 ++++++++++++--- .../lite/kernels/transpose_conv_test.cc | 92 +++++++++++++- .../lite/tools/optimize/operator_property.cc | 10 ++ .../lite/tools/versioning/op_version.cc | 7 ++ .../lite/tools/versioning/op_version_test.cc | 15 +++ 9 files changed, 333 insertions(+), 20 deletions(-) create mode 100644 tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index b30da135716..8c320720a31 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -421,6 +421,7 @@ cc_library( "reference/integer_ops/pooling.h", "reference/integer_ops/softmax.h", "reference/integer_ops/tanh.h", + "reference/integer_ops/transpose_conv.h", "reference/logistic.h", "reference/maximum_minimum.h", "reference/mul.h", diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h b/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h new file mode 100644 index 00000000000..1ad6e20f2dc --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h @@ -0,0 +1,118 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +// Fixed-point per-channel-quantization transpose convolution reference kernel. +inline void TransposeConv( + const ConvParams& params, const int32* output_multiplier, + const int32* output_shift, const RuntimeShape& input_shape, + const int8* input_data, const RuntimeShape& filter_shape, + const int8* filter_data, const RuntimeShape& output_shape, + int8* output_data, const RuntimeShape& im2col_shape, int8* im2col_data, + int32* scratch_buffer) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + (void)im2col_data; // only used in optimized code. + (void)im2col_shape; // only used in optimized code. + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int32 input_offset = params.input_offset; + const int32 output_offset = params.output_offset; + const int32 output_activation_min = std::numeric_limits::min(); + const int32 output_activation_max = std::numeric_limits::max(); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int num_elements = output_shape.FlatSize(); + // We need to initialize scratch_buffer to all 0s, as we apply the same + // 'scatter' based trick as in float version. + memset(scratch_buffer, 0, num_elements * sizeof(int32)); + + // Loop through input elements one at a time. + for (int batch = 0; batch < batches; ++batch) { + for (int in_y = 0; in_y < input_height; ++in_y) { + for (int in_x = 0; in_x < input_width; ++in_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + // Loop through the output elements it will influence. + const int out_x_origin = (in_x * stride_width) - pad_width; + const int out_y_origin = (in_y * stride_height) - pad_height; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int out_channel = 0; out_channel < output_depth; + ++out_channel) { + // Compute output element location. + const int out_x = out_x_origin + filter_x; + const int out_y = out_y_origin + filter_y; + // We cannot accumulate out of bounds. + if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && + (out_y < output_height)) { + const int8 input_value = input_data[Offset( + input_shape, batch, in_y, in_x, in_channel)]; + const int8 filter_value = + filter_data[Offset(filter_shape, out_channel, filter_y, + filter_x, in_channel)]; + scratch_buffer[Offset(output_shape, batch, out_y, out_x, + out_channel)] += + (input_value + input_offset) * filter_value; + } + } + } + } + } + } + } + } + + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + int32 acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x, + out_channel)]; + acc = MultiplyByQuantizedMultiplier( + acc, output_multiplier[out_channel], output_shift[out_channel]); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] = + static_cast(acc); + } + } + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_TRANSPOSE_CONV_H_ diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 8f2c3e4c30e..68e102511eb 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -205,7 +205,9 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version */ 3); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_COS, Register_COS()); - AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); + AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_TILE, Register_TILE()); AddBuiltin(BuiltinOperator_SUM, Register_SUM(), /* min_version */ 1, diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 61197b78d1d..2180405c139 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -168,7 +168,12 @@ class SingleOpModel { // Templated version of AddConstInput(). template int AddConstInput(const TensorData& t, std::initializer_list data) { - int id = AddTensor(t, data); + int id = 0; + if (t.per_channel_quantization) { + id = AddTensorPerChannelQuant(t); + } else { + id = AddTensor(t, data); + } inputs_.push_back(id); return id; } diff --git a/tensorflow/lite/kernels/transpose_conv.cc b/tensorflow/lite/kernels/transpose_conv.cc index 51b51bf885b..4ae74e628a2 100644 --- a/tensorflow/lite/kernels/transpose_conv.cc +++ b/tensorflow/lite/kernels/transpose_conv.cc @@ -24,6 +24,8 @@ limitations under the License. #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/eigen_support.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +// NOLINTNEXTLINE - This header file should't go to the top. +#include "tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -75,6 +77,12 @@ struct OpData { int32_t output_multiplier; int output_shift; + // Per channel output multiplier and shift. + // TODO(b/144846950): Add channel dimension index for the kernel to be more + // flexible. + std::vector per_channel_output_multiplier; + std::vector per_channel_output_shift; + // The range of the fused activation layer. For example for kNone and // uint8_t these would be 0 and 255. int32_t output_activation_min; @@ -144,7 +152,7 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, } // Allocate scratch buffer tensor for UInt8 inputs. - if (input_type == kTfLiteUInt8) { + if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) { if (data->scratch_tensor_id == kTensorNotAllocated) { context->AddTensors(context, 1, &data->scratch_tensor_id); } @@ -214,6 +222,11 @@ TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context, GetTensorData(weights), GetTensorShape(transposed_weights), GetTensorData(transposed_weights)); + } else if (weights->type == kTfLiteInt8) { + optimized_ops::Transpose(transpose_params, input_shape, + GetTensorData(weights), + GetTensorShape(transposed_weights), + GetTensorData(transposed_weights)); } else { context->ReportError( context, "Transpose conv only support float & uint8 right now."); @@ -242,8 +255,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4); - TF_LITE_ENSURE(context, - input->type == kTfLiteFloat32 || input->type == kTfLiteUInt8); + TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 || + input->type == kTfLiteUInt8 || + input->type == kTfLiteInt8); TF_LITE_ENSURE_EQ(context, weights->type, input->type); TF_LITE_ENSURE_EQ(context, output->type, input->type); // Ensure that weights and inputs have the same channel dimension. @@ -288,7 +302,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } } - if (input->type == kTfLiteUInt8) { + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { node->temporaries->data[data->scratch_tensor_index] = data->scratch_tensor_id; TfLiteTensor* scratch_buffer = @@ -302,19 +316,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { ResizeTensor(context, output_shape, scratch_buffer)); } - // Calcuate output multiplier for quantization. - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, weights, output, &real_multiplier)); - int exponent; - // Populate quantization parameteters with multiplier and shift. - QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); - data->output_shift = -exponent; - // Populate max and min activation range. - CalculateActivationRangeUint8(kTfLiteActNone, output, - &data->output_activation_min, - &data->output_activation_max); + TF_LITE_ENSURE_EQ(context, weights->quantization.type, + kTfLiteAffineQuantization); + const auto* affine_quantization = + reinterpret_cast( + weights->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + const int number_channel = affine_quantization->scale->size; + data->per_channel_output_multiplier.resize(number_channel); + data->per_channel_output_shift.resize(number_channel); + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, weights, nullptr, output, kTfLiteActNone, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data())); } + return kTfLiteOk; } @@ -403,6 +422,39 @@ void EvalQuantized(TfLiteContext* context, } } +void EvalQuantizedPerChannel(TfLiteContext* context, + const TfLiteTransposeConvParams* params, + OpData* data, const TfLiteTensor* input, + const TfLiteTensor* weights, + const TfLiteTensor* transposed_weights, + TfLiteTensor* col2im, TfLiteTensor* output, + TfLiteTensor* scratch_buffer) { + tflite::ConvParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width_offset = data->padding.width_offset; + op_params.padding_values.height_offset = data->padding.height_offset; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + // Need to flip the sign of input offset to add it directly to the quantized + // buffer. + op_params.input_offset = -input->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + // TODO(b/143380105): Need to add optimized kernel for int8 quantized + // transpose conv. + reference_integer_ops::TransposeConv( + op_params, data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data(), GetTensorShape(input), + GetTensorData(input), GetTensorShape(weights), + GetTensorData(weights), GetTensorShape(output), + GetTensorData(output), GetTensorShape(col2im), + GetTensorData(col2im), GetTensorData(scratch_buffer)); +} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Retrieve tensors (All should be allocated by now) @@ -473,6 +525,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { scratch_buffer); break; } + case kTfLiteInt8: { + TfLiteTensor* scratch_buffer = + GetTemporary(context, node, data->scratch_tensor_index); + if (IsDynamicTensor(scratch_buffer)) { + TF_LITE_ENSURE_OK(context, + ResizeTensor(context, output_shape, scratch_buffer)); + } + if (data->weights_are_transposed && !IsConstantTensor(weights)) { + ResizeAndTransposeWeights(context, weights, transposed_weights); + } + EvalQuantizedPerChannel(context, params, data, input, weights, + transposed_weights, col2im, output, + scratch_buffer); + break; + } default: context->ReportError(context, "Type '%s' is not currently supported.", TfLiteTypeGetName(input->type)); diff --git a/tensorflow/lite/kernels/transpose_conv_test.cc b/tensorflow/lite/kernels/transpose_conv_test.cc index 8f89630d2d7..9a1a950fe0f 100644 --- a/tensorflow/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/kernels/transpose_conv_test.cc @@ -50,7 +50,7 @@ class BaseTransposeConvOpModel : public SingleOpModel { std::initializer_list filter_data, const TensorData& input, const TensorData& output, Padding padding, int stride_w, int stride_h, - TestType test_type) { + TestType test_type, int version = 1) { // Just to be confusing, transpose_conv has an _input_ named "output_shape" // that sets the shape of the output tensor of the op :). It must always be // an int32 1D four element tensor. @@ -70,7 +70,7 @@ class BaseTransposeConvOpModel : public SingleOpModel { CreateTransposeConvOptions(builder_, padding, stride_w, stride_h) .Union()); resolver_ = absl::make_unique( - BuiltinOperator_TRANSPOSE_CONV, registration); + BuiltinOperator_TRANSPOSE_CONV, registration, version); BuildInterpreter( {GetShape(output_shape_), GetShape(filter_), GetShape(input_)}); @@ -83,6 +83,8 @@ class BaseTransposeConvOpModel : public SingleOpModel { void SetInput(std::initializer_list data) { if (std::is_same::value) { QuantizeAndPopulate(input_, data); + } else if (std::is_same::value) { + QuantizeAndPopulate(input_, data); } else { PopulateTensor(input_, data); } @@ -313,6 +315,92 @@ TEST_P(TransposeConvOpTest, SimpleTestQuantized) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } +class PerChannelQuantizedTransposeConvOpModel + : public BaseTransposeConvOpModel { + public: + using BaseTransposeConvOpModel::BaseTransposeConvOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } + + void SetInput(const std::initializer_list& data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(const std::initializer_list& data) { + PerChannelSymmetricQuantizeAndPopulate(filter_, data); + } +}; + +TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannelSingleChannel) { + // TODO(b/138722124): Enable these tests on NNAPI. + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + + const std::initializer_list filter_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + PerChannelQuantizedTransposeConvOpModel model( + GetRegistration(), {1, 4, 4, 1}, + {TensorType_INT8, {1, 3, 3, 1}, 0, 0, 0, 0, true, {9.0 / 127}, {0}, 0}, + {}, {TensorType_INT8, {1, 4, 4, 1}, 0, 0, 16.0 / 255, -128}, + {TensorType_INT8, {}, 0, 0, 2, -128}, Padding_SAME, 1, 1, GetTestType(), + /* version */ 2); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + model.SetFilter(filter_data); + model.Invoke(); + + EXPECT_THAT( + model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({28, 62, 82, 76, 98, 192, 238, 198, 206, + 372, 416, 330, 262, 446, 486, 366}, + 1e-5))); + + // GetOutputShape() should always be same as model.SetOutputShape(...); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +// Test data copied from the float multi-channel test above. +TEST_P(TransposeConvOpTest, TestQuantizedPerChannelMultiChannel) { + // TODO(b/138722124): Enable these tests on NNAPI. + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + + const std::initializer_list filter_data = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18}; + PerChannelQuantizedTransposeConvOpModel model( + GetRegistration(), {1, 5, 5, 2}, + {TensorType_INT8, + {2, 3, 3, 1}, + 0, + 0, + 0, + 0, + true, + {17.0 / 127, 18.0 / 127}, + {0, 0}, + 0}, + {}, {TensorType_INT8, {1, 2, 2, 1}, 0, 0, 4.0 / 255, -128}, + {TensorType_INT8, {}, 0, 0, 1, -128}, Padding_VALID, 2, 2, GetTestType(), + /* version */ 2); + model.SetInput({1, 2, 3, 4}); + model.SetFilter(filter_data); + model.Invoke(); + + EXPECT_THAT( + model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9, 10, 25, 28, 18, + 20, 22, 24, 16, 20, 24, 28, 62, 72, 42, 48, 54, 60, 21, 24, 27, 30, + 61, 68, 36, 40, 44, 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72}, + 1e-5))); + + // GetOutputShape() should always be same as model.SetOutputShape(...); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2})); +} + TEST_P(TransposeConvOpTest, TwoFiltersTestQuantized) { // Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, // 18} diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index eec8bacea23..b284e025159 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -93,6 +93,16 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.version = 3; break; } + case BuiltinOperator_TRANSPOSE_CONV: { + TensorProperty tensor_property; + tensor_property.per_axis = true; + tensor_property.per_axis_index = 0; + tensor_property.symmetric = true; + property.inputs = {{1, tensor_property}, {2, {}}}; + property.outputs = {{0, {}}}; + property.version = 2; + break; + } case BuiltinOperator_DEPTHWISE_CONV_2D: { TensorProperty tensor_property; tensor_property.per_axis = true; diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index fdb361a5846..e638840606d 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -149,6 +149,13 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; + case BuiltinOperator_TRANSPOSE_CONV: + // If the op takes int8 input, it is version 2. + if (op_sig.input_types.at(0) == TensorType_INT8) { + return 2; + } + return 1; + case BuiltinOperator_LSTM: // If the input tensor is float and a weight is int8, this is a version // 3 hybrid operation. diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index 025b4a2f1a0..adb1e89e44c 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -351,4 +351,19 @@ TEST(OpVersionTest, VersioningFloorDivOperatorTest) { EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); } +TEST(OpVersionTest, VersioningTransposeConvOperatorTest) { + OpSignature fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE_CONV, + .input_types = + std::vector{TensorType_FLOAT32, TensorType_UINT8}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); + + fake_op_sig = { + .op = BuiltinOperator_TRANSPOSE_CONV, + .input_types = std::vector{TensorType_INT8}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); +} + } // namespace tflite