[XLA:CLIENT] Support single operand and implied strings in Einsum
PiperOrigin-RevId: 272074915
This commit is contained in:
parent
52c04ee452
commit
5471ed4915
@ -38,8 +38,12 @@ class EinsumOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp lhs = ctx->Input(0);
|
||||
xla::XlaOp rhs = ctx->Input(1);
|
||||
ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_));
|
||||
if (equation_.find(",") == equation_.npos) {
|
||||
ctx->SetOutput(0, xla::Einsum(lhs, equation_));
|
||||
} else {
|
||||
xla::XlaOp rhs = ctx->Input(1);
|
||||
ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -168,6 +168,53 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
|
||||
|
||||
XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
|
||||
|
||||
namespace {
|
||||
std::vector<int64> EinsumDiagonalLabels(absl::Span<const int64> config) {
|
||||
std::vector<int64> unique_labels;
|
||||
for (auto label = config.begin(); label != config.end(); ++label) {
|
||||
auto first_label = absl::c_find(config, *label);
|
||||
if (first_label == label) {
|
||||
unique_labels.push_back(*label);
|
||||
}
|
||||
}
|
||||
if (unique_labels.size() == config.size()) {
|
||||
unique_labels.clear();
|
||||
}
|
||||
return unique_labels;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
if (EinsumDiagonalLabels(config).empty()) {
|
||||
return x;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||
Shape iota_shape = x_shape;
|
||||
iota_shape.set_element_type(S32);
|
||||
XlaOp mask = ConstantR0(builder, true);
|
||||
|
||||
absl::InlinedVector<int64, 8> reduce_dims;
|
||||
for (auto label = config.begin(); label != config.end(); ++label) {
|
||||
const int64 dim = label - config.begin();
|
||||
auto first_label = absl::c_find(config, *label);
|
||||
if (first_label == label) {
|
||||
continue;
|
||||
}
|
||||
reduce_dims.push_back(dim);
|
||||
const int64 first_dim = first_label - config.begin();
|
||||
mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
|
||||
Iota(builder, iota_shape, dim)));
|
||||
}
|
||||
auto zero = ScalarLike(x, 0);
|
||||
return Reduce(Select(mask, x, zero), zero,
|
||||
CreateScalarIdentityWithZeroComputation(
|
||||
x_shape.element_type(), builder),
|
||||
reduce_dims);
|
||||
});
|
||||
}
|
||||
|
||||
Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
|
||||
absl::Span<const int64> y_config,
|
||||
absl::Span<const int64> output_config) {
|
||||
@ -200,6 +247,7 @@ Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Helper method to remove dimensions from a shape and dot dimension numbers
|
||||
// used to implment implicit broadcasting.
|
||||
@ -232,6 +280,20 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
|
||||
xla::PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
|
||||
auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
|
||||
if (!x_diagonal_labels.empty() && !y_diagonal_labels.empty()) {
|
||||
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels,
|
||||
EinsumDiagonal(y, y_config), y_diagonal_labels,
|
||||
output_config, precision);
|
||||
} else if (!x_diagonal_labels.empty()) {
|
||||
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, y, y_config,
|
||||
output_config, precision);
|
||||
} else if (!y_diagonal_labels.empty()) {
|
||||
return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels,
|
||||
output_config, precision);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateEinsumNumericDimensions(x_config, y_config, output_config));
|
||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||
@ -484,10 +546,38 @@ StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
|
||||
return einsum_config_numeric;
|
||||
}
|
||||
|
||||
std::string NormalizeEinsumString(absl::string_view einsum_config) {
|
||||
if (einsum_config.find("->") != einsum_config.npos) {
|
||||
return "";
|
||||
}
|
||||
bool has_ellipsis = einsum_config.find("...") != einsum_config.npos;
|
||||
std::map<char, int64> chars;
|
||||
for (char c : einsum_config) {
|
||||
if (absl::ascii_isalpha(c)) {
|
||||
++chars[c];
|
||||
}
|
||||
}
|
||||
std::string new_config(einsum_config.begin(), einsum_config.end());
|
||||
new_config.append("->");
|
||||
if (has_ellipsis) {
|
||||
new_config.append("...");
|
||||
}
|
||||
for (auto p : chars) {
|
||||
if (p.second == 1) {
|
||||
new_config.push_back(p.first);
|
||||
}
|
||||
}
|
||||
return new_config;
|
||||
}
|
||||
|
||||
XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
auto new_config = NormalizeEinsumString(einsum_config);
|
||||
if (!new_config.empty()) {
|
||||
return Einsum(x, y, new_config, precision);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||
TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -498,6 +588,12 @@ XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Einsum(XlaOp x, absl::string_view einsum_config,
|
||||
PrecisionConfig::Precision precision) {
|
||||
return Einsum(ScalarLike(x, 1), x, absl::StrCat(",", einsum_config),
|
||||
precision);
|
||||
}
|
||||
|
||||
XlaOp TransposeInMinorDims(XlaOp x) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
|
@ -105,6 +105,11 @@ xla::XlaOp BatchDot(
|
||||
StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
|
||||
absl::string_view einsum_config, int64 x_rank, int64 y_rank);
|
||||
|
||||
// If an einsum config does not contain an -> one will be added and the output
|
||||
// config will be the sorted characters with any ellipsis at the beginning.
|
||||
// Returns an empty string if the einsum string already has an ->.
|
||||
std::string NormalizeEinsumString(absl::string_view einsum_config);
|
||||
|
||||
// Determine if each dimension label is in at least two inputs.
|
||||
//
|
||||
// NOTE: This function is meant for testing, there is no need to call it
|
||||
@ -117,6 +122,13 @@ Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
|
||||
xla::XlaOp Einsum(
|
||||
xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
|
||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
|
||||
xla::XlaOp Einsum(
|
||||
xla::XlaOp x, absl::string_view einsum_config,
|
||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
|
||||
|
||||
// Handles repeated indices within an operand by taking the tensor diagonal of
|
||||
// the input.
|
||||
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config);
|
||||
|
||||
// Same as above but supporting numeric labels on dimensins. So "ab,cb->ac"
|
||||
// becomes:
|
||||
|
@ -213,12 +213,7 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
|
||||
}
|
||||
|
||||
std::vector<string> einsum_strings_that_fail_parsing = {
|
||||
"",
|
||||
"a",
|
||||
"ab->ba",
|
||||
"ab,bc,cd->ad",
|
||||
|
||||
"a...b...,bc->a...c",
|
||||
"", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c",
|
||||
};
|
||||
for (auto test_case : einsum_strings_that_fail_parsing) {
|
||||
auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
|
||||
@ -244,5 +239,13 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
|
||||
EXPECT_EQ(NormalizeEinsumString("a,b->ab"), "");
|
||||
EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab");
|
||||
EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd");
|
||||
EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab");
|
||||
EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1201,7 +1201,12 @@ XLA_TEST_P(EinsumTest, SimpleEinsumTest) {
|
||||
MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
|
||||
.ValueOrDie(),
|
||||
&builder);
|
||||
Einsum(x, y, std::get<2>(GetParam()));
|
||||
auto config = std::get<2>(GetParam());
|
||||
if (config.find(",") == config.npos) {
|
||||
Einsum(x, config);
|
||||
} else {
|
||||
Einsum(x, y, config);
|
||||
}
|
||||
ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
|
||||
}
|
||||
|
||||
@ -1231,11 +1236,39 @@ std::vector<EinsumParamType> GetEinsumTestCases() {
|
||||
p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"},
|
||||
p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"},
|
||||
p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"},
|
||||
p{v{6, 6}, v{7, 7}, "mm,nn->mn"},
|
||||
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
|
||||
p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
|
||||
p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"},
|
||||
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"},
|
||||
p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"},
|
||||
p{v{5, 6}, v{6, 7}, "mk,kn"},
|
||||
p{v{5, 6}, v{6, 7}, "mk,kn"},
|
||||
p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn"},
|
||||
p{v{5, 6}, v{6, 7}, "ab,cd"},
|
||||
p{v{6}, v{6, 7}, "b,bc"},
|
||||
p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc"},
|
||||
p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba"},
|
||||
p{v{77}, v{77}, "a,a"},
|
||||
p{v{77}, v{77, 55}, "a,ab"},
|
||||
p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb"},
|
||||
p{v{55}, v{}, "a"},
|
||||
p{v{11, 111}, v{11}, "ab,a"},
|
||||
p{v{16, 34}, v{16, 34}, "ab,ab"},
|
||||
p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac"},
|
||||
p{v{5, 19}, v{}, "ab"},
|
||||
p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf"},
|
||||
p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn"},
|
||||
p{v{5, 6}, v{}, "...mk"},
|
||||
p{v{5, 6, 12, 13}, v{}, "...mk"},
|
||||
p{v{5, 6, 12, 13}, v{}, "m...k"},
|
||||
p{v{5, 6, 12, 13}, v{}, "mk..."},
|
||||
p{v{5, 6}, v{6, 7}, "...mk->km..."},
|
||||
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
|
||||
p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
|
||||
p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn"},
|
||||
p{v{16, 16, 16}, v{}, "iii"},
|
||||
p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb..."},
|
||||
};
|
||||
return test_cases;
|
||||
}
|
||||
|
@ -68,7 +68,6 @@ class EinsumOpTest(test.TestCase):
|
||||
self._check('aabcc->ac', (3, 3, 5, 4, 4))
|
||||
self._check('aabcd->ad', (3, 3, 5, 4, 4))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testUnaryEllipsis(self):
|
||||
# Unary cases with ellipsis.
|
||||
# Edge cases.
|
||||
@ -110,7 +109,6 @@ class EinsumOpTest(test.TestCase):
|
||||
self._check('ba,b->', (3, 2), (3,))
|
||||
self._check('ab,ab->', (3, 4), (3, 4))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testRepeatedIndices(self):
|
||||
# Repeated indices.
|
||||
self._check('ijj,k->ik', (2, 3, 3), (4,))
|
||||
@ -143,7 +141,6 @@ class EinsumOpTest(test.TestCase):
|
||||
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
|
||||
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testBroadcastingWithRepeatedIndices(self):
|
||||
# Broadcasting with repeated indices.
|
||||
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
|
||||
|
@ -231,7 +231,6 @@ class EinsumTest(test.TestCase):
|
||||
_ = special_math_ops.einsum(
|
||||
'ij,jk->ik', a, b, name='name', invalid1='value1', invalid2='value2')
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_unary(self):
|
||||
self._check('a', (3,))
|
||||
self._check('aa', (3, 3))
|
||||
@ -256,7 +255,6 @@ class EinsumTest(test.TestCase):
|
||||
self._check('aabcc->ac', (3, 3, 5, 4, 4))
|
||||
self._check('aabcd->ad', (3, 3, 5, 4, 4))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_unary_ellipsis(self):
|
||||
self._check('...->', ())
|
||||
self._check('...ijk->...ki', (3, 4, 5))
|
||||
@ -295,12 +293,10 @@ class EinsumTest(test.TestCase):
|
||||
self._check('ab,b', (3, 4), (4,))
|
||||
self._check('cab,b', (1, 3, 4), (4,))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_reduced_indices(self):
|
||||
self._check('ba,b->', (3, 2), (3,))
|
||||
self._check('ab,ab->', (3, 4), (3, 4))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_repeated_indices(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Repeated indices.
|
||||
@ -324,7 +320,6 @@ class EinsumTest(test.TestCase):
|
||||
self._check('...,...->...', (2, 3), (2, 3)) # hadamard product
|
||||
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_broadcasting(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Batch matmul with broadcasting.
|
||||
|
Loading…
x
Reference in New Issue
Block a user