Refactor logic from buffer_liveness to use in HeapSimulator.
Also added some simple tests. Change: 150144113
This commit is contained in:
parent
830cde8776
commit
e0d0c676ec
tensorflow/compiler/xla/service
@ -493,6 +493,36 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "liveness_util",
|
||||
srcs = ["liveness_util.cc"],
|
||||
hdrs = ["liveness_util.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "liveness_util_test",
|
||||
srcs = ["liveness_util_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":liveness_util",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "buffer_liveness",
|
||||
srcs = [
|
||||
@ -504,6 +534,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":liveness_util",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -586,6 +617,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":liveness_util",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -17,11 +17,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/liveness_util.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -92,128 +92,6 @@ string BufferLiveness::ToString() const {
|
||||
return tensorflow::str_util::Join(pieces, "\n");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns false if 'user' cannot possibly use the buffer at 'index' in
|
||||
// 'operand'. Returns true otherwise.
|
||||
// Precondition: 'operand' is an operand of 'user'.
|
||||
bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index,
|
||||
HloInstruction* user,
|
||||
const TuplePointsToAnalysis& points_to_analysis) {
|
||||
if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
|
||||
// GetTupleElement instructions only access the top-level buffer of their
|
||||
// operand.
|
||||
return false;
|
||||
} else if (user->opcode() == HloOpcode::kFusion &&
|
||||
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
// Find fusion parameter associated with 'operand'.
|
||||
auto it = std::find_if(
|
||||
user->fused_parameters().begin(), user->fused_parameters().end(),
|
||||
[=](HloInstruction* fused_param) {
|
||||
return user->operand(fused_param->parameter_number()) == operand;
|
||||
});
|
||||
CHECK(it != user->fused_parameters().end());
|
||||
// Iterate through all users of all buffer aliases of the buffer in the
|
||||
// points-to set of fusion parameter at 'index'.
|
||||
// Return true if any uses are detected at 'index', returns false otherwise.
|
||||
const LogicalBuffer* buffer =
|
||||
points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
|
||||
for (const BufferAlias& alias :
|
||||
points_to_analysis.GetBufferAliases(*buffer)) {
|
||||
for (HloInstruction* alias_user : alias.instruction()->users()) {
|
||||
if (!MayUseBufferInOperand(alias.instruction(), alias.index(),
|
||||
alias_user, points_to_analysis)) {
|
||||
continue;
|
||||
}
|
||||
// Return true: use detected at 'buffer' -> 'alias' -> 'alias_user'.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// Return false: found no uses of 'operand' at 'index' in 'user'.
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
|
||||
// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
|
||||
// where 'user' is a user of an alias of 'intruction' at 'index', and
|
||||
// 'operand_index' is the operand index at which the alias appears in the
|
||||
// operand list of 'user'.
|
||||
std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
|
||||
HloInstruction* instruction, const ShapeIndex& index,
|
||||
const TuplePointsToAnalysis& points_to_analysis) {
|
||||
std::vector<std::pair<HloInstruction*, int64>> uses;
|
||||
const std::vector<const LogicalBuffer*>& points_to =
|
||||
points_to_analysis.GetPointsToSet(instruction).element(index);
|
||||
for (const LogicalBuffer* buffer : points_to) {
|
||||
for (const BufferAlias& alias :
|
||||
points_to_analysis.GetBufferAliases(*buffer)) {
|
||||
for (HloInstruction* alias_user : alias.instruction()->users()) {
|
||||
if (!MayUseBufferInOperand(alias.instruction(), alias.index(),
|
||||
alias_user, points_to_analysis)) {
|
||||
continue;
|
||||
}
|
||||
for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
|
||||
uses.emplace_back(alias_user, op_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return uses;
|
||||
}
|
||||
|
||||
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
|
||||
// 'operand' (at 'operand_index').
|
||||
// Returns false otherwise.
|
||||
// User and operand can share buffers iff both instructions emit the same shape
|
||||
// and layout, and 'user' meets one of the following two qualifications:
|
||||
// *) Is element-wise.
|
||||
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
|
||||
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
|
||||
// at operand 0.
|
||||
bool CanShareOperandBufferWithUser(
|
||||
HloInstruction* operand, const ShapeIndex& operand_index,
|
||||
HloInstruction* user, const ShapeIndex& user_index,
|
||||
const TuplePointsToAnalysis& points_to_analysis) {
|
||||
Shape operand_subshape =
|
||||
ShapeUtil::GetSubshape(operand->shape(), operand_index);
|
||||
Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index);
|
||||
// Check that operand and user emit the same shape and layout.
|
||||
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
|
||||
return false;
|
||||
}
|
||||
// Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice
|
||||
// fused root instruction.
|
||||
if (user->opcode() == HloOpcode::kFusion &&
|
||||
user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
|
||||
user->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice) {
|
||||
for (auto& fused_param : user->fused_parameters()) {
|
||||
// Find fusion parameter associated with 'operand'.
|
||||
if (user->operand(fused_param->parameter_number()) != operand) {
|
||||
continue;
|
||||
}
|
||||
// Get all uses of 'operand' at 'index' from 'user.fused_instructions'.
|
||||
auto fused_param_uses = GetAllUsesOfInstructionAtIndex(
|
||||
fused_param, operand_index, points_to_analysis);
|
||||
// Return true iff there is exactly one use of 'operand' at 'index', and
|
||||
// this singleton use is the fused root at operand index 0.
|
||||
if (fused_param_uses.size() == 1 &&
|
||||
fused_param_uses[0].first == user->fused_expression_root() &&
|
||||
fused_param_uses[0].second == 0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Check if 'user' is element-wise.
|
||||
return user->IsElementwise();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
|
||||
const LogicalBuffer& b) const {
|
||||
TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a));
|
||||
@ -226,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
|
||||
// Every user of 'a' must be a predecessor of 'b' or 'b' itself.
|
||||
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
|
||||
for (auto user : alias.instruction()->users()) {
|
||||
if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user,
|
||||
points_to_analysis())) {
|
||||
if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user,
|
||||
points_to_analysis())) {
|
||||
continue;
|
||||
}
|
||||
if (user != b.instruction() &&
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/liveness_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -26,6 +27,8 @@ namespace xla {
|
||||
using tensorflow::gtl::FlatMap;
|
||||
using tensorflow::gtl::FlatSet;
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns the set of buffers that may be sources of all operands of the given
|
||||
// instruction. The returned buffers are guaranteed to have no duplicates, and
|
||||
// to be sorted in a deterministic order.
|
||||
@ -46,6 +49,8 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
|
||||
return sorted;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/*static*/
|
||||
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
||||
@ -145,13 +150,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
||||
// we must be the last user of the buffer.
|
||||
bool shared = false;
|
||||
for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) {
|
||||
// The operand buffer can be shared if we have the same shape, and we're
|
||||
// an elementwise instruction.
|
||||
//
|
||||
// TODO(b/35903632): Refactor and use the CanShareOperandBufferWithUser
|
||||
// logic from buffer_liveness.cc
|
||||
if (ShapeUtil::Equal(buffer->shape(), operand_buffer->shape()) &&
|
||||
instruction->IsElementwise()) {
|
||||
if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
|
||||
CanShareOperandBufferWithUser(
|
||||
operand_buffer->instruction(), operand_buffer->index(),
|
||||
buffer->instruction(), buffer->index(), points_to_analysis)) {
|
||||
heap.ShareBuffer(buffer, operand_buffer);
|
||||
shared = true;
|
||||
break;
|
||||
|
151
tensorflow/compiler/xla/service/liveness_util.cc
Normal file
151
tensorflow/compiler/xla/service/liveness_util.cc
Normal file
@ -0,0 +1,151 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/liveness_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index,
|
||||
HloInstruction* user,
|
||||
const TuplePointsToAnalysis& points_to_analysis) {
|
||||
CHECK(user->IsUserOf(operand))
|
||||
<< "user: " << user->ToString() << " operand: " << operand->ToString();
|
||||
if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
|
||||
// GetTupleElement instructions only access the top-level buffer of their
|
||||
// operand.
|
||||
return true;
|
||||
} else if (user->opcode() == HloOpcode::kFusion &&
|
||||
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
// Find fusion parameter associated with 'operand'.
|
||||
auto it = std::find_if(
|
||||
user->fused_parameters().begin(), user->fused_parameters().end(),
|
||||
[=](HloInstruction* fused_param) {
|
||||
return user->operand(fused_param->parameter_number()) == operand;
|
||||
});
|
||||
CHECK(it != user->fused_parameters().end());
|
||||
// Iterate through all users of all buffer aliases of the buffer in the
|
||||
// points-to set of fusion parameter at 'index'.
|
||||
// Return false if any uses are detected at 'index', returns true otherwise.
|
||||
const LogicalBuffer* buffer =
|
||||
points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
|
||||
for (const BufferAlias& alias :
|
||||
points_to_analysis.GetBufferAliases(*buffer)) {
|
||||
for (HloInstruction* alias_user : alias.instruction()->users()) {
|
||||
if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
|
||||
alias_user, points_to_analysis)) {
|
||||
continue;
|
||||
}
|
||||
// Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Return true: found no uses of 'operand' at 'index' in 'user'.
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
|
||||
// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
|
||||
// where 'user' is a user of an alias of 'intruction' at 'index', and
|
||||
// 'operand_index' is the operand index at which the alias appears in the
|
||||
// operand list of 'user'.
|
||||
std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
|
||||
HloInstruction* instruction, const ShapeIndex& index,
|
||||
const TuplePointsToAnalysis& points_to_analysis) {
|
||||
std::vector<std::pair<HloInstruction*, int64>> uses;
|
||||
const std::vector<const LogicalBuffer*>& points_to =
|
||||
points_to_analysis.GetPointsToSet(instruction).element(index);
|
||||
for (const LogicalBuffer* buffer : points_to) {
|
||||
for (const BufferAlias& alias :
|
||||
points_to_analysis.GetBufferAliases(*buffer)) {
|
||||
for (HloInstruction* alias_user : alias.instruction()->users()) {
|
||||
if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
|
||||
alias_user, points_to_analysis)) {
|
||||
continue;
|
||||
}
|
||||
for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
|
||||
uses.emplace_back(alias_user, op_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return uses;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// User and operand can share buffers iff both instructions emit the same shape
|
||||
// and layout, and 'user' meets one of the following two qualifications:
|
||||
// *) Is element-wise.
|
||||
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
|
||||
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
|
||||
// at operand 0.
|
||||
bool CanShareOperandBufferWithUser(
|
||||
HloInstruction* operand, const ShapeIndex& operand_index,
|
||||
HloInstruction* user, const ShapeIndex& user_index,
|
||||
const TuplePointsToAnalysis& points_to_analysis) {
|
||||
CHECK(user->IsUserOf(operand))
|
||||
<< "user: " << user->ToString() << " operand: " << operand->ToString();
|
||||
Shape operand_subshape =
|
||||
ShapeUtil::GetSubshape(operand->shape(), operand_index);
|
||||
Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index);
|
||||
// Check that operand and user emit the same shape and layout.
|
||||
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
|
||||
return false;
|
||||
}
|
||||
// Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice
|
||||
// fused root instruction.
|
||||
if (user->opcode() == HloOpcode::kFusion &&
|
||||
user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
|
||||
user->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice) {
|
||||
for (auto& fused_param : user->fused_parameters()) {
|
||||
// Find fusion parameter associated with 'operand'.
|
||||
if (user->operand(fused_param->parameter_number()) != operand) {
|
||||
continue;
|
||||
}
|
||||
// Get all uses of 'operand' at 'index' from 'user.fused_instructions'.
|
||||
auto fused_param_uses = GetAllUsesOfInstructionAtIndex(
|
||||
fused_param, operand_index, points_to_analysis);
|
||||
// Return true iff there is exactly one use of 'operand' at 'index', and
|
||||
// this singleton use is the fused root at operand index 0.
|
||||
if (fused_param_uses.size() == 1 &&
|
||||
fused_param_uses[0].first == user->fused_expression_root() &&
|
||||
fused_param_uses[0].second == 0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Check if 'user' is element-wise.
|
||||
return user->IsElementwise();
|
||||
}
|
||||
|
||||
} // namespace xla
|
51
tensorflow/compiler/xla/service/liveness_util.h
Normal file
51
tensorflow/compiler/xla/service/liveness_util.h
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A collection of utilities on the HLO graph.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Returns true if 'user' cannot possibly use the buffer at 'index' in
|
||||
// 'operand'. Returns false otherwise.
|
||||
//
|
||||
// REQUIRES: 'operand' is an operand of 'user'.
|
||||
bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index,
|
||||
HloInstruction* user,
|
||||
const TuplePointsToAnalysis& points_to_analysis);
|
||||
|
||||
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
|
||||
// 'operand' (at 'operand_index').
|
||||
// Returns false otherwise.
|
||||
//
|
||||
// REQUIRES: 'operand' is an operand of 'user'.
|
||||
bool CanShareOperandBufferWithUser(
|
||||
HloInstruction* operand, const ShapeIndex& operand_index,
|
||||
HloInstruction* user, const ShapeIndex& user_index,
|
||||
const TuplePointsToAnalysis& points_to_analysis);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
|
189
tensorflow/compiler/xla/service/liveness_util_test.cc
Normal file
189
tensorflow/compiler/xla/service/liveness_util_test.cc
Normal file
@ -0,0 +1,189 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/liveness_util.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class PointsToAnalysisTestBase : public HloTestBase {
|
||||
protected:
|
||||
void BuildModule(std::unique_ptr<HloComputation> computation) {
|
||||
module_ = MakeUnique<HloModule>(TestName());
|
||||
computation_ = module_->AddEntryComputation(std::move(computation));
|
||||
}
|
||||
|
||||
void RunAnalysis() {
|
||||
CHECK_NOTNULL(module_.get());
|
||||
points_to_analysis_ =
|
||||
TuplePointsToAnalysis::Run(module_.get(),
|
||||
/*include_loop_fusion_instructions=*/true)
|
||||
.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
|
||||
BuildModule(std::move(computation));
|
||||
RunAnalysis();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> module_;
|
||||
HloComputation* computation_ = nullptr;
|
||||
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
|
||||
};
|
||||
|
||||
class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
|
||||
|
||||
TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
// GetTupleElement instructions only access the top-level buffer of their
|
||||
// operand.
|
||||
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_));
|
||||
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_));
|
||||
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_));
|
||||
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
|
||||
|
||||
// Create a DynamicUpdateSlice instruction of tuple element 1.
|
||||
auto starts = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
|
||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
auto dynamic_update_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, gte1, update, starts));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
|
||||
|
||||
BuildModule(builder.Build());
|
||||
auto fusion = computation_->CreateFusionInstruction(
|
||||
{dynamic_update_slice, starts, update, gte1},
|
||||
HloInstruction::FusionKind::kLoop);
|
||||
RunAnalysis();
|
||||
|
||||
// The fusion instruction never uses tuple element 0, but does use element 1.
|
||||
EXPECT_TRUE(
|
||||
DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_));
|
||||
EXPECT_FALSE(
|
||||
DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_));
|
||||
}
|
||||
|
||||
class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {};
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {8});
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param"));
|
||||
auto exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
|
||||
auto log = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape in_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, in_shape, "param0"));
|
||||
auto param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, in_shape, "param1"));
|
||||
auto result = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
|
||||
*points_to_analysis_));
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
|
||||
*points_to_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
|
||||
|
||||
// Create a DynamicUpdateSlice instruction of tuple element 1.
|
||||
auto starts = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
|
||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
auto dynamic_update_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, gte1, update, starts));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
|
||||
|
||||
BuildModule(builder.Build());
|
||||
auto fusion = computation_->CreateFusionInstruction(
|
||||
{dynamic_update_slice, starts, update, gte1},
|
||||
HloInstruction::FusionKind::kLoop);
|
||||
RunAnalysis();
|
||||
|
||||
// The fusion instruction can share with tuple element 1.
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
|
||||
*points_to_analysis_));
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
|
||||
*points_to_analysis_));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user