From e7cc47384f4d57cc04ec550dbc2c08e467e42a4a Mon Sep 17 00:00:00 2001 From: Anudhyan Boral Date: Tue, 26 May 2020 13:15:31 -0700 Subject: [PATCH] [TF:XLA] Small change in tf2xla matmul to use BatchDot instead of Transpose + Dot. This has the advantage that we can more easily detect symmetric matmuls (e.g. A * At) before the algebraic simplifier passes. BatchDot simply moves around contract_dims instead of adding a Transpose op. Benchmarks (JF) --------------- Summary of changes: Compile time 0.99x geomean, range [ 0.80x, 1.58x], 1.00x arith mean Host memory 1.00x geomean, range [ 0.77x, 1.25x] SMEM usage 1.00x geomean, range [ 0.98x, 1.02x] Benchmark runtime 1.00x geomean, range [ 0.99x, 2.43x] No changes after rounding in HBM usage, VMEM usage, Bundle count, Overlay wait time, Static throttling PiperOrigin-RevId: 313255256 Change-Id: I13d781161fad9d685c7bfcb96e511130b2b9e182 --- tensorflow/compiler/tf2xla/kernels/matmul_op.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index a3fcb4d4b8f..bd6f58453df 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -81,9 +82,7 @@ class MatMulOp : public XlaOpKernel { b = xla::ConvertElementType(b, xla::F32); } } - auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a; - auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b; - ctx->SetOutput(0, xla::Dot(lhs, rhs)); + ctx->SetOutput(0, xla::BatchDot(a, transpose_a_, b, transpose_b_)); } private: