Extend functionality of tf.einsum to the full NumPy spec. (part 1 of 2).
This CL creates a C++ Einsum Op which will be the basis of tf.einsum. Functionality improvements: - Support repeated indices / generalized traces (e.g. ijjk,klm->ikm) - Improve ellipsis labels to support NumPy-style broadcasting. Previously, it only supported broadcasting the same batch shape. (e.g. [1,1,2] is not compatible with [3,1]) - Ellipsis also supports partially-unknown and unknown shapes. Performance improvements: - Take advantage of transpose_a/transpose_b params in matmul to improve performance in common matmul-like cases. PiperOrigin-RevId: 248830232
This commit is contained in:
parent
39c4ef4be6
commit
8e8f040cec
@ -959,6 +959,7 @@ tf_cuda_library(
|
||||
"util/guarded_philox_random.h",
|
||||
"util/mirror_pad_mode.h",
|
||||
"util/padding.h",
|
||||
"util/einsum_op_util.h",
|
||||
"util/port.h",
|
||||
"util/ptr_util.h",
|
||||
"util/reffed_status_callback.h",
|
||||
|
100
tensorflow/core/api_def/base_api/api_def_Einsum.pbtxt
Normal file
100
tensorflow/core/api_def/base_api/api_def_Einsum.pbtxt
Normal file
@ -0,0 +1,100 @@
|
||||
op {
|
||||
graph_op_name: "Einsum"
|
||||
in_arg {
|
||||
name: "inputs"
|
||||
description: <<END
|
||||
List of 1 or 2 Tensors.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Output Tensor with shape depending upon `equation`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "equation"
|
||||
description: <<END
|
||||
String describing the Einstein Summation operation; in the format of np.einsum.
|
||||
END
|
||||
}
|
||||
summary: "Tensor contraction according to Einstein summation convention."
|
||||
description: <<END
|
||||
Implements generalized Tensor contraction and reduction. Each input Tensor must
|
||||
have a corresponding input subscript appearing in the comma-separated left-hand
|
||||
side of the equation. The right-hand side of the equation consists of the
|
||||
output subscript. The input subscripts and the output subscript should consist
|
||||
of zero or more named axis labels and at most one ellipsis (`...`).
|
||||
|
||||
The named axis labels may be any single character other than those having
|
||||
special meaning, namely `,.->`. The behavior of this Op is undefined if it
|
||||
receives an ill-formatted equation; since the validation is done at
|
||||
graph-building time, we omit format validation checks at runtime.
|
||||
|
||||
Note: This Op is *not* intended to be called by the user; instead users should
|
||||
call `tf.einsum` directly. It is a hidden Op used by `tf.einsum`.
|
||||
|
||||
Operations are applied to the input(s) according to the following rules:
|
||||
|
||||
(a) Generalized Diagonals: For input dimensions corresponding to axis labels
|
||||
appearing more than once in the same input subscript, we take the
|
||||
generalized (`k`-dimensional) diagonal.
|
||||
For example, in the equation `iii->i` with input shape `[3, 3, 3]`, the
|
||||
generalized diagonal would consist of `3` elements at indices `(0, 0, 0)`,
|
||||
`(1, 1, 1)` and `(2, 2, 2)` to create a Tensor of shape `[3]`.
|
||||
|
||||
(b) Reduction: Axes corresponding to labels appearing only in one input
|
||||
subscript but not in the output subscript are summed over prior to Tensor
|
||||
contraction.
|
||||
For example, in the equation `ab,bc->b`, the axis labels `a` and `c` are
|
||||
the reduction axis labels.
|
||||
|
||||
(c) Batch Dimensions: Axes corresponding to labels appearing in each of the
|
||||
input subscripts and also in the output subscript make up the batch
|
||||
dimensions in Tensor contraction. Unnamed axis labels corresponding to
|
||||
ellipsis (`...`) also correspond to batch dimensions.
|
||||
For example, for the equation denoting batch matrix multiplication,
|
||||
`bij,bjk->bik`, the axis label `b` corresponds to a batch dimension.
|
||||
|
||||
(d) Contraction: In case of binary einsum, axes corresponding to labels
|
||||
appearing in two different inputs (and not in the output) are contracted
|
||||
against each other.
|
||||
Considering the batch matrix multiplication equation again
|
||||
(`bij,bjk->bik`), the contracted axis label is `j`.
|
||||
|
||||
(e) Expand Diagonal: If the output subcripts contain repeated (explicit) axis
|
||||
labels, the opposite operation of (a) is applied. For example, in the
|
||||
equation `i->iii`, and input shape `[3]`, the output of shape `[3, 3, 3]`
|
||||
are all zeros, except for the (generalized) diagonal which is populated
|
||||
with values from the input.
|
||||
Note: This operation is not supported by `np.einsum` or `tf.einsum`; it is
|
||||
provided to enable computing the symbolic gradient of `tf.einsum`.
|
||||
|
||||
The output subcripts must contain only labels appearing in at least one of the
|
||||
input subscripts. Furthermore, all dimensions mapping to the same axis label
|
||||
must be equal.
|
||||
|
||||
Any of the input and output subscripts may contain at most a single ellipsis
|
||||
(`...`). These ellipsis are mapped against dimensions not corresponding to any
|
||||
named axis label. If two inputs contain ellipsis, then they are broadcasted
|
||||
according to standard NumPy broadcasting
|
||||
[rules](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
|
||||
|
||||
The broadcasted dimensions are placed in the corresponding location of the
|
||||
ellipsis in the output subscript. If the broadcasted dimensions are non-empty
|
||||
and the output subcripts do not contain ellipsis, then an InvalidArgument error
|
||||
is raised.
|
||||
|
||||
@compatibility(numpy)
|
||||
Similar to [`numpy.einsum`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html).
|
||||
|
||||
Comparison with `numpy.einsum`:
|
||||
|
||||
* This Op only supports unary and binary forms of `numpy.einsum`.
|
||||
* This Op does not support implicit form. (i.e. equations without `->`).
|
||||
* This Op also supports repeated indices in the output subscript, which is not
|
||||
supported by `numpy.einsum`.
|
||||
@end_compatibility
|
||||
|
||||
END
|
||||
}
|
4
tensorflow/core/api_def/python_api/api_def_Einsum.pbtxt
Normal file
4
tensorflow/core/api_def/python_api/api_def_Einsum.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Einsum"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -13,8 +13,16 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/util/einsum_op_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -233,6 +241,178 @@ Status MatMulShape(shape_inference::InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Validate that an Einsum subscript contains exactly one or zero ellipsis; and
|
||||
// that periods (.) occur only within an ellipses (...).
|
||||
Status ValidateEinsumEllipsis(absl::string_view subscript,
|
||||
bool* found_ellipsis) {
|
||||
const int num_periods = absl::c_count(subscript, '.');
|
||||
if (num_periods != 0 && num_periods != 3) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected at most one ellipsis (...), but found ", num_periods,
|
||||
" periods (.) in the input subscript: ", subscript);
|
||||
}
|
||||
if (num_periods == 3 && !absl::StrContains(subscript, "...")) {
|
||||
return errors::InvalidArgument(
|
||||
"Periods found outside of ellipsis in subscript: ", subscript);
|
||||
}
|
||||
*found_ellipsis = num_periods > 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status EinsumShape(shape_inference::InferenceContext* c) {
|
||||
// We assume that the equation has a valid format. Either (x),(y)->(z)
|
||||
// or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or
|
||||
// more latin alphabets and contains at most one ellipsis ('...').
|
||||
string equation;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation));
|
||||
gtl::InlinedVector<string, 2> input_labels;
|
||||
string output_labels;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseEinsumEquation(equation, &input_labels, &output_labels));
|
||||
|
||||
if (c->num_inputs() == 0 || c->num_inputs() > 2) {
|
||||
return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ",
|
||||
c->num_inputs());
|
||||
}
|
||||
if (c->num_inputs() != input_labels.size()) {
|
||||
return errors::InvalidArgument("Expected ", input_labels.size(),
|
||||
" inputs for equation ", equation,
|
||||
" but got: ", c->num_inputs());
|
||||
}
|
||||
|
||||
// Validate input subscripts, build the label to dimension mapping and obtain
|
||||
// the broadcast shapes that map to ellipsis.
|
||||
absl::flat_hash_map<char, DimensionHandle> label_to_dimension;
|
||||
gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs());
|
||||
for (int i = 0; i < c->num_inputs(); ++i) {
|
||||
bool has_ellipsis = false;
|
||||
TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis));
|
||||
ShapeHandle input_shape = c->input(i);
|
||||
// Validate that the input rank is sufficient for the given number of named
|
||||
// labels.
|
||||
if (c->RankKnown(input_shape)) {
|
||||
if (has_ellipsis) {
|
||||
const int num_named_labels =
|
||||
static_cast<int>(input_labels[i].size()) - 3;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
c->WithRankAtLeast(input_shape, num_named_labels, &input_shape),
|
||||
" for ", i, "th input and equation: ", equation);
|
||||
} else {
|
||||
const int num_named_labels = static_cast<int>(input_labels[i].size());
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
c->WithRank(input_shape, num_named_labels, &input_shape), " for ",
|
||||
i, "th input and equation: ", equation);
|
||||
}
|
||||
}
|
||||
|
||||
bool seen_ellipsis = false;
|
||||
input_bcast_shapes[i] = c->Scalar();
|
||||
// Run through the input labels; populate label_to_dimension mapping and
|
||||
// compute the broadcast shapes corresponding to the ellipsis (if present).
|
||||
for (int label_idx = 0; label_idx < input_labels[i].size(); ++label_idx) {
|
||||
const char label = input_labels[i][label_idx];
|
||||
// Calculate the input axis that the current label is referring to. After
|
||||
// the ellipsis, the axis may be found by using negative indices; i.e the
|
||||
// (rank - k)th dimension corresponds to the (num_labels - k)th label.
|
||||
const int64 axis_before_ellipsis = label_idx;
|
||||
const int64 axis_after_ellipsis =
|
||||
c->RankKnown(input_shape)
|
||||
? label_idx + c->Rank(input_shape) - input_labels[i].size()
|
||||
: -1;
|
||||
|
||||
// Populate the input broadcast shape when we encounter an ellipsis (...).
|
||||
if (label == '.') {
|
||||
if (!c->RankKnown(input_shape)) {
|
||||
input_bcast_shapes[i] = c->UnknownShape();
|
||||
} else {
|
||||
// The broadcast shape runs till the named label right after the
|
||||
// ellipsis, the label with index (label_idx + 3).
|
||||
TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis,
|
||||
axis_after_ellipsis + 3,
|
||||
&input_bcast_shapes[i]));
|
||||
}
|
||||
label_idx += 2; // Skip the rest of the ellipsis.
|
||||
seen_ellipsis = true;
|
||||
continue;
|
||||
}
|
||||
// Obtain the dimension that the current label corresponds to.
|
||||
int64 axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis;
|
||||
DimensionHandle new_dim = c->RankKnown(input_shape)
|
||||
? c->Dim(input_shape, axis)
|
||||
: c->UnknownDim();
|
||||
// If we've seen this label before, make sure previous and current
|
||||
// dimensions are compatible.
|
||||
if (label_to_dimension.contains(label)) {
|
||||
DimensionHandle merged;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(label_to_dimension[label], new_dim, &merged));
|
||||
label_to_dimension[label] = merged;
|
||||
} else {
|
||||
label_to_dimension[label] = new_dim;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For two inputs, broadcast the two input broadcast shapes to create the
|
||||
// output broadcast shape. For one input, just copy the single broadcast
|
||||
// shape.
|
||||
ShapeHandle output_bcast_shape;
|
||||
if (input_bcast_shapes.size() == 1) {
|
||||
output_bcast_shape = input_bcast_shapes[0];
|
||||
} else if (input_bcast_shapes.size() == 2) {
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, input_bcast_shapes[0], input_bcast_shapes[1], &output_bcast_shape));
|
||||
}
|
||||
|
||||
bool output_has_ellipsis = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateEinsumEllipsis(output_labels, &output_has_ellipsis));
|
||||
if (output_has_ellipsis) {
|
||||
// If the output subscript has ellipsis and the output broadcast rank is
|
||||
// unknown, then the output shape should have unknown rank.
|
||||
if (!c->RankKnown(output_bcast_shape)) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
// If the output subscripts don't have ellipsis then make sure the output
|
||||
// broadcasting shape is empty.
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape),
|
||||
" for einsum equation '", equation,
|
||||
"' without ellipsis (...) in the output subscripts where input(s) have "
|
||||
"non-empty broadcasting shape");
|
||||
output_bcast_shape = c->Scalar();
|
||||
}
|
||||
|
||||
// Create the output shape from output labels and label_to_dimension mapping.
|
||||
std::vector<DimensionHandle> output_dims;
|
||||
for (int label_idx = 0; label_idx < output_labels.size(); ++label_idx) {
|
||||
const char label = output_labels[label_idx];
|
||||
// Append the output_bcast_shape when the ellipsis is encountered.
|
||||
if (label == '.') {
|
||||
for (int k = 0; k < c->Rank(output_bcast_shape); ++k) {
|
||||
output_dims.push_back(c->Dim(output_bcast_shape, k));
|
||||
}
|
||||
label_idx += 2; // Skip the rest of the ellipsis.
|
||||
continue;
|
||||
}
|
||||
auto dimension_it = label_to_dimension.find(label);
|
||||
if (dimension_it == label_to_dimension.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Einsum output subscripts for equation '", equation, "' has label '",
|
||||
label, "' which is not present in the input subscripts");
|
||||
}
|
||||
output_dims.push_back(dimension_it->second);
|
||||
}
|
||||
c->set_output(0, c->MakeShape(output_dims));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
|
||||
ShapeHandle a_shape;
|
||||
ShapeHandle b_shape;
|
||||
|
@ -230,6 +230,9 @@ Status MatMulShape(shape_inference::InferenceContext* c);
|
||||
// batch dimensions.
|
||||
Status BatchMatMulV2Shape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for Einsum.
|
||||
Status EinsumShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for BiasAdd-like operations.
|
||||
Status BiasAddShape(shape_inference::InferenceContext* c);
|
||||
|
||||
|
@ -213,6 +213,129 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, Einsum_ShapeFn) {
|
||||
ShapeInferenceTestOp op("Einsum");
|
||||
auto set_equation = [&op](int n, string equation) {
|
||||
std::vector<NodeDefBuilder::NodeOut> input_list;
|
||||
input_list.reserve(n);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
input_list.emplace_back("a", 0, DT_FLOAT);
|
||||
}
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "Einsum")
|
||||
.Input(input_list)
|
||||
.Attr("equation", equation)
|
||||
.Finalize(&op.node_def));
|
||||
};
|
||||
|
||||
// Unary cases.
|
||||
set_equation(1, "abc->c");
|
||||
INFER_OK(op, "[?,?,?]", "[d0_2]");
|
||||
set_equation(1, "abc->aabbcc");
|
||||
INFER_OK(op, "[?,?,?]", "[d0_0,d0_0,d0_1,d0_1,d0_2,d0_2]");
|
||||
set_equation(1, "abc->");
|
||||
INFER_OK(op, "[?,?,?]", "[]");
|
||||
set_equation(1, "->");
|
||||
INFER_OK(op, "[]", "[]");
|
||||
|
||||
// Binary cases.
|
||||
set_equation(2, "ij,jk->ik");
|
||||
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
|
||||
set_equation(2, "ij,jk->ik");
|
||||
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
|
||||
set_equation(2, "ab,ab->");
|
||||
INFER_OK(op, "[?,?];[?,?]", "[]");
|
||||
set_equation(2, "ab,->b");
|
||||
INFER_OK(op, "[?,?];[]", "[d0_1]");
|
||||
set_equation(2, ",->");
|
||||
INFER_OK(op, "[];[]", "[]");
|
||||
set_equation(2, "aaa,b->abbb");
|
||||
INFER_OK(op, "[?,?,?];[?]", "[d0_0,d1_0,d1_0,d1_0]");
|
||||
set_equation(2, ",abcd->badc");
|
||||
INFER_OK(op, "[];[?,?,?,?]", "[d1_1,d1_0,d1_3,d1_2]");
|
||||
|
||||
// Ellipsis cases.
|
||||
set_equation(1, "a...bc->c...");
|
||||
INFER_OK(op, "[?,?,?,?,?]", "[d0_4,d0_1,d0_2]");
|
||||
set_equation(2, "...ij,...jk->...ik");
|
||||
INFER_OK(op, "[?,?,?,?,?];[1,?,?]", "[d0_0,d0_1,d0_2,d0_3,d1_2]");
|
||||
INFER_OK(op, "[1,?,?];[?,?,?,?,?]", "[d1_0,d1_1,d1_2,d0_1,d1_4]");
|
||||
|
||||
// Unknown rank.
|
||||
set_equation(1, "abc->c");
|
||||
INFER_OK(op, "?", "[?]");
|
||||
set_equation(1, "a...bc->c");
|
||||
INFER_OK(op, "?", "[?]");
|
||||
set_equation(1, "a...bc->c...");
|
||||
INFER_OK(op, "?", "?");
|
||||
|
||||
set_equation(2, "...ij,...jk->...ik");
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[?,?,?];?", "?");
|
||||
INFER_OK(op, "?;[?,?,?]", "?");
|
||||
set_equation(2, "...ij,...jk->ik");
|
||||
INFER_OK(op, "?;?", "[?,?]");
|
||||
set_equation(2, "abd,b...c->...cad");
|
||||
INFER_OK(op, "[?,?,?];[?,?,?,?]", "[d1_1,d1_2,d1_3,d0_0,d0_2]");
|
||||
set_equation(2, "...ab,b...c->ac...");
|
||||
INFER_OK(op, "[?,1,?,?];[?,?,?]", "[d0_2,d1_2,d0_0,d1_1]");
|
||||
|
||||
// Wrong number of inputs.
|
||||
set_equation(2, "ab->b");
|
||||
INFER_ERROR("got: 2", op, "[?,?];[?,?]");
|
||||
set_equation(1, "ab,a->b");
|
||||
INFER_ERROR("got: 1", op, "[?,?]");
|
||||
|
||||
// Invalid format. Implicit form is not supported.
|
||||
set_equation(1, "a");
|
||||
INFER_ERROR("equation", op, "[2]");
|
||||
set_equation(2, "ab,bc");
|
||||
INFER_ERROR("equation", op, "[2,2];[2,2]");
|
||||
|
||||
// Wrong number of ellipsis or periods outside of ellipsis.
|
||||
set_equation(1, "..a.->a...");
|
||||
INFER_ERROR("ellipsis", op, "[1,1,2,1]");
|
||||
set_equation(1, "...a->.a..");
|
||||
INFER_ERROR("ellipsis", op, "[1,1,1,2]");
|
||||
set_equation(1, "...a...->...a");
|
||||
INFER_ERROR("ellipsis", op, "[1,1,1,2]");
|
||||
set_equation(1, "..a..b..->...ab");
|
||||
INFER_ERROR("ellipsis", op, "[1,1,2,1]");
|
||||
set_equation(2, "...a...,ab->a");
|
||||
INFER_ERROR("ellipsis", op, "[1,2,1];[2,1]");
|
||||
set_equation(2, "a,...ab...->a");
|
||||
INFER_ERROR("ellipsis", op, "[2];[1,2,1,1]");
|
||||
set_equation(2, "a,ab->a......");
|
||||
INFER_ERROR("ellipsis", op, "[2];[2,1]");
|
||||
|
||||
// Output label doesn't appear in input.
|
||||
set_equation(1, "abc->d");
|
||||
INFER_ERROR("'d'", op, "[?,?,?]");
|
||||
|
||||
// Mismatch in input rank.
|
||||
set_equation(1, "abc->c");
|
||||
INFER_ERROR("4", op, "[?,?,?,?]");
|
||||
INFER_ERROR("2", op, "[?,?]");
|
||||
set_equation(1, "...abc->...c");
|
||||
INFER_ERROR("2", op, "[?,?]");
|
||||
|
||||
// Input dimensions are not consistent.
|
||||
set_equation(2, "ab,ab->a");
|
||||
INFER_ERROR("are 1 and 2", op, "[1,2];[2,1]");
|
||||
set_equation(2, "aa,bb->a");
|
||||
INFER_ERROR("are 1 and 2", op, "[1,2];[2,2]");
|
||||
|
||||
// Invalid broadcasting dimensions.
|
||||
set_equation(2, "...ij,...jk->...ik");
|
||||
INFER_ERROR("are 2 and 3", op, "[2,?,?];[3,?,?]");
|
||||
set_equation(2, "i...j,jk...->...ik");
|
||||
INFER_ERROR("are 2 and 3", op, "[?,2,?];[?,?,3]");
|
||||
set_equation(2, "...ij,...jk->ik");
|
||||
set_equation(2, "i...j,jk...->ik");
|
||||
INFER_ERROR("non-empty broadcasting", op, "[?,2,?];[?,?]");
|
||||
set_equation(2, "...ab,b...c->ac...");
|
||||
INFER_OK(op, "?;[4,5,3]", "?");
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, BatchMatMulV2_ShapeFn) {
|
||||
ShapeInferenceTestOp op("BatchMatMulV2");
|
||||
auto set_adj = [&op](bool adj_x, bool adj_y) {
|
||||
|
@ -3115,6 +3115,7 @@ cc_library(
|
||||
":cholesky_grad",
|
||||
":cholesky_op",
|
||||
":determinant_op",
|
||||
":einsum_op",
|
||||
":lu_op",
|
||||
":matrix_exponential_op",
|
||||
":matrix_inverse_op",
|
||||
@ -3313,6 +3314,21 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "einsum_op",
|
||||
prefix = "einsum_op",
|
||||
deps = [
|
||||
":batch_matmul_op",
|
||||
":reduction_ops",
|
||||
":transpose_functor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "linalg_ops_common",
|
||||
srcs = ["linalg_ops_common.cc"],
|
||||
|
711
tensorflow/core/kernels/einsum_op.cc
Normal file
711
tensorflow/core/kernels/einsum_op.cc
Normal file
@ -0,0 +1,711 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/einsum_op_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
using ShapeVec = gtl::InlinedVector<int64, 8>;
|
||||
using Labels = gtl::InlinedVector<int, 8>;
|
||||
using OperandLabels = gtl::InlinedVector<Labels, 2>;
|
||||
using LabelCounts = gtl::InlinedVector<int, 8>;
|
||||
using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>;
|
||||
using LabelToDimSizes = gtl::InlinedVector<int64, 8>;
|
||||
|
||||
// Dummy axis label used to denote an ellipsis in an input or output subscript.
|
||||
constexpr int kEllipsisLabel = -1;
|
||||
|
||||
// Each dimension is categorized into exactly one of five types based on whether
|
||||
// its corresponding label is present in the input and/or the output subscripts.
|
||||
enum DimensionType {
|
||||
// Batch dimensions are those present in two inputs as well as the output.
|
||||
// They are part of the batch dimensions during Tensor contraction.
|
||||
// Such dimensions may be broadcasting dimensions (those mapping to ellipsis)
|
||||
// or explicit batch dimensions corresponding to named axis labels.
|
||||
kBroadcasting = 0,
|
||||
kBatch = 1,
|
||||
// Free dimensions are present in exactly one of the inputs, and also the
|
||||
// output. These are non-contracted axes in the Tensor contraction.
|
||||
kFree = 2,
|
||||
// Contract dimensions are present in two inputs, but not the output. These
|
||||
// dimensions are contracted in Tensor contraction.
|
||||
kContract = 3,
|
||||
// Reduce dimensions are present in exactly one input; and not in the output
|
||||
// and are summed over prior to Tensor contraction.
|
||||
kReduce = 4,
|
||||
};
|
||||
|
||||
// Returns the DimensionType given whether the corresponding label is present in
|
||||
// exactly one input subscript (is_unique) and whether it is absent from the
|
||||
// output subscripts (is_removed). Does not handle broadcasting dimensions.
|
||||
DimensionType GetDimensionType(bool is_removed, bool is_unique) {
|
||||
if (!is_removed && !is_unique)
|
||||
return kBatch;
|
||||
else if (!is_removed && is_unique)
|
||||
return kFree;
|
||||
else if (is_removed && !is_unique)
|
||||
return kContract;
|
||||
else // is_removed && is_unique
|
||||
return kReduce;
|
||||
}
|
||||
|
||||
// Maps the character labels to consecutive integers.
|
||||
void MapToLabels(const string& subscript, Labels* labels,
|
||||
absl::flat_hash_map<char, int>* label_mapping) {
|
||||
for (int i = 0; i < subscript.size(); ++i) {
|
||||
const char label_char = subscript[i];
|
||||
if (label_char == '.') {
|
||||
labels->push_back(kEllipsisLabel);
|
||||
i += 2; // Skip next 2 characters as well.
|
||||
continue;
|
||||
}
|
||||
if (!label_mapping->contains(label_char)) {
|
||||
const int next_label = label_mapping->size();
|
||||
(*label_mapping)[label_char] = next_label;
|
||||
}
|
||||
const int mapped_label = (*label_mapping)[label_char];
|
||||
labels->push_back(mapped_label);
|
||||
}
|
||||
}
|
||||
|
||||
// Parses and validates the equation and the input shapes. Single character
|
||||
// labels are integerized and we populate input and output label subscripts and
|
||||
// corresponding counts. Also create the mapping from (named) labels to their
|
||||
// DimensionType.
|
||||
Status ParseEquation(const string& equation, OperandLabels* input_labels,
|
||||
Labels* output_labels,
|
||||
std::vector<DimensionType>* label_types,
|
||||
OperandLabelCounts* input_label_counts,
|
||||
LabelCounts* output_label_counts,
|
||||
gtl::InlinedVector<bool, 2>* input_has_ellipsis,
|
||||
bool* output_has_ellipsis) {
|
||||
gtl::InlinedVector<string, 2> input_str;
|
||||
string output_str;
|
||||
TF_RETURN_IF_ERROR(ParseEinsumEquation(equation, &input_str, &output_str));
|
||||
|
||||
// Temporary map from single character labels to (consecutive) integer labels.
|
||||
absl::flat_hash_map<char, int> label_mapping;
|
||||
int num_inputs = input_str.size();
|
||||
input_labels->resize(num_inputs);
|
||||
|
||||
// Map from single characters to integer labels.
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
MapToLabels(input_str[i], &input_labels->at(i), &label_mapping);
|
||||
}
|
||||
MapToLabels(output_str, output_labels, &label_mapping);
|
||||
|
||||
// Compute counts for input and output labels.
|
||||
int num_labels = label_mapping.size();
|
||||
input_label_counts->resize(num_inputs);
|
||||
input_has_ellipsis->resize(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
input_label_counts->at(i).resize(num_labels);
|
||||
for (const int label : input_labels->at(i)) {
|
||||
if (label != kEllipsisLabel)
|
||||
input_label_counts->at(i)[label] += 1;
|
||||
else
|
||||
input_has_ellipsis->at(i) = true;
|
||||
}
|
||||
}
|
||||
output_label_counts->resize(num_labels);
|
||||
for (const int label : *output_labels) {
|
||||
if (label != kEllipsisLabel)
|
||||
output_label_counts->at(label) += 1;
|
||||
else
|
||||
*output_has_ellipsis = true;
|
||||
}
|
||||
|
||||
// Map each label to a unique DimensionType.
|
||||
label_types->resize(num_labels);
|
||||
for (int label = 0; label < num_labels; ++label) {
|
||||
if (label == kEllipsisLabel) continue;
|
||||
bool removed = (*output_label_counts)[label] == 0;
|
||||
bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 ||
|
||||
(*input_label_counts)[1][label] == 0;
|
||||
(*label_types)[label] = GetDimensionType(removed, unique);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Insert new (unnamed) broadcasting labels at the location of ellipsis.
|
||||
void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels,
|
||||
int ellipsis_axis, Labels* labels,
|
||||
LabelCounts* label_counts) {
|
||||
labels->erase(labels->begin() + ellipsis_axis);
|
||||
labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0);
|
||||
std::iota(labels->begin() + ellipsis_axis,
|
||||
labels->begin() + ellipsis_axis + num_bcast_dims, num_named_labels);
|
||||
// Increment label counts. Since these are new labels, the count is set to 1.
|
||||
label_counts->resize(num_named_labels + num_bcast_dims, 1);
|
||||
}
|
||||
|
||||
// Record and validate the label to dimension mapping. Must be a named
|
||||
// (non-broadcasting) label as broadcasting labels don't have a fixed dimension.
|
||||
Status RecordLabelToDimension(const int label, const int axis,
|
||||
const Tensor& input,
|
||||
LabelToDimSizes* label_to_dim_sizes) {
|
||||
const int64 input_dim = input.dim_size(axis);
|
||||
// We know that label_to_dim_sizes has the size to accommodate named labels.
|
||||
if (label_to_dim_sizes->at(label) != 0 &&
|
||||
label_to_dim_sizes->at(label) != input_dim) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected dimension ", label_to_dim_sizes->at(label), " at axis ", axis,
|
||||
" of the input shaped ", input.shape().DebugString(),
|
||||
" but got dimension ", input_dim);
|
||||
}
|
||||
(*label_to_dim_sizes)[label] = input_dim;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Validate input dimensions and populate unnamed labels and their label counts.
|
||||
Status ProcessDimensions(const OpInputList& inputs,
|
||||
const gtl::InlinedVector<bool, 2>& input_has_ellipsis,
|
||||
const bool output_has_ellipsis,
|
||||
OperandLabels* input_labels, Labels* output_labels,
|
||||
std::vector<DimensionType>* label_types,
|
||||
OperandLabelCounts* input_label_counts,
|
||||
LabelCounts* output_label_counts,
|
||||
LabelToDimSizes* label_to_dim_sizes) {
|
||||
if (inputs.size() != input_labels->size()) {
|
||||
return errors::InvalidArgument("Expected ", input_labels->size(),
|
||||
" inputs but got: ", inputs.size());
|
||||
}
|
||||
const int num_inputs = inputs.size();
|
||||
|
||||
// We infer the number of broadcasting dimensions by taking the maximum rank
|
||||
// among the broadcasting subshapes of the input.
|
||||
int max_bcast_dims = 0;
|
||||
const int num_named_labels = label_types->size();
|
||||
label_to_dim_sizes->resize(num_named_labels);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
Labels* labels = &(*input_labels)[i];
|
||||
|
||||
if (!input_has_ellipsis[i]) {
|
||||
if (inputs[i].dims() != labels->size()) {
|
||||
return errors::InvalidArgument("Expected input ", i, " to have rank ",
|
||||
labels->size(),
|
||||
" but got: ", inputs[i].dims());
|
||||
}
|
||||
for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
|
||||
const int label = (*labels)[label_idx];
|
||||
TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i],
|
||||
label_to_dim_sizes));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Input has an ellipsis.
|
||||
if (inputs[i].dims() + 1 < labels->size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected input ", i, " to have rank at least ", labels->size() - 1,
|
||||
" but got: ", inputs[i].dims());
|
||||
}
|
||||
int ellipsis_axis = -1;
|
||||
const int num_bcast_dims = inputs[i].dims() - labels->size() + 1;
|
||||
for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
|
||||
const int label = (*labels)[label_idx];
|
||||
if (label == kEllipsisLabel) {
|
||||
ellipsis_axis = label_idx;
|
||||
continue;
|
||||
}
|
||||
// Current label is not an ellipsis.
|
||||
const int axis =
|
||||
label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1);
|
||||
TF_RETURN_IF_ERROR(
|
||||
RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes));
|
||||
}
|
||||
// Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting dimensions.
|
||||
if (ellipsis_axis != -1) {
|
||||
InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis,
|
||||
labels, &input_label_counts->at(i));
|
||||
max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims);
|
||||
}
|
||||
}
|
||||
if (!absl::c_linear_search(input_has_ellipsis, true) &&
|
||||
!output_has_ellipsis) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Insert broadcasting dimensions in the output labels.
|
||||
auto it =
|
||||
std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel);
|
||||
if (it != output_labels->end()) {
|
||||
const int ellipsis_axis = it - output_labels->begin();
|
||||
InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis,
|
||||
output_labels, output_label_counts);
|
||||
} else if (max_bcast_dims > 0) {
|
||||
return errors::InvalidArgument("Output contains ", max_bcast_dims,
|
||||
" broadcasting dimension(s) but no ellipsis "
|
||||
"(...) was found in the output subscripts.");
|
||||
}
|
||||
// Populate DimensionType for the new broadcasting labels.
|
||||
label_types->resize(num_named_labels + max_bcast_dims, kBroadcasting);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Permutes the labels according to the given permutation.
|
||||
void PermuteLabels(const std::vector<int>& permutation, Labels* labels) {
|
||||
Labels permuted_labels(labels->size());
|
||||
for (int i = 0; i < labels->size(); ++i) {
|
||||
permuted_labels[i] = (*labels)[permutation[i]];
|
||||
}
|
||||
labels->swap(permuted_labels);
|
||||
}
|
||||
|
||||
// Returns a reshaped input Tensor. The underlying buffer is not copied.
|
||||
Status CopyFrom(const Tensor& input, const TensorShape& shape, Tensor* output) {
|
||||
if (output->CopyFrom(input, shape)) return Status::OK();
|
||||
return errors::Internal(
|
||||
"Encountered error while reshaping a Tensor of shape ",
|
||||
input.shape().DebugString(), " to shape ", shape.DebugString());
|
||||
}
|
||||
|
||||
// Returns whether transposing would be a no-op; whether input has rank < 2 or
|
||||
// the permutation is the identity permutation.
|
||||
bool ShouldTranspose(const TensorShape& input_shape,
|
||||
const std::vector<int>& permutation) {
|
||||
if (input_shape.dims() < 2) return false;
|
||||
for (int i = 0; i < permutation.size(); ++i) {
|
||||
if (permutation[i] != i) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Transpose the input given a permutation. Returns a reference to the input if
|
||||
// transposing is not necessary.
|
||||
template <typename Device, typename T>
|
||||
Status TransposeOperand(OpKernelContext* ctx, const Tensor& input,
|
||||
const std::vector<int>& permutation, Tensor* output) {
|
||||
if (!ShouldTranspose(input.shape(), permutation)) {
|
||||
return CopyFrom(input, input.shape(), output);
|
||||
}
|
||||
TensorShape transposed_shape;
|
||||
for (int i = 0; i < input.dims(); ++i) {
|
||||
transposed_shape.AddDim(input.dim_size(permutation[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output));
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If there are repeated labels in either the input or output, then this strides
|
||||
// the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively.
|
||||
template <typename Device, typename T>
|
||||
Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input,
|
||||
const Labels& labels, const LabelCounts& label_counts,
|
||||
const bool should_inflate, Tensor* output) {
|
||||
// Return early if there are no repeated indices.
|
||||
if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) {
|
||||
return CopyFrom(input, input.shape(), output);
|
||||
}
|
||||
// We reshape so that each repeated label is compressed to one dimension.
|
||||
// E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, 5].
|
||||
// Striding appropriately (in this case with strides 14 (=1+3+9) and 1)
|
||||
// recovers the generalized diagonal of shape [3, 5].
|
||||
ShapeVec reshape;
|
||||
ShapeVec strides;
|
||||
// Strided and inflated shapes correspond to input and output shapes,
|
||||
// respectively, should_inflate is true (vice-versa if should_inflate is
|
||||
// false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example.
|
||||
ShapeVec strided_shape;
|
||||
ShapeVec inflated_shape;
|
||||
for (int label : labels) {
|
||||
const int count = label_counts[label];
|
||||
const int current_axis =
|
||||
should_inflate ? strided_shape.size() : inflated_shape.size();
|
||||
const int64 dim = input.dim_size(current_axis);
|
||||
strided_shape.push_back(dim);
|
||||
inflated_shape.insert(inflated_shape.end(), count, dim);
|
||||
const int64 reshape_dim = MathUtil::IPow(dim, count);
|
||||
reshape.push_back(reshape_dim);
|
||||
// While taking the d-diagonal in a rank k Tensor, we take d equally-spaced
|
||||
// elements including the first and last element. Then,
|
||||
// (k - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1).
|
||||
const int64 stride =
|
||||
(dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1;
|
||||
strides.push_back(stride);
|
||||
}
|
||||
|
||||
TensorShape output_shape =
|
||||
TensorShape(should_inflate ? inflated_shape : strided_shape);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
switch (reshape.size()) {
|
||||
#define NDIMS_CASE(N) \
|
||||
case N: { \
|
||||
if (should_inflate) { \
|
||||
auto output_map = output->shaped<T, N>(reshape); \
|
||||
auto input_map = input.shaped<T, N>(strided_shape); \
|
||||
output_map.device(device) = input_map.inflate(strides); \
|
||||
} else { \
|
||||
auto input_map = input.shaped<T, N>(reshape); \
|
||||
auto output_map = output->shaped<T, N>(strided_shape); \
|
||||
output_map.device(device) = input_map.stride(strides); \
|
||||
} \
|
||||
} break;
|
||||
NDIMS_CASE(1);
|
||||
NDIMS_CASE(2);
|
||||
NDIMS_CASE(3);
|
||||
NDIMS_CASE(4);
|
||||
NDIMS_CASE(5);
|
||||
NDIMS_CASE(6);
|
||||
default:
|
||||
return errors::Unimplemented(
|
||||
"Unsupported rank: ", reshape.size(),
|
||||
" while handling repeated indices. Up to rank 6 is supported.");
|
||||
#undef NDIMS_CASE
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns true if the input dimensions are already sorted in the order
|
||||
// [batch, contract, free, reduce]. Used to implement an optimization to avoid
|
||||
// an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul.
|
||||
bool ShouldSwapFreeAndContract(const Labels& labels,
|
||||
const std::vector<DimensionType>& label_types) {
|
||||
// Check that ordering is according to dimension type, with the role of
|
||||
// free and contract dimensions swapped.
|
||||
gtl::InlinedVector<int, 5> remap = {0, 1, 3, 2, 4};
|
||||
for (int i = 0; i + 1 < labels.size(); ++i) {
|
||||
const int dimtype_a = remap[label_types[labels[i]]];
|
||||
const int dimtype_b = remap[label_types[labels[i + 1]]];
|
||||
if (dimtype_a > dimtype_b ||
|
||||
(dimtype_a == dimtype_b && labels[i] > labels[i + 1])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
Status ReduceOperand(OpKernelContext* ctx, const Tensor& input,
|
||||
const std::vector<DimensionType>& label_types,
|
||||
const LabelCounts& label_counts, Labels* labels,
|
||||
Labels* free_labels, bool* swap_free_and_contract,
|
||||
Tensor* output) {
|
||||
// Find the permutation to transpose the input dimensions in the order of
|
||||
// DimensionType; i.e. batch, free, contract and reduce dimensions. This
|
||||
// makes it more convenient to invoke Reduce/Contract operations.
|
||||
std::vector<int> permutation(input.dims());
|
||||
absl::c_iota(permutation, 0);
|
||||
Tensor input_transposed;
|
||||
// Check if we can avoid the transpose. We need to flip the adj_x (or adj_y)
|
||||
// flag during BatchMatMul. This is an extra optimization not necessary for
|
||||
// correctness.
|
||||
if (ShouldSwapFreeAndContract(*labels, label_types)) {
|
||||
*swap_free_and_contract = true;
|
||||
} else {
|
||||
absl::c_sort(permutation, [&](int i, int j) {
|
||||
int label_i = (*labels)[i];
|
||||
int label_j = (*labels)[j];
|
||||
return std::tie(label_types[label_i], label_i) <
|
||||
std::tie(label_types[label_j], label_j);
|
||||
});
|
||||
}
|
||||
// Transpose the input so that DimensionTypes are in order.
|
||||
TF_RETURN_IF_ERROR(
|
||||
TransposeOperand<Device, T>(ctx, input, permutation, &input_transposed));
|
||||
PermuteLabels(permutation, labels);
|
||||
|
||||
// Take the generalized diagonal for dimensions with repeated axis labels.
|
||||
Tensor input_deduped;
|
||||
labels->erase(std::unique(labels->begin(), labels->end()), labels->end());
|
||||
TF_RETURN_IF_ERROR(
|
||||
StrideOrInflate<Device, T>(ctx, input_transposed, *labels, label_counts,
|
||||
false /* should_inflate */, &input_deduped));
|
||||
|
||||
// Reshape denotes the rank-5 shape [broadcast, batch, free, contract, reduce]
|
||||
// where we've compacted the dimensions of each DimensionType.
|
||||
gtl::InlinedVector<int64, 5> reshape(5, 1);
|
||||
// The output shape is [batch shape] + [free size, contract size]
|
||||
// That is, the batch shape is preserved (for broadcasting while contracting)
|
||||
// while the free dims and contract dims are compressed to one dimension each.
|
||||
TensorShape output_shape;
|
||||
for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
|
||||
const int label = labels->at(label_idx);
|
||||
int64 dim = input_deduped.dim_size(label_idx);
|
||||
if (label_types[label] == kBroadcasting || label_types[label] == kBatch) {
|
||||
output_shape.AddDim(dim);
|
||||
} else if (label_types[label] == kFree) {
|
||||
free_labels->push_back(label);
|
||||
}
|
||||
reshape[label_types[label]] *= dim;
|
||||
}
|
||||
if (*swap_free_and_contract) std::swap(reshape[kFree], reshape[kContract]);
|
||||
output_shape.AddDim(reshape[kFree]);
|
||||
output_shape.AddDim(reshape[kContract]);
|
||||
|
||||
if (reshape[kReduce] == 1) { // No need to actually reduce.
|
||||
return CopyFrom(input_deduped, output_shape, output);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
|
||||
using Reducer = Eigen::internal::SumReducer<T>;
|
||||
using Index = typename TTypes<T>::Tensor::Index;
|
||||
// Reduce along the last axis (i.e axis 1) of the rank-2 Tensor.
|
||||
const int64 output_size = reshape[kBroadcasting] * reshape[kBatch] *
|
||||
reshape[kFree] * reshape[kContract];
|
||||
functor::ReduceFunctor<Device, Reducer>::Reduce(
|
||||
ctx, output->shaped<T, 1>({output_size}),
|
||||
const_cast<const Tensor&>(input_deduped)
|
||||
.shaped<T, 2>({output_size, reshape[kReduce]}),
|
||||
Eigen::array<Index, 1>({1}), Reducer());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M].
|
||||
Status ReshapeToRank3(const Tensor& input, int batch_size, Tensor* output) {
|
||||
const int rank = input.dims();
|
||||
TensorShape output_shape = {batch_size, input.dim_size(rank - 2),
|
||||
input.dim_size(rank - 1)};
|
||||
return CopyFrom(input, output_shape, output);
|
||||
}
|
||||
|
||||
// Conjugates the input.
|
||||
template <typename Device, typename T>
|
||||
Status Conjugate(OpKernelContext* ctx, Tensor* input) {
|
||||
std::vector<int> permutation(input->dims());
|
||||
std::iota(permutation.begin(), permutation.end(), 0);
|
||||
Tensor output;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, input->shape(), &output));
|
||||
const Device& d = ctx->eigen_device<Device>();
|
||||
TF_RETURN_IF_ERROR(DoConjugateTranspose(d, *input, permutation, &output));
|
||||
std::swap(*input, output);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Contracts the inputs along the last axis. (or the second last if the
|
||||
// corresponding value of swap_free_and_contract is true). The batch dimensions
|
||||
// are broadcast to the output shape.
|
||||
// TODO(anudhyan): Factor this function into a BatchMatMul functor and support
|
||||
// transpose_x and transpose_y attributes (in addition to adj_x and adj_y).
|
||||
// Also, the BatchMatMul might devolve into a component-wise multiplication when
|
||||
// the matrix shape is [1,1]; in this case BatchMatMul functor would be very
|
||||
// inefficient. The functor should detect if this is the case and perform
|
||||
// componentwise multiplication functor instead.
|
||||
template <typename Device, typename T>
|
||||
Status ContractOperands(OpKernelContext* ctx, absl::Span<const Tensor> inputs,
|
||||
absl::Span<const bool> swap_free_and_contract,
|
||||
Tensor* output) {
|
||||
if (inputs.size() == 1) return CopyFrom(inputs[0], inputs[0].shape(), output);
|
||||
MatMulBCast bcast(inputs[0].shape().dim_sizes(),
|
||||
inputs[1].shape().dim_sizes());
|
||||
if (!bcast.IsValid()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(),
|
||||
" vs. ", inputs[1].shape().DebugString());
|
||||
}
|
||||
Tensor lhs;
|
||||
TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs));
|
||||
Tensor rhs;
|
||||
TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs));
|
||||
TensorShape output_shape = bcast.output_batch_shape();
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const int64 free_axis =
|
||||
inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2);
|
||||
output_shape.AddDim(inputs[i].dim_size(free_axis));
|
||||
}
|
||||
bool adj_x = swap_free_and_contract[0];
|
||||
bool adj_y = !swap_free_and_contract[1];
|
||||
if (is_complex<T>::value) {
|
||||
if (adj_x) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &lhs));
|
||||
if (adj_y) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &rhs));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
|
||||
Tensor output_reshaped;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
|
||||
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, adj_x, adj_y, bcast,
|
||||
&output_reshaped);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename Device, typename T>
|
||||
class EinsumOp : public OpKernel {
|
||||
public:
|
||||
explicit EinsumOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
string equation;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("equation", &equation));
|
||||
OP_REQUIRES_OK(c, ParseEquation(equation, &input_labels_, &output_labels_,
|
||||
&label_types_, &input_label_counts_,
|
||||
&output_label_counts_, &input_has_ellipsis_,
|
||||
&output_has_ellipsis_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
OpInputList inputs;
|
||||
OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs));
|
||||
|
||||
OperandLabels input_labels(input_labels_);
|
||||
Labels output_labels(output_labels_);
|
||||
std::vector<DimensionType> label_types(label_types_);
|
||||
OperandLabelCounts input_label_counts(input_label_counts_);
|
||||
LabelCounts output_label_counts(output_label_counts_);
|
||||
LabelToDimSizes label_to_dim_sizes;
|
||||
|
||||
OP_REQUIRES_OK(ctx, ProcessDimensions(
|
||||
inputs, input_has_ellipsis_, output_has_ellipsis_,
|
||||
&input_labels, &output_labels, &label_types,
|
||||
&input_label_counts, &output_label_counts,
|
||||
&label_to_dim_sizes));
|
||||
|
||||
// The reduction phase (a) sums across reduction dimensions, (b) takes
|
||||
// generalized diagonals, and (c) reshapes it into shape
|
||||
// [(broadcasting) batch shape] + [F,C]
|
||||
// where F and C denote the total (compacted) size of free and contract
|
||||
// dimensions, respectively.
|
||||
const int num_inputs = inputs.size();
|
||||
OperandLabels free_labels(num_inputs);
|
||||
gtl::InlinedVector<Tensor, 2> inputs_reduced(num_inputs);
|
||||
gtl::InlinedVector<bool, 2> swap_free_and_contract(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ReduceOperand<Device, T>(
|
||||
ctx, inputs[i], label_types, input_label_counts[i],
|
||||
&input_labels[i], &free_labels[i],
|
||||
&swap_free_and_contract[i], &inputs_reduced[i]));
|
||||
}
|
||||
|
||||
// After reduction, the inputs should be reshaped to Tensors suitable for
|
||||
// contraction. If num_inputs is 1, the reduced input is simply forwarded to
|
||||
// the output.
|
||||
Tensor contraction_output_reshaped;
|
||||
OP_REQUIRES_OK(ctx, ContractOperands<Device, T>(
|
||||
ctx, inputs_reduced, swap_free_and_contract,
|
||||
&contraction_output_reshaped));
|
||||
|
||||
// Copy the batch labels from the contraction output. Recover the batch
|
||||
// shape, which may have been broadcasted.
|
||||
TensorShape result_shape = contraction_output_reshaped.shape();
|
||||
result_shape.RemoveLastDims(2);
|
||||
|
||||
int num_labels = label_types.size();
|
||||
Labels result_labels;
|
||||
// All batch dimensions should be present in the contracted result. First
|
||||
// the broadcasting dimensions, then the named batch dimensions.
|
||||
for (int label = 0; label < num_labels; ++label) {
|
||||
if (label_types[label] == kBroadcasting) result_labels.push_back(label);
|
||||
}
|
||||
for (int label = 0; label < num_labels; ++label) {
|
||||
if (label_types[label] == kBatch) result_labels.push_back(label);
|
||||
}
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
for (int label : free_labels[i]) {
|
||||
result_labels.push_back(label);
|
||||
result_shape.AddDim(label_to_dim_sizes[label]);
|
||||
}
|
||||
}
|
||||
|
||||
// Reshape the contraction (or reduction) result to its expanded shape:
|
||||
// [(broadcasted) batch shape] + [free shape 0] + [free shape 1].
|
||||
Tensor contraction_output;
|
||||
OP_REQUIRES_OK(ctx, CopyFrom(contraction_output_reshaped, result_shape,
|
||||
&contraction_output));
|
||||
|
||||
// Inflate the output if necessary. (E.g. for the equation 'i->iii' which
|
||||
// may arise while computing gradient of a regular Einsum).
|
||||
// TODO(anudhyan): It's possible that Eigen's contract and inflate can be
|
||||
// chained here to avoid materializing an intermediate.
|
||||
Tensor output_inflated;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, StrideOrInflate<Device, T>(
|
||||
ctx, contraction_output, result_labels, output_label_counts,
|
||||
true /* should_inflate */, &output_inflated));
|
||||
if (output_inflated.dims() > contraction_output.dims()) {
|
||||
// We inflated the output. Modify result labels accordingly.
|
||||
Labels inflated_labels;
|
||||
for (int label : result_labels) {
|
||||
inflated_labels.insert(inflated_labels.end(),
|
||||
output_label_counts[label], label);
|
||||
}
|
||||
result_labels.swap(inflated_labels);
|
||||
}
|
||||
// Find the permutation to map the result labels to the output labels. Note
|
||||
// that both the result and the final output may have the repeated labels,
|
||||
// in which case the permutation preserves the left-to-right ordering.
|
||||
// E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the
|
||||
// permutation should be [0, 2, 1]. We also use the fact that repeated
|
||||
// labels in the result are adjacent to each other.
|
||||
std::vector<int> output_permutation(output_labels.size());
|
||||
std::vector<int> label_to_position(num_labels, -1);
|
||||
for (int i = 0; i < result_labels.size(); ++i) {
|
||||
// Remember the position of only the leftmost result label.
|
||||
if (label_to_position[result_labels[i]] == -1) {
|
||||
label_to_position[result_labels[i]] = i;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < output_labels.size(); ++i) {
|
||||
output_permutation[i] = label_to_position[output_labels[i]];
|
||||
// We have found the leftmost occurrence. The next one would be adjacent.
|
||||
label_to_position[output_labels[i]] += 1;
|
||||
}
|
||||
Tensor output;
|
||||
OP_REQUIRES_OK(ctx, TransposeOperand<Device, T>(
|
||||
ctx, output_inflated, output_permutation, &output));
|
||||
ctx->set_output(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
OperandLabels input_labels_;
|
||||
Labels output_labels_;
|
||||
std::vector<DimensionType> label_types_;
|
||||
OperandLabelCounts input_label_counts_;
|
||||
LabelCounts output_label_counts_;
|
||||
gtl::InlinedVector<bool, 2> input_has_ellipsis_;
|
||||
bool output_has_ellipsis_ = false;
|
||||
};
|
||||
|
||||
#define REGISTER_EINSUM(D, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
|
||||
EinsumOp<D##Device, TYPE>);
|
||||
|
||||
// TODO(anudhyan): Also register GPU kernels for Einsum.
|
||||
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
TF_CALL_complex64(REGISTER_CPU);
|
||||
TF_CALL_complex128(REGISTER_CPU);
|
||||
#undef REGISTER_CPU
|
||||
#undef REGISTER_EINSUM
|
||||
|
||||
} // namespace tensorflow
|
@ -474,6 +474,14 @@ REGISTER_OP("TridiagonalSolve")
|
||||
.Attr("T: {double, float, complex64, complex128}")
|
||||
.SetShapeFn(TridiagonalSolveShapeFn);
|
||||
|
||||
REGISTER_OP("Einsum")
|
||||
.Input("inputs: N * T")
|
||||
.Output("output: T")
|
||||
.Attr("equation: string")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(shape_inference::EinsumShape);
|
||||
|
||||
// Deprecated op registrations:
|
||||
|
||||
// Can be deleted after 3feb2017.
|
||||
|
47
tensorflow/core/util/einsum_op_util.cc
Normal file
47
tensorflow/core/util/einsum_op_util.cc
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/einsum_op_util.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status ParseEinsumEquation(const string& equation,
|
||||
gtl::InlinedVector<string, 2>* input_subscripts,
|
||||
string* output_subscript) {
|
||||
gtl::InlinedVector<string, 2> inputs_and_output_subscripts =
|
||||
absl::StrSplit(equation, "->");
|
||||
if (inputs_and_output_subscripts.size() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Expecting exactly one '->' in einsum equation: ", equation);
|
||||
}
|
||||
*output_subscript = std::move(inputs_and_output_subscripts[1]);
|
||||
*input_subscripts =
|
||||
absl::StrSplit(std::move(inputs_and_output_subscripts[0]), ',');
|
||||
if (input_subscripts->size() != 1 && input_subscripts->size() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Expecting 1 or 2 input subscripts in equation '", equation,
|
||||
"' but got: ", input_subscripts->size());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
29
tensorflow/core/util/einsum_op_util.h
Normal file
29
tensorflow/core/util/einsum_op_util.h
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
|
||||
#define TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Parses and validates an einsum equation in explicit form.
|
||||
Status ParseEinsumEquation(const string& equation,
|
||||
gtl::InlinedVector<string, 2>* input_subscripts,
|
||||
string* output_subscript);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
|
@ -2127,6 +2127,23 @@ cuda_py_test(
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "einsum_op_test",
|
||||
size = "small",
|
||||
srcs = ["einsum_op_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "manip_ops_test",
|
||||
size = "small",
|
||||
|
251
tensorflow/python/kernel_tests/einsum_op_test.py
Normal file
251
tensorflow/python/kernel_tests/einsum_op_test.py
Normal file
@ -0,0 +1,251 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.Einsum."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_linalg_ops
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class EinsumOpTest(test.TestCase):
|
||||
|
||||
def _check(self, s, *input_shapes, **kwargs):
|
||||
dtype = kwargs.pop('dtype', np.float32)
|
||||
r = np.random.RandomState(0)
|
||||
inputs = []
|
||||
for shape in input_shapes:
|
||||
arr = np.array(r.randn(*shape)).astype(dtype)
|
||||
if dtype == np.complex64 or dtype == np.complex128:
|
||||
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
||||
inputs.append(arr)
|
||||
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
|
||||
a = np.einsum(s, *inputs)
|
||||
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s))
|
||||
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def testUnary(self):
|
||||
self._check('->', ())
|
||||
self._check('aa->', (3, 3))
|
||||
self._check('aa->a', (3, 3))
|
||||
self._check('aaa->', (3, 3, 3))
|
||||
self._check('aaa->a', (3, 3, 3))
|
||||
self._check('aab->a', (3, 3, 4))
|
||||
self._check('ab->', (3, 3))
|
||||
self._check('ab->ab', (3, 3))
|
||||
self._check('abc->b', (3, 4, 5))
|
||||
self._check('abc->ca', (3, 4, 5))
|
||||
self._check('abc->cab', (3, 4, 5))
|
||||
self._check('aabcc->a', (3, 3, 5, 4, 4))
|
||||
self._check('aabcc->ac', (3, 3, 5, 4, 4))
|
||||
self._check('aabcd->ad', (3, 3, 5, 4, 4))
|
||||
|
||||
def testUnaryEllipsis(self):
|
||||
self._check('...->...', ())
|
||||
self._check('...->', ())
|
||||
self._check('->...', ())
|
||||
|
||||
# Tests from dask
|
||||
self._check('a...a->a...', (2, 2))
|
||||
self._check('a...a->', (2, 2))
|
||||
self._check('a...a->...', (2, 5, 1, 2))
|
||||
self._check('a...a->a...', (2, 1, 2))
|
||||
self._check('a...a->a...', (2, 3, 4, 5, 2))
|
||||
|
||||
self._check('...ijk->...ki', (3, 4, 5))
|
||||
self._check('...ijk->...ki', (1, 3, 4, 5))
|
||||
self._check('...ijk->...ki', (2, 2, 3, 4, 5))
|
||||
|
||||
# Repeated indices.
|
||||
self._check('i...ii->...i', (3, 2, 3, 3))
|
||||
|
||||
def testBinary(self):
|
||||
self._check(',->', (), ())
|
||||
self._check('a,a->', (3,), (3,))
|
||||
self._check('a,a->a', (3,), (3,))
|
||||
self._check('ba,b->', (3, 2), (3,))
|
||||
self._check('ab,b->a', (3, 4), (4,))
|
||||
self._check('ab,ab->', (3, 4), (3, 4))
|
||||
self._check('nij,jk->nik', (5, 2, 3), (3, 4))
|
||||
self._check('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
|
||||
# Repeated indices.
|
||||
self._check('ijj,k->ik', (2, 3, 3), (4,))
|
||||
self._check('aba,a->b', (3, 4, 3), (3,))
|
||||
# From https://github.com/dask/dask/pull/3412#discussion_r182413444
|
||||
self._check('aab,bc->ac', (2, 2, 3), (3, 4))
|
||||
self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
|
||||
# Based on https://github.com/google/jax/issues/37#issuecomment-448572187
|
||||
self._check('sa,shb->shab', (2, 1), (2, 3, 4))
|
||||
|
||||
def testBroadcasting(self):
|
||||
# Batch matmul without broadcasting.
|
||||
self._check('...ij,...jk->...ik', (5, 1, 2, 3), (5, 1, 3, 4))
|
||||
# Batch matmul with broadcasting.
|
||||
self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
|
||||
self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
|
||||
self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5))
|
||||
self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5))
|
||||
self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5))
|
||||
self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))
|
||||
# Broadcasting with repeated indices.
|
||||
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
|
||||
self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4))
|
||||
self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4))
|
||||
self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4))
|
||||
# Following 2 from # https://stackoverflow.com/a/19203475/1611416
|
||||
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
|
||||
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
||||
|
||||
def testDtypes(self):
|
||||
for dtype in [np.float64, np.float32, np.complex64, np.complex128]:
|
||||
self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
|
||||
self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype)
|
||||
self._check('ji,kj->ik', (2, 2), (2, 2), dtype=dtype)
|
||||
self._check('ij,jk->ki', (2, 2), (2, 2), dtype=dtype)
|
||||
self._check('ji,kj->ki', (2, 2), (2, 2), dtype=dtype)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testInvalid(self):
|
||||
r = np.random.RandomState(0)
|
||||
cases = [
|
||||
# incorrect rank.
|
||||
('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)),
|
||||
('...ij,jk->ik', r.randn(3), r.randn(3, 4)),
|
||||
# inconsistent dimensions.
|
||||
('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)),
|
||||
# broadcasting is invalid
|
||||
('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)),
|
||||
# output should have ellipsis when broadcasting shape is
|
||||
# non-empty.
|
||||
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
|
||||
]
|
||||
for args in cases:
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
|
||||
|
||||
placeholders = [
|
||||
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
|
||||
]
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testPlaceholder(self):
|
||||
|
||||
def check(equation, *input_and_placeholder_shapes):
|
||||
r = np.random.RandomState(0)
|
||||
inputs = []
|
||||
input_placeholders = []
|
||||
for actual_shape, placeholder_shape in input_and_placeholder_shapes:
|
||||
input_np = np.array(r.randn(*actual_shape))
|
||||
inputs.append(input_np)
|
||||
input_placeholders.append(
|
||||
array_ops.placeholder_with_default(input_np, placeholder_shape))
|
||||
|
||||
a = np.einsum(equation, *inputs)
|
||||
b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation))
|
||||
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
|
||||
|
||||
check('bijl,bjkm->bik', ((9, 2, 3, 5), (None, None, None, 5)),
|
||||
((9, 3, 4, 7), (None, None, 4, None)))
|
||||
check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
|
||||
check('...ij,...->...i', ((4, 3, 1, 2), (None, 3, None, 2)),
|
||||
((4, 3), (None, 3)))
|
||||
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
|
||||
|
||||
def testOutputRepeatedLabels(self):
|
||||
# This is the reverse operation of repeated input labels, to be used for
|
||||
# computing symbolic gradients of einsum.
|
||||
r = np.random.RandomState(0)
|
||||
a = r.randn(2, 2)
|
||||
s = 'a->aa'
|
||||
diag_a = np.diag(np.diag(a))
|
||||
b = self.evaluate(gen_linalg_ops.einsum([np.diag(a)], s))
|
||||
self.assertAllClose(diag_a, b, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
class EinsumBenchmark(test.Benchmark):
|
||||
cases = [
|
||||
# Unary cases.
|
||||
['ijk->i', 100],
|
||||
['ijk->kji', 100],
|
||||
# Regular matmul or batch matmul.
|
||||
['ij,jk->ik', 1000],
|
||||
['ji,kj->ik', 1000],
|
||||
['ab,ab->', 100],
|
||||
['ab,ba->', 100],
|
||||
['abc,abc->', 100],
|
||||
['abc,bac->', 100],
|
||||
['abc,cba->', 100],
|
||||
['bij,bjk->bik', 100],
|
||||
['bji,bjk->bki', 100],
|
||||
['ikl,kji->kl', 100],
|
||||
['klj,lki->ij', 100],
|
||||
['ijk,ilj->kli', 100],
|
||||
['kij,mkb->ijmb', 100],
|
||||
['abcd,ad->bc', 40],
|
||||
# Larger binary contractions.
|
||||
['ijk,jklm->il', 40],
|
||||
['efabc,eabcd->efd', 30],
|
||||
['fabec,abcde->fde', 30],
|
||||
['efabc,edabc->efd', 30],
|
||||
['eadbf,dfebc->ecfad', 30],
|
||||
['abcdef,bcdfg->abcdeg', 30],
|
||||
]
|
||||
|
||||
def benchmarkEinsum(self):
|
||||
for equation, dim in self.cases:
|
||||
with ops.Graph().as_default(), \
|
||||
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||
ops.device('/cpu:0'):
|
||||
r = np.random.RandomState(0)
|
||||
input_subscripts = equation.split('->')[0].split(',')
|
||||
input_vars = []
|
||||
for subscript in input_subscripts:
|
||||
input_shape = (dim,) * len(subscript)
|
||||
input_vars.append(
|
||||
variables.Variable(np.array(r.randn(*input_shape), np.float32)))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Call einsum_v1.
|
||||
self.run_op_benchmark(
|
||||
sess,
|
||||
special_math_ops.einsum(equation, *input_vars),
|
||||
min_iters=50,
|
||||
name='einsum_v1_cpu_({})_{}'.format(equation, dim))
|
||||
|
||||
# Call gen_linalg_ops.einsum.
|
||||
self.run_op_benchmark(
|
||||
sess,
|
||||
gen_linalg_ops.einsum(input_vars, equation),
|
||||
min_iters=50,
|
||||
name='einsum_v2_cpu_({})_{}'.format(equation, dim))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -988,6 +988,10 @@ tf_module {
|
||||
name: "EditDistance"
|
||||
argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Einsum"
|
||||
argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Elu"
|
||||
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -988,6 +988,10 @@ tf_module {
|
||||
name: "EditDistance"
|
||||
argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Einsum"
|
||||
argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Elu"
|
||||
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user