[XLA:CLIENT] Support single operand and implied strings in Einsum

PiperOrigin-RevId: 272074915
This commit is contained in:
Blake Hechtman 2019-09-30 15:07:14 -07:00 committed by TensorFlower Gardener
parent 52c04ee452
commit 5471ed4915
7 changed files with 157 additions and 17 deletions

View File

@ -38,8 +38,12 @@ class EinsumOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp lhs = ctx->Input(0);
xla::XlaOp rhs = ctx->Input(1);
ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_));
if (equation_.find(",") == equation_.npos) {
ctx->SetOutput(0, xla::Einsum(lhs, equation_));
} else {
xla::XlaOp rhs = ctx->Input(1);
ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_));
}
}
private:

View File

@ -168,6 +168,53 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
namespace {
std::vector<int64> EinsumDiagonalLabels(absl::Span<const int64> config) {
std::vector<int64> unique_labels;
for (auto label = config.begin(); label != config.end(); ++label) {
auto first_label = absl::c_find(config, *label);
if (first_label == label) {
unique_labels.push_back(*label);
}
}
if (unique_labels.size() == config.size()) {
unique_labels.clear();
}
return unique_labels;
}
} // namespace
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (EinsumDiagonalLabels(config).empty()) {
return x;
}
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
Shape iota_shape = x_shape;
iota_shape.set_element_type(S32);
XlaOp mask = ConstantR0(builder, true);
absl::InlinedVector<int64, 8> reduce_dims;
for (auto label = config.begin(); label != config.end(); ++label) {
const int64 dim = label - config.begin();
auto first_label = absl::c_find(config, *label);
if (first_label == label) {
continue;
}
reduce_dims.push_back(dim);
const int64 first_dim = first_label - config.begin();
mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
Iota(builder, iota_shape, dim)));
}
auto zero = ScalarLike(x, 0);
return Reduce(Select(mask, x, zero), zero,
CreateScalarIdentityWithZeroComputation(
x_shape.element_type(), builder),
reduce_dims);
});
}
Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
absl::Span<const int64> y_config,
absl::Span<const int64> output_config) {
@ -200,6 +247,7 @@ Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
}
return Status::OK();
}
namespace {
// Helper method to remove dimensions from a shape and dot dimension numbers
// used to implment implicit broadcasting.
@ -232,6 +280,20 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
xla::PrecisionConfig::Precision precision) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
if (!x_diagonal_labels.empty() && !y_diagonal_labels.empty()) {
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels,
EinsumDiagonal(y, y_config), y_diagonal_labels,
output_config, precision);
} else if (!x_diagonal_labels.empty()) {
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, y, y_config,
output_config, precision);
} else if (!y_diagonal_labels.empty()) {
return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels,
output_config, precision);
}
TF_RETURN_IF_ERROR(
ValidateEinsumNumericDimensions(x_config, y_config, output_config));
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
@ -484,10 +546,38 @@ StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
return einsum_config_numeric;
}
std::string NormalizeEinsumString(absl::string_view einsum_config) {
if (einsum_config.find("->") != einsum_config.npos) {
return "";
}
bool has_ellipsis = einsum_config.find("...") != einsum_config.npos;
std::map<char, int64> chars;
for (char c : einsum_config) {
if (absl::ascii_isalpha(c)) {
++chars[c];
}
}
std::string new_config(einsum_config.begin(), einsum_config.end());
new_config.append("->");
if (has_ellipsis) {
new_config.append("...");
}
for (auto p : chars) {
if (p.second == 1) {
new_config.push_back(p.first);
}
}
return new_config;
}
XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
auto new_config = NormalizeEinsumString(einsum_config);
if (!new_config.empty()) {
return Einsum(x, y, new_config, precision);
}
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
TF_ASSIGN_OR_RETURN(
@ -498,6 +588,12 @@ XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
});
}
XlaOp Einsum(XlaOp x, absl::string_view einsum_config,
PrecisionConfig::Precision precision) {
return Einsum(ScalarLike(x, 1), x, absl::StrCat(",", einsum_config),
precision);
}
XlaOp TransposeInMinorDims(XlaOp x) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {

View File

@ -105,6 +105,11 @@ xla::XlaOp BatchDot(
StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
absl::string_view einsum_config, int64 x_rank, int64 y_rank);
// If an einsum config does not contain an -> one will be added and the output
// config will be the sorted characters with any ellipsis at the beginning.
// Returns an empty string if the einsum string already has an ->.
std::string NormalizeEinsumString(absl::string_view einsum_config);
// Determine if each dimension label is in at least two inputs.
//
// NOTE: This function is meant for testing, there is no need to call it
@ -117,6 +122,13 @@ Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
xla::XlaOp Einsum(
xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
xla::XlaOp Einsum(
xla::XlaOp x, absl::string_view einsum_config,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
// Handles repeated indices within an operand by taking the tensor diagonal of
// the input.
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config);
// Same as above but supporting numeric labels on dimensins. So "ab,cb->ac"
// becomes:

View File

@ -213,12 +213,7 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
}
std::vector<string> einsum_strings_that_fail_parsing = {
"",
"a",
"ab->ba",
"ab,bc,cd->ad",
"a...b...,bc->a...c",
"", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c",
};
for (auto test_case : einsum_strings_that_fail_parsing) {
auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
@ -244,5 +239,13 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
}
}
XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
EXPECT_EQ(NormalizeEinsumString("a,b->ab"), "");
EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab");
EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd");
EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab");
EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc");
}
} // namespace
} // namespace xla

View File

@ -1201,7 +1201,12 @@ XLA_TEST_P(EinsumTest, SimpleEinsumTest) {
MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
.ValueOrDie(),
&builder);
Einsum(x, y, std::get<2>(GetParam()));
auto config = std::get<2>(GetParam());
if (config.find(",") == config.npos) {
Einsum(x, config);
} else {
Einsum(x, y, config);
}
ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
}
@ -1231,11 +1236,39 @@ std::vector<EinsumParamType> GetEinsumTestCases() {
p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"},
p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"},
p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"},
p{v{6, 6}, v{7, 7}, "mm,nn->mn"},
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"},
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"},
p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"},
p{v{5, 6}, v{6, 7}, "mk,kn"},
p{v{5, 6}, v{6, 7}, "mk,kn"},
p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn"},
p{v{5, 6}, v{6, 7}, "ab,cd"},
p{v{6}, v{6, 7}, "b,bc"},
p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc"},
p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba"},
p{v{77}, v{77}, "a,a"},
p{v{77}, v{77, 55}, "a,ab"},
p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb"},
p{v{55}, v{}, "a"},
p{v{11, 111}, v{11}, "ab,a"},
p{v{16, 34}, v{16, 34}, "ab,ab"},
p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac"},
p{v{5, 19}, v{}, "ab"},
p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf"},
p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn"},
p{v{5, 6}, v{}, "...mk"},
p{v{5, 6, 12, 13}, v{}, "...mk"},
p{v{5, 6, 12, 13}, v{}, "m...k"},
p{v{5, 6, 12, 13}, v{}, "mk..."},
p{v{5, 6}, v{6, 7}, "...mk->km..."},
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn"},
p{v{16, 16, 16}, v{}, "iii"},
p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb..."},
};
return test_cases;
}

View File

@ -68,7 +68,6 @@ class EinsumOpTest(test.TestCase):
self._check('aabcc->ac', (3, 3, 5, 4, 4))
self._check('aabcd->ad', (3, 3, 5, 4, 4))
@test_util.disable_xla('b/131919749')
def testUnaryEllipsis(self):
# Unary cases with ellipsis.
# Edge cases.
@ -110,7 +109,6 @@ class EinsumOpTest(test.TestCase):
self._check('ba,b->', (3, 2), (3,))
self._check('ab,ab->', (3, 4), (3, 4))
@test_util.disable_xla('b/131919749')
def testRepeatedIndices(self):
# Repeated indices.
self._check('ijj,k->ik', (2, 3, 3), (4,))
@ -143,7 +141,6 @@ class EinsumOpTest(test.TestCase):
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
@test_util.disable_xla('b/131919749')
def testBroadcastingWithRepeatedIndices(self):
# Broadcasting with repeated indices.
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))

View File

@ -231,7 +231,6 @@ class EinsumTest(test.TestCase):
_ = special_math_ops.einsum(
'ij,jk->ik', a, b, name='name', invalid1='value1', invalid2='value2')
@test_util.disable_xla('b/131919749')
def test_unary(self):
self._check('a', (3,))
self._check('aa', (3, 3))
@ -256,7 +255,6 @@ class EinsumTest(test.TestCase):
self._check('aabcc->ac', (3, 3, 5, 4, 4))
self._check('aabcd->ad', (3, 3, 5, 4, 4))
@test_util.disable_xla('b/131919749')
def test_unary_ellipsis(self):
self._check('...->', ())
self._check('...ijk->...ki', (3, 4, 5))
@ -295,12 +293,10 @@ class EinsumTest(test.TestCase):
self._check('ab,b', (3, 4), (4,))
self._check('cab,b', (1, 3, 4), (4,))
@test_util.disable_xla('b/131919749')
def test_reduced_indices(self):
self._check('ba,b->', (3, 2), (3,))
self._check('ab,ab->', (3, 4), (3, 4))
@test_util.disable_xla('b/131919749')
def test_repeated_indices(self):
with compat.forward_compatibility_horizon(2019, 10, 19):
# Repeated indices.
@ -324,7 +320,6 @@ class EinsumTest(test.TestCase):
self._check('...,...->...', (2, 3), (2, 3)) # hadamard product
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
@test_util.disable_xla('b/131919749')
def test_broadcasting(self):
with compat.forward_compatibility_horizon(2019, 10, 19):
# Batch matmul with broadcasting.