Add broadcasting support to tf.matmul.
Add Numpy-style broadcasting in the batch dimensions for tf.matmul op. The last two dimensions of both operands constitute the matrix dimensions. The dimensions beyond these are broadcasted to form a common output shape with the standard NumPy broadcasting rules. (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) Note: This implementation differs from Numpy's behavior in that vectors (rank-1 Tensors) are not promoted to matrices (rank-2 Tensors) by appending/prepending dimensions. PiperOrigin-RevId: 241040476
This commit is contained in:
parent
51dbafd723
commit
47ab68d265
tensorflow
contrib/makefile
core
api_def
framework
kernels
BUILDbatch_matmul_op_common.ccbatch_matmul_op_common.hbatch_matmul_op_common_test.ccbatch_matmul_op_impl.hbatch_matmul_op_test.cc
ops
python
tools/api/golden
@ -8,6 +8,7 @@ tensorflow/contrib/boosted_trees/ops/training_ops.cc
|
||||
tensorflow/core/kernels/aggregate_ops.cc
|
||||
tensorflow/core/kernels/argmax_op.cc
|
||||
tensorflow/core/kernels/avgpooling_op.cc
|
||||
tensorflow/core/kernels/batch_matmul_op_common.cc
|
||||
tensorflow/core/kernels/batch_matmul_op_real.cc
|
||||
tensorflow/core/kernels/batch_norm_op.cc
|
||||
tensorflow/core/kernels/batchtospace_op.cc
|
||||
|
59
tensorflow/core/api_def/base_api/api_def_BatchMatMulV2.pbtxt
Normal file
59
tensorflow/core/api_def/base_api/api_def_BatchMatMulV2.pbtxt
Normal file
@ -0,0 +1,59 @@
|
||||
op {
|
||||
graph_op_name: "BatchMatMulV2"
|
||||
in_arg {
|
||||
name: "x"
|
||||
description: <<END
|
||||
2-D or higher with shape `[..., r_x, c_x]`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "y"
|
||||
description: <<END
|
||||
2-D or higher with shape `[..., r_y, c_y]`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
3-D or higher with shape `[..., r_o, c_o]`
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "adj_x"
|
||||
description: <<END
|
||||
If `True`, adjoint the slices of `x`. Defaults to `False`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "adj_y"
|
||||
description: <<END
|
||||
If `True`, adjoint the slices of `y`. Defaults to `False`.
|
||||
END
|
||||
}
|
||||
summary: "Multiplies slices of two tensors in batches."
|
||||
description: <<END
|
||||
Multiplies all slices of `Tensor` `x` and `y` (each slice can be
|
||||
viewed as an element of a batch), and arranges the individual results
|
||||
in a single output tensor of the same batch size. Each of the
|
||||
individual slices can optionally be adjointed (to adjoint a matrix
|
||||
means to transpose and conjugate it) before multiplication by setting
|
||||
the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
|
||||
|
||||
The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
|
||||
and `[..., r_y, c_y]`.
|
||||
|
||||
The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
|
||||
|
||||
r_o = c_x if adj_x else r_x
|
||||
c_o = r_y if adj_y else c_y
|
||||
|
||||
It is computed as:
|
||||
|
||||
output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
|
||||
|
||||
*NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More
|
||||
about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
|
||||
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "BatchMatMulV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -233,6 +233,43 @@ Status MatMulShape(shape_inference::InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
|
||||
ShapeHandle a_shape;
|
||||
ShapeHandle b_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
|
||||
|
||||
// Determine output rows and columns.
|
||||
bool adj_x;
|
||||
bool adj_y;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
|
||||
DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
|
||||
DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
|
||||
|
||||
// Inner dimensions should be compatible.
|
||||
DimensionHandle inner_merged;
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
|
||||
c->Dim(b_shape, adj_y ? -1 : -2), &inner_merged));
|
||||
|
||||
// Batch dimensions should broadcast with each other.
|
||||
ShapeHandle a_batch_shape;
|
||||
ShapeHandle b_batch_shape;
|
||||
ShapeHandle output_batch_shape;
|
||||
TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_shape));
|
||||
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
|
||||
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, a_batch_shape, b_batch_shape, &output_batch_shape));
|
||||
|
||||
ShapeHandle output_shape;
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(
|
||||
output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape));
|
||||
|
||||
c->set_output(0, output_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BiasAddShape(shape_inference::InferenceContext* c) {
|
||||
ShapeHandle input_shape;
|
||||
|
||||
|
@ -226,6 +226,10 @@ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
|
||||
// Shape function for MatMul-like operations.
|
||||
Status MatMulShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for Batched MatMul-like operations with broadcasting across
|
||||
// batch dimensions.
|
||||
Status BatchMatMulV2Shape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for BiasAdd-like operations.
|
||||
Status BiasAddShape(shape_inference::InferenceContext* c);
|
||||
|
||||
|
@ -213,6 +213,74 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, BatchMatMulV2_ShapeFn) {
|
||||
ShapeInferenceTestOp op("BatchMatMulV2");
|
||||
auto set_adj = [&op](bool adj_x, bool adj_y) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMulV2")
|
||||
.Input({"a", 0, DT_FLOAT})
|
||||
.Input({"b", 0, DT_FLOAT})
|
||||
.Attr("adj_x", adj_x)
|
||||
.Attr("adj_y", adj_y)
|
||||
.Finalize(&op.node_def));
|
||||
};
|
||||
|
||||
set_adj(false, false);
|
||||
|
||||
// Rank checks.
|
||||
INFER_ERROR("at least rank 2", op, "[];?");
|
||||
INFER_ERROR("at least rank 2", op, "[1];?");
|
||||
INFER_ERROR("at least rank 2", op, "?;[]");
|
||||
INFER_ERROR("at least rank 2", op, "?;[2]");
|
||||
|
||||
INFER_OK(op, "?;?", "?");
|
||||
|
||||
// 0 batch dims.
|
||||
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
|
||||
|
||||
// 1 batch dims.
|
||||
INFER_OK(op, "[3,?,?];[3,?,?]", "[d0_0,d0_1,d1_2]");
|
||||
INFER_OK(op, "[?,?,?];[1,?,?]", "[d0_0,d0_1,d1_2]");
|
||||
INFER_OK(op, "[?,?,?];[2,?,?]", "[d1_0,d0_1,d1_2]");
|
||||
INFER_OK(op, "[1,?,?];[?,?,?]", "[d1_0,d0_1,d1_2]");
|
||||
INFER_OK(op, "[2,?,?];[?,?,?]", "[d0_0,d0_1,d1_2]");
|
||||
INFER_OK(op, "[?,?,?];[?,?,?]", "[?,d0_1,d1_2]");
|
||||
|
||||
// Empty batch dim with broadcasting.
|
||||
INFER_OK(op, "[?,?];[?,?,?]", "[d1_0,d0_0,d1_2]");
|
||||
INFER_OK(op, "[?,?,?];[?,?]", "[d0_0,d0_1,d1_1]");
|
||||
INFER_OK(op, "[?,?];[?,?,?,?]", "[d1_0,d1_1,d0_0,d1_3]");
|
||||
INFER_OK(op, "[?,?,?,?];[?,?]", "[d0_0,d0_1,d0_2,d1_1]");
|
||||
|
||||
// Unknown number of batch dims.
|
||||
INFER_OK(op, "[?,?];?", "?");
|
||||
INFER_OK(op, "?;[?,?]", "?");
|
||||
INFER_OK(op, "[?,?,?,?];?", "?");
|
||||
|
||||
// Large number of batch dims.
|
||||
INFER_OK(op, "[?,?,?,?,?];[1,?,?]", "[d0_0,d0_1,d0_2,d0_3,d1_2]");
|
||||
INFER_OK(op, "[1,?,?];[?,?,?,?,?]", "[d1_0,d1_1,d1_2,d0_1,d1_4]");
|
||||
|
||||
// Batch dim mismatch.
|
||||
INFER_ERROR("are 2 and 3", op, "[?,?,2,?,?];[3,?,?]");
|
||||
INFER_ERROR("are 2 and 3", op, "[2,?,?];[?,?,3,?,?]");
|
||||
|
||||
// Test adj_a, testing output and that inner dims are compared.
|
||||
set_adj(false, false);
|
||||
INFER_OK(op, "[2,2,3,4];[2,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch
|
||||
set_adj(true, false);
|
||||
INFER_OK(op, "[2,2,3,4];[2,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]");
|
||||
INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch
|
||||
|
||||
// Test adj_b=true.
|
||||
set_adj(false, true);
|
||||
INFER_OK(op, "[2,2,?,?];[2,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]");
|
||||
INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch
|
||||
set_adj(true, true);
|
||||
INFER_OK(op, "[2,2,?,?];[2,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]");
|
||||
INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
OpRegistrationData op_reg_data;
|
||||
TF_CHECK_OK(OpDefBuilder("BiasAdd")
|
||||
|
@ -3572,12 +3572,26 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "batch_matmul_op_common_test",
|
||||
size = "small",
|
||||
srcs = ["batch_matmul_op_common_test.cc"],
|
||||
deps = [
|
||||
":batch_matmul_op",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "batch_matmul_op_test",
|
||||
size = "small",
|
||||
srcs = ["batch_matmul_op_test.cc"],
|
||||
deps = [
|
||||
":batch_matmul_op",
|
||||
":broadcast_to_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -5418,6 +5432,8 @@ filegroup(
|
||||
name = "mobile_srcs",
|
||||
srcs = [
|
||||
"avgpooling_op.h",
|
||||
"batch_matmul_op_common.cc",
|
||||
"batch_matmul_op_common.h",
|
||||
"batch_util.h",
|
||||
"cwise_ops.h",
|
||||
"cwise_ops_common.h",
|
||||
@ -5912,6 +5928,7 @@ filegroup(
|
||||
"*_3d*",
|
||||
"*.cu.*",
|
||||
# Ops already in android_srcs
|
||||
"batch_matmul_op_common.cc",
|
||||
"pooling_ops_common.cc",
|
||||
# Ops which we are currently excluding because they are likely
|
||||
# not used on Android. Those ops also do not compile if included,
|
||||
|
76
tensorflow/core/kernels/batch_matmul_op_common.cc
Normal file
76
tensorflow/core/kernels/batch_matmul_op_common.cc
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Returns the mapping from the output batch indices to the corresponding
|
||||
// input's batch indices, given the input's "reshape" and "bcast" shapes as
|
||||
// returned by the BCast helper class. The i'th element denotes the (flattened)
|
||||
// batch index of the input that must be used to compute the i'th batch output.
|
||||
void ComputeBatchIndices(const int64 output_batch_size,
|
||||
const MatMulBCast::Vec& reshape,
|
||||
const MatMulBCast::Vec& bcast,
|
||||
std::vector<int64>* out_indices) {
|
||||
// Populates the mapping in out_indices. This algorithm is identical to
|
||||
// the following steps:
|
||||
// - Reshape {0, 1, ..., input_batch_size - 1} to the input shape.
|
||||
// - Broadcast to the output shape.
|
||||
// - Reshape back to a flat 1D vector.
|
||||
out_indices->resize(output_batch_size);
|
||||
int64 num_output_elements = 1;
|
||||
int64 num_input_elements = 1;
|
||||
for (int64 i = reshape.size() - 1; i >= 0; --i) {
|
||||
// Replicate the already populated mapping an additional (dim - 1) times.
|
||||
// If we are broadcasting, just copy the existing mapping.
|
||||
// Otherwise, add another dimension from the input shape.
|
||||
const int64 dim = std::max(reshape[i], bcast[i]);
|
||||
const int64 incr = bcast[i] > 1 ? 0 : num_input_elements;
|
||||
for (int64 k = 0; k < (dim - 1) * num_output_elements; ++k) {
|
||||
(*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr;
|
||||
}
|
||||
num_output_elements *= dim;
|
||||
num_input_elements *= reshape[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MatMulBCast::MatMulBCast(Vec x, Vec y) {
|
||||
if (x.size() < 2 || y.size() < 2) return;
|
||||
x.resize(x.size() - 2);
|
||||
y.resize(y.size() - 2);
|
||||
|
||||
batch_bcast_ = absl::make_unique<BCast>(std::move(x), std::move(y));
|
||||
if (!batch_bcast_->IsValid()) return;
|
||||
|
||||
x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements();
|
||||
y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements();
|
||||
output_shape_ = TensorShape(batch_bcast_->output_shape());
|
||||
output_batch_size_ = output_shape_.num_elements();
|
||||
broadcasting_required_ =
|
||||
std::min(x_batch_size_, y_batch_size_) != output_batch_size_;
|
||||
|
||||
if (broadcasting_required_) {
|
||||
ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(),
|
||||
batch_bcast_->x_bcast(), &x_batch_indices_);
|
||||
ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(),
|
||||
batch_bcast_->y_bcast(), &y_batch_indices_);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
70
tensorflow/core/kernels/batch_matmul_op_common.h
Normal file
70
tensorflow/core/kernels/batch_matmul_op_common.h
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Simple wrapper over BCast specialized for MatMul.
|
||||
// Provides utilities for broadcasting across batch dimensions for binary
|
||||
// MatMul-like operations.
|
||||
class MatMulBCast {
|
||||
public:
|
||||
using Vec = BCast::Vec;
|
||||
|
||||
MatMulBCast(Vec x, Vec y);
|
||||
|
||||
bool IsValid() const { return batch_bcast_ && batch_bcast_->IsValid(); }
|
||||
bool IsBroadcastingRequired() const { return broadcasting_required_; }
|
||||
|
||||
const int64 output_batch_size() const { return output_batch_size_; }
|
||||
const int64 x_batch_size() const { return x_batch_size_; }
|
||||
const int64 y_batch_size() const { return y_batch_size_; }
|
||||
const TensorShape& output_batch_shape() const { return output_shape_; }
|
||||
|
||||
// Returns the mapping from the flattened output batch indices to x's
|
||||
// flattened batch indices. The result is a vector of length
|
||||
// output_batch_size(). To compute the i'th batch output, a binary matmul-like
|
||||
// operation should use the `x_batch_indices()[i]`th batch index of `x`.
|
||||
// Note: Returns an empty vector if broadcasting is not required. Callers
|
||||
// should only use this when IsBroadcastingRequired() returns true.
|
||||
const std::vector<int64>& x_batch_indices() const { return x_batch_indices_; }
|
||||
// Returns the mapping from the flattened output batch indices to y's
|
||||
// flattened batch indices. Similar to x_batch_indices().
|
||||
// Note: Returns an empty vector if broadcasting is not required. Callers
|
||||
// should only use this when IsBroadcastingRequired() returns true.
|
||||
const std::vector<int64>& y_batch_indices() const { return y_batch_indices_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<BCast> batch_bcast_;
|
||||
bool broadcasting_required_ = false;
|
||||
int64 x_batch_size_;
|
||||
int64 y_batch_size_;
|
||||
TensorShape output_shape_;
|
||||
int64 output_batch_size_;
|
||||
std::vector<int64> x_batch_indices_;
|
||||
std::vector<int64> y_batch_indices_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
|
138
tensorflow/core/kernels/batch_matmul_op_common_test.cc
Normal file
138
tensorflow/core/kernels/batch_matmul_op_common_test.cc
Normal file
@ -0,0 +1,138 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_common.h"
|
||||
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
string MatMulBCastToStr(const MatMulBCast& b) {
|
||||
if (!b.IsValid()) {
|
||||
return "invalid";
|
||||
}
|
||||
string ret;
|
||||
strings::StrAppend(
|
||||
&ret, "[", str_util::Join(b.output_batch_shape().dim_sizes(), ","), "]");
|
||||
strings::StrAppend(&ret, "[", str_util::Join(b.x_batch_indices(), ","), "]");
|
||||
strings::StrAppend(&ret, "[", str_util::Join(b.y_batch_indices(), ","), "]");
|
||||
return ret;
|
||||
}
|
||||
|
||||
TEST(MatMulBCastTest, SimpleBroadcast) {
|
||||
MatMulBCast bcast({1, 5, 3}, {4, 3, 7});
|
||||
|
||||
EXPECT_TRUE(bcast.IsValid());
|
||||
EXPECT_TRUE(bcast.IsBroadcastingRequired());
|
||||
|
||||
EXPECT_EQ(1, bcast.x_batch_size());
|
||||
EXPECT_EQ(4, bcast.y_batch_size());
|
||||
EXPECT_EQ(4, bcast.output_batch_size());
|
||||
|
||||
EXPECT_EQ("[4][0,0,0,0][0,1,2,3]", MatMulBCastToStr(bcast));
|
||||
}
|
||||
|
||||
TEST(MatMulBCastTest, EmptyBatchBroadcast) {
|
||||
MatMulBCast bcast({5, 3}, {3, 7});
|
||||
|
||||
EXPECT_TRUE(bcast.IsValid());
|
||||
EXPECT_FALSE(bcast.IsBroadcastingRequired());
|
||||
|
||||
EXPECT_EQ(1, bcast.x_batch_size());
|
||||
EXPECT_EQ(1, bcast.y_batch_size());
|
||||
EXPECT_EQ(1, bcast.output_batch_size());
|
||||
|
||||
EXPECT_EQ("[][][]", MatMulBCastToStr(bcast));
|
||||
}
|
||||
|
||||
TEST(MatMulBCastTest, BroadcastingNotRequired) {
|
||||
MatMulBCast bcast({2, 4, 6, 5, 3}, {2, 4, 6, 3, 7});
|
||||
|
||||
EXPECT_TRUE(bcast.IsValid());
|
||||
EXPECT_FALSE(bcast.IsBroadcastingRequired());
|
||||
|
||||
EXPECT_EQ(48, bcast.x_batch_size());
|
||||
EXPECT_EQ(48, bcast.y_batch_size());
|
||||
EXPECT_EQ(48, bcast.output_batch_size());
|
||||
|
||||
EXPECT_EQ("[2,4,6][][]", MatMulBCastToStr(bcast));
|
||||
}
|
||||
|
||||
TEST(MatMulBCastTest, EmptyWithNonEmptyBatchBroadcast) {
|
||||
MatMulBCast bcast1({5, 3}, {6, 3, 7});
|
||||
|
||||
EXPECT_TRUE(bcast1.IsValid());
|
||||
EXPECT_TRUE(bcast1.IsBroadcastingRequired());
|
||||
|
||||
EXPECT_EQ(1, bcast1.x_batch_size());
|
||||
EXPECT_EQ(6, bcast1.y_batch_size());
|
||||
EXPECT_EQ(6, bcast1.output_batch_size());
|
||||
EXPECT_EQ("[6][0,0,0,0,0,0][0,1,2,3,4,5]", MatMulBCastToStr(bcast1));
|
||||
|
||||
MatMulBCast bcast2({2, 5, 3}, {3, 7});
|
||||
EXPECT_TRUE(bcast2.IsValid());
|
||||
EXPECT_TRUE(bcast2.IsBroadcastingRequired());
|
||||
|
||||
EXPECT_EQ(2, bcast2.x_batch_size());
|
||||
EXPECT_EQ(1, bcast2.y_batch_size());
|
||||
EXPECT_EQ(2, bcast2.output_batch_size());
|
||||
EXPECT_EQ("[2][0,1][0,0]", MatMulBCastToStr(bcast2));
|
||||
}
|
||||
|
||||
TEST(MatMulBCastTest, InvalidDimensions) {
|
||||
// Too few dimensions.
|
||||
MatMulBCast bcast1({3, 3}, {3});
|
||||
EXPECT_FALSE(bcast1.IsValid());
|
||||
|
||||
MatMulBCast bcast2({3}, {3, 3});
|
||||
EXPECT_FALSE(bcast2.IsValid());
|
||||
|
||||
// Batch dimensions not broadcastable.
|
||||
MatMulBCast bcast3({4, 5, 3}, {2, 3, 7});
|
||||
EXPECT_FALSE(bcast3.IsValid());
|
||||
|
||||
MatMulBCast bcast4({2, 1, 5, 3}, {1, 3, 1, 3, 7});
|
||||
EXPECT_FALSE(bcast4.IsValid());
|
||||
}
|
||||
|
||||
TEST(MatMulBCastTest, BroadcastBothOperands) {
|
||||
MatMulBCast bcast({3, 1, 5, 3}, {1, 4, 3, 7});
|
||||
EXPECT_TRUE(bcast.IsValid());
|
||||
|
||||
EXPECT_EQ(3, bcast.x_batch_size());
|
||||
EXPECT_EQ(4, bcast.y_batch_size());
|
||||
EXPECT_EQ(12, bcast.output_batch_size());
|
||||
|
||||
EXPECT_EQ("[3,4][0,0,0,0,1,1,1,1,2,2,2,2][0,1,2,3,0,1,2,3,0,1,2,3]",
|
||||
MatMulBCastToStr(bcast));
|
||||
}
|
||||
|
||||
TEST(MatMulBCastTest, DifferentRanks) {
|
||||
MatMulBCast bcast({3, 1, 5, 3}, {2, 1, 2, 3, 7});
|
||||
EXPECT_TRUE(bcast.IsValid());
|
||||
|
||||
EXPECT_EQ(3, bcast.x_batch_size());
|
||||
EXPECT_EQ(4, bcast.y_batch_size());
|
||||
EXPECT_EQ(12, bcast.output_batch_size());
|
||||
|
||||
EXPECT_EQ("[2,3,2][0,0,1,1,2,2,0,0,1,1,2,2][0,1,0,1,0,1,2,3,2,3,2,3]",
|
||||
MatMulBCastToStr(bcast));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -29,7 +30,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_common.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
@ -74,8 +78,8 @@ struct ParallelMatMulKernel {
|
||||
}
|
||||
|
||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor in_y, bool adj_x, bool adj_y, Tensor* out,
|
||||
int start, int limit) {
|
||||
const Tensor in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
|
||||
static_assert(IsComplex, "Complex type expected.");
|
||||
auto Tx = in_x.tensor<Scalar, 3>();
|
||||
auto Ty = in_y.tensor<Scalar, 3>();
|
||||
@ -88,14 +92,21 @@ struct ParallelMatMulKernel {
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
||||
contract_pairs[0] = ContractionDims(adj_x, adj_y);
|
||||
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
|
||||
for (int i = start; i < limit; ++i) {
|
||||
auto x = Tx.template chip<0>(i);
|
||||
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = start; i < limit; ++i) {
|
||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||
|
||||
auto x = Tx.template chip<0>(x_batch_index);
|
||||
auto z = Tz.template chip<0>(i);
|
||||
if (adj_x != adj_y) {
|
||||
auto y = Ty.template chip<0>(i).conjugate();
|
||||
auto y = Ty.template chip<0>(y_batch_index).conjugate();
|
||||
z.device(d) = x.contract(y, contract_pairs);
|
||||
} else {
|
||||
auto y = Ty.template chip<0>(i);
|
||||
auto y = Ty.template chip<0>(y_batch_index);
|
||||
z.device(d) = x.contract(y, contract_pairs);
|
||||
}
|
||||
}
|
||||
@ -110,18 +121,25 @@ struct ParallelMatMulKernel<Scalar, false> {
|
||||
static void Conjugate(const OpKernelContext* context, Tensor* out) {}
|
||||
|
||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
|
||||
int start, int limit) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
|
||||
auto Tx = in_x.tensor<Scalar, 3>();
|
||||
auto Ty = in_y.tensor<Scalar, 3>();
|
||||
auto Tz = out->tensor<Scalar, 3>();
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
||||
contract_pairs[0] = ContractionDims(adj_x, adj_y);
|
||||
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
|
||||
for (int i = start; i < limit; ++i) {
|
||||
auto x = Tx.template chip<0>(i);
|
||||
auto y = Ty.template chip<0>(i);
|
||||
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = start; i < limit; ++i) {
|
||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||
auto x = Tx.template chip<0>(x_batch_index);
|
||||
auto y = Ty.template chip<0>(y_batch_index);
|
||||
auto z = Tz.template chip<0>(i);
|
||||
|
||||
z.device(d) = x.contract(y, contract_pairs);
|
||||
}
|
||||
}
|
||||
@ -151,10 +169,16 @@ struct SequentialMatMulKernel {
|
||||
}
|
||||
|
||||
static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
|
||||
bool adj_y, Tensor* out, int start, int limit) {
|
||||
for (int i = start; i < limit; ++i) {
|
||||
auto x = ConstTensorSliceToEigenMatrix(in_x, i);
|
||||
auto y = ConstTensorSliceToEigenMatrix(in_y, i);
|
||||
bool adj_y, const MatMulBCast& bcast, Tensor* out, int start,
|
||||
int limit) {
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = start; i < limit; ++i) {
|
||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||
auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
|
||||
auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
|
||||
auto z = TensorSliceToEigenMatrix(out, i);
|
||||
if (!adj_x) {
|
||||
if (!adj_y) {
|
||||
@ -181,13 +205,14 @@ struct LaunchBatchMatMul;
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
|
||||
ParallelMatMulKernel;
|
||||
bool conjugate_result = false;
|
||||
|
||||
// Number of matrix multiplies i.e. size of the batch.
|
||||
const int64 batch_size = in_x.dim_size(0);
|
||||
const int64 batch_size = bcast.output_batch_size();
|
||||
const int64 cost_per_unit =
|
||||
in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
|
||||
const int64 small_dim = std::min(
|
||||
@ -199,17 +224,17 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
||||
// Parallelize over inner dims.
|
||||
// For large matrix products it is counter-productive to parallelize
|
||||
// over the batch dimension.
|
||||
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, out, 0,
|
||||
batch_size);
|
||||
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, bcast, out,
|
||||
0, batch_size);
|
||||
conjugate_result = adj_x;
|
||||
} else {
|
||||
// Parallelize over outer dims. For small matrices and large batches, it
|
||||
// is counter-productive to parallelize the inner matrix multiplies.
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||
cost_per_unit,
|
||||
[&in_x, &in_y, adj_x, adj_y, out](int start, int limit) {
|
||||
SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y, out,
|
||||
start, limit);
|
||||
[&in_x, &in_y, adj_x, adj_y, &bcast, out](int start, int limit) {
|
||||
SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
|
||||
bcast, out, start, limit);
|
||||
});
|
||||
}
|
||||
if (conjugate_result) {
|
||||
@ -270,7 +295,8 @@ class CublasScratchAllocator : public se::ScratchAllocator {
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
constexpr se::blas::Transpose kTranspose =
|
||||
is_complex<Scalar>::value ? se::blas::Transpose::kConjugateTranspose
|
||||
: se::blas::Transpose::kTranspose;
|
||||
@ -279,7 +305,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
|
||||
const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
|
||||
const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
|
||||
const uint64 batch_size = in_x.dim_size(0);
|
||||
const int64 batch_size = bcast.output_batch_size();
|
||||
auto blas_transpose_a = trans[adj_x];
|
||||
auto blas_transpose_b = trans[adj_y];
|
||||
|
||||
@ -293,8 +319,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
std::vector<DeviceMemoryType*> a_ptrs;
|
||||
std::vector<DeviceMemoryType*> b_ptrs;
|
||||
std::vector<DeviceMemoryType*> c_ptrs;
|
||||
a_device_memory.reserve(batch_size);
|
||||
b_device_memory.reserve(batch_size);
|
||||
a_device_memory.reserve(bcast.x_batch_size());
|
||||
b_device_memory.reserve(bcast.y_batch_size());
|
||||
c_device_memory.reserve(batch_size);
|
||||
a_ptrs.reserve(batch_size);
|
||||
b_ptrs.reserve(batch_size);
|
||||
@ -302,13 +328,31 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
||||
auto* c_base_ptr = out->template flat<Scalar>().data();
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
|
||||
if (!bcast.IsBroadcastingRequired()) {
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
} else {
|
||||
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
||||
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
}
|
||||
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
}
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
||||
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
}
|
||||
|
||||
typedef Scalar Coefficient;
|
||||
@ -385,7 +429,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
template <>
|
||||
struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
typedef Eigen::half Scalar;
|
||||
constexpr perftools::gputools::blas::Transpose kTranspose =
|
||||
is_complex<Scalar>::value
|
||||
@ -396,7 +441,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
|
||||
const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
|
||||
const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
|
||||
const uint64 batch_size = in_x.dim_size(0);
|
||||
const uint64 batch_size = bcast.output_batch_size();
|
||||
auto blas_transpose_a = trans[adj_x];
|
||||
auto blas_transpose_b = trans[adj_y];
|
||||
|
||||
@ -410,8 +455,8 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
std::vector<DeviceMemoryType*> a_ptrs;
|
||||
std::vector<DeviceMemoryType*> b_ptrs;
|
||||
std::vector<DeviceMemoryType*> c_ptrs;
|
||||
a_device_memory.reserve(batch_size);
|
||||
b_device_memory.reserve(batch_size);
|
||||
a_device_memory.reserve(bcast.x_batch_size());
|
||||
b_device_memory.reserve(bcast.y_batch_size());
|
||||
c_device_memory.reserve(batch_size);
|
||||
a_ptrs.reserve(batch_size);
|
||||
b_ptrs.reserve(batch_size);
|
||||
@ -419,13 +464,31 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
||||
auto* c_base_ptr = out->template flat<Scalar>().data();
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
|
||||
if (!bcast.IsBroadcastingRequired()) {
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
} else {
|
||||
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
||||
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
}
|
||||
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
}
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
||||
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
}
|
||||
|
||||
typedef float Coefficient;
|
||||
@ -480,17 +543,24 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
template <typename Scalar>
|
||||
struct ParallelMatMulKernelSYCL {
|
||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
|
||||
int start, int limit) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
|
||||
auto Tx = in_x.tensor<Scalar, 3>();
|
||||
auto Ty = in_y.tensor<Scalar, 3>();
|
||||
auto Tz = out->tensor<Scalar, 3>();
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
||||
contract_pairs[0] = ContractionDims(adj_x, adj_y);
|
||||
auto d = context->eigen_sycl_device();
|
||||
for (int i = start; i < limit; ++i) {
|
||||
auto x = Tx.template chip<0>(i);
|
||||
auto y = Ty.template chip<0>(i);
|
||||
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = start; i < limit; ++i) {
|
||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||
|
||||
auto x = Tx.template chip<0>(x_batch_index);
|
||||
auto y = Ty.template chip<0>(y_batch_index);
|
||||
auto z = Tz.template chip<0>(i);
|
||||
z.device(d) = x.contract(y, contract_pairs);
|
||||
}
|
||||
@ -500,54 +570,58 @@ struct ParallelMatMulKernelSYCL {
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatMul<SYCLDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
// Number of matrix multiplies i.e. size of the batch.
|
||||
const int64 batch_size = in_x.dim_size(0);
|
||||
const int64 batch_size = bcast.output_batch_size();
|
||||
ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
|
||||
out, 0, batch_size);
|
||||
bcast, out, 0, batch_size);
|
||||
}
|
||||
};
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
template <typename Device, typename Scalar>
|
||||
class BatchMatMul : public OpKernel {
|
||||
class BaseBatchMatMulOp : public OpKernel {
|
||||
public:
|
||||
explicit BatchMatMul(OpKernelConstruction* context) : OpKernel(context) {
|
||||
explicit BaseBatchMatMulOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
||||
}
|
||||
|
||||
virtual ~BatchMatMul() {}
|
||||
~BaseBatchMatMulOp() override {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& in0 = ctx->input(0);
|
||||
const Tensor& in1 = ctx->input(1);
|
||||
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
|
||||
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
|
||||
in0.shape().DebugString(), " vs. ",
|
||||
in1.shape().DebugString()));
|
||||
const int ndims = in0.dims();
|
||||
|
||||
ValidateInputTensors(ctx, in0, in1);
|
||||
|
||||
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
|
||||
OP_REQUIRES(
|
||||
ctx, ndims >= 2,
|
||||
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
|
||||
TensorShape out_shape;
|
||||
for (int i = 0; i < ndims - 2; ++i) {
|
||||
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
|
||||
errors::InvalidArgument(
|
||||
"In[0].dim(", i, ") and In[1].dim(", i,
|
||||
") must be the same: ", in0.shape().DebugString(), " vs ",
|
||||
in1.shape().DebugString()));
|
||||
out_shape.AddDim(in0.dim_size(i));
|
||||
}
|
||||
auto n = (ndims == 2) ? 1 : out_shape.num_elements();
|
||||
auto d0 = in0.dim_size(ndims - 2);
|
||||
auto d1 = in0.dim_size(ndims - 1);
|
||||
ctx, bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"In[0] and In[1] must have compatible batch dimensions: ",
|
||||
in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
|
||||
|
||||
TensorShape out_shape = bcast.output_batch_shape();
|
||||
auto batch_size = bcast.output_batch_size();
|
||||
auto d0 = in0.dim_size(in0.dims() - 2);
|
||||
auto d1 = in0.dim_size(in0.dims() - 1);
|
||||
Tensor in0_reshaped;
|
||||
CHECK(in0_reshaped.CopyFrom(in0, TensorShape({n, d0, d1})));
|
||||
auto d2 = in1.dim_size(ndims - 2);
|
||||
auto d3 = in1.dim_size(ndims - 1);
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
|
||||
errors::Internal("Failed to reshape In[0] from ",
|
||||
in0.shape().DebugString()));
|
||||
auto d2 = in1.dim_size(in1.dims() - 2);
|
||||
auto d3 = in1.dim_size(in1.dims() - 1);
|
||||
Tensor in1_reshaped;
|
||||
CHECK(in1_reshaped.CopyFrom(in1, TensorShape({n, d2, d3})));
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
|
||||
errors::Internal("Failed to reshape In[1] from ",
|
||||
in1.shape().DebugString()));
|
||||
if (adj_x_) std::swap(d0, d1);
|
||||
if (adj_y_) std::swap(d2, d3);
|
||||
OP_REQUIRES(ctx, d1 == d2,
|
||||
@ -568,31 +642,102 @@ class BatchMatMul : public OpKernel {
|
||||
return;
|
||||
}
|
||||
Tensor out_reshaped;
|
||||
CHECK(out_reshaped.CopyFrom(*out, TensorShape({n, d0, d3})));
|
||||
LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
|
||||
adj_x_, adj_y_, &out_reshaped);
|
||||
OP_REQUIRES(ctx,
|
||||
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
|
||||
errors::Internal("Failed to reshape output from ",
|
||||
out->shape().DebugString()));
|
||||
LaunchBatchMatMul<Device, Scalar>::Launch(
|
||||
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, bcast, &out_reshaped);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) = 0;
|
||||
|
||||
private:
|
||||
bool adj_x_;
|
||||
bool adj_y_;
|
||||
};
|
||||
|
||||
#define REGISTER_BATCH_MATMUL_CPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMul<CPUDevice, TYPE>)
|
||||
// BatchMatMul Op implementation which disallows broadcasting.
|
||||
template <typename Device, typename Scalar>
|
||||
class BatchMatMulOp : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
public:
|
||||
explicit BatchMatMulOp(OpKernelConstruction* context)
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context) {}
|
||||
|
||||
#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMul<GPUDevice, TYPE>)
|
||||
~BatchMatMulOp() override {}
|
||||
|
||||
private:
|
||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) override {
|
||||
// Disallow broadcasting support. Ensure that all batch dimensions of the
|
||||
// input tensors match.
|
||||
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
|
||||
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
|
||||
in0.shape().DebugString(), " vs. ",
|
||||
in1.shape().DebugString()));
|
||||
const int ndims = in0.dims();
|
||||
OP_REQUIRES(
|
||||
ctx, ndims >= 2,
|
||||
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
|
||||
for (int i = 0; i < ndims - 2; ++i) {
|
||||
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
|
||||
errors::InvalidArgument(
|
||||
"In[0].dim(", i, ") and In[1].dim(", i,
|
||||
") must be the same: ", in0.shape().DebugString(), " vs ",
|
||||
in1.shape().DebugString()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// BatchMatMul Op implementation with broadcasting support.
|
||||
template <typename Device, typename Scalar>
|
||||
class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
|
||||
public:
|
||||
explicit BatchMatMulV2Op(OpKernelConstruction* context)
|
||||
: BaseBatchMatMulOp<Device, Scalar>(context) {}
|
||||
|
||||
~BatchMatMulV2Op() override {}
|
||||
|
||||
private:
|
||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) override {
|
||||
// Enable broadcasting support. Validity of broadcasting is checked in
|
||||
// BaseBatchMatMulOp.
|
||||
OP_REQUIRES(
|
||||
ctx, in0.dims() >= 2,
|
||||
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
|
||||
OP_REQUIRES(
|
||||
ctx, in1.dims() >= 2,
|
||||
errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_BATCH_MATMUL_CPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulOp<CPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulV2Op<CPUDevice, TYPE>)
|
||||
|
||||
#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulOp<GPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulV2Op<GPUDevice, TYPE>)
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_BATCH_MATMUL_SYCL(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMul<SYCLDevice, TYPE>)
|
||||
#define REGISTER_BATCH_MATMUL_SYCL(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulOp<SYCLDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMulV2").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
|
||||
BatchMatMulV2Op<SYCLDevice, TYPE>)
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -14,11 +14,40 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/broadcast_to_op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
|
||||
.Input(input)
|
||||
.Input(shape)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* BatchMatmulV2(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMulV2")
|
||||
.Input(in0)
|
||||
.Input(in1)
|
||||
.Attr("adj_x", adj_x)
|
||||
.Attr("adj_y", adj_y)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
|
||||
@ -33,6 +62,40 @@ static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
|
||||
return g;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BatchMatmulWithBroadcast(int b0, int b1, int m, int k, int n,
|
||||
bool manual_broadcast, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, TensorShape({b0, m, k}));
|
||||
in0.flat<T>().setRandom();
|
||||
Tensor in1(type, TensorShape({b1, k, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
|
||||
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
|
||||
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
|
||||
|
||||
Node* in0_node = nullptr;
|
||||
Node* in1_node = nullptr;
|
||||
if (manual_broadcast) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
auto vec0 = broadcasted_in0_shape.vec<int64>();
|
||||
auto vec1 = broadcasted_in1_shape.vec<int64>();
|
||||
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
|
||||
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
|
||||
}
|
||||
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, broadcasted_in0_shape));
|
||||
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
|
||||
test::graph::Constant(g, broadcasted_in1_shape));
|
||||
} else {
|
||||
in0_node = test::graph::Constant(g, in0);
|
||||
in1_node = test::graph::Constant(g, in1);
|
||||
}
|
||||
|
||||
BatchMatmulV2(g, in0_node, in1_node, false, false);
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_BatchMatmulDev(B, M, K, N, TA, TB, T, TFTYPE, DEVICE) \
|
||||
static void \
|
||||
BM_BatchMatmul##_##B##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE( \
|
||||
@ -59,6 +122,64 @@ static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a,
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \
|
||||
// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex<double>, DT_COMPLEX128, gpu);
|
||||
|
||||
// Macro arguments names: --------------------------------------------------- //
|
||||
// B1: batch size of LHS
|
||||
// B2: batch size of RHS
|
||||
// M: outer dimension of LHS
|
||||
// K: inner dimensions of LHS and RHS
|
||||
// N: outer dimension of RHS
|
||||
// MB: boolean indicating whether to use manual broadcasting
|
||||
// T: C++ type of scalars (e.g. float, std::complex)
|
||||
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
|
||||
// D: Device (e.g. cpu, gpu)
|
||||
#define BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, T, TT, D) \
|
||||
static void \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
|
||||
K * N * 2); \
|
||||
test::Benchmark(#D, BatchMatmulWithBroadcast<T>(B1, B2, M, K, N, MB, TT)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_BatchMatmulBCast##_##B1##_##B2##_##M##_##K##_##N##_##MB##_##TT##_##D);
|
||||
|
||||
#define BM_BatchMatmulBCast(B1, B2, M, K, N, MB) \
|
||||
BM_BatchMatmulBCastDev(B1, B2, M, K, N, MB, float, DT_FLOAT, cpu);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 128, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 128, 1024, 1024, false);
|
||||
|
||||
// Square matmul.
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(1, 128, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, true);
|
||||
BM_BatchMatmulBCast(128, 1, 512, 512, 512, false);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1024, 1024, 1024, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1024, 1024, 1024, false);
|
||||
|
||||
// Matrix-vector multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(1, 128, 10000, 200, 1, false);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, true);
|
||||
BM_BatchMatmulBCast(128, 1, 10000, 200, 1, false);
|
||||
|
||||
// Vector-matrix multiplies.
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(1, 128, 1, 200, 10000, false);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, true);
|
||||
BM_BatchMatmulBCast(128, 1, 1, 200, 10000, false);
|
||||
|
||||
// Typical fully connected layers
|
||||
BM_BatchMatmul(1, 1, 1024, 1024, false, false);
|
||||
BM_BatchMatmul(1, 8, 1024, 1024, false, false);
|
||||
@ -132,4 +253,5 @@ BM_BatchMatmul(1, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(8, 1, 200, 10000, true, true);
|
||||
BM_BatchMatmul(32, 1, 200, 10000, true, true);
|
||||
|
||||
} // end namespace tensorflow
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -158,6 +158,17 @@ REGISTER_OP("BatchMatMul")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("BatchMatMulV2")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("output: T")
|
||||
.Attr(
|
||||
"T: {bfloat16, half, float, double, int32, int64, complex64, "
|
||||
"complex128}")
|
||||
.Attr("adj_x: bool = false")
|
||||
.Attr("adj_y: bool = false")
|
||||
.SetShapeFn(shape_inference::BatchMatMulV2Shape);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Casting Ops
|
||||
//
|
||||
|
@ -2782,6 +2782,7 @@ py_library(
|
||||
":state_ops_gen",
|
||||
":tensor_shape",
|
||||
":util",
|
||||
"//tensorflow/python/compat",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
@ -21,63 +21,43 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.linalg import linear_operator_util
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def GetRandomNormalInput(shape, dtype):
|
||||
# float16 has limited range so we reduce the variance of the scalars.
|
||||
scale = 10.0 if dtype != np.float16 else 0.1
|
||||
loc = -10.0 if dtype != np.float16 else 0.1
|
||||
vals = np.array(np.random.normal(loc, scale, np.prod(shape)), dtype=dtype)
|
||||
if dtype in (np.complex64, np.complex128):
|
||||
imag = np.array(np.random.normal(loc, scale, np.prod(shape)), dtype=dtype)
|
||||
vals += 1j * imag
|
||||
return vals.reshape(shape)
|
||||
|
||||
|
||||
class BatchMatmulOpTest(test.TestCase):
|
||||
|
||||
# Uses numpy to compute batch_matmul(x, y, adjoint_a, adjoint_b).
|
||||
def _npBatchMatmul(self, x, y, adjoint_a, adjoint_b):
|
||||
# output's shape depends on adj[0] and adj[1]
|
||||
d0 = x.shape[-2] if not adjoint_a else x.shape[-1]
|
||||
d2 = y.shape[-1] if not adjoint_b else y.shape[-2]
|
||||
batch_dims = x.shape[:-2]
|
||||
num = np.prod(batch_dims)
|
||||
z = np.empty(list(batch_dims) + [d0, d2], dtype=x.dtype)
|
||||
xr = x.reshape([num, x.shape[-2], x.shape[-1]])
|
||||
yr = y.reshape([num, y.shape[-2], y.shape[-1]])
|
||||
zr = z.reshape([num, z.shape[-2], z.shape[-1]])
|
||||
for i in range(num):
|
||||
a = np.matrix(xr[i, :, :])
|
||||
if adjoint_a:
|
||||
a = a.transpose().conj()
|
||||
b = np.matrix(yr[i, :, :])
|
||||
if adjoint_b:
|
||||
b = b.transpose().conj()
|
||||
zr[i, :, :] = a * b
|
||||
return z
|
||||
if adjoint_a:
|
||||
x = np.conjugate(np.swapaxes(x, -1, -2))
|
||||
if adjoint_b:
|
||||
y = np.conjugate(np.swapaxes(y, -1, -2))
|
||||
return np.matmul(x, y)
|
||||
|
||||
# Test _npBatchMatMul works.
|
||||
def testNpVersion(self):
|
||||
x = np.array([0., 1., 2., 3.]).reshape([1, 2, 2])
|
||||
y = np.array([1., 2., 3., 4.]).reshape([1, 2, 2])
|
||||
z0 = self._npBatchMatmul(x, y, False, False)
|
||||
z1 = np.array([3., 4., 11., 16.]).reshape([1, 2, 2])
|
||||
self.assertTrue(np.array_equal(z0, z1))
|
||||
|
||||
x = np.array([1., (1j), (-1.), (-1j)]).reshape([1, 2, 2])
|
||||
y = x * np.complex(1, 1) # rotate x 90 degree
|
||||
z0 = self._npBatchMatmul(x, y, False, False)
|
||||
z1 = np.array([2., (2.j), -2., (-2.j)]).reshape([1, 2, 2])
|
||||
self.assertTrue(np.array_equal(z0, z1))
|
||||
|
||||
z0 = self._npBatchMatmul(x, y, False, True)
|
||||
z1 = np.array([(2. - 2.j), (-2. + 2.j), (-2. + 2.j), (2. - 2.j)]).reshape(
|
||||
[1, 2, 2])
|
||||
self.assertTrue(np.array_equal(z0, z1))
|
||||
|
||||
z0 = self._npBatchMatmul(x, y, True, False)
|
||||
z1 = np.array([(2. + 2.j), (-2. + 2.j), (2. - 2.j), (2. + 2.j)]).reshape(
|
||||
[1, 2, 2])
|
||||
self.assertTrue(np.array_equal(z0, z1))
|
||||
|
||||
# Compares _tfpBatchMatmul(x, y, alpha, adj) and _npBatchMatMul(x, y, alpha,
|
||||
# adj)
|
||||
def _compare(self, x_in, y_in, adjoint_a, adjoint_b, static_shape=True):
|
||||
# Compares TensorFlow BatchMatmul with NumPy's matmul.
|
||||
def _compare(self, x_in, y_in, adjoint_a, adjoint_b, static_shape):
|
||||
x_t_shape = x_in.shape[:-2] + (x_in.shape[-1], x_in.shape[-2])
|
||||
y_t_shape = y_in.shape[:-2] + (y_in.shape[-1], y_in.shape[-2])
|
||||
x = x_in if not adjoint_a else x_in.reshape(x_t_shape)
|
||||
@ -97,19 +77,15 @@ class BatchMatmulOpTest(test.TestCase):
|
||||
z1 = self._npBatchMatmul(x, y, adjoint_a, adjoint_b)
|
||||
self.assertAllClose(z0_val, z1, rtol=tol, atol=tol)
|
||||
|
||||
def _rand(self, shape, dtype):
|
||||
vals = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype)
|
||||
if dtype in (np.complex64, np.complex128):
|
||||
imag = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype)
|
||||
vals += 1j * imag
|
||||
return vals.reshape(shape)
|
||||
|
||||
def _testNonEmpty(self, dtype, adjoint_a, adjoint_b, use_static_shape):
|
||||
|
||||
def CompareNonEmpty(self, a_shape, b_shape):
|
||||
self._compare(
|
||||
self._rand(a_shape, dtype),
|
||||
self._rand(b_shape, dtype), adjoint_a, adjoint_b, use_static_shape)
|
||||
GetRandomNormalInput(a_shape, dtype),
|
||||
GetRandomNormalInput(b_shape, dtype),
|
||||
adjoint_a,
|
||||
adjoint_b,
|
||||
static_shape=use_static_shape)
|
||||
|
||||
CompareNonEmpty(self, [1, 2, 3], [1, 3, 5])
|
||||
CompareNonEmpty(self, [1, 2, 3], [1, 3, 1])
|
||||
@ -121,13 +97,33 @@ class BatchMatmulOpTest(test.TestCase):
|
||||
CompareNonEmpty(self, [10, 64, 75], [10, 75, 30])
|
||||
CompareNonEmpty(self, [5, 7, 2, 3], [5, 7, 3, 5])
|
||||
|
||||
def _testBroadcasting(self, dtype, adjoint_a, adjoint_b, use_static_shape):
|
||||
|
||||
def CompareNonEmpty(self, a_shape, b_shape):
|
||||
self._compare(
|
||||
GetRandomNormalInput(a_shape, dtype),
|
||||
GetRandomNormalInput(b_shape, dtype),
|
||||
adjoint_a,
|
||||
adjoint_b,
|
||||
static_shape=use_static_shape)
|
||||
|
||||
CompareNonEmpty(self, [2, 3], [1, 3, 5])
|
||||
CompareNonEmpty(self, [1, 2, 3], [3, 5])
|
||||
CompareNonEmpty(self, [5, 1, 2, 3], [1, 7, 3, 5])
|
||||
CompareNonEmpty(self, [5, 2, 2, 3], [3, 5])
|
||||
CompareNonEmpty(self, [2, 3], [5, 2, 3, 5])
|
||||
CompareNonEmpty(self, [4, 5, 1, 2, 3], [1, 1, 3, 5])
|
||||
CompareNonEmpty(self, [1, 2, 1, 4, 2, 1, 3, 4], [3, 2, 1, 1, 1, 2, 4, 2])
|
||||
|
||||
def _testEmpty(self, dtype, adjoint_a, adjoint_b, use_static_shape):
|
||||
|
||||
def CompareEmpty(self, a_shape, b_shape):
|
||||
self._compare(
|
||||
np.zeros(a_shape).astype(dtype),
|
||||
np.zeros(b_shape).astype(dtype), adjoint_a, adjoint_b,
|
||||
use_static_shape)
|
||||
np.zeros(b_shape).astype(dtype),
|
||||
adjoint_a,
|
||||
adjoint_b,
|
||||
static_shape=use_static_shape)
|
||||
|
||||
CompareEmpty(self, [0, 3, 2], [0, 2, 4])
|
||||
CompareEmpty(self, [3, 0, 2], [3, 2, 5])
|
||||
@ -136,7 +132,6 @@ class BatchMatmulOpTest(test.TestCase):
|
||||
|
||||
def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape):
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def Test(self):
|
||||
np.random.seed(42)
|
||||
self._testNonEmpty(dtype, adjoint_a, adjoint_b, use_static_shape)
|
||||
@ -145,6 +140,18 @@ def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape):
|
||||
return Test
|
||||
|
||||
|
||||
def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b,
|
||||
use_static_shape):
|
||||
|
||||
@test_util.disable_xla("b/128537983")
|
||||
def Test(self):
|
||||
with compat.forward_compatibility_horizon(2019, 4, 19):
|
||||
np.random.seed(42)
|
||||
self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
|
||||
|
||||
return Test
|
||||
|
||||
|
||||
class BatchMatmulGradientTest(test.TestCase):
|
||||
|
||||
# loss = sum(batch_matmul(x, y)). Verify dl/dx and dl/dy via the
|
||||
@ -155,45 +162,125 @@ class BatchMatmulGradientTest(test.TestCase):
|
||||
x = x_in if not adjoint_a else x_in.reshape(x_t_shape)
|
||||
y = y_in if not adjoint_b else y_in.reshape(y_t_shape)
|
||||
epsilon = np.finfo(x.dtype).eps
|
||||
delta = epsilon**(1.0 / 3.0)
|
||||
# Since our gradient is linear, a larger delta decreases the error.
|
||||
delta = 10 * epsilon**(1.0 / 3.0)
|
||||
|
||||
def Loss(x, y):
|
||||
z = math_ops.matmul(x, y, adjoint_a, adjoint_b)
|
||||
return math_ops.reduce_sum(z)
|
||||
return math_ops.reduce_sum(math_ops.matmul(x, y, adjoint_a, adjoint_b))
|
||||
|
||||
with self.cached_session(use_gpu=True):
|
||||
((x_jacob_t, y_jacob_t),
|
||||
(x_jacob_n, y_jacob_n)) = gradient_checker_v2.compute_gradient(
|
||||
Loss, [x, y], delta=delta)
|
||||
tol = 20 * delta
|
||||
tol = 10 * delta
|
||||
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=tol, atol=tol)
|
||||
self.assertAllClose(y_jacob_t, y_jacob_n, rtol=tol, atol=tol)
|
||||
|
||||
# Tests a batched matmul of x, and y: x is a 3D tensor of shape [b,
|
||||
# n, k] y is a 3D tensor of shape [b, k, m] the batched matmul
|
||||
# computes z of shape [b, n, m], where z[i, :, :] = x[i, :, :]
|
||||
# matmul y[i, :, :]
|
||||
def _compare(self, b, n, k, m, dtype, adjoint_a, adjoint_b):
|
||||
# Tests gradients of a batched matmul of x, and y
|
||||
def _compare(self, a_shape, b_shape, dtype, adjoint_a, adjoint_b):
|
||||
np.random.seed(42)
|
||||
x = np.random.normal(0, 1, b * n * k).astype(dtype).reshape([b, n, k])
|
||||
if dtype in (np.complex64, np.complex128):
|
||||
x.imag = np.random.normal(0, 1,
|
||||
b * n * k).astype(dtype).reshape([b, n, k])
|
||||
y = np.random.normal(0, 1, b * k * m).astype(dtype).reshape([b, k, m])
|
||||
if dtype in (np.complex64, np.complex128):
|
||||
y.imag = np.random.normal(0, 1,
|
||||
b * k * m).astype(dtype).reshape([b, k, m])
|
||||
x = GetRandomNormalInput(a_shape, dtype)
|
||||
y = GetRandomNormalInput(b_shape, dtype)
|
||||
self._checkGrad(x, y, adjoint_a, adjoint_b)
|
||||
|
||||
|
||||
def _GetBatchMatmulGradientTest(dtype, adjoint_a, adjoint_b):
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def Test(self):
|
||||
self._compare(1, 2, 3, 5, dtype, adjoint_a, adjoint_b)
|
||||
self._compare(3, 4, 7, 10, dtype, adjoint_a, adjoint_b)
|
||||
|
||||
def CheckGradients(self, a_shape, b_shape):
|
||||
self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b)
|
||||
|
||||
CheckGradients(self, [1, 2, 3], [1, 3, 5])
|
||||
CheckGradients(self, [3, 4, 7], [3, 7, 10])
|
||||
|
||||
return Test
|
||||
|
||||
|
||||
def _GetBatchMatmulGradientWithBroadcastingTest(dtype, adjoint_a, adjoint_b):
|
||||
|
||||
@test_util.disable_xla("b/128537983")
|
||||
def Test(self):
|
||||
|
||||
def CheckGradients(self, a_shape, b_shape):
|
||||
self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b)
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 4, 19):
|
||||
CheckGradients(self, [1, 5, 2, 3], [7, 1, 3, 2])
|
||||
CheckGradients(self, [2, 3], [1, 3, 5])
|
||||
CheckGradients(self, [2, 3], [5, 3, 5])
|
||||
CheckGradients(self, [5, 2, 5], [5, 3])
|
||||
CheckGradients(self, [5, 2, 2, 3], [3, 5])
|
||||
CheckGradients(self, [4, 5, 1, 2, 3], [1, 1, 3, 5])
|
||||
CheckGradients(self, [1, 2, 1, 4, 2, 1, 3, 4], [3, 2, 1, 1, 1, 2, 4, 2])
|
||||
|
||||
return Test
|
||||
|
||||
|
||||
class BatchMatMulBenchmark(test.Benchmark):
|
||||
# Batch sizes are 512.
|
||||
shape_pairs = [
|
||||
# Typical fully connected layer.
|
||||
((4, 8, 4, 2, 1, 1024), (1024, 1024)),
|
||||
((4, 1, 4, 1, 1, 1024), (1, 8, 1, 2, 1024, 1024)),
|
||||
# Square matmul.
|
||||
((4, 8, 4, 2, 512, 512), (512, 512)),
|
||||
((4, 1, 4, 1, 512, 512), (1, 8, 1, 2, 512, 512)),
|
||||
# Matrix-vector multiplies.
|
||||
((4, 8, 4, 2, 10000, 200), (200, 1)),
|
||||
((4, 1, 4, 1, 10000, 200), (1, 8, 1, 2, 200, 1)),
|
||||
# Vector-matrix multiplies.
|
||||
((4, 8, 4, 2, 1, 200), (200, 10000)),
|
||||
((4, 1, 4, 1, 1, 200), (1, 8, 1, 2, 200, 10000)),
|
||||
]
|
||||
|
||||
def benchmarkBatchMatMulBroadcast(self):
|
||||
for (a_shape, b_shape) in self.shape_pairs:
|
||||
with compat.forward_compatibility_horizon(2019, 4, 19):
|
||||
with ops.Graph().as_default(), \
|
||||
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||
ops.device("/cpu:0"):
|
||||
matrix_a = variables.Variable(
|
||||
GetRandomNormalInput(a_shape, np.float32))
|
||||
matrix_b = variables.Variable(
|
||||
GetRandomNormalInput(b_shape, np.float32))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Use batch matmul op's internal broadcasting.
|
||||
self.run_op_benchmark(
|
||||
sess,
|
||||
math_ops.matmul(matrix_a, matrix_b),
|
||||
min_iters=50,
|
||||
name="batch_matmul_cpu_{}_{}".format(a_shape, b_shape))
|
||||
|
||||
# Manually broadcast the input matrices using the broadcast_to op.
|
||||
broadcasted_batch_shape = array_ops.broadcast_static_shape(
|
||||
matrix_a.shape[:-2], matrix_b.shape[:-2])
|
||||
broadcasted_a_shape = broadcasted_batch_shape.concatenate(
|
||||
matrix_a.shape[-2:])
|
||||
broadcasted_b_shape = broadcasted_batch_shape.concatenate(
|
||||
matrix_b.shape[-2:])
|
||||
self.run_op_benchmark(
|
||||
sess,
|
||||
math_ops.matmul(
|
||||
array_ops.broadcast_to(matrix_a, broadcasted_a_shape),
|
||||
array_ops.broadcast_to(matrix_b, broadcasted_b_shape)),
|
||||
min_iters=50,
|
||||
name="batch_matmul_manual_broadcast_cpu_{}_{}".format(
|
||||
a_shape, b_shape))
|
||||
|
||||
# Use linear_operator_util.matmul_with_broadcast.
|
||||
name_template = (
|
||||
"batch_matmul_manual_broadcast_with_linear_operator_util"
|
||||
"_cpu_{}_{}"
|
||||
)
|
||||
self.run_op_benchmark(
|
||||
sess,
|
||||
linear_operator_util.matmul_with_broadcast(matrix_a, matrix_b),
|
||||
min_iters=50,
|
||||
name=name_template.format(a_shape, b_shape))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for dtype_ in [
|
||||
np.float16, np.float32, np.float64, np.complex64, np.complex128, np.int32
|
||||
@ -201,13 +288,27 @@ if __name__ == "__main__":
|
||||
for adjoint_a_ in False, True:
|
||||
for adjoint_b_ in False, True:
|
||||
name = "%s_%s_%s" % (dtype_.__name__, adjoint_a_, adjoint_b_)
|
||||
# TF2 does not support placeholders under eager so we skip it
|
||||
# TF2 does not support placeholders under eager so we skip it.
|
||||
for use_static_shape_ in set([True, tf2.enabled()]):
|
||||
setattr(BatchMatmulOpTest,
|
||||
"testBatchMatmulOp_" + name + ("_%s" % use_static_shape_),
|
||||
_GetBatchMatmulOpTest(dtype_, adjoint_a_, adjoint_b_,
|
||||
use_static_shape_))
|
||||
if dtype_ is not np.int32:
|
||||
setattr(BatchMatmulGradientTest, "testBatchMatmulGradient_" + name,
|
||||
_GetBatchMatmulGradientTest(dtype_, adjoint_a_, adjoint_b_))
|
||||
setattr(
|
||||
BatchMatmulOpTest,
|
||||
"testBatchMatmulOp_" + name + "_{}".format(use_static_shape_),
|
||||
_GetBatchMatmulOpTest(dtype_, adjoint_a_, adjoint_b_,
|
||||
use_static_shape_))
|
||||
# Broadcasting is supported only in v2.
|
||||
setattr(
|
||||
BatchMatmulOpTest, "testBatchMatmulBroadcasting_" + name +
|
||||
("_%s" % use_static_shape_),
|
||||
_GetBatchMatmulOpBroadcastingTest(dtype_, adjoint_a_, adjoint_b_,
|
||||
use_static_shape_))
|
||||
if dtype_ == np.int32:
|
||||
continue
|
||||
setattr(BatchMatmulGradientTest, "testBatchMatmulGradient_" + name,
|
||||
_GetBatchMatmulGradientTest(dtype_, adjoint_a_, adjoint_b_))
|
||||
# Broadcasting is supported only in v2.
|
||||
setattr(
|
||||
BatchMatmulGradientTest,
|
||||
"testBatchMatmulGradientWithBroadcasting_" + name,
|
||||
_GetBatchMatmulGradientWithBroadcastingTest(dtype_, adjoint_a_,
|
||||
adjoint_b_))
|
||||
test.main()
|
||||
|
@ -1473,6 +1473,44 @@ def _BatchMatMul(op, grad):
|
||||
return grad_x, grad_y
|
||||
|
||||
|
||||
@ops.RegisterGradient("BatchMatMulV2")
|
||||
def _BatchMatMulV2(op, grad):
|
||||
"""Returns the gradient of x and y given the gradient of x * y."""
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[1]
|
||||
adj_x = op.get_attr("adj_x")
|
||||
adj_y = op.get_attr("adj_y")
|
||||
|
||||
if not adj_x:
|
||||
if not adj_y:
|
||||
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
|
||||
grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
|
||||
else:
|
||||
grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
|
||||
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
|
||||
else:
|
||||
if not adj_y:
|
||||
grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
|
||||
grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
|
||||
else:
|
||||
grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
|
||||
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
|
||||
|
||||
# Reduce along the broadcasted batch dimensions, if broadcasting is required.
|
||||
shape_x_static = x.get_shape()
|
||||
shape_y_static = y.get_shape()
|
||||
if not (shape_x_static.is_fully_defined() and
|
||||
shape_y_static.is_fully_defined() and
|
||||
shape_x_static == shape_y_static):
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
|
||||
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
|
||||
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
|
||||
|
||||
return grad_x, grad_y
|
||||
|
||||
|
||||
ops.NotDifferentiable("Range")
|
||||
ops.NotDifferentiable("LinSpace")
|
||||
|
||||
|
@ -73,6 +73,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.compat import compat as fwd_compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -2547,9 +2548,20 @@ def matmul(a,
|
||||
# TODO(apassos) remove _shape_tuple here when it is not needed.
|
||||
a_shape = a._shape_tuple() # pylint: disable=protected-access
|
||||
b_shape = b._shape_tuple() # pylint: disable=protected-access
|
||||
|
||||
if fwd_compat.forward_compatible(2019, 4, 18):
|
||||
output_may_have_non_empty_batch_shape = (
|
||||
(a_shape is None or len(a_shape) > 2) or
|
||||
(b_shape is None or len(b_shape) > 2))
|
||||
batch_mat_mul_fn = gen_math_ops.batch_mat_mul_v2
|
||||
else:
|
||||
output_may_have_non_empty_batch_shape = (
|
||||
(a_shape is None or len(a_shape) > 2) and
|
||||
(b_shape is None or len(b_shape) > 2))
|
||||
batch_mat_mul_fn = gen_math_ops.batch_mat_mul
|
||||
|
||||
if (not a_is_sparse and
|
||||
not b_is_sparse) and ((a_shape is None or len(a_shape) > 2) and
|
||||
(b_shape is None or len(b_shape) > 2)):
|
||||
not b_is_sparse) and output_may_have_non_empty_batch_shape:
|
||||
# BatchMatmul does not support transpose, so we conjugate the matrix and
|
||||
# use adjoint instead. Conj() is a noop for real matrices.
|
||||
if transpose_a:
|
||||
@ -2558,8 +2570,7 @@ def matmul(a,
|
||||
if transpose_b:
|
||||
b = conj(b)
|
||||
adjoint_b = True
|
||||
return gen_math_ops.batch_mat_mul(
|
||||
a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
|
||||
return batch_mat_mul_fn(a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
|
||||
|
||||
# Neither matmul nor sparse_matmul support adjoint, so we conjugate
|
||||
# the matrix and use transpose instead. Conj() is a noop for real
|
||||
|
@ -316,6 +316,10 @@ tf_module {
|
||||
name: "BatchMatMul"
|
||||
argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchMatMulV2"
|
||||
argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchMatrixBandPart"
|
||||
argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -316,6 +316,10 @@ tf_module {
|
||||
name: "BatchMatMul"
|
||||
argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchMatMulV2"
|
||||
argspec: "args=[\'x\', \'y\', \'adj_x\', \'adj_y\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchMatrixBandPart"
|
||||
argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user