[XLA] Canonicalize dot dimension numbers on CPU and GPU backends cases that previously failed shape inference.

PiperOrigin-RevId: 228377214
This commit is contained in:
Blake Hechtman 2019-01-08 12:22:54 -08:00 committed by TensorFlower Gardener
parent 67d8a0330b
commit a3d639a5a4
6 changed files with 208 additions and 8 deletions

View File

@ -1873,8 +1873,9 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@ -156,29 +158,187 @@ Status DecomposeBatchDot(HloInstruction* dot) {
return computation->ReplaceInstruction(dot, new_dot);
}
// Convert a dot into a canonical form where non-contracting and contracting
// dimensions are reshaped together and batch dimensions are the most major
// dimensions. The requires transposing and reshapes the lhs and rhs and
// reshaping the output batch to the original shape.
Status CanonicalizeDot(HloInstruction* original_dot) {
auto computation = original_dot->parent();
const auto& original_dnums = original_dot->dot_dimension_numbers();
const int64 num_batch_dims = original_dnums.lhs_batch_dimensions_size();
const int64 num_contracting_dims =
original_dnums.lhs_contracting_dimensions_size();
const auto& lhs_shape = original_dot->operand(0)->shape();
const int64 lhs_rank = lhs_shape.rank();
const int64 num_lhs_non_contracting_dims =
lhs_rank - num_batch_dims - num_contracting_dims;
std::vector<int64> lhs_non_contracting_dims;
lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims);
int64 lhs_contracting_size = 1;
int64 lhs_non_contracting_size = 1;
std::vector<int64> batch_dim_sizes;
batch_dim_sizes.reserve(num_batch_dims);
for (int64 i = 0; i < lhs_rank; ++i) {
if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) {
lhs_contracting_size *= lhs_shape.dimensions(i);
} else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(),
i)) {
batch_dim_sizes.push_back(lhs_shape.dimensions(i));
} else {
lhs_non_contracting_dims.push_back(i);
lhs_non_contracting_size *= lhs_shape.dimensions(i);
}
}
// The canonical form of the lhs is
// [BatchDims, NonContractingDims, ContractingsDims]
std::vector<int64> lhs_transpose;
lhs_transpose.reserve(lhs_rank);
lhs_transpose.insert(lhs_transpose.end(),
original_dnums.lhs_batch_dimensions().begin(),
original_dnums.lhs_batch_dimensions().end());
lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(),
lhs_non_contracting_dims.end());
lhs_transpose.insert(lhs_transpose.end(),
original_dnums.lhs_contracting_dimensions().begin(),
original_dnums.lhs_contracting_dimensions().end());
HloInstruction* transposed_lhs =
computation->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose),
lhs_shape),
original_dot->mutable_operand(0), lhs_transpose));
std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
lhs_reshape_dims.push_back(lhs_non_contracting_size);
lhs_reshape_dims.push_back(lhs_contracting_size);
// Reshape the contracting and non-contracting dimensions together.
HloInstruction* reshaped_lhs =
computation->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims),
transposed_lhs));
const auto& rhs_shape = original_dot->operand(1)->shape();
const int64 rhs_rank = rhs_shape.rank();
const int64 num_rhs_non_contracting_dims =
rhs_rank - num_batch_dims - num_contracting_dims;
std::vector<int64> rhs_non_contracting_dims;
rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims);
int64 rhs_non_contracting_size = 1;
int64 rhs_contracting_size = 1;
for (int64 i = 0; i < rhs_rank; ++i) {
if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) {
rhs_contracting_size *= rhs_shape.dimensions(i);
} else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(),
i)) {
rhs_non_contracting_dims.push_back(i);
rhs_non_contracting_size *= rhs_shape.dimensions(i);
}
}
// The canonical form of the rhs is
// [BatchDims, ContractingsDims, NonContractingDims]
std::vector<int64> rhs_transpose;
rhs_transpose.reserve(rhs_rank);
rhs_transpose.insert(rhs_transpose.end(),
original_dnums.rhs_batch_dimensions().begin(),
original_dnums.rhs_batch_dimensions().end());
rhs_transpose.insert(rhs_transpose.end(),
original_dnums.rhs_contracting_dimensions().begin(),
original_dnums.rhs_contracting_dimensions().end());
rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(),
rhs_non_contracting_dims.end());
HloInstruction* transposed_rhs =
computation->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose),
rhs_shape),
original_dot->mutable_operand(1), rhs_transpose));
std::vector<int64> rhs_reshape_dims = batch_dim_sizes;
rhs_reshape_dims.push_back(rhs_contracting_size);
rhs_reshape_dims.push_back(rhs_non_contracting_size);
// Reshape the contracting and non-contracting dimensions together.
HloInstruction* reshaped_rhs =
computation->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims),
transposed_rhs));
std::vector<int64> dot_dims = batch_dim_sizes;
dot_dims.push_back(lhs_non_contracting_size);
dot_dims.push_back(rhs_non_contracting_size);
DotDimensionNumbers dot_dnums;
for (int64 i = 0; i < num_batch_dims; ++i) {
dot_dnums.add_lhs_batch_dimensions(i);
dot_dnums.add_rhs_batch_dimensions(i);
}
dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1);
dot_dnums.add_rhs_contracting_dimensions(num_batch_dims);
HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot(
ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims),
reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config()));
return computation->ReplaceInstruction(
original_dot, computation->AddInstruction(HloInstruction::CreateReshape(
original_dot->shape(), dot)));
}
} // namespace
StatusOr<bool> DotDecomposer::Run(HloModule* module) {
XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString());
// Gather all batch Dot operations.
std::vector<HloInstruction*> batch_dots;
// Gather all Non-canonical Dot operations.
std::vector<HloInstruction*> non_canonical_dots;
for (auto* computation : module->MakeNonfusionComputations()) {
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kDot) {
continue;
}
const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) {
batch_dots.push_back(instruction);
// A dot it not canonical if there are more than one contracting
// dimension.
if (dnums.lhs_contracting_dimensions_size() > 1) {
non_canonical_dots.push_back(instruction);
continue;
}
if (dnums.lhs_batch_dimensions().empty()) {
continue;
}
std::vector<int64> canonical_batch_dims(
dnums.lhs_batch_dimensions_size());
absl::c_iota(canonical_batch_dims, 0);
if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) ||
!absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) {
non_canonical_dots.push_back(instruction);
}
}
}
// Decompose each batch Dot in 'batch_dots'.
bool changed = false;
for (auto* dot : batch_dots) {
TF_RETURN_IF_ERROR(DecomposeBatchDot(dot));
for (auto* dot : non_canonical_dots) {
TF_RETURN_IF_ERROR(CanonicalizeDot(dot));
changed = true;
}
if (decompose_batch_dot_) {
std::vector<HloInstruction*> batch_dots;
for (auto* computation : module->MakeNonfusionComputations()) {
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kDot) {
continue;
}
const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
if (!dnums.lhs_batch_dimensions().empty()) {
batch_dots.push_back(instruction);
}
}
}
// Decompose each batch Dot in 'batch_dots'.
for (auto* dot : batch_dots) {
TF_RETURN_IF_ERROR(DecomposeBatchDot(dot));
changed = true;
}
}
XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString());
return changed;
}

View File

@ -699,6 +699,7 @@ cc_library(
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:convolution_group_converter",
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
@ -165,6 +166,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// We need a cost model for GPUs. Currently, do nothing.
return false;
};
pipeline.AddPass<DotDecomposer>(false);
pipeline.AddPass<ConvolutionGroupConverter>(
cost_model,
/*convert_batch_groups_only=*/true);

View File

@ -711,6 +711,7 @@ xla_test(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@ -768,6 +769,7 @@ xla_test(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
@ -1147,5 +1148,38 @@ XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
class DotOperationTextTest : public HloTestBase {};
XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) {
absl::string_view hlo_string =
R"(
HloModule ComplexDotMultipleNonContracting
ENTRY %test {
%lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
%rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
}
XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) {
absl::string_view hlo_string =
R"(
HloModule ComplexDotMultipleNonContracting
ENTRY %test {
%lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0)
%rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1)
ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3}
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
}
} // namespace
} // namespace xla