[TF:XLA] Small change in tf2xla matmul to use BatchDot instead of Transpose + Dot.
This has the advantage that we can more easily detect symmetric matmuls (e.g. A * At) before the algebraic simplifier passes. BatchDot simply moves around contract_dims instead of adding a Transpose op. Benchmarks (JF) --------------- Summary of changes: Compile time 0.99x geomean, range [ 0.80x, 1.58x], 1.00x arith mean Host memory 1.00x geomean, range [ 0.77x, 1.25x] SMEM usage 1.00x geomean, range [ 0.98x, 1.02x] Benchmark runtime 1.00x geomean, range [ 0.99x, 2.43x] No changes after rounding in HBM usage, VMEM usage, Bundle count, Overlay wait time, Static throttling PiperOrigin-RevId: 313255256 Change-Id: I13d781161fad9d685c7bfcb96e511130b2b9e182
This commit is contained in:
parent
bb34d65cd7
commit
e7cc47384f
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
@ -81,9 +82,7 @@ class MatMulOp : public XlaOpKernel {
|
||||
b = xla::ConvertElementType(b, xla::F32);
|
||||
}
|
||||
}
|
||||
auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a;
|
||||
auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b;
|
||||
ctx->SetOutput(0, xla::Dot(lhs, rhs));
|
||||
ctx->SetOutput(0, xla::BatchDot(a, transpose_a_, b, transpose_b_));
|
||||
}
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user