[XLA] Separate bytes accessed in cost analysis by operands and outputs.

This allows finer-grained analysis to calculate memory bottlenecks, especially
when different inputs/outputs of an HLO instruction lie in different memory
spaces with different bandwidths.

PiperOrigin-RevId: 276367915
Change-Id: I7c0d29228ce5d23e7908fb38e221c732cadafb35
This commit is contained in:
Berkin Ilbeyi 2019-10-23 15:49:38 -07:00 committed by TensorFlower Gardener
parent 0d654e64e8
commit ed14965279
3 changed files with 381 additions and 36 deletions

View File

@ -55,8 +55,11 @@ Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
// sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
// handle opaque types.
float bytes_accessed = GetShapeSize(hlo->shape());
for (const HloInstruction* operand : hlo->operands()) {
SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
for (int64 i = 0; i < hlo->operand_count(); ++i) {
const HloInstruction* operand = hlo->operand(i);
bytes_accessed += GetShapeSize(operand->shape());
SetOperandBytesAccessed(i, GetShapeSize(operand->shape()));
}
current_properties_[kBytesAccessedKey] = bytes_accessed;
@ -199,6 +202,7 @@ Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
current_should_compute_bottleneck_time_ = false;
current_properties_[kBytesAccessedKey] = 0;
SetOutputBytesAccessed(0);
current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
@ -206,6 +210,7 @@ Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
current_should_compute_bottleneck_time_ = false;
current_properties_[kBytesAccessedKey] = 0;
SetOutputBytesAccessed(0);
current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
@ -214,11 +219,14 @@ Status HloCostAnalysis::HandleIota(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) {
Status HloCostAnalysis::HandleGetTupleElement(
const HloInstruction* get_tuple_element) {
// GetTupleElement forwards a pointer and does not touch each element in the
// output.
current_should_compute_bottleneck_time_ = false;
current_properties_[kBytesAccessedKey] = 0;
SetOutputBytesAccessed(0);
SetOperandBytesAccessed(0, 0);
current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
@ -237,20 +245,35 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
SetOutputBytesAccessed(GetShapeSize(slice->shape()));
SetOperandBytesAccessed(0, GetShapeSize(slice->shape()));
return Status::OK();
}
Status HloCostAnalysis::HandleDynamicSlice(
const HloInstruction* dynamic_slice) {
current_properties_[kBytesAccessedKey] =
GetShapeSize(dynamic_slice->shape()) * 2;
GetShapeSize(dynamic_slice->shape()) * 2 +
GetShapeSize(dynamic_slice->operand(1)->shape());
SetOutputBytesAccessed(GetShapeSize(dynamic_slice->shape()));
SetOperandBytesAccessed(0, GetShapeSize(dynamic_slice->shape()));
SetOperandBytesAccessed(1, GetShapeSize(dynamic_slice->operand(1)->shape()));
return Status::OK();
}
Status HloCostAnalysis::HandleDynamicUpdateSlice(
const HloInstruction* dynamic_update_slice) {
current_properties_[kBytesAccessedKey] =
GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2;
GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2 +
GetShapeSize(dynamic_update_slice->operand(2)->shape());
// Operand 0 aliases with the output.
SetOutputBytesAccessed(
GetShapeSize(dynamic_update_slice->operand(1)->shape()));
SetOperandBytesAccessed(0, 0);
SetOperandBytesAccessed(
1, GetShapeSize(dynamic_update_slice->operand(1)->shape()));
SetOperandBytesAccessed(
2, GetShapeSize(dynamic_update_slice->operand(2)->shape()));
return Status::OK();
}
@ -260,6 +283,10 @@ Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
// index table of the tuple.
current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
SetOutputBytesAccessed(GetShapeSize(tuple->shape()));
for (int i = 0; i < tuple->operand_count(); ++i) {
SetOperandBytesAccessed(i, 0);
}
return Status::OK();
}
@ -279,6 +306,10 @@ Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
// Domain does not have any computation or data transfer.
current_should_compute_bottleneck_time_ = false;
current_properties_[kBytesAccessedKey] = 0;
SetOutputBytesAccessed(0);
for (int i = 0; i < domain->operand_count(); ++i) {
SetOperandBytesAccessed(i, 0);
}
current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
@ -315,7 +346,7 @@ Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
// Compute the cost of all elements for this Map operation.
const int64 element_count = ShapeUtil::ElementsIn(map->shape());
for (const auto& property : sub_properties) {
if (property.first != kBytesAccessedKey) {
if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
current_properties_[property.first] = property.second * element_count;
}
}
@ -339,7 +370,7 @@ Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
int64 reduction_count =
ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
for (const auto& property : sub_properties) {
if (property.first != kBytesAccessedKey) {
if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
current_properties_[property.first] = property.second * reduction_count;
}
}
@ -365,7 +396,7 @@ Status HloCostAnalysis::HandleReduceWindow(
const int64 reduction_count =
(window_element_count - 1) * output_element_count;
for (const auto& property : sub_properties) {
if (property.first != kBytesAccessedKey) {
if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
current_properties_[property.first] = property.second * reduction_count;
}
}
@ -392,12 +423,12 @@ Status HloCostAnalysis::HandleSelectAndScatter(
}
const int64 select_count = source_element_count * (window_element_count - 1);
for (const auto& property : select_properties) {
if (property.first != kBytesAccessedKey) {
if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
current_properties_[property.first] += property.second * select_count;
}
}
for (const auto& property : scatter_properties) {
if (property.first != kBytesAccessedKey) {
if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
current_properties_[property.first] +=
property.second * source_element_count;
}
@ -408,6 +439,8 @@ Status HloCostAnalysis::HandleSelectAndScatter(
Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
// A bitcast does no computation and touches no memory.
current_properties_[kBytesAccessedKey] = 0;
SetOutputBytesAccessed(0);
SetOperandBytesAccessed(0, 0);
current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
@ -467,11 +500,15 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) {
Status HloCostAnalysis::HandleAfterAll(const HloInstruction* token) {
// This instruction is used to enforce ordering at compile time. No code is
// emitted.
current_should_compute_bottleneck_time_ = false;
current_properties_[kBytesAccessedKey] = 0;
SetOutputBytesAccessed(0);
for (int i = 0; i < token->operand_count(); ++i) {
SetOperandBytesAccessed(i, 0);
}
current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
@ -482,6 +519,10 @@ Status HloCostAnalysis::HandleAddDependency(
// emitted.
current_should_compute_bottleneck_time_ = false;
current_properties_[kBytesAccessedKey] = 0;
SetOutputBytesAccessed(0);
for (int i = 0; i < add_dependency->operand_count(); ++i) {
SetOperandBytesAccessed(i, 0);
}
current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
@ -627,8 +668,13 @@ Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
}
Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
// Half of operand 0 is read.
float bytes_accessed = GetShapeSize(hlo->shape());
SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
bytes_accessed += GetShapeSize(hlo->operand(1)->shape());
SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(1)->shape()));
current_properties_[kBytesAccessedKey] = bytes_accessed;
const Shape& a_shape = hlo->operand(0)->shape();
@ -641,7 +687,11 @@ Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
}
Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
// Half of operand 0 is read and half of the output will be written.
float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
SetOutputBytesAccessed(GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
current_properties_[kBytesAccessedKey] = bytes_accessed;
const Shape& a_shape = hlo->operand(0)->shape();
@ -728,8 +778,10 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
if (shape_index.empty()) {
if (fusion->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
current_properties_[kBytesAccessedKey] += GetShapeSize(
int64 size = GetShapeSize(
fusion->fused_expression_root()->operand(1)->shape());
current_properties_[kBytesAccessedKey] += size;
SetOutputBytesAccessed(shape_index, size);
return;
}
} else if (shape_index.size() == 1) {
@ -737,19 +789,54 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
fusion->fused_expression_root()
->operand(shape_index[0])
->opcode() == HloOpcode::kDynamicUpdateSlice) {
current_properties_[kBytesAccessedKey] +=
GetShapeSize(fusion->fused_expression_root()
->operand(shape_index[0])
->operand(1)
->shape());
int64 size = GetShapeSize(fusion->fused_expression_root()
->operand(shape_index[0])
->operand(1)
->shape());
current_properties_[kBytesAccessedKey] += size;
SetOutputBytesAccessed(shape_index, size);
return;
}
}
current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
SetOutputBytesAccessed(shape_index, GetShapeSize(subshape));
});
for (const HloInstruction* operand : fusion->fused_parameters()) {
current_properties_[kBytesAccessedKey] += FusionParameterReadBytes(operand);
if (fusion->shape().IsTuple()) {
// Propagate and accumulate the output tuple bytes from the tuple subshapes.
// This ensures we have the correct output bytes accessed for the shape
// index
// {}.
std::function<float(const Shape&, const ShapeIndex&)>
propagate_output_size_to_parent;
propagate_output_size_to_parent = [&](const Shape& shape,
const ShapeIndex& shape_index) {
auto output_bytes_it =
current_properties_.find(GetOutputBytesAccessedKey(shape_index));
if (output_bytes_it != current_properties_.end()) {
return output_bytes_it->second;
}
float bytes_accessed = 0;
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
const Shape& subshape = shape.tuple_shapes(i);
ShapeIndex subshape_index(shape_index);
subshape_index.push_back(i);
bytes_accessed +=
propagate_output_size_to_parent(subshape, subshape_index);
}
SetOutputBytesAccessed(shape_index, bytes_accessed);
return bytes_accessed;
};
current_properties_.erase(
current_properties_.find(GetOutputBytesAccessedKey()));
propagate_output_size_to_parent(fusion->shape(), {});
}
for (int64 i = 0; i < fusion->fused_parameters().size(); ++i) {
const HloInstruction* operand = fusion->fused_parameter(i);
int64 size = FusionParameterReadBytes(operand);
current_properties_[kBytesAccessedKey] += size;
SetOperandBytesAccessed(i, size);
}
return Status::OK();
@ -762,13 +849,17 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
return Status::OK();
}
Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) {
Status HloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) {
// Mark applicable fields as "unknown", since we don't know what CustomCall
// does. This is better than returning an error, which would stop iteration,
// and therefore would prevent us from getting *any* stats for a computation
// which contains a CustomCall.
current_properties_[kOptimalSecondsKey] = -1;
current_properties_[kBytesAccessedKey] = -1;
SetOutputBytesAccessed(-1);
for (int i = 0; i < custom_call->operand_count(); ++i) {
SetOperandBytesAccessed(i, -1);
}
current_properties_[kFlopsKey] = -1;
current_should_compute_bottleneck_time_ = false;
return Status::OK();
@ -831,9 +922,12 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
// Gather doesn't read the whole input buffer, it's equivalent to a copy the
// size of the output shape and a read of the gather indices.
int64 output_size = GetShapeSize(gather->shape());
current_properties_[kBytesAccessedKey] =
GetShapeSize(gather->shape()) * 2 +
GetShapeSize(gather->operand(1)->shape());
output_size * 2 + GetShapeSize(gather->operand(1)->shape());
SetOperandBytesAccessed(0, output_size);
SetOperandBytesAccessed(1, GetShapeSize(gather->operand(1)->shape()));
SetOutputBytesAccessed(output_size);
// Gather does not issue any flops.
return Status::OK();
}
@ -841,15 +935,19 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
// Scatter accesses the equivalent of 3 update shapes (input, output, and
// updates), and the scatter indices.
int64 update_size = GetShapeSize(scatter->operand(2)->shape());
current_properties_[kBytesAccessedKey] =
GetShapeSize(scatter->operand(2)->shape()) * 3 +
GetShapeSize(scatter->operand(1)->shape());
update_size * 3 + GetShapeSize(scatter->operand(1)->shape());
SetOperandBytesAccessed(0, update_size);
SetOperandBytesAccessed(1, GetShapeSize(scatter->operand(1)->shape()));
SetOperandBytesAccessed(2, update_size);
SetOutputBytesAccessed(update_size);
const int64 element_count =
ShapeUtil::ElementsIn(scatter->operand(2)->shape());
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
ProcessSubcomputation(scatter->to_apply()));
for (const auto& property : sub_properties) {
if (property.first != kBytesAccessedKey) {
if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
current_properties_[property.first] = property.second * element_count;
}
}
@ -898,6 +996,19 @@ int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
}
int64 HloCostAnalysis::operand_bytes_accessed(const HloInstruction& hlo,
int64 operand_num,
ShapeIndex index) const {
return GetPropertyForHlo(hlo, GetOperandBytesAccessedKey(operand_num, index),
hlo_properties_);
}
int64 HloCostAnalysis::output_bytes_accessed(const HloInstruction& hlo,
ShapeIndex index) const {
return GetPropertyForHlo(hlo, GetOutputBytesAccessedKey(index),
hlo_properties_);
}
float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
}
@ -917,4 +1028,33 @@ std::unique_ptr<HloCostAnalysis> HloCostAnalysis::CreateNestedCostAnalysis(
return absl::WrapUnique(new HloCostAnalysis(shape_size, per_second_rates));
}
void HloCostAnalysis::SetOperandBytesAccessed(int64 operand_num, float value) {
current_properties_[GetOperandBytesAccessedKey(operand_num).c_str()] = value;
}
void HloCostAnalysis::SetOperandBytesAccessed(int64 operand_num,
ShapeIndex index, float value) {
current_properties_[GetOperandBytesAccessedKey(operand_num, index).c_str()] =
value;
}
void HloCostAnalysis::SetOutputBytesAccessed(float value) {
current_properties_[GetOutputBytesAccessedKey()] = value;
}
void HloCostAnalysis::SetOutputBytesAccessed(ShapeIndex index, float value) {
current_properties_[GetOutputBytesAccessedKey(index)] = value;
}
/*static*/ std::string HloCostAnalysis::GetOperandBytesAccessedKey(
int64 operand_num, ShapeIndex index) {
return absl::StrCat(kBytesAccessedKey, " operand ", operand_num, " ",
index.ToString());
}
/*static*/ std::string HloCostAnalysis::GetOutputBytesAccessedKey(
ShapeIndex index) {
return absl::StrCat(kBytesAccessedKey, " output ", index.ToString());
}
} // namespace xla

View File

@ -150,6 +150,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
int64 flop_count(const HloInstruction& hlo) const;
int64 transcendental_count(const HloInstruction& hlo) const;
int64 bytes_accessed(const HloInstruction& hlo) const;
int64 operand_bytes_accessed(const HloInstruction& hlo, int64 operand_num,
ShapeIndex index = {}) const;
int64 output_bytes_accessed(const HloInstruction& hlo,
ShapeIndex index = {}) const;
float optimal_seconds(const HloInstruction& hlo) const;
const Properties& properties() const { return properties_sum_; }
@ -198,6 +202,21 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
// node.
int64 FusionParameterReadBytes(const HloInstruction* hlo) const;
// Set bytes accessed by the specified operand and shape index.
void SetOperandBytesAccessed(int64 operand_num, float value);
void SetOperandBytesAccessed(int64 operand_num, ShapeIndex index,
float value);
// Set bytes accessed by the output at the shape index.
void SetOutputBytesAccessed(float value);
void SetOutputBytesAccessed(ShapeIndex index, float value);
// Return the key that is used to index into Properties for the specified
// input/output at the shape index.
static std::string GetOperandBytesAccessedKey(int64 operand_num,
ShapeIndex index = {});
static std::string GetOutputBytesAccessedKey(ShapeIndex index = {});
// Function which computes the size of the top-level of a given shape (not
// including nested elements, if any). If null then bytes_accessed methods
// return an error.

View File

@ -155,6 +155,11 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) {
// Bytes accessed is sum of inputs and output.
EXPECT_EQ(analysis.bytes_accessed(),
sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 5);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 5 * 30);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 30);
}
TEST_F(HloCostAnalysisTest, DotGeneral) {
@ -184,6 +189,13 @@ TEST_F(HloCostAnalysisTest, DotGeneral) {
// Bytes accessed is sum of inputs and output.
EXPECT_EQ(analysis.bytes_accessed(),
sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 10 * 30));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0),
sizeof(float) * 10 * 5 * 5);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1),
sizeof(float) * 5 * 5 * 30);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 30);
}
TEST_F(HloCostAnalysisTest, DotGeneral2) {
@ -213,6 +225,13 @@ TEST_F(HloCostAnalysisTest, DotGeneral2) {
// Bytes accessed is sum of inputs and output.
EXPECT_EQ(analysis.bytes_accessed(),
sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 5 * 10 * 30));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0),
sizeof(float) * 10 * 5 * 5);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1),
sizeof(float) * 5 * 5 * 30);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 5 * 10 * 30);
}
TEST_F(HloCostAnalysisTest, DotGeneral3) {
@ -236,6 +255,12 @@ TEST_F(HloCostAnalysisTest, DotGeneral3) {
// Bytes accessed is sum of inputs and output.
EXPECT_EQ(analysis.bytes_accessed(),
sizeof(float) * (10 * 5 + 5 * 30 + 5 * 5 * 10 * 30));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 5);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 5 * 30);
EXPECT_EQ(analysis.output_bytes_accessed(*root),
sizeof(float) * 5 * 5 * 10 * 30);
}
TEST_F(HloCostAnalysisTest, Map) {
@ -253,6 +278,10 @@ TEST_F(HloCostAnalysisTest, Map) {
EXPECT_EQ(analysis.flop_count(), 10);
EXPECT_EQ(analysis.transcendental_count(), 10);
EXPECT_EQ(analysis.bytes_accessed(), 80);
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10);
}
TEST_F(HloCostAnalysisTest, Convolution) {
@ -282,6 +311,11 @@ TEST_F(HloCostAnalysisTest, Convolution) {
// Bytes accessed is sum of inputs and output.
EXPECT_EQ(analysis.bytes_accessed(),
sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 3 * 3);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 8 * 18);
}
TEST_F(HloCostAnalysisTest, ConvolutionExtreme) {
@ -357,6 +391,14 @@ TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
// Bytes accessed is sum of inputs and output.
EXPECT_EQ(analysis.bytes_accessed(),
sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0),
sizeof(float) * 120 * 10 * 20);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1),
sizeof(float) * 120 * 3 * 3);
EXPECT_EQ(analysis.output_bytes_accessed(*root),
sizeof(float) * 120 * 8 * 18);
}
TEST_F(HloCostAnalysisTest, Reduce) {
@ -374,6 +416,13 @@ TEST_F(HloCostAnalysisTest, Reduce) {
// Subtracting the output size from the input size gives the number of
// reduction operations performed.
EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10);
EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 + 1 + 10));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 1);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10);
}
TEST_F(HloCostAnalysisTest, ReduceWindow) {
@ -391,6 +440,13 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) {
// Each of [2x4] output elements are generated from reducing [4x5] elements.
EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1));
EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 + 1 + 2 * 4));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 1);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 4);
}
TEST_F(HloCostAnalysisTest, SelectAndScatter) {
@ -411,6 +467,15 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) {
// Each of [2x4] source elements computes its destination from reducing [4x5]
// elements followed by the scatter computation.
EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1));
EXPECT_EQ(analysis.bytes_accessed(),
sizeof(float) * (10 * 20 + 2 * 4 + 1 + 10 * 20));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 2 * 4);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(float) * 1);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 20);
}
TEST_F(HloCostAnalysisTest, Broadcast) {
@ -421,6 +486,12 @@ TEST_F(HloCostAnalysisTest, Broadcast) {
ASSERT_IS_OK(
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
EXPECT_EQ(analysis.flop_count(), 0);
EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (1 + 10 * 7));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 1);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 7);
}
// Calculates the computation cost of a graph with more than one HLO node.
@ -535,10 +606,84 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
static_assert(bytes_accessed == 64, "");
EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i);
}
}
TEST_F(FusionCostAnalysis, LoopFusionTupleOutput) {
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
// Same as above but the fusion outputs a tuple.
HloComputation::Builder builder(TestName());
auto c1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
auto c2 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
auto c3 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({c1, c2}));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
auto clamp = builder.AddInstruction(
HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
auto mul = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
auto sub = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
auto tuple2 = builder.AddInstruction(
HloInstruction::CreateTuple({sub, sub, mul, tuple1}));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{tuple2, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
HloCostAnalysis fusion_analysis(ShapeSize);
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
EXPECT_EQ(fusion_analysis.flop_count(), 16);
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion),
sizeof(float) * (3 + 5) * 2 * 2 + kPointerSize * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
kPointerSize * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 3),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
sizeof(float) * 5 * 2 * 2);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {0}),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {1}),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {2}),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {3}),
sizeof(float) * 2 * 2 * 2);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {3, 0}),
sizeof(float) * 2 * 2);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {3, 1}),
sizeof(float) * 2 * 2);
}
TEST_F(FusionCostAnalysis, NoLayout) {
Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
// Instructions within a fused op may have no layout.
@ -566,24 +711,38 @@ TEST_F(FusionCostAnalysis, NoLayout) {
EXPECT_EQ(fusion_analysis.flop_count(), 120);
EXPECT_EQ(fusion_analysis.transcendental_count(), 0);
EXPECT_EQ(fusion_analysis.bytes_accessed(),
sizeof(float) * (2 * 3 * 4 * 5 + 3 + 2 * 3 * 4 * 5));
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
sizeof(float) * 2 * 3 * 4 * 5);
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
sizeof(float) * 3);
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
sizeof(float) * 2 * 3 * 4 * 5);
}
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);
{
XlaBuilder builder("tuple");
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
Tuple(&builder, {x, y});
auto hlo_module = BuildHloGraph(&builder);
ASSERT_IS_OK(
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
}
XlaBuilder builder("tuple");
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
Tuple(&builder, {x, y});
auto hlo_module = BuildHloGraph(&builder);
ASSERT_IS_OK(
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
EXPECT_EQ(analysis.flop_count(), 0);
EXPECT_EQ(analysis.transcendental_count(), 0);
EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), 0);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), 0);
EXPECT_EQ(analysis.output_bytes_accessed(*root), kPointerSize * 2);
}
using DomainCostAnalysis = HloTestBase;
@ -650,6 +809,10 @@ TEST_F(HloCostAnalysisTest, Slice) {
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
EXPECT_EQ(analysis.bytes_accessed(), 8);
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float));
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float));
}
TEST_F(HloCostAnalysisTest, DynamicSlice) {
@ -665,7 +828,12 @@ TEST_F(HloCostAnalysisTest, DynamicSlice) {
ASSERT_IS_OK(
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
EXPECT_EQ(analysis.bytes_accessed(), 8);
EXPECT_EQ(analysis.bytes_accessed(), 8 + 4);
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float));
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32));
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float));
}
TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) {
@ -681,7 +849,14 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) {
ASSERT_IS_OK(
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
EXPECT_EQ(analysis.bytes_accessed(), 8);
EXPECT_EQ(analysis.bytes_accessed(), 8 + 4);
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), 0);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float));
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(int32));
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float));
}
TEST_F(HloCostAnalysisTest, Gather) {
@ -707,6 +882,11 @@ TEST_F(HloCostAnalysisTest, Gather) {
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
EXPECT_EQ(analysis.bytes_accessed(), 56);
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 2 * 3);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32) * 2);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 3);
}
TEST_F(HloCostAnalysisTest, Scatter) {
@ -734,6 +914,12 @@ TEST_F(HloCostAnalysisTest, Scatter) {
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 3 * (2 * 3)));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 2 * 3);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32) * 2);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(float) * 2 * 3);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 3);
}
} // namespace
} // namespace xla