Update tf._FusedMatMul to support BF16 element type.

Updated _FusedMatMul op definition in TensorFlow op registry to always support bf16 regardless of build.

PiperOrigin-RevId: 323516484
Change-Id: Idba319a22f2042beec5432cf696dd63207edb0dd
This commit is contained in:
Andy Ly 2020-07-27 23:26:39 -07:00 committed by TensorFlower Gardener
parent e8029af3b2
commit 22b60f146e
4 changed files with 159 additions and 121 deletions

View File

@ -11661,7 +11661,7 @@ create these operators.
TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>;
}
def TF__FusedMatMulOp : TF_Op<"_FusedMatMul", [NoSideEffect]> {
def TF__FusedMatMulOp : TF_Op<"_FusedMatMul", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = [{
Performs a MatMul followed by a specified series of operations.
}];
@ -11687,9 +11687,9 @@ expected to create these operators.
}];
let arguments = (ins
F32Tensor:$a,
F32Tensor:$b,
Variadic<F32Tensor>:$args,
TensorOf<[BF16, F32]>:$a,
TensorOf<[BF16, F32]>:$b,
Variadic<TensorOf<[BF16, F32]>>:$args,
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedAttr<BoolAttr, "false">:$transpose_b,
@ -11698,7 +11698,7 @@ expected to create these operators.
);
let results = (outs
F32Tensor:$product
TensorOf<[BF16, F32]>:$product
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
@ -388,65 +389,80 @@ TEST_F(RemapperTest, FuseConv2DWithBias) {
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
TEST_F(RemapperTest, FuseMatMulWithBias) {
using ::tensorflow::ops::Placeholder;
class RemapperFuseMatMulWithBiasTest : public RemapperTest {
public:
template <DataType DTYPE>
void RunTest() {
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto lhs_shape = ops::Placeholder::Shape({8, 32});
auto rhs_shape = ops::Placeholder::Shape({32, 64});
auto bias_shape = ops::Placeholder::Shape({64});
auto lhs_shape = ops::Placeholder::Shape({8, 32});
auto rhs_shape = ops::Placeholder::Shape({32, 64});
auto bias_shape = ops::Placeholder::Shape({64});
auto lhs = Placeholder(s.WithOpName("lhs"), DT_FLOAT, lhs_shape);
auto rhs = Placeholder(s.WithOpName("rhs"), DT_FLOAT, rhs_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
auto lhs = Placeholder(s.WithOpName("lhs"), DTYPE, lhs_shape);
auto rhs = Placeholder(s.WithOpName("rhs"), DTYPE, rhs_shape);
auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);
auto lhs_t = GenerateRandomTensor<DT_FLOAT>({8, 32});
auto rhs_t = GenerateRandomTensor<DT_FLOAT>({32, 64});
auto bias_t = GenerateRandomTensor<DT_FLOAT>({64});
auto lhs_t = GenerateTensorWithSetRandom<DTYPE>({8, 32});
auto rhs_t = GenerateTensorWithSetRandom<DTYPE>({32, 64});
auto bias_t = GenerateTensorWithSetRandom<DTYPE>({64});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "bias_add") {
EXPECT_EQ(node.op(), "_FusedMatMul");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "lhs");
EXPECT_EQ(node.input(1), "rhs");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
}
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "bias_add") {
EXPECT_EQ(node.op(), "_FusedMatMul");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "lhs");
EXPECT_EQ(node.input(1), "rhs");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
}
}
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
typedef typename EnumToDataType<DTYPE>::Type T;
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
}
};
TEST_F(RemapperFuseMatMulWithBiasTest, F32) { RunTest<DT_FLOAT>(); }
TEST_F(RemapperFuseMatMulWithBiasTest, Bf16) {
#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
GTEST_SKIP() << "Intel MKL with bfloat16 support is not enabled, skipping "
"FuseMatMulWithBias with bfloat16.";
#endif // !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
RunTest<DT_BFLOAT16>(); // NOLINT
}
// TODO(b/161005848): Fix flaky test.
@ -602,82 +618,99 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
}
}
TEST_F(RemapperTest, FuseMatMulWithBiasAndActivation) {
using ::tensorflow::ops::Placeholder;
class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
public:
template <DataType DTYPE>
void RunTest() {
using ::tensorflow::ops::Placeholder;
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto lhs_shape = ops::Placeholder::Shape({8, 32});
auto rhs_shape = ops::Placeholder::Shape({32, 64});
auto bias_shape = ops::Placeholder::Shape({64});
auto lhs_shape = ops::Placeholder::Shape({8, 32});
auto rhs_shape = ops::Placeholder::Shape({32, 64});
auto bias_shape = ops::Placeholder::Shape({64});
auto lhs = Placeholder(s.WithOpName("lhs"), DT_FLOAT, lhs_shape);
auto rhs = Placeholder(s.WithOpName("rhs"), DT_FLOAT, rhs_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
auto lhs = Placeholder(s.WithOpName("lhs"), DTYPE, lhs_shape);
auto rhs = Placeholder(s.WithOpName("rhs"), DTYPE, rhs_shape);
auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
if (activation == "Relu") {
return ops::Identity(fetch, ops::Relu(activate, bias_add));
} else if (activation == "Relu6") {
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, bias_add));
if (activation == "Relu") {
return ops::Identity(fetch, ops::Relu(activate, bias_add));
} else if (activation == "Relu6") {
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, bias_add));
}
return ops::Identity(fetch, bias);
}();
auto lhs_t = GenerateTensorWithSetRandom<DTYPE>({8, 32});
auto rhs_t = GenerateTensorWithSetRandom<DTYPE>({32, 64});
auto bias_t = GenerateTensorWithSetRandom<DTYPE>({64});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
return ops::Identity(fetch, bias);
}();
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto lhs_t = GenerateRandomTensor<DT_FLOAT>({8, 32});
auto rhs_t = GenerateRandomTensor<DT_FLOAT>({32, 64});
auto bias_t = GenerateRandomTensor<DT_FLOAT>({64});
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "activation") {
EXPECT_EQ(node.op(), "_FusedMatMul");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "lhs");
EXPECT_EQ(node.input(1), "rhs");
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "activation") {
EXPECT_EQ(node.op(), "_FusedMatMul");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "lhs");
EXPECT_EQ(node.input(1), "rhs");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 2);
EXPECT_EQ(fused_ops[0], "BiasAdd");
EXPECT_EQ(fused_ops[1], activation);
found++;
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 2);
EXPECT_EQ(fused_ops[0], "BiasAdd");
EXPECT_EQ(fused_ops[1], activation);
found++;
}
}
}
EXPECT_EQ(1, found);
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
typedef typename EnumToDataType<DTYPE>::Type T;
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
}
}
};
TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, F32) {
RunTest<DT_FLOAT>();
}
TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, Bf16) {
#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
GTEST_SKIP() << "Intel MKL with bfloat16 support is not enabled, skipping "
"FuseMatMulWithBiasAndActivation with bfloat16.";
#endif // !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
RunTest<DT_BFLOAT16>(); // NOLINT
}
#ifndef INTEL_MKL

View File

@ -89,6 +89,15 @@ class GrapplerTest : public ::testing::Test {
return tensor;
}
// Creates a random tensor with given shape using `setRandom`.
template <DataType DTYPE>
Tensor GenerateTensorWithSetRandom(const TensorShape& shape) const {
typedef typename EnumToDataType<DTYPE>::Type T;
Tensor tensor(DTYPE, shape);
tensor.flat<T>().setRandom();
return tensor;
}
// Get a constant tensor with given shape.
template <DataType DTYPE>
Tensor GenerateConstantTensor(

View File

@ -952,11 +952,7 @@ REGISTER_OP("_FusedMatMul")
.Output("product: T")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
#if defined(INTEL_MKL) && defined(ENABLE_INTEL_MKL_BFLOAT16)
.Attr("T: {bfloat16, float}")
#else
.Attr("T: {float}")
#endif
.Attr("num_args: int >= 0")
.Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ----------- //