Update tf._FusedMatMul to support BF16 element type.
Updated _FusedMatMul op definition in TensorFlow op registry to always support bf16 regardless of build. PiperOrigin-RevId: 323516484 Change-Id: Idba319a22f2042beec5432cf696dd63207edb0dd
This commit is contained in:
parent
e8029af3b2
commit
22b60f146e
@ -11661,7 +11661,7 @@ create these operators.
|
||||
TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>;
|
||||
}
|
||||
|
||||
def TF__FusedMatMulOp : TF_Op<"_FusedMatMul", [NoSideEffect]> {
|
||||
def TF__FusedMatMulOp : TF_Op<"_FusedMatMul", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = [{
|
||||
Performs a MatMul followed by a specified series of operations.
|
||||
}];
|
||||
@ -11687,9 +11687,9 @@ expected to create these operators.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$a,
|
||||
F32Tensor:$b,
|
||||
Variadic<F32Tensor>:$args,
|
||||
TensorOf<[BF16, F32]>:$a,
|
||||
TensorOf<[BF16, F32]>:$b,
|
||||
Variadic<TensorOf<[BF16, F32]>>:$args,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_b,
|
||||
@ -11698,7 +11698,7 @@ expected to create these operators.
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$product
|
||||
TensorOf<[BF16, F32]>:$product
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/grappler/devices.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
@ -388,65 +389,80 @@ TEST_F(RemapperTest, FuseConv2DWithBias) {
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseMatMulWithBias) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
class RemapperFuseMatMulWithBiasTest : public RemapperTest {
|
||||
public:
|
||||
template <DataType DTYPE>
|
||||
void RunTest() {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto lhs_shape = ops::Placeholder::Shape({8, 32});
|
||||
auto rhs_shape = ops::Placeholder::Shape({32, 64});
|
||||
auto bias_shape = ops::Placeholder::Shape({64});
|
||||
auto lhs_shape = ops::Placeholder::Shape({8, 32});
|
||||
auto rhs_shape = ops::Placeholder::Shape({32, 64});
|
||||
auto bias_shape = ops::Placeholder::Shape({64});
|
||||
|
||||
auto lhs = Placeholder(s.WithOpName("lhs"), DT_FLOAT, lhs_shape);
|
||||
auto rhs = Placeholder(s.WithOpName("rhs"), DT_FLOAT, rhs_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
auto lhs = Placeholder(s.WithOpName("lhs"), DTYPE, lhs_shape);
|
||||
auto rhs = Placeholder(s.WithOpName("rhs"), DTYPE, rhs_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
|
||||
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);
|
||||
|
||||
auto lhs_t = GenerateRandomTensor<DT_FLOAT>({8, 32});
|
||||
auto rhs_t = GenerateRandomTensor<DT_FLOAT>({32, 64});
|
||||
auto bias_t = GenerateRandomTensor<DT_FLOAT>({64});
|
||||
auto lhs_t = GenerateTensorWithSetRandom<DTYPE>({8, 32});
|
||||
auto rhs_t = GenerateTensorWithSetRandom<DTYPE>({32, 64});
|
||||
auto bias_t = GenerateTensorWithSetRandom<DTYPE>({64});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::ON);
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "bias_add") {
|
||||
EXPECT_EQ(node.op(), "_FusedMatMul");
|
||||
ASSERT_GE(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "lhs");
|
||||
EXPECT_EQ(node.input(1), "rhs");
|
||||
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
ASSERT_EQ(fused_ops.size(), 1);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
found++;
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(1, found);
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
Remapper optimizer(RewriterConfig::ON);
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "bias_add") {
|
||||
EXPECT_EQ(node.op(), "_FusedMatMul");
|
||||
ASSERT_GE(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "lhs");
|
||||
EXPECT_EQ(node.input(1), "rhs");
|
||||
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
ASSERT_EQ(fused_ops.size(), 1);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(1, found);
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
typedef typename EnumToDataType<DTYPE>::Type T;
|
||||
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(RemapperFuseMatMulWithBiasTest, F32) { RunTest<DT_FLOAT>(); }
|
||||
|
||||
TEST_F(RemapperFuseMatMulWithBiasTest, Bf16) {
|
||||
#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
GTEST_SKIP() << "Intel MKL with bfloat16 support is not enabled, skipping "
|
||||
"FuseMatMulWithBias with bfloat16.";
|
||||
#endif // !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
RunTest<DT_BFLOAT16>(); // NOLINT
|
||||
}
|
||||
|
||||
// TODO(b/161005848): Fix flaky test.
|
||||
@ -602,82 +618,99 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseMatMulWithBiasAndActivation) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
public:
|
||||
template <DataType DTYPE>
|
||||
void RunTest() {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto lhs_shape = ops::Placeholder::Shape({8, 32});
|
||||
auto rhs_shape = ops::Placeholder::Shape({32, 64});
|
||||
auto bias_shape = ops::Placeholder::Shape({64});
|
||||
auto lhs_shape = ops::Placeholder::Shape({8, 32});
|
||||
auto rhs_shape = ops::Placeholder::Shape({32, 64});
|
||||
auto bias_shape = ops::Placeholder::Shape({64});
|
||||
|
||||
auto lhs = Placeholder(s.WithOpName("lhs"), DT_FLOAT, lhs_shape);
|
||||
auto rhs = Placeholder(s.WithOpName("rhs"), DT_FLOAT, rhs_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
auto lhs = Placeholder(s.WithOpName("lhs"), DTYPE, lhs_shape);
|
||||
auto rhs = Placeholder(s.WithOpName("rhs"), DTYPE, rhs_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
|
||||
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||
|
||||
ops::Identity fetch = [&]() -> ops::Identity {
|
||||
auto activate = s.WithOpName("activation");
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
ops::Identity fetch = [&]() -> ops::Identity {
|
||||
auto activate = s.WithOpName("activation");
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
|
||||
if (activation == "Relu") {
|
||||
return ops::Identity(fetch, ops::Relu(activate, bias_add));
|
||||
} else if (activation == "Relu6") {
|
||||
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||
if (activation == "Relu") {
|
||||
return ops::Identity(fetch, ops::Relu(activate, bias_add));
|
||||
} else if (activation == "Relu6") {
|
||||
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||
}
|
||||
|
||||
return ops::Identity(fetch, bias);
|
||||
}();
|
||||
|
||||
auto lhs_t = GenerateTensorWithSetRandom<DTYPE>({8, 32});
|
||||
auto rhs_t = GenerateTensorWithSetRandom<DTYPE>({32, 64});
|
||||
auto bias_t = GenerateTensorWithSetRandom<DTYPE>({64});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
return ops::Identity(fetch, bias);
|
||||
}();
|
||||
Remapper optimizer(RewriterConfig::ON);
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
auto lhs_t = GenerateRandomTensor<DT_FLOAT>({8, 32});
|
||||
auto rhs_t = GenerateRandomTensor<DT_FLOAT>({32, 64});
|
||||
auto bias_t = GenerateRandomTensor<DT_FLOAT>({64});
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "activation") {
|
||||
EXPECT_EQ(node.op(), "_FusedMatMul");
|
||||
ASSERT_GE(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "lhs");
|
||||
EXPECT_EQ(node.input(1), "rhs");
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::ON);
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "activation") {
|
||||
EXPECT_EQ(node.op(), "_FusedMatMul");
|
||||
ASSERT_GE(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "lhs");
|
||||
EXPECT_EQ(node.input(1), "rhs");
|
||||
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
ASSERT_EQ(fused_ops.size(), 2);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
found++;
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
ASSERT_EQ(fused_ops.size(), 2);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
found++;
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(1, found);
|
||||
EXPECT_EQ(1, found);
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
typedef typename EnumToDataType<DTYPE>::Type T;
|
||||
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, F32) {
|
||||
RunTest<DT_FLOAT>();
|
||||
}
|
||||
|
||||
TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, Bf16) {
|
||||
#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
GTEST_SKIP() << "Intel MKL with bfloat16 support is not enabled, skipping "
|
||||
"FuseMatMulWithBiasAndActivation with bfloat16.";
|
||||
#endif // !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
RunTest<DT_BFLOAT16>(); // NOLINT
|
||||
}
|
||||
|
||||
#ifndef INTEL_MKL
|
||||
|
@ -89,6 +89,15 @@ class GrapplerTest : public ::testing::Test {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Creates a random tensor with given shape using `setRandom`.
|
||||
template <DataType DTYPE>
|
||||
Tensor GenerateTensorWithSetRandom(const TensorShape& shape) const {
|
||||
typedef typename EnumToDataType<DTYPE>::Type T;
|
||||
Tensor tensor(DTYPE, shape);
|
||||
tensor.flat<T>().setRandom();
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Get a constant tensor with given shape.
|
||||
template <DataType DTYPE>
|
||||
Tensor GenerateConstantTensor(
|
||||
|
@ -952,11 +952,7 @@ REGISTER_OP("_FusedMatMul")
|
||||
.Output("product: T")
|
||||
.Attr("transpose_a: bool = false")
|
||||
.Attr("transpose_b: bool = false")
|
||||
#if defined(INTEL_MKL) && defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
.Attr("T: {bfloat16, float}")
|
||||
#else
|
||||
.Attr("T: {float}")
|
||||
#endif
|
||||
.Attr("num_args: int >= 0")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ----------- //
|
||||
|
Loading…
Reference in New Issue
Block a user