From 5471ed4915581239cea98e8d1d7ce6ec3ecc03ff Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Mon, 30 Sep 2019 15:07:14 -0700 Subject: [PATCH] [XLA:CLIENT] Support single operand and implied strings in Einsum PiperOrigin-RevId: 272074915 --- .../compiler/tf2xla/kernels/einsum_op.cc | 8 +- tensorflow/compiler/xla/client/lib/matrix.cc | 96 +++++++++++++++++++ tensorflow/compiler/xla/client/lib/matrix.h | 12 +++ .../compiler/xla/client/lib/matrix_test.cc | 15 +-- .../compiler/xla/tests/dot_operation_test.cc | 35 ++++++- .../python/kernel_tests/einsum_op_test.py | 3 - .../python/ops/special_math_ops_test.py | 5 - 7 files changed, 157 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc index bf9313389dd..ae4d36d986d 100644 --- a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc @@ -38,8 +38,12 @@ class EinsumOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp lhs = ctx->Input(0); - xla::XlaOp rhs = ctx->Input(1); - ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_)); + if (equation_.find(",") == equation_.npos) { + ctx->SetOutput(0, xla::Einsum(lhs, equation_)); + } else { + xla::XlaOp rhs = ctx->Input(1); + ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_)); + } } private: diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index 909ee2d8476..6819e72ad6f 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -168,6 +168,53 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } +namespace { +std::vector EinsumDiagonalLabels(absl::Span config) { + std::vector unique_labels; + for (auto label = config.begin(); label != config.end(); ++label) { + auto first_label = absl::c_find(config, *label); + if (first_label == label) { + unique_labels.push_back(*label); + } + } + if (unique_labels.size() == config.size()) { + unique_labels.clear(); + } + return unique_labels; +} +} // namespace + +xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + if (EinsumDiagonalLabels(config).empty()) { + return x; + } + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + Shape iota_shape = x_shape; + iota_shape.set_element_type(S32); + XlaOp mask = ConstantR0(builder, true); + + absl::InlinedVector reduce_dims; + for (auto label = config.begin(); label != config.end(); ++label) { + const int64 dim = label - config.begin(); + auto first_label = absl::c_find(config, *label); + if (first_label == label) { + continue; + } + reduce_dims.push_back(dim); + const int64 first_dim = first_label - config.begin(); + mask = And(mask, Eq(Iota(builder, iota_shape, first_dim), + Iota(builder, iota_shape, dim))); + } + auto zero = ScalarLike(x, 0); + return Reduce(Select(mask, x, zero), zero, + CreateScalarIdentityWithZeroComputation( + x_shape.element_type(), builder), + reduce_dims); + }); +} + Status ValidateEinsumNumericDimensions(absl::Span x_config, absl::Span y_config, absl::Span output_config) { @@ -200,6 +247,7 @@ Status ValidateEinsumNumericDimensions(absl::Span x_config, } return Status::OK(); } + namespace { // Helper method to remove dimensions from a shape and dot dimension numbers // used to implment implicit broadcasting. @@ -232,6 +280,20 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, xla::PrecisionConfig::Precision precision) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto x_diagonal_labels = EinsumDiagonalLabels(x_config); + auto y_diagonal_labels = EinsumDiagonalLabels(y_config); + if (!x_diagonal_labels.empty() && !y_diagonal_labels.empty()) { + return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, + EinsumDiagonal(y, y_config), y_diagonal_labels, + output_config, precision); + } else if (!x_diagonal_labels.empty()) { + return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, y, y_config, + output_config, precision); + } else if (!y_diagonal_labels.empty()) { + return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels, + output_config, precision); + } + TF_RETURN_IF_ERROR( ValidateEinsumNumericDimensions(x_config, y_config, output_config)); TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); @@ -484,10 +546,38 @@ StatusOr, 3>> ParseEinsumString( return einsum_config_numeric; } +std::string NormalizeEinsumString(absl::string_view einsum_config) { + if (einsum_config.find("->") != einsum_config.npos) { + return ""; + } + bool has_ellipsis = einsum_config.find("...") != einsum_config.npos; + std::map chars; + for (char c : einsum_config) { + if (absl::ascii_isalpha(c)) { + ++chars[c]; + } + } + std::string new_config(einsum_config.begin(), einsum_config.end()); + new_config.append("->"); + if (has_ellipsis) { + new_config.append("..."); + } + for (auto p : chars) { + if (p.second == 1) { + new_config.push_back(p.first); + } + } + return new_config; +} + XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, PrecisionConfig::Precision precision) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto new_config = NormalizeEinsumString(einsum_config); + if (!new_config.empty()) { + return Einsum(x, y, new_config, precision); + } TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); TF_ASSIGN_OR_RETURN( @@ -498,6 +588,12 @@ XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, }); } +XlaOp Einsum(XlaOp x, absl::string_view einsum_config, + PrecisionConfig::Precision precision) { + return Einsum(ScalarLike(x, 1), x, absl::StrCat(",", einsum_config), + precision); +} + XlaOp TransposeInMinorDims(XlaOp x) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index be4860880ba..fcf06d28480 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -105,6 +105,11 @@ xla::XlaOp BatchDot( StatusOr, 3>> ParseEinsumString( absl::string_view einsum_config, int64 x_rank, int64 y_rank); +// If an einsum config does not contain an -> one will be added and the output +// config will be the sorted characters with any ellipsis at the beginning. +// Returns an empty string if the einsum string already has an ->. +std::string NormalizeEinsumString(absl::string_view einsum_config); + // Determine if each dimension label is in at least two inputs. // // NOTE: This function is meant for testing, there is no need to call it @@ -117,6 +122,13 @@ Status ValidateEinsumNumericDimensions(absl::Span x_config, xla::XlaOp Einsum( xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); +xla::XlaOp Einsum( + xla::XlaOp x, absl::string_view einsum_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Handles repeated indices within an operand by taking the tensor diagonal of +// the input. +xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span config); // Same as above but supporting numeric labels on dimensins. So "ab,cb->ac" // becomes: diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc index 73b40012ea2..42a6fec4af3 100644 --- a/tensorflow/compiler/xla/client/lib/matrix_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -213,12 +213,7 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) { } std::vector einsum_strings_that_fail_parsing = { - "", - "a", - "ab->ba", - "ab,bc,cd->ad", - - "a...b...,bc->a...c", + "", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c", }; for (auto test_case : einsum_strings_that_fail_parsing) { auto parse_result_or_status = ParseEinsumString(test_case, 3, 3); @@ -244,5 +239,13 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) { } } +XLA_TEST_F(MatrixTest, NormalizeEinsumString) { + EXPECT_EQ(NormalizeEinsumString("a,b->ab"), ""); + EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab"); + EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd"); + EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab"); + EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index c4daf7790f8..08519746d1d 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1201,7 +1201,12 @@ XLA_TEST_P(EinsumTest, SimpleEinsumTest) { MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam()))) .ValueOrDie(), &builder); - Einsum(x, y, std::get<2>(GetParam())); + auto config = std::get<2>(GetParam()); + if (config.find(",") == config.npos) { + Einsum(x, config); + } else { + Einsum(x, y, config); + } ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3}); } @@ -1231,11 +1236,39 @@ std::vector GetEinsumTestCases() { p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"}, p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"}, p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"}, + p{v{6, 6}, v{7, 7}, "mm,nn->mn"}, p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"}, p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"}, p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"}, p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"}, p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"}, + p{v{5, 6}, v{6, 7}, "mk,kn"}, + p{v{5, 6}, v{6, 7}, "mk,kn"}, + p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn"}, + p{v{5, 6}, v{6, 7}, "ab,cd"}, + p{v{6}, v{6, 7}, "b,bc"}, + p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc"}, + p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba"}, + p{v{77}, v{77}, "a,a"}, + p{v{77}, v{77, 55}, "a,ab"}, + p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb"}, + p{v{55}, v{}, "a"}, + p{v{11, 111}, v{11}, "ab,a"}, + p{v{16, 34}, v{16, 34}, "ab,ab"}, + p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac"}, + p{v{5, 19}, v{}, "ab"}, + p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf"}, + p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn"}, + p{v{5, 6}, v{}, "...mk"}, + p{v{5, 6, 12, 13}, v{}, "...mk"}, + p{v{5, 6, 12, 13}, v{}, "m...k"}, + p{v{5, 6, 12, 13}, v{}, "mk..."}, + p{v{5, 6}, v{6, 7}, "...mk->km..."}, + p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"}, + p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"}, + p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn"}, + p{v{16, 16, 16}, v{}, "iii"}, + p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb..."}, }; return test_cases; } diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py index da85b5f9710..9b052d2f45d 100644 --- a/tensorflow/python/kernel_tests/einsum_op_test.py +++ b/tensorflow/python/kernel_tests/einsum_op_test.py @@ -68,7 +68,6 @@ class EinsumOpTest(test.TestCase): self._check('aabcc->ac', (3, 3, 5, 4, 4)) self._check('aabcd->ad', (3, 3, 5, 4, 4)) - @test_util.disable_xla('b/131919749') def testUnaryEllipsis(self): # Unary cases with ellipsis. # Edge cases. @@ -110,7 +109,6 @@ class EinsumOpTest(test.TestCase): self._check('ba,b->', (3, 2), (3,)) self._check('ab,ab->', (3, 4), (3, 4)) - @test_util.disable_xla('b/131919749') def testRepeatedIndices(self): # Repeated indices. self._check('ijj,k->ik', (2, 3, 3), (4,)) @@ -143,7 +141,6 @@ class EinsumOpTest(test.TestCase): self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6)) self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,)) - @test_util.disable_xla('b/131919749') def testBroadcastingWithRepeatedIndices(self): # Broadcasting with repeated indices. self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4)) diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index 0090ac6d23b..6629f2bb6ca 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -231,7 +231,6 @@ class EinsumTest(test.TestCase): _ = special_math_ops.einsum( 'ij,jk->ik', a, b, name='name', invalid1='value1', invalid2='value2') - @test_util.disable_xla('b/131919749') def test_unary(self): self._check('a', (3,)) self._check('aa', (3, 3)) @@ -256,7 +255,6 @@ class EinsumTest(test.TestCase): self._check('aabcc->ac', (3, 3, 5, 4, 4)) self._check('aabcd->ad', (3, 3, 5, 4, 4)) - @test_util.disable_xla('b/131919749') def test_unary_ellipsis(self): self._check('...->', ()) self._check('...ijk->...ki', (3, 4, 5)) @@ -295,12 +293,10 @@ class EinsumTest(test.TestCase): self._check('ab,b', (3, 4), (4,)) self._check('cab,b', (1, 3, 4), (4,)) - @test_util.disable_xla('b/131919749') def test_reduced_indices(self): self._check('ba,b->', (3, 2), (3,)) self._check('ab,ab->', (3, 4), (3, 4)) - @test_util.disable_xla('b/131919749') def test_repeated_indices(self): with compat.forward_compatibility_horizon(2019, 10, 19): # Repeated indices. @@ -324,7 +320,6 @@ class EinsumTest(test.TestCase): self._check('...,...->...', (2, 3), (2, 3)) # hadamard product self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product - @test_util.disable_xla('b/131919749') def test_broadcasting(self): with compat.forward_compatibility_horizon(2019, 10, 19): # Batch matmul with broadcasting.