Extend functionality of tf.einsum to the full NumPy spec. (part 1 of 2).

This CL creates a C++ Einsum Op which will be the basis of tf.einsum.

Functionality improvements:
 - Support repeated indices / generalized traces (e.g. ijjk,klm->ikm)
 - Improve ellipsis labels to support NumPy-style broadcasting. Previously, it only
   supported broadcasting the same batch shape.
   (e.g. [1,1,2] is not compatible with [3,1])
 - Ellipsis also supports partially-unknown and unknown shapes.

Performance improvements:
 - Take advantage of transpose_a/transpose_b params in matmul to improve performance in
   common matmul-like cases.

PiperOrigin-RevId: 248830232
This commit is contained in:
Anudhyan Boral 2019-05-17 20:37:53 -07:00 committed by TensorFlower Gardener
parent 39c4ef4be6
commit 8e8f040cec
15 changed files with 1498 additions and 0 deletions

View File

@ -959,6 +959,7 @@ tf_cuda_library(
"util/guarded_philox_random.h",
"util/mirror_pad_mode.h",
"util/padding.h",
"util/einsum_op_util.h",
"util/port.h",
"util/ptr_util.h",
"util/reffed_status_callback.h",

View File

@ -0,0 +1,100 @@
op {
graph_op_name: "Einsum"
in_arg {
name: "inputs"
description: <<END
List of 1 or 2 Tensors.
END
}
out_arg {
name: "output"
description: <<END
Output Tensor with shape depending upon `equation`.
END
}
attr {
name: "equation"
description: <<END
String describing the Einstein Summation operation; in the format of np.einsum.
END
}
summary: "Tensor contraction according to Einstein summation convention."
description: <<END
Implements generalized Tensor contraction and reduction. Each input Tensor must
have a corresponding input subscript appearing in the comma-separated left-hand
side of the equation. The right-hand side of the equation consists of the
output subscript. The input subscripts and the output subscript should consist
of zero or more named axis labels and at most one ellipsis (`...`).
The named axis labels may be any single character other than those having
special meaning, namely `,.->`. The behavior of this Op is undefined if it
receives an ill-formatted equation; since the validation is done at
graph-building time, we omit format validation checks at runtime.
Note: This Op is *not* intended to be called by the user; instead users should
call `tf.einsum` directly. It is a hidden Op used by `tf.einsum`.
Operations are applied to the input(s) according to the following rules:
(a) Generalized Diagonals: For input dimensions corresponding to axis labels
appearing more than once in the same input subscript, we take the
generalized (`k`-dimensional) diagonal.
For example, in the equation `iii->i` with input shape `[3, 3, 3]`, the
generalized diagonal would consist of `3` elements at indices `(0, 0, 0)`,
`(1, 1, 1)` and `(2, 2, 2)` to create a Tensor of shape `[3]`.
(b) Reduction: Axes corresponding to labels appearing only in one input
subscript but not in the output subscript are summed over prior to Tensor
contraction.
For example, in the equation `ab,bc->b`, the axis labels `a` and `c` are
the reduction axis labels.
(c) Batch Dimensions: Axes corresponding to labels appearing in each of the
input subscripts and also in the output subscript make up the batch
dimensions in Tensor contraction. Unnamed axis labels corresponding to
ellipsis (`...`) also correspond to batch dimensions.
For example, for the equation denoting batch matrix multiplication,
`bij,bjk->bik`, the axis label `b` corresponds to a batch dimension.
(d) Contraction: In case of binary einsum, axes corresponding to labels
appearing in two different inputs (and not in the output) are contracted
against each other.
Considering the batch matrix multiplication equation again
(`bij,bjk->bik`), the contracted axis label is `j`.
(e) Expand Diagonal: If the output subcripts contain repeated (explicit) axis
labels, the opposite operation of (a) is applied. For example, in the
equation `i->iii`, and input shape `[3]`, the output of shape `[3, 3, 3]`
are all zeros, except for the (generalized) diagonal which is populated
with values from the input.
Note: This operation is not supported by `np.einsum` or `tf.einsum`; it is
provided to enable computing the symbolic gradient of `tf.einsum`.
The output subcripts must contain only labels appearing in at least one of the
input subscripts. Furthermore, all dimensions mapping to the same axis label
must be equal.
Any of the input and output subscripts may contain at most a single ellipsis
(`...`). These ellipsis are mapped against dimensions not corresponding to any
named axis label. If two inputs contain ellipsis, then they are broadcasted
according to standard NumPy broadcasting
[rules](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
The broadcasted dimensions are placed in the corresponding location of the
ellipsis in the output subscript. If the broadcasted dimensions are non-empty
and the output subcripts do not contain ellipsis, then an InvalidArgument error
is raised.
@compatibility(numpy)
Similar to [`numpy.einsum`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html).
Comparison with `numpy.einsum`:
* This Op only supports unary and binary forms of `numpy.einsum`.
* This Op does not support implicit form. (i.e. equations without `->`).
* This Op also supports repeated indices in the output subscript, which is not
supported by `numpy.einsum`.
@end_compatibility
END
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "Einsum"
visibility: HIDDEN
}

View File

@ -13,8 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/util/einsum_op_util.h"
namespace tensorflow {
@ -233,6 +241,178 @@ Status MatMulShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
namespace {
// Validate that an Einsum subscript contains exactly one or zero ellipsis; and
// that periods (.) occur only within an ellipses (...).
Status ValidateEinsumEllipsis(absl::string_view subscript,
bool* found_ellipsis) {
const int num_periods = absl::c_count(subscript, '.');
if (num_periods != 0 && num_periods != 3) {
return errors::InvalidArgument(
"Expected at most one ellipsis (...), but found ", num_periods,
" periods (.) in the input subscript: ", subscript);
}
if (num_periods == 3 && !absl::StrContains(subscript, "...")) {
return errors::InvalidArgument(
"Periods found outside of ellipsis in subscript: ", subscript);
}
*found_ellipsis = num_periods > 0;
return Status::OK();
}
} // namespace
Status EinsumShape(shape_inference::InferenceContext* c) {
// We assume that the equation has a valid format. Either (x),(y)->(z)
// or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or
// more latin alphabets and contains at most one ellipsis ('...').
string equation;
TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation));
gtl::InlinedVector<string, 2> input_labels;
string output_labels;
TF_RETURN_IF_ERROR(
ParseEinsumEquation(equation, &input_labels, &output_labels));
if (c->num_inputs() == 0 || c->num_inputs() > 2) {
return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ",
c->num_inputs());
}
if (c->num_inputs() != input_labels.size()) {
return errors::InvalidArgument("Expected ", input_labels.size(),
" inputs for equation ", equation,
" but got: ", c->num_inputs());
}
// Validate input subscripts, build the label to dimension mapping and obtain
// the broadcast shapes that map to ellipsis.
absl::flat_hash_map<char, DimensionHandle> label_to_dimension;
gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs());
for (int i = 0; i < c->num_inputs(); ++i) {
bool has_ellipsis = false;
TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis));
ShapeHandle input_shape = c->input(i);
// Validate that the input rank is sufficient for the given number of named
// labels.
if (c->RankKnown(input_shape)) {
if (has_ellipsis) {
const int num_named_labels =
static_cast<int>(input_labels[i].size()) - 3;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
c->WithRankAtLeast(input_shape, num_named_labels, &input_shape),
" for ", i, "th input and equation: ", equation);
} else {
const int num_named_labels = static_cast<int>(input_labels[i].size());
TF_RETURN_WITH_CONTEXT_IF_ERROR(
c->WithRank(input_shape, num_named_labels, &input_shape), " for ",
i, "th input and equation: ", equation);
}
}
bool seen_ellipsis = false;
input_bcast_shapes[i] = c->Scalar();
// Run through the input labels; populate label_to_dimension mapping and
// compute the broadcast shapes corresponding to the ellipsis (if present).
for (int label_idx = 0; label_idx < input_labels[i].size(); ++label_idx) {
const char label = input_labels[i][label_idx];
// Calculate the input axis that the current label is referring to. After
// the ellipsis, the axis may be found by using negative indices; i.e the
// (rank - k)th dimension corresponds to the (num_labels - k)th label.
const int64 axis_before_ellipsis = label_idx;
const int64 axis_after_ellipsis =
c->RankKnown(input_shape)
? label_idx + c->Rank(input_shape) - input_labels[i].size()
: -1;
// Populate the input broadcast shape when we encounter an ellipsis (...).
if (label == '.') {
if (!c->RankKnown(input_shape)) {
input_bcast_shapes[i] = c->UnknownShape();
} else {
// The broadcast shape runs till the named label right after the
// ellipsis, the label with index (label_idx + 3).
TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis,
axis_after_ellipsis + 3,
&input_bcast_shapes[i]));
}
label_idx += 2; // Skip the rest of the ellipsis.
seen_ellipsis = true;
continue;
}
// Obtain the dimension that the current label corresponds to.
int64 axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis;
DimensionHandle new_dim = c->RankKnown(input_shape)
? c->Dim(input_shape, axis)
: c->UnknownDim();
// If we've seen this label before, make sure previous and current
// dimensions are compatible.
if (label_to_dimension.contains(label)) {
DimensionHandle merged;
TF_RETURN_IF_ERROR(
c->Merge(label_to_dimension[label], new_dim, &merged));
label_to_dimension[label] = merged;
} else {
label_to_dimension[label] = new_dim;
}
}
}
// For two inputs, broadcast the two input broadcast shapes to create the
// output broadcast shape. For one input, just copy the single broadcast
// shape.
ShapeHandle output_bcast_shape;
if (input_bcast_shapes.size() == 1) {
output_bcast_shape = input_bcast_shapes[0];
} else if (input_bcast_shapes.size() == 2) {
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, input_bcast_shapes[0], input_bcast_shapes[1], &output_bcast_shape));
}
bool output_has_ellipsis = false;
TF_RETURN_IF_ERROR(
ValidateEinsumEllipsis(output_labels, &output_has_ellipsis));
if (output_has_ellipsis) {
// If the output subscript has ellipsis and the output broadcast rank is
// unknown, then the output shape should have unknown rank.
if (!c->RankKnown(output_bcast_shape)) {
c->set_output(0, c->UnknownShape());
return Status::OK();
}
} else {
// If the output subscripts don't have ellipsis then make sure the output
// broadcasting shape is empty.
TF_RETURN_WITH_CONTEXT_IF_ERROR(
c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape),
" for einsum equation '", equation,
"' without ellipsis (...) in the output subscripts where input(s) have "
"non-empty broadcasting shape");
output_bcast_shape = c->Scalar();
}
// Create the output shape from output labels and label_to_dimension mapping.
std::vector<DimensionHandle> output_dims;
for (int label_idx = 0; label_idx < output_labels.size(); ++label_idx) {
const char label = output_labels[label_idx];
// Append the output_bcast_shape when the ellipsis is encountered.
if (label == '.') {
for (int k = 0; k < c->Rank(output_bcast_shape); ++k) {
output_dims.push_back(c->Dim(output_bcast_shape, k));
}
label_idx += 2; // Skip the rest of the ellipsis.
continue;
}
auto dimension_it = label_to_dimension.find(label);
if (dimension_it == label_to_dimension.end()) {
return errors::InvalidArgument(
"Einsum output subscripts for equation '", equation, "' has label '",
label, "' which is not present in the input subscripts");
}
output_dims.push_back(dimension_it->second);
}
c->set_output(0, c->MakeShape(output_dims));
return Status::OK();
}
Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
ShapeHandle a_shape;
ShapeHandle b_shape;

View File

@ -230,6 +230,9 @@ Status MatMulShape(shape_inference::InferenceContext* c);
// batch dimensions.
Status BatchMatMulV2Shape(shape_inference::InferenceContext* c);
// Shape function for Einsum.
Status EinsumShape(shape_inference::InferenceContext* c);
// Shape function for BiasAdd-like operations.
Status BiasAddShape(shape_inference::InferenceContext* c);

View File

@ -213,6 +213,129 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
}
}
TEST(CommonShapeFnsTest, Einsum_ShapeFn) {
ShapeInferenceTestOp op("Einsum");
auto set_equation = [&op](int n, string equation) {
std::vector<NodeDefBuilder::NodeOut> input_list;
input_list.reserve(n);
for (int i = 0; i < n; ++i) {
input_list.emplace_back("a", 0, DT_FLOAT);
}
TF_ASSERT_OK(NodeDefBuilder("test", "Einsum")
.Input(input_list)
.Attr("equation", equation)
.Finalize(&op.node_def));
};
// Unary cases.
set_equation(1, "abc->c");
INFER_OK(op, "[?,?,?]", "[d0_2]");
set_equation(1, "abc->aabbcc");
INFER_OK(op, "[?,?,?]", "[d0_0,d0_0,d0_1,d0_1,d0_2,d0_2]");
set_equation(1, "abc->");
INFER_OK(op, "[?,?,?]", "[]");
set_equation(1, "->");
INFER_OK(op, "[]", "[]");
// Binary cases.
set_equation(2, "ij,jk->ik");
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
set_equation(2, "ij,jk->ik");
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
set_equation(2, "ab,ab->");
INFER_OK(op, "[?,?];[?,?]", "[]");
set_equation(2, "ab,->b");
INFER_OK(op, "[?,?];[]", "[d0_1]");
set_equation(2, ",->");
INFER_OK(op, "[];[]", "[]");
set_equation(2, "aaa,b->abbb");
INFER_OK(op, "[?,?,?];[?]", "[d0_0,d1_0,d1_0,d1_0]");
set_equation(2, ",abcd->badc");
INFER_OK(op, "[];[?,?,?,?]", "[d1_1,d1_0,d1_3,d1_2]");
// Ellipsis cases.
set_equation(1, "a...bc->c...");
INFER_OK(op, "[?,?,?,?,?]", "[d0_4,d0_1,d0_2]");
set_equation(2, "...ij,...jk->...ik");
INFER_OK(op, "[?,?,?,?,?];[1,?,?]", "[d0_0,d0_1,d0_2,d0_3,d1_2]");
INFER_OK(op, "[1,?,?];[?,?,?,?,?]", "[d1_0,d1_1,d1_2,d0_1,d1_4]");
// Unknown rank.
set_equation(1, "abc->c");
INFER_OK(op, "?", "[?]");
set_equation(1, "a...bc->c");
INFER_OK(op, "?", "[?]");
set_equation(1, "a...bc->c...");
INFER_OK(op, "?", "?");
set_equation(2, "...ij,...jk->...ik");
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[?,?,?];?", "?");
INFER_OK(op, "?;[?,?,?]", "?");
set_equation(2, "...ij,...jk->ik");
INFER_OK(op, "?;?", "[?,?]");
set_equation(2, "abd,b...c->...cad");
INFER_OK(op, "[?,?,?];[?,?,?,?]", "[d1_1,d1_2,d1_3,d0_0,d0_2]");
set_equation(2, "...ab,b...c->ac...");
INFER_OK(op, "[?,1,?,?];[?,?,?]", "[d0_2,d1_2,d0_0,d1_1]");
// Wrong number of inputs.
set_equation(2, "ab->b");
INFER_ERROR("got: 2", op, "[?,?];[?,?]");
set_equation(1, "ab,a->b");
INFER_ERROR("got: 1", op, "[?,?]");
// Invalid format. Implicit form is not supported.
set_equation(1, "a");
INFER_ERROR("equation", op, "[2]");
set_equation(2, "ab,bc");
INFER_ERROR("equation", op, "[2,2];[2,2]");
// Wrong number of ellipsis or periods outside of ellipsis.
set_equation(1, "..a.->a...");
INFER_ERROR("ellipsis", op, "[1,1,2,1]");
set_equation(1, "...a->.a..");
INFER_ERROR("ellipsis", op, "[1,1,1,2]");
set_equation(1, "...a...->...a");
INFER_ERROR("ellipsis", op, "[1,1,1,2]");
set_equation(1, "..a..b..->...ab");
INFER_ERROR("ellipsis", op, "[1,1,2,1]");
set_equation(2, "...a...,ab->a");
INFER_ERROR("ellipsis", op, "[1,2,1];[2,1]");
set_equation(2, "a,...ab...->a");
INFER_ERROR("ellipsis", op, "[2];[1,2,1,1]");
set_equation(2, "a,ab->a......");
INFER_ERROR("ellipsis", op, "[2];[2,1]");
// Output label doesn't appear in input.
set_equation(1, "abc->d");
INFER_ERROR("'d'", op, "[?,?,?]");
// Mismatch in input rank.
set_equation(1, "abc->c");
INFER_ERROR("4", op, "[?,?,?,?]");
INFER_ERROR("2", op, "[?,?]");
set_equation(1, "...abc->...c");
INFER_ERROR("2", op, "[?,?]");
// Input dimensions are not consistent.
set_equation(2, "ab,ab->a");
INFER_ERROR("are 1 and 2", op, "[1,2];[2,1]");
set_equation(2, "aa,bb->a");
INFER_ERROR("are 1 and 2", op, "[1,2];[2,2]");
// Invalid broadcasting dimensions.
set_equation(2, "...ij,...jk->...ik");
INFER_ERROR("are 2 and 3", op, "[2,?,?];[3,?,?]");
set_equation(2, "i...j,jk...->...ik");
INFER_ERROR("are 2 and 3", op, "[?,2,?];[?,?,3]");
set_equation(2, "...ij,...jk->ik");
set_equation(2, "i...j,jk...->ik");
INFER_ERROR("non-empty broadcasting", op, "[?,2,?];[?,?]");
set_equation(2, "...ab,b...c->ac...");
INFER_OK(op, "?;[4,5,3]", "?");
}
TEST(CommonShapeFnsTest, BatchMatMulV2_ShapeFn) {
ShapeInferenceTestOp op("BatchMatMulV2");
auto set_adj = [&op](bool adj_x, bool adj_y) {

View File

@ -3115,6 +3115,7 @@ cc_library(
":cholesky_grad",
":cholesky_op",
":determinant_op",
":einsum_op",
":lu_op",
":matrix_exponential_op",
":matrix_inverse_op",
@ -3313,6 +3314,21 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "einsum_op",
prefix = "einsum_op",
deps = [
":batch_matmul_op",
":reduction_ops",
":transpose_functor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "linalg_ops_common",
srcs = ["linalg_ops_common.cc"],

View File

@ -0,0 +1,711 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_split.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/reduction_ops_common.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/einsum_op_util.h"
namespace tensorflow {
namespace {
using CPUDevice = Eigen::ThreadPoolDevice;
using ShapeVec = gtl::InlinedVector<int64, 8>;
using Labels = gtl::InlinedVector<int, 8>;
using OperandLabels = gtl::InlinedVector<Labels, 2>;
using LabelCounts = gtl::InlinedVector<int, 8>;
using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>;
using LabelToDimSizes = gtl::InlinedVector<int64, 8>;
// Dummy axis label used to denote an ellipsis in an input or output subscript.
constexpr int kEllipsisLabel = -1;
// Each dimension is categorized into exactly one of five types based on whether
// its corresponding label is present in the input and/or the output subscripts.
enum DimensionType {
// Batch dimensions are those present in two inputs as well as the output.
// They are part of the batch dimensions during Tensor contraction.
// Such dimensions may be broadcasting dimensions (those mapping to ellipsis)
// or explicit batch dimensions corresponding to named axis labels.
kBroadcasting = 0,
kBatch = 1,
// Free dimensions are present in exactly one of the inputs, and also the
// output. These are non-contracted axes in the Tensor contraction.
kFree = 2,
// Contract dimensions are present in two inputs, but not the output. These
// dimensions are contracted in Tensor contraction.
kContract = 3,
// Reduce dimensions are present in exactly one input; and not in the output
// and are summed over prior to Tensor contraction.
kReduce = 4,
};
// Returns the DimensionType given whether the corresponding label is present in
// exactly one input subscript (is_unique) and whether it is absent from the
// output subscripts (is_removed). Does not handle broadcasting dimensions.
DimensionType GetDimensionType(bool is_removed, bool is_unique) {
if (!is_removed && !is_unique)
return kBatch;
else if (!is_removed && is_unique)
return kFree;
else if (is_removed && !is_unique)
return kContract;
else // is_removed && is_unique
return kReduce;
}
// Maps the character labels to consecutive integers.
void MapToLabels(const string& subscript, Labels* labels,
absl::flat_hash_map<char, int>* label_mapping) {
for (int i = 0; i < subscript.size(); ++i) {
const char label_char = subscript[i];
if (label_char == '.') {
labels->push_back(kEllipsisLabel);
i += 2; // Skip next 2 characters as well.
continue;
}
if (!label_mapping->contains(label_char)) {
const int next_label = label_mapping->size();
(*label_mapping)[label_char] = next_label;
}
const int mapped_label = (*label_mapping)[label_char];
labels->push_back(mapped_label);
}
}
// Parses and validates the equation and the input shapes. Single character
// labels are integerized and we populate input and output label subscripts and
// corresponding counts. Also create the mapping from (named) labels to their
// DimensionType.
Status ParseEquation(const string& equation, OperandLabels* input_labels,
Labels* output_labels,
std::vector<DimensionType>* label_types,
OperandLabelCounts* input_label_counts,
LabelCounts* output_label_counts,
gtl::InlinedVector<bool, 2>* input_has_ellipsis,
bool* output_has_ellipsis) {
gtl::InlinedVector<string, 2> input_str;
string output_str;
TF_RETURN_IF_ERROR(ParseEinsumEquation(equation, &input_str, &output_str));
// Temporary map from single character labels to (consecutive) integer labels.
absl::flat_hash_map<char, int> label_mapping;
int num_inputs = input_str.size();
input_labels->resize(num_inputs);
// Map from single characters to integer labels.
for (int i = 0; i < num_inputs; ++i) {
MapToLabels(input_str[i], &input_labels->at(i), &label_mapping);
}
MapToLabels(output_str, output_labels, &label_mapping);
// Compute counts for input and output labels.
int num_labels = label_mapping.size();
input_label_counts->resize(num_inputs);
input_has_ellipsis->resize(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
input_label_counts->at(i).resize(num_labels);
for (const int label : input_labels->at(i)) {
if (label != kEllipsisLabel)
input_label_counts->at(i)[label] += 1;
else
input_has_ellipsis->at(i) = true;
}
}
output_label_counts->resize(num_labels);
for (const int label : *output_labels) {
if (label != kEllipsisLabel)
output_label_counts->at(label) += 1;
else
*output_has_ellipsis = true;
}
// Map each label to a unique DimensionType.
label_types->resize(num_labels);
for (int label = 0; label < num_labels; ++label) {
if (label == kEllipsisLabel) continue;
bool removed = (*output_label_counts)[label] == 0;
bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 ||
(*input_label_counts)[1][label] == 0;
(*label_types)[label] = GetDimensionType(removed, unique);
}
return Status::OK();
}
// Insert new (unnamed) broadcasting labels at the location of ellipsis.
void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels,
int ellipsis_axis, Labels* labels,
LabelCounts* label_counts) {
labels->erase(labels->begin() + ellipsis_axis);
labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0);
std::iota(labels->begin() + ellipsis_axis,
labels->begin() + ellipsis_axis + num_bcast_dims, num_named_labels);
// Increment label counts. Since these are new labels, the count is set to 1.
label_counts->resize(num_named_labels + num_bcast_dims, 1);
}
// Record and validate the label to dimension mapping. Must be a named
// (non-broadcasting) label as broadcasting labels don't have a fixed dimension.
Status RecordLabelToDimension(const int label, const int axis,
const Tensor& input,
LabelToDimSizes* label_to_dim_sizes) {
const int64 input_dim = input.dim_size(axis);
// We know that label_to_dim_sizes has the size to accommodate named labels.
if (label_to_dim_sizes->at(label) != 0 &&
label_to_dim_sizes->at(label) != input_dim) {
return errors::InvalidArgument(
"Expected dimension ", label_to_dim_sizes->at(label), " at axis ", axis,
" of the input shaped ", input.shape().DebugString(),
" but got dimension ", input_dim);
}
(*label_to_dim_sizes)[label] = input_dim;
return Status::OK();
}
// Validate input dimensions and populate unnamed labels and their label counts.
Status ProcessDimensions(const OpInputList& inputs,
const gtl::InlinedVector<bool, 2>& input_has_ellipsis,
const bool output_has_ellipsis,
OperandLabels* input_labels, Labels* output_labels,
std::vector<DimensionType>* label_types,
OperandLabelCounts* input_label_counts,
LabelCounts* output_label_counts,
LabelToDimSizes* label_to_dim_sizes) {
if (inputs.size() != input_labels->size()) {
return errors::InvalidArgument("Expected ", input_labels->size(),
" inputs but got: ", inputs.size());
}
const int num_inputs = inputs.size();
// We infer the number of broadcasting dimensions by taking the maximum rank
// among the broadcasting subshapes of the input.
int max_bcast_dims = 0;
const int num_named_labels = label_types->size();
label_to_dim_sizes->resize(num_named_labels);
for (int i = 0; i < num_inputs; ++i) {
Labels* labels = &(*input_labels)[i];
if (!input_has_ellipsis[i]) {
if (inputs[i].dims() != labels->size()) {
return errors::InvalidArgument("Expected input ", i, " to have rank ",
labels->size(),
" but got: ", inputs[i].dims());
}
for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
const int label = (*labels)[label_idx];
TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i],
label_to_dim_sizes));
}
continue;
}
// Input has an ellipsis.
if (inputs[i].dims() + 1 < labels->size()) {
return errors::InvalidArgument(
"Expected input ", i, " to have rank at least ", labels->size() - 1,
" but got: ", inputs[i].dims());
}
int ellipsis_axis = -1;
const int num_bcast_dims = inputs[i].dims() - labels->size() + 1;
for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
const int label = (*labels)[label_idx];
if (label == kEllipsisLabel) {
ellipsis_axis = label_idx;
continue;
}
// Current label is not an ellipsis.
const int axis =
label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1);
TF_RETURN_IF_ERROR(
RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes));
}
// Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting dimensions.
if (ellipsis_axis != -1) {
InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis,
labels, &input_label_counts->at(i));
max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims);
}
}
if (!absl::c_linear_search(input_has_ellipsis, true) &&
!output_has_ellipsis) {
return Status::OK();
}
// Insert broadcasting dimensions in the output labels.
auto it =
std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel);
if (it != output_labels->end()) {
const int ellipsis_axis = it - output_labels->begin();
InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis,
output_labels, output_label_counts);
} else if (max_bcast_dims > 0) {
return errors::InvalidArgument("Output contains ", max_bcast_dims,
" broadcasting dimension(s) but no ellipsis "
"(...) was found in the output subscripts.");
}
// Populate DimensionType for the new broadcasting labels.
label_types->resize(num_named_labels + max_bcast_dims, kBroadcasting);
return Status::OK();
}
// Permutes the labels according to the given permutation.
void PermuteLabels(const std::vector<int>& permutation, Labels* labels) {
Labels permuted_labels(labels->size());
for (int i = 0; i < labels->size(); ++i) {
permuted_labels[i] = (*labels)[permutation[i]];
}
labels->swap(permuted_labels);
}
// Returns a reshaped input Tensor. The underlying buffer is not copied.
Status CopyFrom(const Tensor& input, const TensorShape& shape, Tensor* output) {
if (output->CopyFrom(input, shape)) return Status::OK();
return errors::Internal(
"Encountered error while reshaping a Tensor of shape ",
input.shape().DebugString(), " to shape ", shape.DebugString());
}
// Returns whether transposing would be a no-op; whether input has rank < 2 or
// the permutation is the identity permutation.
bool ShouldTranspose(const TensorShape& input_shape,
const std::vector<int>& permutation) {
if (input_shape.dims() < 2) return false;
for (int i = 0; i < permutation.size(); ++i) {
if (permutation[i] != i) return true;
}
return false;
}
// Transpose the input given a permutation. Returns a reference to the input if
// transposing is not necessary.
template <typename Device, typename T>
Status TransposeOperand(OpKernelContext* ctx, const Tensor& input,
const std::vector<int>& permutation, Tensor* output) {
if (!ShouldTranspose(input.shape(), permutation)) {
return CopyFrom(input, input.shape(), output);
}
TensorShape transposed_shape;
for (int i = 0; i < input.dims(); ++i) {
transposed_shape.AddDim(input.dim_size(permutation[i]));
}
TF_RETURN_IF_ERROR(
ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output));
const Device& device = ctx->eigen_device<Device>();
TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output));
return Status::OK();
}
// If there are repeated labels in either the input or output, then this strides
// the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively.
template <typename Device, typename T>
Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input,
const Labels& labels, const LabelCounts& label_counts,
const bool should_inflate, Tensor* output) {
// Return early if there are no repeated indices.
if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) {
return CopyFrom(input, input.shape(), output);
}
// We reshape so that each repeated label is compressed to one dimension.
// E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, 5].
// Striding appropriately (in this case with strides 14 (=1+3+9) and 1)
// recovers the generalized diagonal of shape [3, 5].
ShapeVec reshape;
ShapeVec strides;
// Strided and inflated shapes correspond to input and output shapes,
// respectively, should_inflate is true (vice-versa if should_inflate is
// false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example.
ShapeVec strided_shape;
ShapeVec inflated_shape;
for (int label : labels) {
const int count = label_counts[label];
const int current_axis =
should_inflate ? strided_shape.size() : inflated_shape.size();
const int64 dim = input.dim_size(current_axis);
strided_shape.push_back(dim);
inflated_shape.insert(inflated_shape.end(), count, dim);
const int64 reshape_dim = MathUtil::IPow(dim, count);
reshape.push_back(reshape_dim);
// While taking the d-diagonal in a rank k Tensor, we take d equally-spaced
// elements including the first and last element. Then,
// (k - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1).
const int64 stride =
(dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1;
strides.push_back(stride);
}
TensorShape output_shape =
TensorShape(should_inflate ? inflated_shape : strided_shape);
TF_RETURN_IF_ERROR(
ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
const Device& device = ctx->eigen_device<Device>();
switch (reshape.size()) {
#define NDIMS_CASE(N) \
case N: { \
if (should_inflate) { \
auto output_map = output->shaped<T, N>(reshape); \
auto input_map = input.shaped<T, N>(strided_shape); \
output_map.device(device) = input_map.inflate(strides); \
} else { \
auto input_map = input.shaped<T, N>(reshape); \
auto output_map = output->shaped<T, N>(strided_shape); \
output_map.device(device) = input_map.stride(strides); \
} \
} break;
NDIMS_CASE(1);
NDIMS_CASE(2);
NDIMS_CASE(3);
NDIMS_CASE(4);
NDIMS_CASE(5);
NDIMS_CASE(6);
default:
return errors::Unimplemented(
"Unsupported rank: ", reshape.size(),
" while handling repeated indices. Up to rank 6 is supported.");
#undef NDIMS_CASE
}
return Status::OK();
}
// Returns true if the input dimensions are already sorted in the order
// [batch, contract, free, reduce]. Used to implement an optimization to avoid
// an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul.
bool ShouldSwapFreeAndContract(const Labels& labels,
const std::vector<DimensionType>& label_types) {
// Check that ordering is according to dimension type, with the role of
// free and contract dimensions swapped.
gtl::InlinedVector<int, 5> remap = {0, 1, 3, 2, 4};
for (int i = 0; i + 1 < labels.size(); ++i) {
const int dimtype_a = remap[label_types[labels[i]]];
const int dimtype_b = remap[label_types[labels[i + 1]]];
if (dimtype_a > dimtype_b ||
(dimtype_a == dimtype_b && labels[i] > labels[i + 1])) {
return false;
}
}
return true;
}
template <typename Device, typename T>
Status ReduceOperand(OpKernelContext* ctx, const Tensor& input,
const std::vector<DimensionType>& label_types,
const LabelCounts& label_counts, Labels* labels,
Labels* free_labels, bool* swap_free_and_contract,
Tensor* output) {
// Find the permutation to transpose the input dimensions in the order of
// DimensionType; i.e. batch, free, contract and reduce dimensions. This
// makes it more convenient to invoke Reduce/Contract operations.
std::vector<int> permutation(input.dims());
absl::c_iota(permutation, 0);
Tensor input_transposed;
// Check if we can avoid the transpose. We need to flip the adj_x (or adj_y)
// flag during BatchMatMul. This is an extra optimization not necessary for
// correctness.
if (ShouldSwapFreeAndContract(*labels, label_types)) {
*swap_free_and_contract = true;
} else {
absl::c_sort(permutation, [&](int i, int j) {
int label_i = (*labels)[i];
int label_j = (*labels)[j];
return std::tie(label_types[label_i], label_i) <
std::tie(label_types[label_j], label_j);
});
}
// Transpose the input so that DimensionTypes are in order.
TF_RETURN_IF_ERROR(
TransposeOperand<Device, T>(ctx, input, permutation, &input_transposed));
PermuteLabels(permutation, labels);
// Take the generalized diagonal for dimensions with repeated axis labels.
Tensor input_deduped;
labels->erase(std::unique(labels->begin(), labels->end()), labels->end());
TF_RETURN_IF_ERROR(
StrideOrInflate<Device, T>(ctx, input_transposed, *labels, label_counts,
false /* should_inflate */, &input_deduped));
// Reshape denotes the rank-5 shape [broadcast, batch, free, contract, reduce]
// where we've compacted the dimensions of each DimensionType.
gtl::InlinedVector<int64, 5> reshape(5, 1);
// The output shape is [batch shape] + [free size, contract size]
// That is, the batch shape is preserved (for broadcasting while contracting)
// while the free dims and contract dims are compressed to one dimension each.
TensorShape output_shape;
for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
const int label = labels->at(label_idx);
int64 dim = input_deduped.dim_size(label_idx);
if (label_types[label] == kBroadcasting || label_types[label] == kBatch) {
output_shape.AddDim(dim);
} else if (label_types[label] == kFree) {
free_labels->push_back(label);
}
reshape[label_types[label]] *= dim;
}
if (*swap_free_and_contract) std::swap(reshape[kFree], reshape[kContract]);
output_shape.AddDim(reshape[kFree]);
output_shape.AddDim(reshape[kContract]);
if (reshape[kReduce] == 1) { // No need to actually reduce.
return CopyFrom(input_deduped, output_shape, output);
}
TF_RETURN_IF_ERROR(
ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
using Reducer = Eigen::internal::SumReducer<T>;
using Index = typename TTypes<T>::Tensor::Index;
// Reduce along the last axis (i.e axis 1) of the rank-2 Tensor.
const int64 output_size = reshape[kBroadcasting] * reshape[kBatch] *
reshape[kFree] * reshape[kContract];
functor::ReduceFunctor<Device, Reducer>::Reduce(
ctx, output->shaped<T, 1>({output_size}),
const_cast<const Tensor&>(input_deduped)
.shaped<T, 2>({output_size, reshape[kReduce]}),
Eigen::array<Index, 1>({1}), Reducer());
return Status::OK();
}
// Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M].
Status ReshapeToRank3(const Tensor& input, int batch_size, Tensor* output) {
const int rank = input.dims();
TensorShape output_shape = {batch_size, input.dim_size(rank - 2),
input.dim_size(rank - 1)};
return CopyFrom(input, output_shape, output);
}
// Conjugates the input.
template <typename Device, typename T>
Status Conjugate(OpKernelContext* ctx, Tensor* input) {
std::vector<int> permutation(input->dims());
std::iota(permutation.begin(), permutation.end(), 0);
Tensor output;
TF_RETURN_IF_ERROR(
ctx->allocate_temp(DataTypeToEnum<T>::value, input->shape(), &output));
const Device& d = ctx->eigen_device<Device>();
TF_RETURN_IF_ERROR(DoConjugateTranspose(d, *input, permutation, &output));
std::swap(*input, output);
return Status::OK();
}
// Contracts the inputs along the last axis. (or the second last if the
// corresponding value of swap_free_and_contract is true). The batch dimensions
// are broadcast to the output shape.
// TODO(anudhyan): Factor this function into a BatchMatMul functor and support
// transpose_x and transpose_y attributes (in addition to adj_x and adj_y).
// Also, the BatchMatMul might devolve into a component-wise multiplication when
// the matrix shape is [1,1]; in this case BatchMatMul functor would be very
// inefficient. The functor should detect if this is the case and perform
// componentwise multiplication functor instead.
template <typename Device, typename T>
Status ContractOperands(OpKernelContext* ctx, absl::Span<const Tensor> inputs,
absl::Span<const bool> swap_free_and_contract,
Tensor* output) {
if (inputs.size() == 1) return CopyFrom(inputs[0], inputs[0].shape(), output);
MatMulBCast bcast(inputs[0].shape().dim_sizes(),
inputs[1].shape().dim_sizes());
if (!bcast.IsValid()) {
return errors::InvalidArgument(
"Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(),
" vs. ", inputs[1].shape().DebugString());
}
Tensor lhs;
TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs));
Tensor rhs;
TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs));
TensorShape output_shape = bcast.output_batch_shape();
for (int i = 0; i < inputs.size(); ++i) {
const int64 free_axis =
inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2);
output_shape.AddDim(inputs[i].dim_size(free_axis));
}
bool adj_x = swap_free_and_contract[0];
bool adj_y = !swap_free_and_contract[1];
if (is_complex<T>::value) {
if (adj_x) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &lhs));
if (adj_y) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &rhs));
}
TF_RETURN_IF_ERROR(
ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
Tensor output_reshaped;
TF_RETURN_IF_ERROR(
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, adj_x, adj_y, bcast,
&output_reshaped);
return Status::OK();
}
} // namespace
template <typename Device, typename T>
class EinsumOp : public OpKernel {
public:
explicit EinsumOp(OpKernelConstruction* c) : OpKernel(c) {
string equation;
OP_REQUIRES_OK(c, c->GetAttr("equation", &equation));
OP_REQUIRES_OK(c, ParseEquation(equation, &input_labels_, &output_labels_,
&label_types_, &input_label_counts_,
&output_label_counts_, &input_has_ellipsis_,
&output_has_ellipsis_));
}
void Compute(OpKernelContext* ctx) override {
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs));
OperandLabels input_labels(input_labels_);
Labels output_labels(output_labels_);
std::vector<DimensionType> label_types(label_types_);
OperandLabelCounts input_label_counts(input_label_counts_);
LabelCounts output_label_counts(output_label_counts_);
LabelToDimSizes label_to_dim_sizes;
OP_REQUIRES_OK(ctx, ProcessDimensions(
inputs, input_has_ellipsis_, output_has_ellipsis_,
&input_labels, &output_labels, &label_types,
&input_label_counts, &output_label_counts,
&label_to_dim_sizes));
// The reduction phase (a) sums across reduction dimensions, (b) takes
// generalized diagonals, and (c) reshapes it into shape
// [(broadcasting) batch shape] + [F,C]
// where F and C denote the total (compacted) size of free and contract
// dimensions, respectively.
const int num_inputs = inputs.size();
OperandLabels free_labels(num_inputs);
gtl::InlinedVector<Tensor, 2> inputs_reduced(num_inputs);
gtl::InlinedVector<bool, 2> swap_free_and_contract(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
OP_REQUIRES_OK(ctx,
ReduceOperand<Device, T>(
ctx, inputs[i], label_types, input_label_counts[i],
&input_labels[i], &free_labels[i],
&swap_free_and_contract[i], &inputs_reduced[i]));
}
// After reduction, the inputs should be reshaped to Tensors suitable for
// contraction. If num_inputs is 1, the reduced input is simply forwarded to
// the output.
Tensor contraction_output_reshaped;
OP_REQUIRES_OK(ctx, ContractOperands<Device, T>(
ctx, inputs_reduced, swap_free_and_contract,
&contraction_output_reshaped));
// Copy the batch labels from the contraction output. Recover the batch
// shape, which may have been broadcasted.
TensorShape result_shape = contraction_output_reshaped.shape();
result_shape.RemoveLastDims(2);
int num_labels = label_types.size();
Labels result_labels;
// All batch dimensions should be present in the contracted result. First
// the broadcasting dimensions, then the named batch dimensions.
for (int label = 0; label < num_labels; ++label) {
if (label_types[label] == kBroadcasting) result_labels.push_back(label);
}
for (int label = 0; label < num_labels; ++label) {
if (label_types[label] == kBatch) result_labels.push_back(label);
}
for (int i = 0; i < num_inputs; ++i) {
for (int label : free_labels[i]) {
result_labels.push_back(label);
result_shape.AddDim(label_to_dim_sizes[label]);
}
}
// Reshape the contraction (or reduction) result to its expanded shape:
// [(broadcasted) batch shape] + [free shape 0] + [free shape 1].
Tensor contraction_output;
OP_REQUIRES_OK(ctx, CopyFrom(contraction_output_reshaped, result_shape,
&contraction_output));
// Inflate the output if necessary. (E.g. for the equation 'i->iii' which
// may arise while computing gradient of a regular Einsum).
// TODO(anudhyan): It's possible that Eigen's contract and inflate can be
// chained here to avoid materializing an intermediate.
Tensor output_inflated;
OP_REQUIRES_OK(
ctx, StrideOrInflate<Device, T>(
ctx, contraction_output, result_labels, output_label_counts,
true /* should_inflate */, &output_inflated));
if (output_inflated.dims() > contraction_output.dims()) {
// We inflated the output. Modify result labels accordingly.
Labels inflated_labels;
for (int label : result_labels) {
inflated_labels.insert(inflated_labels.end(),
output_label_counts[label], label);
}
result_labels.swap(inflated_labels);
}
// Find the permutation to map the result labels to the output labels. Note
// that both the result and the final output may have the repeated labels,
// in which case the permutation preserves the left-to-right ordering.
// E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the
// permutation should be [0, 2, 1]. We also use the fact that repeated
// labels in the result are adjacent to each other.
std::vector<int> output_permutation(output_labels.size());
std::vector<int> label_to_position(num_labels, -1);
for (int i = 0; i < result_labels.size(); ++i) {
// Remember the position of only the leftmost result label.
if (label_to_position[result_labels[i]] == -1) {
label_to_position[result_labels[i]] = i;
}
}
for (int i = 0; i < output_labels.size(); ++i) {
output_permutation[i] = label_to_position[output_labels[i]];
// We have found the leftmost occurrence. The next one would be adjacent.
label_to_position[output_labels[i]] += 1;
}
Tensor output;
OP_REQUIRES_OK(ctx, TransposeOperand<Device, T>(
ctx, output_inflated, output_permutation, &output));
ctx->set_output(0, output);
}
private:
OperandLabels input_labels_;
Labels output_labels_;
std::vector<DimensionType> label_types_;
OperandLabelCounts input_label_counts_;
LabelCounts output_label_counts_;
gtl::InlinedVector<bool, 2> input_has_ellipsis_;
bool output_has_ellipsis_ = false;
};
#define REGISTER_EINSUM(D, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
EinsumOp<D##Device, TYPE>);
// TODO(anudhyan): Also register GPU kernels for Einsum.
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
#undef REGISTER_CPU
#undef REGISTER_EINSUM
} // namespace tensorflow

View File

@ -474,6 +474,14 @@ REGISTER_OP("TridiagonalSolve")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(TridiagonalSolveShapeFn);
REGISTER_OP("Einsum")
.Input("inputs: N * T")
.Output("output: T")
.Attr("equation: string")
.Attr("N: int >= 1")
.Attr("T: type")
.SetShapeFn(shape_inference::EinsumShape);
// Deprecated op registrations:
// Can be deleted after 3feb2017.

View File

@ -0,0 +1,47 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/einsum_op_util.h"
#include <string>
#include "absl/strings/str_split.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace tensorflow {
Status ParseEinsumEquation(const string& equation,
gtl::InlinedVector<string, 2>* input_subscripts,
string* output_subscript) {
gtl::InlinedVector<string, 2> inputs_and_output_subscripts =
absl::StrSplit(equation, "->");
if (inputs_and_output_subscripts.size() != 2) {
return errors::InvalidArgument(
"Expecting exactly one '->' in einsum equation: ", equation);
}
*output_subscript = std::move(inputs_and_output_subscripts[1]);
*input_subscripts =
absl::StrSplit(std::move(inputs_and_output_subscripts[0]), ',');
if (input_subscripts->size() != 1 && input_subscripts->size() != 2) {
return errors::InvalidArgument(
"Expecting 1 or 2 input subscripts in equation '", equation,
"' but got: ", input_subscripts->size());
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,29 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
#define TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace tensorflow {
// Parses and validates an einsum equation in explicit form.
Status ParseEinsumEquation(const string& equation,
gtl::InlinedVector<string, 2>* input_subscripts,
string* output_subscript);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_

View File

@ -2127,6 +2127,23 @@ cuda_py_test(
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "einsum_op_test",
size = "small",
srcs = ["einsum_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python/ops/linalg",
],
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "manip_ops_test",
size = "small",

View File

@ -0,0 +1,251 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.Einsum."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
class EinsumOpTest(test.TestCase):
def _check(self, s, *input_shapes, **kwargs):
dtype = kwargs.pop('dtype', np.float32)
r = np.random.RandomState(0)
inputs = []
for shape in input_shapes:
arr = np.array(r.randn(*shape)).astype(dtype)
if dtype == np.complex64 or dtype == np.complex128:
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
inputs.append(arr)
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
a = np.einsum(s, *inputs)
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s))
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
def testUnary(self):
self._check('->', ())
self._check('aa->', (3, 3))
self._check('aa->a', (3, 3))
self._check('aaa->', (3, 3, 3))
self._check('aaa->a', (3, 3, 3))
self._check('aab->a', (3, 3, 4))
self._check('ab->', (3, 3))
self._check('ab->ab', (3, 3))
self._check('abc->b', (3, 4, 5))
self._check('abc->ca', (3, 4, 5))
self._check('abc->cab', (3, 4, 5))
self._check('aabcc->a', (3, 3, 5, 4, 4))
self._check('aabcc->ac', (3, 3, 5, 4, 4))
self._check('aabcd->ad', (3, 3, 5, 4, 4))
def testUnaryEllipsis(self):
self._check('...->...', ())
self._check('...->', ())
self._check('->...', ())
# Tests from dask
self._check('a...a->a...', (2, 2))
self._check('a...a->', (2, 2))
self._check('a...a->...', (2, 5, 1, 2))
self._check('a...a->a...', (2, 1, 2))
self._check('a...a->a...', (2, 3, 4, 5, 2))
self._check('...ijk->...ki', (3, 4, 5))
self._check('...ijk->...ki', (1, 3, 4, 5))
self._check('...ijk->...ki', (2, 2, 3, 4, 5))
# Repeated indices.
self._check('i...ii->...i', (3, 2, 3, 3))
def testBinary(self):
self._check(',->', (), ())
self._check('a,a->', (3,), (3,))
self._check('a,a->a', (3,), (3,))
self._check('ba,b->', (3, 2), (3,))
self._check('ab,b->a', (3, 4), (4,))
self._check('ab,ab->', (3, 4), (3, 4))
self._check('nij,jk->nik', (5, 2, 3), (3, 4))
self._check('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
# Repeated indices.
self._check('ijj,k->ik', (2, 3, 3), (4,))
self._check('aba,a->b', (3, 4, 3), (3,))
# From https://github.com/dask/dask/pull/3412#discussion_r182413444
self._check('aab,bc->ac', (2, 2, 3), (3, 4))
self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
# Based on https://github.com/google/jax/issues/37#issuecomment-448572187
self._check('sa,shb->shab', (2, 1), (2, 3, 4))
def testBroadcasting(self):
# Batch matmul without broadcasting.
self._check('...ij,...jk->...ik', (5, 1, 2, 3), (5, 1, 3, 4))
# Batch matmul with broadcasting.
self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5))
self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5))
self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5))
self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))
# Broadcasting with repeated indices.
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4))
self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4))
self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4))
# Following 2 from # https://stackoverflow.com/a/19203475/1611416
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
def testDtypes(self):
for dtype in [np.float64, np.float32, np.complex64, np.complex128]:
self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ji,kj->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ij,jk->ki', (2, 2), (2, 2), dtype=dtype)
self._check('ji,kj->ki', (2, 2), (2, 2), dtype=dtype)
@test_util.run_in_graph_and_eager_modes
def testInvalid(self):
r = np.random.RandomState(0)
cases = [
# incorrect rank.
('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)),
('...ij,jk->ik', r.randn(3), r.randn(3, 4)),
# inconsistent dimensions.
('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)),
# broadcasting is invalid
('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)),
# output should have ellipsis when broadcasting shape is
# non-empty.
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
]
for args in cases:
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
placeholders = [
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
]
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
@test_util.run_in_graph_and_eager_modes
def testPlaceholder(self):
def check(equation, *input_and_placeholder_shapes):
r = np.random.RandomState(0)
inputs = []
input_placeholders = []
for actual_shape, placeholder_shape in input_and_placeholder_shapes:
input_np = np.array(r.randn(*actual_shape))
inputs.append(input_np)
input_placeholders.append(
array_ops.placeholder_with_default(input_np, placeholder_shape))
a = np.einsum(equation, *inputs)
b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation))
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
check('bijl,bjkm->bik', ((9, 2, 3, 5), (None, None, None, 5)),
((9, 3, 4, 7), (None, None, 4, None)))
check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
check('...ij,...->...i', ((4, 3, 1, 2), (None, 3, None, 2)),
((4, 3), (None, 3)))
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
def testOutputRepeatedLabels(self):
# This is the reverse operation of repeated input labels, to be used for
# computing symbolic gradients of einsum.
r = np.random.RandomState(0)
a = r.randn(2, 2)
s = 'a->aa'
diag_a = np.diag(np.diag(a))
b = self.evaluate(gen_linalg_ops.einsum([np.diag(a)], s))
self.assertAllClose(diag_a, b, atol=1e-4, rtol=1e-4)
class EinsumBenchmark(test.Benchmark):
cases = [
# Unary cases.
['ijk->i', 100],
['ijk->kji', 100],
# Regular matmul or batch matmul.
['ij,jk->ik', 1000],
['ji,kj->ik', 1000],
['ab,ab->', 100],
['ab,ba->', 100],
['abc,abc->', 100],
['abc,bac->', 100],
['abc,cba->', 100],
['bij,bjk->bik', 100],
['bji,bjk->bki', 100],
['ikl,kji->kl', 100],
['klj,lki->ij', 100],
['ijk,ilj->kli', 100],
['kij,mkb->ijmb', 100],
['abcd,ad->bc', 40],
# Larger binary contractions.
['ijk,jklm->il', 40],
['efabc,eabcd->efd', 30],
['fabec,abcde->fde', 30],
['efabc,edabc->efd', 30],
['eadbf,dfebc->ecfad', 30],
['abcdef,bcdfg->abcdeg', 30],
]
def benchmarkEinsum(self):
for equation, dim in self.cases:
with ops.Graph().as_default(), \
session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device('/cpu:0'):
r = np.random.RandomState(0)
input_subscripts = equation.split('->')[0].split(',')
input_vars = []
for subscript in input_subscripts:
input_shape = (dim,) * len(subscript)
input_vars.append(
variables.Variable(np.array(r.randn(*input_shape), np.float32)))
variables.global_variables_initializer().run()
# Call einsum_v1.
self.run_op_benchmark(
sess,
special_math_ops.einsum(equation, *input_vars),
min_iters=50,
name='einsum_v1_cpu_({})_{}'.format(equation, dim))
# Call gen_linalg_ops.einsum.
self.run_op_benchmark(
sess,
gen_linalg_ops.einsum(input_vars, equation),
min_iters=50,
name='einsum_v2_cpu_({})_{}'.format(equation, dim))
if __name__ == '__main__':
test.main()

View File

@ -988,6 +988,10 @@ tf_module {
name: "EditDistance"
argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "Einsum"
argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Elu"
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -988,6 +988,10 @@ tf_module {
name: "EditDistance"
argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "Einsum"
argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Elu"
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "