[TF:XLA] Small change in tf2xla matmul to use BatchDot instead of Transpose + Dot.

This has the advantage that we can more easily detect symmetric matmuls (e.g. A * At) before the algebraic simplifier passes. BatchDot simply moves around contract_dims instead of adding a Transpose op.

Benchmarks (JF)
---------------
Summary of changes:
        Compile time  0.99x geomean, range [ 0.80x,  1.58x],  1.00x arith mean
         Host memory  1.00x geomean, range [ 0.77x,  1.25x]
          SMEM usage  1.00x geomean, range [ 0.98x,  1.02x]
   Benchmark runtime  1.00x geomean, range [ 0.99x,  2.43x]
No changes after rounding in HBM usage, VMEM usage, Bundle count, Overlay wait time, Static throttling

PiperOrigin-RevId: 313255256
Change-Id: I13d781161fad9d685c7bfcb96e511130b2b9e182
This commit is contained in:
Anudhyan Boral 2020-05-26 13:15:31 -07:00 committed by TensorFlower Gardener
parent bb34d65cd7
commit e7cc47384f

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -81,9 +82,7 @@ class MatMulOp : public XlaOpKernel {
b = xla::ConvertElementType(b, xla::F32);
}
}
auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a;
auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b;
ctx->SetOutput(0, xla::Dot(lhs, rhs));
ctx->SetOutput(0, xla::BatchDot(a, transpose_a_, b, transpose_b_));
}
private: