Clean up duplicated code in matmul_op/batch_matmul_op and associated tests. Having a separate redundant implementation for non-batch matmul is a maintenance burden.

We add a registration of the batch kernel under the "MatMul" symbol (with modified shape validation) for legacy graphs.

PiperOrigin-RevId: 339715052
Change-Id: Iad428cdf62a656de56f62d996db44be682af5d13
This commit is contained in:
A. Unique TensorFlower 2020-10-29 11:52:31 -07:00 committed by TensorFlower Gardener
parent 47926752a8
commit 0e4a3676d9
16 changed files with 352 additions and 941 deletions

View File

@ -769,7 +769,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
EXPECT_NE(TF_OK, TF_GetCode(status));
EXPECT_EQ(nullptr, t);
const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]";
EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
<< TF_Message(status);
// Since error is not cleared, the following copy with correct device will

View File

@ -931,7 +931,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
@ -950,7 +949,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
@ -969,7 +967,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
@ -988,7 +985,6 @@ TEST(SessionTest, InvalidOpInputName) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
@ -1026,7 +1022,6 @@ TEST(SessionTest, ExtendValidation) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
@ -1043,7 +1038,6 @@ TEST(SessionTest, ExtendValidation) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
@ -1057,7 +1051,6 @@ TEST(SessionTest, ExtendValidation) {
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);

View File

@ -406,7 +406,7 @@ XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
TEST(TFunc, WXPlusB) {
auto expect = R"P(
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
mm = MatMul[T=$T, transpose_a=false, transpose_b=false](w, x)
y = Add[T=$T](mm:product:0, b)
return y = y:z:0
}

View File

@ -346,10 +346,7 @@ FunctionDef WXPlusB() {
{{{"mm"},
"MatMul",
{"w", "x"},
{{"T", "$T"},
{"transpose_a", false},
{"transpose_b", false},
{"_kernel", "eigen"}}},
{{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}}},
{{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
}

View File

@ -48,7 +48,6 @@ load(
)
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl_ml",
"mkl_deps",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
@ -3241,7 +3240,6 @@ cc_library(
deps = [
":aggregate_ops",
":argmax_op",
":batch_matmul_op",
":betainc_op",
":bincount_op",
":bucketize_op",
@ -3337,14 +3335,27 @@ tf_kernel_library(
tf_kernel_library(
name = "batch_matmul_op",
deps = [":matmul_op"],
)
tf_kernel_library(
name = "matmul_op",
# <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
hdrs = ["batch_matmul_op_impl.h"],
prefix = "batch_matmul_op",
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
"//third_party/mkl:intel_binary_blob",
]) + if_cuda_or_rocm([
"//tensorflow/core/kernels:gpu_utils",
]),
hdrs = ["matmul_op_impl.h"],
defines = select({
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
"//conditions:default": [],
}),
prefix = "matmul_op",
deps = MATH_DEPS + [
":eigen_contraction_kernel",
":fused_eigen_output_kernels",
] + select({
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
"//conditions:default": [],
}) + mkl_deps() + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]) + if_cuda_or_rocm([":gpu_utils"]),
)
tf_kernel_library(
@ -3406,28 +3417,6 @@ tf_kernel_library(
]),
)
tf_kernel_library(
name = "matmul_op",
srcs = [
"matmul_op.cc",
"matmul_op_fused.cc",
],
hdrs = ["matmul_op.h"],
defines = select({
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
"//conditions:default": [],
}),
deps = MATH_DEPS + [
":eigen_contraction_kernel",
":fused_eigen_output_kernels",
] + select({
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
"//conditions:default": [],
}) + mkl_deps() + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]) + if_cuda_or_rocm([":gpu_utils"]),
)
tf_kernel_library(
name = "reduction_ops",
gpu_srcs = ["reduction_gpu_kernels.cu.h"],
@ -3620,25 +3609,6 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "batch_matmul_op_test",
size = "small",
srcs = ["batch_matmul_op_test.cc"],
deps = [
":batch_matmul_op",
":broadcast_to_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cuda_cc_test(
name = "scan_ops_test",
size = "small",
@ -5868,8 +5838,8 @@ filegroup(
"identity_op.h",
"immutable_constant_op.cc",
"immutable_constant_op.h",
"matmul_op.cc",
"matmul_op.h",
"matmul_op_impl.h",
"matmul_op_real.cc",
"no_op.cc",
"no_op.h",
"one_hot_op.cc",
@ -5948,7 +5918,6 @@ filegroup(
srcs = [
"argmax_op.h",
"avgpooling_op.h",
"batch_matmul_op_impl.h",
"batch_norm_op.h",
"bincount_op.h",
"broadcast_to_op.h",
@ -6039,7 +6008,6 @@ filegroup(
":android_extended_ops_headers",
"argmax_op.cc",
"avgpooling_op.cc",
"batch_matmul_op_real.cc",
"batch_norm_op.cc",
"bcast_ops.cc",
"check_numerics_op.cc",
@ -7431,7 +7399,6 @@ test_suite(
"manual", # Avoid redundancy when using wildcard test patterns.
],
tests = [
":batch_matmul_op_test",
":batch_norm_op_test",
":broadcast_to_op_test",
":cast_op_test",

View File

@ -1,257 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/broadcast_to_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
.Input(input)
.Input(shape)
.Finalize(g, &ret));
return ret;
}
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
.Input(in0)
.Input(in1)
.Attr("adj_x", adj_x)
.Attr("adj_y", adj_y)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
bool adjoint_b, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
in1.flat<T>().setRandom();
test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
test::graph::Constant(g, in1), adjoint_a, adjoint_b);
return g;
}
template <typename T>
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
bool manual_broadcast, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, TensorShape({b0, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, TensorShape({b1, k, n}));
in1.flat<T>().setRandom();
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
Node* in0_node = nullptr;
Node* in1_node = nullptr;
if (manual_broadcast) {
for (int i = 0; i < 3; ++i) {
auto vec0 = broadcasted_in0_shape.vec<int64>();
auto vec1 = broadcasted_in1_shape.vec<int64>();
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
}
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
test::graph::Constant(g, broadcasted_in0_shape));
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
test::graph::Constant(g, broadcasted_in1_shape));
} else {
in0_node = test::graph::Constant(g, in0);
in1_node = test::graph::Constant(g, in1);
}
BatchMatmulV2(g, in0_node, in1_node, false, false);
return g;
}
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
static void \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \
test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
/* Uncomment to enable benchmarks for double & complex types: */
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// gpu);
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
// \
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
// Macro arguments names: --------------------------------------------------- //
// B1: batch size of LHS
// B2: batch size of RHS
// M: outer dimension of LHS
// K: inner dimensions of LHS and RHS
// N: outer dimension of RHS
// MB: boolean indicating whether to use manual broadcasting
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
// D: Device (e.g. cpu, gpu)
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
static void \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
K * N * 2); \
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
// Typical fully connected layers
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
// Square matmul.
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
// Matrix-vector multiplies.
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
// Vector-matrix multiplies.
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
// Typical fully connected layers
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
// Square matmul.
BM_BatchMatmul(1, 32, 32, 32, false, false);
BM_BatchMatmul(1, 128, 128, 128, false, false);
BM_BatchMatmul(1, 256, 256, 256, false, false);
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
BM_BatchMatmul(2, 32, 32, 32, false, false);
BM_BatchMatmul(2, 128, 128, 128, false, false);
BM_BatchMatmul(2, 256, 256, 256, false, false);
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
BM_BatchMatmul(4, 32, 32, 32, false, false);
BM_BatchMatmul(4, 128, 128, 128, false, false);
BM_BatchMatmul(4, 256, 256, 256, false, false);
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
BM_BatchMatmul(8, 32, 32, 32, false, false);
BM_BatchMatmul(8, 128, 128, 128, false, false);
BM_BatchMatmul(8, 256, 256, 256, false, false);
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
BM_BatchMatmul(32, 32, 32, 32, false, false);
BM_BatchMatmul(32, 128, 128, 128, false, false);
BM_BatchMatmul(32, 256, 256, 256, false, false);
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
// Matrix-vector multiplies.
BM_BatchMatmul(1, 10000, 200, 1, false, false);
BM_BatchMatmul(8, 10000, 200, 1, false, false);
BM_BatchMatmul(32, 10000, 200, 1, false, false);
BM_BatchMatmul(1, 10000, 200, 1, true, false);
BM_BatchMatmul(8, 10000, 200, 1, true, false);
BM_BatchMatmul(32, 10000, 200, 1, true, false);
BM_BatchMatmul(1, 10000, 200, 1, false, true);
BM_BatchMatmul(8, 10000, 200, 1, false, true);
BM_BatchMatmul(32, 10000, 200, 1, false, true);
BM_BatchMatmul(1, 10000, 200, 1, true, true);
BM_BatchMatmul(8, 10000, 200, 1, true, true);
BM_BatchMatmul(32, 10000, 200, 1, true, true);
// Vector-matrix multiplies.
BM_BatchMatmul(1, 1, 200, 10000, false, false);
BM_BatchMatmul(8, 1, 200, 10000, false, false);
BM_BatchMatmul(32, 1, 200, 10000, false, false);
BM_BatchMatmul(1, 1, 200, 10000, true, false);
BM_BatchMatmul(8, 1, 200, 10000, true, false);
BM_BatchMatmul(32, 1, 200, 10000, true, false);
BM_BatchMatmul(1, 1, 200, 10000, false, true);
BM_BatchMatmul(8, 1, 200, 10000, false, true);
BM_BatchMatmul(32, 1, 200, 10000, false, true);
BM_BatchMatmul(1, 1, 200, 10000, true, true);
BM_BatchMatmul(8, 1, 200, 10000, true, true);
BM_BatchMatmul(32, 1, 200, 10000, true, true);
} // namespace
} // namespace tensorflow

View File

@ -30,9 +30,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/linalg/einsum_op.h"
#include "tensorflow/core/kernels/matmul_op_impl.h"
#include "tensorflow/core/kernels/reduction_ops_common.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -1,567 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/math_ops.cc.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/matmul_op.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/util/matmul_autotune.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T, bool USE_CUBLAS>
struct LaunchMatMul;
namespace {
// Converts a TensorFlow Tensor to an Eigen Matrix.
template <typename T>
Eigen::Map<
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
ToEigenMatrix(const Tensor& tensor) {
auto matrix = tensor.matrix<T>();
return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map(
matrix.data(), matrix.dimension(0), matrix.dimension(1));
}
// Converts a TensorFlow Tensor to an Eigen Vector.
template <typename T>
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) {
auto v = tensor->flat<T>();
return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
}
template <typename T>
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(
const Tensor& tensor) {
auto v = tensor.flat<T>();
return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
}
} // namespace
// If either side can be represented as a vector, do an explicit vector
// matrix multiply and return true; else return false.
//
// Note: this uses plain Eigen and not Eigen Tensor because it is more
// efficient.
template <typename T>
bool ExplicitVectorMatrixOptimization(
const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
Tensor* out) {
if (out->dim_size(0) == 1) {
if (dim_pair[0].second == 0) {
// Note: this case is optimized in Eigen Tensors.
return false;
} else {
auto out_v = ToEigenVector<T>(out);
auto a_v = ToEigenVector<T>(a);
auto b_m = ToEigenMatrix<T>(b);
out_v.noalias() = b_m * a_v;
}
return true;
} else if (out->dim_size(1) == 1) {
auto out_v = ToEigenVector<T>(out);
auto a_m = ToEigenMatrix<T>(a);
auto b_v = ToEigenVector<T>(b);
if (dim_pair[0].first == 0) {
out_v.noalias() = a_m.transpose() * b_v;
} else {
out_v.noalias() = a_m * b_v;
}
return true;
}
return false;
}
// Half is not supported.
template <>
bool ExplicitVectorMatrixOptimization<Eigen::half>(
const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
Tensor* out) {
return false;
}
template <typename Device, typename T>
struct LaunchMatMulBase {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
typedef se::blas::AlgorithmType AlgorithmType;
#else
typedef int64 AlgorithmType;
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
static void launch(
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
std::vector<AlgorithmType>* algorithms, bool use_autotune, Tensor* out) {
// An explicit vector-matrix multiply is much better optimized than an
// implicit one and this is a bottleneck during non-batched inference.
bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
if (!was_vector) {
functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
out->matrix<T>(), a.matrix<T>(),
b.matrix<T>(), dim_pair);
}
}
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
std::vector<int64>* algorithms,
bool* algorithm_set_flag) {}
};
// On CPUs, we ignore USE_CUBLAS
template <typename T>
struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
template <typename T, bool USE_CUBLAS>
struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace {
template <typename T>
struct LaunchBlasGemv {
static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans,
uint64 m, uint64 n, const se::DeviceMemory<T>& a,
const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c,
se::blas::ProfileResult* output_profile) {
const auto blas_trans = trans ? se::blas::Transpose::kTranspose
: se::blas::Transpose::kNoTranspose;
if (output_profile == nullptr) {
bool blas_launch_status =
stream
->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
static_cast<T>(0.0), c, 1)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(
errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
}
} else {
bool blas_launch_status =
stream
->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
a, m, b, 1, static_cast<T>(0.0), c, 1,
output_profile)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMV with profiling launch failed: m=", m, ", n=", n));
}
}
}
static bool IsSupported() { return true; }
};
template <>
void LaunchBlasGemv<Eigen::half>::Compute(
OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n,
const se::DeviceMemory<Eigen::half>& a,
const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c,
se::blas::ProfileResult* output_profile) {
ctx->SetStatus(errors::Internal(
"Blas GEMV launch failed: GEMV is not implemented for float16."));
}
template <>
bool LaunchBlasGemv<Eigen::half>::IsSupported() {
return false;
}
template <typename T>
bool ShouldUseGemv(uint64 n) {
return (LaunchBlasGemv<T>::IsSupported() && n == 1);
}
} // namespace
bool GetCublasAutotuneComputationType(const DataType& dtype,
se::blas::ComputationType* compute_type) {
using se::blas::ComputationType;
switch (dtype) {
case DT_HALF:
case DT_BFLOAT16:
static bool use_f32_for_f16_computation =
MatmulDoFP32ComputationFP16Input();
if (use_f32_for_f16_computation) {
*compute_type = ComputationType::kF32;
} else {
*compute_type = ComputationType::kF16;
}
return false;
case DT_FLOAT:
*compute_type = ComputationType::kF32;
return true;
case DT_DOUBLE:
*compute_type = ComputationType::kF64;
return true;
default:
// Unsupported compute_type, return false.
return false;
}
}
// A dummy type to group matmul autotune results together.
struct MatmulAutoTuneGroup {
static string name() { return "Matmul"; }
};
typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
se::blas::AlgorithmConfig>
AutoTuneMatmul;
template <typename T>
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
static void launch(
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
using se::blas::AlgorithmConfig;
using se::blas::ComputationType;
using se::blas::kDefaultAlgorithm;
using se::blas::kDefaultBlasGemm;
using se::blas::kDefaultBlasGemv;
using se::blas::kNoAlgorithm;
using se::blas::ProfileResult;
using se::blas::Transpose;
Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
const uint64 m = a.dim_size(1 - dim_pair[0].first);
const uint64 k = a.dim_size(dim_pair[0].first);
const uint64 n = b.dim_size(1 - dim_pair[0].second);
bool transpose_a = dim_pair[0].first == 0;
bool transpose_b = dim_pair[0].second == 1;
auto blas_transpose_a = trans[transpose_a];
auto blas_transpose_b = trans[transpose_b];
auto* stream = ctx->op_device_context()->stream();
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
a.template flat<T>().size());
auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
b.template flat<T>().size());
auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
out->template flat<T>().size());
auto alpha = static_cast<T>(1.0);
auto beta = static_cast<T>(0.0);
int device_id = stream->parent()->device_ordinal();
DataType dtype = a.dtype();
MatmulParameters matmul_parameters = {
transpose_a, transpose_b, m, n, k, dtype, device_id,
};
AlgorithmConfig algorithm_config(kNoAlgorithm);
ComputationType computation_type;
bool compute_type_supported =
GetCublasAutotuneComputationType(dtype, &computation_type);
if (use_autotune && compute_type_supported && !algorithms->empty()) {
ProfileResult best_result;
// TODO(yangzihao): Unify this code with conv autotuning.
if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
&algorithm_config)) {
ProfileResult profile_result;
for (auto profile_algorithm : (*algorithms)) {
// Cublas does
// C = A x B
// where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// C' = B' x A' (' stands for transpose)
bool cublas_launch_status =
stream
->ThenBlasGemmWithAlgorithm(
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
&c_ptr, n, computation_type, profile_algorithm,
&profile_result)
.ok();
if (cublas_launch_status) {
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
}
}
// Try BlasGemmWithProfiling
bool cublas_launch_status =
stream
->ThenBlasGemmWithProfiling(
blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
&c_ptr, n, &profile_result)
.ok();
if (cublas_launch_status) {
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
}
// Try BlasGemvWithProfiling
if (ShouldUseGemv<T>(n)) {
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
transpose_a ? m : k, transpose_a ? k : m,
a_ptr, b_ptr, &c_ptr, &profile_result);
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
}
}
// We make sure that each matmul parameter set only gets one pass of
// autotune. If the best result is found, assign it to algorithm_type
// and insert it to autotune map. If all internal kernels of
// cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
// autotune map.
if (best_result.is_valid()) {
algorithm_config.set_algorithm(best_result.algorithm());
}
AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
algorithm_config);
if (algorithm_config.algorithm() != kNoAlgorithm &&
algorithm_config.algorithm() != kDefaultBlasGemm &&
algorithm_config.algorithm() != kDefaultBlasGemv) {
bool cublas_launch_status =
stream
->ThenBlasGemmWithAlgorithm(
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
&c_ptr, n, computation_type, algorithm_config.algorithm(),
nullptr)
.ok();
if (!cublas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMM with algorithm launch failed : a.shape=(",
a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
}
}
}
// For the following case, we use normal BlasGemm():
// 1) We didn't set the use_autotune flag;
// 2) compute type does not support autotune;
// 3) no algorithm is found;
// 4) all internal kernels in autotune return invalid results.
// For the following case, we use normal BlasGemv():
// 1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
// and n == 1.
// 2) We set the use_autotune flag and it picked up BlasGemv() and set the
// algorithm_config.algorithm() to be kDefaultBlasGemv.
if (!use_autotune || !compute_type_supported || algorithms->empty() ||
algorithm_config.algorithm() == kNoAlgorithm ||
algorithm_config.algorithm() == kDefaultBlasGemm ||
algorithm_config.algorithm() == kDefaultBlasGemv) {
if (algorithm_config.algorithm() == kDefaultBlasGemv ||
ShouldUseGemv<T>(n)) {
// This is a matrix*vector multiply so use GEMV to compute A * b.
// Here we are multiplying in the natural order, so we have to flip
// the transposition flag to compensate for the tensor being stored
// row-major.
// TODO(yangzihao): Add Gemv as an autotuning option too.
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
transpose_a ? m : k, transpose_a ? k : m,
a_ptr, b_ptr, &c_ptr, nullptr);
} else {
// Use C' = B' x A' (' stands for transpose)
bool blas_launch_status =
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
1.0f, b_ptr, transpose_b ? k : n, a_ptr,
transpose_a ? m : k, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
"), m=", m, ", n=", n, ", k=", k));
}
}
}
}
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
std::vector<int64>* algorithms,
bool* algorithm_set_flag) {
if (*algorithm_set_flag == false) {
auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
stream->parent()->GetBlasGemmAlgorithms(algorithms);
*algorithm_set_flag = true;
}
}
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T, bool USE_CUBLAS>
class MatMulOp : public OpKernel {
public:
explicit MatMulOp(OpKernelConstruction* ctx)
: OpKernel(ctx), algorithms_set_already_(false) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
ctx, &algorithms_, &algorithms_set_already_);
use_autotune_ = MatmulAutotuneEnable();
}
void Compute(OpKernelContext* ctx) override {
const Tensor& a = ctx->input(0);
const Tensor& b = ctx->input(1);
// Check that the dimensions of the two matrices are valid.
OP_REQUIRES(
ctx, TensorShapeUtils::IsMatrix(a.shape()),
errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
a.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsMatrix(b.shape()),
errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
b.shape().DebugString()));
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0].first = transpose_a_ ? 0 : 1;
dim_pair[0].second = transpose_b_ ? 1 : 0;
OP_REQUIRES(
ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
errors::InvalidArgument(
"Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
", In[1]: ", b.shape().DebugString()));
int a_dim_remaining = 1 - dim_pair[0].first;
int b_dim_remaining = 1 - dim_pair[0].second;
TensorShape out_shape(
{a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
if (out->NumElements() == 0) {
// If a has shape [0, x] or b has shape [x, 0], the output shape
// is a 0-element matrix, so there is nothing to do.
return;
}
if (a.NumElements() == 0 && b.NumElements() == 0) {
// If a has shape [x, 0] and b has shape [0, y], the
// output shape is [x, y] where x and y are non-zero, so we fill
// the output with zeros.
functor::SetZeroFunctor<Device, T> f;
f(ctx->eigen_device<Device>(), out->flat<T>());
return;
}
if (std::is_same<T, bfloat16>::value) {
bool is_cpu = std::is_same<Device, CPUDevice>::value;
OP_REQUIRES(ctx, is_cpu,
errors::Internal("bfloat16 matmul is not supported by GPU"));
Tensor a_float, b_float, out_float;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float));
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float));
// TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(),
a.NumElements());
BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(),
b.NumElements());
LaunchMatMul<Device, float, USE_CUBLAS>::launch(
ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_,
&out_float);
FloatToBFloat16(out_float.flat<float>().data(),
out->flat<bfloat16>().data(), out->NumElements());
} else {
LaunchMatMul<Device, T, USE_CUBLAS>::launch(
ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
}
}
private:
std::vector<int64> algorithms_;
bool algorithms_set_already_;
bool use_autotune_;
bool transpose_a_;
bool transpose_b_;
};
namespace functor {
// Partial specialization MatMulFunctor<Device=CPUDevice, T>.
template <typename T>
struct MatMulFunctor<CPUDevice, T> {
void operator()(
const CPUDevice& d, typename MatMulTypes<T>::out_type out,
typename MatMulTypes<T>::in_type in0,
typename MatMulTypes<T>::in_type in1,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
}
};
} // end namespace functor
#define REGISTER_CPU_EIGEN(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
REGISTER_CPU_EIGEN(T);
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
REGISTER_KERNEL_BUILDER(Name("MatMul") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
TF_CALL_int32(REGISTER_CPU);
TF_CALL_int64(REGISTER_CPU);
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
TF_CALL_COMPLEX_TYPES(REGISTER_CPU);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/matmul_op_impl.h"
namespace tensorflow {

View File

@ -15,8 +15,8 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.
#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
#define EIGEN_USE_THREADS
@ -633,10 +633,21 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
template <typename Device, typename Scalar>
class BaseBatchMatMulOp : public OpKernel {
public:
explicit BaseBatchMatMulOp(OpKernelConstruction* context)
explicit BaseBatchMatMulOp(OpKernelConstruction* context,
bool is_legacy_matmul)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
if (is_legacy_matmul) {
// The old MatMul kernel has "transpose_a/transpose_b" attributes.
OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
adj_x_ = false;
adj_y_ = false;
} else {
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
trans_x_ = false;
trans_y_ = false;
}
}
~BaseBatchMatMulOp() override {}
@ -672,8 +683,8 @@ class BaseBatchMatMulOp : public OpKernel {
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
errors::Internal("Failed to reshape In[1] from ",
in1.shape().DebugString()));
if (adj_x_) std::swap(d0, d1);
if (adj_y_) std::swap(d2, d3);
if (adj_x_ || trans_x_) std::swap(d0, d1);
if (adj_y_ || trans_y_) std::swap(d2, d3);
OP_REQUIRES(ctx, d1 == d2,
errors::InvalidArgument(
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
@ -696,9 +707,36 @@ class BaseBatchMatMulOp : public OpKernel {
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
errors::Internal("Failed to reshape output from ",
out->shape().DebugString()));
LaunchBatchMatMul<Device, Scalar>::Launch(
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
/*trans_y=*/false, bcast, &out_reshaped);
if (std::is_same<Scalar, bfloat16>::value) {
bool is_cpu = std::is_same<Device, CPUDevice>::value;
OP_REQUIRES(ctx, is_cpu,
errors::Internal("bfloat16 matmul is not supported by GPU"));
Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
&in0_reshaped_float));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
&in1_reshaped_float));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
&out_reshaped_float));
// TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
in0_reshaped_float.flat<float>().data(),
in0_reshaped.NumElements());
BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
in1_reshaped_float.flat<float>().data(),
in1_reshaped.NumElements());
LaunchBatchMatMul<Device, float>::Launch(
ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
trans_y_, bcast, &out_reshaped_float);
FloatToBFloat16(out_reshaped_float.flat<float>().data(),
out_reshaped.flat<bfloat16>().data(), out->NumElements());
} else {
LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
adj_x_, adj_y_, trans_x_,
trans_y_, bcast, &out_reshaped);
}
}
protected:
@ -706,16 +744,19 @@ class BaseBatchMatMulOp : public OpKernel {
const Tensor& in1) = 0;
private:
// TODO(171979567) Make the ops take both adj and transpose attributes.
bool adj_x_;
bool adj_y_;
bool trans_x_;
bool trans_y_;
};
// BatchMatMul Op implementation which disallows broadcasting.
template <typename Device, typename Scalar>
template <typename Device, typename Scalar, bool is_legacy_matmul = false>
class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
public:
explicit BatchMatMulOp(OpKernelConstruction* context)
: BaseBatchMatMulOp<Device, Scalar>(context) {}
: BaseBatchMatMulOp<Device, Scalar>(context, is_legacy_matmul) {}
~BatchMatMulOp() override {}
@ -729,15 +770,21 @@ class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
in0.shape().DebugString(), " vs. ",
in1.shape().DebugString()));
const int ndims = in0.dims();
OP_REQUIRES(
ctx, ndims >= 2,
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
if (is_legacy_matmul) {
OP_REQUIRES(ctx, ndims == 2,
errors::InvalidArgument(
"In[0].dim(", i, ") and In[1].dim(", i,
") must be the same: ", in0.shape().DebugString(), " vs ",
in1.shape().DebugString()));
"In[0] and In[1] ndims must be == 2: ", ndims));
} else {
OP_REQUIRES(ctx, ndims >= 2,
errors::InvalidArgument(
"In[0] and In[1] ndims must be >= 2: ", ndims));
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
errors::InvalidArgument(
"In[0].dim(", i, ") and In[1].dim(", i,
") must be the same: ", in0.shape().DebugString(),
" vs ", in1.shape().DebugString()));
}
}
}
};
@ -747,7 +794,8 @@ template <typename Device, typename Scalar>
class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
public:
explicit BatchMatMulV2Op(OpKernelConstruction* context)
: BaseBatchMatMulOp<Device, Scalar>(context) {}
: BaseBatchMatMulOp<Device, Scalar>(context,
/* is_legacy_matmul= */ false) {}
~BatchMatMulV2Op() override {}
@ -771,7 +819,10 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
BatchMatMulOp<CPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulV2Op<CPUDevice, TYPE>)
BatchMatMulV2Op<CPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulOp<CPUDevice, TYPE, /* is_legacy_matmul=*/true>)
#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
REGISTER_KERNEL_BUILDER( \
@ -779,8 +830,11 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
BatchMatMulOp<GPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMulV2Op<GPUDevice, TYPE>)
BatchMatMulV2Op<GPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMulOp<GPUDevice, TYPE, /* is_legacy_matmul=*/true>)
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/matmul_op_impl.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
@ -21,17 +21,13 @@ limitations under the License.
namespace tensorflow {
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_FLOAT_TYPES(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int16(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATMUL_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace {
template <typename T>
class FusedMatMulOpTest : public OpsTestBase {
@ -459,4 +460,230 @@ BM_Matmul(2000, 1, 2000, true, false);
BM_Matmul(2000, 1, 2000, false, true);
BM_Matmul(2000, 1, 2000, true, true);
} // end namespace tensorflow
// Benchmarks for batched matmul with broadcasting.
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
.Input(input)
.Input(shape)
.Finalize(g, &ret));
return ret;
}
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
.Input(in0)
.Input(in1)
.Attr("adj_x", adj_x)
.Attr("adj_y", adj_y)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
bool adjoint_b, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
in1.flat<T>().setRandom();
test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
test::graph::Constant(g, in1), adjoint_a, adjoint_b);
return g;
}
template <typename T>
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
bool manual_broadcast, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, TensorShape({b0, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, TensorShape({b1, k, n}));
in1.flat<T>().setRandom();
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
Node* in0_node = nullptr;
Node* in1_node = nullptr;
if (manual_broadcast) {
for (int i = 0; i < 3; ++i) {
auto vec0 = broadcasted_in0_shape.vec<int64>();
auto vec1 = broadcasted_in1_shape.vec<int64>();
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
}
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
test::graph::Constant(g, broadcasted_in0_shape));
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
test::graph::Constant(g, broadcasted_in1_shape));
} else {
in0_node = test::graph::Constant(g, in0);
in1_node = test::graph::Constant(g, in1);
}
BatchMatmulV2(g, in0_node, in1_node, false, false);
return g;
}
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
static void \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \
test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// cpu);
// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
/* Uncomment to enable benchmarks for double & complex types: */
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
// gpu);
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
// \
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
// Macro arguments names: --------------------------------------------------- //
// B1: batch size of LHS
// B2: batch size of RHS
// M: outer dimension of LHS
// K: inner dimensions of LHS and RHS
// N: outer dimension of RHS
// MB: boolean indicating whether to use manual broadcasting
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
// D: Device (e.g. cpu, gpu)
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
static void \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
K * N * 2); \
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
// Typical fully connected layers
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
// Square matmul.
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
// Matrix-vector multiplies.
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
// Vector-matrix multiplies.
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
// Typical fully connected layers
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
// Square matmul.
BM_BatchMatmul(1, 32, 32, 32, false, false);
BM_BatchMatmul(1, 128, 128, 128, false, false);
BM_BatchMatmul(1, 256, 256, 256, false, false);
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
BM_BatchMatmul(2, 32, 32, 32, false, false);
BM_BatchMatmul(2, 128, 128, 128, false, false);
BM_BatchMatmul(2, 256, 256, 256, false, false);
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
BM_BatchMatmul(4, 32, 32, 32, false, false);
BM_BatchMatmul(4, 128, 128, 128, false, false);
BM_BatchMatmul(4, 256, 256, 256, false, false);
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
BM_BatchMatmul(8, 32, 32, 32, false, false);
BM_BatchMatmul(8, 128, 128, 128, false, false);
BM_BatchMatmul(8, 256, 256, 256, false, false);
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
BM_BatchMatmul(32, 32, 32, 32, false, false);
BM_BatchMatmul(32, 128, 128, 128, false, false);
BM_BatchMatmul(32, 256, 256, 256, false, false);
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
// Matrix-vector multiplies.
BM_BatchMatmul(1, 10000, 200, 1, false, false);
BM_BatchMatmul(8, 10000, 200, 1, false, false);
BM_BatchMatmul(32, 10000, 200, 1, false, false);
BM_BatchMatmul(1, 10000, 200, 1, true, false);
BM_BatchMatmul(8, 10000, 200, 1, true, false);
BM_BatchMatmul(32, 10000, 200, 1, true, false);
BM_BatchMatmul(1, 10000, 200, 1, false, true);
BM_BatchMatmul(8, 10000, 200, 1, false, true);
BM_BatchMatmul(32, 10000, 200, 1, false, true);
BM_BatchMatmul(1, 10000, 200, 1, true, true);
BM_BatchMatmul(8, 10000, 200, 1, true, true);
BM_BatchMatmul(32, 10000, 200, 1, true, true);
// Vector-matrix multiplies.
BM_BatchMatmul(1, 1, 200, 10000, false, false);
BM_BatchMatmul(8, 1, 200, 10000, false, false);
BM_BatchMatmul(32, 1, 200, 10000, false, false);
BM_BatchMatmul(1, 1, 200, 10000, true, false);
BM_BatchMatmul(8, 1, 200, 10000, true, false);
BM_BatchMatmul(32, 1, 200, 10000, true, false);
BM_BatchMatmul(1, 1, 200, 10000, false, true);
BM_BatchMatmul(8, 1, 200, 10000, false, true);
BM_BatchMatmul(32, 1, 200, 10000, false, true);
BM_BatchMatmul(1, 1, 200, 10000, true, true);
BM_BatchMatmul(8, 1, 200, 10000, true, true);
BM_BatchMatmul(32, 1, 200, 10000, true, true);
} // namespace
} // namespace tensorflow

View File

@ -33,8 +33,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/matmul_op_impl.h"
#include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

View File

@ -197,14 +197,16 @@ class MatMulInfixOperatorTest(test_lib.TestCase):
def testMismatchedShape(self):
with self.assertRaisesRegex(
Exception, "(Shape must be rank 2 but is rank 1|is not a matrix)"):
Exception, (r"(In\[0\] and In\[1\] has different ndims|In\[0\] "
r"ndims must be >= 2|Shape must be rank 2 but is rank 1)")):
infix_matmul(
ops.convert_to_tensor([10.0, 20.0, 30.0]),
ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
def testMismatchedDimensions(self):
with self.assertRaisesRegex(
Exception, "(Dimensions must be equal|Matrix size-incompatible)"):
Exception,
r"(In\[0\] mismatch In\[1\] shape|Dimensions must be equal)"):
infix_matmul(
ops.convert_to_tensor([[10.0, 20.0, 30.0]]),
ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
@ -234,9 +236,10 @@ if __name__ == "__main__":
# TF2 does not support placeholders under eager so we skip it
for use_static_shape in set([True, tf2.enabled()]):
for dtype in dtypes_to_test:
if not use_static_shape and (dtype == np.int32 or dtype == np.int64):
# TODO(rmlarsen): Re-enable this test when we have fixed the underlying
# bug in Windows (b/35935459).
if test_util.is_xla_enabled() and (dtype == np.int32 or
dtype == np.int64):
# TODO(b/171924639): Enable this test when XLA DOT supports
# integer types.
continue
for m in sizes:
for n in sizes:

View File

@ -55,8 +55,8 @@ class TensordotTest(test_lib.TestCase):
if context.executing_eagerly():
return
with self.cached_session() as sess:
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
"Matrix size-incompatible"):
with self.assertRaisesOpError(
r"In\[0\] mismatch In\[1\] shape: 2 vs\. 3: \[2,2\] \[3,2\]"):
a_ph = array_ops.placeholder(dtypes.float32)
b_ph = array_ops.placeholder(dtypes.float32)
axes_ph = array_ops.placeholder(dtypes.int32)

View File

@ -108,15 +108,14 @@ class PrintOpFilegroupTest(test.TestCase):
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
'rawproto', self.WriteGraphFiles(graphs), default_ops)
matmul_prefix = ''
matmul_prefix = 'Batch'
self.assertListEqual(
[
('AccumulateNV2', None), #
('BiasAdd', 'BiasOp<CPUDevice, float>'), #
('MatMul',
matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), #
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), #
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, true>'), #
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, true>'), #
('NoOp', 'NoOp'), #
('Reshape', 'ReshapeOp'), #
('_Recv', 'RecvOp'), #
@ -132,9 +131,8 @@ class PrintOpFilegroupTest(test.TestCase):
[
('AccumulateNV2', None), #
('BiasAdd', 'BiasOp<CPUDevice, float>'), #
('MatMul',
matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), #
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), #
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, true>'), #
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, true>'), #
('NoOp', 'NoOp'), #
('Reshape', 'ReshapeOp'), #
('_Recv', 'RecvOp'), #