Fix regression in (non-batch) matrix multiplication:
1. Introduce a specialization for real non-batch calls to the Eigen contraction kernel. 2. Reduce overhead in Tensor::CopyFromInternal and inline it. 3. Reduce overhead in the constructor of MatMulBCast and inline it. PiperOrigin-RevId: 342963388 Change-Id: I21fbac97ce7b47a8c585ae8399670d31d53bd108
This commit is contained in:
parent
e87fba482e
commit
347d28b0ce
@ -675,20 +675,6 @@ void Tensor::CheckIsAlignedAndSingleElement() const {
|
|||||||
|
|
||||||
Tensor::~Tensor() { UnrefIfNonNull(buf_); }
|
Tensor::~Tensor() { UnrefIfNonNull(buf_); }
|
||||||
|
|
||||||
void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
|
|
||||||
CHECK_EQ(shape.num_elements(), other.NumElements());
|
|
||||||
// Data type will be overwritten if this == &other, since dtype is part of
|
|
||||||
// shape.
|
|
||||||
DataType other_dtype = other.dtype();
|
|
||||||
shape_ = shape;
|
|
||||||
set_dtype(other_dtype);
|
|
||||||
if (buf_ != other.buf_) {
|
|
||||||
UnrefIfNonNull(buf_);
|
|
||||||
buf_ = other.buf_;
|
|
||||||
RefIfNonNull(buf_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
|
Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
|
||||||
const TensorShape& shape) {
|
const TensorShape& shape) {
|
||||||
int in_size = DataTypeSize(other.dtype());
|
int in_size = DataTypeSize(other.dtype());
|
||||||
|
@ -697,7 +697,19 @@ class Tensor {
|
|||||||
set_dtype(dt);
|
set_dtype(dt);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyFromInternal(const Tensor& other, const TensorShape& shape);
|
inline void CopyFromInternal(const Tensor& other, const TensorShape& shape) {
|
||||||
|
DCHECK_EQ(shape.num_elements(), other.NumElements());
|
||||||
|
// Data type will be overwritten if this == &other, since dtype is part of
|
||||||
|
// shape.
|
||||||
|
DataType other_dtype = other.dtype();
|
||||||
|
shape_ = shape;
|
||||||
|
set_dtype(other_dtype);
|
||||||
|
if (buf_ != other.buf_) {
|
||||||
|
if (buf_) buf_->Unref();
|
||||||
|
buf_ = other.buf_;
|
||||||
|
if (buf_) buf_->Ref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* base() const;
|
T* base() const;
|
||||||
|
@ -77,7 +77,7 @@ struct ParallelMatMulKernel {
|
|||||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||||
const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
|
const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||||
bool trans_y, const MatMulBCast& bcast, Tensor* out,
|
bool trans_y, const MatMulBCast& bcast, Tensor* out,
|
||||||
int start, int limit) {
|
int batch_size) {
|
||||||
static_assert(IsComplex, "Complex type expected.");
|
static_assert(IsComplex, "Complex type expected.");
|
||||||
auto Tx = in_x.tensor<Scalar, 3>();
|
auto Tx = in_x.tensor<Scalar, 3>();
|
||||||
auto Ty = in_y.tensor<Scalar, 3>();
|
auto Ty = in_y.tensor<Scalar, 3>();
|
||||||
@ -94,7 +94,8 @@ struct ParallelMatMulKernel {
|
|||||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||||
for (int64 i = start; i < limit; ++i) {
|
// TODO(rmlarsen): Consider launching these contractions asynchronously.
|
||||||
|
for (int64 i = 0; i < batch_size; ++i) {
|
||||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||||
|
|
||||||
@ -121,18 +122,24 @@ struct ParallelMatMulKernel<Scalar, false> {
|
|||||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||||
bool trans_y, const MatMulBCast& bcast, Tensor* out,
|
bool trans_y, const MatMulBCast& bcast, Tensor* out,
|
||||||
int start, int limit) {
|
int batch_size) {
|
||||||
|
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||||
|
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
|
||||||
|
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
||||||
|
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
|
||||||
|
if (batch_size == 1 && !should_bcast) {
|
||||||
|
auto Tx = in_x.flat_inner_dims<Scalar, 2>();
|
||||||
|
auto Ty = in_y.flat_inner_dims<Scalar, 2>();
|
||||||
|
auto Tz = out->flat_inner_dims<Scalar, 2>();
|
||||||
|
Tz.device(d) = Tx.contract(Ty, contract_pairs);
|
||||||
|
} else {
|
||||||
auto Tx = in_x.tensor<Scalar, 3>();
|
auto Tx = in_x.tensor<Scalar, 3>();
|
||||||
auto Ty = in_y.tensor<Scalar, 3>();
|
auto Ty = in_y.tensor<Scalar, 3>();
|
||||||
auto Tz = out->tensor<Scalar, 3>();
|
auto Tz = out->tensor<Scalar, 3>();
|
||||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
|
||||||
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
|
|
||||||
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
|
|
||||||
|
|
||||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
|
||||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||||
for (int64 i = start; i < limit; ++i) {
|
// TODO(rmlarsen): Consider launching these contractions asynchronously.
|
||||||
|
for (int64 i = 0; i < batch_size; ++i) {
|
||||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||||
auto x = Tx.template chip<0>(x_batch_index);
|
auto x = Tx.template chip<0>(x_batch_index);
|
||||||
@ -142,6 +149,7 @@ struct ParallelMatMulKernel<Scalar, false> {
|
|||||||
z.device(d) = x.contract(y, contract_pairs);
|
z.device(d) = x.contract(y, contract_pairs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Sequential batch matmul kernel that calls the regular Eigen matmul.
|
// Sequential batch matmul kernel that calls the regular Eigen matmul.
|
||||||
@ -234,13 +242,15 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
|||||||
// Jan 21, 2020.
|
// Jan 21, 2020.
|
||||||
const int64 kMaxCostOuterParallelism = 128 * 128; // heuristic.
|
const int64 kMaxCostOuterParallelism = 128 * 128; // heuristic.
|
||||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||||
|
// TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
|
||||||
|
// evaluation in Eigen Tensor.
|
||||||
if (small_dim > 1 &&
|
if (small_dim > 1 &&
|
||||||
(batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
|
(batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
|
||||||
// Parallelize over inner dims.
|
// Parallelize over inner dims.
|
||||||
// For large matrix products it is counter-productive to parallelize
|
// For large matrix products it is counter-productive to parallelize
|
||||||
// over the batch dimension.
|
// over the batch dimension.
|
||||||
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
|
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
|
||||||
trans_y, bcast, out, 0, batch_size);
|
trans_y, bcast, out, batch_size);
|
||||||
conjugate_result = adj_x;
|
conjugate_result = adj_x;
|
||||||
} else {
|
} else {
|
||||||
// Parallelize over outer dims. For small matrices and large batches, it
|
// Parallelize over outer dims. For small matrices and large batches, it
|
||||||
@ -656,7 +666,11 @@ class BaseBatchMatMulOp : public OpKernel {
|
|||||||
const Tensor& in0 = ctx->input(0);
|
const Tensor& in0 = ctx->input(0);
|
||||||
const Tensor& in1 = ctx->input(1);
|
const Tensor& in1 = ctx->input(1);
|
||||||
|
|
||||||
ValidateInputTensors(ctx, in0, in1);
|
const Status s = ValidateInputTensors(ctx, in0, in1);
|
||||||
|
if (!s.ok()) {
|
||||||
|
ctx->SetStatus(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
|
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
@ -740,7 +754,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||||
const Tensor& in1) = 0;
|
const Tensor& in1) = 0;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -761,32 +775,37 @@ class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
|
|||||||
~BatchMatMulOp() override {}
|
~BatchMatMulOp() override {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||||
const Tensor& in1) override {
|
const Tensor& in1) override {
|
||||||
// Disallow broadcasting support. Ensure that all batch dimensions of the
|
// Disallow broadcasting support. Ensure that all batch dimensions of the
|
||||||
// input tensors match.
|
// input tensors match.
|
||||||
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
|
if (in0.dims() != in1.dims()) {
|
||||||
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
|
return errors::InvalidArgument(
|
||||||
in0.shape().DebugString(), " vs. ",
|
"In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
|
||||||
in1.shape().DebugString()));
|
" vs. ", in1.shape().DebugString());
|
||||||
|
}
|
||||||
const int ndims = in0.dims();
|
const int ndims = in0.dims();
|
||||||
if (is_legacy_matmul) {
|
if (is_legacy_matmul) {
|
||||||
OP_REQUIRES(ctx, ndims == 2,
|
if (ndims != 2) {
|
||||||
errors::InvalidArgument(
|
return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
|
||||||
"In[0] and In[1] ndims must be == 2: ", ndims));
|
ndims);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES(ctx, ndims >= 2,
|
if (ndims < 2) {
|
||||||
errors::InvalidArgument(
|
return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
|
||||||
"In[0] and In[1] ndims must be >= 2: ", ndims));
|
ndims);
|
||||||
|
}
|
||||||
for (int i = 0; i < ndims - 2; ++i) {
|
for (int i = 0; i < ndims - 2; ++i) {
|
||||||
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
|
if (in0.dim_size(i) != in1.dim_size(i)) {
|
||||||
errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"In[0].dim(", i, ") and In[1].dim(", i,
|
"In[0].dim(", i, ") and In[1].dim(", i,
|
||||||
") must be the same: ", in0.shape().DebugString(),
|
") must be the same: ", in0.shape().DebugString(), " vs ",
|
||||||
" vs ", in1.shape().DebugString()));
|
in1.shape().DebugString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// BatchMatMul Op implementation with broadcasting support.
|
// BatchMatMul Op implementation with broadcasting support.
|
||||||
@ -800,16 +819,17 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
|
|||||||
~BatchMatMulV2Op() override {}
|
~BatchMatMulV2Op() override {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||||
const Tensor& in1) override {
|
const Tensor& in1) override {
|
||||||
// Enable broadcasting support. Validity of broadcasting is checked in
|
// Enable broadcasting support. Validity of broadcasting is checked in
|
||||||
// BaseBatchMatMulOp.
|
// BaseBatchMatMulOp.
|
||||||
OP_REQUIRES(
|
if (in0.dims() < 2) {
|
||||||
ctx, in0.dims() >= 2,
|
return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
|
||||||
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
|
}
|
||||||
OP_REQUIRES(
|
if (in1.dims() < 2) {
|
||||||
ctx, in1.dims() >= 2,
|
return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
|
||||||
errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -99,7 +99,6 @@ filegroup(
|
|||||||
"guarded_philox_random.h",
|
"guarded_philox_random.h",
|
||||||
"matmul_autotune.cc",
|
"matmul_autotune.cc",
|
||||||
"matmul_autotune.h",
|
"matmul_autotune.h",
|
||||||
"matmul_bcast.cc",
|
|
||||||
"matmul_bcast.h",
|
"matmul_bcast.h",
|
||||||
"mirror_pad_mode.cc",
|
"mirror_pad_mode.cc",
|
||||||
"mirror_pad_mode.h",
|
"mirror_pad_mode.h",
|
||||||
@ -221,7 +220,6 @@ filegroup(
|
|||||||
"example_proto_helper.cc",
|
"example_proto_helper.cc",
|
||||||
"guarded_philox_random.cc",
|
"guarded_philox_random.cc",
|
||||||
"matmul_autotune.cc",
|
"matmul_autotune.cc",
|
||||||
"matmul_bcast.cc",
|
|
||||||
"mirror_pad_mode.cc",
|
"mirror_pad_mode.cc",
|
||||||
"saved_tensor_slice_util.cc",
|
"saved_tensor_slice_util.cc",
|
||||||
"stat_summarizer.cc",
|
"stat_summarizer.cc",
|
||||||
|
@ -1,42 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/core/util/matmul_bcast.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
MatMulBCast::MatMulBCast(Vec x, Vec y) {
|
|
||||||
if (x.size() < 2 || y.size() < 2) return;
|
|
||||||
x.resize(x.size() - 2);
|
|
||||||
y.resize(y.size() - 2);
|
|
||||||
|
|
||||||
batch_bcast_ = absl::make_unique<BCast>(std::move(x), std::move(y));
|
|
||||||
if (!batch_bcast_->IsValid()) return;
|
|
||||||
|
|
||||||
x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements();
|
|
||||||
y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements();
|
|
||||||
output_shape_ = TensorShape(batch_bcast_->output_shape());
|
|
||||||
output_batch_size_ = output_shape_.num_elements();
|
|
||||||
broadcasting_required_ =
|
|
||||||
std::min(x_batch_size_, y_batch_size_) != output_batch_size_;
|
|
||||||
|
|
||||||
if (broadcasting_required_) {
|
|
||||||
ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(),
|
|
||||||
batch_bcast_->x_bcast(), &x_batch_indices_);
|
|
||||||
ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(),
|
|
||||||
batch_bcast_->y_bcast(), &y_batch_indices_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
@ -26,20 +26,50 @@ namespace tensorflow {
|
|||||||
|
|
||||||
// Simple wrapper over BCast specialized for MatMul.
|
// Simple wrapper over BCast specialized for MatMul.
|
||||||
// Provides utilities for broadcasting across batch dimensions for binary
|
// Provides utilities for broadcasting across batch dimensions for binary
|
||||||
// MatMul-like operations.
|
// MatMul-like operations. If neither argument has batch dimensions (rank <= 2)
|
||||||
|
// then no broadcasting is needed and the operation MatMul operation is
|
||||||
|
// considered valid.
|
||||||
class MatMulBCast {
|
class MatMulBCast {
|
||||||
public:
|
public:
|
||||||
using Vec = BCast::Vec;
|
using Vec = BCast::Vec;
|
||||||
|
|
||||||
MatMulBCast(Vec x, Vec y);
|
MatMulBCast(const Vec& x, const Vec& y) {
|
||||||
|
if (std::max(x.size(), y.size()) == 2) return;
|
||||||
|
const Vec x_resized(x.begin(), x.end() - 2);
|
||||||
|
const Vec y_resized(y.begin(), y.end() - 2);
|
||||||
|
|
||||||
bool IsValid() const { return batch_bcast_ && batch_bcast_->IsValid(); }
|
batch_bcast_ =
|
||||||
|
absl::make_unique<BCast>(std::move(x_resized), std::move(y_resized));
|
||||||
|
if (!batch_bcast_->IsValid()) {
|
||||||
|
// Set broadcasting_required_ to true to make IsValid() return false;
|
||||||
|
broadcasting_required_ = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements();
|
||||||
|
y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements();
|
||||||
|
output_batch_shape_ = TensorShape(batch_bcast_->output_shape());
|
||||||
|
output_batch_size_ = output_batch_shape_.num_elements();
|
||||||
|
broadcasting_required_ =
|
||||||
|
std::min(x_batch_size_, y_batch_size_) != output_batch_size_;
|
||||||
|
|
||||||
|
if (broadcasting_required_) {
|
||||||
|
ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(),
|
||||||
|
batch_bcast_->x_bcast(), &x_batch_indices_);
|
||||||
|
ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(),
|
||||||
|
batch_bcast_->y_bcast(), &y_batch_indices_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsValid() const {
|
||||||
|
return !broadcasting_required_ || (batch_bcast_ && batch_bcast_->IsValid());
|
||||||
|
}
|
||||||
bool IsBroadcastingRequired() const { return broadcasting_required_; }
|
bool IsBroadcastingRequired() const { return broadcasting_required_; }
|
||||||
|
|
||||||
const int64 output_batch_size() const { return output_batch_size_; }
|
const int64 output_batch_size() const { return output_batch_size_; }
|
||||||
const int64 x_batch_size() const { return x_batch_size_; }
|
const int64 x_batch_size() const { return x_batch_size_; }
|
||||||
const int64 y_batch_size() const { return y_batch_size_; }
|
const int64 y_batch_size() const { return y_batch_size_; }
|
||||||
const TensorShape& output_batch_shape() const { return output_shape_; }
|
const TensorShape& output_batch_shape() const { return output_batch_shape_; }
|
||||||
|
|
||||||
// Returns the mapping from the flattened output batch indices to x's
|
// Returns the mapping from the flattened output batch indices to x's
|
||||||
// flattened batch indices. The result is a vector of length
|
// flattened batch indices. The result is a vector of length
|
||||||
@ -57,10 +87,10 @@ class MatMulBCast {
|
|||||||
private:
|
private:
|
||||||
std::unique_ptr<BCast> batch_bcast_;
|
std::unique_ptr<BCast> batch_bcast_;
|
||||||
bool broadcasting_required_ = false;
|
bool broadcasting_required_ = false;
|
||||||
int64 x_batch_size_;
|
int64 x_batch_size_ = 1;
|
||||||
int64 y_batch_size_;
|
int64 y_batch_size_ = 1;
|
||||||
TensorShape output_shape_;
|
TensorShape output_batch_shape_;
|
||||||
int64 output_batch_size_;
|
int64 output_batch_size_ = 1;
|
||||||
std::vector<int64> x_batch_indices_;
|
std::vector<int64> x_batch_indices_;
|
||||||
std::vector<int64> y_batch_indices_;
|
std::vector<int64> y_batch_indices_;
|
||||||
};
|
};
|
||||||
|
@ -94,14 +94,18 @@ TEST(MatMulBCastTest, EmptyWithNonEmptyBatchBroadcast) {
|
|||||||
EXPECT_EQ("[2][0,1][0,0]", MatMulBCastToStr(bcast2));
|
EXPECT_EQ("[2][0,1][0,0]", MatMulBCastToStr(bcast2));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(MatMulBCastTest, InvalidDimensions) {
|
TEST(MatMulBCastTest, NoBathcDimensions) {
|
||||||
// Too few dimensions.
|
|
||||||
MatMulBCast bcast1({3, 3}, {3});
|
MatMulBCast bcast1({3, 3}, {3});
|
||||||
EXPECT_FALSE(bcast1.IsValid());
|
EXPECT_TRUE(bcast1.IsValid());
|
||||||
|
|
||||||
MatMulBCast bcast2({3}, {3, 3});
|
MatMulBCast bcast2({3}, {3, 3});
|
||||||
EXPECT_FALSE(bcast2.IsValid());
|
EXPECT_TRUE(bcast2.IsValid());
|
||||||
|
|
||||||
|
MatMulBCast bcast3({3, 3}, {3, 3});
|
||||||
|
EXPECT_TRUE(bcast3.IsValid());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MatMulBCastTest, InvalidDimensions) {
|
||||||
// Batch dimensions not broadcastable.
|
// Batch dimensions not broadcastable.
|
||||||
MatMulBCast bcast3({4, 5, 3}, {2, 3, 7});
|
MatMulBCast bcast3({4, 5, 3}, {2, 3, 7});
|
||||||
EXPECT_FALSE(bcast3.IsValid());
|
EXPECT_FALSE(bcast3.IsValid());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user