[XLA] Canonicalize dot dimension numbers on CPU and GPU backends cases that previously failed shape inference.
PiperOrigin-RevId: 228377214
This commit is contained in:
parent
67d8a0330b
commit
a3d639a5a4
@ -1873,8 +1873,9 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -156,29 +158,187 @@ Status DecomposeBatchDot(HloInstruction* dot) {
|
||||
return computation->ReplaceInstruction(dot, new_dot);
|
||||
}
|
||||
|
||||
// Convert a dot into a canonical form where non-contracting and contracting
|
||||
// dimensions are reshaped together and batch dimensions are the most major
|
||||
// dimensions. The requires transposing and reshapes the lhs and rhs and
|
||||
// reshaping the output batch to the original shape.
|
||||
Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
auto computation = original_dot->parent();
|
||||
const auto& original_dnums = original_dot->dot_dimension_numbers();
|
||||
const int64 num_batch_dims = original_dnums.lhs_batch_dimensions_size();
|
||||
const int64 num_contracting_dims =
|
||||
original_dnums.lhs_contracting_dimensions_size();
|
||||
|
||||
const auto& lhs_shape = original_dot->operand(0)->shape();
|
||||
const int64 lhs_rank = lhs_shape.rank();
|
||||
const int64 num_lhs_non_contracting_dims =
|
||||
lhs_rank - num_batch_dims - num_contracting_dims;
|
||||
|
||||
std::vector<int64> lhs_non_contracting_dims;
|
||||
lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims);
|
||||
int64 lhs_contracting_size = 1;
|
||||
int64 lhs_non_contracting_size = 1;
|
||||
std::vector<int64> batch_dim_sizes;
|
||||
batch_dim_sizes.reserve(num_batch_dims);
|
||||
for (int64 i = 0; i < lhs_rank; ++i) {
|
||||
if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) {
|
||||
lhs_contracting_size *= lhs_shape.dimensions(i);
|
||||
} else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(),
|
||||
i)) {
|
||||
batch_dim_sizes.push_back(lhs_shape.dimensions(i));
|
||||
} else {
|
||||
lhs_non_contracting_dims.push_back(i);
|
||||
lhs_non_contracting_size *= lhs_shape.dimensions(i);
|
||||
}
|
||||
}
|
||||
// The canonical form of the lhs is
|
||||
// [BatchDims, NonContractingDims, ContractingsDims]
|
||||
std::vector<int64> lhs_transpose;
|
||||
lhs_transpose.reserve(lhs_rank);
|
||||
lhs_transpose.insert(lhs_transpose.end(),
|
||||
original_dnums.lhs_batch_dimensions().begin(),
|
||||
original_dnums.lhs_batch_dimensions().end());
|
||||
lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(),
|
||||
lhs_non_contracting_dims.end());
|
||||
lhs_transpose.insert(lhs_transpose.end(),
|
||||
original_dnums.lhs_contracting_dimensions().begin(),
|
||||
original_dnums.lhs_contracting_dimensions().end());
|
||||
HloInstruction* transposed_lhs =
|
||||
computation->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose),
|
||||
lhs_shape),
|
||||
original_dot->mutable_operand(0), lhs_transpose));
|
||||
std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
|
||||
lhs_reshape_dims.push_back(lhs_non_contracting_size);
|
||||
lhs_reshape_dims.push_back(lhs_contracting_size);
|
||||
// Reshape the contracting and non-contracting dimensions together.
|
||||
HloInstruction* reshaped_lhs =
|
||||
computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims),
|
||||
transposed_lhs));
|
||||
|
||||
const auto& rhs_shape = original_dot->operand(1)->shape();
|
||||
const int64 rhs_rank = rhs_shape.rank();
|
||||
const int64 num_rhs_non_contracting_dims =
|
||||
rhs_rank - num_batch_dims - num_contracting_dims;
|
||||
std::vector<int64> rhs_non_contracting_dims;
|
||||
rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims);
|
||||
int64 rhs_non_contracting_size = 1;
|
||||
int64 rhs_contracting_size = 1;
|
||||
for (int64 i = 0; i < rhs_rank; ++i) {
|
||||
if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) {
|
||||
rhs_contracting_size *= rhs_shape.dimensions(i);
|
||||
} else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(),
|
||||
i)) {
|
||||
rhs_non_contracting_dims.push_back(i);
|
||||
rhs_non_contracting_size *= rhs_shape.dimensions(i);
|
||||
}
|
||||
}
|
||||
|
||||
// The canonical form of the rhs is
|
||||
// [BatchDims, ContractingsDims, NonContractingDims]
|
||||
std::vector<int64> rhs_transpose;
|
||||
rhs_transpose.reserve(rhs_rank);
|
||||
rhs_transpose.insert(rhs_transpose.end(),
|
||||
original_dnums.rhs_batch_dimensions().begin(),
|
||||
original_dnums.rhs_batch_dimensions().end());
|
||||
rhs_transpose.insert(rhs_transpose.end(),
|
||||
original_dnums.rhs_contracting_dimensions().begin(),
|
||||
original_dnums.rhs_contracting_dimensions().end());
|
||||
rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(),
|
||||
rhs_non_contracting_dims.end());
|
||||
HloInstruction* transposed_rhs =
|
||||
computation->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose),
|
||||
rhs_shape),
|
||||
original_dot->mutable_operand(1), rhs_transpose));
|
||||
|
||||
std::vector<int64> rhs_reshape_dims = batch_dim_sizes;
|
||||
rhs_reshape_dims.push_back(rhs_contracting_size);
|
||||
rhs_reshape_dims.push_back(rhs_non_contracting_size);
|
||||
// Reshape the contracting and non-contracting dimensions together.
|
||||
HloInstruction* reshaped_rhs =
|
||||
computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims),
|
||||
transposed_rhs));
|
||||
|
||||
std::vector<int64> dot_dims = batch_dim_sizes;
|
||||
dot_dims.push_back(lhs_non_contracting_size);
|
||||
dot_dims.push_back(rhs_non_contracting_size);
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
for (int64 i = 0; i < num_batch_dims; ++i) {
|
||||
dot_dnums.add_lhs_batch_dimensions(i);
|
||||
dot_dnums.add_rhs_batch_dimensions(i);
|
||||
}
|
||||
dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(num_batch_dims);
|
||||
|
||||
HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot(
|
||||
ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims),
|
||||
reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config()));
|
||||
|
||||
return computation->ReplaceInstruction(
|
||||
original_dot, computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
original_dot->shape(), dot)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> DotDecomposer::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString());
|
||||
// Gather all batch Dot operations.
|
||||
std::vector<HloInstruction*> batch_dots;
|
||||
// Gather all Non-canonical Dot operations.
|
||||
std::vector<HloInstruction*> non_canonical_dots;
|
||||
for (auto* computation : module->MakeNonfusionComputations()) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() != HloOpcode::kDot) {
|
||||
continue;
|
||||
}
|
||||
const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
|
||||
if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) {
|
||||
batch_dots.push_back(instruction);
|
||||
// A dot it not canonical if there are more than one contracting
|
||||
// dimension.
|
||||
if (dnums.lhs_contracting_dimensions_size() > 1) {
|
||||
non_canonical_dots.push_back(instruction);
|
||||
continue;
|
||||
}
|
||||
if (dnums.lhs_batch_dimensions().empty()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<int64> canonical_batch_dims(
|
||||
dnums.lhs_batch_dimensions_size());
|
||||
absl::c_iota(canonical_batch_dims, 0);
|
||||
if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) ||
|
||||
!absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) {
|
||||
non_canonical_dots.push_back(instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Decompose each batch Dot in 'batch_dots'.
|
||||
bool changed = false;
|
||||
for (auto* dot : batch_dots) {
|
||||
TF_RETURN_IF_ERROR(DecomposeBatchDot(dot));
|
||||
for (auto* dot : non_canonical_dots) {
|
||||
TF_RETURN_IF_ERROR(CanonicalizeDot(dot));
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (decompose_batch_dot_) {
|
||||
std::vector<HloInstruction*> batch_dots;
|
||||
for (auto* computation : module->MakeNonfusionComputations()) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() != HloOpcode::kDot) {
|
||||
continue;
|
||||
}
|
||||
const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
|
||||
if (!dnums.lhs_batch_dimensions().empty()) {
|
||||
batch_dots.push_back(instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Decompose each batch Dot in 'batch_dots'.
|
||||
|
||||
for (auto* dot : batch_dots) {
|
||||
TF_RETURN_IF_ERROR(DecomposeBatchDot(dot));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString());
|
||||
return changed;
|
||||
}
|
||||
|
@ -699,6 +699,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
||||
@ -165,6 +166,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
// We need a cost model for GPUs. Currently, do nothing.
|
||||
return false;
|
||||
};
|
||||
pipeline.AddPass<DotDecomposer>(false);
|
||||
pipeline.AddPass<ConvolutionGroupConverter>(
|
||||
cost_model,
|
||||
/*convert_batch_groups_only=*/true);
|
||||
|
@ -711,6 +711,7 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
@ -768,6 +769,7 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/reference_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
@ -1147,5 +1148,38 @@ XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
class DotOperationTextTest : public HloTestBase {};
|
||||
|
||||
XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule ComplexDotMultipleNonContracting
|
||||
|
||||
ENTRY %test {
|
||||
%lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
|
||||
%rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
|
||||
ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule ComplexDotMultipleNonContracting
|
||||
|
||||
ENTRY %test {
|
||||
%lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0)
|
||||
%rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1)
|
||||
ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user