diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 4360e085796..19f12569ff9 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -224,6 +224,7 @@ cc_library( name = "shape_util", srcs = [ "index_util.cc", + "layout.cc", "layout_util.cc", "primitive_util.cc", "shape.cc", @@ -231,6 +232,7 @@ cc_library( ], hdrs = [ "index_util.h", + "layout.h", "layout_util.h", "primitive_util.h", "shape.h", @@ -301,6 +303,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "layout_test", + srcs = ["layout_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "index_util_test", srcs = ["index_util_test.cc"], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 74b76f92994..43127cae1e5 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -186,7 +186,7 @@ StatusOr Client::ComputeConstant(const XlaComputation& computation, ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; + *request.mutable_output_layout() = output_layout->ToProto(); } ComputeConstantResponse response; diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc new file mode 100644 index 00000000000..e3b5fcd5274 --- /dev/null +++ b/tensorflow/compiler/xla/layout.cc @@ -0,0 +1,96 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { + +TileProto Tile::ToProto() const { + TileProto tile_proto; + for (int64 i : dimensions()) { + tile_proto.add_dimensions(i); + } + return tile_proto; +} + +string Tile::ToString() const { + return absl::StrCat("(", absl::StrJoin(dimensions(), ","), ")"); +} + +/* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) { + Layout layout; + layout.set_format(proto.format()); + layout.minor_to_major_.reserve(proto.minor_to_major_size()); + for (const int64 dimension : proto.minor_to_major()) { + layout.add_minor_to_major(dimension); + } + layout.set_max_sparse_elements(proto.max_sparse_elements()); + for (const TileProto& tile_proto : proto.tiles()) { + *layout.add_tiles() = Tile::CreateFromProto(tile_proto); + } + layout.set_element_size_in_bits(proto.element_size_in_bits()); + return layout; +} + +LayoutProto Layout::ToProto() const { + LayoutProto proto; + proto.set_format(format_); + proto.mutable_minor_to_major()->Reserve(minor_to_major_size()); + for (const int64 dimension : minor_to_major()) { + proto.add_minor_to_major(dimension); + } + proto.set_max_sparse_elements(max_sparse_elements_); + for (const Tile& tile : tiles()) { + *proto.add_tiles() = tile.ToProto(); + } + proto.set_element_size_in_bits(element_size_in_bits()); + return proto; +} + +string Layout::ToString() const { + // TODO(b/119839262): Emit tiles in string. + if (format() == SPARSE) { + return absl::StrCat("sparse{", max_sparse_elements(), "}"); + } else if (format() == DENSE) { + return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), "}"); + } else { + CHECK_EQ(format(), INVALID_FORMAT); + return "invalid{}"; + } +} + +bool Layout::operator==(const Layout& other) const { + return (other.format() == format() && + other.minor_to_major() == minor_to_major() && + other.element_size_in_bits() == element_size_in_bits() && + other.max_sparse_elements() == max_sparse_elements() && + other.tiles() == tiles()); +} + +std::ostream& operator<<(std::ostream& out, const Tile& tile) { + out << tile.ToString(); + return out; +} + +std::ostream& operator<<(std::ostream& out, const Layout& layout) { + out << layout.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h new file mode 100644 index 00000000000..313368c39e4 --- /dev/null +++ b/tensorflow/compiler/xla/layout.h @@ -0,0 +1,187 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ +#define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ + +#include + +#include "absl/types/span.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details. +class Tile { + public: + Tile() = default; + explicit Tile(absl::Span dimensions) + : dimensions_(dimensions.begin(), dimensions.end()) {} + + // De/Serialize a Tile to and from a TileProto. + static Tile CreateFromProto(const TileProto& tile_proto) { + return Tile(AsInt64Slice(tile_proto.dimensions())); + } + TileProto ToProto() const; + + bool operator==(const Tile& other) const { + return dimensions() == other.dimensions(); + } + bool operator!=(const Tile& other) const { return !(*this == other); } + + string ToString() const; + + // Returns the bound of the tile in the given dimension index. + int64 dimension(int i) const { return dimensions_.at(i); } + + // Returns the dimensions of the tile. + const std::vector& dimensions() const { return dimensions_; } + + private: + // The bounds of the tile. + std::vector dimensions_; +}; + +class Layout { + public: + Layout() = default; + + // Constructs a dense layout with the given minor-to-major order. + explicit Layout(absl::Span minor_to_major) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} + + // Constructs a dense tiled layout with the given minor-to-major order and + // tiles. + Layout(absl::Span minor_to_major, absl::Span tiles) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()), + tiles_(tiles.begin(), tiles.end()) {} + + // Construct a shape from a LayoutProto. + static Layout CreateFromProto(const LayoutProto& proto); + + // Returns a LayoutProto representation of the Layout. + LayoutProto ToProto() const; + + // Returns a human-readable string that represents this layout. + string ToString() const; + + bool operator==(const Layout& other) const; + bool operator!=(const Layout& other) const { return !(*this == other); } + + // The following methods mirror the protobuf generated code interface for the + // message LayoutProto. This enabled easy migration of this data structure + // from a proto to a proper C++ class. + // + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing the format. + Format format() const { return format_; } + Layout& set_format(Format value) { + format_ = value; + return *this; + } + + // Methods for accessing the minor-to-major array. + int minor_to_major_size() const { return minor_to_major_.size(); } + int64 minor_to_major(int index) const { return minor_to_major_.at(index); } + Layout& set_minor_to_major(int index, int64 value) { + minor_to_major_.at(index) = value; + return *this; + } + Layout& add_minor_to_major(int64 value) { + minor_to_major_.push_back(value); + return *this; + } + Layout& clear_minor_to_major() { + minor_to_major_.clear(); + return *this; + } + const std::vector& minor_to_major() const { return minor_to_major_; } + std::vector* mutable_minor_to_major() { return &minor_to_major_; } + + // Methods for accessing the tile field. + int tiles_size() const { return tiles_.size(); } + const Tile& tiles(int index) const { return tiles_.at(index); } + Tile* mutable_tiles(int index) { return &tiles_.at(index); } + Tile* add_tiles() { + tiles_.push_back(Tile()); + return &tiles_.back(); + } + Layout& clear_tiles() { + tiles_.clear(); + return *this; + } + const std::vector& tiles() const { return tiles_; } + std::vector* mutable_tiles() { return &tiles_; } + + // Methods for accessing the int64 fields. + int64 max_sparse_elements() const { return max_sparse_elements_; } + Layout& set_max_sparse_elements(int64 value) { + max_sparse_elements_ = value; + return *this; + } + int64 element_size_in_bits() const { return element_size_in_bits_; } + Layout& set_element_size_in_bits(int64 value) { + element_size_in_bits_ = value; + return *this; + } + + void Swap(Layout* other) { + using std::swap; + swap(*this, *other); + } + + void Clear() { + format_ = INVALID_FORMAT; + minor_to_major_.clear(); + max_sparse_elements_ = 0; + element_size_in_bits_ = 0; + } + + public: + // The format of this layout. + Format format_ = INVALID_FORMAT; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). + std::vector minor_to_major_; + + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be zero unless the format is SPARSE. + int64 max_sparse_elements_ = 0; + + // The number of bits used to store an individual array element. + int64 element_size_in_bits_ = 0; + + // The tiles used in tiling-based layout. + std::vector tiles_; +}; + +std::ostream& operator<<(std::ostream& out, const Tile& Tile); +std::ostream& operator<<(std::ostream& out, const Layout& layout); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc new file mode 100644 index 00000000000..fb6abd3f652 --- /dev/null +++ b/tensorflow/compiler/xla/layout_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class LayoutTest : public ::testing::Test {}; + +TEST_F(LayoutTest, ToString) { + EXPECT_EQ(Layout().ToString(), "invalid{}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(), + "sparse{123}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(), + "{3,2,1,0}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0}"); +} + +TEST_F(LayoutTest, StreamOut) { + { + std::ostringstream oss; + oss << Tile({7, 8}); + EXPECT_EQ(oss.str(), "(7,8)"); + } + + { + std::ostringstream oss; + oss << Layout({0, 1, 2}); + EXPECT_EQ(oss.str(), "{0,1,2}"); + } +} + +TEST_F(LayoutTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + +TEST_F(LayoutTest, Equality) { + EXPECT_EQ(Layout(), Layout()); + const std::vector empty_dims; + EXPECT_EQ(Layout(empty_dims), Layout(empty_dims)); + EXPECT_NE(Layout(), Layout(empty_dims)); + EXPECT_EQ(Layout({0, 1, 2, 3}), Layout({0, 1, 2, 3})); + EXPECT_NE(Layout({0, 1, 2, 3}), Layout({0, 1, 2})); + EXPECT_EQ(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 44})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 45})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2, 3})); + EXPECT_EQ(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(33)); + EXPECT_NE(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(7)); + EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE)); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(42)); + EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(24)); +} + +TEST_F(LayoutTest, LayoutToFromProto) { + // Round-trips a Layout through proto de/serialization. + auto expect_unchanged = [](const Layout& layout) { + EXPECT_EQ(layout, Layout::CreateFromProto(layout.ToProto())); + }; + + expect_unchanged(Layout()); + expect_unchanged(Layout({1, 3, 2, 0})); + expect_unchanged(Layout().set_format(SPARSE)); + expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123)); + expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42)); + expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index dbb81381acd..ddccd8c798d 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -41,15 +41,13 @@ namespace { // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets // minor_to_major to the value that represents the default layout. -void SetDefaultLayoutToContainer( - tensorflow::protobuf::RepeatedField* - minor_to_major) { +void SetDefaultLayoutToContainer(std::vector* minor_to_major) { // The default XLA layout is major-to-minor (dim 0 is major). // For more information on XLA layouts, see: // https://www.tensorflow.org/performance/xla/shapes const int64 size = minor_to_major->size(); for (int64 i = 0; i < size; ++i) { - minor_to_major->Set(i, size - 1 - i); + (*minor_to_major)[i] = size - 1 - i; } } @@ -94,9 +92,8 @@ namespace { Layout CreateDefaultLayoutForRank(int64 rank) { Layout layout; layout.set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = layout.mutable_minor_to_major(); - minor_to_major->Resize(rank, 0); + std::vector* minor_to_major = layout.mutable_minor_to_major(); + minor_to_major->resize(rank, 0); SetDefaultLayoutToContainer(minor_to_major); return layout; } @@ -139,9 +136,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape->clear_layout(); } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); - minor_to_major->Resize(shape->dimensions_size(), 0); + auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); + minor_to_major->resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); } else { // Opaque, token types etc. have no layout. @@ -210,9 +206,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { - return InvalidArgument( - "Layout has an invalid format (%d) in layout {%s}, shape {%s}", - layout.format(), layout.ShortDebugString(), shape.ShortDebugString()); + return InvalidArgument("Layout has an invalid format (%d)", + layout.format()); } if (layout.format() == DENSE) { @@ -316,7 +311,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) { - return protobuf_util::ProtobufEquals(lhs, rhs); + return lhs == rhs; } /* static */ absl::Span LayoutUtil::MinorToMajor( @@ -358,11 +353,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ string LayoutUtil::HumanString(const Layout& layout) { - if (IsSparse(layout)) { - return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); - } - CHECK(IsDense(layout)); - return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); + return layout.ToString(); } namespace { @@ -444,11 +435,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return true; } -std::ostream& operator<<(std::ostream& out, const Layout& layout) { - out << LayoutUtil::HumanString(layout); - return out; -} - /*static*/ size_t LayoutUtil::Hash(const Layout& layout) { using tensorflow::hash; using tensorflow::Hash64Combine; diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6c298e57252..609dba67bcd 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -195,8 +196,6 @@ class LayoutUtil { TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; -std::ostream& operator<<(std::ostream& out, const Layout& layout); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 12ce2d2d7c6..4cc94c270cd 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -317,17 +317,6 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } -TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { - EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), - 101); -} - -TEST_F(LayoutUtilTest, StreamOut) { - std::ostringstream oss; - oss << LayoutUtil::MakeLayout({0, 1, 2}); - EXPECT_EQ(oss.str(), "{0,1,2}"); -} - TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); auto status = diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 0f86f9f35e1..339660cf44f 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -42,8 +42,7 @@ PackedLiteralReader::~PackedLiteralReader() { delete file_; } StatusOr PackedLiteralReader::Read(const Shape& shape, const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) - << " layout: " - << (layout == nullptr ? "" : layout->ShortDebugString()); + << " layout: " << (layout == nullptr ? "" : layout->ToString()); Shape literal_shape = shape; if (layout != nullptr) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 1fc46bafa10..92e4d6dbbc1 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 5ec7fe2aded..ae5bd93e7c5 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1078,9 +1078,11 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); + absl::optional output_layout; if (arg->has_output_layout()) { + output_layout = Layout::CreateFromProto(arg->output_layout()); TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), program_shape.result())); + *output_layout, program_shape.result())); } HloModuleConfig config(program_shape); @@ -1096,8 +1098,8 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal.Relayout(arg->output_layout()); + if (output_layout.has_value()) { + result_literal = result_literal.Relayout(*output_layout); } *result->mutable_literal() = result_literal.ToProto(); diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 746ab9e9977..b206345db2a 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -32,7 +32,7 @@ Shape::Shape(const ShapeProto& shape_proto) { *add_tuple_shapes() = Shape(element_shape); } if (shape_proto.has_layout()) { - *mutable_layout() = shape_proto.layout(); + *mutable_layout() = Layout::CreateFromProto(shape_proto.layout()); } } @@ -48,7 +48,7 @@ ShapeProto Shape::ToProto() const { *proto.add_tuple_shapes() = shape.ToProto(); } if (has_layout()) { - *proto.mutable_layout() = layout(); + *proto.mutable_layout() = layout().ToProto(); } return proto; } diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 7f6b14ab428..7643f64d8a5 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -76,21 +77,10 @@ class Shape { std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } // Methods for accessing the layout field. - bool has_layout() const { return layout_.has_value(); } - const Layout& layout() const { - if (layout_.has_value()) { - return *layout_; - } else { - return Layout::default_instance(); - } - } - Layout* mutable_layout() { - if (!layout_.has_value()) { - layout_ = Layout(); - } - return &layout_.value(); - } - void clear_layout() { layout_.reset(); } + bool has_layout() const { return layout_.format() != INVALID_FORMAT; } + const Layout& layout() const { return layout_; } + Layout* mutable_layout() { return &layout_; } + void clear_layout() { layout_.Clear(); } void Swap(Shape* other) { using std::swap; @@ -101,7 +91,7 @@ class Shape { element_type_ = PRIMITIVE_TYPE_INVALID; dimensions_.clear(); tuple_shapes_.clear(); - layout_.reset(); + clear_layout(); } string SerializeAsString() const { return ToProto().SerializeAsString(); } @@ -118,8 +108,8 @@ class Shape { // The tuple element subshapes. This is nonempty only for tuple shapes. std::vector tuple_shapes_; - // The array layout of the shape. This is present only for array shapes. - absl::optional layout_; + // The layout of the shape. Only relevant for arrays. + Layout layout_; }; // Shape of the parameters and output of an XLA computation. This is analogous diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index a4d4e1e53e7..eef2dc913dc 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -164,9 +164,9 @@ StatusOr MakeShapeWithLayoutInternal( TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); + min2maj->clear(); for (int64 value : minor_to_major) { - min2maj->Add(value); + min2maj->push_back(value); } if (!shape.has_layout()) { return InvalidArgument("Shape has no layout."); @@ -1618,10 +1618,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); layout->set_format(DENSE); - for (size_t i = 0; i < layout->minor_to_major().size();) { + for (int64 i = 0; i < layout->minor_to_major().size();) { if (layout->minor_to_major(i) == dim_to_delete) { layout->mutable_minor_to_major()->erase( - layout->minor_to_major().begin() + i); + layout->mutable_minor_to_major()->begin() + i); continue; } if (layout->minor_to_major(i) > dim_to_delete) { diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 3622f2c1e84..df005a67097 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -133,7 +133,9 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); - literal_layout->mutable_minor_to_major()->SwapElements(0, 1); + // Swap the first and second elements. + *literal_layout->mutable_minor_to_major() = { + literal_layout->minor_to_major(1), literal_layout->minor_to_major(0)}; HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 32b51c104c7..238312e36bb 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -399,7 +399,7 @@ message WaitForExecutionResponse { message ComputeConstantGraphRequest { HloModuleProto computation = 1; - Layout output_layout = 2; + LayoutProto output_layout = 2; } message ComputeConstantResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 85ec83437a1..e9c86abe509 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -100,6 +100,8 @@ message PaddingConfig { // A format specifies the method used by a layout to store an array in memory. enum Format { + // TODO(b/120869032): Rename this to FORMAT_NONE or something else which + // better corresponds to its meaning. INVALID_FORMAT = 0; // The default layout, with exactly one storage location per element. DENSE = 1; @@ -109,8 +111,9 @@ enum Format { } // Describes a tile used in tiling-based layout. Refer to -// g3doc/layout_with_tiling.md for details about tiling-based layout. -message Tile { +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details about tiling-based layout. +message TileProto { // Number of elements in each dimension of the tile. It's ordered from the // most major dimension of the tile to the most minor dimension of the tile. // The dimensions correspond to a suffix of the dimensions of the shape being @@ -128,7 +131,7 @@ message Tile { // See the XLA documentation for more information on shapes and layouts. // // LINT.IfChange -message Layout { +message LayoutProto { // The method used to store the data in memory. The format determines which of // the other fields are used by the layout. Format format = 4; @@ -153,7 +156,7 @@ message Layout { // // TODO(b/119839262): implement tiling in each backend or add Unimplemented // error. - repeated Tile tiles = 6; + repeated TileProto tiles = 6; // Bit size of each element. If the size is bigger than what the element // type requires, the value is stored in the least significant @@ -196,7 +199,7 @@ message ShapeProto { repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. - Layout layout = 5; + LayoutProto layout = 5; // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for