[Grappler] Fuse Relu6 and Elu into Conv2D

PiperOrigin-RevId: 242907340
This commit is contained in:
Eugene Zhulenev 2019-04-10 11:29:35 -07:00 committed by TensorFlower Gardener
parent 12b9e3ad86
commit 2ace7a7d9d
5 changed files with 218 additions and 166 deletions

View File

@ -211,6 +211,8 @@ bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
return false;
}
bool IsElu(const NodeDef& node) { return node.op() == "Elu"; }
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
bool IsEnter(const NodeDef& node) {
@ -412,6 +414,8 @@ bool IsReduction(const NodeDef& node) {
bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
bool IsRelu6(const NodeDef& node) { return node.op() == "Relu6"; }
bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }

View File

@ -66,6 +66,7 @@ bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
bool IsDivNoNan(const NodeDef& node);
bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
bool IsElu(const NodeDef& node);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
@ -134,6 +135,7 @@ bool IsReciprocalGrad(const NodeDef& node);
bool IsRecv(const NodeDef& node);
bool IsReduction(const NodeDef& node);
bool IsRelu(const NodeDef& node);
bool IsRelu6(const NodeDef& node);
bool IsRelu6Grad(const NodeDef& node);
bool IsReluGrad(const NodeDef& node);
bool IsReshape(const NodeDef& node);

View File

@ -78,16 +78,16 @@ struct Conv2DWithBiasAdd {
const NodeDef* bias_add = nullptr;
};
// Conv2D node followed by a BiasAdd and Relu.
struct Conv2DWithBiasAddAndRelu {
Conv2DWithBiasAddAndRelu() = default;
Conv2DWithBiasAddAndRelu(const NodeDef* conv2d, const NodeDef* bias_add,
const NodeDef* relu)
: conv2d(conv2d), bias_add(bias_add), relu(relu) {}
// Conv2D node followed by a BiasAdd and Activation.
struct Conv2DWithBiasAddAndActivation {
Conv2DWithBiasAddAndActivation() = default;
Conv2DWithBiasAddAndActivation(const NodeDef* conv2d, const NodeDef* bias_add,
const NodeDef* activation)
: conv2d(conv2d), bias_add(bias_add), activation(activation) {}
const NodeDef* conv2d = nullptr;
const NodeDef* bias_add = nullptr;
const NodeDef* relu = nullptr;
const NodeDef* activation = nullptr;
};
// Conv2D node followed by a Squeeze and BiasAdd.
@ -114,20 +114,21 @@ struct Conv2DWithBatchNorm {
float epsilon = 0.0;
};
// Conv2D node followed by a FusedBatchNorm and Relu.
struct Conv2DWithBatchNormAndRelu {
Conv2DWithBatchNormAndRelu() = default;
Conv2DWithBatchNormAndRelu(const NodeDef* conv2d,
const NodeDef* fused_batch_norm,
const NodeDef* relu, float epsilon = 0.0)
// Conv2D node followed by a FusedBatchNorm and Activation.
struct Conv2DWithBatchNormAndActivation {
Conv2DWithBatchNormAndActivation() = default;
Conv2DWithBatchNormAndActivation(const NodeDef* conv2d,
const NodeDef* fused_batch_norm,
const NodeDef* activation,
float epsilon = 0.0)
: conv2d(conv2d),
fused_batch_norm(fused_batch_norm),
relu(relu),
activation(activation),
epsilon(epsilon) {}
const NodeDef* conv2d = nullptr;
const NodeDef* fused_batch_norm = nullptr;
const NodeDef* relu = nullptr;
const NodeDef* activation = nullptr;
float epsilon = 0.0;
};
@ -194,7 +195,7 @@ bool IsCpuCompatible(const Pattern& matched) {
// Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device.
bool IsGpuCompatible(const RemapperContext& ctx,
const Conv2DWithBiasAddAndRelu& matched) {
const Conv2DWithBiasAddAndActivation& matched) {
const std::vector<OpInfo::TensorProperties>& input_props =
ctx.graph_properties.GetInputProperties(matched.conv2d->name());
const TensorShapeProto& filter_shape =
@ -226,6 +227,10 @@ bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
return IsCpuCompatible(matched) || IsGpuCompatible(ctx, matched);
}
bool IsSupportedActivation(const NodeDef& node) {
return IsRelu(node) || IsRelu6(node) || IsElu(node);
}
bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add,
Conv2DWithBiasAdd* matched,
bool check_device_compatible = true) {
@ -259,31 +264,34 @@ bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add,
return true;
}
bool FindConv2DWithBiasAndRelu(const RemapperContext& ctx, const NodeDef* relu,
Conv2DWithBiasAddAndRelu* matched) {
bool FindConv2DWithBiasAndActivation(const RemapperContext& ctx,
const NodeDef* activation,
Conv2DWithBiasAddAndActivation* matched) {
if (!EigenSupportsContractionOutputKernel()) return false;
// Root of the pattern must be a Relu.
if (!relu || !IsRelu(*relu) || HasControlFaninOrFanout(ctx.graph_view, relu))
// Root of the pattern must be an activation node.
if (!activation || !IsSupportedActivation(*activation) ||
HasControlFaninOrFanout(ctx.graph_view, activation))
return false;
// And input to Relu must match Conv2DWithBiasAdd pattern.
const auto input_port = GraphView::InputPort(relu, 0);
// And input to the activation node must match Conv2DWithBiasAdd pattern.
const auto input_port = GraphView::InputPort(activation, 0);
const auto bias_add = ctx.graph_view.GetRegularFanin(input_port);
Conv2DWithBiasAdd base;
if (!FindConv2DWithBias(ctx, bias_add.node, &base,
/*check_device_compatible=*/false) ||
!HasSingleFanoutNode(ctx.graph_view, base.bias_add) ||
!HaveSameDataType(relu, base.bias_add) ||
!HaveSameDataType(activation, base.bias_add) ||
IsInPreserveSet(ctx, base.bias_add))
return false;
// Check that data type and data format are supported on assigned device.
const Conv2DWithBiasAddAndRelu pattern{base.conv2d, base.bias_add, relu};
const Conv2DWithBiasAddAndActivation pattern{base.conv2d, base.bias_add,
activation};
if (!IsDeviceCompatible(ctx, pattern)) return false;
// We successfully found a Conv2D+BiasAdd+Relu pattern.
// We successfully found a Conv2D+BiasAdd+Activation pattern.
*matched = pattern;
return true;
@ -387,16 +395,17 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx,
return true;
}
bool FindConv2DWithBatchNormAndRelu(const RemapperContext& ctx,
const NodeDef* node,
Conv2DWithBatchNormAndRelu* matched) {
bool FindConv2DWithBatchNormAndActivation(
const RemapperContext& ctx, const NodeDef* node,
Conv2DWithBatchNormAndActivation* matched) {
if (!EigenSupportsContractionOutputKernel()) return false;
// Root of the pattern must be a Relu.
if (!node || !IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node))
// Root of the pattern must be an activation node.
if (!node || !IsSupportedActivation(*node) ||
HasControlFaninOrFanout(ctx.graph_view, node))
return false;
// And input to Relu must match Conv2DWithBatchNorm pattern.
// And input to the activation node must match Conv2DWithBatchNorm pattern.
const auto input_port = GraphView::InputPort(node, 0);
const auto batch_norm = ctx.graph_view.GetRegularFanin(input_port);
@ -407,10 +416,10 @@ bool FindConv2DWithBatchNormAndRelu(const RemapperContext& ctx,
IsInPreserveSet(ctx, base.fused_batch_norm))
return false;
// We successfully found a Conv2D+FusedBatchNorm+Relu pattern.
// We successfully found a Conv2D+FusedBatchNorm+Activation pattern.
matched->conv2d = base.conv2d;
matched->fused_batch_norm = base.fused_batch_norm;
matched->relu = node;
matched->activation = node;
matched->epsilon = base.epsilon;
return true;
@ -511,19 +520,19 @@ void AddFusedConv2DNode(
}
void AddFusedConv2DNode(
const RemapperContext& ctx, const Conv2DWithBiasAddAndRelu& matched,
const RemapperContext& ctx, const Conv2DWithBiasAddAndActivation& matched,
GraphDef* optimized_graph,
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
DCHECK(IsDeviceCompatible(ctx, matched))
<< "Unsupported fused Conv2D pattern";
VLOG(2) << "Fuse Conv2D with BiasAdd and Relu: "
<< " relu=" << matched.relu->name()
VLOG(2) << "Fuse Conv2D with BiasAdd and " << matched.activation->op() << ":"
<< " activation=" << matched.activation->name()
<< " bias_add=" << matched.bias_add->name()
<< " conv2d=" << matched.conv2d->name();
NodeDef* fused_conv2d = optimized_graph->add_node();
fused_conv2d->set_name(matched.relu->name());
fused_conv2d->set_name(matched.activation->name());
fused_conv2d->set_op(kFusedConv2D);
fused_conv2d->set_device(matched.conv2d->device());
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
@ -531,9 +540,9 @@ void AddFusedConv2DNode(
fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd", "Relu"});
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd", matched.activation->op()});
invalidated_nodes->insert(matched.relu);
invalidated_nodes->insert(matched.activation);
invalidated_nodes->insert(matched.bias_add);
invalidated_nodes->insert(matched.conv2d);
}
@ -601,15 +610,15 @@ void AddFusedConv2DNode(
}
void AddFusedConv2DNode(
const Conv2DWithBatchNormAndRelu& matched, GraphDef* optimized_graph,
const Conv2DWithBatchNormAndActivation& matched, GraphDef* optimized_graph,
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
VLOG(2) << "Fuse Conv2D with BatchNorm and Relu: relu="
<< matched.relu->name()
VLOG(2) << "Fuse Conv2D with BatchNorm and " << matched.activation->op()
<< ": activation=" << matched.activation->name()
<< " batch_norm=" << matched.fused_batch_norm->name()
<< " conv2d=" << matched.conv2d->name();
NodeDef* fused_conv2d = optimized_graph->add_node();
fused_conv2d->set_name(matched.relu->name());
fused_conv2d->set_name(matched.activation->name());
fused_conv2d->set_op(kFusedConv2D);
fused_conv2d->set_device(matched.conv2d->device());
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
@ -620,10 +629,11 @@ void AddFusedConv2DNode(
fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
SetFusedConv2DAttributes(fused_conv2d, {"FusedBatchNorm", "Relu"},
SetFusedConv2DAttributes(fused_conv2d,
{"FusedBatchNorm", matched.activation->op()},
/*num_args=*/4, /*epsilon=*/matched.epsilon);
invalidated_nodes->insert(matched.relu);
invalidated_nodes->insert(matched.activation);
invalidated_nodes->insert(matched.fused_batch_norm);
invalidated_nodes->insert(matched.conv2d);
}
@ -778,9 +788,9 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// clang-format off
FusedBatchNorm fused_batch_norm;
Conv2DWithBiasAdd conv2d_with_bias;
Conv2DWithBiasAddAndRelu conv2d_with_bias_and_relu;
Conv2DWithBiasAddAndActivation conv2d_with_bias_and_activation;
Conv2DWithBatchNorm conv2d_with_batch_norm;
Conv2DWithBatchNormAndRelu conv2d_with_batch_norm_and_relu;
Conv2DWithBatchNormAndActivation conv2d_with_batch_norm_and_activation;
Conv2DWithSqueezeAndBiasAdd conv2d_with_squeeze_and_bias;
// clang-format on
@ -795,7 +805,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
RemapperContext ctx(topo_sorted_item);
// Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd
// and Relu nodes that were fused into a Conv2D node.
// and Activation nodes that were fused into a Conv2D node.
absl::flat_hash_set<const NodeDef*> invalidated_nodes;
optimized_graph->mutable_node()->Reserve(topo_sorted_item.graph.node_size());
@ -810,9 +820,10 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
continue;
}
// Remap Conv2D+BiasAdd+Relu into the _FusedConv2D.
if (FindConv2DWithBiasAndRelu(ctx, &node, &conv2d_with_bias_and_relu)) {
AddFusedConv2DNode(ctx, conv2d_with_bias_and_relu, optimized_graph,
// Remap Conv2D+BiasAdd+Activation into the _FusedConv2D.
if (FindConv2DWithBiasAndActivation(ctx, &node,
&conv2d_with_bias_and_activation)) {
AddFusedConv2DNode(ctx, conv2d_with_bias_and_activation, optimized_graph,
&invalidated_nodes);
continue;
}
@ -835,10 +846,10 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
continue;
}
// Remap Conv2D+FusedBatchNorm+Relu into the _FusedConv2D;
if (FindConv2DWithBatchNormAndRelu(ctx, &node,
&conv2d_with_batch_norm_and_relu)) {
AddFusedConv2DNode(conv2d_with_batch_norm_and_relu, optimized_graph,
// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
if (FindConv2DWithBatchNormAndActivation(
ctx, &node, &conv2d_with_batch_norm_and_activation)) {
AddFusedConv2DNode(conv2d_with_batch_norm_and_activation, optimized_graph,
&invalidated_nodes);
continue;
}

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/devices.h"
@ -97,7 +98,7 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
EXPECT_EQ(1, tensors_expected.size());
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
}
}
@ -164,69 +165,85 @@ TEST_F(RemapperTest, FuseConv2DWithBias) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(RemapperTest, FuseConv2DWithBiasAndRelu) {
TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
if (!EigenSupportsContractionOutputKernel()) return;
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = Placeholder::Shape({1, 1, 3, 128});
auto bias_shape = Placeholder::Shape({128});
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = Placeholder::Shape({1, 1, 3, 128});
auto bias_shape = Placeholder::Shape({128});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
std::vector<int> strides = {1, 1, 1, 1};
auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
auto relu = ops::Relu(s.WithOpName("relu"), bias_add);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
std::vector<int> strides = {1, 1, 1, 1};
auto conv =
ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 128});
auto bias_t = GenerateRandomTensor<DT_FLOAT>({128});
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
if (activation == "Relu") {
return ops::Identity(fetch, ops::Relu(activate, bias_add));
} else if (activation == "Relu6") {
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, bias_add));
}
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
return ops::Identity(fetch, bias);
}();
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 128});
auto bias_t = GenerateRandomTensor<DT_FLOAT>({128});
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ("_FusedConv2D", node.op());
EXPECT_EQ("input", node.input(0));
EXPECT_EQ("filter", node.input(1));
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
EXPECT_EQ(1, node.attr().at("num_args").i());
EXPECT_EQ("bias", node.input(2));
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(2, fused_ops.size());
EXPECT_EQ("BiasAdd", fused_ops[0]);
EXPECT_EQ("Relu", fused_ops[1]);
found++;
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
}
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "activation") {
EXPECT_EQ("_FusedConv2D", node.op());
EXPECT_EQ("input", node.input(0));
EXPECT_EQ("filter", node.input(1));
EXPECT_EQ(1, node.attr().at("num_args").i());
EXPECT_EQ("bias", node.input(2));
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(2, fused_ops.size());
EXPECT_EQ("BiasAdd", fused_ops[0]);
EXPECT_EQ(activation, fused_ops[1]);
found++;
}
}
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
}
TEST_F(RemapperTest, FuseConv2DWithBatchNorm) {
@ -306,83 +323,100 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNorm) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(RemapperTest, FuseConv2DWithBatchNormAndRelu) {
TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
if (!EigenSupportsContractionOutputKernel()) return;
using ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128});
auto scale_shape = ops::Placeholder::Shape({128});
auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128});
auto scale_shape = ops::Placeholder::Shape({128});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, scale_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, scale_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, scale_shape);
auto variance = Placeholder(s.WithOpName("variance"), DT_FLOAT, scale_shape);
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, scale_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, scale_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, scale_shape);
auto variance =
Placeholder(s.WithOpName("variance"), DT_FLOAT, scale_shape);
std::vector<int> strides = {1, 1, 1, 1};
auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
ops::FusedBatchNorm::Attrs attrs;
attrs = attrs.IsTraining(false);
auto batch_norm = ops::FusedBatchNorm(s.WithOpName("batch_norm"), conv, scale,
offset, mean, variance, attrs);
auto relu = ops::Relu(s.WithOpName("relu"), batch_norm.y);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
std::vector<int> strides = {1, 1, 1, 1};
auto conv =
ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
ops::FusedBatchNorm::Attrs attrs;
attrs = attrs.IsTraining(false);
auto batch_norm = ops::FusedBatchNorm(s.WithOpName("batch_norm"), conv,
scale, offset, mean, variance, attrs);
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 128});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({128});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({128});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({128});
auto variance_t = GenerateRandomTensor<DT_FLOAT>({128});
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t},
{"scale", scale_t}, {"offset", offset_t},
{"mean", mean_t}, {"variance", variance_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
if (activation == "Relu") {
return ops::Identity(fetch, ops::Relu(activate, batch_norm.y));
} else if (activation == "Relu6") {
return ops::Identity(fetch, ops::Relu6(activate, batch_norm.y));
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, batch_norm.y));
}
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
return ops::Identity(fetch, batch_norm.y);
}();
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 128});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({128});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({128});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({128});
auto variance_t = GenerateRandomTensor<DT_FLOAT>({128});
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ("_FusedConv2D", node.op());
EXPECT_EQ("input", node.input(0));
EXPECT_EQ("filter", node.input(1));
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t},
{"scale", scale_t}, {"offset", offset_t},
{"mean", mean_t}, {"variance", variance_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
EXPECT_EQ(4, node.attr().at("num_args").i());
EXPECT_EQ("scale", node.input(2));
EXPECT_EQ("offset", node.input(3));
EXPECT_EQ("mean", node.input(4));
EXPECT_EQ("variance", node.input(5));
const auto fused_ops = node.attr().at("fused_ops").list().s();
EXPECT_EQ(2, fused_ops.size());
EXPECT_EQ("FusedBatchNorm", fused_ops[0]);
EXPECT_EQ("Relu", fused_ops[1]);
found++;
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
}
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "activation") {
EXPECT_EQ("_FusedConv2D", node.op());
EXPECT_EQ("input", node.input(0));
EXPECT_EQ("filter", node.input(1));
EXPECT_EQ(4, node.attr().at("num_args").i());
EXPECT_EQ("scale", node.input(2));
EXPECT_EQ("offset", node.input(3));
EXPECT_EQ("mean", node.input(4));
EXPECT_EQ("variance", node.input(5));
const auto fused_ops = node.attr().at("fused_ops").list().s();
EXPECT_EQ(2, fused_ops.size());
EXPECT_EQ("FusedBatchNorm", fused_ops[0]);
EXPECT_EQ(activation, fused_ops[1]);
found++;
}
}
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
}
TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {

View File

@ -88,6 +88,7 @@ GrapplerTest::GrapplerTest() {
cfg->set_layout_optimizer(RewriterConfig::OFF);
cfg->set_loop_optimization(RewriterConfig::OFF);
cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
cfg->set_remapping(RewriterConfig::OFF);
}
std::vector<Tensor> GrapplerTest::EvaluateNodes(