Removed fuse add to conv transformation. This transformation incorrect for border elements, when used Zero clamping.
PiperOrigin-RevId: 316807675 Change-Id: Iffbdcd3f5a46ecbdc8c4ad1c4f8f4393410f7be3
This commit is contained in:
parent
7c5ddb830f
commit
d68cccb57e
@ -92,66 +92,12 @@ class MergeConvolutionWithAdd : public SequenceTransformation {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MergeAddWithConvolution : public SequenceTransformation {
|
|
||||||
public:
|
|
||||||
int ExpectedSequenceLength() const final { return 2; }
|
|
||||||
|
|
||||||
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
|
||||||
GraphFloat32* graph) final {
|
|
||||||
auto& conv_node = *sequence[1];
|
|
||||||
auto& add_node = *sequence[0];
|
|
||||||
if (add_node.operation.type != ToString(OperationType::ADD)) {
|
|
||||||
return {TransformStatus::SKIPPED, ""};
|
|
||||||
}
|
|
||||||
AddAttributes add_attr =
|
|
||||||
absl::any_cast<AddAttributes>(add_node.operation.attributes);
|
|
||||||
if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
|
|
||||||
add_attr.param) &&
|
|
||||||
!absl::holds_alternative<float>(add_attr.param)) {
|
|
||||||
return {TransformStatus::DECLINED,
|
|
||||||
"This fuse applicable only for broadcast or scalar addition."};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
|
|
||||||
Convolution2DAttributes* conv_attr =
|
|
||||||
absl::any_cast<Convolution2DAttributes>(
|
|
||||||
&conv_node.operation.attributes);
|
|
||||||
FuseAddWithConvolution2D(add_attr, conv_attr);
|
|
||||||
} else if (conv_node.operation.type ==
|
|
||||||
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
|
||||||
DepthwiseConvolution2DAttributes* conv_attr =
|
|
||||||
absl::any_cast<DepthwiseConvolution2DAttributes>(
|
|
||||||
&conv_node.operation.attributes);
|
|
||||||
FuseAddWithDepthwiseConvolution2D(add_attr, conv_attr);
|
|
||||||
} else if (conv_node.operation.type ==
|
|
||||||
ToString(OperationType::FULLY_CONNECTED)) {
|
|
||||||
FullyConnectedAttributes* conv_attr =
|
|
||||||
absl::any_cast<FullyConnectedAttributes>(
|
|
||||||
&conv_node.operation.attributes);
|
|
||||||
FuseAddWithFullyConnected(add_attr, conv_attr);
|
|
||||||
} else {
|
|
||||||
return {TransformStatus::SKIPPED, ""};
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status status = RemovePrecedingNode(graph, &add_node, &conv_node);
|
|
||||||
if (!status.ok()) {
|
|
||||||
return {TransformStatus::INVALID,
|
|
||||||
"Unable to remove add node after convolution: " +
|
|
||||||
std::string(status.message())};
|
|
||||||
}
|
|
||||||
return {TransformStatus::APPLIED, ""};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd() {
|
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd() {
|
||||||
return absl::make_unique<MergeConvolutionWithAdd>();
|
return absl::make_unique<MergeConvolutionWithAdd>();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution() {
|
|
||||||
return absl::make_unique<MergeAddWithConvolution>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void FuseConvolution2DWithAdd(const AddAttributes& add_attr,
|
void FuseConvolution2DWithAdd(const AddAttributes& add_attr,
|
||||||
Convolution2DAttributes* attr) {
|
Convolution2DAttributes* attr) {
|
||||||
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
|
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
|
||||||
@ -173,65 +119,5 @@ void FuseFullyConnectedWithAdd(const AddAttributes& add_attr,
|
|||||||
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
|
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuseAddWithConvolution2D(const AddAttributes& add_attr,
|
|
||||||
Convolution2DAttributes* attr) {
|
|
||||||
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
|
||||||
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
|
||||||
if (attr->bias.data.empty()) {
|
|
||||||
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
|
|
||||||
Linear(attr->weights.shape.o));
|
|
||||||
}
|
|
||||||
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
|
||||||
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
|
||||||
const float add_value = add ? add->data[s] : *add_scalar;
|
|
||||||
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
|
||||||
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
|
||||||
const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
|
|
||||||
attr->bias.data[d] += attr->weights.data[index] * add_value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr,
|
|
||||||
DepthwiseConvolution2DAttributes* attr) {
|
|
||||||
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
|
||||||
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
|
||||||
if (attr->bias.data.empty()) {
|
|
||||||
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
|
|
||||||
Linear(attr->weights.shape.o * attr->weights.shape.i));
|
|
||||||
}
|
|
||||||
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
|
||||||
const float add_value = add ? add->data[s] : *add_scalar;
|
|
||||||
for (int g = 0; g < attr->weights.shape.o; ++g) {
|
|
||||||
const int d = s * attr->weights.shape.o + g;
|
|
||||||
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
|
||||||
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
|
||||||
const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
|
|
||||||
attr->bias.data[d] += attr->weights.data[index] * add_value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void FuseAddWithFullyConnected(const AddAttributes& add_attr,
|
|
||||||
FullyConnectedAttributes* attr) {
|
|
||||||
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
|
||||||
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
|
||||||
if (attr->bias.data.empty()) {
|
|
||||||
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
|
|
||||||
Linear(attr->weights.shape.o));
|
|
||||||
}
|
|
||||||
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
|
||||||
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
|
||||||
const float add_value = add ? add->data[s] : *add_scalar;
|
|
||||||
const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
|
|
||||||
attr->bias.data[d] += attr->weights.data[index] * add_value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -30,11 +30,6 @@ namespace gpu {
|
|||||||
// convolution.
|
// convolution.
|
||||||
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd();
|
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd();
|
||||||
|
|
||||||
// Fuse Add Scalar or Add Broadcast before Convolution(Convolution2D,
|
|
||||||
// DepthWise, FullyConnected) into biases of
|
|
||||||
// convolution.
|
|
||||||
std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution();
|
|
||||||
|
|
||||||
// Modify Convolution2DAttributes so that after making convolution with
|
// Modify Convolution2DAttributes so that after making convolution with
|
||||||
// modified attributes we will have the same result as convolution
|
// modified attributes we will have the same result as convolution
|
||||||
// with old attributes and following add operation.
|
// with old attributes and following add operation.
|
||||||
@ -59,24 +54,6 @@ void FuseConvolutionTransposedWithAdd(const AddAttributes& add_attr,
|
|||||||
void FuseFullyConnectedWithAdd(const AddAttributes& add_attr,
|
void FuseFullyConnectedWithAdd(const AddAttributes& add_attr,
|
||||||
FullyConnectedAttributes* attr);
|
FullyConnectedAttributes* attr);
|
||||||
|
|
||||||
// Modify Convolution2DAttributes so that after making convolution with
|
|
||||||
// modified attributes we will have the same result as add operation and
|
|
||||||
// convolution with old attributes
|
|
||||||
void FuseAddWithConvolution2D(const AddAttributes& add_attr,
|
|
||||||
Convolution2DAttributes* attr);
|
|
||||||
|
|
||||||
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
|
||||||
// convolution with modified attributes we will have the same result as add
|
|
||||||
// operation and depth wise convolution with old attributes
|
|
||||||
void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr,
|
|
||||||
DepthwiseConvolution2DAttributes* attr);
|
|
||||||
|
|
||||||
// Modify FullyConnectedAttributes so that after making fully connected
|
|
||||||
// with modified attributes we will have the same result as add operation and
|
|
||||||
// fully connected with old attributes
|
|
||||||
void FuseAddWithFullyConnected(const AddAttributes& add_attr,
|
|
||||||
FullyConnectedAttributes* attr);
|
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -78,57 +78,6 @@ TEST(MergeConvolutionWithAddTest, Smoke) {
|
|||||||
graph.nodes()[0]->operation.type);
|
graph.nodes()[0]->operation.type);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MergeAddWithConvolutionTest, Smoke) {
|
|
||||||
GraphFloat32 graph;
|
|
||||||
auto input = graph.NewValue();
|
|
||||||
input->tensor.shape = BHWC(1, 4, 4, 8);
|
|
||||||
|
|
||||||
Convolution2DAttributes conv_attr;
|
|
||||||
conv_attr.padding.prepended = HW(0, 0);
|
|
||||||
conv_attr.padding.appended = HW(0, 0);
|
|
||||||
conv_attr.strides = HW(1, 1);
|
|
||||||
conv_attr.dilations = HW(1, 1);
|
|
||||||
conv_attr.weights.shape = OHWI(16, 3, 2, 8);
|
|
||||||
conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct());
|
|
||||||
conv_attr.bias.shape = Linear(16);
|
|
||||||
conv_attr.bias.data.resize(16);
|
|
||||||
|
|
||||||
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
|
||||||
add_tensor.shape = Linear(8);
|
|
||||||
add_tensor.data.resize(8);
|
|
||||||
AddAttributes add_attr;
|
|
||||||
add_attr.param = add_tensor;
|
|
||||||
|
|
||||||
auto conv_node = graph.NewNode();
|
|
||||||
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
|
||||||
conv_node->operation.attributes = conv_attr;
|
|
||||||
auto add_node = graph.NewNode();
|
|
||||||
add_node->operation.type = ToString(OperationType::ADD);
|
|
||||||
add_node->operation.attributes = add_attr;
|
|
||||||
|
|
||||||
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
|
|
||||||
|
|
||||||
Value* output;
|
|
||||||
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
|
|
||||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
|
||||||
|
|
||||||
Value* link1;
|
|
||||||
ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok());
|
|
||||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
|
||||||
|
|
||||||
ASSERT_EQ(2, graph.nodes().size());
|
|
||||||
ASSERT_EQ(3, graph.values().size());
|
|
||||||
|
|
||||||
auto transformation = NewMergeAddWithConvolution();
|
|
||||||
ModelTransformer transformer(&graph, nullptr);
|
|
||||||
transformer.Apply("merge_add_with_convolution", transformation.get());
|
|
||||||
|
|
||||||
EXPECT_EQ(1, graph.nodes().size());
|
|
||||||
EXPECT_EQ(2, graph.values().size());
|
|
||||||
EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D),
|
|
||||||
graph.nodes()[0]->operation.type);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(FuseAddAfterConvolution2DTest, Smoke) {
|
TEST(FuseAddAfterConvolution2DTest, Smoke) {
|
||||||
Convolution2DAttributes attr;
|
Convolution2DAttributes attr;
|
||||||
attr.weights.shape = OHWI(2, 1, 2, 2);
|
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||||
@ -213,69 +162,6 @@ TEST(FuseAddAfterFullyConnectedTest, Smoke) {
|
|||||||
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f}));
|
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(FuseAddBeforeConvolution2DTest, Smoke) {
|
|
||||||
Convolution2DAttributes attr;
|
|
||||||
attr.weights.shape = OHWI(2, 1, 2, 2);
|
|
||||||
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
|
||||||
attr.bias.shape = Linear(2);
|
|
||||||
attr.bias.data = {1.1f, 1.2f};
|
|
||||||
|
|
||||||
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
|
||||||
add_tensor.shape = Linear(2);
|
|
||||||
add_tensor.data = {2.0f, 0.5f};
|
|
||||||
AddAttributes add_attr;
|
|
||||||
add_attr.param = add_tensor;
|
|
||||||
|
|
||||||
FuseAddWithConvolution2D(add_attr, &attr);
|
|
||||||
|
|
||||||
EXPECT_THAT(attr.weights.data,
|
|
||||||
Pointwise(FloatNear(1e-6),
|
|
||||||
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
|
|
||||||
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {2.2f, 4.3f}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(FuseAddBeforeDepthwiseConvolution2DTest, Smoke) {
|
|
||||||
DepthwiseConvolution2DAttributes attr;
|
|
||||||
attr.weights.shape = OHWI(2, 1, 2, 2);
|
|
||||||
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
|
||||||
attr.bias.shape = Linear(4);
|
|
||||||
attr.bias.data = {1.1f, 1.2f, 1.3f, 1.4f};
|
|
||||||
|
|
||||||
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
|
||||||
add_tensor.shape = Linear(4);
|
|
||||||
add_tensor.data = {0.3f, 0.7f, 0.5f, 0.1f};
|
|
||||||
AddAttributes add_attr;
|
|
||||||
add_attr.param = add_tensor;
|
|
||||||
|
|
||||||
FuseAddWithDepthwiseConvolution2D(add_attr, &attr);
|
|
||||||
|
|
||||||
EXPECT_THAT(attr.weights.data,
|
|
||||||
Pointwise(FloatNear(1e-6),
|
|
||||||
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
|
|
||||||
EXPECT_THAT(attr.bias.data,
|
|
||||||
Pointwise(FloatNear(1e-6), {1.22f, 1.56f, 1.72f, 2.38f}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(FuseAddBeforeFullyConnectedTest, Smoke) {
|
|
||||||
FullyConnectedAttributes attr;
|
|
||||||
attr.weights.shape = OHWI(2, 1, 1, 2);
|
|
||||||
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f};
|
|
||||||
attr.bias.shape = Linear(2);
|
|
||||||
attr.bias.data = {1.1f, 1.2f};
|
|
||||||
|
|
||||||
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
|
||||||
add_tensor.shape = Linear(2);
|
|
||||||
add_tensor.data = {0.5f, 2.0f};
|
|
||||||
AddAttributes add_attr;
|
|
||||||
add_attr.param = add_tensor;
|
|
||||||
|
|
||||||
FuseAddWithFullyConnected(add_attr, &attr);
|
|
||||||
|
|
||||||
EXPECT_THAT(attr.weights.data,
|
|
||||||
Pointwise(FloatNear(1e-6), {0.1f, 0.2f, 0.3f, 0.4f}));
|
|
||||||
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.55f, 2.15f}));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -54,9 +54,7 @@ bool ApplyGeneralTransformations(ModelTransformer* transformer) {
|
|||||||
transformer->Apply("merge_convolution_with_add",
|
transformer->Apply("merge_convolution_with_add",
|
||||||
NewMergeConvolutionWithAdd().get()) &&
|
NewMergeConvolutionWithAdd().get()) &&
|
||||||
transformer->Apply("merge_mul_with_convolution",
|
transformer->Apply("merge_mul_with_convolution",
|
||||||
NewMergeMulWithConvolution().get()) &&
|
NewMergeMulWithConvolution().get());
|
||||||
transformer->Apply("merge_add_with_convolution",
|
|
||||||
NewMergeAddWithConvolution().get());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
Loading…
Reference in New Issue
Block a user