[XLA] Remove batch dot decomposition from DotDecomposer

This is unused and confusing. The dot canonicalization is still there, which is now the one and only purpose of DotDecomposer.

PiperOrigin-RevId: 241509598
This commit is contained in:
Benjamin Kramer 2019-04-02 05:30:51 -07:00 committed by TensorFlower Gardener
parent 850e30a5cb
commit b33476906d
4 changed files with 5 additions and 161 deletions

View File

@ -269,7 +269,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
// pass.
pipeline.AddPass<CallInliner>();
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>(/*decompose_batch_dot=*/false);
pipeline.AddPass<DotDecomposer>();
auto cost_model = [](HloInstruction* conv) {
// We need a cost model for CPUs. Currently, do nothing.
return false;

View File

@ -29,135 +29,6 @@ namespace xla {
namespace {
// TODO(b/69062148) Remove this code when all backends support BatchDot
// natively.
Status DecomposeBatchDot(HloInstruction* dot) {
auto computation = dot->parent();
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
HloInstruction* lhs = dot->mutable_operand(0);
HloInstruction* rhs = dot->mutable_operand(1);
const Shape& lhs_shape = lhs->shape();
const Shape& rhs_shape = rhs->shape();
const Shape& dot_shape = dot->shape();
// ShapeInference should guarantee that lhs/rhs batch dimensions match.
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
dnums.rhs_batch_dimensions_size());
const int64 num_batch_dims = dnums.lhs_batch_dimensions_size();
// Calculate total batch size (note that ShapeInference requires that
// the batch dimensions are most-major).
int64 batch_size = 1;
for (int i = 0; i < num_batch_dims; ++i) {
CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)),
rhs_shape.dimensions(dnums.rhs_batch_dimensions(i)));
batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i));
}
// Set lhs/rhs_transpose.
CHECK_EQ(1, dnums.lhs_contracting_dimensions_size());
const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0);
const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0;
CHECK_EQ(1, dnums.rhs_contracting_dimensions_size());
const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0);
const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1;
// Compute R3 and R3 shapes for lhs.
PrimitiveType lhs_type = lhs_shape.element_type();
const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0);
const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1);
Shape lhs_shape_r3 =
ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols});
Shape lhs_slice_shape_r3 =
ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols});
Shape lhs_slice_shape_r2 =
ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols});
// Compute R3 and R3 shapes for rhs.
PrimitiveType rhs_type = rhs_shape.element_type();
const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0);
const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1);
Shape rhs_shape_r3 =
ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols});
Shape rhs_slice_shape_r3 =
ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols});
Shape rhs_slice_shape_r2 =
ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols});
// Compute R3 and R3 shapes for dot output.
PrimitiveType dot_type = dot_shape.element_type();
const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0);
const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1);
Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols});
Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols});
Shape concat_shape_r3 =
ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols});
// Reshape lhs/rhs into R3.
auto lhs_r3 = computation->AddInstruction(
HloInstruction::CreateReshape(lhs_shape_r3, lhs));
auto rhs_r3 = computation->AddInstruction(
HloInstruction::CreateReshape(rhs_shape_r3, rhs));
// Loop through batch size, slicing out required lhs/rhs to compute each Dot.
std::vector<HloInstruction*> output_slices(batch_size);
for (int64 i = 0; i < batch_size; ++i) {
// Slice R3 shape from 'lhs' and reshape to R2.
auto lhs_slice_r3 = computation->AddInstruction(
HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0},
{i + 1, lhs_rows, lhs_cols}, {1, 1, 1}));
auto lhs_slice_r2 = computation->AddInstruction(
HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3));
// Slice R3 shape from 'rhs' and reshape to R2.
auto rhs_slice_r3 = computation->AddInstruction(
HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0},
{i + 1, rhs_rows, rhs_cols}, {1, 1, 1}));
auto rhs_slice_r2 = computation->AddInstruction(
HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3));
// Transpose lhs/rhs (if needed).
if (lhs_transpose) {
Shape lhs_slice_shape_r2_transpose =
ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows});
lhs_slice_r2 =
computation->AddInstruction(HloInstruction::CreateTranspose(
lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0}));
}
if (rhs_transpose) {
Shape rhs_slice_shape_r2_transpose =
ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows});
rhs_slice_r2 =
computation->AddInstruction(HloInstruction::CreateTranspose(
rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0}));
}
// Compute Dot of lhs/rhs R2 slices.
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto dot_r2 = computation->AddInstruction(
HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2,
dot_dnums, dot->precision_config()));
// Reshape Dot to R3 so we can concat along batch dimension.
auto dot_r3 = computation->AddInstruction(
HloInstruction::CreateReshape(dot_shape_r3, dot_r2));
output_slices[i] = dot_r3;
}
// Concatenate slices from 'output_slices' along batch dimension.
auto concat = computation->AddInstruction(
HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0));
// Reshape output 'new_dot' to original dimensions.
auto new_dot = computation->AddInstruction(
HloInstruction::CreateReshape(dot_shape, concat));
// Replace all uses of 'dot' in 'computation' with 'new_dot'.
return computation->ReplaceInstruction(dot, new_dot);
}
// Convert a dot into a canonical form where non-contracting and contracting
// dimensions are reshaped together and batch dimensions are the most major
// dimensions. The requires transposing and reshapes the lhs and rhs and
@ -323,27 +194,6 @@ StatusOr<bool> DotDecomposer::Run(HloModule* module) {
TF_RETURN_IF_ERROR(CanonicalizeDot(dot));
changed = true;
}
if (decompose_batch_dot_) {
std::vector<HloInstruction*> batch_dots;
for (auto* computation : module->MakeNonfusionComputations()) {
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kDot) {
continue;
}
const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
if (!dnums.lhs_batch_dimensions().empty()) {
batch_dots.push_back(instruction);
}
}
}
// Decompose each batch Dot in 'batch_dots'.
for (auto* dot : batch_dots) {
TF_RETURN_IF_ERROR(DecomposeBatchDot(dot));
changed = true;
}
}
XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString());
return changed;
}

View File

@ -21,22 +21,16 @@ limitations under the License.
namespace xla {
// DotDecomposer is a pass which decomposes batch Dot operations into a
// sequence of smaller (R2) Dot operations.
// DotDecomposer is a pass which converts dots into a canonical form where
// non-contracting and contracting dimensions are reshaped together and batch
// dimensions are the most major dimensions.
class DotDecomposer : public HloModulePass {
public:
// Decomposes batch Dot operations when 'decompose_batch_dot' is true.
DotDecomposer(bool decompose_batch_dot = true)
: decompose_batch_dot_(decompose_batch_dot) {}
~DotDecomposer() = default;
absl::string_view name() const override { return "dot_decomposer"; }
// Run DotDecomposer pass on computations in 'module'.
// Returns whether the 'module' was changed.
StatusOr<bool> Run(HloModule* module) override;
private:
bool decompose_batch_dot_;
};
} // namespace xla

View File

@ -195,7 +195,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// We need a cost model for GPUs. Currently, do nothing.
return false;
};
pipeline.AddPass<DotDecomposer>(false);
pipeline.AddPass<DotDecomposer>();
pipeline.AddPass<ConvolutionGroupConverter>(
cost_model,
/*convert_batch_groups_only=*/true);