[XLA:CLIENT] Support all gradients of <= 2 operand einsums
PiperOrigin-RevId: 328171480 Change-Id: I1adbe658c9e3f435d4a42c2627cbbfef297e02f3
This commit is contained in:
parent
2d0592a000
commit
116792db45
@ -199,6 +199,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
@ -235,85 +236,93 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
|
||||
XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
|
||||
|
||||
namespace {
|
||||
std::vector<int64> EinsumDiagonalLabels(absl::Span<const int64> config) {
|
||||
absl::optional<std::array<std::vector<int64>, 3>> EinsumDiagonalLabels(
|
||||
absl::Span<const int64> config) {
|
||||
std::vector<int64> unique_labels;
|
||||
std::vector<int64> reduce_dims;
|
||||
std::vector<int64> broadcast_dims;
|
||||
for (auto label = config.begin(); label != config.end(); ++label) {
|
||||
auto first_label = absl::c_find(config, *label);
|
||||
auto dim = label - config.begin();
|
||||
if (first_label == label) {
|
||||
unique_labels.push_back(*label);
|
||||
broadcast_dims.push_back(dim);
|
||||
} else {
|
||||
reduce_dims.push_back(dim);
|
||||
}
|
||||
}
|
||||
if (unique_labels.size() == config.size()) {
|
||||
unique_labels.clear();
|
||||
return absl::nullopt;
|
||||
}
|
||||
return unique_labels;
|
||||
return {{unique_labels, reduce_dims, broadcast_dims}};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
|
||||
// Masks a tensor such that only the diagonal of repeated indices are non-zero.
|
||||
// The result of this can be used to create a diagonal matrix with an identity
|
||||
// reduction.
|
||||
xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span<const int64> config) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
if (EinsumDiagonalLabels(config).empty()) {
|
||||
return x;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||
Shape iota_shape = x_shape;
|
||||
iota_shape.set_element_type(S32);
|
||||
XlaOp mask = ConstantR0(builder, true);
|
||||
|
||||
absl::InlinedVector<int64, 8> reduce_dims;
|
||||
for (auto label = config.begin(); label != config.end(); ++label) {
|
||||
const int64 dim = label - config.begin();
|
||||
auto first_label = absl::c_find(config, *label);
|
||||
if (first_label == label) {
|
||||
continue;
|
||||
if (first_label != label) {
|
||||
const int64 first_dim = first_label - config.begin();
|
||||
mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
|
||||
Iota(builder, iota_shape, dim)));
|
||||
}
|
||||
reduce_dims.push_back(dim);
|
||||
const int64 first_dim = first_label - config.begin();
|
||||
mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
|
||||
Iota(builder, iota_shape, dim)));
|
||||
}
|
||||
auto zero = ScalarLike(x, 0);
|
||||
return Reduce(Select(mask, x, zero), zero,
|
||||
CreateScalarIdentityWithZeroComputation(
|
||||
x_shape.element_type(), builder),
|
||||
reduce_dims);
|
||||
return Select(mask, x, ZerosLike(x));
|
||||
});
|
||||
}
|
||||
|
||||
Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
|
||||
absl::Span<const int64> y_config,
|
||||
absl::Span<const int64> output_config) {
|
||||
for (auto dim : output_config) {
|
||||
if (absl::c_linear_search(x_config, dim) ||
|
||||
absl::c_linear_search(y_config, dim)) {
|
||||
if (absl::c_count(output_config, dim) > 1) {
|
||||
return InvalidArgument("Einsum has repeated output dimension.");
|
||||
}
|
||||
continue;
|
||||
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
auto labels = EinsumDiagonalLabels(config);
|
||||
if (!labels) {
|
||||
return x;
|
||||
}
|
||||
return InvalidArgument(
|
||||
"Einsum has output dimension without corresponding input dimension.");
|
||||
}
|
||||
for (auto dim : x_config) {
|
||||
if (absl::c_linear_search(y_config, dim) ||
|
||||
absl::c_linear_search(output_config, dim)) {
|
||||
if (absl::c_count(x_config, dim) > 1) {
|
||||
return InvalidArgument("Einsum has repeated lhs dimension.");
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto dim : y_config) {
|
||||
if (absl::c_linear_search(x_config, dim) ||
|
||||
absl::c_linear_search(output_config, dim)) {
|
||||
if (absl::c_count(y_config, dim) > 1) {
|
||||
return InvalidArgument("Einsum has repeated rhs dimension.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
auto zero = ScalarLike(x, 0);
|
||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||
return Reduce(EinsumDiagonalMask(x, config), zero,
|
||||
CreateScalarIdentityWithZeroComputation(
|
||||
x_shape.element_type(), builder),
|
||||
labels->at(1));
|
||||
});
|
||||
}
|
||||
|
||||
xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span<const int64> config) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
auto labels = EinsumDiagonalLabels(config);
|
||||
if (!labels) {
|
||||
return x;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||
std::vector<int64> broadcast_sizes;
|
||||
int64 x_dim = 0;
|
||||
for (auto label = config.begin(); label != config.end(); ++label) {
|
||||
auto first_label = absl::c_find(config, *label);
|
||||
if (first_label == label) {
|
||||
broadcast_sizes.push_back(x_shape.dimensions(x_dim));
|
||||
++x_dim;
|
||||
} else {
|
||||
broadcast_sizes.push_back(
|
||||
broadcast_sizes[first_label - config.begin()]);
|
||||
}
|
||||
}
|
||||
x = BroadcastInDim(x, broadcast_sizes, labels->at(2));
|
||||
return EinsumDiagonalMask(x, config);
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Helper method to remove dimensions from a shape and dot dimension numbers
|
||||
// used to implement implicit broadcasting.
|
||||
@ -347,21 +356,23 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
|
||||
if (x_diagonal_labels) {
|
||||
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y,
|
||||
y_config, output_config, precision);
|
||||
}
|
||||
auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
|
||||
if (!x_diagonal_labels.empty() && !y_diagonal_labels.empty()) {
|
||||
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels,
|
||||
EinsumDiagonal(y, y_config), y_diagonal_labels,
|
||||
output_config, precision);
|
||||
} else if (!x_diagonal_labels.empty()) {
|
||||
return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, y, y_config,
|
||||
output_config, precision);
|
||||
} else if (!y_diagonal_labels.empty()) {
|
||||
return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels,
|
||||
output_config, precision);
|
||||
if (y_diagonal_labels) {
|
||||
return Einsum(x, x_config, EinsumDiagonal(y, y_config),
|
||||
y_diagonal_labels->at(0), output_config, precision);
|
||||
}
|
||||
auto output_diagonal_labels = EinsumDiagonalLabels(output_config);
|
||||
if (output_diagonal_labels) {
|
||||
return EinsumInverseDiagonal(
|
||||
Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0),
|
||||
precision),
|
||||
output_config);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateEinsumNumericDimensions(x_config, y_config, output_config));
|
||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||
TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
|
||||
const int64 x_rank = x_config.size();
|
||||
@ -372,21 +383,15 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
|
||||
absl::flat_hash_set<int64> output_map;
|
||||
|
||||
for (auto d : x_config) {
|
||||
if (!x_map.insert(d).second) {
|
||||
return InvalidArgument("XLA Einsum does not support rhs tracing");
|
||||
}
|
||||
x_map.insert(d);
|
||||
}
|
||||
|
||||
for (auto d : y_config) {
|
||||
if (!y_map.insert(d).second) {
|
||||
return InvalidArgument("XLA Einsum does not support lhs tracing");
|
||||
}
|
||||
y_map.insert(d);
|
||||
}
|
||||
|
||||
for (auto d : output_config) {
|
||||
if (!output_map.insert(d).second) {
|
||||
return InvalidArgument("XLA Einsum does not support output tracing");
|
||||
}
|
||||
output_map.insert(d);
|
||||
}
|
||||
|
||||
DotDimensionNumbers dnums;
|
||||
@ -397,6 +402,7 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
|
||||
auto is_contracting = [&](int64 d) {
|
||||
return x_map.contains(d) && y_map.contains(d);
|
||||
};
|
||||
|
||||
auto rhs_dimension_number = [&](int64 d) {
|
||||
return absl::c_find(y_config, d) - y_config.begin();
|
||||
};
|
||||
@ -468,8 +474,9 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
|
||||
output_dimension_number(y_config[d]);
|
||||
}
|
||||
|
||||
const int64 transpose_rank = output_transpose_dims.size();
|
||||
std::vector<int64> transpose_dims(output_rank);
|
||||
for (int64 i = 0; i < output_rank; ++i) {
|
||||
for (int64 i = 0; i < transpose_rank; ++i) {
|
||||
transpose_dims[output_transpose_dims[i]] = i;
|
||||
}
|
||||
|
||||
@ -498,7 +505,27 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
|
||||
CreateScalarAddComputation(x_shape.element_type(), builder),
|
||||
output_reduce_dims);
|
||||
}
|
||||
return Transpose(dot, transpose_dims);
|
||||
dot = Transpose(dot, transpose_dims);
|
||||
if (transpose_rank == output_rank) {
|
||||
return dot;
|
||||
}
|
||||
|
||||
auto is_output_only = [&](int64 d) {
|
||||
return output_map.contains(d) && !x_map.contains(d) && !y_map.contains(d);
|
||||
};
|
||||
|
||||
int64 dot_dim = 0;
|
||||
std::vector<int64> new_dims;
|
||||
new_dims.reserve(output_rank);
|
||||
TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot));
|
||||
for (auto d : output_config) {
|
||||
if (is_output_only(d)) {
|
||||
new_dims.push_back(1);
|
||||
} else {
|
||||
new_dims.push_back(dot_shape.dimensions(dot_dim));
|
||||
}
|
||||
}
|
||||
return Reshape(dot, new_dims);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -112,14 +112,6 @@ StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
|
||||
// Returns an empty string if the einsum string already has an ->.
|
||||
std::string NormalizeEinsumString(absl::string_view einsum_config);
|
||||
|
||||
// Determine if each dimension label is in at least two inputs.
|
||||
//
|
||||
// NOTE: This function is meant for testing, there is no need to call it
|
||||
// directly.
|
||||
Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
|
||||
absl::Span<const int64> y_config,
|
||||
absl::Span<const int64> output_config);
|
||||
|
||||
// Supports two operand einsum notation like "ab,cb->ac".
|
||||
xla::XlaOp Einsum(
|
||||
xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
|
||||
@ -128,9 +120,6 @@ xla::XlaOp Einsum(
|
||||
xla::XlaOp x, absl::string_view einsum_config,
|
||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
|
||||
|
||||
// Handles repeated indices within an operand by taking the tensor diagonal of
|
||||
// the input.
|
||||
xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config);
|
||||
|
||||
// Same as above but supporting numeric labels on dimensions. So "ab,cb->ac"
|
||||
// becomes:
|
||||
|
@ -233,12 +233,23 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
|
||||
};
|
||||
|
||||
std::vector<std::vector<string>> good_test_cases = {
|
||||
{"ab", "bc", "ac"}, {"Bab", "Bbc", "Bac"},
|
||||
{"ab", "cd", "dcba"}, {"abc", "abd", "cbd"},
|
||||
{"...ab", "...bc", "...ac"}, {"a...bc", "...abd", "cbd..."},
|
||||
{"...ab", "...bc", "ac"}, {"...b", "...bc", "...c"},
|
||||
{"...abz", "...bc", "...ac"}, {"...ab", "...bcz", "...ac"},
|
||||
{"abz", "bc", "ac"}, {"ab", "bcz", "ac"},
|
||||
{"ab", "bc", "ac"},
|
||||
{"Bab", "Bbc", "Bac"},
|
||||
{"ab", "cd", "dcba"},
|
||||
{"abc", "abd", "cbd"},
|
||||
{"...ab", "...bc", "...ac"},
|
||||
{"a...bc", "...abd", "cbd..."},
|
||||
{"...ab", "...bc", "ac"},
|
||||
{"...b", "...bc", "...c"},
|
||||
{"...abz", "...bc", "...ac"},
|
||||
{"...ab", "...bcz", "...ac"},
|
||||
{"abz", "bc", "ac"},
|
||||
{"ab", "bcz", "ac"},
|
||||
|
||||
{"a", "b", "c"},
|
||||
{"...a", "...b", "...c"},
|
||||
{"abb", "bcc", "ac"},
|
||||
{"ab", "bc", "ad"},
|
||||
};
|
||||
for (auto test_case : good_test_cases) {
|
||||
auto parse_result_or_status =
|
||||
@ -249,9 +260,6 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
|
||||
}
|
||||
EXPECT_TRUE(ValidateEinsumNumericDimensions(
|
||||
parse_result[0], parse_result[1], parse_result[2])
|
||||
.ok());
|
||||
}
|
||||
|
||||
std::vector<string> einsum_strings_that_fail_parsing = {
|
||||
@ -261,24 +269,6 @@ XLA_TEST_F(MatrixTest, ParseEinsumString) {
|
||||
auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
|
||||
EXPECT_FALSE(parse_result_or_status.status().ok());
|
||||
}
|
||||
std::vector<std::vector<string>> einsum_strings_that_fail_numeric_validation =
|
||||
{
|
||||
{"a", "b", "c"},
|
||||
{"...a", "...b", "...c"},
|
||||
{"abb", "bcc", "ac"},
|
||||
{"ab", "bc", "ad"},
|
||||
};
|
||||
|
||||
for (auto test_case : einsum_strings_that_fail_numeric_validation) {
|
||||
auto parse_result_or_status =
|
||||
ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]),
|
||||
test_case[0].size(), test_case[1].size());
|
||||
EXPECT_TRUE(parse_result_or_status.status().ok());
|
||||
auto parse_result = parse_result_or_status.ValueOrDie();
|
||||
EXPECT_FALSE(ValidateEinsumNumericDimensions(
|
||||
parse_result[0], parse_result[1], parse_result[2])
|
||||
.ok());
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
|
||||
|
@ -237,7 +237,6 @@ class EinsumOpTest(test.TestCase):
|
||||
((4, 3), (None, 3)))
|
||||
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testOutputRepeatedLabels(self):
|
||||
# This is the reverse operation of generalized traces, to be used for
|
||||
# computing symbolic gradients of einsum. Note: this operation is not
|
||||
@ -264,7 +263,6 @@ class EinsumOpTest(test.TestCase):
|
||||
# From transformer xl.
|
||||
check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testEmptyWithRepeatedLabels(self):
|
||||
|
||||
def check(equation, input_shapes, output_shape):
|
||||
@ -310,7 +308,6 @@ class EinsumGradTest(test.TestCase):
|
||||
self.assertLess(
|
||||
gradient_checker_v2.max_error(analytical, numerical), tol)
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testUnary(self):
|
||||
# Unary cases.
|
||||
self._check_gradient('->', ())
|
||||
@ -319,7 +316,6 @@ class EinsumGradTest(test.TestCase):
|
||||
self._check_gradient('aabcd->add', (3, 3, 5, 4, 4))
|
||||
self._check_gradient('abcd->da', (3, 5, 4, 2))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testUnaryEllipsis(self):
|
||||
self._check_gradient('...->...', ())
|
||||
self._check_gradient('...->', ())
|
||||
@ -362,11 +358,9 @@ class EinsumGradTest(test.TestCase):
|
||||
self._check_gradient('ijkm,ijln->ijmn', (2, 3, 3, 4), (2, 3, 3, 2))
|
||||
self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testReducedIndicesWithRepeatedLabels(self):
|
||||
self._check_gradient('abce,badf->bcba', (1, 2, 3, 4), (2, 1, 4, 3))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testRepeatedLabels(self):
|
||||
# Repeated indices.
|
||||
self._check_gradient('aba,a->b', (3, 4, 3), (3,))
|
||||
@ -376,7 +370,6 @@ class EinsumGradTest(test.TestCase):
|
||||
self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4))
|
||||
self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testEmptyWithRepeatedLabels(self):
|
||||
self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
|
||||
self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
|
||||
@ -388,7 +381,6 @@ class EinsumGradTest(test.TestCase):
|
||||
self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
|
||||
self._check_gradient('i...j,j...k->i...k', (3, 1, 2, 2), (2, 2, 3, 1, 4))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testBroadcastingWithRepeatedLabels(self):
|
||||
self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
|
||||
self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4))
|
||||
|
Loading…
x
Reference in New Issue
Block a user