Clean up duplicated code in matmul_op/batch_matmul_op and associated tests. Having a separate redundant implementation for non-batch matmul is a maintenance burden.
We add a registration of the batch kernel under the "MatMul" symbol (with modified shape validation) for legacy graphs. PiperOrigin-RevId: 339715052 Change-Id: Iad428cdf62a656de56f62d996db44be682af5d13
This commit is contained in:
parent
47926752a8
commit
0e4a3676d9
tensorflow
c/eager
core
distributed_runtime/rpc
framework
kernels
python
@ -769,7 +769,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status));
|
||||
EXPECT_EQ(nullptr, t);
|
||||
const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
|
||||
const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]";
|
||||
EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
|
||||
<< TF_Message(status);
|
||||
// Since error is not cleared, the following copy with correct device will
|
||||
|
@ -931,7 +931,6 @@ TEST(SessionTest, InvalidOpInputName) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
@ -950,7 +949,6 @@ TEST(SessionTest, InvalidOpInputName) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
@ -969,7 +967,6 @@ TEST(SessionTest, InvalidOpInputName) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
@ -988,7 +985,6 @@ TEST(SessionTest, InvalidOpInputName) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
@ -1026,7 +1022,6 @@ TEST(SessionTest, ExtendValidation) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
&extension);
|
||||
@ -1043,7 +1038,6 @@ TEST(SessionTest, ExtendValidation) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
&extension);
|
||||
@ -1057,7 +1051,6 @@ TEST(SessionTest, ExtendValidation) {
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
&extension);
|
||||
|
@ -406,7 +406,7 @@ XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
TEST(TFunc, WXPlusB) {
|
||||
auto expect = R"P(
|
||||
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
|
||||
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
|
||||
mm = MatMul[T=$T, transpose_a=false, transpose_b=false](w, x)
|
||||
y = Add[T=$T](mm:product:0, b)
|
||||
return y = y:z:0
|
||||
}
|
||||
|
@ -346,10 +346,7 @@ FunctionDef WXPlusB() {
|
||||
{{{"mm"},
|
||||
"MatMul",
|
||||
{"w", "x"},
|
||||
{{"T", "$T"},
|
||||
{"transpose_a", false},
|
||||
{"transpose_b", false},
|
||||
{"_kernel", "eigen"}}},
|
||||
{{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}}},
|
||||
{{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
|
||||
}
|
||||
|
||||
|
@ -48,7 +48,6 @@ load(
|
||||
)
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl_ml",
|
||||
"mkl_deps",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
@ -3241,7 +3240,6 @@ cc_library(
|
||||
deps = [
|
||||
":aggregate_ops",
|
||||
":argmax_op",
|
||||
":batch_matmul_op",
|
||||
":betainc_op",
|
||||
":bincount_op",
|
||||
":bucketize_op",
|
||||
@ -3337,14 +3335,27 @@ tf_kernel_library(
|
||||
|
||||
tf_kernel_library(
|
||||
name = "batch_matmul_op",
|
||||
deps = [":matmul_op"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "matmul_op",
|
||||
# <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
|
||||
hdrs = ["batch_matmul_op_impl.h"],
|
||||
prefix = "batch_matmul_op",
|
||||
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
]) + if_cuda_or_rocm([
|
||||
"//tensorflow/core/kernels:gpu_utils",
|
||||
]),
|
||||
hdrs = ["matmul_op_impl.h"],
|
||||
defines = select({
|
||||
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
prefix = "matmul_op",
|
||||
deps = MATH_DEPS + [
|
||||
":eigen_contraction_kernel",
|
||||
":fused_eigen_output_kernels",
|
||||
] + select({
|
||||
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
|
||||
"//conditions:default": [],
|
||||
}) + mkl_deps() + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
]) + if_cuda_or_rocm([":gpu_utils"]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -3406,28 +3417,6 @@ tf_kernel_library(
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "matmul_op",
|
||||
srcs = [
|
||||
"matmul_op.cc",
|
||||
"matmul_op_fused.cc",
|
||||
],
|
||||
hdrs = ["matmul_op.h"],
|
||||
defines = select({
|
||||
":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = MATH_DEPS + [
|
||||
":eigen_contraction_kernel",
|
||||
":fused_eigen_output_kernels",
|
||||
] + select({
|
||||
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
|
||||
"//conditions:default": [],
|
||||
}) + mkl_deps() + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
]) + if_cuda_or_rocm([":gpu_utils"]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "reduction_ops",
|
||||
gpu_srcs = ["reduction_gpu_kernels.cu.h"],
|
||||
@ -3620,25 +3609,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "batch_matmul_op_test",
|
||||
size = "small",
|
||||
srcs = ["batch_matmul_op_test.cc"],
|
||||
deps = [
|
||||
":batch_matmul_op",
|
||||
":broadcast_to_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "scan_ops_test",
|
||||
size = "small",
|
||||
@ -5868,8 +5838,8 @@ filegroup(
|
||||
"identity_op.h",
|
||||
"immutable_constant_op.cc",
|
||||
"immutable_constant_op.h",
|
||||
"matmul_op.cc",
|
||||
"matmul_op.h",
|
||||
"matmul_op_impl.h",
|
||||
"matmul_op_real.cc",
|
||||
"no_op.cc",
|
||||
"no_op.h",
|
||||
"one_hot_op.cc",
|
||||
@ -5948,7 +5918,6 @@ filegroup(
|
||||
srcs = [
|
||||
"argmax_op.h",
|
||||
"avgpooling_op.h",
|
||||
"batch_matmul_op_impl.h",
|
||||
"batch_norm_op.h",
|
||||
"bincount_op.h",
|
||||
"broadcast_to_op.h",
|
||||
@ -6039,7 +6008,6 @@ filegroup(
|
||||
":android_extended_ops_headers",
|
||||
"argmax_op.cc",
|
||||
"avgpooling_op.cc",
|
||||
"batch_matmul_op_real.cc",
|
||||
"batch_norm_op.cc",
|
||||
"bcast_ops.cc",
|
||||
"check_numerics_op.cc",
|
||||
@ -7431,7 +7399,6 @@ test_suite(
|
||||
"manual", # Avoid redundancy when using wildcard test patterns.
|
||||
],
|
||||
tests = [
|
||||
":batch_matmul_op_test",
|
||||
":batch_norm_op_test",
|
||||
":broadcast_to_op_test",
|
||||
":cast_op_test",
|
||||
|
@ -1,257 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/broadcast_to_op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
|
||||
.Input(input)
|
||||
.Input(shape)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
|
||||
.Input(in0)
|
||||
.Input(in1)
|
||||
.Attr("adj_x", adj_x)
|
||||
.Attr("adj_y", adj_y)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
|
||||
bool adjoint_b, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
|
||||
in0.flat<T>().setRandom();
|
||||
Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, in1), adjoint_a, adjoint_b);
|
||||
return g;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
|
||||
bool manual_broadcast, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, TensorShape({b0, m, k}));
|
||||
in0.flat<T>().setRandom();
|
||||
Tensor in1(type, TensorShape({b1, k, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
|
||||
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
|
||||
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
|
||||
|
||||
Node* in0_node = nullptr;
|
||||
Node* in1_node = nullptr;
|
||||
if (manual_broadcast) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
auto vec0 = broadcasted_in0_shape.vec<int64>();
|
||||
auto vec1 = broadcasted_in1_shape.vec<int64>();
|
||||
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
|
||||
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
|
||||
}
|
||||
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, broadcasted_in0_shape));
|
||||
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
|
||||
test::graph::Constant(g, broadcasted_in1_shape));
|
||||
} else {
|
||||
in0_node = test::graph::Constant(g, in0);
|
||||
in1_node = test::graph::Constant(g, in1);
|
||||
}
|
||||
|
||||
BatchMatmulV2(g, in0_node, in1_node, false, false);
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
|
||||
static void \
|
||||
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \
|
||||
test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
|
||||
|
||||
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
|
||||
BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
|
||||
// cpu);
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
|
||||
/* Uncomment to enable benchmarks for double & complex types: */
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
|
||||
// gpu);
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
|
||||
// \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
|
||||
|
||||
// Macro arguments names: --------------------------------------------------- //
|
||||
// B1: batch size of LHS
|
||||
// B2: batch size of RHS
|
||||
// M: outer dimension of LHS
|
||||
// K: inner dimensions of LHS and RHS
|
||||
// N: outer dimension of RHS
|
||||
// MB: boolean indicating whether to use manual broadcasting
|
||||
// T: C++ type of scalars (e.g. float, std::complex)
|
||||
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
|
||||
// D: Device (e.g. cpu, gpu)
|
||||
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
|
||||
static void \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
|
||||
K * N * 2); \
|
||||
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
|
||||
|
||||
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
|
||||
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
|
||||
|
||||
// Square matmul.
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
|
||||
|
||||
// Matrix-vector multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
|
||||
|
||||
// Vector-matrix multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
|
||||
|
||||
// Square matmul.
|
||||
BM_BatchMatmul(1, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(1, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(1, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(2, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(2, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(2, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(4, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(4, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(4, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(8, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(8, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(8, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(32, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(32, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(32, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
|
||||
|
||||
// Matrix-vector multiplies.
|
||||
BM_BatchMatmul(1, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, true, true);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, true, true);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, true, true);
|
||||
|
||||
// Vector-matrix multiplies.
|
||||
BM_BatchMatmul(1, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, true, true);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -30,9 +30,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/linalg/einsum_op.h"
|
||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
@ -1,567 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/math_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/matmul_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/util/matmul_autotune.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#endif
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
struct LaunchMatMul;
|
||||
|
||||
namespace {
|
||||
// Converts a TensorFlow Tensor to an Eigen Matrix.
|
||||
template <typename T>
|
||||
Eigen::Map<
|
||||
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ToEigenMatrix(const Tensor& tensor) {
|
||||
auto matrix = tensor.matrix<T>();
|
||||
return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map(
|
||||
matrix.data(), matrix.dimension(0), matrix.dimension(1));
|
||||
}
|
||||
|
||||
// Converts a TensorFlow Tensor to an Eigen Vector.
|
||||
template <typename T>
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) {
|
||||
auto v = tensor->flat<T>();
|
||||
return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
|
||||
}
|
||||
template <typename T>
|
||||
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(
|
||||
const Tensor& tensor) {
|
||||
auto v = tensor.flat<T>();
|
||||
return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// If either side can be represented as a vector, do an explicit vector
|
||||
// matrix multiply and return true; else return false.
|
||||
//
|
||||
// Note: this uses plain Eigen and not Eigen Tensor because it is more
|
||||
// efficient.
|
||||
template <typename T>
|
||||
bool ExplicitVectorMatrixOptimization(
|
||||
const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
Tensor* out) {
|
||||
if (out->dim_size(0) == 1) {
|
||||
if (dim_pair[0].second == 0) {
|
||||
// Note: this case is optimized in Eigen Tensors.
|
||||
return false;
|
||||
} else {
|
||||
auto out_v = ToEigenVector<T>(out);
|
||||
auto a_v = ToEigenVector<T>(a);
|
||||
auto b_m = ToEigenMatrix<T>(b);
|
||||
out_v.noalias() = b_m * a_v;
|
||||
}
|
||||
return true;
|
||||
} else if (out->dim_size(1) == 1) {
|
||||
auto out_v = ToEigenVector<T>(out);
|
||||
auto a_m = ToEigenMatrix<T>(a);
|
||||
auto b_v = ToEigenVector<T>(b);
|
||||
if (dim_pair[0].first == 0) {
|
||||
out_v.noalias() = a_m.transpose() * b_v;
|
||||
} else {
|
||||
out_v.noalias() = a_m * b_v;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Half is not supported.
|
||||
template <>
|
||||
bool ExplicitVectorMatrixOptimization<Eigen::half>(
|
||||
const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
Tensor* out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchMatMulBase {
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
typedef se::blas::AlgorithmType AlgorithmType;
|
||||
#else
|
||||
typedef int64 AlgorithmType;
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
static void launch(
|
||||
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
std::vector<AlgorithmType>* algorithms, bool use_autotune, Tensor* out) {
|
||||
// An explicit vector-matrix multiply is much better optimized than an
|
||||
// implicit one and this is a bottleneck during non-batched inference.
|
||||
bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
|
||||
if (!was_vector) {
|
||||
functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
|
||||
out->matrix<T>(), a.matrix<T>(),
|
||||
b.matrix<T>(), dim_pair);
|
||||
}
|
||||
}
|
||||
|
||||
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
|
||||
std::vector<int64>* algorithms,
|
||||
bool* algorithm_set_flag) {}
|
||||
};
|
||||
// On CPUs, we ignore USE_CUBLAS
|
||||
template <typename T>
|
||||
struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
|
||||
|
||||
template <typename T, bool USE_CUBLAS>
|
||||
struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
|
||||
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
struct LaunchBlasGemv {
|
||||
static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans,
|
||||
uint64 m, uint64 n, const se::DeviceMemory<T>& a,
|
||||
const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c,
|
||||
se::blas::ProfileResult* output_profile) {
|
||||
const auto blas_trans = trans ? se::blas::Transpose::kTranspose
|
||||
: se::blas::Transpose::kNoTranspose;
|
||||
if (output_profile == nullptr) {
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
|
||||
static_cast<T>(0.0), c, 1)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(
|
||||
errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
|
||||
}
|
||||
} else {
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
|
||||
a, m, b, 1, static_cast<T>(0.0), c, 1,
|
||||
output_profile)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"Blas GEMV with profiling launch failed: m=", m, ", n=", n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsSupported() { return true; }
|
||||
};
|
||||
|
||||
template <>
|
||||
void LaunchBlasGemv<Eigen::half>::Compute(
|
||||
OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n,
|
||||
const se::DeviceMemory<Eigen::half>& a,
|
||||
const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c,
|
||||
se::blas::ProfileResult* output_profile) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"Blas GEMV launch failed: GEMV is not implemented for float16."));
|
||||
}
|
||||
|
||||
template <>
|
||||
bool LaunchBlasGemv<Eigen::half>::IsSupported() {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ShouldUseGemv(uint64 n) {
|
||||
return (LaunchBlasGemv<T>::IsSupported() && n == 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool GetCublasAutotuneComputationType(const DataType& dtype,
|
||||
se::blas::ComputationType* compute_type) {
|
||||
using se::blas::ComputationType;
|
||||
switch (dtype) {
|
||||
case DT_HALF:
|
||||
case DT_BFLOAT16:
|
||||
static bool use_f32_for_f16_computation =
|
||||
MatmulDoFP32ComputationFP16Input();
|
||||
if (use_f32_for_f16_computation) {
|
||||
*compute_type = ComputationType::kF32;
|
||||
} else {
|
||||
*compute_type = ComputationType::kF16;
|
||||
}
|
||||
return false;
|
||||
case DT_FLOAT:
|
||||
*compute_type = ComputationType::kF32;
|
||||
return true;
|
||||
case DT_DOUBLE:
|
||||
*compute_type = ComputationType::kF64;
|
||||
return true;
|
||||
default:
|
||||
// Unsupported compute_type, return false.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// A dummy type to group matmul autotune results together.
|
||||
struct MatmulAutoTuneGroup {
|
||||
static string name() { return "Matmul"; }
|
||||
};
|
||||
typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
|
||||
se::blas::AlgorithmConfig>
|
||||
AutoTuneMatmul;
|
||||
|
||||
template <typename T>
|
||||
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
|
||||
static void launch(
|
||||
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
|
||||
using se::blas::AlgorithmConfig;
|
||||
using se::blas::ComputationType;
|
||||
using se::blas::kDefaultAlgorithm;
|
||||
using se::blas::kDefaultBlasGemm;
|
||||
using se::blas::kDefaultBlasGemv;
|
||||
using se::blas::kNoAlgorithm;
|
||||
using se::blas::ProfileResult;
|
||||
using se::blas::Transpose;
|
||||
Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
|
||||
const uint64 m = a.dim_size(1 - dim_pair[0].first);
|
||||
const uint64 k = a.dim_size(dim_pair[0].first);
|
||||
const uint64 n = b.dim_size(1 - dim_pair[0].second);
|
||||
bool transpose_a = dim_pair[0].first == 0;
|
||||
bool transpose_b = dim_pair[0].second == 1;
|
||||
auto blas_transpose_a = trans[transpose_a];
|
||||
auto blas_transpose_b = trans[transpose_b];
|
||||
|
||||
auto* stream = ctx->op_device_context()->stream();
|
||||
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
|
||||
a.template flat<T>().size());
|
||||
auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
|
||||
b.template flat<T>().size());
|
||||
auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
|
||||
out->template flat<T>().size());
|
||||
auto alpha = static_cast<T>(1.0);
|
||||
auto beta = static_cast<T>(0.0);
|
||||
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = a.dtype();
|
||||
MatmulParameters matmul_parameters = {
|
||||
transpose_a, transpose_b, m, n, k, dtype, device_id,
|
||||
};
|
||||
AlgorithmConfig algorithm_config(kNoAlgorithm);
|
||||
|
||||
ComputationType computation_type;
|
||||
bool compute_type_supported =
|
||||
GetCublasAutotuneComputationType(dtype, &computation_type);
|
||||
if (use_autotune && compute_type_supported && !algorithms->empty()) {
|
||||
ProfileResult best_result;
|
||||
// TODO(yangzihao): Unify this code with conv autotuning.
|
||||
if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
|
||||
&algorithm_config)) {
|
||||
ProfileResult profile_result;
|
||||
for (auto profile_algorithm : (*algorithms)) {
|
||||
// Cublas does
|
||||
// C = A x B
|
||||
// where A, B and C are assumed to be in column major.
|
||||
// We want the output to be in row-major, so we can compute
|
||||
// C' = B' x A' (' stands for transpose)
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmWithAlgorithm(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
|
||||
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
|
||||
&c_ptr, n, computation_type, profile_algorithm,
|
||||
&profile_result)
|
||||
.ok();
|
||||
if (cublas_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try BlasGemmWithProfiling
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmWithProfiling(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
|
||||
transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
|
||||
&c_ptr, n, &profile_result)
|
||||
.ok();
|
||||
if (cublas_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try BlasGemvWithProfiling
|
||||
if (ShouldUseGemv<T>(n)) {
|
||||
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
|
||||
transpose_a ? m : k, transpose_a ? k : m,
|
||||
a_ptr, b_ptr, &c_ptr, &profile_result);
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// We make sure that each matmul parameter set only gets one pass of
|
||||
// autotune. If the best result is found, assign it to algorithm_type
|
||||
// and insert it to autotune map. If all internal kernels of
|
||||
// cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
|
||||
// autotune map.
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
|
||||
algorithm_config);
|
||||
if (algorithm_config.algorithm() != kNoAlgorithm &&
|
||||
algorithm_config.algorithm() != kDefaultBlasGemm &&
|
||||
algorithm_config.algorithm() != kDefaultBlasGemv) {
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmWithAlgorithm(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
|
||||
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
|
||||
&c_ptr, n, computation_type, algorithm_config.algorithm(),
|
||||
nullptr)
|
||||
.ok();
|
||||
if (!cublas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"Blas GEMM with algorithm launch failed : a.shape=(",
|
||||
a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
|
||||
", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
|
||||
}
|
||||
}
|
||||
}
|
||||
// For the following case, we use normal BlasGemm():
|
||||
// 1) We didn't set the use_autotune flag;
|
||||
// 2) compute type does not support autotune;
|
||||
// 3) no algorithm is found;
|
||||
// 4) all internal kernels in autotune return invalid results.
|
||||
// For the following case, we use normal BlasGemv():
|
||||
// 1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
|
||||
// and n == 1.
|
||||
// 2) We set the use_autotune flag and it picked up BlasGemv() and set the
|
||||
// algorithm_config.algorithm() to be kDefaultBlasGemv.
|
||||
if (!use_autotune || !compute_type_supported || algorithms->empty() ||
|
||||
algorithm_config.algorithm() == kNoAlgorithm ||
|
||||
algorithm_config.algorithm() == kDefaultBlasGemm ||
|
||||
algorithm_config.algorithm() == kDefaultBlasGemv) {
|
||||
if (algorithm_config.algorithm() == kDefaultBlasGemv ||
|
||||
ShouldUseGemv<T>(n)) {
|
||||
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
||||
// Here we are multiplying in the natural order, so we have to flip
|
||||
// the transposition flag to compensate for the tensor being stored
|
||||
// row-major.
|
||||
// TODO(yangzihao): Add Gemv as an autotuning option too.
|
||||
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
|
||||
transpose_a ? m : k, transpose_a ? k : m,
|
||||
a_ptr, b_ptr, &c_ptr, nullptr);
|
||||
} else {
|
||||
// Use C' = B' x A' (' stands for transpose)
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
1.0f, b_ptr, transpose_b ? k : n, a_ptr,
|
||||
transpose_a ? m : k, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
|
||||
a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
|
||||
"), m=", m, ", n=", n, ", k=", k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
|
||||
std::vector<int64>* algorithms,
|
||||
bool* algorithm_set_flag) {
|
||||
if (*algorithm_set_flag == false) {
|
||||
auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
|
||||
stream->parent()->GetBlasGemmAlgorithms(algorithms);
|
||||
*algorithm_set_flag = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
class MatMulOp : public OpKernel {
|
||||
public:
|
||||
explicit MatMulOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx), algorithms_set_already_(false) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
|
||||
|
||||
LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
|
||||
ctx, &algorithms_, &algorithms_set_already_);
|
||||
use_autotune_ = MatmulAutotuneEnable();
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& a = ctx->input(0);
|
||||
const Tensor& b = ctx->input(1);
|
||||
|
||||
// Check that the dimensions of the two matrices are valid.
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsMatrix(a.shape()),
|
||||
errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
|
||||
a.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsMatrix(b.shape()),
|
||||
errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
|
||||
b.shape().DebugString()));
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
|
||||
dim_pair[0].first = transpose_a_ ? 0 : 1;
|
||||
dim_pair[0].second = transpose_b_ ? 1 : 0;
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
|
||||
errors::InvalidArgument(
|
||||
"Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
|
||||
", In[1]: ", b.shape().DebugString()));
|
||||
int a_dim_remaining = 1 - dim_pair[0].first;
|
||||
int b_dim_remaining = 1 - dim_pair[0].second;
|
||||
TensorShape out_shape(
|
||||
{a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
|
||||
|
||||
if (out->NumElements() == 0) {
|
||||
// If a has shape [0, x] or b has shape [x, 0], the output shape
|
||||
// is a 0-element matrix, so there is nothing to do.
|
||||
return;
|
||||
}
|
||||
|
||||
if (a.NumElements() == 0 && b.NumElements() == 0) {
|
||||
// If a has shape [x, 0] and b has shape [0, y], the
|
||||
// output shape is [x, y] where x and y are non-zero, so we fill
|
||||
// the output with zeros.
|
||||
functor::SetZeroFunctor<Device, T> f;
|
||||
f(ctx->eigen_device<Device>(), out->flat<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
if (std::is_same<T, bfloat16>::value) {
|
||||
bool is_cpu = std::is_same<Device, CPUDevice>::value;
|
||||
OP_REQUIRES(ctx, is_cpu,
|
||||
errors::Internal("bfloat16 matmul is not supported by GPU"));
|
||||
Tensor a_float, b_float, out_float;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float));
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float));
|
||||
|
||||
// TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
|
||||
BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(),
|
||||
a.NumElements());
|
||||
BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(),
|
||||
b.NumElements());
|
||||
|
||||
LaunchMatMul<Device, float, USE_CUBLAS>::launch(
|
||||
ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_,
|
||||
&out_float);
|
||||
FloatToBFloat16(out_float.flat<float>().data(),
|
||||
out->flat<bfloat16>().data(), out->NumElements());
|
||||
} else {
|
||||
LaunchMatMul<Device, T, USE_CUBLAS>::launch(
|
||||
ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64> algorithms_;
|
||||
bool algorithms_set_already_;
|
||||
bool use_autotune_;
|
||||
bool transpose_a_;
|
||||
bool transpose_b_;
|
||||
};
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Partial specialization MatMulFunctor<Device=CPUDevice, T>.
|
||||
template <typename T>
|
||||
struct MatMulFunctor<CPUDevice, T> {
|
||||
void operator()(
|
||||
const CPUDevice& d, typename MatMulTypes<T>::out_type out,
|
||||
typename MatMulTypes<T>::in_type in0,
|
||||
typename MatMulTypes<T>::in_type in1,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
|
||||
MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
#define REGISTER_CPU_EIGEN(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
|
||||
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
|
||||
REGISTER_CPU_EIGEN(T);
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MatMul") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label("cublas"), \
|
||||
MatMulOp<GPUDevice, T, true /* cublas */>)
|
||||
|
||||
TF_CALL_int32(REGISTER_CPU);
|
||||
TF_CALL_int64(REGISTER_CPU);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
// See docs in ../ops/math_ops.cc.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
@ -633,10 +633,21 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
template <typename Device, typename Scalar>
|
||||
class BaseBatchMatMulOp : public OpKernel {
|
||||
public:
|
||||
explicit BaseBatchMatMulOp(OpKernelConstruction* context)
|
||||
explicit BaseBatchMatMulOp(OpKernelConstruction* context,
|
||||
bool is_legacy_matmul)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
||||
if (is_legacy_matmul) {
|
||||
// The old MatMul kernel has "transpose_a/transpose_b" attributes.
|
||||
OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
|
||||
adj_x_ = false;
|
||||
adj_y_ = false;
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
||||
trans_x_ = false;
|
||||
trans_y_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
~BaseBatchMatMulOp() override {}
|
||||
@ -672,8 +683,8 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
|
||||
errors::Internal("Failed to reshape In[1] from ",
|
||||
in1.shape().DebugString()));
|
||||
if (adj_x_) std::swap(d0, d1);
|
||||
if (adj_y_) std::swap(d2, d3);
|
||||
if (adj_x_ || trans_x_) std::swap(d0, d1);
|
||||
if (adj_y_ || trans_y_) std::swap(d2, d3);
|
||||
OP_REQUIRES(ctx, d1 == d2,
|
||||
errors::InvalidArgument(
|
||||
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
|
||||
@ -696,9 +707,36 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
|
||||
errors::Internal("Failed to reshape output from ",
|
||||
out->shape().DebugString()));
|
||||
LaunchBatchMatMul<Device, Scalar>::Launch(
|
||||
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
|
||||
/*trans_y=*/false, bcast, &out_reshaped);
|
||||
if (std::is_same<Scalar, bfloat16>::value) {
|
||||
bool is_cpu = std::is_same<Device, CPUDevice>::value;
|
||||
OP_REQUIRES(ctx, is_cpu,
|
||||
errors::Internal("bfloat16 matmul is not supported by GPU"));
|
||||
Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
|
||||
&in0_reshaped_float));
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
|
||||
&in1_reshaped_float));
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
|
||||
&out_reshaped_float));
|
||||
|
||||
// TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
|
||||
BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
|
||||
in0_reshaped_float.flat<float>().data(),
|
||||
in0_reshaped.NumElements());
|
||||
BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
|
||||
in1_reshaped_float.flat<float>().data(),
|
||||
in1_reshaped.NumElements());
|
||||
|
||||
LaunchBatchMatMul<Device, float>::Launch(
|
||||
ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
|
||||
trans_y_, bcast, &out_reshaped_float);
|
||||
FloatToBFloat16(out_reshaped_float.flat<float>().data(),
|
||||
out_reshaped.flat<bfloat16>().data(), out->NumElements());
|
||||
} else {
|
||||
LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
|
||||
adj_x_, adj_y_, trans_x_,
|
||||
trans_y_, bcast, &out_reshaped);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -706,16 +744,19 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
const Tensor& in1) = 0;
|
||||
|
||||
private:
|
||||
// TODO(171979567) Make the ops take both adj and transpose attributes.
|
||||
bool adj_x_;
|
||||
bool adj_y_;
|
||||
bool trans_x_;
|
||||
bool trans_y_;
|
||||
};
|
||||
|
||||
// BatchMatMul Op implementation which disallows broadcasting.
|
||||
template <typename Device, typename Scalar>
|
||||
template <typename Device, typename Scalar, bool is_legacy_matmul = false>
|
||||
class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
public:
|
||||
explicit BatchMatMulOp(OpKernelConstruction* context)
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context) {}
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context, is_legacy_matmul) {}
|
||||
|
||||
~BatchMatMulOp() override {}
|
||||
|
||||
@ -729,15 +770,21 @@ class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
in0.shape().DebugString(), " vs. ",
|
||||
in1.shape().DebugString()));
|
||||
const int ndims = in0.dims();
|
||||
OP_REQUIRES(
|
||||
ctx, ndims >= 2,
|
||||
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
|
||||
for (int i = 0; i < ndims - 2; ++i) {
|
||||
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
|
||||
if (is_legacy_matmul) {
|
||||
OP_REQUIRES(ctx, ndims == 2,
|
||||
errors::InvalidArgument(
|
||||
"In[0].dim(", i, ") and In[1].dim(", i,
|
||||
") must be the same: ", in0.shape().DebugString(), " vs ",
|
||||
in1.shape().DebugString()));
|
||||
"In[0] and In[1] ndims must be == 2: ", ndims));
|
||||
} else {
|
||||
OP_REQUIRES(ctx, ndims >= 2,
|
||||
errors::InvalidArgument(
|
||||
"In[0] and In[1] ndims must be >= 2: ", ndims));
|
||||
for (int i = 0; i < ndims - 2; ++i) {
|
||||
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
|
||||
errors::InvalidArgument(
|
||||
"In[0].dim(", i, ") and In[1].dim(", i,
|
||||
") must be the same: ", in0.shape().DebugString(),
|
||||
" vs ", in1.shape().DebugString()));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -747,7 +794,8 @@ template <typename Device, typename Scalar>
|
||||
class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
public:
|
||||
explicit BatchMatMulV2Op(OpKernelConstruction* context)
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context) {}
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context,
|
||||
/* is_legacy_matmul= */ false) {}
|
||||
|
||||
~BatchMatMulV2Op() override {}
|
||||
|
||||
@ -771,7 +819,10 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
BatchMatMulOp<CPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulV2Op<CPUDevice, TYPE>)
|
||||
BatchMatMulV2Op<CPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulOp<CPUDevice, TYPE, /* is_legacy_matmul=*/true>)
|
||||
|
||||
#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -779,8 +830,11 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
BatchMatMulOp<GPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulV2Op<GPUDevice, TYPE>)
|
||||
BatchMatMulV2Op<GPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulOp<GPUDevice, TYPE, /* is_legacy_matmul=*/true>)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
@ -21,17 +21,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_int16(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATMUL_GPU);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
class FusedMatMulOpTest : public OpsTestBase {
|
||||
@ -459,4 +460,230 @@ BM_Matmul(2000, 1, 2000, true, false);
|
||||
BM_Matmul(2000, 1, 2000, false, true);
|
||||
BM_Matmul(2000, 1, 2000, true, true);
|
||||
|
||||
} // end namespace tensorflow
|
||||
// Benchmarks for batched matmul with broadcasting.
|
||||
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
|
||||
.Input(input)
|
||||
.Input(shape)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
|
||||
.Input(in0)
|
||||
.Input(in1)
|
||||
.Attr("adj_x", adj_x)
|
||||
.Attr("adj_y", adj_y)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
|
||||
bool adjoint_b, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, adjoint_a ? TensorShape({b, k, m}) : TensorShape({b, m, k}));
|
||||
in0.flat<T>().setRandom();
|
||||
Tensor in1(type, adjoint_b ? TensorShape({b, n, k}) : TensorShape({b, k, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
test::graph::BatchMatmul(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, in1), adjoint_a, adjoint_b);
|
||||
return g;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
|
||||
bool manual_broadcast, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, TensorShape({b0, m, k}));
|
||||
in0.flat<T>().setRandom();
|
||||
Tensor in1(type, TensorShape({b1, k, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
|
||||
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
|
||||
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
|
||||
|
||||
Node* in0_node = nullptr;
|
||||
Node* in1_node = nullptr;
|
||||
if (manual_broadcast) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
auto vec0 = broadcasted_in0_shape.vec<int64>();
|
||||
auto vec1 = broadcasted_in1_shape.vec<int64>();
|
||||
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
|
||||
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
|
||||
}
|
||||
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, broadcasted_in0_shape));
|
||||
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
|
||||
test::graph::Constant(g, broadcasted_in1_shape));
|
||||
} else {
|
||||
in0_node = test::graph::Constant(g, in0);
|
||||
in1_node = test::graph::Constant(g, in1);
|
||||
}
|
||||
|
||||
BatchMatmulV2(g, in0_node, in1_node, false, false);
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
|
||||
static void \
|
||||
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * B * M * K * N * 2); \
|
||||
test::Benchmark(#DEVICE, BatchMatmul<T>(B, M, K, N, TA, TB, TFTYPE)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE);
|
||||
|
||||
#define BM_BatchMatmul(B, M, K, N, TA, TB) \
|
||||
BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, cpu);
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
|
||||
// cpu);
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, float, DT_FLOAT, gpu);
|
||||
/* Uncomment to enable benchmarks for double & complex types: */
|
||||
// BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex<float>, DT_COMPLEX64,
|
||||
// gpu);
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, cpu);
|
||||
// \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
|
||||
|
||||
// Macro arguments names: --------------------------------------------------- //
|
||||
// B1: batch size of LHS
|
||||
// B2: batch size of RHS
|
||||
// M: outer dimension of LHS
|
||||
// K: inner dimensions of LHS and RHS
|
||||
// N: outer dimension of RHS
|
||||
// MB: boolean indicating whether to use manual broadcasting
|
||||
// T: C++ type of scalars (e.g. float, std::complex)
|
||||
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
|
||||
// D: Device (e.g. cpu, gpu)
|
||||
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
|
||||
static void \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
|
||||
K * N * 2); \
|
||||
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
|
||||
|
||||
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
|
||||
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
|
||||
|
||||
// Square matmul.
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
|
||||
|
||||
// Matrix-vector multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
|
||||
|
||||
// Vector-matrix multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 128, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 8, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 16, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 128, 1024, 1024, false, false);
|
||||
|
||||
// Square matmul.
|
||||
BM_BatchMatmul(1, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(1, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(1, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(1, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(2, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(2, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(2, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(2, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(2, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(4, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(4, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(4, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(4, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(4, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(8, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(8, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(8, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(8, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(8, 2048, 2048, 2048, false, false);
|
||||
BM_BatchMatmul(32, 32, 32, 32, false, false);
|
||||
BM_BatchMatmul(32, 128, 128, 128, false, false);
|
||||
BM_BatchMatmul(32, 256, 256, 256, false, false);
|
||||
BM_BatchMatmul(32, 1024, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(32, 2048, 2048, 2048, false, false);
|
||||
|
||||
// Matrix-vector multiplies.
|
||||
BM_BatchMatmul(1, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, false, false);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, true, false);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, false, true);
|
||||
BM_BatchMatmul(1, 10000, 200, 1, true, true);
|
||||
BM_BatchMatmul(8, 10000, 200, 1, true, true);
|
||||
BM_BatchMatmul(32, 10000, 200, 1, true, true);
|
||||
|
||||
// Vector-matrix multiplies.
|
||||
BM_BatchMatmul(1, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, false, false);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, true, false);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, false, true);
|
||||
BM_BatchMatmul(1, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, true, true);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -33,8 +33,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
@ -197,14 +197,16 @@ class MatMulInfixOperatorTest(test_lib.TestCase):
|
||||
|
||||
def testMismatchedShape(self):
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "(Shape must be rank 2 but is rank 1|is not a matrix)"):
|
||||
Exception, (r"(In\[0\] and In\[1\] has different ndims|In\[0\] "
|
||||
r"ndims must be >= 2|Shape must be rank 2 but is rank 1)")):
|
||||
infix_matmul(
|
||||
ops.convert_to_tensor([10.0, 20.0, 30.0]),
|
||||
ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
|
||||
|
||||
def testMismatchedDimensions(self):
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "(Dimensions must be equal|Matrix size-incompatible)"):
|
||||
Exception,
|
||||
r"(In\[0\] mismatch In\[1\] shape|Dimensions must be equal)"):
|
||||
infix_matmul(
|
||||
ops.convert_to_tensor([[10.0, 20.0, 30.0]]),
|
||||
ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
|
||||
@ -234,9 +236,10 @@ if __name__ == "__main__":
|
||||
# TF2 does not support placeholders under eager so we skip it
|
||||
for use_static_shape in set([True, tf2.enabled()]):
|
||||
for dtype in dtypes_to_test:
|
||||
if not use_static_shape and (dtype == np.int32 or dtype == np.int64):
|
||||
# TODO(rmlarsen): Re-enable this test when we have fixed the underlying
|
||||
# bug in Windows (b/35935459).
|
||||
if test_util.is_xla_enabled() and (dtype == np.int32 or
|
||||
dtype == np.int64):
|
||||
# TODO(b/171924639): Enable this test when XLA DOT supports
|
||||
# integer types.
|
||||
continue
|
||||
for m in sizes:
|
||||
for n in sizes:
|
||||
|
@ -55,8 +55,8 @@ class TensordotTest(test_lib.TestCase):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
||||
"Matrix size-incompatible"):
|
||||
with self.assertRaisesOpError(
|
||||
r"In\[0\] mismatch In\[1\] shape: 2 vs\. 3: \[2,2\] \[3,2\]"):
|
||||
a_ph = array_ops.placeholder(dtypes.float32)
|
||||
b_ph = array_ops.placeholder(dtypes.float32)
|
||||
axes_ph = array_ops.placeholder(dtypes.int32)
|
||||
|
@ -108,15 +108,14 @@ class PrintOpFilegroupTest(test.TestCase):
|
||||
|
||||
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
|
||||
'rawproto', self.WriteGraphFiles(graphs), default_ops)
|
||||
matmul_prefix = ''
|
||||
matmul_prefix = 'Batch'
|
||||
|
||||
self.assertListEqual(
|
||||
[
|
||||
('AccumulateNV2', None), #
|
||||
('BiasAdd', 'BiasOp<CPUDevice, float>'), #
|
||||
('MatMul',
|
||||
matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), #
|
||||
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), #
|
||||
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, true>'), #
|
||||
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, true>'), #
|
||||
('NoOp', 'NoOp'), #
|
||||
('Reshape', 'ReshapeOp'), #
|
||||
('_Recv', 'RecvOp'), #
|
||||
@ -132,9 +131,8 @@ class PrintOpFilegroupTest(test.TestCase):
|
||||
[
|
||||
('AccumulateNV2', None), #
|
||||
('BiasAdd', 'BiasOp<CPUDevice, float>'), #
|
||||
('MatMul',
|
||||
matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), #
|
||||
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'), #
|
||||
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, true>'), #
|
||||
('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, true>'), #
|
||||
('NoOp', 'NoOp'), #
|
||||
('Reshape', 'ReshapeOp'), #
|
||||
('_Recv', 'RecvOp'), #
|
||||
|
Loading…
Reference in New Issue
Block a user