Removed fuse add to conv transformation. This transformation incorrect for border elements, when used Zero clamping.
PiperOrigin-RevId: 316807675 Change-Id: Iffbdcd3f5a46ecbdc8c4ad1c4f8f4393410f7be3
This commit is contained in:
parent
7c5ddb830f
commit
d68cccb57e
@ -92,66 +92,12 @@ class MergeConvolutionWithAdd : public SequenceTransformation {
|
||||
}
|
||||
};
|
||||
|
||||
class MergeAddWithConvolution : public SequenceTransformation {
|
||||
public:
|
||||
int ExpectedSequenceLength() const final { return 2; }
|
||||
|
||||
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||
GraphFloat32* graph) final {
|
||||
auto& conv_node = *sequence[1];
|
||||
auto& add_node = *sequence[0];
|
||||
if (add_node.operation.type != ToString(OperationType::ADD)) {
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
AddAttributes add_attr =
|
||||
absl::any_cast<AddAttributes>(add_node.operation.attributes);
|
||||
if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
|
||||
add_attr.param) &&
|
||||
!absl::holds_alternative<float>(add_attr.param)) {
|
||||
return {TransformStatus::DECLINED,
|
||||
"This fuse applicable only for broadcast or scalar addition."};
|
||||
}
|
||||
|
||||
if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
|
||||
Convolution2DAttributes* conv_attr =
|
||||
absl::any_cast<Convolution2DAttributes>(
|
||||
&conv_node.operation.attributes);
|
||||
FuseAddWithConvolution2D(add_attr, conv_attr);
|
||||
} else if (conv_node.operation.type ==
|
||||
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
||||
DepthwiseConvolution2DAttributes* conv_attr =
|
||||
absl::any_cast<DepthwiseConvolution2DAttributes>(
|
||||
&conv_node.operation.attributes);
|
||||
FuseAddWithDepthwiseConvolution2D(add_attr, conv_attr);
|
||||
} else if (conv_node.operation.type ==
|
||||
ToString(OperationType::FULLY_CONNECTED)) {
|
||||
FullyConnectedAttributes* conv_attr =
|
||||
absl::any_cast<FullyConnectedAttributes>(
|
||||
&conv_node.operation.attributes);
|
||||
FuseAddWithFullyConnected(add_attr, conv_attr);
|
||||
} else {
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
|
||||
absl::Status status = RemovePrecedingNode(graph, &add_node, &conv_node);
|
||||
if (!status.ok()) {
|
||||
return {TransformStatus::INVALID,
|
||||
"Unable to remove add node after convolution: " +
|
||||
std::string(status.message())};
|
||||
}
|
||||
return {TransformStatus::APPLIED, ""};
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd() {
|
||||
return absl::make_unique<MergeConvolutionWithAdd>();
|
||||
}
|
||||
|
||||
std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution() {
|
||||
return absl::make_unique<MergeAddWithConvolution>();
|
||||
}
|
||||
|
||||
void FuseConvolution2DWithAdd(const AddAttributes& add_attr,
|
||||
Convolution2DAttributes* attr) {
|
||||
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
|
||||
@ -173,65 +119,5 @@ void FuseFullyConnectedWithAdd(const AddAttributes& add_attr,
|
||||
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
|
||||
}
|
||||
|
||||
void FuseAddWithConvolution2D(const AddAttributes& add_attr,
|
||||
Convolution2DAttributes* attr) {
|
||||
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
||||
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
||||
if (attr->bias.data.empty()) {
|
||||
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
|
||||
Linear(attr->weights.shape.o));
|
||||
}
|
||||
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||
const float add_value = add ? add->data[s] : *add_scalar;
|
||||
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||
const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
|
||||
attr->bias.data[d] += attr->weights.data[index] * add_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr,
|
||||
DepthwiseConvolution2DAttributes* attr) {
|
||||
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
||||
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
||||
if (attr->bias.data.empty()) {
|
||||
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
|
||||
Linear(attr->weights.shape.o * attr->weights.shape.i));
|
||||
}
|
||||
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||
const float add_value = add ? add->data[s] : *add_scalar;
|
||||
for (int g = 0; g < attr->weights.shape.o; ++g) {
|
||||
const int d = s * attr->weights.shape.o + g;
|
||||
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
|
||||
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
|
||||
const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
|
||||
attr->bias.data[d] += attr->weights.data[index] * add_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FuseAddWithFullyConnected(const AddAttributes& add_attr,
|
||||
FullyConnectedAttributes* attr) {
|
||||
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
|
||||
auto add_scalar = absl::get_if<float>(&add_attr.param);
|
||||
if (attr->bias.data.empty()) {
|
||||
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
|
||||
Linear(attr->weights.shape.o));
|
||||
}
|
||||
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||
const float add_value = add ? add->data[s] : *add_scalar;
|
||||
const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
|
||||
attr->bias.data[d] += attr->weights.data[index] * add_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -30,11 +30,6 @@ namespace gpu {
|
||||
// convolution.
|
||||
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd();
|
||||
|
||||
// Fuse Add Scalar or Add Broadcast before Convolution(Convolution2D,
|
||||
// DepthWise, FullyConnected) into biases of
|
||||
// convolution.
|
||||
std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution();
|
||||
|
||||
// Modify Convolution2DAttributes so that after making convolution with
|
||||
// modified attributes we will have the same result as convolution
|
||||
// with old attributes and following add operation.
|
||||
@ -59,24 +54,6 @@ void FuseConvolutionTransposedWithAdd(const AddAttributes& add_attr,
|
||||
void FuseFullyConnectedWithAdd(const AddAttributes& add_attr,
|
||||
FullyConnectedAttributes* attr);
|
||||
|
||||
// Modify Convolution2DAttributes so that after making convolution with
|
||||
// modified attributes we will have the same result as add operation and
|
||||
// convolution with old attributes
|
||||
void FuseAddWithConvolution2D(const AddAttributes& add_attr,
|
||||
Convolution2DAttributes* attr);
|
||||
|
||||
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
||||
// convolution with modified attributes we will have the same result as add
|
||||
// operation and depth wise convolution with old attributes
|
||||
void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr,
|
||||
DepthwiseConvolution2DAttributes* attr);
|
||||
|
||||
// Modify FullyConnectedAttributes so that after making fully connected
|
||||
// with modified attributes we will have the same result as add operation and
|
||||
// fully connected with old attributes
|
||||
void FuseAddWithFullyConnected(const AddAttributes& add_attr,
|
||||
FullyConnectedAttributes* attr);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -78,57 +78,6 @@ TEST(MergeConvolutionWithAddTest, Smoke) {
|
||||
graph.nodes()[0]->operation.type);
|
||||
}
|
||||
|
||||
TEST(MergeAddWithConvolutionTest, Smoke) {
|
||||
GraphFloat32 graph;
|
||||
auto input = graph.NewValue();
|
||||
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||
|
||||
Convolution2DAttributes conv_attr;
|
||||
conv_attr.padding.prepended = HW(0, 0);
|
||||
conv_attr.padding.appended = HW(0, 0);
|
||||
conv_attr.strides = HW(1, 1);
|
||||
conv_attr.dilations = HW(1, 1);
|
||||
conv_attr.weights.shape = OHWI(16, 3, 2, 8);
|
||||
conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct());
|
||||
conv_attr.bias.shape = Linear(16);
|
||||
conv_attr.bias.data.resize(16);
|
||||
|
||||
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||
add_tensor.shape = Linear(8);
|
||||
add_tensor.data.resize(8);
|
||||
AddAttributes add_attr;
|
||||
add_attr.param = add_tensor;
|
||||
|
||||
auto conv_node = graph.NewNode();
|
||||
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||
conv_node->operation.attributes = conv_attr;
|
||||
auto add_node = graph.NewNode();
|
||||
add_node->operation.type = ToString(OperationType::ADD);
|
||||
add_node->operation.attributes = add_attr;
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
|
||||
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
ASSERT_EQ(2, graph.nodes().size());
|
||||
ASSERT_EQ(3, graph.values().size());
|
||||
|
||||
auto transformation = NewMergeAddWithConvolution();
|
||||
ModelTransformer transformer(&graph, nullptr);
|
||||
transformer.Apply("merge_add_with_convolution", transformation.get());
|
||||
|
||||
EXPECT_EQ(1, graph.nodes().size());
|
||||
EXPECT_EQ(2, graph.values().size());
|
||||
EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D),
|
||||
graph.nodes()[0]->operation.type);
|
||||
}
|
||||
|
||||
TEST(FuseAddAfterConvolution2DTest, Smoke) {
|
||||
Convolution2DAttributes attr;
|
||||
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||
@ -213,69 +162,6 @@ TEST(FuseAddAfterFullyConnectedTest, Smoke) {
|
||||
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f}));
|
||||
}
|
||||
|
||||
TEST(FuseAddBeforeConvolution2DTest, Smoke) {
|
||||
Convolution2DAttributes attr;
|
||||
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||
attr.bias.shape = Linear(2);
|
||||
attr.bias.data = {1.1f, 1.2f};
|
||||
|
||||
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||
add_tensor.shape = Linear(2);
|
||||
add_tensor.data = {2.0f, 0.5f};
|
||||
AddAttributes add_attr;
|
||||
add_attr.param = add_tensor;
|
||||
|
||||
FuseAddWithConvolution2D(add_attr, &attr);
|
||||
|
||||
EXPECT_THAT(attr.weights.data,
|
||||
Pointwise(FloatNear(1e-6),
|
||||
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
|
||||
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {2.2f, 4.3f}));
|
||||
}
|
||||
|
||||
TEST(FuseAddBeforeDepthwiseConvolution2DTest, Smoke) {
|
||||
DepthwiseConvolution2DAttributes attr;
|
||||
attr.weights.shape = OHWI(2, 1, 2, 2);
|
||||
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
|
||||
attr.bias.shape = Linear(4);
|
||||
attr.bias.data = {1.1f, 1.2f, 1.3f, 1.4f};
|
||||
|
||||
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||
add_tensor.shape = Linear(4);
|
||||
add_tensor.data = {0.3f, 0.7f, 0.5f, 0.1f};
|
||||
AddAttributes add_attr;
|
||||
add_attr.param = add_tensor;
|
||||
|
||||
FuseAddWithDepthwiseConvolution2D(add_attr, &attr);
|
||||
|
||||
EXPECT_THAT(attr.weights.data,
|
||||
Pointwise(FloatNear(1e-6),
|
||||
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
|
||||
EXPECT_THAT(attr.bias.data,
|
||||
Pointwise(FloatNear(1e-6), {1.22f, 1.56f, 1.72f, 2.38f}));
|
||||
}
|
||||
|
||||
TEST(FuseAddBeforeFullyConnectedTest, Smoke) {
|
||||
FullyConnectedAttributes attr;
|
||||
attr.weights.shape = OHWI(2, 1, 1, 2);
|
||||
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f};
|
||||
attr.bias.shape = Linear(2);
|
||||
attr.bias.data = {1.1f, 1.2f};
|
||||
|
||||
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||
add_tensor.shape = Linear(2);
|
||||
add_tensor.data = {0.5f, 2.0f};
|
||||
AddAttributes add_attr;
|
||||
add_attr.param = add_tensor;
|
||||
|
||||
FuseAddWithFullyConnected(add_attr, &attr);
|
||||
|
||||
EXPECT_THAT(attr.weights.data,
|
||||
Pointwise(FloatNear(1e-6), {0.1f, 0.2f, 0.3f, 0.4f}));
|
||||
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.55f, 2.15f}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -54,9 +54,7 @@ bool ApplyGeneralTransformations(ModelTransformer* transformer) {
|
||||
transformer->Apply("merge_convolution_with_add",
|
||||
NewMergeConvolutionWithAdd().get()) &&
|
||||
transformer->Apply("merge_mul_with_convolution",
|
||||
NewMergeMulWithConvolution().get()) &&
|
||||
transformer->Apply("merge_add_with_convolution",
|
||||
NewMergeAddWithConvolution().get());
|
||||
NewMergeMulWithConvolution().get());
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
Loading…
Reference in New Issue
Block a user