From 777b6ad484d1647a1b7a64ab862b1cb3ff706a4f Mon Sep 17 00:00:00 2001 From: Raman Sarokin <sorokin@google.com> Date: Mon, 29 Jun 2020 10:42:03 -0700 Subject: [PATCH] Improved AddBias transformation. PiperOrigin-RevId: 318845057 Change-Id: I41321cdea9d8c605fa77dcff4f962a891536d985 --- .../gpu/common/transformations/add_bias.cc | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc index 7feac824ef7..ec2474138a3 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc @@ -27,38 +27,47 @@ namespace tflite { namespace gpu { namespace { -template <typename T> -TransformResult FillBias(Node* node) { - auto& attr = absl::any_cast<T&>(node->operation.attributes); - if (attr.bias.data.empty()) { - const int dst_channels = attr.weights.shape.o; - attr.bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(dst_channels)); +TransformResult FillBias( + int output_channels, + tflite::gpu::Tensor<Linear, DataType::FLOAT32>* biases) { + if (biases->data.empty()) { + *biases = + MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(output_channels)); return {TransformStatus::APPLIED, "Added bias"}; } + if (biases->shape.v != output_channels) { + float last_value = biases->data.back(); + biases->shape.v = output_channels; + biases->data.resize(output_channels, last_value); + return {TransformStatus::APPLIED, "Bias extended"}; + } return {TransformStatus::SKIPPED, ""}; } -template TransformResult FillBias<Convolution2DAttributes>(Node* node); -template TransformResult FillBias<ConvolutionTransposedAttributes>(Node* node); -template TransformResult FillBias<DepthwiseConvolution2DAttributes>(Node* node); -template TransformResult FillBias<FullyConnectedAttributes>(Node* node); - class AddBias : public NodeTransformation { public: TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final { if (node->operation.type == ToString(OperationType::CONVOLUTION_2D)) { - return FillBias<Convolution2DAttributes>(node); + auto& attr = + absl::any_cast<Convolution2DAttributes&>(node->operation.attributes); + return FillBias(attr.weights.shape.o, &attr.bias); } if (node->operation.type == ToString(OperationType::CONVOLUTION_TRANSPOSED)) { - return FillBias<ConvolutionTransposedAttributes>(node); + auto& attr = absl::any_cast<ConvolutionTransposedAttributes&>( + node->operation.attributes); + return FillBias(attr.weights.shape.o, &attr.bias); } if (node->operation.type == ToString(OperationType::DEPTHWISE_CONVOLUTION)) { - return FillBias<DepthwiseConvolution2DAttributes>(node); + auto& attr = absl::any_cast<DepthwiseConvolution2DAttributes&>( + node->operation.attributes); + return FillBias(attr.weights.shape.o * attr.weights.shape.i, &attr.bias); } if (node->operation.type == ToString(OperationType::FULLY_CONNECTED)) { - return FillBias<FullyConnectedAttributes>(node); + auto& attr = + absl::any_cast<FullyConnectedAttributes&>(node->operation.attributes); + return FillBias(attr.weights.shape.o, &attr.bias); } return {TransformStatus::SKIPPED, ""}; }