diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 72407213f97..cd6163415c3 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -962,6 +962,7 @@ tf_cuda_library( "util/guarded_philox_random.h", "util/mirror_pad_mode.h", "util/padding.h", + "util/einsum_op_util.h", "util/port.h", "util/ptr_util.h", "util/reffed_status_callback.h", diff --git a/tensorflow/core/api_def/base_api/api_def_Einsum.pbtxt b/tensorflow/core/api_def/base_api/api_def_Einsum.pbtxt new file mode 100644 index 00000000000..f84fd23e5e2 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Einsum.pbtxt @@ -0,0 +1,100 @@ +op { + graph_op_name: "Einsum" + in_arg { + name: "inputs" + description: <<END +List of 1 or 2 Tensors. +END + } + out_arg { + name: "output" + description: <<END +Output Tensor with shape depending upon `equation`. +END + } + attr { + name: "equation" + description: <<END +String describing the Einstein Summation operation; in the format of np.einsum. +END + } + summary: "Tensor contraction according to Einstein summation convention." + description: <<END +Implements generalized Tensor contraction and reduction. Each input Tensor must +have a corresponding input subscript appearing in the comma-separated left-hand +side of the equation. The right-hand side of the equation consists of the +output subscript. The input subscripts and the output subscript should consist +of zero or more named axis labels and at most one ellipsis (`...`). + +The named axis labels may be any single character other than those having +special meaning, namely `,.->`. The behavior of this Op is undefined if it +receives an ill-formatted equation; since the validation is done at +graph-building time, we omit format validation checks at runtime. + +Note: This Op is *not* intended to be called by the user; instead users should +call `tf.einsum` directly. It is a hidden Op used by `tf.einsum`. + +Operations are applied to the input(s) according to the following rules: + + (a) Generalized Diagonals: For input dimensions corresponding to axis labels + appearing more than once in the same input subscript, we take the + generalized (`k`-dimensional) diagonal. + For example, in the equation `iii->i` with input shape `[3, 3, 3]`, the + generalized diagonal would consist of `3` elements at indices `(0, 0, 0)`, + `(1, 1, 1)` and `(2, 2, 2)` to create a Tensor of shape `[3]`. + + (b) Reduction: Axes corresponding to labels appearing only in one input + subscript but not in the output subscript are summed over prior to Tensor + contraction. + For example, in the equation `ab,bc->b`, the axis labels `a` and `c` are + the reduction axis labels. + + (c) Batch Dimensions: Axes corresponding to labels appearing in each of the + input subscripts and also in the output subscript make up the batch + dimensions in Tensor contraction. Unnamed axis labels corresponding to + ellipsis (`...`) also correspond to batch dimensions. + For example, for the equation denoting batch matrix multiplication, + `bij,bjk->bik`, the axis label `b` corresponds to a batch dimension. + + (d) Contraction: In case of binary einsum, axes corresponding to labels + appearing in two different inputs (and not in the output) are contracted + against each other. + Considering the batch matrix multiplication equation again + (`bij,bjk->bik`), the contracted axis label is `j`. + + (e) Expand Diagonal: If the output subcripts contain repeated (explicit) axis + labels, the opposite operation of (a) is applied. For example, in the + equation `i->iii`, and input shape `[3]`, the output of shape `[3, 3, 3]` + are all zeros, except for the (generalized) diagonal which is populated + with values from the input. + Note: This operation is not supported by `np.einsum` or `tf.einsum`; it is + provided to enable computing the symbolic gradient of `tf.einsum`. + +The output subcripts must contain only labels appearing in at least one of the +input subscripts. Furthermore, all dimensions mapping to the same axis label +must be equal. + +Any of the input and output subscripts may contain at most a single ellipsis +(`...`). These ellipsis are mapped against dimensions not corresponding to any +named axis label. If two inputs contain ellipsis, then they are broadcasted +according to standard NumPy broadcasting +[rules](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). + +The broadcasted dimensions are placed in the corresponding location of the +ellipsis in the output subscript. If the broadcasted dimensions are non-empty +and the output subcripts do not contain ellipsis, then an InvalidArgument error +is raised. + +@compatibility(numpy) +Similar to [`numpy.einsum`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html). + +Comparison with `numpy.einsum`: + + * This Op only supports unary and binary forms of `numpy.einsum`. + * This Op does not support implicit form. (i.e. equations without `->`). + * This Op also supports repeated indices in the output subscript, which is not + supported by `numpy.einsum`. +@end_compatibility + +END +} diff --git a/tensorflow/core/api_def/python_api/api_def_Einsum.pbtxt b/tensorflow/core/api_def/python_api/api_def_Einsum.pbtxt new file mode 100644 index 00000000000..5178c80df4b --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_Einsum.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "Einsum" + visibility: HIDDEN +} diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 8691716ef6a..fb98a880aeb 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -12,12 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" - #include <unordered_set> +#include "absl/container/flat_hash_map.h" +#include "absl/strings/match.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/util/einsum_op_util.h" namespace tensorflow { @@ -236,6 +242,178 @@ Status MatMulShape(shape_inference::InferenceContext* c) { return Status::OK(); } +namespace { + +// Validate that an Einsum subscript contains exactly one or zero ellipsis; and +// that periods (.) occur only within an ellipses (...). +Status ValidateEinsumEllipsis(absl::string_view subscript, + bool* found_ellipsis) { + const int num_periods = absl::c_count(subscript, '.'); + if (num_periods != 0 && num_periods != 3) { + return errors::InvalidArgument( + "Expected at most one ellipsis (...), but found ", num_periods, + " periods (.) in the input subscript: ", subscript); + } + if (num_periods == 3 && !absl::StrContains(subscript, "...")) { + return errors::InvalidArgument( + "Periods found outside of ellipsis in subscript: ", subscript); + } + *found_ellipsis = num_periods > 0; + return Status::OK(); +} + +} // namespace + +Status EinsumShape(shape_inference::InferenceContext* c) { + // We assume that the equation has a valid format. Either (x),(y)->(z) + // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or + // more latin alphabets and contains at most one ellipsis ('...'). + string equation; + TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation)); + gtl::InlinedVector<string, 2> input_labels; + string output_labels; + TF_RETURN_IF_ERROR( + ParseEinsumEquation(equation, &input_labels, &output_labels)); + + if (c->num_inputs() == 0 || c->num_inputs() > 2) { + return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ", + c->num_inputs()); + } + if (c->num_inputs() != input_labels.size()) { + return errors::InvalidArgument("Expected ", input_labels.size(), + " inputs for equation ", equation, + " but got: ", c->num_inputs()); + } + + // Validate input subscripts, build the label to dimension mapping and obtain + // the broadcast shapes that map to ellipsis. + absl::flat_hash_map<char, DimensionHandle> label_to_dimension; + gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs()); + for (int i = 0; i < c->num_inputs(); ++i) { + bool has_ellipsis = false; + TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis)); + ShapeHandle input_shape = c->input(i); + // Validate that the input rank is sufficient for the given number of named + // labels. + if (c->RankKnown(input_shape)) { + if (has_ellipsis) { + const int num_named_labels = + static_cast<int>(input_labels[i].size()) - 3; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + c->WithRankAtLeast(input_shape, num_named_labels, &input_shape), + " for ", i, "th input and equation: ", equation); + } else { + const int num_named_labels = static_cast<int>(input_labels[i].size()); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + c->WithRank(input_shape, num_named_labels, &input_shape), " for ", + i, "th input and equation: ", equation); + } + } + + bool seen_ellipsis = false; + input_bcast_shapes[i] = c->Scalar(); + // Run through the input labels; populate label_to_dimension mapping and + // compute the broadcast shapes corresponding to the ellipsis (if present). + for (int label_idx = 0; label_idx < input_labels[i].size(); ++label_idx) { + const char label = input_labels[i][label_idx]; + // Calculate the input axis that the current label is referring to. After + // the ellipsis, the axis may be found by using negative indices; i.e the + // (rank - k)th dimension corresponds to the (num_labels - k)th label. + const int64 axis_before_ellipsis = label_idx; + const int64 axis_after_ellipsis = + c->RankKnown(input_shape) + ? label_idx + c->Rank(input_shape) - input_labels[i].size() + : -1; + + // Populate the input broadcast shape when we encounter an ellipsis (...). + if (label == '.') { + if (!c->RankKnown(input_shape)) { + input_bcast_shapes[i] = c->UnknownShape(); + } else { + // The broadcast shape runs till the named label right after the + // ellipsis, the label with index (label_idx + 3). + TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis, + axis_after_ellipsis + 3, + &input_bcast_shapes[i])); + } + label_idx += 2; // Skip the rest of the ellipsis. + seen_ellipsis = true; + continue; + } + // Obtain the dimension that the current label corresponds to. + int64 axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis; + DimensionHandle new_dim = c->RankKnown(input_shape) + ? c->Dim(input_shape, axis) + : c->UnknownDim(); + // If we've seen this label before, make sure previous and current + // dimensions are compatible. + if (label_to_dimension.contains(label)) { + DimensionHandle merged; + TF_RETURN_IF_ERROR( + c->Merge(label_to_dimension[label], new_dim, &merged)); + label_to_dimension[label] = merged; + } else { + label_to_dimension[label] = new_dim; + } + } + } + + // For two inputs, broadcast the two input broadcast shapes to create the + // output broadcast shape. For one input, just copy the single broadcast + // shape. + ShapeHandle output_bcast_shape; + if (input_bcast_shapes.size() == 1) { + output_bcast_shape = input_bcast_shapes[0]; + } else if (input_bcast_shapes.size() == 2) { + TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( + c, input_bcast_shapes[0], input_bcast_shapes[1], &output_bcast_shape)); + } + + bool output_has_ellipsis = false; + TF_RETURN_IF_ERROR( + ValidateEinsumEllipsis(output_labels, &output_has_ellipsis)); + if (output_has_ellipsis) { + // If the output subscript has ellipsis and the output broadcast rank is + // unknown, then the output shape should have unknown rank. + if (!c->RankKnown(output_bcast_shape)) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } + } else { + // If the output subscripts don't have ellipsis then make sure the output + // broadcasting shape is empty. + TF_RETURN_WITH_CONTEXT_IF_ERROR( + c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape), + " for einsum equation '", equation, + "' without ellipsis (...) in the output subscripts where input(s) have " + "non-empty broadcasting shape"); + output_bcast_shape = c->Scalar(); + } + + // Create the output shape from output labels and label_to_dimension mapping. + std::vector<DimensionHandle> output_dims; + for (int label_idx = 0; label_idx < output_labels.size(); ++label_idx) { + const char label = output_labels[label_idx]; + // Append the output_bcast_shape when the ellipsis is encountered. + if (label == '.') { + for (int k = 0; k < c->Rank(output_bcast_shape); ++k) { + output_dims.push_back(c->Dim(output_bcast_shape, k)); + } + label_idx += 2; // Skip the rest of the ellipsis. + continue; + } + auto dimension_it = label_to_dimension.find(label); + if (dimension_it == label_to_dimension.end()) { + return errors::InvalidArgument( + "Einsum output subscripts for equation '", equation, "' has label '", + label, "' which is not present in the input subscripts"); + } + output_dims.push_back(dimension_it->second); + } + c->set_output(0, c->MakeShape(output_dims)); + return Status::OK(); +} + Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) { ShapeHandle a_shape; ShapeHandle b_shape; diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 5050270a2f2..e11a8557f60 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -230,6 +230,9 @@ Status MatMulShape(shape_inference::InferenceContext* c); // batch dimensions. Status BatchMatMulV2Shape(shape_inference::InferenceContext* c); +// Shape function for Einsum. +Status EinsumShape(shape_inference::InferenceContext* c); + // Shape function for BiasAdd-like operations. Status BiasAddShape(shape_inference::InferenceContext* c); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index f30fd2aea3a..40ca891f929 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -213,6 +213,129 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { } } +TEST(CommonShapeFnsTest, Einsum_ShapeFn) { + ShapeInferenceTestOp op("Einsum"); + auto set_equation = [&op](int n, string equation) { + std::vector<NodeDefBuilder::NodeOut> input_list; + input_list.reserve(n); + for (int i = 0; i < n; ++i) { + input_list.emplace_back("a", 0, DT_FLOAT); + } + TF_ASSERT_OK(NodeDefBuilder("test", "Einsum") + .Input(input_list) + .Attr("equation", equation) + .Finalize(&op.node_def)); + }; + + // Unary cases. + set_equation(1, "abc->c"); + INFER_OK(op, "[?,?,?]", "[d0_2]"); + set_equation(1, "abc->aabbcc"); + INFER_OK(op, "[?,?,?]", "[d0_0,d0_0,d0_1,d0_1,d0_2,d0_2]"); + set_equation(1, "abc->"); + INFER_OK(op, "[?,?,?]", "[]"); + set_equation(1, "->"); + INFER_OK(op, "[]", "[]"); + + // Binary cases. + set_equation(2, "ij,jk->ik"); + INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); + set_equation(2, "ij,jk->ik"); + INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); + set_equation(2, "ab,ab->"); + INFER_OK(op, "[?,?];[?,?]", "[]"); + set_equation(2, "ab,->b"); + INFER_OK(op, "[?,?];[]", "[d0_1]"); + set_equation(2, ",->"); + INFER_OK(op, "[];[]", "[]"); + set_equation(2, "aaa,b->abbb"); + INFER_OK(op, "[?,?,?];[?]", "[d0_0,d1_0,d1_0,d1_0]"); + set_equation(2, ",abcd->badc"); + INFER_OK(op, "[];[?,?,?,?]", "[d1_1,d1_0,d1_3,d1_2]"); + + // Ellipsis cases. + set_equation(1, "a...bc->c..."); + INFER_OK(op, "[?,?,?,?,?]", "[d0_4,d0_1,d0_2]"); + set_equation(2, "...ij,...jk->...ik"); + INFER_OK(op, "[?,?,?,?,?];[1,?,?]", "[d0_0,d0_1,d0_2,d0_3,d1_2]"); + INFER_OK(op, "[1,?,?];[?,?,?,?,?]", "[d1_0,d1_1,d1_2,d0_1,d1_4]"); + + // Unknown rank. + set_equation(1, "abc->c"); + INFER_OK(op, "?", "[?]"); + set_equation(1, "a...bc->c"); + INFER_OK(op, "?", "[?]"); + set_equation(1, "a...bc->c..."); + INFER_OK(op, "?", "?"); + + set_equation(2, "...ij,...jk->...ik"); + INFER_OK(op, "?;?", "?"); + INFER_OK(op, "[?,?,?];?", "?"); + INFER_OK(op, "?;[?,?,?]", "?"); + set_equation(2, "...ij,...jk->ik"); + INFER_OK(op, "?;?", "[?,?]"); + set_equation(2, "abd,b...c->...cad"); + INFER_OK(op, "[?,?,?];[?,?,?,?]", "[d1_1,d1_2,d1_3,d0_0,d0_2]"); + set_equation(2, "...ab,b...c->ac..."); + INFER_OK(op, "[?,1,?,?];[?,?,?]", "[d0_2,d1_2,d0_0,d1_1]"); + + // Wrong number of inputs. + set_equation(2, "ab->b"); + INFER_ERROR("got: 2", op, "[?,?];[?,?]"); + set_equation(1, "ab,a->b"); + INFER_ERROR("got: 1", op, "[?,?]"); + + // Invalid format. Implicit form is not supported. + set_equation(1, "a"); + INFER_ERROR("equation", op, "[2]"); + set_equation(2, "ab,bc"); + INFER_ERROR("equation", op, "[2,2];[2,2]"); + + // Wrong number of ellipsis or periods outside of ellipsis. + set_equation(1, "..a.->a..."); + INFER_ERROR("ellipsis", op, "[1,1,2,1]"); + set_equation(1, "...a->.a.."); + INFER_ERROR("ellipsis", op, "[1,1,1,2]"); + set_equation(1, "...a...->...a"); + INFER_ERROR("ellipsis", op, "[1,1,1,2]"); + set_equation(1, "..a..b..->...ab"); + INFER_ERROR("ellipsis", op, "[1,1,2,1]"); + set_equation(2, "...a...,ab->a"); + INFER_ERROR("ellipsis", op, "[1,2,1];[2,1]"); + set_equation(2, "a,...ab...->a"); + INFER_ERROR("ellipsis", op, "[2];[1,2,1,1]"); + set_equation(2, "a,ab->a......"); + INFER_ERROR("ellipsis", op, "[2];[2,1]"); + + // Output label doesn't appear in input. + set_equation(1, "abc->d"); + INFER_ERROR("'d'", op, "[?,?,?]"); + + // Mismatch in input rank. + set_equation(1, "abc->c"); + INFER_ERROR("4", op, "[?,?,?,?]"); + INFER_ERROR("2", op, "[?,?]"); + set_equation(1, "...abc->...c"); + INFER_ERROR("2", op, "[?,?]"); + + // Input dimensions are not consistent. + set_equation(2, "ab,ab->a"); + INFER_ERROR("are 1 and 2", op, "[1,2];[2,1]"); + set_equation(2, "aa,bb->a"); + INFER_ERROR("are 1 and 2", op, "[1,2];[2,2]"); + + // Invalid broadcasting dimensions. + set_equation(2, "...ij,...jk->...ik"); + INFER_ERROR("are 2 and 3", op, "[2,?,?];[3,?,?]"); + set_equation(2, "i...j,jk...->...ik"); + INFER_ERROR("are 2 and 3", op, "[?,2,?];[?,?,3]"); + set_equation(2, "...ij,...jk->ik"); + set_equation(2, "i...j,jk...->ik"); + INFER_ERROR("non-empty broadcasting", op, "[?,2,?];[?,?]"); + set_equation(2, "...ab,b...c->ac..."); + INFER_OK(op, "?;[4,5,3]", "?"); +} + TEST(CommonShapeFnsTest, BatchMatMulV2_ShapeFn) { ShapeInferenceTestOp op("BatchMatMulV2"); auto set_adj = [&op](bool adj_x, bool adj_y) { diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 644e591100c..bd330539c7f 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3216,6 +3216,7 @@ cc_library( ":cholesky_grad", ":cholesky_op", ":determinant_op", + ":einsum_op", ":lu_op", ":matrix_exponential_op", ":matrix_inverse_op", @@ -3414,6 +3415,21 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "einsum_op", + prefix = "einsum_op", + deps = [ + ":batch_matmul_op", + ":reduction_ops", + ":transpose_functor", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "linalg_ops_common", srcs = ["linalg_ops_common.cc"], diff --git a/tensorflow/core/kernels/einsum_op.cc b/tensorflow/core/kernels/einsum_op.cc new file mode 100644 index 00000000000..ae5733c19d6 --- /dev/null +++ b/tensorflow/core/kernels/einsum_op.cc @@ -0,0 +1,712 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_split.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/batch_matmul_op_impl.h" +#include "tensorflow/core/kernels/reduction_ops_common.h" +#include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/einsum_op_util.h" + +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; + +namespace { + +using ShapeVec = gtl::InlinedVector<int64, 8>; +using Labels = gtl::InlinedVector<int, 8>; +using OperandLabels = gtl::InlinedVector<Labels, 2>; +using LabelCounts = gtl::InlinedVector<int, 8>; +using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>; +using LabelToDimSizes = gtl::InlinedVector<int64, 8>; + +// Dummy axis label used to denote an ellipsis in an input or output subscript. +constexpr int kEllipsisLabel = -1; + +// Each dimension is categorized into exactly one of five types based on whether +// its corresponding label is present in the input and/or the output subscripts. +enum DimensionType { + // Batch dimensions are those present in two inputs as well as the output. + // They are part of the batch dimensions during Tensor contraction. + // Such dimensions may be broadcasting dimensions (those mapping to ellipsis) + // or explicit batch dimensions corresponding to named axis labels. + kBroadcasting = 0, + kBatch = 1, + // Free dimensions are present in exactly one of the inputs, and also the + // output. These are non-contracted axes in the Tensor contraction. + kFree = 2, + // Contract dimensions are present in two inputs, but not the output. These + // dimensions are contracted in Tensor contraction. + kContract = 3, + // Reduce dimensions are present in exactly one input; and not in the output + // and are summed over prior to Tensor contraction. + kReduce = 4, +}; + +// Returns the DimensionType given whether the corresponding label is present in +// exactly one input subscript (is_unique) and whether it is absent from the +// output subscripts (is_removed). Does not handle broadcasting dimensions. +DimensionType GetDimensionType(bool is_removed, bool is_unique) { + if (!is_removed && !is_unique) + return kBatch; + else if (!is_removed && is_unique) + return kFree; + else if (is_removed && !is_unique) + return kContract; + else // is_removed && is_unique + return kReduce; +} + +// Maps the character labels to consecutive integers. +void MapToLabels(const string& subscript, Labels* labels, + absl::flat_hash_map<char, int>* label_mapping) { + for (int i = 0; i < subscript.size(); ++i) { + const char label_char = subscript[i]; + if (label_char == '.') { + labels->push_back(kEllipsisLabel); + i += 2; // Skip next 2 characters as well. + continue; + } + if (!label_mapping->contains(label_char)) { + const int next_label = label_mapping->size(); + (*label_mapping)[label_char] = next_label; + } + const int mapped_label = (*label_mapping)[label_char]; + labels->push_back(mapped_label); + } +} + +// Parses and validates the equation and the input shapes. Single character +// labels are integerized and we populate input and output label subscripts and +// corresponding counts. Also create the mapping from (named) labels to their +// DimensionType. +Status ParseEquation(const string& equation, OperandLabels* input_labels, + Labels* output_labels, + std::vector<DimensionType>* label_types, + OperandLabelCounts* input_label_counts, + LabelCounts* output_label_counts, + gtl::InlinedVector<bool, 2>* input_has_ellipsis, + bool* output_has_ellipsis) { + gtl::InlinedVector<string, 2> input_str; + string output_str; + TF_RETURN_IF_ERROR(ParseEinsumEquation(equation, &input_str, &output_str)); + + // Temporary map from single character labels to (consecutive) integer labels. + absl::flat_hash_map<char, int> label_mapping; + int num_inputs = input_str.size(); + input_labels->resize(num_inputs); + + // Map from single characters to integer labels. + for (int i = 0; i < num_inputs; ++i) { + MapToLabels(input_str[i], &input_labels->at(i), &label_mapping); + } + MapToLabels(output_str, output_labels, &label_mapping); + + // Compute counts for input and output labels. + int num_labels = label_mapping.size(); + input_label_counts->resize(num_inputs); + input_has_ellipsis->resize(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + input_label_counts->at(i).resize(num_labels); + for (const int label : input_labels->at(i)) { + if (label != kEllipsisLabel) + input_label_counts->at(i)[label] += 1; + else + input_has_ellipsis->at(i) = true; + } + } + output_label_counts->resize(num_labels); + for (const int label : *output_labels) { + if (label != kEllipsisLabel) + output_label_counts->at(label) += 1; + else + *output_has_ellipsis = true; + } + + // Map each label to a unique DimensionType. + label_types->resize(num_labels); + for (int label = 0; label < num_labels; ++label) { + if (label == kEllipsisLabel) continue; + bool removed = (*output_label_counts)[label] == 0; + bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 || + (*input_label_counts)[1][label] == 0; + (*label_types)[label] = GetDimensionType(removed, unique); + } + return Status::OK(); +} + +// Insert new (unnamed) broadcasting labels at the location of ellipsis. +void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels, + int ellipsis_axis, Labels* labels, + LabelCounts* label_counts) { + labels->erase(labels->begin() + ellipsis_axis); + labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0); + std::iota(labels->begin() + ellipsis_axis, + labels->begin() + ellipsis_axis + num_bcast_dims, num_named_labels); + // Increment label counts. Since these are new labels, the count is set to 1. + label_counts->resize(num_named_labels + num_bcast_dims, 1); +} + +// Record and validate the label to dimension mapping. Must be a named +// (non-broadcasting) label as broadcasting labels don't have a fixed dimension. +Status RecordLabelToDimension(const int label, const int axis, + const Tensor& input, + LabelToDimSizes* label_to_dim_sizes) { + const int64 input_dim = input.dim_size(axis); + // We know that label_to_dim_sizes has the size to accommodate named labels. + if (label_to_dim_sizes->at(label) != 0 && + label_to_dim_sizes->at(label) != input_dim) { + return errors::InvalidArgument( + "Expected dimension ", label_to_dim_sizes->at(label), " at axis ", axis, + " of the input shaped ", input.shape().DebugString(), + " but got dimension ", input_dim); + } + (*label_to_dim_sizes)[label] = input_dim; + return Status::OK(); +} + +// Validate input dimensions and populate unnamed labels and their label counts. +Status ProcessDimensions(const OpInputList& inputs, + const gtl::InlinedVector<bool, 2>& input_has_ellipsis, + const bool output_has_ellipsis, + OperandLabels* input_labels, Labels* output_labels, + std::vector<DimensionType>* label_types, + OperandLabelCounts* input_label_counts, + LabelCounts* output_label_counts, + LabelToDimSizes* label_to_dim_sizes) { + if (inputs.size() != input_labels->size()) { + return errors::InvalidArgument("Expected ", input_labels->size(), + " inputs but got: ", inputs.size()); + } + const int num_inputs = inputs.size(); + + // We infer the number of broadcasting dimensions by taking the maximum rank + // among the broadcasting subshapes of the input. + int max_bcast_dims = 0; + const int num_named_labels = label_types->size(); + label_to_dim_sizes->resize(num_named_labels); + for (int i = 0; i < num_inputs; ++i) { + Labels* labels = &(*input_labels)[i]; + + if (!input_has_ellipsis[i]) { + if (inputs[i].dims() != labels->size()) { + return errors::InvalidArgument("Expected input ", i, " to have rank ", + labels->size(), + " but got: ", inputs[i].dims()); + } + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i], + label_to_dim_sizes)); + } + continue; + } + + // Input has an ellipsis. + if (inputs[i].dims() + 1 < labels->size()) { + return errors::InvalidArgument( + "Expected input ", i, " to have rank at least ", labels->size() - 1, + " but got: ", inputs[i].dims()); + } + int ellipsis_axis = -1; + const int num_bcast_dims = inputs[i].dims() - labels->size() + 1; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + if (label == kEllipsisLabel) { + ellipsis_axis = label_idx; + continue; + } + // Current label is not an ellipsis. + const int axis = + label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1); + TF_RETURN_IF_ERROR( + RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes)); + } + // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting dimensions. + if (ellipsis_axis != -1) { + InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, + labels, &input_label_counts->at(i)); + max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims); + } + } + if (!absl::c_linear_search(input_has_ellipsis, true) && + !output_has_ellipsis) { + return Status::OK(); + } + // Insert broadcasting dimensions in the output labels. + auto it = + std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel); + if (it != output_labels->end()) { + const int ellipsis_axis = it - output_labels->begin(); + InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, + output_labels, output_label_counts); + } else if (max_bcast_dims > 0) { + return errors::InvalidArgument("Output contains ", max_bcast_dims, + " broadcasting dimension(s) but no ellipsis " + "(...) was found in the output subscripts."); + } + // Populate DimensionType for the new broadcasting labels. + label_types->resize(num_named_labels + max_bcast_dims, kBroadcasting); + return Status::OK(); +} + +// Permutes the labels according to the given permutation. +void PermuteLabels(const std::vector<int>& permutation, Labels* labels) { + Labels permuted_labels(labels->size()); + for (int i = 0; i < labels->size(); ++i) { + permuted_labels[i] = (*labels)[permutation[i]]; + } + labels->swap(permuted_labels); +} + +// Returns a reshaped input Tensor. The underlying buffer is not copied. +Status CopyFrom(const Tensor& input, const TensorShape& shape, Tensor* output) { + if (output->CopyFrom(input, shape)) return Status::OK(); + return errors::Internal( + "Encountered error while reshaping a Tensor of shape ", + input.shape().DebugString(), " to shape ", shape.DebugString()); +} + +// Returns whether transposing would be a no-op; whether input has rank < 2 or +// the permutation is the identity permutation. +bool ShouldTranspose(const TensorShape& input_shape, + const std::vector<int>& permutation) { + if (input_shape.dims() < 2) return false; + for (int i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) return true; + } + return false; +} + +// Transpose the input given a permutation. Returns a reference to the input if +// transposing is not necessary. +template <typename Device, typename T> +Status TransposeOperand(OpKernelContext* ctx, const Tensor& input, + const std::vector<int>& permutation, Tensor* output) { + if (!ShouldTranspose(input.shape(), permutation)) { + return CopyFrom(input, input.shape(), output); + } + TensorShape transposed_shape; + for (int i = 0; i < input.dims(); ++i) { + transposed_shape.AddDim(input.dim_size(permutation[i])); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output)); + const Device& device = ctx->eigen_device<Device>(); + TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output)); + return Status::OK(); +} + +// If there are repeated labels in either the input or output, then this strides +// the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively. +template <typename Device, typename T> +Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, + const Labels& labels, const LabelCounts& label_counts, + const bool should_inflate, Tensor* output) { + // Return early if there are no repeated indices. + if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) { + return CopyFrom(input, input.shape(), output); + } + // We reshape so that each repeated label is compressed to one dimension. + // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, 5]. + // Striding appropriately (in this case with strides 14 (=1+3+9) and 1) + // recovers the generalized diagonal of shape [3, 5]. + ShapeVec reshape; + ShapeVec strides; + // Strided and inflated shapes correspond to input and output shapes, + // respectively, should_inflate is true (vice-versa if should_inflate is + // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example. + ShapeVec strided_shape; + ShapeVec inflated_shape; + for (int label : labels) { + const int count = label_counts[label]; + const int current_axis = + should_inflate ? strided_shape.size() : inflated_shape.size(); + const int64 dim = input.dim_size(current_axis); + strided_shape.push_back(dim); + inflated_shape.insert(inflated_shape.end(), count, dim); + const int64 reshape_dim = MathUtil::IPow(dim, count); + reshape.push_back(reshape_dim); + // While taking the d-diagonal in a rank k Tensor, we take d equally-spaced + // elements including the first and last element. Then, + // (k - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1). + const int64 stride = + (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1; + strides.push_back(stride); + } + + TensorShape output_shape = + TensorShape(should_inflate ? inflated_shape : strided_shape); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output)); + const Device& device = ctx->eigen_device<Device>(); + switch (reshape.size()) { +#define NDIMS_CASE(N) \ + case N: { \ + if (should_inflate) { \ + auto output_map = output->shaped<T, N>(reshape); \ + auto input_map = input.shaped<T, N>(strided_shape); \ + output_map.device(device) = input_map.inflate(strides); \ + } else { \ + auto input_map = input.shaped<T, N>(reshape); \ + auto output_map = output->shaped<T, N>(strided_shape); \ + output_map.device(device) = input_map.stride(strides); \ + } \ + } break; + NDIMS_CASE(1); + NDIMS_CASE(2); + NDIMS_CASE(3); + NDIMS_CASE(4); + NDIMS_CASE(5); + NDIMS_CASE(6); + default: + return errors::Unimplemented( + "Unsupported rank: ", reshape.size(), + " while handling repeated indices. Up to rank 6 is supported."); +#undef NDIMS_CASE + } + return Status::OK(); +} + +// Returns true if the input dimensions are already sorted in the order +// [batch, contract, free, reduce]. Used to implement an optimization to avoid +// an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul. +bool ShouldSwapFreeAndContract(const Labels& labels, + const std::vector<DimensionType>& label_types) { + // Check that ordering is according to dimension type, with the role of + // free and contract dimensions swapped. + gtl::InlinedVector<int, 5> remap = {0, 1, 3, 2, 4}; + for (int i = 0; i + 1 < labels.size(); ++i) { + const int dimtype_a = remap[label_types[labels[i]]]; + const int dimtype_b = remap[label_types[labels[i + 1]]]; + if (dimtype_a > dimtype_b || + (dimtype_a == dimtype_b && labels[i] > labels[i + 1])) { + return false; + } + } + return true; +} + +template <typename Device, typename T> +Status ReduceOperand(OpKernelContext* ctx, const Tensor& input, + const std::vector<DimensionType>& label_types, + const LabelCounts& label_counts, Labels* labels, + Labels* free_labels, bool* swap_free_and_contract, + Tensor* output) { + // Find the permutation to transpose the input dimensions in the order of + // DimensionType; i.e. batch, free, contract and reduce dimensions. This + // makes it more convenient to invoke Reduce/Contract operations. + std::vector<int> permutation(input.dims()); + absl::c_iota(permutation, 0); + Tensor input_transposed; + // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y) + // flag during BatchMatMul. This is an extra optimization not necessary for + // correctness. + if (ShouldSwapFreeAndContract(*labels, label_types)) { + *swap_free_and_contract = true; + } else { + absl::c_sort(permutation, [&](int i, int j) { + int label_i = (*labels)[i]; + int label_j = (*labels)[j]; + return std::tie(label_types[label_i], label_i) < + std::tie(label_types[label_j], label_j); + }); + } + // Transpose the input so that DimensionTypes are in order. + TF_RETURN_IF_ERROR( + TransposeOperand<Device, T>(ctx, input, permutation, &input_transposed)); + PermuteLabels(permutation, labels); + + // Take the generalized diagonal for dimensions with repeated axis labels. + Tensor input_deduped; + labels->erase(std::unique(labels->begin(), labels->end()), labels->end()); + TF_RETURN_IF_ERROR( + StrideOrInflate<Device, T>(ctx, input_transposed, *labels, label_counts, + false /* should_inflate */, &input_deduped)); + + // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, reduce] + // where we've compacted the dimensions of each DimensionType. + gtl::InlinedVector<int64, 5> reshape(5, 1); + // The output shape is [batch shape] + [free size, contract size] + // That is, the batch shape is preserved (for broadcasting while contracting) + // while the free dims and contract dims are compressed to one dimension each. + TensorShape output_shape; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = labels->at(label_idx); + int64 dim = input_deduped.dim_size(label_idx); + if (label_types[label] == kBroadcasting || label_types[label] == kBatch) { + output_shape.AddDim(dim); + } else if (label_types[label] == kFree) { + free_labels->push_back(label); + } + reshape[label_types[label]] *= dim; + } + if (*swap_free_and_contract) std::swap(reshape[kFree], reshape[kContract]); + output_shape.AddDim(reshape[kFree]); + output_shape.AddDim(reshape[kContract]); + + if (reshape[kReduce] == 1) { // No need to actually reduce. + return CopyFrom(input_deduped, output_shape, output); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output)); + using Reducer = Eigen::internal::SumReducer<T>; + using Index = typename TTypes<T>::Tensor::Index; + // Reduce along the last axis (i.e axis 1) of the rank-2 Tensor. + const int64 output_size = reshape[kBroadcasting] * reshape[kBatch] * + reshape[kFree] * reshape[kContract]; + functor::ReduceFunctor<Device, Reducer>::Reduce( + ctx, output->shaped<T, 1>({output_size}), + const_cast<const Tensor&>(input_deduped) + .shaped<T, 2>({output_size, reshape[kReduce]}), + Eigen::array<Index, 1>({1}), Reducer()); + return Status::OK(); +} + +// Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. +Status ReshapeToRank3(const Tensor& input, int batch_size, Tensor* output) { + const int rank = input.dims(); + TensorShape output_shape = {batch_size, input.dim_size(rank - 2), + input.dim_size(rank - 1)}; + return CopyFrom(input, output_shape, output); +} + +// Conjugates the input. +template <typename Device, typename T> +Status Conjugate(OpKernelContext* ctx, Tensor* input) { + std::vector<int> permutation(input->dims()); + std::iota(permutation.begin(), permutation.end(), 0); + Tensor output; + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum<T>::value, input->shape(), &output)); + const Device& d = ctx->eigen_device<Device>(); + TF_RETURN_IF_ERROR(DoConjugateTranspose(d, *input, permutation, &output)); + std::swap(*input, output); + return Status::OK(); +} + +// Contracts the inputs along the last axis. (or the second last if the +// corresponding value of swap_free_and_contract is true). The batch dimensions +// are broadcast to the output shape. +// TODO(anudhyan): Factor this function into a BatchMatMul functor and support +// transpose_x and transpose_y attributes (in addition to adj_x and adj_y). +// Also, the BatchMatMul might devolve into a component-wise multiplication when +// the matrix shape is [1,1]; in this case BatchMatMul functor would be very +// inefficient. The functor should detect if this is the case and perform +// componentwise multiplication functor instead. +template <typename Device, typename T> +Status ContractOperands(OpKernelContext* ctx, absl::Span<const Tensor> inputs, + absl::Span<const bool> swap_free_and_contract, + Tensor* output) { + if (inputs.size() == 1) return CopyFrom(inputs[0], inputs[0].shape(), output); + MatMulBCast bcast(inputs[0].shape().dim_sizes(), + inputs[1].shape().dim_sizes()); + if (!bcast.IsValid()) { + return errors::InvalidArgument( + "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(), + " vs. ", inputs[1].shape().DebugString()); + } + Tensor lhs; + TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs)); + Tensor rhs; + TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs)); + TensorShape output_shape = bcast.output_batch_shape(); + for (int i = 0; i < inputs.size(); ++i) { + const int64 free_axis = + inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2); + output_shape.AddDim(inputs[i].dim_size(free_axis)); + } + bool adj_x = swap_free_and_contract[0]; + bool adj_y = !swap_free_and_contract[1]; + if (is_complex<T>::value) { + if (adj_x) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &lhs)); + if (adj_y) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &rhs)); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output)); + Tensor output_reshaped; + TF_RETURN_IF_ERROR( + ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); + LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, adj_x, adj_y, bcast, + &output_reshaped); + return Status::OK(); +} +} // namespace + +template <typename Device, typename T> +class EinsumOp : public OpKernel { + public: + explicit EinsumOp(OpKernelConstruction* c) : OpKernel(c) { + string equation; + OP_REQUIRES_OK(c, c->GetAttr("equation", &equation)); + OP_REQUIRES_OK(c, ParseEquation(equation, &input_labels_, &output_labels_, + &label_types_, &input_label_counts_, + &output_label_counts_, &input_has_ellipsis_, + &output_has_ellipsis_)); + } + + void Compute(OpKernelContext* ctx) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); + + OperandLabels input_labels(input_labels_); + Labels output_labels(output_labels_); + std::vector<DimensionType> label_types(label_types_); + OperandLabelCounts input_label_counts(input_label_counts_); + LabelCounts output_label_counts(output_label_counts_); + LabelToDimSizes label_to_dim_sizes; + + OP_REQUIRES_OK(ctx, ProcessDimensions( + inputs, input_has_ellipsis_, output_has_ellipsis_, + &input_labels, &output_labels, &label_types, + &input_label_counts, &output_label_counts, + &label_to_dim_sizes)); + + // The reduction phase (a) sums across reduction dimensions, (b) takes + // generalized diagonals, and (c) reshapes it into shape + // [(broadcasting) batch shape] + [F,C] + // where F and C denote the total (compacted) size of free and contract + // dimensions, respectively. + const int num_inputs = inputs.size(); + OperandLabels free_labels(num_inputs); + gtl::InlinedVector<Tensor, 2> inputs_reduced(num_inputs); + gtl::InlinedVector<bool, 2> swap_free_and_contract(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + OP_REQUIRES_OK(ctx, + ReduceOperand<Device, T>( + ctx, inputs[i], label_types, input_label_counts[i], + &input_labels[i], &free_labels[i], + &swap_free_and_contract[i], &inputs_reduced[i])); + } + + // After reduction, the inputs should be reshaped to Tensors suitable for + // contraction. If num_inputs is 1, the reduced input is simply forwarded to + // the output. + Tensor contraction_output_reshaped; + OP_REQUIRES_OK(ctx, ContractOperands<Device, T>( + ctx, inputs_reduced, swap_free_and_contract, + &contraction_output_reshaped)); + + // Copy the batch labels from the contraction output. Recover the batch + // shape, which may have been broadcasted. + TensorShape result_shape = contraction_output_reshaped.shape(); + result_shape.RemoveLastDims(2); + + int num_labels = label_types.size(); + Labels result_labels; + // All batch dimensions should be present in the contracted result. First + // the broadcasting dimensions, then the named batch dimensions. + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == kBroadcasting) result_labels.push_back(label); + } + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == kBatch) result_labels.push_back(label); + } + for (int i = 0; i < num_inputs; ++i) { + for (int label : free_labels[i]) { + result_labels.push_back(label); + result_shape.AddDim(label_to_dim_sizes[label]); + } + } + + // Reshape the contraction (or reduction) result to its expanded shape: + // [(broadcasted) batch shape] + [free shape 0] + [free shape 1]. + Tensor contraction_output; + OP_REQUIRES_OK(ctx, CopyFrom(contraction_output_reshaped, result_shape, + &contraction_output)); + + // Inflate the output if necessary. (E.g. for the equation 'i->iii' which + // may arise while computing gradient of a regular Einsum). + // TODO(anudhyan): It's possible that Eigen's contract and inflate can be + // chained here to avoid materializing an intermediate. + Tensor output_inflated; + OP_REQUIRES_OK( + ctx, StrideOrInflate<Device, T>( + ctx, contraction_output, result_labels, output_label_counts, + true /* should_inflate */, &output_inflated)); + if (output_inflated.dims() > contraction_output.dims()) { + // We inflated the output. Modify result labels accordingly. + Labels inflated_labels; + for (int label : result_labels) { + inflated_labels.insert(inflated_labels.end(), + output_label_counts[label], label); + } + result_labels.swap(inflated_labels); + } + // Find the permutation to map the result labels to the output labels. Note + // that both the result and the final output may have the repeated labels, + // in which case the permutation preserves the left-to-right ordering. + // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the + // permutation should be [0, 2, 1]. We also use the fact that repeated + // labels in the result are adjacent to each other. + std::vector<int> output_permutation(output_labels.size()); + std::vector<int> label_to_position(num_labels, -1); + for (int i = 0; i < result_labels.size(); ++i) { + // Remember the position of only the leftmost result label. + if (label_to_position[result_labels[i]] == -1) { + label_to_position[result_labels[i]] = i; + } + } + for (int i = 0; i < output_labels.size(); ++i) { + output_permutation[i] = label_to_position[output_labels[i]]; + // We have found the leftmost occurrence. The next one would be adjacent. + label_to_position[output_labels[i]] += 1; + } + Tensor output; + OP_REQUIRES_OK(ctx, TransposeOperand<Device, T>( + ctx, output_inflated, output_permutation, &output)); + ctx->set_output(0, output); + } + + private: + OperandLabels input_labels_; + Labels output_labels_; + std::vector<DimensionType> label_types_; + OperandLabelCounts input_label_counts_; + LabelCounts output_label_counts_; + gtl::InlinedVector<bool, 2> input_has_ellipsis_; + bool output_has_ellipsis_ = false; +}; + +#define REGISTER_EINSUM(D, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \ + EinsumOp<D##Device, TYPE>); + +// TODO(anudhyan): Also register GPU kernels for Einsum. +#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE) +TF_CALL_float(REGISTER_CPU); +TF_CALL_double(REGISTER_CPU); +TF_CALL_complex64(REGISTER_CPU); +TF_CALL_complex128(REGISTER_CPU); +#undef REGISTER_CPU +#undef REGISTER_EINSUM + +} // namespace tensorflow diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 51ab3e268c6..f037d38ef81 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -474,6 +474,14 @@ REGISTER_OP("TridiagonalSolve") .Attr("T: {double, float, complex64, complex128}") .SetShapeFn(TridiagonalSolveShapeFn); +REGISTER_OP("Einsum") + .Input("inputs: N * T") + .Output("output: T") + .Attr("equation: string") + .Attr("N: int >= 1") + .Attr("T: type") + .SetShapeFn(shape_inference::EinsumShape); + // Deprecated op registrations: // Can be deleted after 3feb2017. diff --git a/tensorflow/core/util/einsum_op_util.cc b/tensorflow/core/util/einsum_op_util.cc new file mode 100644 index 00000000000..ec339a5128a --- /dev/null +++ b/tensorflow/core/util/einsum_op_util.cc @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/einsum_op_util.h" + +#include <string> + +#include "absl/strings/str_split.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { + +Status ParseEinsumEquation(const string& equation, + gtl::InlinedVector<string, 2>* input_subscripts, + string* output_subscript) { + gtl::InlinedVector<string, 2> inputs_and_output_subscripts = + absl::StrSplit(equation, "->"); + if (inputs_and_output_subscripts.size() != 2) { + return errors::InvalidArgument( + "Expecting exactly one '->' in einsum equation: ", equation); + } + *output_subscript = std::move(inputs_and_output_subscripts[1]); + *input_subscripts = + absl::StrSplit(std::move(inputs_and_output_subscripts[0]), ','); + if (input_subscripts->size() != 1 && input_subscripts->size() != 2) { + return errors::InvalidArgument( + "Expecting 1 or 2 input subscripts in equation '", equation, + "' but got: ", input_subscripts->size()); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/einsum_op_util.h b/tensorflow/core/util/einsum_op_util.h new file mode 100644 index 00000000000..f12af7bb7b8 --- /dev/null +++ b/tensorflow/core/util/einsum_op_util.h @@ -0,0 +1,29 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ +#define TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { + +// Parses and validates an einsum equation in explicit form. +Status ParseEinsumEquation(const string& equation, + gtl::InlinedVector<string, 2>* input_subscripts, + string* output_subscript); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 3399bdc9992..df8633236c5 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2136,6 +2136,23 @@ cuda_py_test( xla_enable_strict_auto_jit = True, ) +cuda_py_test( + name = "einsum_op_test", + size = "small", + srcs = ["einsum_op_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python/ops/linalg", + ], + xla_enable_strict_auto_jit = True, +) + cuda_py_test( name = "manip_ops_test", size = "small", diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py new file mode 100644 index 00000000000..b51b91ddbf4 --- /dev/null +++ b/tensorflow/python/kernel_tests/einsum_op_test.py @@ -0,0 +1,251 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.Einsum.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_linalg_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import test + + +class EinsumOpTest(test.TestCase): + + def _check(self, s, *input_shapes, **kwargs): + dtype = kwargs.pop('dtype', np.float32) + r = np.random.RandomState(0) + inputs = [] + for shape in input_shapes: + arr = np.array(r.randn(*shape)).astype(dtype) + if dtype == np.complex64 or dtype == np.complex128: + arr += 1j * np.array(r.randn(*shape)).astype(dtype) + inputs.append(arr) + input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs] + a = np.einsum(s, *inputs) + b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s)) + self.assertAllClose(a, b, atol=1e-4, rtol=1e-4) + + def testUnary(self): + self._check('->', ()) + self._check('aa->', (3, 3)) + self._check('aa->a', (3, 3)) + self._check('aaa->', (3, 3, 3)) + self._check('aaa->a', (3, 3, 3)) + self._check('aab->a', (3, 3, 4)) + self._check('ab->', (3, 3)) + self._check('ab->ab', (3, 3)) + self._check('abc->b', (3, 4, 5)) + self._check('abc->ca', (3, 4, 5)) + self._check('abc->cab', (3, 4, 5)) + self._check('aabcc->a', (3, 3, 5, 4, 4)) + self._check('aabcc->ac', (3, 3, 5, 4, 4)) + self._check('aabcd->ad', (3, 3, 5, 4, 4)) + + def testUnaryEllipsis(self): + self._check('...->...', ()) + self._check('...->', ()) + self._check('->...', ()) + + # Tests from dask + self._check('a...a->a...', (2, 2)) + self._check('a...a->', (2, 2)) + self._check('a...a->...', (2, 5, 1, 2)) + self._check('a...a->a...', (2, 1, 2)) + self._check('a...a->a...', (2, 3, 4, 5, 2)) + + self._check('...ijk->...ki', (3, 4, 5)) + self._check('...ijk->...ki', (1, 3, 4, 5)) + self._check('...ijk->...ki', (2, 2, 3, 4, 5)) + + # Repeated indices. + self._check('i...ii->...i', (3, 2, 3, 3)) + + def testBinary(self): + self._check(',->', (), ()) + self._check('a,a->', (3,), (3,)) + self._check('a,a->a', (3,), (3,)) + self._check('ba,b->', (3, 2), (3,)) + self._check('ab,b->a', (3, 4), (4,)) + self._check('ab,ab->', (3, 4), (3, 4)) + self._check('nij,jk->nik', (5, 2, 3), (3, 4)) + self._check('abc,bad->abcd', (1, 2, 3), (2, 1, 4)) + # Repeated indices. + self._check('ijj,k->ik', (2, 3, 3), (4,)) + self._check('aba,a->b', (3, 4, 3), (3,)) + # From https://github.com/dask/dask/pull/3412#discussion_r182413444 + self._check('aab,bc->ac', (2, 2, 3), (3, 4)) + self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4)) + # Based on https://github.com/google/jax/issues/37#issuecomment-448572187 + self._check('sa,shb->shab', (2, 1), (2, 3, 4)) + + def testBroadcasting(self): + # Batch matmul without broadcasting. + self._check('...ij,...jk->...ik', (5, 1, 2, 3), (5, 1, 3, 4)) + # Batch matmul with broadcasting. + self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5)) + self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5)) + self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5)) + self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5)) + self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5)) + self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5)) + # Broadcasting with repeated indices. + self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4)) + self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4)) + self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4)) + self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4)) + # Following 2 from # https://stackoverflow.com/a/19203475/1611416 + self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6)) + self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,)) + + def testDtypes(self): + for dtype in [np.float64, np.float32, np.complex64, np.complex128]: + self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype) + self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype) + self._check('ji,kj->ik', (2, 2), (2, 2), dtype=dtype) + self._check('ij,jk->ki', (2, 2), (2, 2), dtype=dtype) + self._check('ji,kj->ki', (2, 2), (2, 2), dtype=dtype) + + @test_util.run_in_graph_and_eager_modes + def testInvalid(self): + r = np.random.RandomState(0) + cases = [ + # incorrect rank. + ('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)), + ('...ij,jk->ik', r.randn(3), r.randn(3, 4)), + # inconsistent dimensions. + ('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)), + # broadcasting is invalid + ('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)), + # output should have ellipsis when broadcasting shape is + # non-empty. + ('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)), + ] + for args in cases: + with self.assertRaises((ValueError, errors.InvalidArgumentError)): + _ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0])) + + placeholders = [ + array_ops.placeholder_with_default(x, shape=None) for x in args[1:] + ] + with self.assertRaises((ValueError, errors.InvalidArgumentError)): + _ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0])) + + @test_util.run_in_graph_and_eager_modes + def testPlaceholder(self): + + def check(equation, *input_and_placeholder_shapes): + r = np.random.RandomState(0) + inputs = [] + input_placeholders = [] + for actual_shape, placeholder_shape in input_and_placeholder_shapes: + input_np = np.array(r.randn(*actual_shape)) + inputs.append(input_np) + input_placeholders.append( + array_ops.placeholder_with_default(input_np, placeholder_shape)) + + a = np.einsum(equation, *inputs) + b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation)) + self.assertAllClose(a, b, atol=1e-4, rtol=1e-4) + + check('bijl,bjkm->bik', ((9, 2, 3, 5), (None, None, None, 5)), + ((9, 3, 4, 7), (None, None, 4, None))) + check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None)) + check('...ij,...->...i', ((4, 3, 1, 2), (None, 3, None, 2)), + ((4, 3), (None, 3))) + check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None)) + + def testOutputRepeatedLabels(self): + # This is the reverse operation of repeated input labels, to be used for + # computing symbolic gradients of einsum. + r = np.random.RandomState(0) + a = r.randn(2, 2) + s = 'a->aa' + diag_a = np.diag(np.diag(a)) + b = self.evaluate(gen_linalg_ops.einsum([np.diag(a)], s)) + self.assertAllClose(diag_a, b, atol=1e-4, rtol=1e-4) + + +class EinsumBenchmark(test.Benchmark): + cases = [ + # Unary cases. + ['ijk->i', 100], + ['ijk->kji', 100], + # Regular matmul or batch matmul. + ['ij,jk->ik', 1000], + ['ji,kj->ik', 1000], + ['ab,ab->', 100], + ['ab,ba->', 100], + ['abc,abc->', 100], + ['abc,bac->', 100], + ['abc,cba->', 100], + ['bij,bjk->bik', 100], + ['bji,bjk->bki', 100], + ['ikl,kji->kl', 100], + ['klj,lki->ij', 100], + ['ijk,ilj->kli', 100], + ['kij,mkb->ijmb', 100], + ['abcd,ad->bc', 40], + # Larger binary contractions. + ['ijk,jklm->il', 40], + ['efabc,eabcd->efd', 30], + ['fabec,abcde->fde', 30], + ['efabc,edabc->efd', 30], + ['eadbf,dfebc->ecfad', 30], + ['abcdef,bcdfg->abcdeg', 30], + ] + + def benchmarkEinsum(self): + for equation, dim in self.cases: + with ops.Graph().as_default(), \ + session.Session(config=benchmark.benchmark_config()) as sess, \ + ops.device('/cpu:0'): + r = np.random.RandomState(0) + input_subscripts = equation.split('->')[0].split(',') + input_vars = [] + for subscript in input_subscripts: + input_shape = (dim,) * len(subscript) + input_vars.append( + variables.Variable(np.array(r.randn(*input_shape), np.float32))) + variables.global_variables_initializer().run() + + # Call einsum_v1. + self.run_op_benchmark( + sess, + special_math_ops.einsum(equation, *input_vars), + min_iters=50, + name='einsum_v1_cpu_({})_{}'.format(equation, dim)) + + # Call gen_linalg_ops.einsum. + self.run_op_benchmark( + sess, + gen_linalg_ops.einsum(input_vars, equation), + min_iters=50, + name='einsum_v2_cpu_({})_{}'.format(equation, dim)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 53cd28b9dce..ca661f2a572 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1000,6 +1000,10 @@ tf_module { name: "EditDistance" argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } + member_method { + name: "Einsum" + argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Elu" argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 53cd28b9dce..ca661f2a572 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1000,6 +1000,10 @@ tf_module { name: "EditDistance" argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } + member_method { + name: "Einsum" + argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Elu" argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "