[XLA] Remove batch dot decomposition from DotDecomposer
This is unused and confusing. The dot canonicalization is still there, which is now the one and only purpose of DotDecomposer. PiperOrigin-RevId: 241509598
This commit is contained in:
parent
850e30a5cb
commit
b33476906d
@ -269,7 +269,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
// pass.
|
||||
pipeline.AddPass<CallInliner>();
|
||||
pipeline.AddPass<BatchDotSimplification>();
|
||||
pipeline.AddPass<DotDecomposer>(/*decompose_batch_dot=*/false);
|
||||
pipeline.AddPass<DotDecomposer>();
|
||||
auto cost_model = [](HloInstruction* conv) {
|
||||
// We need a cost model for CPUs. Currently, do nothing.
|
||||
return false;
|
||||
|
@ -29,135 +29,6 @@ namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(b/69062148) Remove this code when all backends support BatchDot
|
||||
// natively.
|
||||
Status DecomposeBatchDot(HloInstruction* dot) {
|
||||
auto computation = dot->parent();
|
||||
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
|
||||
HloInstruction* lhs = dot->mutable_operand(0);
|
||||
HloInstruction* rhs = dot->mutable_operand(1);
|
||||
const Shape& lhs_shape = lhs->shape();
|
||||
const Shape& rhs_shape = rhs->shape();
|
||||
const Shape& dot_shape = dot->shape();
|
||||
|
||||
// ShapeInference should guarantee that lhs/rhs batch dimensions match.
|
||||
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
|
||||
dnums.rhs_batch_dimensions_size());
|
||||
const int64 num_batch_dims = dnums.lhs_batch_dimensions_size();
|
||||
// Calculate total batch size (note that ShapeInference requires that
|
||||
// the batch dimensions are most-major).
|
||||
int64 batch_size = 1;
|
||||
for (int i = 0; i < num_batch_dims; ++i) {
|
||||
CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)),
|
||||
rhs_shape.dimensions(dnums.rhs_batch_dimensions(i)));
|
||||
batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i));
|
||||
}
|
||||
|
||||
// Set lhs/rhs_transpose.
|
||||
CHECK_EQ(1, dnums.lhs_contracting_dimensions_size());
|
||||
const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0);
|
||||
const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0;
|
||||
|
||||
CHECK_EQ(1, dnums.rhs_contracting_dimensions_size());
|
||||
const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0);
|
||||
const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1;
|
||||
|
||||
// Compute R3 and R3 shapes for lhs.
|
||||
PrimitiveType lhs_type = lhs_shape.element_type();
|
||||
const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0);
|
||||
const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1);
|
||||
Shape lhs_shape_r3 =
|
||||
ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols});
|
||||
Shape lhs_slice_shape_r3 =
|
||||
ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols});
|
||||
Shape lhs_slice_shape_r2 =
|
||||
ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols});
|
||||
|
||||
// Compute R3 and R3 shapes for rhs.
|
||||
PrimitiveType rhs_type = rhs_shape.element_type();
|
||||
const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0);
|
||||
const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1);
|
||||
Shape rhs_shape_r3 =
|
||||
ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols});
|
||||
Shape rhs_slice_shape_r3 =
|
||||
ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols});
|
||||
Shape rhs_slice_shape_r2 =
|
||||
ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols});
|
||||
|
||||
// Compute R3 and R3 shapes for dot output.
|
||||
PrimitiveType dot_type = dot_shape.element_type();
|
||||
const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0);
|
||||
const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1);
|
||||
Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols});
|
||||
Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols});
|
||||
Shape concat_shape_r3 =
|
||||
ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols});
|
||||
|
||||
// Reshape lhs/rhs into R3.
|
||||
auto lhs_r3 = computation->AddInstruction(
|
||||
HloInstruction::CreateReshape(lhs_shape_r3, lhs));
|
||||
auto rhs_r3 = computation->AddInstruction(
|
||||
HloInstruction::CreateReshape(rhs_shape_r3, rhs));
|
||||
|
||||
// Loop through batch size, slicing out required lhs/rhs to compute each Dot.
|
||||
std::vector<HloInstruction*> output_slices(batch_size);
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
// Slice R3 shape from 'lhs' and reshape to R2.
|
||||
auto lhs_slice_r3 = computation->AddInstruction(
|
||||
HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0},
|
||||
{i + 1, lhs_rows, lhs_cols}, {1, 1, 1}));
|
||||
auto lhs_slice_r2 = computation->AddInstruction(
|
||||
HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3));
|
||||
|
||||
// Slice R3 shape from 'rhs' and reshape to R2.
|
||||
auto rhs_slice_r3 = computation->AddInstruction(
|
||||
HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0},
|
||||
{i + 1, rhs_rows, rhs_cols}, {1, 1, 1}));
|
||||
auto rhs_slice_r2 = computation->AddInstruction(
|
||||
HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3));
|
||||
|
||||
// Transpose lhs/rhs (if needed).
|
||||
if (lhs_transpose) {
|
||||
Shape lhs_slice_shape_r2_transpose =
|
||||
ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows});
|
||||
lhs_slice_r2 =
|
||||
computation->AddInstruction(HloInstruction::CreateTranspose(
|
||||
lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0}));
|
||||
}
|
||||
if (rhs_transpose) {
|
||||
Shape rhs_slice_shape_r2_transpose =
|
||||
ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows});
|
||||
rhs_slice_r2 =
|
||||
computation->AddInstruction(HloInstruction::CreateTranspose(
|
||||
rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0}));
|
||||
}
|
||||
|
||||
// Compute Dot of lhs/rhs R2 slices.
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto dot_r2 = computation->AddInstruction(
|
||||
HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2,
|
||||
dot_dnums, dot->precision_config()));
|
||||
|
||||
// Reshape Dot to R3 so we can concat along batch dimension.
|
||||
auto dot_r3 = computation->AddInstruction(
|
||||
HloInstruction::CreateReshape(dot_shape_r3, dot_r2));
|
||||
|
||||
output_slices[i] = dot_r3;
|
||||
}
|
||||
|
||||
// Concatenate slices from 'output_slices' along batch dimension.
|
||||
auto concat = computation->AddInstruction(
|
||||
HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0));
|
||||
// Reshape output 'new_dot' to original dimensions.
|
||||
auto new_dot = computation->AddInstruction(
|
||||
HloInstruction::CreateReshape(dot_shape, concat));
|
||||
|
||||
// Replace all uses of 'dot' in 'computation' with 'new_dot'.
|
||||
return computation->ReplaceInstruction(dot, new_dot);
|
||||
}
|
||||
|
||||
// Convert a dot into a canonical form where non-contracting and contracting
|
||||
// dimensions are reshaped together and batch dimensions are the most major
|
||||
// dimensions. The requires transposing and reshapes the lhs and rhs and
|
||||
@ -323,27 +194,6 @@ StatusOr<bool> DotDecomposer::Run(HloModule* module) {
|
||||
TF_RETURN_IF_ERROR(CanonicalizeDot(dot));
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if (decompose_batch_dot_) {
|
||||
std::vector<HloInstruction*> batch_dots;
|
||||
for (auto* computation : module->MakeNonfusionComputations()) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() != HloOpcode::kDot) {
|
||||
continue;
|
||||
}
|
||||
const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
|
||||
if (!dnums.lhs_batch_dimensions().empty()) {
|
||||
batch_dots.push_back(instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Decompose each batch Dot in 'batch_dots'.
|
||||
|
||||
for (auto* dot : batch_dots) {
|
||||
TF_RETURN_IF_ERROR(DecomposeBatchDot(dot));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString());
|
||||
return changed;
|
||||
}
|
||||
|
@ -21,22 +21,16 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// DotDecomposer is a pass which decomposes batch Dot operations into a
|
||||
// sequence of smaller (R2) Dot operations.
|
||||
// DotDecomposer is a pass which converts dots into a canonical form where
|
||||
// non-contracting and contracting dimensions are reshaped together and batch
|
||||
// dimensions are the most major dimensions.
|
||||
class DotDecomposer : public HloModulePass {
|
||||
public:
|
||||
// Decomposes batch Dot operations when 'decompose_batch_dot' is true.
|
||||
DotDecomposer(bool decompose_batch_dot = true)
|
||||
: decompose_batch_dot_(decompose_batch_dot) {}
|
||||
~DotDecomposer() = default;
|
||||
absl::string_view name() const override { return "dot_decomposer"; }
|
||||
|
||||
// Run DotDecomposer pass on computations in 'module'.
|
||||
// Returns whether the 'module' was changed.
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
bool decompose_batch_dot_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -195,7 +195,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
// We need a cost model for GPUs. Currently, do nothing.
|
||||
return false;
|
||||
};
|
||||
pipeline.AddPass<DotDecomposer>(false);
|
||||
pipeline.AddPass<DotDecomposer>();
|
||||
pipeline.AddPass<ConvolutionGroupConverter>(
|
||||
cost_model,
|
||||
/*convert_batch_groups_only=*/true);
|
||||
|
Loading…
x
Reference in New Issue
Block a user