Fixed add bias transformation.

Added check for convolution with dynamic weights.

PiperOrigin-RevId: 320996352
Change-Id: Ie88eb026151c8ce49e9987867bc2807e13176cea
This commit is contained in:
Raman Sarokin 2020-07-13 11:26:15 -07:00 committed by TensorFlower Gardener
parent 3576b73743
commit 264eb6ed1d

View File

@ -48,6 +48,11 @@ class AddBias : public NodeTransformation {
public:
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
if (node->operation.type == ToString(OperationType::CONVOLUTION_2D)) {
if (graph->FindInputs(node->id).size() != 1) {
return {TransformStatus::DECLINED,
"This transformation is only applicable to conv with one "
"runtime input."};
}
auto& attr =
absl::any_cast<Convolution2DAttributes&>(node->operation.attributes);
return FillBias(attr.weights.shape.o, &attr.bias);