Replace Layout and Tile protos with C++ classes in XLA.

No functional change. Rename the proto message Layout to LayoutProto, and Tile to TileProto. Define in-place replacement C++ classes named Layout and Tile with an interface which mirrors the protobuf generated code interface. Having these data structures as C++ classes enables greater flexibility in the interface, enables enforcement of invariants, and potential performance improvements.

PiperOrigin-RevId: 225121052
This commit is contained in:
Mark Heffernan 2018-12-11 20:56:59 -08:00 committed by TensorFlower Gardener
parent c6245fa0b4
commit 2381067873
17 changed files with 450 additions and 74 deletions

View File

@ -224,6 +224,7 @@ cc_library(
name = "shape_util",
srcs = [
"index_util.cc",
"layout.cc",
"layout_util.cc",
"primitive_util.cc",
"shape.cc",
@ -231,6 +232,7 @@ cc_library(
],
hdrs = [
"index_util.h",
"layout.h",
"layout_util.h",
"primitive_util.h",
"shape.h",
@ -301,6 +303,22 @@ tf_cc_test(
],
)
tf_cc_test(
name = "layout_test",
srcs = ["layout_test.cc"],
deps = [
":shape_util",
":status_macros",
":test",
":test_helpers",
":types",
":util",
":xla_data_proto",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "index_util_test",
srcs = ["index_util_test.cc"],

View File

@ -186,7 +186,7 @@ StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
ComputeConstantGraphRequest request;
*request.mutable_computation() = computation.proto();
if (output_layout != nullptr) {
*request.mutable_output_layout() = *output_layout;
*request.mutable_output_layout() = output_layout->ToProto();
}
ComputeConstantResponse response;

View File

@ -0,0 +1,96 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/layout.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
namespace xla {
TileProto Tile::ToProto() const {
TileProto tile_proto;
for (int64 i : dimensions()) {
tile_proto.add_dimensions(i);
}
return tile_proto;
}
string Tile::ToString() const {
return absl::StrCat("(", absl::StrJoin(dimensions(), ","), ")");
}
/* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) {
Layout layout;
layout.set_format(proto.format());
layout.minor_to_major_.reserve(proto.minor_to_major_size());
for (const int64 dimension : proto.minor_to_major()) {
layout.add_minor_to_major(dimension);
}
layout.set_max_sparse_elements(proto.max_sparse_elements());
for (const TileProto& tile_proto : proto.tiles()) {
*layout.add_tiles() = Tile::CreateFromProto(tile_proto);
}
layout.set_element_size_in_bits(proto.element_size_in_bits());
return layout;
}
LayoutProto Layout::ToProto() const {
LayoutProto proto;
proto.set_format(format_);
proto.mutable_minor_to_major()->Reserve(minor_to_major_size());
for (const int64 dimension : minor_to_major()) {
proto.add_minor_to_major(dimension);
}
proto.set_max_sparse_elements(max_sparse_elements_);
for (const Tile& tile : tiles()) {
*proto.add_tiles() = tile.ToProto();
}
proto.set_element_size_in_bits(element_size_in_bits());
return proto;
}
string Layout::ToString() const {
// TODO(b/119839262): Emit tiles in string.
if (format() == SPARSE) {
return absl::StrCat("sparse{", max_sparse_elements(), "}");
} else if (format() == DENSE) {
return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), "}");
} else {
CHECK_EQ(format(), INVALID_FORMAT);
return "invalid{}";
}
}
bool Layout::operator==(const Layout& other) const {
return (other.format() == format() &&
other.minor_to_major() == minor_to_major() &&
other.element_size_in_bits() == element_size_in_bits() &&
other.max_sparse_elements() == max_sparse_elements() &&
other.tiles() == tiles());
}
std::ostream& operator<<(std::ostream& out, const Tile& tile) {
out << tile.ToString();
return out;
}
std::ostream& operator<<(std::ostream& out, const Layout& layout) {
out << layout.ToString();
return out;
}
} // namespace xla

View File

@ -0,0 +1,187 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_
#define TENSORFLOW_COMPILER_XLA_LAYOUT_H_
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
// Describes a tile used in tiling-based layout. Refer to
// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for
// details.
class Tile {
public:
Tile() = default;
explicit Tile(absl::Span<const int64> dimensions)
: dimensions_(dimensions.begin(), dimensions.end()) {}
// De/Serialize a Tile to and from a TileProto.
static Tile CreateFromProto(const TileProto& tile_proto) {
return Tile(AsInt64Slice(tile_proto.dimensions()));
}
TileProto ToProto() const;
bool operator==(const Tile& other) const {
return dimensions() == other.dimensions();
}
bool operator!=(const Tile& other) const { return !(*this == other); }
string ToString() const;
// Returns the bound of the tile in the given dimension index.
int64 dimension(int i) const { return dimensions_.at(i); }
// Returns the dimensions of the tile.
const std::vector<int64>& dimensions() const { return dimensions_; }
private:
// The bounds of the tile.
std::vector<int64> dimensions_;
};
class Layout {
public:
Layout() = default;
// Constructs a dense layout with the given minor-to-major order.
explicit Layout(absl::Span<const int64> minor_to_major)
: format_(DENSE),
minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {}
// Constructs a dense tiled layout with the given minor-to-major order and
// tiles.
Layout(absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles)
: format_(DENSE),
minor_to_major_(minor_to_major.begin(), minor_to_major.end()),
tiles_(tiles.begin(), tiles.end()) {}
// Construct a shape from a LayoutProto.
static Layout CreateFromProto(const LayoutProto& proto);
// Returns a LayoutProto representation of the Layout.
LayoutProto ToProto() const;
// Returns a human-readable string that represents this layout.
string ToString() const;
bool operator==(const Layout& other) const;
bool operator!=(const Layout& other) const { return !(*this == other); }
// The following methods mirror the protobuf generated code interface for the
// message LayoutProto. This enabled easy migration of this data structure
// from a proto to a proper C++ class.
//
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
// interface.
// Methods for accessing the format.
Format format() const { return format_; }
Layout& set_format(Format value) {
format_ = value;
return *this;
}
// Methods for accessing the minor-to-major array.
int minor_to_major_size() const { return minor_to_major_.size(); }
int64 minor_to_major(int index) const { return minor_to_major_.at(index); }
Layout& set_minor_to_major(int index, int64 value) {
minor_to_major_.at(index) = value;
return *this;
}
Layout& add_minor_to_major(int64 value) {
minor_to_major_.push_back(value);
return *this;
}
Layout& clear_minor_to_major() {
minor_to_major_.clear();
return *this;
}
const std::vector<int64>& minor_to_major() const { return minor_to_major_; }
std::vector<int64>* mutable_minor_to_major() { return &minor_to_major_; }
// Methods for accessing the tile field.
int tiles_size() const { return tiles_.size(); }
const Tile& tiles(int index) const { return tiles_.at(index); }
Tile* mutable_tiles(int index) { return &tiles_.at(index); }
Tile* add_tiles() {
tiles_.push_back(Tile());
return &tiles_.back();
}
Layout& clear_tiles() {
tiles_.clear();
return *this;
}
const std::vector<Tile>& tiles() const { return tiles_; }
std::vector<Tile>* mutable_tiles() { return &tiles_; }
// Methods for accessing the int64 fields.
int64 max_sparse_elements() const { return max_sparse_elements_; }
Layout& set_max_sparse_elements(int64 value) {
max_sparse_elements_ = value;
return *this;
}
int64 element_size_in_bits() const { return element_size_in_bits_; }
Layout& set_element_size_in_bits(int64 value) {
element_size_in_bits_ = value;
return *this;
}
void Swap(Layout* other) {
using std::swap;
swap(*this, *other);
}
void Clear() {
format_ = INVALID_FORMAT;
minor_to_major_.clear();
max_sparse_elements_ = 0;
element_size_in_bits_ = 0;
}
public:
// The format of this layout.
Format format_ = INVALID_FORMAT;
// Sequence of dimension numbers, from minor (fastest varying index) to major
// (slowest varying index).
std::vector<int64> minor_to_major_;
// The maximum number of elements that can be stored for SPARSE formats. This
// can be used to determine the maximum size in bytes of arrays stored in
// memory. This field must be zero unless the format is SPARSE.
int64 max_sparse_elements_ = 0;
// The number of bits used to store an individual array element.
int64 element_size_in_bits_ = 0;
// The tiles used in tiling-based layout.
std::vector<Tile> tiles_;
};
std::ostream& operator<<(std::ostream& out, const Tile& Tile);
std::ostream& operator<<(std::ostream& out, const Layout& layout);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_

View File

@ -0,0 +1,104 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/layout.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
class LayoutTest : public ::testing::Test {};
TEST_F(LayoutTest, ToString) {
EXPECT_EQ(Layout().ToString(), "invalid{}");
EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}");
EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(),
"sparse{123}");
EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}");
EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(),
"{3,2,1,0}");
EXPECT_EQ(
Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(),
"{1,0}");
}
TEST_F(LayoutTest, StreamOut) {
{
std::ostringstream oss;
oss << Tile({7, 8});
EXPECT_EQ(oss.str(), "(7,8)");
}
{
std::ostringstream oss;
oss << Layout({0, 1, 2});
EXPECT_EQ(oss.str(), "{0,1,2}");
}
}
TEST_F(LayoutTest, SparseLayoutMaxElements) {
EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)),
101);
}
TEST_F(LayoutTest, Equality) {
EXPECT_EQ(Layout(), Layout());
const std::vector<int64> empty_dims;
EXPECT_EQ(Layout(empty_dims), Layout(empty_dims));
EXPECT_NE(Layout(), Layout(empty_dims));
EXPECT_EQ(Layout({0, 1, 2, 3}), Layout({0, 1, 2, 3}));
EXPECT_NE(Layout({0, 1, 2, 3}), Layout({0, 1, 2}));
EXPECT_EQ(Layout({0, 1, 2}, {Tile({42, 44})}),
Layout({0, 1, 2}, {Tile({42, 44})}));
EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}),
Layout({0, 1, 2}, {Tile({42, 45})}));
EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2, 3}));
EXPECT_EQ(Layout({0, 1, 2}).set_element_size_in_bits(33),
Layout({0, 1, 2}).set_element_size_in_bits(33));
EXPECT_NE(Layout({0, 1, 2}).set_element_size_in_bits(33),
Layout({0, 1, 2}).set_element_size_in_bits(7));
EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE));
EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42),
Layout().set_format(SPARSE).set_max_sparse_elements(42));
EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42),
Layout().set_format(SPARSE).set_max_sparse_elements(24));
}
TEST_F(LayoutTest, LayoutToFromProto) {
// Round-trips a Layout through proto de/serialization.
auto expect_unchanged = [](const Layout& layout) {
EXPECT_EQ(layout, Layout::CreateFromProto(layout.ToProto()));
};
expect_unchanged(Layout());
expect_unchanged(Layout({1, 3, 2, 0}));
expect_unchanged(Layout().set_format(SPARSE));
expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123));
expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42));
expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}));
}
} // namespace
} // namespace xla

View File

@ -41,15 +41,13 @@ namespace {
// Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
// minor_to_major to the value that represents the default layout.
void SetDefaultLayoutToContainer(
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
minor_to_major) {
void SetDefaultLayoutToContainer(std::vector<int64>* minor_to_major) {
// The default XLA layout is major-to-minor (dim 0 is major).
// For more information on XLA layouts, see:
// https://www.tensorflow.org/performance/xla/shapes
const int64 size = minor_to_major->size();
for (int64 i = 0; i < size; ++i) {
minor_to_major->Set(i, size - 1 - i);
(*minor_to_major)[i] = size - 1 - i;
}
}
@ -94,9 +92,8 @@ namespace {
Layout CreateDefaultLayoutForRank(int64 rank) {
Layout layout;
layout.set_format(DENSE);
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
minor_to_major = layout.mutable_minor_to_major();
minor_to_major->Resize(rank, 0);
std::vector<int64>* minor_to_major = layout.mutable_minor_to_major();
minor_to_major->resize(rank, 0);
SetDefaultLayoutToContainer(minor_to_major);
return layout;
}
@ -139,9 +136,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
shape->clear_layout();
} else if (ShapeUtil::IsArray(*shape)) {
shape->mutable_layout()->set_format(DENSE);
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
minor_to_major->Resize(shape->dimensions_size(), 0);
auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
minor_to_major->resize(shape->dimensions_size(), 0);
SetDefaultLayoutToContainer(minor_to_major);
} else {
// Opaque, token types etc. have no layout.
@ -210,9 +206,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
return InvalidArgument(
"Layout has an invalid format (%d) in layout {%s}, shape {%s}",
layout.format(), layout.ShortDebugString(), shape.ShortDebugString());
return InvalidArgument("Layout has an invalid format (%d)",
layout.format());
}
if (layout.format() == DENSE) {
@ -316,7 +311,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
/* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
return protobuf_util::ProtobufEquals(lhs, rhs);
return lhs == rhs;
}
/* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
@ -358,11 +353,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
/* static */ string LayoutUtil::HumanString(const Layout& layout) {
if (IsSparse(layout)) {
return absl::StrCat("sparse{", layout.max_sparse_elements(), "}");
}
CHECK(IsDense(layout));
return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}");
return layout.ToString();
}
namespace {
@ -444,11 +435,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
return true;
}
std::ostream& operator<<(std::ostream& out, const Layout& layout) {
out << LayoutUtil::HumanString(layout);
return out;
}
/*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
using tensorflow::hash;
using tensorflow::Hash64Combine;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
@ -195,8 +196,6 @@ class LayoutUtil {
TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil);
};
std::ostream& operator<<(std::ostream& out, const Layout& layout);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_

View File

@ -317,17 +317,6 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
}
TEST_F(LayoutUtilTest, SparseLayoutMaxElements) {
EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)),
101);
}
TEST_F(LayoutUtilTest, StreamOut) {
std::ostringstream oss;
oss << LayoutUtil::MakeLayout({0, 1, 2});
EXPECT_EQ(oss.str(), "{0,1,2}");
}
TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) {
Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1});
auto status =

View File

@ -42,8 +42,7 @@ PackedLiteralReader::~PackedLiteralReader() { delete file_; }
StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
const Layout* layout) {
VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
<< " layout: "
<< (layout == nullptr ? "<none>" : layout->ShortDebugString());
<< " layout: " << (layout == nullptr ? "<none>" : layout->ToString());
Shape literal_shape = shape;
if (layout != nullptr) {
TF_RETURN_IF_ERROR(

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -1078,9 +1078,11 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
ProgramShape program_shape(arg->computation().host_program_shape());
TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
absl::optional<Layout> output_layout;
if (arg->has_output_layout()) {
output_layout = Layout::CreateFromProto(arg->output_layout());
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
arg->output_layout(), program_shape.result()));
*output_layout, program_shape.result()));
}
HloModuleConfig config(program_shape);
@ -1096,8 +1098,8 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
// relayout here.
//
// TODO(b/77824332): Make HloEvaluator take care of the re-layout.
if (arg->has_output_layout()) {
result_literal = result_literal.Relayout(arg->output_layout());
if (output_layout.has_value()) {
result_literal = result_literal.Relayout(*output_layout);
}
*result->mutable_literal() = result_literal.ToProto();

View File

@ -32,7 +32,7 @@ Shape::Shape(const ShapeProto& shape_proto) {
*add_tuple_shapes() = Shape(element_shape);
}
if (shape_proto.has_layout()) {
*mutable_layout() = shape_proto.layout();
*mutable_layout() = Layout::CreateFromProto(shape_proto.layout());
}
}
@ -48,7 +48,7 @@ ShapeProto Shape::ToProto() const {
*proto.add_tuple_shapes() = shape.ToProto();
}
if (has_layout()) {
*proto.mutable_layout() = layout();
*proto.mutable_layout() = layout().ToProto();
}
return proto;
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
@ -76,21 +77,10 @@ class Shape {
std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
// Methods for accessing the layout field.
bool has_layout() const { return layout_.has_value(); }
const Layout& layout() const {
if (layout_.has_value()) {
return *layout_;
} else {
return Layout::default_instance();
}
}
Layout* mutable_layout() {
if (!layout_.has_value()) {
layout_ = Layout();
}
return &layout_.value();
}
void clear_layout() { layout_.reset(); }
bool has_layout() const { return layout_.format() != INVALID_FORMAT; }
const Layout& layout() const { return layout_; }
Layout* mutable_layout() { return &layout_; }
void clear_layout() { layout_.Clear(); }
void Swap(Shape* other) {
using std::swap;
@ -101,7 +91,7 @@ class Shape {
element_type_ = PRIMITIVE_TYPE_INVALID;
dimensions_.clear();
tuple_shapes_.clear();
layout_.reset();
clear_layout();
}
string SerializeAsString() const { return ToProto().SerializeAsString(); }
@ -118,8 +108,8 @@ class Shape {
// The tuple element subshapes. This is nonempty only for tuple shapes.
std::vector<Shape> tuple_shapes_;
// The array layout of the shape. This is present only for array shapes.
absl::optional<Layout> layout_;
// The layout of the shape. Only relevant for arrays.
Layout layout_;
};
// Shape of the parameters and output of an XLA computation. This is analogous

View File

@ -164,9 +164,9 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeUtil::MakeValidatedShape(element_type, dimensions));
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
min2maj->Clear();
min2maj->clear();
for (int64 value : minor_to_major) {
min2maj->Add(value);
min2maj->push_back(value);
}
if (!shape.has_layout()) {
return InvalidArgument("Shape has no layout.");
@ -1618,10 +1618,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
if (LayoutUtil::HasLayout(shape)) {
Layout* layout = shape.mutable_layout();
layout->set_format(DENSE);
for (size_t i = 0; i < layout->minor_to_major().size();) {
for (int64 i = 0; i < layout->minor_to_major().size();) {
if (layout->minor_to_major(i) == dim_to_delete) {
layout->mutable_minor_to_major()->erase(
layout->minor_to_major().begin() + i);
layout->mutable_minor_to_major()->begin() + i);
continue;
}
if (layout->minor_to_major(i) > dim_to_delete) {

View File

@ -133,7 +133,9 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
// Reverse the minor-to-major order of the literal.
Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
ASSERT_EQ(2, literal_layout->minor_to_major_size());
literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
// Swap the first and second elements.
*literal_layout->mutable_minor_to_major() = {
literal_layout->minor_to_major(1), literal_layout->minor_to_major(0)};
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));

View File

@ -399,7 +399,7 @@ message WaitForExecutionResponse {
message ComputeConstantGraphRequest {
HloModuleProto computation = 1;
Layout output_layout = 2;
LayoutProto output_layout = 2;
}
message ComputeConstantResponse {

View File

@ -100,6 +100,8 @@ message PaddingConfig {
// A format specifies the method used by a layout to store an array in memory.
enum Format {
// TODO(b/120869032): Rename this to FORMAT_NONE or something else which
// better corresponds to its meaning.
INVALID_FORMAT = 0;
// The default layout, with exactly one storage location per element.
DENSE = 1;
@ -109,8 +111,9 @@ enum Format {
}
// Describes a tile used in tiling-based layout. Refer to
// g3doc/layout_with_tiling.md for details about tiling-based layout.
message Tile {
// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for
// details about tiling-based layout.
message TileProto {
// Number of elements in each dimension of the tile. It's ordered from the
// most major dimension of the tile to the most minor dimension of the tile.
// The dimensions correspond to a suffix of the dimensions of the shape being
@ -128,7 +131,7 @@ message Tile {
// See the XLA documentation for more information on shapes and layouts.
//
// LINT.IfChange
message Layout {
message LayoutProto {
// The method used to store the data in memory. The format determines which of
// the other fields are used by the layout.
Format format = 4;
@ -153,7 +156,7 @@ message Layout {
//
// TODO(b/119839262): implement tiling in each backend or add Unimplemented
// error.
repeated Tile tiles = 6;
repeated TileProto tiles = 6;
// Bit size of each element. If the size is bigger than what the element
// type requires, the value is stored in the least significant
@ -196,7 +199,7 @@ message ShapeProto {
repeated ShapeProto tuple_shapes = 4;
// The layout used to back this shape.
Layout layout = 5;
LayoutProto layout = 5;
// Important: if any field is added, be sure to modify ShapeUtil::Equal(),
// ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for