[XLA:CLIENT] Support single operand and implied strings in Einsum
PiperOrigin-RevId: 272074915
This commit is contained in:
parent
52c04ee452
commit
5471ed4915
@ -38,9 +38,13 @@ class EinsumOp : public XlaOpKernel {
|
|||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::XlaOp lhs = ctx->Input(0);
|
xla::XlaOp lhs = ctx->Input(0);
|
||||||
|
if (equation_.find(",") == equation_.npos) {
|
||||||
|
ctx->SetOutput(0, xla::Einsum(lhs, equation_));
|
||||||
|
} else {
|
||||||
xla::XlaOp rhs = ctx->Input(1);
|
xla::XlaOp rhs = ctx->Input(1);
|
||||||
ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_));
|
ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string equation_;
|
string equation_;
|
||||||
|
|||||||
@ -168,6 +168,53 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
|
|||||||
|
|
||||||
XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
|
XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::vector<int64> EinsumDiagonalLabels(absl::Span<const int64> config) {
|
||||||
|
std::vector<int64> unique_labels;
|
||||||
|
for (auto label = config.begin(); label != config.end(); ++label) {
|
||||||
|
auto first_label = absl::c_find(config, *label);
|
||||||
|
if (first_label == label) {
|
||||||
|
unique_labels.push_back(*label);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (unique_labels.size() == config.size()) {
|
||||||
|
unique_labels.clear();
|
||||||
|
}
|
||||||
|
return unique_labels;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
|
||||||
|
XlaBuilder* builder = x.builder();
|
||||||
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
if (EinsumDiagonalLabels(config).empty()) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||||
|
Shape iota_shape = x_shape;
|
||||||
|
iota_shape.set_element_type(S32);
|
||||||
|
XlaOp mask = ConstantR0(builder, true);
|
||||||
|
|
||||||
|
absl::InlinedVector<int64, 8> reduce_dims;
|
||||||
|
for (auto label = config.begin(); label != config.end(); ++label) {
|
||||||
|
const int64 dim = label - config.begin();
|
||||||
|
auto first_label = absl::c_find(config, *label);
|
||||||
|
if (first_label == label) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
reduce_dims.push_back(dim);
|
||||||
|
const int64 first_dim = first_label - config.begin();
|
||||||
|
mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
|
||||||
|
Iota(builder, iota_shape, dim)));
|
||||||
|
}
|
||||||
|
auto zero = ScalarLike(x, 0);
|
||||||
|
return Reduce(Select(mask, x, zero), zero,
|
||||||
|
CreateScalarIdentityWithZeroComputation(
|
||||||
|
x_shape.element_type(), builder),
|
||||||
|
reduce_dims);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
|
Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
|
||||||
absl::Span<const int64> y_config,
|
absl::Span<const int64> y_config,
|
||||||
absl::Span<const int64> output_config) {
|
absl::Span<const int64> output_config) {
|
||||||
@ -200,6 +247,7 @@ Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Helper method to remove dimensions from a shape and dot dimension numbers
|
// Helper method to remove dimensions from a shape and dot dimension numbers
|
||||||
// used to implment implicit broadcasting.
|
// used to implment implicit broadcasting.
|
||||||
@ -232,6 +280,20 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
|
|||||||
xla::PrecisionConfig::Precision precision) {
|
xla::PrecisionConfig::Precision precision) {
|
||||||
XlaBuilder* builder = x.builder();
|
XlaBuilder* builder = x.builder();
|
||||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
|
||||||
|
auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
|
||||||
|
if (!x_diagonal_labels.empty() && !y_diagonal_labels.empty()) {
|
||||||
|
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels,
|
||||||
|
EinsumDiagonal(y, y_config), y_diagonal_labels,
|
||||||
|
output_config, precision);
|
||||||
|
} else if (!x_diagonal_labels.empty()) {
|
||||||
|
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, y, y_config,
|
||||||
|
output_config, precision);
|
||||||
|
} else if (!y_diagonal_labels.empty()) {
|
||||||
|
return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels,
|
||||||
|
output_config, precision);
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ValidateEinsumNumericDimensions(x_config, y_config, output_config));
|
ValidateEinsumNumericDimensions(x_config, y_config, output_config));
|
||||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||||
@ -484,10 +546,38 @@ StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
|
|||||||
return einsum_config_numeric;
|
return einsum_config_numeric;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string NormalizeEinsumString(absl::string_view einsum_config) {
|
||||||
|
if (einsum_config.find("->") != einsum_config.npos) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
bool has_ellipsis = einsum_config.find("...") != einsum_config.npos;
|
||||||
|
std::map<char, int64> chars;
|
||||||
|
for (char c : einsum_config) {
|
||||||
|
if (absl::ascii_isalpha(c)) {
|
||||||
|
++chars[c];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string new_config(einsum_config.begin(), einsum_config.end());
|
||||||
|
new_config.append("->");
|
||||||
|
if (has_ellipsis) {
|
||||||
|
new_config.append("...");
|
||||||
|
}
|
||||||
|
for (auto p : chars) {
|
||||||
|
if (p.second == 1) {
|
||||||
|
new_config.push_back(p.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new_config;
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
|
XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
|
||||||
PrecisionConfig::Precision precision) {
|
PrecisionConfig::Precision precision) {
|
||||||
XlaBuilder* builder = x.builder();
|
XlaBuilder* builder = x.builder();
|
||||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
auto new_config = NormalizeEinsumString(einsum_config);
|
||||||
|
if (!new_config.empty()) {
|
||||||
|
return Einsum(x, y, new_config, precision);
|
||||||
|
}
|
||||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||||
TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
|
TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -498,6 +588,12 @@ XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp Einsum(XlaOp x, absl::string_view einsum_config,
|
||||||
|
PrecisionConfig::Precision precision) {
|
||||||
|
return Einsum(ScalarLike(x, 1), x, absl::StrCat(",", einsum_config),
|
||||||
|
precision);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp TransposeInMinorDims(XlaOp x) {
|
XlaOp TransposeInMinorDims(XlaOp x) {
|
||||||
XlaBuilder* builder = x.builder();
|
XlaBuilder* builder = x.builder();
|
||||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
|||||||
@ -105,6 +105,11 @@ xla::XlaOp BatchDot(
|
|||||||
StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
|
StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
|
||||||
absl::string_view einsum_config, int64 x_rank, int64 y_rank);
|
absl::string_view einsum_config, int64 x_rank, int64 y_rank);
|
||||||
|
|
||||||
|
// If an einsum config does not contain an -> one will be added and the output
|
||||||
|
// config will be the sorted characters with any ellipsis at the beginning.
|
||||||
|
// Returns an empty string if the einsum string already has an ->.
|
||||||
|
std::string NormalizeEinsumString(absl::string_view einsum_config);
|
||||||
|
|
||||||
// Determine if each dimension label is in at least two inputs.
|
// Determine if each dimension label is in at least two inputs.
|
||||||
//
|
//
|
||||||
// NOTE: This function is meant for testing, there is no need to call it
|
// NOTE: This function is meant for testing, there is no need to call it
|
||||||
@ -117,6 +122,13 @@ Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
|
|||||||
xla::XlaOp Einsum(
|
xla::XlaOp Einsum(
|
||||||
xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
|
xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
|
||||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
|
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
|
||||||
|
xla::XlaOp Einsum(
|
||||||
|
xla::XlaOp x, absl::string_view einsum_config,
|
||||||
|
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
|
||||||
|
|
||||||
|
// Handles repeated indices within an operand by taking the tensor diagonal of
|
||||||
|
// the input.
|
||||||
|
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config);
|
||||||
|
|
||||||
// Same as above but supporting numeric labels on dimensins. So "ab,cb->ac"
|
// Same as above but supporting numeric labels on dimensins. So "ab,cb->ac"
|
||||||
// becomes:
|
// becomes:
|
||||||
|
|||||||
@ -213,12 +213,7 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> einsum_strings_that_fail_parsing = {
|
std::vector<string> einsum_strings_that_fail_parsing = {
|
||||||
"",
|
"", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c",
|
||||||
"a",
|
|
||||||
"ab->ba",
|
|
||||||
"ab,bc,cd->ad",
|
|
||||||
|
|
||||||
"a...b...,bc->a...c",
|
|
||||||
};
|
};
|
||||||
for (auto test_case : einsum_strings_that_fail_parsing) {
|
for (auto test_case : einsum_strings_that_fail_parsing) {
|
||||||
auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
|
auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
|
||||||
@ -244,5 +239,13 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
|
||||||
|
EXPECT_EQ(NormalizeEinsumString("a,b->ab"), "");
|
||||||
|
EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab");
|
||||||
|
EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd");
|
||||||
|
EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab");
|
||||||
|
EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -1201,7 +1201,12 @@ XLA_TEST_P(EinsumTest, SimpleEinsumTest) {
|
|||||||
MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
|
MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
|
||||||
.ValueOrDie(),
|
.ValueOrDie(),
|
||||||
&builder);
|
&builder);
|
||||||
Einsum(x, y, std::get<2>(GetParam()));
|
auto config = std::get<2>(GetParam());
|
||||||
|
if (config.find(",") == config.npos) {
|
||||||
|
Einsum(x, config);
|
||||||
|
} else {
|
||||||
|
Einsum(x, y, config);
|
||||||
|
}
|
||||||
ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
|
ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1231,11 +1236,39 @@ std::vector<EinsumParamType> GetEinsumTestCases() {
|
|||||||
p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"},
|
p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"},
|
||||||
p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"},
|
p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"},
|
||||||
p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"},
|
p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"},
|
||||||
|
p{v{6, 6}, v{7, 7}, "mm,nn->mn"},
|
||||||
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
|
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
|
||||||
p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
|
p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
|
||||||
p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"},
|
p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"},
|
||||||
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"},
|
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"},
|
||||||
p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"},
|
p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"},
|
||||||
|
p{v{5, 6}, v{6, 7}, "mk,kn"},
|
||||||
|
p{v{5, 6}, v{6, 7}, "mk,kn"},
|
||||||
|
p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn"},
|
||||||
|
p{v{5, 6}, v{6, 7}, "ab,cd"},
|
||||||
|
p{v{6}, v{6, 7}, "b,bc"},
|
||||||
|
p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc"},
|
||||||
|
p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba"},
|
||||||
|
p{v{77}, v{77}, "a,a"},
|
||||||
|
p{v{77}, v{77, 55}, "a,ab"},
|
||||||
|
p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb"},
|
||||||
|
p{v{55}, v{}, "a"},
|
||||||
|
p{v{11, 111}, v{11}, "ab,a"},
|
||||||
|
p{v{16, 34}, v{16, 34}, "ab,ab"},
|
||||||
|
p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac"},
|
||||||
|
p{v{5, 19}, v{}, "ab"},
|
||||||
|
p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf"},
|
||||||
|
p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn"},
|
||||||
|
p{v{5, 6}, v{}, "...mk"},
|
||||||
|
p{v{5, 6, 12, 13}, v{}, "...mk"},
|
||||||
|
p{v{5, 6, 12, 13}, v{}, "m...k"},
|
||||||
|
p{v{5, 6, 12, 13}, v{}, "mk..."},
|
||||||
|
p{v{5, 6}, v{6, 7}, "...mk->km..."},
|
||||||
|
p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
|
||||||
|
p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
|
||||||
|
p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn"},
|
||||||
|
p{v{16, 16, 16}, v{}, "iii"},
|
||||||
|
p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb..."},
|
||||||
};
|
};
|
||||||
return test_cases;
|
return test_cases;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -68,7 +68,6 @@ class EinsumOpTest(test.TestCase):
|
|||||||
self._check('aabcc->ac', (3, 3, 5, 4, 4))
|
self._check('aabcc->ac', (3, 3, 5, 4, 4))
|
||||||
self._check('aabcd->ad', (3, 3, 5, 4, 4))
|
self._check('aabcd->ad', (3, 3, 5, 4, 4))
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
|
||||||
def testUnaryEllipsis(self):
|
def testUnaryEllipsis(self):
|
||||||
# Unary cases with ellipsis.
|
# Unary cases with ellipsis.
|
||||||
# Edge cases.
|
# Edge cases.
|
||||||
@ -110,7 +109,6 @@ class EinsumOpTest(test.TestCase):
|
|||||||
self._check('ba,b->', (3, 2), (3,))
|
self._check('ba,b->', (3, 2), (3,))
|
||||||
self._check('ab,ab->', (3, 4), (3, 4))
|
self._check('ab,ab->', (3, 4), (3, 4))
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
|
||||||
def testRepeatedIndices(self):
|
def testRepeatedIndices(self):
|
||||||
# Repeated indices.
|
# Repeated indices.
|
||||||
self._check('ijj,k->ik', (2, 3, 3), (4,))
|
self._check('ijj,k->ik', (2, 3, 3), (4,))
|
||||||
@ -143,7 +141,6 @@ class EinsumOpTest(test.TestCase):
|
|||||||
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
|
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
|
||||||
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
|
||||||
def testBroadcastingWithRepeatedIndices(self):
|
def testBroadcastingWithRepeatedIndices(self):
|
||||||
# Broadcasting with repeated indices.
|
# Broadcasting with repeated indices.
|
||||||
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
|
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
|
||||||
|
|||||||
@ -231,7 +231,6 @@ class EinsumTest(test.TestCase):
|
|||||||
_ = special_math_ops.einsum(
|
_ = special_math_ops.einsum(
|
||||||
'ij,jk->ik', a, b, name='name', invalid1='value1', invalid2='value2')
|
'ij,jk->ik', a, b, name='name', invalid1='value1', invalid2='value2')
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
|
||||||
def test_unary(self):
|
def test_unary(self):
|
||||||
self._check('a', (3,))
|
self._check('a', (3,))
|
||||||
self._check('aa', (3, 3))
|
self._check('aa', (3, 3))
|
||||||
@ -256,7 +255,6 @@ class EinsumTest(test.TestCase):
|
|||||||
self._check('aabcc->ac', (3, 3, 5, 4, 4))
|
self._check('aabcc->ac', (3, 3, 5, 4, 4))
|
||||||
self._check('aabcd->ad', (3, 3, 5, 4, 4))
|
self._check('aabcd->ad', (3, 3, 5, 4, 4))
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
|
||||||
def test_unary_ellipsis(self):
|
def test_unary_ellipsis(self):
|
||||||
self._check('...->', ())
|
self._check('...->', ())
|
||||||
self._check('...ijk->...ki', (3, 4, 5))
|
self._check('...ijk->...ki', (3, 4, 5))
|
||||||
@ -295,12 +293,10 @@ class EinsumTest(test.TestCase):
|
|||||||
self._check('ab,b', (3, 4), (4,))
|
self._check('ab,b', (3, 4), (4,))
|
||||||
self._check('cab,b', (1, 3, 4), (4,))
|
self._check('cab,b', (1, 3, 4), (4,))
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
|
||||||
def test_reduced_indices(self):
|
def test_reduced_indices(self):
|
||||||
self._check('ba,b->', (3, 2), (3,))
|
self._check('ba,b->', (3, 2), (3,))
|
||||||
self._check('ab,ab->', (3, 4), (3, 4))
|
self._check('ab,ab->', (3, 4), (3, 4))
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
|
||||||
def test_repeated_indices(self):
|
def test_repeated_indices(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||||
# Repeated indices.
|
# Repeated indices.
|
||||||
@ -324,7 +320,6 @@ class EinsumTest(test.TestCase):
|
|||||||
self._check('...,...->...', (2, 3), (2, 3)) # hadamard product
|
self._check('...,...->...', (2, 3), (2, 3)) # hadamard product
|
||||||
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
|
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
|
||||||
def test_broadcasting(self):
|
def test_broadcasting(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||||
# Batch matmul with broadcasting.
|
# Batch matmul with broadcasting.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user