Merge pull request #39335 from Intel-tensorflow:yifeng/depthwise_conv2d_bf16_fusion
PiperOrigin-RevId: 314566506 Change-Id: I33d1f2f2245dabadc8927ff64f9cdff60501d624
This commit is contained in:
commit
87b9b4216d
tensorflow/core
@ -23,6 +23,19 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input);
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
#define REGISTER_TEST_BFLOAT16(TEST) \
|
||||
REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input);
|
||||
|
||||
#define REGISTER_TEST_ALL_TYPES(TEST) \
|
||||
REGISTER_TEST_FLOAT32(TEST); \
|
||||
REGISTER_TEST_BFLOAT16(TEST);
|
||||
#else
|
||||
#define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST);
|
||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
@ -207,93 +220,99 @@ CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);
|
||||
#undef CREATE_CONV2DFUSION_ADD_BCAST_TEST
|
||||
#undef CREATE_CONV2DFUSION_TEST
|
||||
|
||||
TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
|
||||
auto filter_shape = Placeholder::Shape({1, 1, 3, 1});
|
||||
auto bias_shape = Placeholder::Shape({3});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
|
||||
std::vector<int> strides = {1, 1, 1, 1};
|
||||
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
|
||||
input, filter, strides, "SAME");
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
|
||||
|
||||
ops::Identity fetch = [&]() -> ops::Identity {
|
||||
auto activate = s.WithOpName("activation");
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
|
||||
if (activation == "Relu") {
|
||||
return ops::Identity(fetch, ops::Relu(activate, bias_add));
|
||||
} else if (activation == "Relu6") {
|
||||
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||
}
|
||||
|
||||
DCHECK(activation == "None");
|
||||
return ops::Identity(fetch, bias_add);
|
||||
}();
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
|
||||
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1});
|
||||
auto bias_t = GenerateRandomTensor<DT_FLOAT>({3});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::ON);
|
||||
GraphDef output;
|
||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() != "bias_add" && node.name() != "activation") continue;
|
||||
|
||||
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
|
||||
ASSERT_EQ(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "input");
|
||||
EXPECT_EQ(node.input(1), "filter");
|
||||
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
if (node.name() == "bias_add") {
|
||||
ASSERT_EQ(fused_ops.size(), 1);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "activation") {
|
||||
ASSERT_EQ(fused_ops.size(), 2);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 1);
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklRemapperTest, NAME##_##T) { \
|
||||
using ::tensorflow::ops::Placeholder; \
|
||||
\
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) { \
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); \
|
||||
\
|
||||
auto input_shape = Placeholder::Shape({8, 32, 32, 3}); \
|
||||
auto filter_shape = Placeholder::Shape({1, 1, 3, 1}); \
|
||||
auto bias_shape = Placeholder::Shape({3}); \
|
||||
\
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); \
|
||||
auto filter = \
|
||||
Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); \
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); \
|
||||
\
|
||||
std::vector<int> strides = {1, 1, 1, 1}; \
|
||||
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"), \
|
||||
input, filter, strides, "SAME"); \
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); \
|
||||
\
|
||||
ops::Identity fetch = [&]() -> ops::Identity { \
|
||||
auto activate = s.WithOpName("activation"); \
|
||||
auto fetch = s.WithOpName("fetch"); \
|
||||
\
|
||||
if (activation == "Relu") { \
|
||||
return ops::Identity(fetch, ops::Relu(activate, bias_add)); \
|
||||
} else if (activation == "Relu6") { \
|
||||
return ops::Identity(fetch, ops::Relu6(activate, bias_add)); \
|
||||
} else if (activation == "Elu") { \
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add)); \
|
||||
} \
|
||||
\
|
||||
DCHECK(activation == "None"); \
|
||||
return ops::Identity(fetch, bias_add); \
|
||||
}(); \
|
||||
\
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3}); \
|
||||
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1}); \
|
||||
auto bias_t = GenerateRandomTensor<DT_FLOAT>({3}); \
|
||||
\
|
||||
GrapplerItem item; \
|
||||
item.fetch = {"fetch"}; \
|
||||
item.feed = { \
|
||||
{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}}; \
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph)); \
|
||||
\
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) { \
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0"); \
|
||||
} \
|
||||
\
|
||||
Remapper optimizer(RewriterConfig::ON); \
|
||||
GraphDef output; \
|
||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); \
|
||||
\
|
||||
int found = 0; \
|
||||
for (const NodeDef& node : output.node()) { \
|
||||
if (node.name() != "bias_add" && node.name() != "activation") \
|
||||
continue; \
|
||||
\
|
||||
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative"); \
|
||||
ASSERT_EQ(node.input_size(), 3); \
|
||||
EXPECT_EQ(node.input(0), "input"); \
|
||||
EXPECT_EQ(node.input(1), "filter"); \
|
||||
\
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1); \
|
||||
EXPECT_EQ(node.input(2), "bias"); \
|
||||
\
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s(); \
|
||||
if (node.name() == "bias_add") { \
|
||||
ASSERT_EQ(fused_ops.size(), 1); \
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd"); \
|
||||
found++; \
|
||||
} \
|
||||
if (node.name() == "activation") { \
|
||||
ASSERT_EQ(fused_ops.size(), 2); \
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd"); \
|
||||
EXPECT_EQ(fused_ops[1], activation); \
|
||||
found++; \
|
||||
} \
|
||||
} \
|
||||
EXPECT_EQ(found, 1); \
|
||||
\
|
||||
auto tensors_expected = \
|
||||
EvaluateNodes(item.graph, item.fetch, item.feed); \
|
||||
ASSERT_EQ(tensors_expected.size(), 1); \
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed); \
|
||||
ASSERT_EQ(tensors.size(), 1); \
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6); \
|
||||
} \
|
||||
}
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(FuseDepthwiseConv2DWithBiasAndActivation);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
|
||||
|
@ -2309,11 +2309,20 @@ REGISTER_KERNEL_BUILDER(
|
||||
.TypeConstraint<quint8>("out_type"),
|
||||
NoOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative")
|
||||
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T"),
|
||||
.TypeConstraint<bfloat16>("T"),
|
||||
NoOp);
|
||||
|
||||
#define REGISTER_NO_OP_CPU_2D_DEPTHWISE(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
NoOp);
|
||||
|
||||
TF_CALL_float(REGISTER_NO_OP_CPU_2D_DEPTHWISE);
|
||||
TF_CALL_bfloat16(REGISTER_NO_OP_CPU_2D_DEPTHWISE);
|
||||
|
||||
// Register templatized MKL kernels for non-fused and fused-versions of
|
||||
// QuantizedDepthwiseConv2D.
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D")
|
||||
@ -2367,14 +2376,6 @@ REGISTER_KERNEL_BUILDER(
|
||||
MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, quint8, quint8, true,
|
||||
true>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_MklFusedDepthwiseConv2dNative")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel),
|
||||
MklFusedDepthwiseConvOp<CPUDevice, float, float, float, float, float, int32,
|
||||
false, true, true>);
|
||||
|
||||
// Register 2D operations
|
||||
#define REGISTER_MKL_CPU_2D(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -2426,13 +2427,20 @@ REGISTER_KERNEL_BUILDER(
|
||||
TF_CALL_float(REGISTER_MKL_CPU_2D);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_2D);
|
||||
|
||||
#define REGISTER_MKL_CPU_2D_DEPTHWISE(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklDepthwiseConv2dNative") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>);
|
||||
#define REGISTER_MKL_CPU_2D_DEPTHWISE(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklDepthwiseConv2dNative") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklFusedDepthwiseConv2dNative") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true, \
|
||||
true>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE);
|
||||
|
Loading…
Reference in New Issue
Block a user