diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index a3fcb4d4b8f..bd6f58453df 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -81,9 +82,7 @@ class MatMulOp : public XlaOpKernel { b = xla::ConvertElementType(b, xla::F32); } } - auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a; - auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b; - ctx->SetOutput(0, xla::Dot(lhs, rhs)); + ctx->SetOutput(0, xla::BatchDot(a, transpose_a_, b, transpose_b_)); } private: