[XLA] BufferValue::Color now type aliases int64.

PiperOrigin-RevId: 313404227
Change-Id: I2d393d426865c61ff210f10e3d9b8402a1813cf1
This commit is contained in:
Berkin Ilbeyi 2020-05-27 09:31:24 -07:00 committed by TensorFlower Gardener
parent be4d17c088
commit 96ba1c3609
6 changed files with 16 additions and 25 deletions

View File

@ -261,7 +261,7 @@ void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset,
Shape* shape = ShapeUtil::GetMutableSubshape(
position.instruction->mutable_shape(), position.index);
if (shape->has_layout()) {
shape->mutable_layout()->set_memory_space(buffer.color().value());
shape->mutable_layout()->set_memory_space(buffer.color());
}
}
}
@ -272,7 +272,7 @@ BufferAllocationProto BufferAllocation::ToProto() const {
proto.set_size(size_);
proto.set_is_thread_local(is_thread_local_);
proto.set_is_tuple(is_tuple_);
proto.set_color(color_.value());
proto.set_color(color_);
if (is_entry_computation_parameter_) {
proto.set_is_entry_computation_parameter(true);
for (int64 idx : param_shape_index()) {
@ -336,8 +336,8 @@ static const HloInstruction* GetOutputInstruction(
string BufferAllocation::ToString() const {
string output;
StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
if (color().value() != 0) {
StrAppend(&output, ", color ", color().value());
if (color() != 0) {
StrAppend(&output, ", color ", color());
}
if (is_entry_computation_parameter()) {
const HloInstruction* param = GetEntryParameterInstruction(*this);
@ -607,9 +607,7 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation,
// BufferAllocation.
void BufferAssignment::CombineTempAllocations() {
VLOG(1) << "CombineTempAllocations()";
flat_hash_map<BufferValue::Color, BufferAllocation,
BufferValue::Color::Hasher>
combined_allocation_map;
flat_hash_map<BufferValue::Color, BufferAllocation> combined_allocation_map;
// Move all temp allocations into a single run at the end of the allocations
// vector.
@ -1059,8 +1057,8 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
// The instruction or operand color is excluded because it was assigned by
// memory_space_assignment.
if (excluded_colors.contains(instruction_buffer.color().value()) ||
excluded_colors.contains(operand_buffer.color().value())) {
if (excluded_colors.contains(instruction_buffer.color()) ||
excluded_colors.contains(operand_buffer.color())) {
continue;
}
@ -1353,13 +1351,10 @@ Status BufferAssigner::AssignBuffersForComputations(
return Status::OK();
}
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
LogicalBuffer::Color::Hasher>
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>>
BufferAssigner::SplitBuffersByColor(
const flat_hash_set<const HloValue*>& buffers) {
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
LogicalBuffer::Color::Hasher>
color_map;
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>> color_map;
for (auto buffer : buffers) {
color_map[buffer->color()].insert(buffer);
}
@ -1374,8 +1369,7 @@ Status BufferAssigner::AssignPresetBuffers(
}
// Create an allocation for each preset color.
absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*,
LogicalBuffer::Color::Hasher>
absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*>
preset_allocations;
for (auto& color_and_info : preset_assignments_->assignment_informations()) {
LogicalBuffer::Color color(color_and_info.first);

View File

@ -673,8 +673,7 @@ class BufferAssigner {
// Split a set of buffers into several sets, each of which contains buffers
// colored with the same color.
absl::flat_hash_map<LogicalBuffer::Color,
absl::flat_hash_set<const HloValue*>,
LogicalBuffer::Color::Hasher>
absl::flat_hash_set<const HloValue*>>
SplitBuffersByColor(const absl::flat_hash_set<const HloValue*>& buffers);
// If true, allocate buffers for constant instructions.

View File

@ -59,7 +59,7 @@ LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const {
ToLocationProto(*instruction(), index());
proto.mutable_defined_at()->Swap(&proto_location);
if (has_color()) {
proto.set_color(color().value());
proto.set_color(color());
}
return proto;
}

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/int_type.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@ -86,7 +85,7 @@ namespace xla {
class BufferValue {
public:
TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64);
using Color = int64;
// Id is a unique identifier for the BufferValue to facilitate efficient
// collections of BufferValues with stable iteration order.
@ -154,7 +153,7 @@ class BufferValue {
static LogicalBufferProto::Location ToLocationProto(
const HloInstruction& instruction, const ShapeIndex& index);
const Color kInvalidColor = Color(-1);
const Color kInvalidColor = -1;
protected:
BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id);

View File

@ -91,8 +91,7 @@ string HloValue::ToShortString() const {
return absl::StrFormat(
"<%d %s%s%s%s>", id(), instruction()->name(),
instruction()->shape().IsTuple() ? index().ToString() : "",
is_phi() ? " (phi)" : "",
has_color() ? StrCat(" @", color().value()) : "");
is_phi() ? " (phi)" : "", has_color() ? StrCat(" @", color()) : "");
}
string HloValue::ToString(int indent) const {

View File

@ -34,7 +34,7 @@ LogicalBuffer::~LogicalBuffer() {}
string LogicalBuffer::ToString() const {
string color_string;
if (has_color()) {
color_string = absl::StrCat(" @", color().value());
color_string = absl::StrCat(" @", color());
}
return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","),
"](#", id(), color_string, ")");