diff --git a/tensorflow/lite/experimental/micro/kernels/BUILD b/tensorflow/lite/experimental/micro/kernels/BUILD index 43288c9de60..5121bc3d15b 100644 --- a/tensorflow/lite/experimental/micro/kernels/BUILD +++ b/tensorflow/lite/experimental/micro/kernels/BUILD @@ -19,6 +19,7 @@ cc_library( "elementwise.cc", "fully_connected.cc", "pooling.cc", + "prelu.cc", "softmax.cc", ], hdrs = [ @@ -59,6 +60,7 @@ cc_library( "fully_connected.cc", "pooling.cc", "portable_optimized/depthwise_conv.cc", + "prelu.cc", "softmax.cc", ], hdrs = [ @@ -179,3 +181,16 @@ tflite_micro_cc_test( "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) + +tflite_micro_cc_test( + name = "prelu_test", + srcs = [ + "prelu_test.cc", + ], + deps = [ + ":all_ops_resolver", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) diff --git a/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc index 6fb2e664802..c54cdf78f6c 100644 --- a/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc +++ b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc @@ -23,6 +23,7 @@ TfLiteRegistration* Register_CONV_2D(); TfLiteRegistration* Register_AVERAGE_POOL_2D(); TfLiteRegistration* Register_MAX_POOL_2D(); TfLiteRegistration* Register_ABS(); +TfLiteRegistration* Register_PRELU(); AllOpsResolver::AllOpsResolver() { AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()); @@ -34,6 +35,7 @@ AllOpsResolver::AllOpsResolver() { AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D()); AddBuiltin(BuiltinOperator_ABS, Register_ABS()); + AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); } } // namespace micro diff --git a/tensorflow/lite/experimental/micro/kernels/prelu.cc b/tensorflow/lite/experimental/micro/kernels/prelu.cc new file mode 100644 index 00000000000..bfa5b9a0e75 --- /dev/null +++ b/tensorflow/lite/experimental/micro/kernels/prelu.cc @@ -0,0 +1,114 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/prelu.h" + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace activations { + +TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +inline void BroadcastPrelu4DSlowFloat( + const RuntimeShape& unextended_input1_shape, const float* input1_data, + const RuntimeShape& unextended_input2_shape, const float* input2_data, + const RuntimeShape& unextended_output_shape, float* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { + auto out_idx = Offset(output_shape, b, y, x, c); + auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); + auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); + auto in1_val = input1_data[in1_idx]; + auto in2_val = input2_data[in2_idx]; + output_data[out_idx] = in1_val >= 0.0 ? in1_val : in1_val * in2_val; + } + } + } + } +} + +TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* alpha = GetInput(context, node, 1); + TfLiteTensor* output = GetOutput(context, node, 0); + int32_t output_multiplier = 0; + int output_shift = 0; + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + double real_multiplier = + input->params.scale * alpha->params.scale / output->params.scale; + QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier, + &output_shift); + } + switch (input->type) { + case kTfLiteFloat32: { + BroadcastPrelu4DSlowFloat( + GetTensorShape(input), GetTensorData(input), + GetTensorShape(alpha), GetTensorData(alpha), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; + } break; + case kTfLiteUInt8: { + PreluParams op_params; + op_params.input_offset = -input->params.zero_point; + op_params.alpha_offset = -alpha->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = output_shift; + reference_ops::BroadcastPrelu4DSlow( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(alpha), GetTensorData(alpha), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; + } break; + default: + context->ReportError( + context, "Only float32 and uint8 are supported currently, got %d.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } +} + +} // namespace activations + +TfLiteRegistration* Register_PRELU() { + static TfLiteRegistration r = {nullptr, nullptr, activations::PreluPrepare, + activations::PreluEval}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/experimental/micro/kernels/prelu_test.cc b/tensorflow/lite/experimental/micro/kernels/prelu_test.cc new file mode 100644 index 00000000000..583b43ba189 --- /dev/null +++ b/tensorflow/lite/experimental/micro/kernels/prelu_test.cc @@ -0,0 +1,204 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/test_utils.h" + +namespace tflite { +namespace testing { +namespace { + +void TestPreluFloat(std::initializer_list input_dims_data, + std::initializer_list input_data, + std::initializer_list alpha_dims_data, + std::initializer_list alpha_data, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* alpha_dims = IntArrayFromInitializer(alpha_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + constexpr int inputs_size = 2; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims, "input_tensor"), + CreateFloatTensor(alpha_data, alpha_dims, "alpha_tensor"), + CreateFloatTensor(output_data, output_dims, "output_tensor"), + }; + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_PRELU, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, nullptr, init_data_size); + } + int inputs_array_data[] = {2, 0, 1}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 2}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = nullptr; + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1e-5f); + } +} + +void TestPreluQuantized(std::initializer_list input_dims_data, + std::initializer_list input_data, + float input_min, float input_max, + std::initializer_list alpha_dims_data, + std::initializer_list alpha_data, + float alpha_min, float alpha_max, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + float output_min, float output_max, + uint8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* alpha_dims = IntArrayFromInitializer(alpha_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + constexpr int inputs_size = 2; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(alpha_data, alpha_dims, "alpha_tensor", alpha_min, + alpha_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_PRELU, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, nullptr, init_data_size); + } + int inputs_array_data[] = {2, 0, 1}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 2}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = nullptr; + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]); + } +} +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(FloatPreluActivationsOpTest) { + const int output_dims_count = 12; + float output_data[output_dims_count]; + tflite::testing::TestPreluFloat({1, 2, 2, 3}, // input shape + { + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + -1.0f, -1.0f, -1.0f, // Row 2, Column 1 + -2.0f, -2.0f, -2.0f, // Row 1, Column 2 + }, + {1, 1, 3}, // alpha shape + {0.0f, 1.0f, 2.0f}, // alpha values + { + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + 0.0f, -1.0f, -2.0f, // Row 2, Column 1 + 0.0f, -2.0f, -4.0f, // Row 1, Column 2 + }, + {1, 2, 2, 3}, // output shape + output_data); +} + +TF_LITE_MICRO_TEST(QuantizedPreluActivationsOpTest) { + using tflite::testing::F2Q; + const float kMin = -1; + const float kMax = 127.f / 128.f; + const float kAlphaMin = -0.5f; + const float kAlphaMax = 0.5f; + const int output_dims_count = 12; + uint8_t output_data[output_dims_count]; + tflite::testing::TestPreluQuantized( + {1, 2, 2, 3}, // input shape + {F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), + F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), + F2Q(-1.0f, kMin, kMax), F2Q(-1.0f, kMin, kMax), F2Q(-1.0f, kMin, kMax), + F2Q(-0.25f, kMin, kMax), F2Q(-0.25f, kMin, kMax), + F2Q(-0.25f, kMin, kMax)}, + kMin, kMax, {1, 1, 3}, // alpha shape + {F2Q(0.0f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(-0.5f, kMin, kMax)}, + kMin, kMax, + {F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), + F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), + F2Q(0.0f, kMin, kMax), F2Q(-0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), + F2Q(0.0f, kMin, kMax), F2Q(-0.125f, kMin, kMax), + F2Q(0.125f, kMin, kMax)}, + {1, 2, 2, 3}, // output shape + kMin, kMax, output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/tools/make/Makefile b/tensorflow/lite/experimental/micro/tools/make/Makefile index 67a3ea97db6..f3828928612 100644 --- a/tensorflow/lite/experimental/micro/tools/make/Makefile +++ b/tensorflow/lite/experimental/micro/tools/make/Makefile @@ -112,6 +112,7 @@ tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h \ tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h \ tensorflow/lite/kernels/internal/reference/fully_connected.h \ tensorflow/lite/kernels/internal/reference/pooling.h \ +tensorflow/lite/kernels/internal/reference/prelu.h \ tensorflow/lite/kernels/internal/reference/softmax.h \ tensorflow/lite/kernels/internal/round.h \ tensorflow/lite/kernels/internal/tensor_ctypes.h \ diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index a4cbd0f3271..199909ccbf8 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -365,6 +365,7 @@ cc_library( "reference/integer_ops/softmax.h", "reference/integer_ops/tanh.h", "reference/pooling.h", + "reference/prelu.h", "reference/reference_ops.h", "reference/softmax.h", "reference/strided_slice.h", @@ -405,6 +406,7 @@ cc_library( "reference/fully_connected.h", "reference/legacy_reference_ops.h", "reference/pooling.h", + "reference/prelu.h", "reference/reference_ops.h", "reference/softmax.h", "reference/strided_slice.h", diff --git a/tensorflow/lite/kernels/internal/reference/prelu.h b/tensorflow/lite/kernels/internal/reference/prelu.h new file mode 100644 index 00000000000..adbbf66eb1b --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/prelu.h @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_ + +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { + +namespace reference_ops { + +// Broadcast prelu to output_shape for quantized uint8 data. +inline void BroadcastPrelu4DSlow(const PreluParams& params, + const RuntimeShape& input_shape, + const uint8* input_data, + const RuntimeShape& alpha_shape, + const uint8* alpha_data, + const RuntimeShape& output_shape, + uint8* output_data) { + TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(alpha_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4); + const RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input_shape, alpha_shape, &desc1, &desc2); + + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + int output_index = Offset(extended_output_shape, b, y, x, c); + int input_index = SubscriptToIndex(desc1, b, y, x, c); + const int32 input_value = + params.input_offset + input_data[input_index]; + if (input_value >= 0) { + output_data[output_index] = input_data[input_index]; + } else { + auto alpha_index = SubscriptToIndex(desc2, b, y, x, c); + const int32 alpha_value = + params.alpha_offset + alpha_data[alpha_index]; + const int32 unclamped_output = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input_value * alpha_value, params.output_multiplier, + params.output_shift); + const int32 quantized_min = std::numeric_limits::min(); + const int32 quantized_max = std::numeric_limits::max(); + const int32 clamped_output = std::min( + quantized_max, std::max(quantized_min, unclamped_output)); + output_data[output_index] = static_cast(clamped_output); + } + } + } + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_ diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index a8b35ae7b92..92b3b47fb04 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/pooling.h" +#include "tensorflow/lite/kernels/internal/reference/prelu.h" #include "tensorflow/lite/kernels/internal/reference/softmax.h" #include "tensorflow/lite/kernels/internal/reference/strided_slice.h" #include "tensorflow/lite/kernels/internal/round.h" @@ -4403,53 +4404,6 @@ inline void ResizeNearestNeighbor( } } -inline void BroadcastPrelu4DSlow(const PreluParams& params, - const RuntimeShape& input_shape, - const uint8* input_data, - const RuntimeShape& alpha_shape, - const uint8* alpha_data, - const RuntimeShape& output_shape, - uint8* output_data) { - TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(alpha_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4); - const RuntimeShape extended_output_shape = - RuntimeShape::ExtendedShape(4, output_shape); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input_shape, alpha_shape, &desc1, &desc2); - - for (int b = 0; b < extended_output_shape.Dims(0); ++b) { - for (int y = 0; y < extended_output_shape.Dims(1); ++y) { - for (int x = 0; x < extended_output_shape.Dims(2); ++x) { - for (int c = 0; c < extended_output_shape.Dims(3); ++c) { - int output_index = Offset(extended_output_shape, b, y, x, c); - int input_index = SubscriptToIndex(desc1, b, y, x, c); - const int32 input_value = - params.input_offset + input_data[input_index]; - if (input_value >= 0) { - output_data[output_index] = input_data[input_index]; - } else { - auto alpha_index = SubscriptToIndex(desc2, b, y, x, c); - const int32 alpha_value = - params.alpha_offset + alpha_data[alpha_index]; - const int32 unclamped_output = - params.output_offset + - MultiplyByQuantizedMultiplierSmallerThanOneExp( - input_value * alpha_value, params.output_multiplier, - params.output_shift); - const int32 quantized_min = std::numeric_limits::min(); - const int32 quantized_max = std::numeric_limits::max(); - const int32 clamped_output = std::min( - quantized_max, std::max(quantized_min, unclamped_output)); - output_data[output_index] = static_cast(clamped_output); - } - } - } - } - } -} - template void Fill(const RuntimeShape& value_shape, const T* value_data, const RuntimeShape& output_shape, T* output_data) {