Fix regression in (non-batch) matrix multiplication:
1. Introduce a specialization for real non-batch calls to the Eigen contraction kernel. 2. Reduce overhead in Tensor::CopyFromInternal and inline it. 3. Reduce overhead in the constructor of MatMulBCast and inline it. PiperOrigin-RevId: 342963388 Change-Id: I21fbac97ce7b47a8c585ae8399670d31d53bd108
This commit is contained in:
parent
e87fba482e
commit
347d28b0ce
@ -675,20 +675,6 @@ void Tensor::CheckIsAlignedAndSingleElement() const {
|
||||
|
||||
Tensor::~Tensor() { UnrefIfNonNull(buf_); }
|
||||
|
||||
void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
|
||||
CHECK_EQ(shape.num_elements(), other.NumElements());
|
||||
// Data type will be overwritten if this == &other, since dtype is part of
|
||||
// shape.
|
||||
DataType other_dtype = other.dtype();
|
||||
shape_ = shape;
|
||||
set_dtype(other_dtype);
|
||||
if (buf_ != other.buf_) {
|
||||
UnrefIfNonNull(buf_);
|
||||
buf_ = other.buf_;
|
||||
RefIfNonNull(buf_);
|
||||
}
|
||||
}
|
||||
|
||||
Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
|
||||
const TensorShape& shape) {
|
||||
int in_size = DataTypeSize(other.dtype());
|
||||
|
@ -697,7 +697,19 @@ class Tensor {
|
||||
set_dtype(dt);
|
||||
}
|
||||
|
||||
void CopyFromInternal(const Tensor& other, const TensorShape& shape);
|
||||
inline void CopyFromInternal(const Tensor& other, const TensorShape& shape) {
|
||||
DCHECK_EQ(shape.num_elements(), other.NumElements());
|
||||
// Data type will be overwritten if this == &other, since dtype is part of
|
||||
// shape.
|
||||
DataType other_dtype = other.dtype();
|
||||
shape_ = shape;
|
||||
set_dtype(other_dtype);
|
||||
if (buf_ != other.buf_) {
|
||||
if (buf_) buf_->Unref();
|
||||
buf_ = other.buf_;
|
||||
if (buf_) buf_->Ref();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* base() const;
|
||||
|
@ -77,7 +77,7 @@ struct ParallelMatMulKernel {
|
||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out,
|
||||
int start, int limit) {
|
||||
int batch_size) {
|
||||
static_assert(IsComplex, "Complex type expected.");
|
||||
auto Tx = in_x.tensor<Scalar, 3>();
|
||||
auto Ty = in_y.tensor<Scalar, 3>();
|
||||
@ -94,7 +94,8 @@ struct ParallelMatMulKernel {
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = start; i < limit; ++i) {
|
||||
// TODO(rmlarsen): Consider launching these contractions asynchronously.
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||
|
||||
@ -121,25 +122,32 @@ struct ParallelMatMulKernel<Scalar, false> {
|
||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out,
|
||||
int start, int limit) {
|
||||
auto Tx = in_x.tensor<Scalar, 3>();
|
||||
auto Ty = in_y.tensor<Scalar, 3>();
|
||||
auto Tz = out->tensor<Scalar, 3>();
|
||||
int batch_size) {
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
||||
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
|
||||
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
|
||||
if (batch_size == 1 && !should_bcast) {
|
||||
auto Tx = in_x.flat_inner_dims<Scalar, 2>();
|
||||
auto Ty = in_y.flat_inner_dims<Scalar, 2>();
|
||||
auto Tz = out->flat_inner_dims<Scalar, 2>();
|
||||
Tz.device(d) = Tx.contract(Ty, contract_pairs);
|
||||
} else {
|
||||
auto Tx = in_x.tensor<Scalar, 3>();
|
||||
auto Ty = in_y.tensor<Scalar, 3>();
|
||||
auto Tz = out->tensor<Scalar, 3>();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
// TODO(rmlarsen): Consider launching these contractions asynchronously.
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||
auto x = Tx.template chip<0>(x_batch_index);
|
||||
auto y = Ty.template chip<0>(y_batch_index);
|
||||
auto z = Tz.template chip<0>(i);
|
||||
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = start; i < limit; ++i) {
|
||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||
auto x = Tx.template chip<0>(x_batch_index);
|
||||
auto y = Ty.template chip<0>(y_batch_index);
|
||||
auto z = Tz.template chip<0>(i);
|
||||
|
||||
z.device(d) = x.contract(y, contract_pairs);
|
||||
z.device(d) = x.contract(y, contract_pairs);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -234,13 +242,15 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
||||
// Jan 21, 2020.
|
||||
const int64 kMaxCostOuterParallelism = 128 * 128; // heuristic.
|
||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||
// TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
|
||||
// evaluation in Eigen Tensor.
|
||||
if (small_dim > 1 &&
|
||||
(batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
|
||||
// Parallelize over inner dims.
|
||||
// For large matrix products it is counter-productive to parallelize
|
||||
// over the batch dimension.
|
||||
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
|
||||
trans_y, bcast, out, 0, batch_size);
|
||||
trans_y, bcast, out, batch_size);
|
||||
conjugate_result = adj_x;
|
||||
} else {
|
||||
// Parallelize over outer dims. For small matrices and large batches, it
|
||||
@ -656,7 +666,11 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
const Tensor& in0 = ctx->input(0);
|
||||
const Tensor& in1 = ctx->input(1);
|
||||
|
||||
ValidateInputTensors(ctx, in0, in1);
|
||||
const Status s = ValidateInputTensors(ctx, in0, in1);
|
||||
if (!s.ok()) {
|
||||
ctx->SetStatus(s);
|
||||
return;
|
||||
}
|
||||
|
||||
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
|
||||
OP_REQUIRES(
|
||||
@ -740,8 +754,8 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) = 0;
|
||||
virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) = 0;
|
||||
|
||||
private:
|
||||
// TODO(171979567) Make the ops take both adj and transpose attributes.
|
||||
@ -761,31 +775,36 @@ class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
~BatchMatMulOp() override {}
|
||||
|
||||
private:
|
||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) override {
|
||||
Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) override {
|
||||
// Disallow broadcasting support. Ensure that all batch dimensions of the
|
||||
// input tensors match.
|
||||
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
|
||||
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
|
||||
in0.shape().DebugString(), " vs. ",
|
||||
in1.shape().DebugString()));
|
||||
if (in0.dims() != in1.dims()) {
|
||||
return errors::InvalidArgument(
|
||||
"In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
|
||||
" vs. ", in1.shape().DebugString());
|
||||
}
|
||||
const int ndims = in0.dims();
|
||||
if (is_legacy_matmul) {
|
||||
OP_REQUIRES(ctx, ndims == 2,
|
||||
errors::InvalidArgument(
|
||||
"In[0] and In[1] ndims must be == 2: ", ndims));
|
||||
if (ndims != 2) {
|
||||
return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
|
||||
ndims);
|
||||
}
|
||||
} else {
|
||||
OP_REQUIRES(ctx, ndims >= 2,
|
||||
errors::InvalidArgument(
|
||||
"In[0] and In[1] ndims must be >= 2: ", ndims));
|
||||
if (ndims < 2) {
|
||||
return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
|
||||
ndims);
|
||||
}
|
||||
for (int i = 0; i < ndims - 2; ++i) {
|
||||
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
|
||||
errors::InvalidArgument(
|
||||
"In[0].dim(", i, ") and In[1].dim(", i,
|
||||
") must be the same: ", in0.shape().DebugString(),
|
||||
" vs ", in1.shape().DebugString()));
|
||||
if (in0.dim_size(i) != in1.dim_size(i)) {
|
||||
return errors::InvalidArgument(
|
||||
"In[0].dim(", i, ") and In[1].dim(", i,
|
||||
") must be the same: ", in0.shape().DebugString(), " vs ",
|
||||
in1.shape().DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
@ -800,16 +819,17 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
~BatchMatMulV2Op() override {}
|
||||
|
||||
private:
|
||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) override {
|
||||
Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) override {
|
||||
// Enable broadcasting support. Validity of broadcasting is checked in
|
||||
// BaseBatchMatMulOp.
|
||||
OP_REQUIRES(
|
||||
ctx, in0.dims() >= 2,
|
||||
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
|
||||
OP_REQUIRES(
|
||||
ctx, in1.dims() >= 2,
|
||||
errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
|
||||
if (in0.dims() < 2) {
|
||||
return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
|
||||
}
|
||||
if (in1.dims() < 2) {
|
||||
return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -99,7 +99,6 @@ filegroup(
|
||||
"guarded_philox_random.h",
|
||||
"matmul_autotune.cc",
|
||||
"matmul_autotune.h",
|
||||
"matmul_bcast.cc",
|
||||
"matmul_bcast.h",
|
||||
"mirror_pad_mode.cc",
|
||||
"mirror_pad_mode.h",
|
||||
@ -221,7 +220,6 @@ filegroup(
|
||||
"example_proto_helper.cc",
|
||||
"guarded_philox_random.cc",
|
||||
"matmul_autotune.cc",
|
||||
"matmul_bcast.cc",
|
||||
"mirror_pad_mode.cc",
|
||||
"saved_tensor_slice_util.cc",
|
||||
"stat_summarizer.cc",
|
||||
|
@ -1,42 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
MatMulBCast::MatMulBCast(Vec x, Vec y) {
|
||||
if (x.size() < 2 || y.size() < 2) return;
|
||||
x.resize(x.size() - 2);
|
||||
y.resize(y.size() - 2);
|
||||
|
||||
batch_bcast_ = absl::make_unique<BCast>(std::move(x), std::move(y));
|
||||
if (!batch_bcast_->IsValid()) return;
|
||||
|
||||
x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements();
|
||||
y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements();
|
||||
output_shape_ = TensorShape(batch_bcast_->output_shape());
|
||||
output_batch_size_ = output_shape_.num_elements();
|
||||
broadcasting_required_ =
|
||||
std::min(x_batch_size_, y_batch_size_) != output_batch_size_;
|
||||
|
||||
if (broadcasting_required_) {
|
||||
ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(),
|
||||
batch_bcast_->x_bcast(), &x_batch_indices_);
|
||||
ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(),
|
||||
batch_bcast_->y_bcast(), &y_batch_indices_);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -26,20 +26,50 @@ namespace tensorflow {
|
||||
|
||||
// Simple wrapper over BCast specialized for MatMul.
|
||||
// Provides utilities for broadcasting across batch dimensions for binary
|
||||
// MatMul-like operations.
|
||||
// MatMul-like operations. If neither argument has batch dimensions (rank <= 2)
|
||||
// then no broadcasting is needed and the operation MatMul operation is
|
||||
// considered valid.
|
||||
class MatMulBCast {
|
||||
public:
|
||||
using Vec = BCast::Vec;
|
||||
|
||||
MatMulBCast(Vec x, Vec y);
|
||||
MatMulBCast(const Vec& x, const Vec& y) {
|
||||
if (std::max(x.size(), y.size()) == 2) return;
|
||||
const Vec x_resized(x.begin(), x.end() - 2);
|
||||
const Vec y_resized(y.begin(), y.end() - 2);
|
||||
|
||||
bool IsValid() const { return batch_bcast_ && batch_bcast_->IsValid(); }
|
||||
batch_bcast_ =
|
||||
absl::make_unique<BCast>(std::move(x_resized), std::move(y_resized));
|
||||
if (!batch_bcast_->IsValid()) {
|
||||
// Set broadcasting_required_ to true to make IsValid() return false;
|
||||
broadcasting_required_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements();
|
||||
y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements();
|
||||
output_batch_shape_ = TensorShape(batch_bcast_->output_shape());
|
||||
output_batch_size_ = output_batch_shape_.num_elements();
|
||||
broadcasting_required_ =
|
||||
std::min(x_batch_size_, y_batch_size_) != output_batch_size_;
|
||||
|
||||
if (broadcasting_required_) {
|
||||
ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(),
|
||||
batch_bcast_->x_bcast(), &x_batch_indices_);
|
||||
ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(),
|
||||
batch_bcast_->y_bcast(), &y_batch_indices_);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsValid() const {
|
||||
return !broadcasting_required_ || (batch_bcast_ && batch_bcast_->IsValid());
|
||||
}
|
||||
bool IsBroadcastingRequired() const { return broadcasting_required_; }
|
||||
|
||||
const int64 output_batch_size() const { return output_batch_size_; }
|
||||
const int64 x_batch_size() const { return x_batch_size_; }
|
||||
const int64 y_batch_size() const { return y_batch_size_; }
|
||||
const TensorShape& output_batch_shape() const { return output_shape_; }
|
||||
const TensorShape& output_batch_shape() const { return output_batch_shape_; }
|
||||
|
||||
// Returns the mapping from the flattened output batch indices to x's
|
||||
// flattened batch indices. The result is a vector of length
|
||||
@ -57,10 +87,10 @@ class MatMulBCast {
|
||||
private:
|
||||
std::unique_ptr<BCast> batch_bcast_;
|
||||
bool broadcasting_required_ = false;
|
||||
int64 x_batch_size_;
|
||||
int64 y_batch_size_;
|
||||
TensorShape output_shape_;
|
||||
int64 output_batch_size_;
|
||||
int64 x_batch_size_ = 1;
|
||||
int64 y_batch_size_ = 1;
|
||||
TensorShape output_batch_shape_;
|
||||
int64 output_batch_size_ = 1;
|
||||
std::vector<int64> x_batch_indices_;
|
||||
std::vector<int64> y_batch_indices_;
|
||||
};
|
||||
|
@ -94,14 +94,18 @@ TEST(MatMulBCastTest, EmptyWithNonEmptyBatchBroadcast) {
|
||||
EXPECT_EQ("[2][0,1][0,0]", MatMulBCastToStr(bcast2));
|
||||
}
|
||||
|
||||
TEST(MatMulBCastTest, InvalidDimensions) {
|
||||
// Too few dimensions.
|
||||
TEST(MatMulBCastTest, NoBathcDimensions) {
|
||||
MatMulBCast bcast1({3, 3}, {3});
|
||||
EXPECT_FALSE(bcast1.IsValid());
|
||||
EXPECT_TRUE(bcast1.IsValid());
|
||||
|
||||
MatMulBCast bcast2({3}, {3, 3});
|
||||
EXPECT_FALSE(bcast2.IsValid());
|
||||
EXPECT_TRUE(bcast2.IsValid());
|
||||
|
||||
MatMulBCast bcast3({3, 3}, {3, 3});
|
||||
EXPECT_TRUE(bcast3.IsValid());
|
||||
}
|
||||
|
||||
TEST(MatMulBCastTest, InvalidDimensions) {
|
||||
// Batch dimensions not broadcastable.
|
||||
MatMulBCast bcast3({4, 5, 3}, {2, 3, 7});
|
||||
EXPECT_FALSE(bcast3.IsValid());
|
||||
|
Loading…
x
Reference in New Issue
Block a user