[XLA] BufferValue::Color now type aliases int64.
PiperOrigin-RevId: 313404227 Change-Id: I2d393d426865c61ff210f10e3d9b8402a1813cf1
This commit is contained in:
parent
be4d17c088
commit
96ba1c3609
@ -261,7 +261,7 @@ void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset,
|
||||
Shape* shape = ShapeUtil::GetMutableSubshape(
|
||||
position.instruction->mutable_shape(), position.index);
|
||||
if (shape->has_layout()) {
|
||||
shape->mutable_layout()->set_memory_space(buffer.color().value());
|
||||
shape->mutable_layout()->set_memory_space(buffer.color());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -272,7 +272,7 @@ BufferAllocationProto BufferAllocation::ToProto() const {
|
||||
proto.set_size(size_);
|
||||
proto.set_is_thread_local(is_thread_local_);
|
||||
proto.set_is_tuple(is_tuple_);
|
||||
proto.set_color(color_.value());
|
||||
proto.set_color(color_);
|
||||
if (is_entry_computation_parameter_) {
|
||||
proto.set_is_entry_computation_parameter(true);
|
||||
for (int64 idx : param_shape_index()) {
|
||||
@ -336,8 +336,8 @@ static const HloInstruction* GetOutputInstruction(
|
||||
string BufferAllocation::ToString() const {
|
||||
string output;
|
||||
StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
|
||||
if (color().value() != 0) {
|
||||
StrAppend(&output, ", color ", color().value());
|
||||
if (color() != 0) {
|
||||
StrAppend(&output, ", color ", color());
|
||||
}
|
||||
if (is_entry_computation_parameter()) {
|
||||
const HloInstruction* param = GetEntryParameterInstruction(*this);
|
||||
@ -607,9 +607,7 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation,
|
||||
// BufferAllocation.
|
||||
void BufferAssignment::CombineTempAllocations() {
|
||||
VLOG(1) << "CombineTempAllocations()";
|
||||
flat_hash_map<BufferValue::Color, BufferAllocation,
|
||||
BufferValue::Color::Hasher>
|
||||
combined_allocation_map;
|
||||
flat_hash_map<BufferValue::Color, BufferAllocation> combined_allocation_map;
|
||||
|
||||
// Move all temp allocations into a single run at the end of the allocations
|
||||
// vector.
|
||||
@ -1059,8 +1057,8 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
||||
|
||||
// The instruction or operand color is excluded because it was assigned by
|
||||
// memory_space_assignment.
|
||||
if (excluded_colors.contains(instruction_buffer.color().value()) ||
|
||||
excluded_colors.contains(operand_buffer.color().value())) {
|
||||
if (excluded_colors.contains(instruction_buffer.color()) ||
|
||||
excluded_colors.contains(operand_buffer.color())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1353,13 +1351,10 @@ Status BufferAssigner::AssignBuffersForComputations(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
|
||||
LogicalBuffer::Color::Hasher>
|
||||
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>>
|
||||
BufferAssigner::SplitBuffersByColor(
|
||||
const flat_hash_set<const HloValue*>& buffers) {
|
||||
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
|
||||
LogicalBuffer::Color::Hasher>
|
||||
color_map;
|
||||
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>> color_map;
|
||||
for (auto buffer : buffers) {
|
||||
color_map[buffer->color()].insert(buffer);
|
||||
}
|
||||
@ -1374,8 +1369,7 @@ Status BufferAssigner::AssignPresetBuffers(
|
||||
}
|
||||
|
||||
// Create an allocation for each preset color.
|
||||
absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*,
|
||||
LogicalBuffer::Color::Hasher>
|
||||
absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*>
|
||||
preset_allocations;
|
||||
for (auto& color_and_info : preset_assignments_->assignment_informations()) {
|
||||
LogicalBuffer::Color color(color_and_info.first);
|
||||
|
@ -673,8 +673,7 @@ class BufferAssigner {
|
||||
// Split a set of buffers into several sets, each of which contains buffers
|
||||
// colored with the same color.
|
||||
absl::flat_hash_map<LogicalBuffer::Color,
|
||||
absl::flat_hash_set<const HloValue*>,
|
||||
LogicalBuffer::Color::Hasher>
|
||||
absl::flat_hash_set<const HloValue*>>
|
||||
SplitBuffersByColor(const absl::flat_hash_set<const HloValue*>& buffers);
|
||||
|
||||
// If true, allocate buffers for constant instructions.
|
||||
|
@ -59,7 +59,7 @@ LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const {
|
||||
ToLocationProto(*instruction(), index());
|
||||
proto.mutable_defined_at()->Swap(&proto_location);
|
||||
if (has_color()) {
|
||||
proto.set_color(color().value());
|
||||
proto.set_color(color());
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/int_type.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -86,7 +85,7 @@ namespace xla {
|
||||
|
||||
class BufferValue {
|
||||
public:
|
||||
TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64);
|
||||
using Color = int64;
|
||||
|
||||
// Id is a unique identifier for the BufferValue to facilitate efficient
|
||||
// collections of BufferValues with stable iteration order.
|
||||
@ -154,7 +153,7 @@ class BufferValue {
|
||||
static LogicalBufferProto::Location ToLocationProto(
|
||||
const HloInstruction& instruction, const ShapeIndex& index);
|
||||
|
||||
const Color kInvalidColor = Color(-1);
|
||||
const Color kInvalidColor = -1;
|
||||
|
||||
protected:
|
||||
BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id);
|
||||
|
@ -91,8 +91,7 @@ string HloValue::ToShortString() const {
|
||||
return absl::StrFormat(
|
||||
"<%d %s%s%s%s>", id(), instruction()->name(),
|
||||
instruction()->shape().IsTuple() ? index().ToString() : "",
|
||||
is_phi() ? " (phi)" : "",
|
||||
has_color() ? StrCat(" @", color().value()) : "");
|
||||
is_phi() ? " (phi)" : "", has_color() ? StrCat(" @", color()) : "");
|
||||
}
|
||||
|
||||
string HloValue::ToString(int indent) const {
|
||||
|
@ -34,7 +34,7 @@ LogicalBuffer::~LogicalBuffer() {}
|
||||
string LogicalBuffer::ToString() const {
|
||||
string color_string;
|
||||
if (has_color()) {
|
||||
color_string = absl::StrCat(" @", color().value());
|
||||
color_string = absl::StrCat(" @", color());
|
||||
}
|
||||
return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","),
|
||||
"](#", id(), color_string, ")");
|
||||
|
Loading…
x
Reference in New Issue
Block a user