[XLA] BufferValue::Color now type aliases int64.
PiperOrigin-RevId: 313404227 Change-Id: I2d393d426865c61ff210f10e3d9b8402a1813cf1
This commit is contained in:
parent
be4d17c088
commit
96ba1c3609
@ -261,7 +261,7 @@ void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset,
|
|||||||
Shape* shape = ShapeUtil::GetMutableSubshape(
|
Shape* shape = ShapeUtil::GetMutableSubshape(
|
||||||
position.instruction->mutable_shape(), position.index);
|
position.instruction->mutable_shape(), position.index);
|
||||||
if (shape->has_layout()) {
|
if (shape->has_layout()) {
|
||||||
shape->mutable_layout()->set_memory_space(buffer.color().value());
|
shape->mutable_layout()->set_memory_space(buffer.color());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -272,7 +272,7 @@ BufferAllocationProto BufferAllocation::ToProto() const {
|
|||||||
proto.set_size(size_);
|
proto.set_size(size_);
|
||||||
proto.set_is_thread_local(is_thread_local_);
|
proto.set_is_thread_local(is_thread_local_);
|
||||||
proto.set_is_tuple(is_tuple_);
|
proto.set_is_tuple(is_tuple_);
|
||||||
proto.set_color(color_.value());
|
proto.set_color(color_);
|
||||||
if (is_entry_computation_parameter_) {
|
if (is_entry_computation_parameter_) {
|
||||||
proto.set_is_entry_computation_parameter(true);
|
proto.set_is_entry_computation_parameter(true);
|
||||||
for (int64 idx : param_shape_index()) {
|
for (int64 idx : param_shape_index()) {
|
||||||
@ -336,8 +336,8 @@ static const HloInstruction* GetOutputInstruction(
|
|||||||
string BufferAllocation::ToString() const {
|
string BufferAllocation::ToString() const {
|
||||||
string output;
|
string output;
|
||||||
StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
|
StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
|
||||||
if (color().value() != 0) {
|
if (color() != 0) {
|
||||||
StrAppend(&output, ", color ", color().value());
|
StrAppend(&output, ", color ", color());
|
||||||
}
|
}
|
||||||
if (is_entry_computation_parameter()) {
|
if (is_entry_computation_parameter()) {
|
||||||
const HloInstruction* param = GetEntryParameterInstruction(*this);
|
const HloInstruction* param = GetEntryParameterInstruction(*this);
|
||||||
@ -607,9 +607,7 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation,
|
|||||||
// BufferAllocation.
|
// BufferAllocation.
|
||||||
void BufferAssignment::CombineTempAllocations() {
|
void BufferAssignment::CombineTempAllocations() {
|
||||||
VLOG(1) << "CombineTempAllocations()";
|
VLOG(1) << "CombineTempAllocations()";
|
||||||
flat_hash_map<BufferValue::Color, BufferAllocation,
|
flat_hash_map<BufferValue::Color, BufferAllocation> combined_allocation_map;
|
||||||
BufferValue::Color::Hasher>
|
|
||||||
combined_allocation_map;
|
|
||||||
|
|
||||||
// Move all temp allocations into a single run at the end of the allocations
|
// Move all temp allocations into a single run at the end of the allocations
|
||||||
// vector.
|
// vector.
|
||||||
@ -1059,8 +1057,8 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
|||||||
|
|
||||||
// The instruction or operand color is excluded because it was assigned by
|
// The instruction or operand color is excluded because it was assigned by
|
||||||
// memory_space_assignment.
|
// memory_space_assignment.
|
||||||
if (excluded_colors.contains(instruction_buffer.color().value()) ||
|
if (excluded_colors.contains(instruction_buffer.color()) ||
|
||||||
excluded_colors.contains(operand_buffer.color().value())) {
|
excluded_colors.contains(operand_buffer.color())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1353,13 +1351,10 @@ Status BufferAssigner::AssignBuffersForComputations(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
|
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>>
|
||||||
LogicalBuffer::Color::Hasher>
|
|
||||||
BufferAssigner::SplitBuffersByColor(
|
BufferAssigner::SplitBuffersByColor(
|
||||||
const flat_hash_set<const HloValue*>& buffers) {
|
const flat_hash_set<const HloValue*>& buffers) {
|
||||||
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
|
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>> color_map;
|
||||||
LogicalBuffer::Color::Hasher>
|
|
||||||
color_map;
|
|
||||||
for (auto buffer : buffers) {
|
for (auto buffer : buffers) {
|
||||||
color_map[buffer->color()].insert(buffer);
|
color_map[buffer->color()].insert(buffer);
|
||||||
}
|
}
|
||||||
@ -1374,8 +1369,7 @@ Status BufferAssigner::AssignPresetBuffers(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create an allocation for each preset color.
|
// Create an allocation for each preset color.
|
||||||
absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*,
|
absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*>
|
||||||
LogicalBuffer::Color::Hasher>
|
|
||||||
preset_allocations;
|
preset_allocations;
|
||||||
for (auto& color_and_info : preset_assignments_->assignment_informations()) {
|
for (auto& color_and_info : preset_assignments_->assignment_informations()) {
|
||||||
LogicalBuffer::Color color(color_and_info.first);
|
LogicalBuffer::Color color(color_and_info.first);
|
||||||
|
@ -673,8 +673,7 @@ class BufferAssigner {
|
|||||||
// Split a set of buffers into several sets, each of which contains buffers
|
// Split a set of buffers into several sets, each of which contains buffers
|
||||||
// colored with the same color.
|
// colored with the same color.
|
||||||
absl::flat_hash_map<LogicalBuffer::Color,
|
absl::flat_hash_map<LogicalBuffer::Color,
|
||||||
absl::flat_hash_set<const HloValue*>,
|
absl::flat_hash_set<const HloValue*>>
|
||||||
LogicalBuffer::Color::Hasher>
|
|
||||||
SplitBuffersByColor(const absl::flat_hash_set<const HloValue*>& buffers);
|
SplitBuffersByColor(const absl::flat_hash_set<const HloValue*>& buffers);
|
||||||
|
|
||||||
// If true, allocate buffers for constant instructions.
|
// If true, allocate buffers for constant instructions.
|
||||||
|
@ -59,7 +59,7 @@ LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const {
|
|||||||
ToLocationProto(*instruction(), index());
|
ToLocationProto(*instruction(), index());
|
||||||
proto.mutable_defined_at()->Swap(&proto_location);
|
proto.mutable_defined_at()->Swap(&proto_location);
|
||||||
if (has_color()) {
|
if (has_color()) {
|
||||||
proto.set_color(color().value());
|
proto.set_color(color());
|
||||||
}
|
}
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/gtl/int_type.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -86,7 +85,7 @@ namespace xla {
|
|||||||
|
|
||||||
class BufferValue {
|
class BufferValue {
|
||||||
public:
|
public:
|
||||||
TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64);
|
using Color = int64;
|
||||||
|
|
||||||
// Id is a unique identifier for the BufferValue to facilitate efficient
|
// Id is a unique identifier for the BufferValue to facilitate efficient
|
||||||
// collections of BufferValues with stable iteration order.
|
// collections of BufferValues with stable iteration order.
|
||||||
@ -154,7 +153,7 @@ class BufferValue {
|
|||||||
static LogicalBufferProto::Location ToLocationProto(
|
static LogicalBufferProto::Location ToLocationProto(
|
||||||
const HloInstruction& instruction, const ShapeIndex& index);
|
const HloInstruction& instruction, const ShapeIndex& index);
|
||||||
|
|
||||||
const Color kInvalidColor = Color(-1);
|
const Color kInvalidColor = -1;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id);
|
BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id);
|
||||||
|
@ -91,8 +91,7 @@ string HloValue::ToShortString() const {
|
|||||||
return absl::StrFormat(
|
return absl::StrFormat(
|
||||||
"<%d %s%s%s%s>", id(), instruction()->name(),
|
"<%d %s%s%s%s>", id(), instruction()->name(),
|
||||||
instruction()->shape().IsTuple() ? index().ToString() : "",
|
instruction()->shape().IsTuple() ? index().ToString() : "",
|
||||||
is_phi() ? " (phi)" : "",
|
is_phi() ? " (phi)" : "", has_color() ? StrCat(" @", color()) : "");
|
||||||
has_color() ? StrCat(" @", color().value()) : "");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string HloValue::ToString(int indent) const {
|
string HloValue::ToString(int indent) const {
|
||||||
|
@ -34,7 +34,7 @@ LogicalBuffer::~LogicalBuffer() {}
|
|||||||
string LogicalBuffer::ToString() const {
|
string LogicalBuffer::ToString() const {
|
||||||
string color_string;
|
string color_string;
|
||||||
if (has_color()) {
|
if (has_color()) {
|
||||||
color_string = absl::StrCat(" @", color().value());
|
color_string = absl::StrCat(" @", color());
|
||||||
}
|
}
|
||||||
return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","),
|
return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","),
|
||||||
"](#", id(), color_string, ")");
|
"](#", id(), color_string, ")");
|
||||||
|
Loading…
x
Reference in New Issue
Block a user