Fix regression in (non-batch) matrix multiplication:

1. Introduce a specialization for real non-batch calls to the Eigen contraction kernel.
2. Reduce overhead in Tensor::CopyFromInternal and inline it.
3. Reduce overhead in the constructor of MatMulBCast and inline it.

PiperOrigin-RevId: 342963388
Change-Id: I21fbac97ce7b47a8c585ae8399670d31d53bd108
This commit is contained in:
A. Unique TensorFlower 2020-11-17 15:50:43 -08:00 committed by TensorFlower Gardener
parent e87fba482e
commit 347d28b0ce
7 changed files with 126 additions and 118 deletions

View File

@ -675,20 +675,6 @@ void Tensor::CheckIsAlignedAndSingleElement() const {
Tensor::~Tensor() { UnrefIfNonNull(buf_); }
void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
CHECK_EQ(shape.num_elements(), other.NumElements());
// Data type will be overwritten if this == &other, since dtype is part of
// shape.
DataType other_dtype = other.dtype();
shape_ = shape;
set_dtype(other_dtype);
if (buf_ != other.buf_) {
UnrefIfNonNull(buf_);
buf_ = other.buf_;
RefIfNonNull(buf_);
}
}
Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
const TensorShape& shape) {
int in_size = DataTypeSize(other.dtype());

View File

@ -697,7 +697,19 @@ class Tensor {
set_dtype(dt);
}
void CopyFromInternal(const Tensor& other, const TensorShape& shape);
inline void CopyFromInternal(const Tensor& other, const TensorShape& shape) {
DCHECK_EQ(shape.num_elements(), other.NumElements());
// Data type will be overwritten if this == &other, since dtype is part of
// shape.
DataType other_dtype = other.dtype();
shape_ = shape;
set_dtype(other_dtype);
if (buf_ != other.buf_) {
if (buf_) buf_->Unref();
buf_ = other.buf_;
if (buf_) buf_->Ref();
}
}
template <typename T>
T* base() const;

View File

@ -77,7 +77,7 @@ struct ParallelMatMulKernel {
static void Run(const OpKernelContext* context, const Tensor& in_x,
const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out,
int start, int limit) {
int batch_size) {
static_assert(IsComplex, "Complex type expected.");
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
@ -94,7 +94,8 @@ struct ParallelMatMulKernel {
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
for (int64 i = start; i < limit; ++i) {
// TODO(rmlarsen): Consider launching these contractions asynchronously.
for (int64 i = 0; i < batch_size; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
@ -121,25 +122,32 @@ struct ParallelMatMulKernel<Scalar, false> {
static void Run(const OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out,
int start, int limit) {
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
auto Tz = out->tensor<Scalar, 3>();
int batch_size) {
const bool should_bcast = bcast.IsBroadcastingRequired();
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
if (batch_size == 1 && !should_bcast) {
auto Tx = in_x.flat_inner_dims<Scalar, 2>();
auto Ty = in_y.flat_inner_dims<Scalar, 2>();
auto Tz = out->flat_inner_dims<Scalar, 2>();
Tz.device(d) = Tx.contract(Ty, contract_pairs);
} else {
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
auto Tz = out->tensor<Scalar, 3>();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
// TODO(rmlarsen): Consider launching these contractions asynchronously.
for (int64 i = 0; i < batch_size; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto x = Tx.template chip<0>(x_batch_index);
auto y = Ty.template chip<0>(y_batch_index);
auto z = Tz.template chip<0>(i);
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
for (int64 i = start; i < limit; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto x = Tx.template chip<0>(x_batch_index);
auto y = Ty.template chip<0>(y_batch_index);
auto z = Tz.template chip<0>(i);
z.device(d) = x.contract(y, contract_pairs);
z.device(d) = x.contract(y, contract_pairs);
}
}
}
};
@ -234,13 +242,15 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
// Jan 21, 2020.
const int64 kMaxCostOuterParallelism = 128 * 128; // heuristic.
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
// TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
// evaluation in Eigen Tensor.
if (small_dim > 1 &&
(batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
// Parallelize over inner dims.
// For large matrix products it is counter-productive to parallelize
// over the batch dimension.
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
trans_y, bcast, out, 0, batch_size);
trans_y, bcast, out, batch_size);
conjugate_result = adj_x;
} else {
// Parallelize over outer dims. For small matrices and large batches, it
@ -656,7 +666,11 @@ class BaseBatchMatMulOp : public OpKernel {
const Tensor& in0 = ctx->input(0);
const Tensor& in1 = ctx->input(1);
ValidateInputTensors(ctx, in0, in1);
const Status s = ValidateInputTensors(ctx, in0, in1);
if (!s.ok()) {
ctx->SetStatus(s);
return;
}
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
OP_REQUIRES(
@ -740,8 +754,8 @@ class BaseBatchMatMulOp : public OpKernel {
}
protected:
virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) = 0;
virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) = 0;
private:
// TODO(171979567) Make the ops take both adj and transpose attributes.
@ -761,31 +775,36 @@ class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
~BatchMatMulOp() override {}
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
// Disallow broadcasting support. Ensure that all batch dimensions of the
// input tensors match.
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
in0.shape().DebugString(), " vs. ",
in1.shape().DebugString()));
if (in0.dims() != in1.dims()) {
return errors::InvalidArgument(
"In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
" vs. ", in1.shape().DebugString());
}
const int ndims = in0.dims();
if (is_legacy_matmul) {
OP_REQUIRES(ctx, ndims == 2,
errors::InvalidArgument(
"In[0] and In[1] ndims must be == 2: ", ndims));
if (ndims != 2) {
return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
ndims);
}
} else {
OP_REQUIRES(ctx, ndims >= 2,
errors::InvalidArgument(
"In[0] and In[1] ndims must be >= 2: ", ndims));
if (ndims < 2) {
return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
ndims);
}
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
errors::InvalidArgument(
"In[0].dim(", i, ") and In[1].dim(", i,
") must be the same: ", in0.shape().DebugString(),
" vs ", in1.shape().DebugString()));
if (in0.dim_size(i) != in1.dim_size(i)) {
return errors::InvalidArgument(
"In[0].dim(", i, ") and In[1].dim(", i,
") must be the same: ", in0.shape().DebugString(), " vs ",
in1.shape().DebugString());
}
}
}
return Status::OK();
}
};
@ -800,16 +819,17 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
~BatchMatMulV2Op() override {}
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
// Enable broadcasting support. Validity of broadcasting is checked in
// BaseBatchMatMulOp.
OP_REQUIRES(
ctx, in0.dims() >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
OP_REQUIRES(
ctx, in1.dims() >= 2,
errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
if (in0.dims() < 2) {
return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
}
if (in1.dims() < 2) {
return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
}
return Status::OK();
}
};

View File

@ -99,7 +99,6 @@ filegroup(
"guarded_philox_random.h",
"matmul_autotune.cc",
"matmul_autotune.h",
"matmul_bcast.cc",
"matmul_bcast.h",
"mirror_pad_mode.cc",
"mirror_pad_mode.h",
@ -221,7 +220,6 @@ filegroup(
"example_proto_helper.cc",
"guarded_philox_random.cc",
"matmul_autotune.cc",
"matmul_bcast.cc",
"mirror_pad_mode.cc",
"saved_tensor_slice_util.cc",
"stat_summarizer.cc",

View File

@ -1,42 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/matmul_bcast.h"
namespace tensorflow {
MatMulBCast::MatMulBCast(Vec x, Vec y) {
if (x.size() < 2 || y.size() < 2) return;
x.resize(x.size() - 2);
y.resize(y.size() - 2);
batch_bcast_ = absl::make_unique<BCast>(std::move(x), std::move(y));
if (!batch_bcast_->IsValid()) return;
x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements();
y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements();
output_shape_ = TensorShape(batch_bcast_->output_shape());
output_batch_size_ = output_shape_.num_elements();
broadcasting_required_ =
std::min(x_batch_size_, y_batch_size_) != output_batch_size_;
if (broadcasting_required_) {
ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(),
batch_bcast_->x_bcast(), &x_batch_indices_);
ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(),
batch_bcast_->y_bcast(), &y_batch_indices_);
}
}
} // namespace tensorflow

View File

@ -26,20 +26,50 @@ namespace tensorflow {
// Simple wrapper over BCast specialized for MatMul.
// Provides utilities for broadcasting across batch dimensions for binary
// MatMul-like operations.
// MatMul-like operations. If neither argument has batch dimensions (rank <= 2)
// then no broadcasting is needed and the operation MatMul operation is
// considered valid.
class MatMulBCast {
public:
using Vec = BCast::Vec;
MatMulBCast(Vec x, Vec y);
MatMulBCast(const Vec& x, const Vec& y) {
if (std::max(x.size(), y.size()) == 2) return;
const Vec x_resized(x.begin(), x.end() - 2);
const Vec y_resized(y.begin(), y.end() - 2);
bool IsValid() const { return batch_bcast_ && batch_bcast_->IsValid(); }
batch_bcast_ =
absl::make_unique<BCast>(std::move(x_resized), std::move(y_resized));
if (!batch_bcast_->IsValid()) {
// Set broadcasting_required_ to true to make IsValid() return false;
broadcasting_required_ = true;
return;
}
x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements();
y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements();
output_batch_shape_ = TensorShape(batch_bcast_->output_shape());
output_batch_size_ = output_batch_shape_.num_elements();
broadcasting_required_ =
std::min(x_batch_size_, y_batch_size_) != output_batch_size_;
if (broadcasting_required_) {
ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(),
batch_bcast_->x_bcast(), &x_batch_indices_);
ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(),
batch_bcast_->y_bcast(), &y_batch_indices_);
}
}
bool IsValid() const {
return !broadcasting_required_ || (batch_bcast_ && batch_bcast_->IsValid());
}
bool IsBroadcastingRequired() const { return broadcasting_required_; }
const int64 output_batch_size() const { return output_batch_size_; }
const int64 x_batch_size() const { return x_batch_size_; }
const int64 y_batch_size() const { return y_batch_size_; }
const TensorShape& output_batch_shape() const { return output_shape_; }
const TensorShape& output_batch_shape() const { return output_batch_shape_; }
// Returns the mapping from the flattened output batch indices to x's
// flattened batch indices. The result is a vector of length
@ -57,10 +87,10 @@ class MatMulBCast {
private:
std::unique_ptr<BCast> batch_bcast_;
bool broadcasting_required_ = false;
int64 x_batch_size_;
int64 y_batch_size_;
TensorShape output_shape_;
int64 output_batch_size_;
int64 x_batch_size_ = 1;
int64 y_batch_size_ = 1;
TensorShape output_batch_shape_;
int64 output_batch_size_ = 1;
std::vector<int64> x_batch_indices_;
std::vector<int64> y_batch_indices_;
};

View File

@ -94,14 +94,18 @@ TEST(MatMulBCastTest, EmptyWithNonEmptyBatchBroadcast) {
EXPECT_EQ("[2][0,1][0,0]", MatMulBCastToStr(bcast2));
}
TEST(MatMulBCastTest, InvalidDimensions) {
// Too few dimensions.
TEST(MatMulBCastTest, NoBathcDimensions) {
MatMulBCast bcast1({3, 3}, {3});
EXPECT_FALSE(bcast1.IsValid());
EXPECT_TRUE(bcast1.IsValid());
MatMulBCast bcast2({3}, {3, 3});
EXPECT_FALSE(bcast2.IsValid());
EXPECT_TRUE(bcast2.IsValid());
MatMulBCast bcast3({3, 3}, {3, 3});
EXPECT_TRUE(bcast3.IsValid());
}
TEST(MatMulBCastTest, InvalidDimensions) {
// Batch dimensions not broadcastable.
MatMulBCast bcast3({4, 5, 3}, {2, 3, 7});
EXPECT_FALSE(bcast3.IsValid());