Improved AddBias transformation.
PiperOrigin-RevId: 318845057 Change-Id: I41321cdea9d8c605fa77dcff4f962a891536d985
This commit is contained in:
parent
e3036c5cc9
commit
777b6ad484
@ -27,38 +27,47 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
TransformResult FillBias(Node* node) {
|
||||
auto& attr = absl::any_cast<T&>(node->operation.attributes);
|
||||
if (attr.bias.data.empty()) {
|
||||
const int dst_channels = attr.weights.shape.o;
|
||||
attr.bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(dst_channels));
|
||||
TransformResult FillBias(
|
||||
int output_channels,
|
||||
tflite::gpu::Tensor<Linear, DataType::FLOAT32>* biases) {
|
||||
if (biases->data.empty()) {
|
||||
*biases =
|
||||
MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(output_channels));
|
||||
return {TransformStatus::APPLIED, "Added bias"};
|
||||
}
|
||||
if (biases->shape.v != output_channels) {
|
||||
float last_value = biases->data.back();
|
||||
biases->shape.v = output_channels;
|
||||
biases->data.resize(output_channels, last_value);
|
||||
return {TransformStatus::APPLIED, "Bias extended"};
|
||||
}
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
|
||||
template TransformResult FillBias<Convolution2DAttributes>(Node* node);
|
||||
template TransformResult FillBias<ConvolutionTransposedAttributes>(Node* node);
|
||||
template TransformResult FillBias<DepthwiseConvolution2DAttributes>(Node* node);
|
||||
template TransformResult FillBias<FullyConnectedAttributes>(Node* node);
|
||||
|
||||
class AddBias : public NodeTransformation {
|
||||
public:
|
||||
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
|
||||
if (node->operation.type == ToString(OperationType::CONVOLUTION_2D)) {
|
||||
return FillBias<Convolution2DAttributes>(node);
|
||||
auto& attr =
|
||||
absl::any_cast<Convolution2DAttributes&>(node->operation.attributes);
|
||||
return FillBias(attr.weights.shape.o, &attr.bias);
|
||||
}
|
||||
if (node->operation.type ==
|
||||
ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
|
||||
return FillBias<ConvolutionTransposedAttributes>(node);
|
||||
auto& attr = absl::any_cast<ConvolutionTransposedAttributes&>(
|
||||
node->operation.attributes);
|
||||
return FillBias(attr.weights.shape.o, &attr.bias);
|
||||
}
|
||||
if (node->operation.type ==
|
||||
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
||||
return FillBias<DepthwiseConvolution2DAttributes>(node);
|
||||
auto& attr = absl::any_cast<DepthwiseConvolution2DAttributes&>(
|
||||
node->operation.attributes);
|
||||
return FillBias(attr.weights.shape.o * attr.weights.shape.i, &attr.bias);
|
||||
}
|
||||
if (node->operation.type == ToString(OperationType::FULLY_CONNECTED)) {
|
||||
return FillBias<FullyConnectedAttributes>(node);
|
||||
auto& attr =
|
||||
absl::any_cast<FullyConnectedAttributes&>(node->operation.attributes);
|
||||
return FillBias(attr.weights.shape.o, &attr.bias);
|
||||
}
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user