[XLA] Simplify batch dots that have no contracting dimensions into multiplies.

PiperOrigin-RevId: 233637534
This commit is contained in:
Blake Hechtman 2019-02-12 10:54:26 -08:00 committed by TensorFlower Gardener
parent ce3ee9b6a4
commit ba941c212d
6 changed files with 70 additions and 40 deletions

View File

@ -1605,29 +1605,50 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
return Status::OK();
}
// If a dot only contains batch dimensions, then tranpose the rhs and lhs
// acording to the batch dimension numbers and do a simple multiply.
if (lhs->shape().rank() ==
dot->dot_dimension_numbers().lhs_batch_dimensions_size() &&
rhs->shape().rank() ==
dot->dot_dimension_numbers().rhs_batch_dimensions_size()) {
HloInstruction* new_rhs = rhs;
HloInstruction* new_lhs = lhs;
if (lhs->shape().rank() > 1) {
TF_ASSIGN_OR_RETURN(
new_rhs,
MakeTransposeHlo(
rhs, AsInt64Slice(
dot->dot_dimension_numbers().rhs_batch_dimensions())));
TF_ASSIGN_OR_RETURN(
new_lhs,
MakeTransposeHlo(
lhs, AsInt64Slice(
dot->dot_dimension_numbers().lhs_batch_dimensions())));
// If there are no contracting dimensions, a dot can be rewritten as
// mul(broadcast(transpose(x)),broadcast(transpose(y)))
if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
std::vector<int64> lhs_transpose(
dot->dot_dimension_numbers().lhs_batch_dimensions().begin(),
dot->dot_dimension_numbers().lhs_batch_dimensions().end());
for (int64 i = 0; i < lhs->shape().rank(); ++i) {
if (!absl::c_linear_search(
dot->dot_dimension_numbers().lhs_batch_dimensions(), i)) {
lhs_transpose.push_back(i);
}
}
TF_ASSIGN_OR_RETURN(auto new_dot,
MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
return ReplaceInstruction(dot, new_dot);
TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs,
MakeTransposeHlo(lhs, lhs_transpose));
if (dot->shape().rank() != lhs->shape().rank()) {
std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
absl::c_iota(lhs_broadcast_dims, 0);
new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
dot->shape(), new_lhs, lhs_broadcast_dims));
}
std::vector<int64> rhs_transpose(
dot->dot_dimension_numbers().rhs_batch_dimensions().begin(),
dot->dot_dimension_numbers().rhs_batch_dimensions().end());
for (int64 i = 0; i < rhs->shape().rank(); ++i) {
if (!absl::c_linear_search(
dot->dot_dimension_numbers().rhs_batch_dimensions(), i)) {
rhs_transpose.push_back(i);
}
}
TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs,
MakeTransposeHlo(rhs, rhs_transpose));
if (dot->shape().rank() != rhs->shape().rank()) {
std::vector<int64> rhs_broadcast_dims(
dot->dot_dimension_numbers().lhs_batch_dimensions_size());
absl::c_iota(rhs_broadcast_dims, 0);
for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
rhs_broadcast_dims.push_back(i);
}
new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
dot->shape(), new_rhs, rhs_broadcast_dims));
}
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
new_lhs, new_rhs));
}
if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||

View File

@ -4222,8 +4222,10 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
std::tie(m, k, n, element_type) = GetParam();
Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n});
Shape lhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k});
Shape rhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n});
Shape lhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k})
: ShapeUtil::MakeShape(element_type, {1, 3, 5, m});
Shape rhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n})
: ShapeUtil::MakeShape(element_type, {1, 3, 5, n});
HloComputation::Builder builder(TestName());
auto lhs = builder.AddInstruction(
@ -4237,14 +4239,16 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
dot_dnums.add_rhs_batch_dimensions(0);
dot_dnums.add_rhs_batch_dimensions(1);
dot_dnums.add_rhs_batch_dimensions(2);
dot_dnums.add_lhs_contracting_dimensions(4);
dot_dnums.add_rhs_contracting_dimensions(3);
if (k > 0) {
dot_dnums.add_lhs_contracting_dimensions(4);
dot_dnums.add_rhs_contracting_dimensions(3);
}
builder.AddInstruction(HloInstruction::CreateDot(
dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(default_options_);
TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1 || k == -1;
const bool computation_should_be_modified = dot_should_be_transformed;
EXPECT_EQ(changed, computation_should_be_modified);
bool has_no_dot = true;
@ -4259,7 +4263,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
INSTANTIATE_TEST_SUITE_P(
BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest,
::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
::testing::Combine(::testing::Values(1, 2), ::testing::Values(-1, 1, 2),
::testing::Values(1, 2), ::testing::Values(F32, BF16)));
class DotStrengthReductionTest

View File

@ -2700,11 +2700,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(operand.rank());
std::iota(indices.begin(), indices.end(), 0);
if (dimensions.size() != operand.rank() ||
!std::is_permutation(dimensions.begin(), dimensions.end(),
indices.begin())) {
if (!IsPermutation(dimensions, operand.rank())) {
return InvalidArgument(
"Transpose dimensions [%s] are not a permutation of the operand "
"dimensions (operand shape is %s).",

View File

@ -1572,6 +1572,16 @@ TEST_F(ShapeInferenceTest, Transpose) {
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
}
TEST_F(ShapeInferenceTest, Rank1Transpose) {
Shape a_shape = ShapeUtil::MakeShape(F32, {5});
auto inferred_shape_and_status =
ShapeInference::InferTransposeShape(a_shape, {0});
EXPECT_IS_OK(inferred_shape_and_status);
Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
EXPECT_TRUE(
ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
}
TEST_F(ShapeInferenceTest, Conditional) {
auto inferred_status0 = ShapeInference::InferConditionalShape(
pred_, vector_32_, vector_64_,

View File

@ -1189,6 +1189,8 @@ std::vector<EinsumParamType> GetEinsumTestCases() {
p{v{5, 6}, v{6, 7}, "ab,cd->dcba"},
p{v{6}, v{6, 7}, "b,bc->c"},
p{v{77}, v{77}, "a,a->a"},
p{v{77}, v{77, 55}, "a,ab->ba"},
p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"},
p{v{55}, v{}, "a,->a"},
p{v{11, 111}, v{11}, "ab,a->ab"},
p{v{16, 34}, v{16, 34}, "ab,ab->ab"},

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <stdarg.h>
#include <numeric>
#include "absl/container/inlined_vector.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@ -80,13 +81,9 @@ bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
if (rank != permutation.size()) {
return false;
}
std::vector<int64> output(permutation.size(), -1);
for (auto index : permutation) {
CHECK_GE(index, 0);
CHECK_LT(index, rank);
output[index] = 0;
}
return !absl::c_linear_search(output, -1);
absl::InlinedVector<int64, 8> trivial_permutation(rank);
absl::c_iota(trivial_permutation, 0);
return absl::c_is_permutation(permutation, trivial_permutation);
}
std::vector<int64> InversePermutation(