Improved AddBias transformation.

PiperOrigin-RevId: 318845057
Change-Id: I41321cdea9d8c605fa77dcff4f962a891536d985
This commit is contained in:
Raman Sarokin 2020-06-29 10:42:03 -07:00 committed by TensorFlower Gardener
parent e3036c5cc9
commit 777b6ad484

View File

@ -27,38 +27,47 @@ namespace tflite {
namespace gpu {
namespace {
template <typename T>
TransformResult FillBias(Node* node) {
auto& attr = absl::any_cast<T&>(node->operation.attributes);
if (attr.bias.data.empty()) {
const int dst_channels = attr.weights.shape.o;
attr.bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(dst_channels));
TransformResult FillBias(
int output_channels,
tflite::gpu::Tensor<Linear, DataType::FLOAT32>* biases) {
if (biases->data.empty()) {
*biases =
MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(output_channels));
return {TransformStatus::APPLIED, "Added bias"};
}
if (biases->shape.v != output_channels) {
float last_value = biases->data.back();
biases->shape.v = output_channels;
biases->data.resize(output_channels, last_value);
return {TransformStatus::APPLIED, "Bias extended"};
}
return {TransformStatus::SKIPPED, ""};
}
template TransformResult FillBias<Convolution2DAttributes>(Node* node);
template TransformResult FillBias<ConvolutionTransposedAttributes>(Node* node);
template TransformResult FillBias<DepthwiseConvolution2DAttributes>(Node* node);
template TransformResult FillBias<FullyConnectedAttributes>(Node* node);
class AddBias : public NodeTransformation {
public:
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
if (node->operation.type == ToString(OperationType::CONVOLUTION_2D)) {
return FillBias<Convolution2DAttributes>(node);
auto& attr =
absl::any_cast<Convolution2DAttributes&>(node->operation.attributes);
return FillBias(attr.weights.shape.o, &attr.bias);
}
if (node->operation.type ==
ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
return FillBias<ConvolutionTransposedAttributes>(node);
auto& attr = absl::any_cast<ConvolutionTransposedAttributes&>(
node->operation.attributes);
return FillBias(attr.weights.shape.o, &attr.bias);
}
if (node->operation.type ==
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
return FillBias<DepthwiseConvolution2DAttributes>(node);
auto& attr = absl::any_cast<DepthwiseConvolution2DAttributes&>(
node->operation.attributes);
return FillBias(attr.weights.shape.o * attr.weights.shape.i, &attr.bias);
}
if (node->operation.type == ToString(OperationType::FULLY_CONNECTED)) {
return FillBias<FullyConnectedAttributes>(node);
auto& attr =
absl::any_cast<FullyConnectedAttributes&>(node->operation.attributes);
return FillBias(attr.weights.shape.o, &attr.bias);
}
return {TransformStatus::SKIPPED, ""};
}