Add broadcasting support to tf.matmul.

Add Numpy-style broadcasting in the batch dimensions for tf.matmul op. The last two dimensions of both operands constitute the matrix dimensions. The dimensions beyond these are broadcasted to form a common output shape with the standard NumPy broadcasting rules. (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
Note: This implementation differs from Numpy's behavior in that vectors (rank-1 Tensors) are not promoted to matrices (rank-2 Tensors) by appending/prepending dimensions.
PiperOrigin-RevId: 241040476
This commit is contained in:
Anudhyan Boral 2019-03-29 13:30:57 -07:00 committed by TensorFlower Gardener
parent 51dbafd723
commit 47ab68d265
19 changed files with 1093 additions and 182 deletions

View File

@ -8,6 +8,7 @@ tensorflow/contrib/boosted_trees/ops/training_ops.cc
tensorflow/core/kernels/aggregate_ops.cc
tensorflow/core/kernels/argmax_op.cc
tensorflow/core/kernels/avgpooling_op.cc
tensorflow/core/kernels/batch_matmul_op_common.cc
tensorflow/core/kernels/batch_matmul_op_real.cc
tensorflow/core/kernels/batch_norm_op.cc
tensorflow/core/kernels/batchtospace_op.cc

View File

@ -0,0 +1,59 @@
op {
graph_op_name: "BatchMatMulV2"
in_arg {
name: "x"
description: <<END
2-D or higher with shape `[..., r_x, c_x]`.
END
}
in_arg {
name: "y"
description: <<END
2-D or higher with shape `[..., r_y, c_y]`.
END
}
out_arg {
name: "output"
description: <<END
3-D or higher with shape `[..., r_o, c_o]`
END
}
attr {
name: "adj_x"
description: <<END
If `True`, adjoint the slices of `x`. Defaults to `False`.
END
}
attr {
name: "adj_y"
description: <<END
If `True`, adjoint the slices of `y`. Defaults to `False`.
END
}
summary: "Multiplies slices of two tensors in batches."
description: <<END
Multiplies all slices of `Tensor` `x` and `y` (each slice can be
viewed as an element of a batch), and arranges the individual results
in a single output tensor of the same batch size. Each of the
individual slices can optionally be adjointed (to adjoint a matrix
means to transpose and conjugate it) before multiplication by setting
the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
and `[..., r_y, c_y]`.
The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
r_o = c_x if adj_x else r_x
c_o = r_y if adj_y else c_y
It is computed as:
output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
*NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More
about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
END
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BatchMatMulV2"
visibility: HIDDEN
}

View File

@ -233,6 +233,43 @@ Status MatMulShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
ShapeHandle a_shape;
ShapeHandle b_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
// Determine output rows and columns.
bool adj_x;
bool adj_y;
TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
// Inner dimensions should be compatible.
DimensionHandle inner_merged;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
c->Dim(b_shape, adj_y ? -1 : -2), &inner_merged));
// Batch dimensions should broadcast with each other.
ShapeHandle a_batch_shape;
ShapeHandle b_batch_shape;
ShapeHandle output_batch_shape;
TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_shape));
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, a_batch_shape, b_batch_shape, &output_batch_shape));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(c->Concatenate(
output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape));
c->set_output(0, output_shape);
return Status::OK();
}
Status BiasAddShape(shape_inference::InferenceContext* c) {
ShapeHandle input_shape;

View File

@ -226,6 +226,10 @@ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
// Shape function for MatMul-like operations.
Status MatMulShape(shape_inference::InferenceContext* c);
// Shape function for Batched MatMul-like operations with broadcasting across
// batch dimensions.
Status BatchMatMulV2Shape(shape_inference::InferenceContext* c);
// Shape function for BiasAdd-like operations.
Status BiasAddShape(shape_inference::InferenceContext* c);

View File

@ -213,6 +213,74 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
}
}
TEST(CommonShapeFnsTest, BatchMatMulV2_ShapeFn) {
ShapeInferenceTestOp op("BatchMatMulV2");
auto set_adj = [&op](bool adj_x, bool adj_y) {
TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMulV2")
.Input({"a", 0, DT_FLOAT})
.Input({"b", 0, DT_FLOAT})
.Attr("adj_x", adj_x)
.Attr("adj_y", adj_y)
.Finalize(&op.node_def));
};
set_adj(false, false);
// Rank checks.
INFER_ERROR("at least rank 2", op, "[];?");
INFER_ERROR("at least rank 2", op, "[1];?");
INFER_ERROR("at least rank 2", op, "?;[]");
INFER_ERROR("at least rank 2", op, "?;[2]");
INFER_OK(op, "?;?", "?");
// 0 batch dims.
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
// 1 batch dims.
INFER_OK(op, "[3,?,?];[3,?,?]", "[d0_0,d0_1,d1_2]");
INFER_OK(op, "[?,?,?];[1,?,?]", "[d0_0,d0_1,d1_2]");
INFER_OK(op, "[?,?,?];[2,?,?]", "[d1_0,d0_1,d1_2]");
INFER_OK(op, "[1,?,?];[?,?,?]", "[d1_0,d0_1,d1_2]");
INFER_OK(op, "[2,?,?];[?,?,?]", "[d0_0,d0_1,d1_2]");
INFER_OK(op, "[?,?,?];[?,?,?]", "[?,d0_1,d1_2]");
// Empty batch dim with broadcasting.
INFER_OK(op, "[?,?];[?,?,?]", "[d1_0,d0_0,d1_2]");
INFER_OK(op, "[?,?,?];[?,?]", "[d0_0,d0_1,d1_1]");
INFER_OK(op, "[?,?];[?,?,?,?]", "[d1_0,d1_1,d0_0,d1_3]");
INFER_OK(op, "[?,?,?,?];[?,?]", "[d0_0,d0_1,d0_2,d1_1]");
// Unknown number of batch dims.
INFER_OK(op, "[?,?];?", "?");
INFER_OK(op, "?;[?,?]", "?");
INFER_OK(op, "[?,?,?,?];?", "?");
// Large number of batch dims.
INFER_OK(op, "[?,?,?,?,?];[1,?,?]", "[d0_0,d0_1,d0_2,d0_3,d1_2]");
INFER_OK(op, "[1,?,?];[?,?,?,?,?]", "[d1_0,d1_1,d1_2,d0_1,d1_4]");
// Batch dim mismatch.
INFER_ERROR("are 2 and 3", op, "[?,?,2,?,?];[3,?,?]");
INFER_ERROR("are 2 and 3", op, "[2,?,?];[?,?,3,?,?]");
// Test adj_a, testing output and that inner dims are compared.
set_adj(false, false);
INFER_OK(op, "[2,2,3,4];[2,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch
set_adj(true, false);
INFER_OK(op, "[2,2,3,4];[2,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]");
INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch
// Test adj_b=true.
set_adj(false, true);
INFER_OK(op, "[2,2,?,?];[2,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]");
INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch
set_adj(true, true);
INFER_OK(op, "[2,2,?,?];[2,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]");
INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch
}
TEST(CommonShapeFnsTest, BiasAddShapeTest) {
OpRegistrationData op_reg_data;
TF_CHECK_OK(OpDefBuilder("BiasAdd")

View File

@ -3572,12 +3572,26 @@ tf_cuda_cc_test(
],
)
tf_cc_test(
name = "batch_matmul_op_common_test",
size = "small",
srcs = ["batch_matmul_op_common_test.cc"],
deps = [
":batch_matmul_op",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cuda_cc_test(
name = "batch_matmul_op_test",
size = "small",
srcs = ["batch_matmul_op_test.cc"],
deps = [
":batch_matmul_op",
":broadcast_to_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@ -5418,6 +5432,8 @@ filegroup(
name = "mobile_srcs",
srcs = [
"avgpooling_op.h",
"batch_matmul_op_common.cc",
"batch_matmul_op_common.h",
"batch_util.h",
"cwise_ops.h",
"cwise_ops_common.h",
@ -5912,6 +5928,7 @@ filegroup(
"*_3d*",
"*.cu.*",
# Ops already in android_srcs
"batch_matmul_op_common.cc",
"pooling_ops_common.cc",
# Ops which we are currently excluding because they are likely
# not used on Android. Those ops also do not compile if included,

View File

@ -0,0 +1,76 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/batch_matmul_op_common.h"
namespace tensorflow {
namespace {
// Returns the mapping from the output batch indices to the corresponding
// input's batch indices, given the input's "reshape" and "bcast" shapes as
// returned by the BCast helper class. The i'th element denotes the (flattened)
// batch index of the input that must be used to compute the i'th batch output.
void ComputeBatchIndices(const int64 output_batch_size,
const MatMulBCast::Vec& reshape,
const MatMulBCast::Vec& bcast,
std::vector<int64>* out_indices) {
// Populates the mapping in out_indices. This algorithm is identical to
// the following steps:
// - Reshape {0, 1, ..., input_batch_size - 1} to the input shape.
// - Broadcast to the output shape.
// - Reshape back to a flat 1D vector.
out_indices->resize(output_batch_size);
int64 num_output_elements = 1;
int64 num_input_elements = 1;
for (int64 i = reshape.size() - 1; i >= 0; --i) {
// Replicate the already populated mapping an additional (dim - 1) times.
// If we are broadcasting, just copy the existing mapping.
// Otherwise, add another dimension from the input shape.
const int64 dim = std::max(reshape[i], bcast[i]);
const int64 incr = bcast[i] > 1 ? 0 : num_input_elements;
for (int64 k = 0; k < (dim - 1) * num_output_elements; ++k) {
(*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr;
}
num_output_elements *= dim;
num_input_elements *= reshape[i];
}
}
} // namespace
MatMulBCast::MatMulBCast(Vec x, Vec y) {
if (x.size() < 2 || y.size() < 2) return;
x.resize(x.size() - 2);
y.resize(y.size() - 2);
batch_bcast_ = absl::make_unique<BCast>(std::move(x), std::move(y));
if (!batch_bcast_->IsValid()) return;
x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements();
y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements();
output_shape_ = TensorShape(batch_bcast_->output_shape());
output_batch_size_ = output_shape_.num_elements();
broadcasting_required_ =
std::min(x_batch_size_, y_batch_size_) != output_batch_size_;
if (broadcasting_required_) {
ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(),
batch_bcast_->x_bcast(), &x_batch_indices_);
ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(),
batch_bcast_->y_bcast(), &y_batch_indices_);
}
}
} // namespace tensorflow

View File

@ -0,0 +1,70 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
// Simple wrapper over BCast specialized for MatMul.
// Provides utilities for broadcasting across batch dimensions for binary
// MatMul-like operations.
class MatMulBCast {
public:
using Vec = BCast::Vec;
MatMulBCast(Vec x, Vec y);
bool IsValid() const { return batch_bcast_ && batch_bcast_->IsValid(); }
bool IsBroadcastingRequired() const { return broadcasting_required_; }
const int64 output_batch_size() const { return output_batch_size_; }
const int64 x_batch_size() const { return x_batch_size_; }
const int64 y_batch_size() const { return y_batch_size_; }
const TensorShape& output_batch_shape() const { return output_shape_; }
// Returns the mapping from the flattened output batch indices to x's
// flattened batch indices. The result is a vector of length
// output_batch_size(). To compute the i'th batch output, a binary matmul-like
// operation should use the `x_batch_indices()[i]`th batch index of `x`.
// Note: Returns an empty vector if broadcasting is not required. Callers
// should only use this when IsBroadcastingRequired() returns true.
const std::vector<int64>& x_batch_indices() const { return x_batch_indices_; }
// Returns the mapping from the flattened output batch indices to y's
// flattened batch indices. Similar to x_batch_indices().
// Note: Returns an empty vector if broadcasting is not required. Callers
// should only use this when IsBroadcastingRequired() returns true.
const std::vector<int64>& y_batch_indices() const { return y_batch_indices_; }
private:
std::unique_ptr<BCast> batch_bcast_;
bool broadcasting_required_ = false;
int64 x_batch_size_;
int64 y_batch_size_;
TensorShape output_shape_;
int64 output_batch_size_;
std::vector<int64> x_batch_indices_;
std::vector<int64> y_batch_indices_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_

View File

@ -0,0 +1,138 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/batch_matmul_op_common.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
string MatMulBCastToStr(const MatMulBCast& b) {
if (!b.IsValid()) {
return "invalid";
}
string ret;
strings::StrAppend(
&ret, "[", str_util::Join(b.output_batch_shape().dim_sizes(), ","), "]");
strings::StrAppend(&ret, "[", str_util::Join(b.x_batch_indices(), ","), "]");
strings::StrAppend(&ret, "[", str_util::Join(b.y_batch_indices(), ","), "]");
return ret;
}
TEST(MatMulBCastTest, SimpleBroadcast) {
MatMulBCast bcast({1, 5, 3}, {4, 3, 7});
EXPECT_TRUE(bcast.IsValid());
EXPECT_TRUE(bcast.IsBroadcastingRequired());
EXPECT_EQ(1, bcast.x_batch_size());
EXPECT_EQ(4, bcast.y_batch_size());
EXPECT_EQ(4, bcast.output_batch_size());
EXPECT_EQ("[4][0,0,0,0][0,1,2,3]", MatMulBCastToStr(bcast));
}
TEST(MatMulBCastTest, EmptyBatchBroadcast) {
MatMulBCast bcast({5, 3}, {3, 7});
EXPECT_TRUE(bcast.IsValid());
EXPECT_FALSE(bcast.IsBroadcastingRequired());
EXPECT_EQ(1, bcast.x_batch_size());
EXPECT_EQ(1, bcast.y_batch_size());
EXPECT_EQ(1, bcast.output_batch_size());
EXPECT_EQ("[][][]", MatMulBCastToStr(bcast));
}
TEST(MatMulBCastTest, BroadcastingNotRequired) {
MatMulBCast bcast({2, 4, 6, 5, 3}, {2, 4, 6, 3, 7});
EXPECT_TRUE(bcast.IsValid());
EXPECT_FALSE(bcast.IsBroadcastingRequired());
EXPECT_EQ(48, bcast.x_batch_size());
EXPECT_EQ(48, bcast.y_batch_size());
EXPECT_EQ(48, bcast.output_batch_size());
EXPECT_EQ("[2,4,6][][]", MatMulBCastToStr(bcast));
}
TEST(MatMulBCastTest, EmptyWithNonEmptyBatchBroadcast) {
MatMulBCast bcast1({5, 3}, {6, 3, 7});
EXPECT_TRUE(bcast1.IsValid());
EXPECT_TRUE(bcast1.IsBroadcastingRequired());
EXPECT_EQ(1, bcast1.x_batch_size());
EXPECT_EQ(6, bcast1.y_batch_size());
EXPECT_EQ(6, bcast1.output_batch_size());
EXPECT_EQ("[6][0,0,0,0,0,0][0,1,2,3,4,5]", MatMulBCastToStr(bcast1));
MatMulBCast bcast2({2, 5, 3}, {3, 7});
EXPECT_TRUE(bcast2.IsValid());
EXPECT_TRUE(bcast2.IsBroadcastingRequired());
EXPECT_EQ(2, bcast2.x_batch_size());
EXPECT_EQ(1, bcast2.y_batch_size());
EXPECT_EQ(2, bcast2.output_batch_size());
EXPECT_EQ("[2][0,1][0,0]", MatMulBCastToStr(bcast2));
}
TEST(MatMulBCastTest, InvalidDimensions) {
// Too few dimensions.
MatMulBCast bcast1({3, 3}, {3});
EXPECT_FALSE(bcast1.IsValid());
MatMulBCast bcast2({3}, {3, 3});
EXPECT_FALSE(bcast2.IsValid());
// Batch dimensions not broadcastable.
MatMulBCast bcast3({4, 5, 3}, {2, 3, 7});
EXPECT_FALSE(bcast3.IsValid());
MatMulBCast bcast4({2, 1, 5, 3}, {1, 3, 1, 3, 7});
EXPECT_FALSE(bcast4.IsValid());
}
TEST(MatMulBCastTest, BroadcastBothOperands) {
MatMulBCast bcast({3, 1, 5, 3}, {1, 4, 3, 7});
EXPECT_TRUE(bcast.IsValid());
EXPECT_EQ(3, bcast.x_batch_size());
EXPECT_EQ(4, bcast.y_batch_size());
EXPECT_EQ(12, bcast.output_batch_size());
EXPECT_EQ("[3,4][0,0,0,0,1,1,1,1,2,2,2,2][0,1,2,3,0,1,2,3,0,1,2,3]",
MatMulBCastToStr(bcast));
}
TEST(MatMulBCastTest, DifferentRanks) {
MatMulBCast bcast({3, 1, 5, 3}, {2, 1, 2, 3, 7});
EXPECT_TRUE(bcast.IsValid());
EXPECT_EQ(3, bcast.x_batch_size());
EXPECT_EQ(4, bcast.y_batch_size());
EXPECT_EQ(12, bcast.output_batch_size());
EXPECT_EQ("[2,3,2][0,0,1,1,2,2,0,0,1,1,2,2][0,1,0,1,0,1,2,3,2,3,2,3]",
MatMulBCastToStr(bcast));
}
} // namespace
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -29,7 +30,10 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/batch_matmul_op_common.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
@ -74,8 +78,8 @@ struct ParallelMatMulKernel {
}
static void Run(const OpKernelContext* context, const Tensor& in_x,
const Tensor in_y, bool adj_x, bool adj_y, Tensor* out,
int start, int limit) {
const Tensor in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
static_assert(IsComplex, "Complex type expected.");
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
@ -88,14 +92,21 @@ struct ParallelMatMulKernel {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
contract_pairs[0] = ContractionDims(adj_x, adj_y);
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
for (int i = start; i < limit; ++i) {
auto x = Tx.template chip<0>(i);
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
for (int64 i = start; i < limit; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto x = Tx.template chip<0>(x_batch_index);
auto z = Tz.template chip<0>(i);
if (adj_x != adj_y) {
auto y = Ty.template chip<0>(i).conjugate();
auto y = Ty.template chip<0>(y_batch_index).conjugate();
z.device(d) = x.contract(y, contract_pairs);
} else {
auto y = Ty.template chip<0>(i);
auto y = Ty.template chip<0>(y_batch_index);
z.device(d) = x.contract(y, contract_pairs);
}
}
@ -110,18 +121,25 @@ struct ParallelMatMulKernel<Scalar, false> {
static void Conjugate(const OpKernelContext* context, Tensor* out) {}
static void Run(const OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
int start, int limit) {
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
auto Tz = out->tensor<Scalar, 3>();
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
contract_pairs[0] = ContractionDims(adj_x, adj_y);
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
for (int i = start; i < limit; ++i) {
auto x = Tx.template chip<0>(i);
auto y = Ty.template chip<0>(i);
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
for (int64 i = start; i < limit; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto x = Tx.template chip<0>(x_batch_index);
auto y = Ty.template chip<0>(y_batch_index);
auto z = Tz.template chip<0>(i);
z.device(d) = x.contract(y, contract_pairs);
}
}
@ -151,10 +169,16 @@ struct SequentialMatMulKernel {
}
static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
bool adj_y, Tensor* out, int start, int limit) {
for (int i = start; i < limit; ++i) {
auto x = ConstTensorSliceToEigenMatrix(in_x, i);
auto y = ConstTensorSliceToEigenMatrix(in_y, i);
bool adj_y, const MatMulBCast& bcast, Tensor* out, int start,
int limit) {
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
for (int64 i = start; i < limit; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
auto z = TensorSliceToEigenMatrix(out, i);
if (!adj_x) {
if (!adj_y) {
@ -181,13 +205,14 @@ struct LaunchBatchMatMul;
template <typename Scalar>
struct LaunchBatchMatMul<CPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out) {
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
ParallelMatMulKernel;
bool conjugate_result = false;
// Number of matrix multiplies i.e. size of the batch.
const int64 batch_size = in_x.dim_size(0);
const int64 batch_size = bcast.output_batch_size();
const int64 cost_per_unit =
in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
const int64 small_dim = std::min(
@ -199,17 +224,17 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
// Parallelize over inner dims.
// For large matrix products it is counter-productive to parallelize
// over the batch dimension.
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, out, 0,
batch_size);
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, bcast, out,
0, batch_size);
conjugate_result = adj_x;
} else {
// Parallelize over outer dims. For small matrices and large batches, it
// is counter-productive to parallelize the inner matrix multiplies.
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
cost_per_unit,
[&in_x, &in_y, adj_x, adj_y, out](int start, int limit) {
SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y, out,
start, limit);
[&in_x, &in_y, adj_x, adj_y, &bcast, out](int start, int limit) {
SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
bcast, out, start, limit);
});
}
if (conjugate_result) {
@ -270,7 +295,8 @@ class CublasScratchAllocator : public se::ScratchAllocator {
template <typename Scalar>
struct LaunchBatchMatMul<GPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out) {
constexpr se::blas::Transpose kTranspose =
is_complex<Scalar>::value ? se::blas::Transpose::kConjugateTranspose
: se::blas::Transpose::kTranspose;
@ -279,7 +305,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
const uint64 batch_size = in_x.dim_size(0);
const int64 batch_size = bcast.output_batch_size();
auto blas_transpose_a = trans[adj_x];
auto blas_transpose_b = trans[adj_y];
@ -293,8 +319,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
std::vector<DeviceMemoryType*> a_ptrs;
std::vector<DeviceMemoryType*> b_ptrs;
std::vector<DeviceMemoryType*> c_ptrs;
a_device_memory.reserve(batch_size);
b_device_memory.reserve(batch_size);
a_device_memory.reserve(bcast.x_batch_size());
b_device_memory.reserve(bcast.y_batch_size());
c_device_memory.reserve(batch_size);
a_ptrs.reserve(batch_size);
b_ptrs.reserve(batch_size);
@ -302,13 +328,31 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
auto* a_base_ptr = in_x.template flat<Scalar>().data();
auto* b_base_ptr = in_y.template flat<Scalar>().data();
auto* c_base_ptr = out->template flat<Scalar>().data();
for (int64 i = 0; i < batch_size; ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
a_ptrs.push_back(&a_device_memory.back());
b_ptrs.push_back(&b_device_memory.back());
c_ptrs.push_back(&c_device_memory.back());
if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
a_ptrs.push_back(&a_device_memory.back());
b_ptrs.push_back(&b_device_memory.back());
c_ptrs.push_back(&c_device_memory.back());
}
} else {
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
}
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
}
for (int64 i = 0; i < batch_size; ++i) {
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
c_ptrs.push_back(&c_device_memory.back());
}
}
typedef Scalar Coefficient;
@ -385,7 +429,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
template <>
struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out) {
typedef Eigen::half Scalar;
constexpr perftools::gputools::blas::Transpose kTranspose =
is_complex<Scalar>::value
@ -396,7 +441,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
const uint64 batch_size = in_x.dim_size(0);
const uint64 batch_size = bcast.output_batch_size();
auto blas_transpose_a = trans[adj_x];
auto blas_transpose_b = trans[adj_y];
@ -410,8 +455,8 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
std::vector<DeviceMemoryType*> a_ptrs;
std::vector<DeviceMemoryType*> b_ptrs;
std::vector<DeviceMemoryType*> c_ptrs;
a_device_memory.reserve(batch_size);
b_device_memory.reserve(batch_size);
a_device_memory.reserve(bcast.x_batch_size());
b_device_memory.reserve(bcast.y_batch_size());
c_device_memory.reserve(batch_size);
a_ptrs.reserve(batch_size);
b_ptrs.reserve(batch_size);
@ -419,13 +464,31 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
auto* a_base_ptr = in_x.template flat<Scalar>().data();
auto* b_base_ptr = in_y.template flat<Scalar>().data();
auto* c_base_ptr = out->template flat<Scalar>().data();
for (int64 i = 0; i < batch_size; ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
a_ptrs.push_back(&a_device_memory.back());
b_ptrs.push_back(&b_device_memory.back());
c_ptrs.push_back(&c_device_memory.back());
if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
a_ptrs.push_back(&a_device_memory.back());
b_ptrs.push_back(&b_device_memory.back());
c_ptrs.push_back(&c_device_memory.back());
}
} else {
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
}
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
}
for (int64 i = 0; i < batch_size; ++i) {
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
c_ptrs.push_back(&c_device_memory.back());
}
}
typedef float Coefficient;
@ -480,17 +543,24 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
template <typename Scalar>
struct ParallelMatMulKernelSYCL {
static void Run(const OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
int start, int limit) {
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
auto Tz = out->tensor<Scalar, 3>();
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
contract_pairs[0] = ContractionDims(adj_x, adj_y);
auto d = context->eigen_sycl_device();
for (int i = start; i < limit; ++i) {
auto x = Tx.template chip<0>(i);
auto y = Ty.template chip<0>(i);
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
for (int64 i = start; i < limit; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto x = Tx.template chip<0>(x_batch_index);
auto y = Ty.template chip<0>(y_batch_index);
auto z = Tz.template chip<0>(i);
z.device(d) = x.contract(y, contract_pairs);
}
@ -500,54 +570,58 @@ struct ParallelMatMulKernelSYCL {
template <typename Scalar>
struct LaunchBatchMatMul<SYCLDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out) {
// Number of matrix multiplies i.e. size of the batch.
const int64 batch_size = in_x.dim_size(0);
const int64 batch_size = bcast.output_batch_size();
ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
out, 0, batch_size);
bcast, out, 0, batch_size);
}
};
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename Scalar>
class BatchMatMul : public OpKernel {
class BaseBatchMatMulOp : public OpKernel {
public:
explicit BatchMatMul(OpKernelConstruction* context) : OpKernel(context) {
explicit BaseBatchMatMulOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
}
virtual ~BatchMatMul() {}
~BaseBatchMatMulOp() override {}
void Compute(OpKernelContext* ctx) override {
const Tensor& in0 = ctx->input(0);
const Tensor& in1 = ctx->input(1);
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
in0.shape().DebugString(), " vs. ",
in1.shape().DebugString()));
const int ndims = in0.dims();
ValidateInputTensors(ctx, in0, in1);
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
OP_REQUIRES(
ctx, ndims >= 2,
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
TensorShape out_shape;
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
errors::InvalidArgument(
"In[0].dim(", i, ") and In[1].dim(", i,
") must be the same: ", in0.shape().DebugString(), " vs ",
in1.shape().DebugString()));
out_shape.AddDim(in0.dim_size(i));
}
auto n = (ndims == 2) ? 1 : out_shape.num_elements();
auto d0 = in0.dim_size(ndims - 2);
auto d1 = in0.dim_size(ndims - 1);
ctx, bcast.IsValid(),
errors::InvalidArgument(
"In[0] and In[1] must have compatible batch dimensions: ",
in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
TensorShape out_shape = bcast.output_batch_shape();
auto batch_size = bcast.output_batch_size();
auto d0 = in0.dim_size(in0.dims() - 2);
auto d1 = in0.dim_size(in0.dims() - 1);
Tensor in0_reshaped;
CHECK(in0_reshaped.CopyFrom(in0, TensorShape({n, d0, d1})));
auto d2 = in1.dim_size(ndims - 2);
auto d3 = in1.dim_size(ndims - 1);
OP_REQUIRES(
ctx,
in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
errors::Internal("Failed to reshape In[0] from ",
in0.shape().DebugString()));
auto d2 = in1.dim_size(in1.dims() - 2);
auto d3 = in1.dim_size(in1.dims() - 1);
Tensor in1_reshaped;
CHECK(in1_reshaped.CopyFrom(in1, TensorShape({n, d2, d3})));
OP_REQUIRES(
ctx,
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
errors::Internal("Failed to reshape In[1] from ",
in1.shape().DebugString()));
if (adj_x_) std::swap(d0, d1);
if (adj_y_) std::swap(d2, d3);
OP_REQUIRES(ctx, d1 == d2,
@ -568,31 +642,102 @@ class BatchMatMul : public OpKernel {
return;
}
Tensor out_reshaped;
CHECK(out_reshaped.CopyFrom(*out, TensorShape({n, d0, d3})));
LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
adj_x_, adj_y_, &out_reshaped);
OP_REQUIRES(ctx,
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
errors::Internal("Failed to reshape output from ",
out->shape().DebugString()));
LaunchBatchMatMul<Device, Scalar>::Launch(
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, bcast, &out_reshaped);
}
protected:
virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) = 0;
private:
bool adj_x_;
bool adj_y_;
};
#define REGISTER_BATCH_MATMUL_CPU(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMul<CPUDevice, TYPE>)
// BatchMatMul Op implementation which disallows broadcasting.
template <typename Device, typename Scalar>
class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
public:
explicit BatchMatMulOp(OpKernelConstruction* context)
: BaseBatchMatMulOp<Device, Scalar>(context) {}
#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMul<GPUDevice, TYPE>)
~BatchMatMulOp() override {}
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
// Disallow broadcasting support. Ensure that all batch dimensions of the
// input tensors match.
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
in0.shape().DebugString(), " vs. ",
in1.shape().DebugString()));
const int ndims = in0.dims();
OP_REQUIRES(
ctx, ndims >= 2,
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
errors::InvalidArgument(
"In[0].dim(", i, ") and In[1].dim(", i,
") must be the same: ", in0.shape().DebugString(), " vs ",
in1.shape().DebugString()));
}
}
};
// BatchMatMul Op implementation with broadcasting support.
template <typename Device, typename Scalar>
class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
public:
explicit BatchMatMulV2Op(OpKernelConstruction* context)
: BaseBatchMatMulOp<Device, Scalar>(context) {}
~BatchMatMulV2Op() override {}
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
// Enable broadcasting support. Validity of broadcasting is checked in
// BaseBatchMatMulOp.
OP_REQUIRES(
ctx, in0.dims() >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
OP_REQUIRES(
ctx, in1.dims() >= 2,
errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
}
};
#define REGISTER_BATCH_MATMUL_CPU(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulOp<CPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulV2Op<CPUDevice, TYPE>)
#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMulOp<GPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMulV2Op<GPUDevice, TYPE>)
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_BATCH_MATMUL_SYCL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BatchMatMul<SYCLDevice, TYPE>)
#define REGISTER_BATCH_MATMUL_SYCL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BatchMatMulOp<SYCLDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BatchMatMulV2Op<SYCLDevice, TYPE>)
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -14,11 +14,40 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/broadcast_to_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
.Input(input)
.Input(shape)
.Finalize(g, &ret));
return ret;
}
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
.Input(in0)
.Input(in1)
.Attr("adj_x", adj_x)
.Attr("adj_y", adj_y)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
@ -33,6 +62,40 @@ static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
return g;
}
template <typename T>
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
bool manual_broadcast, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, TensorShape({b0, m, k}));
in0.flat<T>().setRandom();
Tensor in1(type, TensorShape({b1, k, n}));
in1.flat<T>().setRandom();
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
Node* in0_node = nullptr;
Node* in1_node = nullptr;
if (manual_broadcast) {
for (int i = 0; i < 3; ++i) {
auto vec0 = broadcasted_in0_shape.vec<int64>();
auto vec1 = broadcasted_in1_shape.vec<int64>();
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
}
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
test::graph::Constant(g, broadcasted_in0_shape));
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
test::graph::Constant(g, broadcasted_in1_shape));
} else {
in0_node = test::graph::Constant(g, in0);
in1_node = test::graph::Constant(g, in1);
}
BatchMatmulV2(g, in0_node, in1_node, false, false);
return g;
}
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
static void \
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
@ -59,6 +122,64 @@ static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
// Macro arguments names: --------------------------------------------------- //
// B1: batch size of LHS
// B2: batch size of RHS
// M: outer dimension of LHS
// K: inner dimensions of LHS and RHS
// N: outer dimension of RHS
// MB: boolean indicating whether to use manual broadcasting
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
// D: Device (e.g. cpu, gpu)
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
static void \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
K * N * 2); \
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
.Run(iters); \
} \
BENCHMARK( \
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
// Typical fully connected layers
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
// Square matmul.
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
// Matrix-vector multiplies.
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
// Vector-matrix multiplies.
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
// Typical fully connected layers
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
@ -132,4 +253,5 @@ BM_BatchMatmul(1, 1, 200, 10000, true, true);
BM_BatchMatmul(8, 1, 200, 10000, true, true);
BM_BatchMatmul(32, 1, 200, 10000, true, true);
} // end namespace tensorflow
} // namespace
} // namespace tensorflow

View File

@ -158,6 +158,17 @@ REGISTER_OP("BatchMatMul")
return Status::OK();
});
REGISTER_OP("BatchMatMulV2")
.Input("x: T")
.Input("y: T")
.Output("output: T")
.Attr(
"T: {bfloat16, half, float, double, int32, int64, complex64, "
"complex128}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.SetShapeFn(shape_inference::BatchMatMulV2Shape);
// --------------------------------------------------------------------------
// Casting Ops
//

View File

@ -2782,6 +2782,7 @@ py_library(
":state_ops_gen",
":tensor_shape",
":util",
"//tensorflow/python/compat",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
],

View File

@ -21,63 +21,43 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.compat import compat
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
def GetRandomNormalInput(shape, dtype):
# float16 has limited range so we reduce the variance of the scalars.
scale = 10.0 if dtype != np.float16 else 0.1
loc = -10.0 if dtype != np.float16 else 0.1
vals = np.array(np.random.normal(loc, scale, np.prod(shape)), dtype=dtype)
if dtype in (np.complex64, np.complex128):
imag = np.array(np.random.normal(loc, scale, np.prod(shape)), dtype=dtype)
vals += 1j * imag
return vals.reshape(shape)
class BatchMatmulOpTest(test.TestCase):
# Uses numpy to compute batch_matmul(x, y, adjoint_a, adjoint_b).
def _npBatchMatmul(self, x, y, adjoint_a, adjoint_b):
# output's shape depends on adj[0] and adj[1]
d0 = x.shape[-2] if not adjoint_a else x.shape[-1]
d2 = y.shape[-1] if not adjoint_b else y.shape[-2]
batch_dims = x.shape[:-2]
num = np.prod(batch_dims)
z = np.empty(list(batch_dims) + [d0, d2], dtype=x.dtype)
xr = x.reshape([num, x.shape[-2], x.shape[-1]])
yr = y.reshape([num, y.shape[-2], y.shape[-1]])
zr = z.reshape([num, z.shape[-2], z.shape[-1]])
for i in range(num):
a = np.matrix(xr[i, :, :])
if adjoint_a:
a = a.transpose().conj()
b = np.matrix(yr[i, :, :])
if adjoint_b:
b = b.transpose().conj()
zr[i, :, :] = a * b
return z
if adjoint_a:
x = np.conjugate(np.swapaxes(x, -1, -2))
if adjoint_b:
y = np.conjugate(np.swapaxes(y, -1, -2))
return np.matmul(x, y)
# Test _npBatchMatMul works.
def testNpVersion(self):
x = np.array([0., 1., 2., 3.]).reshape([1, 2, 2])
y = np.array([1., 2., 3., 4.]).reshape([1, 2, 2])
z0 = self._npBatchMatmul(x, y, False, False)
z1 = np.array([3., 4., 11., 16.]).reshape([1, 2, 2])
self.assertTrue(np.array_equal(z0, z1))
x = np.array([1., (1j), (-1.), (-1j)]).reshape([1, 2, 2])
y = x * np.complex(1, 1) # rotate x 90 degree
z0 = self._npBatchMatmul(x, y, False, False)
z1 = np.array([2., (2.j), -2., (-2.j)]).reshape([1, 2, 2])
self.assertTrue(np.array_equal(z0, z1))
z0 = self._npBatchMatmul(x, y, False, True)
z1 = np.array([(2. - 2.j), (-2. + 2.j), (-2. + 2.j), (2. - 2.j)]).reshape(
[1, 2, 2])
self.assertTrue(np.array_equal(z0, z1))
z0 = self._npBatchMatmul(x, y, True, False)
z1 = np.array([(2. + 2.j), (-2. + 2.j), (2. - 2.j), (2. + 2.j)]).reshape(
[1, 2, 2])
self.assertTrue(np.array_equal(z0, z1))
# Compares _tfpBatchMatmul(x, y, alpha, adj) and _npBatchMatMul(x, y, alpha,
# adj)
def _compare(self, x_in, y_in, adjoint_a, adjoint_b, static_shape=True):
# Compares TensorFlow BatchMatmul with NumPy's matmul.
def _compare(self, x_in, y_in, adjoint_a, adjoint_b, static_shape):
x_t_shape = x_in.shape[:-2] + (x_in.shape[-1], x_in.shape[-2])
y_t_shape = y_in.shape[:-2] + (y_in.shape[-1], y_in.shape[-2])
x = x_in if not adjoint_a else x_in.reshape(x_t_shape)
@ -97,19 +77,15 @@ class BatchMatmulOpTest(test.TestCase):
z1 = self._npBatchMatmul(x, y, adjoint_a, adjoint_b)
self.assertAllClose(z0_val, z1, rtol=tol, atol=tol)
def _rand(self, shape, dtype):
vals = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype)
if dtype in (np.complex64, np.complex128):
imag = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype)
vals += 1j * imag
return vals.reshape(shape)
def _testNonEmpty(self, dtype, adjoint_a, adjoint_b, use_static_shape):
def CompareNonEmpty(self, a_shape, b_shape):
self._compare(
self._rand(a_shape, dtype),
self._rand(b_shape, dtype), adjoint_a, adjoint_b, use_static_shape)
GetRandomNormalInput(a_shape, dtype),
GetRandomNormalInput(b_shape, dtype),
adjoint_a,
adjoint_b,
static_shape=use_static_shape)
CompareNonEmpty(self, [1, 2, 3], [1, 3, 5])
CompareNonEmpty(self, [1, 2, 3], [1, 3, 1])
@ -121,13 +97,33 @@ class BatchMatmulOpTest(test.TestCase):
CompareNonEmpty(self, [10, 64, 75], [10, 75, 30])
CompareNonEmpty(self, [5, 7, 2, 3], [5, 7, 3, 5])
def _testBroadcasting(self, dtype, adjoint_a, adjoint_b, use_static_shape):
def CompareNonEmpty(self, a_shape, b_shape):
self._compare(
GetRandomNormalInput(a_shape, dtype),
GetRandomNormalInput(b_shape, dtype),
adjoint_a,
adjoint_b,
static_shape=use_static_shape)
CompareNonEmpty(self, [2, 3], [1, 3, 5])
CompareNonEmpty(self, [1, 2, 3], [3, 5])
CompareNonEmpty(self, [5, 1, 2, 3], [1, 7, 3, 5])
CompareNonEmpty(self, [5, 2, 2, 3], [3, 5])
CompareNonEmpty(self, [2, 3], [5, 2, 3, 5])
CompareNonEmpty(self, [4, 5, 1, 2, 3], [1, 1, 3, 5])
CompareNonEmpty(self, [1, 2, 1, 4, 2, 1, 3, 4], [3, 2, 1, 1, 1, 2, 4, 2])
def _testEmpty(self, dtype, adjoint_a, adjoint_b, use_static_shape):
def CompareEmpty(self, a_shape, b_shape):
self._compare(
np.zeros(a_shape).astype(dtype),
np.zeros(b_shape).astype(dtype), adjoint_a, adjoint_b,
use_static_shape)
np.zeros(b_shape).astype(dtype),
adjoint_a,
adjoint_b,
static_shape=use_static_shape)
CompareEmpty(self, [0, 3, 2], [0, 2, 4])
CompareEmpty(self, [3, 0, 2], [3, 2, 5])
@ -136,7 +132,6 @@ class BatchMatmulOpTest(test.TestCase):
def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape):
@test_util.run_v1_only("b/120545219")
def Test(self):
np.random.seed(42)
self._testNonEmpty(dtype, adjoint_a, adjoint_b, use_static_shape)
@ -145,6 +140,18 @@ def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape):
return Test
def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b,
use_static_shape):
@test_util.disable_xla("b/128537983")
def Test(self):
with compat.forward_compatibility_horizon(2019, 4, 19):
np.random.seed(42)
self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
return Test
class BatchMatmulGradientTest(test.TestCase):
# loss = sum(batch_matmul(x, y)). Verify dl/dx and dl/dy via the
@ -155,45 +162,125 @@ class BatchMatmulGradientTest(test.TestCase):
x = x_in if not adjoint_a else x_in.reshape(x_t_shape)
y = y_in if not adjoint_b else y_in.reshape(y_t_shape)
epsilon = np.finfo(x.dtype).eps
delta = epsilon**(1.0 / 3.0)
# Since our gradient is linear, a larger delta decreases the error.
delta = 10 * epsilon**(1.0 / 3.0)
def Loss(x, y):
z = math_ops.matmul(x, y, adjoint_a, adjoint_b)
return math_ops.reduce_sum(z)
return math_ops.reduce_sum(math_ops.matmul(x, y, adjoint_a, adjoint_b))
with self.cached_session(use_gpu=True):
((x_jacob_t, y_jacob_t),
(x_jacob_n, y_jacob_n)) = gradient_checker_v2.compute_gradient(
Loss, [x, y], delta=delta)
tol = 20 * delta
tol = 10 * delta
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=tol, atol=tol)
self.assertAllClose(y_jacob_t, y_jacob_n, rtol=tol, atol=tol)
# Tests a batched matmul of x, and y: x is a 3D tensor of shape [b,
# n, k] y is a 3D tensor of shape [b, k, m] the batched matmul
# computes z of shape [b, n, m], where z[i, :, :] = x[i, :, :]
# matmul y[i, :, :]
def _compare(self, b, n, k, m, dtype, adjoint_a, adjoint_b):
# Tests gradients of a batched matmul of x, and y
def _compare(self, a_shape, b_shape, dtype, adjoint_a, adjoint_b):
np.random.seed(42)
x = np.random.normal(0, 1, b * n * k).astype(dtype).reshape([b, n, k])
if dtype in (np.complex64, np.complex128):
x.imag = np.random.normal(0, 1,
b * n * k).astype(dtype).reshape([b, n, k])
y = np.random.normal(0, 1, b * k * m).astype(dtype).reshape([b, k, m])
if dtype in (np.complex64, np.complex128):
y.imag = np.random.normal(0, 1,
b * k * m).astype(dtype).reshape([b, k, m])
x = GetRandomNormalInput(a_shape, dtype)
y = GetRandomNormalInput(b_shape, dtype)
self._checkGrad(x, y, adjoint_a, adjoint_b)
def _GetBatchMatmulGradientTest(dtype, adjoint_a, adjoint_b):
@test_util.run_v1_only("b/120545219")
def Test(self):
self._compare(1, 2, 3, 5, dtype, adjoint_a, adjoint_b)
self._compare(3, 4, 7, 10, dtype, adjoint_a, adjoint_b)
def CheckGradients(self, a_shape, b_shape):
self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b)
CheckGradients(self, [1, 2, 3], [1, 3, 5])
CheckGradients(self, [3, 4, 7], [3, 7, 10])
return Test
def _GetBatchMatmulGradientWithBroadcastingTest(dtype, adjoint_a, adjoint_b):
@test_util.disable_xla("b/128537983")
def Test(self):
def CheckGradients(self, a_shape, b_shape):
self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b)
with compat.forward_compatibility_horizon(2019, 4, 19):
CheckGradients(self, [1, 5, 2, 3], [7, 1, 3, 2])
CheckGradients(self, [2, 3], [1, 3, 5])
CheckGradients(self, [2, 3], [5, 3, 5])
CheckGradients(self, [5, 2, 5], [5, 3])
CheckGradients(self, [5, 2, 2, 3], [3, 5])
CheckGradients(self, [4, 5, 1, 2, 3], [1, 1, 3, 5])
CheckGradients(self, [1, 2, 1, 4, 2, 1, 3, 4], [3, 2, 1, 1, 1, 2, 4, 2])
return Test
class BatchMatMulBenchmark(test.Benchmark):
# Batch sizes are 512.
shape_pairs = [
# Typical fully connected layer.
((4, 8, 4, 2, 1, 1024), (1024, 1024)),
((4, 1, 4, 1, 1, 1024), (1, 8, 1, 2, 1024, 1024)),
# Square matmul.
((4, 8, 4, 2, 512, 512), (512, 512)),
((4, 1, 4, 1, 512, 512), (1, 8, 1, 2, 512, 512)),
# Matrix-vector multiplies.
((4, 8, 4, 2, 10000, 200), (200, 1)),
((4, 1, 4, 1, 10000, 200), (1, 8, 1, 2, 200, 1)),
# Vector-matrix multiplies.
((4, 8, 4, 2, 1, 200), (200, 10000)),
((4, 1, 4, 1, 1, 200), (1, 8, 1, 2, 200, 10000)),
]
def benchmarkBatchMatMulBroadcast(self):
for (a_shape, b_shape) in self.shape_pairs:
with compat.forward_compatibility_horizon(2019, 4, 19):
with ops.Graph().as_default(), \
session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix_a = variables.Variable(
GetRandomNormalInput(a_shape, np.float32))
matrix_b = variables.Variable(
GetRandomNormalInput(b_shape, np.float32))
variables.global_variables_initializer().run()
# Use batch matmul op's internal broadcasting.
self.run_op_benchmark(
sess,
math_ops.matmul(matrix_a, matrix_b),
min_iters=50,
name="batch_matmul_cpu_{}_{}".format(a_shape, b_shape))
# Manually broadcast the input matrices using the broadcast_to op.
broadcasted_batch_shape = array_ops.broadcast_static_shape(
matrix_a.shape[:-2], matrix_b.shape[:-2])
broadcasted_a_shape = broadcasted_batch_shape.concatenate(
matrix_a.shape[-2:])
broadcasted_b_shape = broadcasted_batch_shape.concatenate(
matrix_b.shape[-2:])
self.run_op_benchmark(
sess,
math_ops.matmul(
array_ops.broadcast_to(matrix_a, broadcasted_a_shape),
array_ops.broadcast_to(matrix_b, broadcasted_b_shape)),
min_iters=50,
name="batch_matmul_manual_broadcast_cpu_{}_{}".format(
a_shape, b_shape))
# Use linear_operator_util.matmul_with_broadcast.
name_template = (
"batch_matmul_manual_broadcast_with_linear_operator_util"
"_cpu_{}_{}"
)
self.run_op_benchmark(
sess,
linear_operator_util.matmul_with_broadcast(matrix_a, matrix_b),
min_iters=50,
name=name_template.format(a_shape, b_shape))
if __name__ == "__main__":
for dtype_ in [
np.float16, np.float32, np.float64, np.complex64, np.complex128, np.int32
@ -201,13 +288,27 @@ if __name__ == "__main__":
for adjoint_a_ in False, True:
for adjoint_b_ in False, True:
name = "%s_%s_%s" % (dtype_.__name__, adjoint_a_, adjoint_b_)
# TF2 does not support placeholders under eager so we skip it
# TF2 does not support placeholders under eager so we skip it.
for use_static_shape_ in set([True, tf2.enabled()]):
setattr(BatchMatmulOpTest,
"testBatchMatmulOp_" + name + ("_%s" % use_static_shape_),
_GetBatchMatmulOpTest(dtype_, adjoint_a_, adjoint_b_,
use_static_shape_))
if dtype_ is not np.int32:
setattr(BatchMatmulGradientTest, "testBatchMatmulGradient_" + name,
_GetBatchMatmulGradientTest(dtype_, adjoint_a_, adjoint_b_))
setattr(
BatchMatmulOpTest,
"testBatchMatmulOp_" + name + "_{}".format(use_static_shape_),
_GetBatchMatmulOpTest(dtype_, adjoint_a_, adjoint_b_,
use_static_shape_))
# Broadcasting is supported only in v2.
setattr(
BatchMatmulOpTest, "testBatchMatmulBroadcasting_" + name +
("_%s" % use_static_shape_),
_GetBatchMatmulOpBroadcastingTest(dtype_, adjoint_a_, adjoint_b_,
use_static_shape_))
if dtype_ == np.int32:
continue
setattr(BatchMatmulGradientTest, "testBatchMatmulGradient_" + name,
_GetBatchMatmulGradientTest(dtype_, adjoint_a_, adjoint_b_))
# Broadcasting is supported only in v2.
setattr(
BatchMatmulGradientTest,
"testBatchMatmulGradientWithBroadcasting_" + name,
_GetBatchMatmulGradientWithBroadcastingTest(dtype_, adjoint_a_,
adjoint_b_))
test.main()

View File

@ -1473,6 +1473,44 @@ def _BatchMatMul(op, grad):
return grad_x, grad_y
@ops.RegisterGradient("BatchMatMulV2")
def _BatchMatMulV2(op, grad):
"""Returns the gradient of x and y given the gradient of x * y."""
x = op.inputs[0]
y = op.inputs[1]
adj_x = op.get_attr("adj_x")
adj_y = op.get_attr("adj_y")
if not adj_x:
if not adj_y:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
else:
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
else:
if not adj_y:
grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
else:
grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
# Reduce along the broadcasted batch dimensions, if broadcasting is required.
shape_x_static = x.get_shape()
shape_y_static = y.get_shape()
if not (shape_x_static.is_fully_defined() and
shape_y_static.is_fully_defined() and
shape_x_static == shape_y_static):
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
return grad_x, grad_y
ops.NotDifferentiable("Range")
ops.NotDifferentiable("LinSpace")

View File

@ -73,6 +73,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.compat import compat as fwd_compat
from tensorflow.python.eager import context
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
@ -2547,9 +2548,20 @@ def matmul(a,
# TODO(apassos) remove _shape_tuple here when it is not needed.
a_shape = a._shape_tuple() # pylint: disable=protected-access
b_shape = b._shape_tuple() # pylint: disable=protected-access
if fwd_compat.forward_compatible(2019, 4, 18):
output_may_have_non_empty_batch_shape = (
(a_shape is None or len(a_shape) > 2) or
(b_shape is None or len(b_shape) > 2))
batch_mat_mul_fn = gen_math_ops.batch_mat_mul_v2
else:
output_may_have_non_empty_batch_shape = (
(a_shape is None or len(a_shape) > 2) and
(b_shape is None or len(b_shape) > 2))
batch_mat_mul_fn = gen_math_ops.batch_mat_mul
if (not a_is_sparse and
not b_is_sparse) and ((a_shape is None or len(a_shape) > 2) and
(b_shape is None or len(b_shape) > 2)):
not b_is_sparse) and output_may_have_non_empty_batch_shape:
# BatchMatmul does not support transpose, so we conjugate the matrix and
# use adjoint instead. Conj() is a noop for real matrices.
if transpose_a:
@ -2558,8 +2570,7 @@ def matmul(a,
if transpose_b:
b = conj(b)
adjoint_b = True
return gen_math_ops.batch_mat_mul(
a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
return batch_mat_mul_fn(a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
# Neither matmul nor sparse_matmul support adjoint, so we conjugate
# the matrix and use transpose instead. Conj() is a noop for real

View File

@ -316,6 +316,10 @@ tf_module {
name: "BatchMatMul"
argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "BatchMatMulV2"
argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "BatchMatrixBandPart"
argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -316,6 +316,10 @@ tf_module {
name: "BatchMatMul"
argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "BatchMatMulV2"
argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "BatchMatrixBandPart"
argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "