Removed fuse add to conv transformation. This transformation incorrect for border elements, when used Zero clamping.

PiperOrigin-RevId: 316807675
Change-Id: Iffbdcd3f5a46ecbdc8c4ad1c4f8f4393410f7be3
This commit is contained in:
Raman Sarokin 2020-06-16 19:54:13 -07:00 committed by TensorFlower Gardener
parent 7c5ddb830f
commit d68cccb57e
4 changed files with 1 additions and 254 deletions

View File

@ -92,66 +92,12 @@ class MergeConvolutionWithAdd : public SequenceTransformation {
}
};
class MergeAddWithConvolution : public SequenceTransformation {
public:
int ExpectedSequenceLength() const final { return 2; }
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
GraphFloat32* graph) final {
auto& conv_node = *sequence[1];
auto& add_node = *sequence[0];
if (add_node.operation.type != ToString(OperationType::ADD)) {
return {TransformStatus::SKIPPED, ""};
}
AddAttributes add_attr =
absl::any_cast<AddAttributes>(add_node.operation.attributes);
if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
add_attr.param) &&
!absl::holds_alternative<float>(add_attr.param)) {
return {TransformStatus::DECLINED,
"This fuse applicable only for broadcast or scalar addition."};
}
if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
Convolution2DAttributes* conv_attr =
absl::any_cast<Convolution2DAttributes>(
&conv_node.operation.attributes);
FuseAddWithConvolution2D(add_attr, conv_attr);
} else if (conv_node.operation.type ==
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
DepthwiseConvolution2DAttributes* conv_attr =
absl::any_cast<DepthwiseConvolution2DAttributes>(
&conv_node.operation.attributes);
FuseAddWithDepthwiseConvolution2D(add_attr, conv_attr);
} else if (conv_node.operation.type ==
ToString(OperationType::FULLY_CONNECTED)) {
FullyConnectedAttributes* conv_attr =
absl::any_cast<FullyConnectedAttributes>(
&conv_node.operation.attributes);
FuseAddWithFullyConnected(add_attr, conv_attr);
} else {
return {TransformStatus::SKIPPED, ""};
}
absl::Status status = RemovePrecedingNode(graph, &add_node, &conv_node);
if (!status.ok()) {
return {TransformStatus::INVALID,
"Unable to remove add node after convolution: " +
std::string(status.message())};
}
return {TransformStatus::APPLIED, ""};
}
};
} // namespace
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd() {
return absl::make_unique<MergeConvolutionWithAdd>();
}
std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution() {
return absl::make_unique<MergeAddWithConvolution>();
}
void FuseConvolution2DWithAdd(const AddAttributes& add_attr,
Convolution2DAttributes* attr) {
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
@ -173,65 +119,5 @@ void FuseFullyConnectedWithAdd(const AddAttributes& add_attr,
FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
}
void FuseAddWithConvolution2D(const AddAttributes& add_attr,
Convolution2DAttributes* attr) {
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
auto add_scalar = absl::get_if<float>(&add_attr.param);
if (attr->bias.data.empty()) {
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
Linear(attr->weights.shape.o));
}
for (int d = 0; d < attr->weights.shape.o; ++d) {
for (int s = 0; s < attr->weights.shape.i; ++s) {
const float add_value = add ? add->data[s] : *add_scalar;
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
attr->bias.data[d] += attr->weights.data[index] * add_value;
}
}
}
}
}
void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr,
DepthwiseConvolution2DAttributes* attr) {
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
auto add_scalar = absl::get_if<float>(&add_attr.param);
if (attr->bias.data.empty()) {
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
Linear(attr->weights.shape.o * attr->weights.shape.i));
}
for (int s = 0; s < attr->weights.shape.i; ++s) {
const float add_value = add ? add->data[s] : *add_scalar;
for (int g = 0; g < attr->weights.shape.o; ++g) {
const int d = s * attr->weights.shape.o + g;
for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
const int index = attr->weights.shape.LinearIndex({{g, k_y, k_x, s}});
attr->bias.data[d] += attr->weights.data[index] * add_value;
}
}
}
}
}
void FuseAddWithFullyConnected(const AddAttributes& add_attr,
FullyConnectedAttributes* attr) {
auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
auto add_scalar = absl::get_if<float>(&add_attr.param);
if (attr->bias.data.empty()) {
attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
Linear(attr->weights.shape.o));
}
for (int d = 0; d < attr->weights.shape.o; ++d) {
for (int s = 0; s < attr->weights.shape.i; ++s) {
const float add_value = add ? add->data[s] : *add_scalar;
const int index = attr->weights.shape.LinearIndex({{d, 0, 0, s}});
attr->bias.data[d] += attr->weights.data[index] * add_value;
}
}
}
} // namespace gpu
} // namespace tflite

View File

@ -30,11 +30,6 @@ namespace gpu {
// convolution.
std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd();
// Fuse Add Scalar or Add Broadcast before Convolution(Convolution2D,
// DepthWise, FullyConnected) into biases of
// convolution.
std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution();
// Modify Convolution2DAttributes so that after making convolution with
// modified attributes we will have the same result as convolution
// with old attributes and following add operation.
@ -59,24 +54,6 @@ void FuseConvolutionTransposedWithAdd(const AddAttributes& add_attr,
void FuseFullyConnectedWithAdd(const AddAttributes& add_attr,
FullyConnectedAttributes* attr);
// Modify Convolution2DAttributes so that after making convolution with
// modified attributes we will have the same result as add operation and
// convolution with old attributes
void FuseAddWithConvolution2D(const AddAttributes& add_attr,
Convolution2DAttributes* attr);
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
// convolution with modified attributes we will have the same result as add
// operation and depth wise convolution with old attributes
void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr,
DepthwiseConvolution2DAttributes* attr);
// Modify FullyConnectedAttributes so that after making fully connected
// with modified attributes we will have the same result as add operation and
// fully connected with old attributes
void FuseAddWithFullyConnected(const AddAttributes& add_attr,
FullyConnectedAttributes* attr);
} // namespace gpu
} // namespace tflite

View File

@ -78,57 +78,6 @@ TEST(MergeConvolutionWithAddTest, Smoke) {
graph.nodes()[0]->operation.type);
}
TEST(MergeAddWithConvolutionTest, Smoke) {
GraphFloat32 graph;
auto input = graph.NewValue();
input->tensor.shape = BHWC(1, 4, 4, 8);
Convolution2DAttributes conv_attr;
conv_attr.padding.prepended = HW(0, 0);
conv_attr.padding.appended = HW(0, 0);
conv_attr.strides = HW(1, 1);
conv_attr.dilations = HW(1, 1);
conv_attr.weights.shape = OHWI(16, 3, 2, 8);
conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct());
conv_attr.bias.shape = Linear(16);
conv_attr.bias.data.resize(16);
Tensor<Linear, DataType::FLOAT32> add_tensor;
add_tensor.shape = Linear(8);
add_tensor.data.resize(8);
AddAttributes add_attr;
add_attr.param = add_tensor;
auto conv_node = graph.NewNode();
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
conv_node->operation.attributes = conv_attr;
auto add_node = graph.NewNode();
add_node->operation.type = ToString(OperationType::ADD);
add_node->operation.attributes = add_attr;
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
Value* output;
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16);
Value* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);
ASSERT_EQ(2, graph.nodes().size());
ASSERT_EQ(3, graph.values().size());
auto transformation = NewMergeAddWithConvolution();
ModelTransformer transformer(&graph, nullptr);
transformer.Apply("merge_add_with_convolution", transformation.get());
EXPECT_EQ(1, graph.nodes().size());
EXPECT_EQ(2, graph.values().size());
EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D),
graph.nodes()[0]->operation.type);
}
TEST(FuseAddAfterConvolution2DTest, Smoke) {
Convolution2DAttributes attr;
attr.weights.shape = OHWI(2, 1, 2, 2);
@ -213,69 +162,6 @@ TEST(FuseAddAfterFullyConnectedTest, Smoke) {
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f}));
}
TEST(FuseAddBeforeConvolution2DTest, Smoke) {
Convolution2DAttributes attr;
attr.weights.shape = OHWI(2, 1, 2, 2);
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
attr.bias.shape = Linear(2);
attr.bias.data = {1.1f, 1.2f};
Tensor<Linear, DataType::FLOAT32> add_tensor;
add_tensor.shape = Linear(2);
add_tensor.data = {2.0f, 0.5f};
AddAttributes add_attr;
add_attr.param = add_tensor;
FuseAddWithConvolution2D(add_attr, &attr);
EXPECT_THAT(attr.weights.data,
Pointwise(FloatNear(1e-6),
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {2.2f, 4.3f}));
}
TEST(FuseAddBeforeDepthwiseConvolution2DTest, Smoke) {
DepthwiseConvolution2DAttributes attr;
attr.weights.shape = OHWI(2, 1, 2, 2);
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
attr.bias.shape = Linear(4);
attr.bias.data = {1.1f, 1.2f, 1.3f, 1.4f};
Tensor<Linear, DataType::FLOAT32> add_tensor;
add_tensor.shape = Linear(4);
add_tensor.data = {0.3f, 0.7f, 0.5f, 0.1f};
AddAttributes add_attr;
add_attr.param = add_tensor;
FuseAddWithDepthwiseConvolution2D(add_attr, &attr);
EXPECT_THAT(attr.weights.data,
Pointwise(FloatNear(1e-6),
{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}));
EXPECT_THAT(attr.bias.data,
Pointwise(FloatNear(1e-6), {1.22f, 1.56f, 1.72f, 2.38f}));
}
TEST(FuseAddBeforeFullyConnectedTest, Smoke) {
FullyConnectedAttributes attr;
attr.weights.shape = OHWI(2, 1, 1, 2);
attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f};
attr.bias.shape = Linear(2);
attr.bias.data = {1.1f, 1.2f};
Tensor<Linear, DataType::FLOAT32> add_tensor;
add_tensor.shape = Linear(2);
add_tensor.data = {0.5f, 2.0f};
AddAttributes add_attr;
add_attr.param = add_tensor;
FuseAddWithFullyConnected(add_attr, &attr);
EXPECT_THAT(attr.weights.data,
Pointwise(FloatNear(1e-6), {0.1f, 0.2f, 0.3f, 0.4f}));
EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.55f, 2.15f}));
}
} // namespace
} // namespace gpu
} // namespace tflite

View File

@ -54,9 +54,7 @@ bool ApplyGeneralTransformations(ModelTransformer* transformer) {
transformer->Apply("merge_convolution_with_add",
NewMergeConvolutionWithAdd().get()) &&
transformer->Apply("merge_mul_with_convolution",
NewMergeMulWithConvolution().get()) &&
transformer->Apply("merge_add_with_convolution",
NewMergeAddWithConvolution().get());
NewMergeMulWithConvolution().get());
}
} // namespace gpu