Replace Layout and Tile protos with C++ classes in XLA.
No functional change. Rename the proto message Layout to LayoutProto, and Tile to TileProto. Define in-place replacement C++ classes named Layout and Tile with an interface which mirrors the protobuf generated code interface. Having these data structures as C++ classes enables greater flexibility in the interface, enables enforcement of invariants, and potential performance improvements. PiperOrigin-RevId: 225121052
This commit is contained in:
parent
c6245fa0b4
commit
2381067873
@ -224,6 +224,7 @@ cc_library(
|
||||
name = "shape_util",
|
||||
srcs = [
|
||||
"index_util.cc",
|
||||
"layout.cc",
|
||||
"layout_util.cc",
|
||||
"primitive_util.cc",
|
||||
"shape.cc",
|
||||
@ -231,6 +232,7 @@ cc_library(
|
||||
],
|
||||
hdrs = [
|
||||
"index_util.h",
|
||||
"layout.h",
|
||||
"layout_util.h",
|
||||
"primitive_util.h",
|
||||
"shape.h",
|
||||
@ -301,6 +303,22 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "layout_test",
|
||||
srcs = ["layout_test.cc"],
|
||||
deps = [
|
||||
":shape_util",
|
||||
":status_macros",
|
||||
":test",
|
||||
":test_helpers",
|
||||
":types",
|
||||
":util",
|
||||
":xla_data_proto",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "index_util_test",
|
||||
srcs = ["index_util_test.cc"],
|
||||
|
@ -186,7 +186,7 @@ StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
|
||||
ComputeConstantGraphRequest request;
|
||||
*request.mutable_computation() = computation.proto();
|
||||
if (output_layout != nullptr) {
|
||||
*request.mutable_output_layout() = *output_layout;
|
||||
*request.mutable_output_layout() = output_layout->ToProto();
|
||||
}
|
||||
|
||||
ComputeConstantResponse response;
|
||||
|
96
tensorflow/compiler/xla/layout.cc
Normal file
96
tensorflow/compiler/xla/layout.cc
Normal file
@ -0,0 +1,96 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
TileProto Tile::ToProto() const {
|
||||
TileProto tile_proto;
|
||||
for (int64 i : dimensions()) {
|
||||
tile_proto.add_dimensions(i);
|
||||
}
|
||||
return tile_proto;
|
||||
}
|
||||
|
||||
string Tile::ToString() const {
|
||||
return absl::StrCat("(", absl::StrJoin(dimensions(), ","), ")");
|
||||
}
|
||||
|
||||
/* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) {
|
||||
Layout layout;
|
||||
layout.set_format(proto.format());
|
||||
layout.minor_to_major_.reserve(proto.minor_to_major_size());
|
||||
for (const int64 dimension : proto.minor_to_major()) {
|
||||
layout.add_minor_to_major(dimension);
|
||||
}
|
||||
layout.set_max_sparse_elements(proto.max_sparse_elements());
|
||||
for (const TileProto& tile_proto : proto.tiles()) {
|
||||
*layout.add_tiles() = Tile::CreateFromProto(tile_proto);
|
||||
}
|
||||
layout.set_element_size_in_bits(proto.element_size_in_bits());
|
||||
return layout;
|
||||
}
|
||||
|
||||
LayoutProto Layout::ToProto() const {
|
||||
LayoutProto proto;
|
||||
proto.set_format(format_);
|
||||
proto.mutable_minor_to_major()->Reserve(minor_to_major_size());
|
||||
for (const int64 dimension : minor_to_major()) {
|
||||
proto.add_minor_to_major(dimension);
|
||||
}
|
||||
proto.set_max_sparse_elements(max_sparse_elements_);
|
||||
for (const Tile& tile : tiles()) {
|
||||
*proto.add_tiles() = tile.ToProto();
|
||||
}
|
||||
proto.set_element_size_in_bits(element_size_in_bits());
|
||||
return proto;
|
||||
}
|
||||
|
||||
string Layout::ToString() const {
|
||||
// TODO(b/119839262): Emit tiles in string.
|
||||
if (format() == SPARSE) {
|
||||
return absl::StrCat("sparse{", max_sparse_elements(), "}");
|
||||
} else if (format() == DENSE) {
|
||||
return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), "}");
|
||||
} else {
|
||||
CHECK_EQ(format(), INVALID_FORMAT);
|
||||
return "invalid{}";
|
||||
}
|
||||
}
|
||||
|
||||
bool Layout::operator==(const Layout& other) const {
|
||||
return (other.format() == format() &&
|
||||
other.minor_to_major() == minor_to_major() &&
|
||||
other.element_size_in_bits() == element_size_in_bits() &&
|
||||
other.max_sparse_elements() == max_sparse_elements() &&
|
||||
other.tiles() == tiles());
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Tile& tile) {
|
||||
out << tile.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Layout& layout) {
|
||||
out << layout.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace xla
|
187
tensorflow/compiler/xla/layout.h
Normal file
187
tensorflow/compiler/xla/layout.h
Normal file
@ -0,0 +1,187 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_LAYOUT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Describes a tile used in tiling-based layout. Refer to
|
||||
// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for
|
||||
// details.
|
||||
class Tile {
|
||||
public:
|
||||
Tile() = default;
|
||||
explicit Tile(absl::Span<const int64> dimensions)
|
||||
: dimensions_(dimensions.begin(), dimensions.end()) {}
|
||||
|
||||
// De/Serialize a Tile to and from a TileProto.
|
||||
static Tile CreateFromProto(const TileProto& tile_proto) {
|
||||
return Tile(AsInt64Slice(tile_proto.dimensions()));
|
||||
}
|
||||
TileProto ToProto() const;
|
||||
|
||||
bool operator==(const Tile& other) const {
|
||||
return dimensions() == other.dimensions();
|
||||
}
|
||||
bool operator!=(const Tile& other) const { return !(*this == other); }
|
||||
|
||||
string ToString() const;
|
||||
|
||||
// Returns the bound of the tile in the given dimension index.
|
||||
int64 dimension(int i) const { return dimensions_.at(i); }
|
||||
|
||||
// Returns the dimensions of the tile.
|
||||
const std::vector<int64>& dimensions() const { return dimensions_; }
|
||||
|
||||
private:
|
||||
// The bounds of the tile.
|
||||
std::vector<int64> dimensions_;
|
||||
};
|
||||
|
||||
class Layout {
|
||||
public:
|
||||
Layout() = default;
|
||||
|
||||
// Constructs a dense layout with the given minor-to-major order.
|
||||
explicit Layout(absl::Span<const int64> minor_to_major)
|
||||
: format_(DENSE),
|
||||
minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {}
|
||||
|
||||
// Constructs a dense tiled layout with the given minor-to-major order and
|
||||
// tiles.
|
||||
Layout(absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles)
|
||||
: format_(DENSE),
|
||||
minor_to_major_(minor_to_major.begin(), minor_to_major.end()),
|
||||
tiles_(tiles.begin(), tiles.end()) {}
|
||||
|
||||
// Construct a shape from a LayoutProto.
|
||||
static Layout CreateFromProto(const LayoutProto& proto);
|
||||
|
||||
// Returns a LayoutProto representation of the Layout.
|
||||
LayoutProto ToProto() const;
|
||||
|
||||
// Returns a human-readable string that represents this layout.
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const Layout& other) const;
|
||||
bool operator!=(const Layout& other) const { return !(*this == other); }
|
||||
|
||||
// The following methods mirror the protobuf generated code interface for the
|
||||
// message LayoutProto. This enabled easy migration of this data structure
|
||||
// from a proto to a proper C++ class.
|
||||
//
|
||||
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
|
||||
// interface.
|
||||
|
||||
// Methods for accessing the format.
|
||||
Format format() const { return format_; }
|
||||
Layout& set_format(Format value) {
|
||||
format_ = value;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Methods for accessing the minor-to-major array.
|
||||
int minor_to_major_size() const { return minor_to_major_.size(); }
|
||||
int64 minor_to_major(int index) const { return minor_to_major_.at(index); }
|
||||
Layout& set_minor_to_major(int index, int64 value) {
|
||||
minor_to_major_.at(index) = value;
|
||||
return *this;
|
||||
}
|
||||
Layout& add_minor_to_major(int64 value) {
|
||||
minor_to_major_.push_back(value);
|
||||
return *this;
|
||||
}
|
||||
Layout& clear_minor_to_major() {
|
||||
minor_to_major_.clear();
|
||||
return *this;
|
||||
}
|
||||
const std::vector<int64>& minor_to_major() const { return minor_to_major_; }
|
||||
std::vector<int64>* mutable_minor_to_major() { return &minor_to_major_; }
|
||||
|
||||
// Methods for accessing the tile field.
|
||||
int tiles_size() const { return tiles_.size(); }
|
||||
const Tile& tiles(int index) const { return tiles_.at(index); }
|
||||
Tile* mutable_tiles(int index) { return &tiles_.at(index); }
|
||||
Tile* add_tiles() {
|
||||
tiles_.push_back(Tile());
|
||||
return &tiles_.back();
|
||||
}
|
||||
Layout& clear_tiles() {
|
||||
tiles_.clear();
|
||||
return *this;
|
||||
}
|
||||
const std::vector<Tile>& tiles() const { return tiles_; }
|
||||
std::vector<Tile>* mutable_tiles() { return &tiles_; }
|
||||
|
||||
// Methods for accessing the int64 fields.
|
||||
int64 max_sparse_elements() const { return max_sparse_elements_; }
|
||||
Layout& set_max_sparse_elements(int64 value) {
|
||||
max_sparse_elements_ = value;
|
||||
return *this;
|
||||
}
|
||||
int64 element_size_in_bits() const { return element_size_in_bits_; }
|
||||
Layout& set_element_size_in_bits(int64 value) {
|
||||
element_size_in_bits_ = value;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Swap(Layout* other) {
|
||||
using std::swap;
|
||||
swap(*this, *other);
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
format_ = INVALID_FORMAT;
|
||||
minor_to_major_.clear();
|
||||
max_sparse_elements_ = 0;
|
||||
element_size_in_bits_ = 0;
|
||||
}
|
||||
|
||||
public:
|
||||
// The format of this layout.
|
||||
Format format_ = INVALID_FORMAT;
|
||||
|
||||
// Sequence of dimension numbers, from minor (fastest varying index) to major
|
||||
// (slowest varying index).
|
||||
std::vector<int64> minor_to_major_;
|
||||
|
||||
// The maximum number of elements that can be stored for SPARSE formats. This
|
||||
// can be used to determine the maximum size in bytes of arrays stored in
|
||||
// memory. This field must be zero unless the format is SPARSE.
|
||||
int64 max_sparse_elements_ = 0;
|
||||
|
||||
// The number of bits used to store an individual array element.
|
||||
int64 element_size_in_bits_ = 0;
|
||||
|
||||
// The tiles used in tiling-based layout.
|
||||
std::vector<Tile> tiles_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Tile& Tile);
|
||||
std::ostream& operator<<(std::ostream& out, const Layout& layout);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_
|
104
tensorflow/compiler/xla/layout_test.cc
Normal file
104
tensorflow/compiler/xla/layout_test.cc
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class LayoutTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(LayoutTest, ToString) {
|
||||
EXPECT_EQ(Layout().ToString(), "invalid{}");
|
||||
EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}");
|
||||
EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(),
|
||||
"sparse{123}");
|
||||
EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}");
|
||||
EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(),
|
||||
"{3,2,1,0}");
|
||||
EXPECT_EQ(
|
||||
Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(),
|
||||
"{1,0}");
|
||||
}
|
||||
|
||||
TEST_F(LayoutTest, StreamOut) {
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << Tile({7, 8});
|
||||
EXPECT_EQ(oss.str(), "(7,8)");
|
||||
}
|
||||
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << Layout({0, 1, 2});
|
||||
EXPECT_EQ(oss.str(), "{0,1,2}");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LayoutTest, SparseLayoutMaxElements) {
|
||||
EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)),
|
||||
101);
|
||||
}
|
||||
|
||||
TEST_F(LayoutTest, Equality) {
|
||||
EXPECT_EQ(Layout(), Layout());
|
||||
const std::vector<int64> empty_dims;
|
||||
EXPECT_EQ(Layout(empty_dims), Layout(empty_dims));
|
||||
EXPECT_NE(Layout(), Layout(empty_dims));
|
||||
EXPECT_EQ(Layout({0, 1, 2, 3}), Layout({0, 1, 2, 3}));
|
||||
EXPECT_NE(Layout({0, 1, 2, 3}), Layout({0, 1, 2}));
|
||||
EXPECT_EQ(Layout({0, 1, 2}, {Tile({42, 44})}),
|
||||
Layout({0, 1, 2}, {Tile({42, 44})}));
|
||||
EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}),
|
||||
Layout({0, 1, 2}, {Tile({42, 45})}));
|
||||
EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2, 3}));
|
||||
EXPECT_EQ(Layout({0, 1, 2}).set_element_size_in_bits(33),
|
||||
Layout({0, 1, 2}).set_element_size_in_bits(33));
|
||||
EXPECT_NE(Layout({0, 1, 2}).set_element_size_in_bits(33),
|
||||
Layout({0, 1, 2}).set_element_size_in_bits(7));
|
||||
EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE));
|
||||
EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42),
|
||||
Layout().set_format(SPARSE).set_max_sparse_elements(42));
|
||||
EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42),
|
||||
Layout().set_format(SPARSE).set_max_sparse_elements(24));
|
||||
}
|
||||
|
||||
TEST_F(LayoutTest, LayoutToFromProto) {
|
||||
// Round-trips a Layout through proto de/serialization.
|
||||
auto expect_unchanged = [](const Layout& layout) {
|
||||
EXPECT_EQ(layout, Layout::CreateFromProto(layout.ToProto()));
|
||||
};
|
||||
|
||||
expect_unchanged(Layout());
|
||||
expect_unchanged(Layout({1, 3, 2, 0}));
|
||||
expect_unchanged(Layout().set_format(SPARSE));
|
||||
expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123));
|
||||
expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42));
|
||||
expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -41,15 +41,13 @@ namespace {
|
||||
|
||||
// Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
|
||||
// minor_to_major to the value that represents the default layout.
|
||||
void SetDefaultLayoutToContainer(
|
||||
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
|
||||
minor_to_major) {
|
||||
void SetDefaultLayoutToContainer(std::vector<int64>* minor_to_major) {
|
||||
// The default XLA layout is major-to-minor (dim 0 is major).
|
||||
// For more information on XLA layouts, see:
|
||||
// https://www.tensorflow.org/performance/xla/shapes
|
||||
const int64 size = minor_to_major->size();
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
minor_to_major->Set(i, size - 1 - i);
|
||||
(*minor_to_major)[i] = size - 1 - i;
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,9 +92,8 @@ namespace {
|
||||
Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
Layout layout;
|
||||
layout.set_format(DENSE);
|
||||
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
|
||||
minor_to_major = layout.mutable_minor_to_major();
|
||||
minor_to_major->Resize(rank, 0);
|
||||
std::vector<int64>* minor_to_major = layout.mutable_minor_to_major();
|
||||
minor_to_major->resize(rank, 0);
|
||||
SetDefaultLayoutToContainer(minor_to_major);
|
||||
return layout;
|
||||
}
|
||||
@ -139,9 +136,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
shape->clear_layout();
|
||||
} else if (ShapeUtil::IsArray(*shape)) {
|
||||
shape->mutable_layout()->set_format(DENSE);
|
||||
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
|
||||
minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
|
||||
minor_to_major->Resize(shape->dimensions_size(), 0);
|
||||
auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
|
||||
minor_to_major->resize(shape->dimensions_size(), 0);
|
||||
SetDefaultLayoutToContainer(minor_to_major);
|
||||
} else {
|
||||
// Opaque, token types etc. have no layout.
|
||||
@ -210,9 +206,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
}
|
||||
|
||||
if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
|
||||
return InvalidArgument(
|
||||
"Layout has an invalid format (%d) in layout {%s}, shape {%s}",
|
||||
layout.format(), layout.ShortDebugString(), shape.ShortDebugString());
|
||||
return InvalidArgument("Layout has an invalid format (%d)",
|
||||
layout.format());
|
||||
}
|
||||
|
||||
if (layout.format() == DENSE) {
|
||||
@ -316,7 +311,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
|
||||
return protobuf_util::ProtobufEquals(lhs, rhs);
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
/* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
|
||||
@ -358,11 +353,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
}
|
||||
|
||||
/* static */ string LayoutUtil::HumanString(const Layout& layout) {
|
||||
if (IsSparse(layout)) {
|
||||
return absl::StrCat("sparse{", layout.max_sparse_elements(), "}");
|
||||
}
|
||||
CHECK(IsDense(layout));
|
||||
return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}");
|
||||
return layout.ToString();
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -444,11 +435,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Layout& layout) {
|
||||
out << LayoutUtil::HumanString(layout);
|
||||
return out;
|
||||
}
|
||||
|
||||
/*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
|
||||
using tensorflow::hash;
|
||||
using tensorflow::Hash64Combine;
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -195,8 +196,6 @@ class LayoutUtil {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil);
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Layout& layout);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_
|
||||
|
@ -317,17 +317,6 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
|
||||
ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, SparseLayoutMaxElements) {
|
||||
EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)),
|
||||
101);
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, StreamOut) {
|
||||
std::ostringstream oss;
|
||||
oss << LayoutUtil::MakeLayout({0, 1, 2});
|
||||
EXPECT_EQ(oss.str(), "{0,1,2}");
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) {
|
||||
Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1});
|
||||
auto status =
|
||||
|
@ -42,8 +42,7 @@ PackedLiteralReader::~PackedLiteralReader() { delete file_; }
|
||||
StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
|
||||
const Layout* layout) {
|
||||
VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
|
||||
<< " layout: "
|
||||
<< (layout == nullptr ? "<none>" : layout->ShortDebugString());
|
||||
<< " layout: " << (layout == nullptr ? "<none>" : layout->ToString());
|
||||
Shape literal_shape = shape;
|
||||
if (layout != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -1078,9 +1078,11 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
|
||||
|
||||
ProgramShape program_shape(arg->computation().host_program_shape());
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
|
||||
absl::optional<Layout> output_layout;
|
||||
if (arg->has_output_layout()) {
|
||||
output_layout = Layout::CreateFromProto(arg->output_layout());
|
||||
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
|
||||
arg->output_layout(), program_shape.result()));
|
||||
*output_layout, program_shape.result()));
|
||||
}
|
||||
|
||||
HloModuleConfig config(program_shape);
|
||||
@ -1096,8 +1098,8 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
|
||||
// relayout here.
|
||||
//
|
||||
// TODO(b/77824332): Make HloEvaluator take care of the re-layout.
|
||||
if (arg->has_output_layout()) {
|
||||
result_literal = result_literal.Relayout(arg->output_layout());
|
||||
if (output_layout.has_value()) {
|
||||
result_literal = result_literal.Relayout(*output_layout);
|
||||
}
|
||||
*result->mutable_literal() = result_literal.ToProto();
|
||||
|
||||
|
@ -32,7 +32,7 @@ Shape::Shape(const ShapeProto& shape_proto) {
|
||||
*add_tuple_shapes() = Shape(element_shape);
|
||||
}
|
||||
if (shape_proto.has_layout()) {
|
||||
*mutable_layout() = shape_proto.layout();
|
||||
*mutable_layout() = Layout::CreateFromProto(shape_proto.layout());
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,7 +48,7 @@ ShapeProto Shape::ToProto() const {
|
||||
*proto.add_tuple_shapes() = shape.ToProto();
|
||||
}
|
||||
if (has_layout()) {
|
||||
*proto.mutable_layout() = layout();
|
||||
*proto.mutable_layout() = layout().ToProto();
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -76,21 +77,10 @@ class Shape {
|
||||
std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
|
||||
|
||||
// Methods for accessing the layout field.
|
||||
bool has_layout() const { return layout_.has_value(); }
|
||||
const Layout& layout() const {
|
||||
if (layout_.has_value()) {
|
||||
return *layout_;
|
||||
} else {
|
||||
return Layout::default_instance();
|
||||
}
|
||||
}
|
||||
Layout* mutable_layout() {
|
||||
if (!layout_.has_value()) {
|
||||
layout_ = Layout();
|
||||
}
|
||||
return &layout_.value();
|
||||
}
|
||||
void clear_layout() { layout_.reset(); }
|
||||
bool has_layout() const { return layout_.format() != INVALID_FORMAT; }
|
||||
const Layout& layout() const { return layout_; }
|
||||
Layout* mutable_layout() { return &layout_; }
|
||||
void clear_layout() { layout_.Clear(); }
|
||||
|
||||
void Swap(Shape* other) {
|
||||
using std::swap;
|
||||
@ -101,7 +91,7 @@ class Shape {
|
||||
element_type_ = PRIMITIVE_TYPE_INVALID;
|
||||
dimensions_.clear();
|
||||
tuple_shapes_.clear();
|
||||
layout_.reset();
|
||||
clear_layout();
|
||||
}
|
||||
|
||||
string SerializeAsString() const { return ToProto().SerializeAsString(); }
|
||||
@ -118,8 +108,8 @@ class Shape {
|
||||
// The tuple element subshapes. This is nonempty only for tuple shapes.
|
||||
std::vector<Shape> tuple_shapes_;
|
||||
|
||||
// The array layout of the shape. This is present only for array shapes.
|
||||
absl::optional<Layout> layout_;
|
||||
// The layout of the shape. Only relevant for arrays.
|
||||
Layout layout_;
|
||||
};
|
||||
|
||||
// Shape of the parameters and output of an XLA computation. This is analogous
|
||||
|
@ -164,9 +164,9 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeUtil::MakeValidatedShape(element_type, dimensions));
|
||||
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
|
||||
min2maj->Clear();
|
||||
min2maj->clear();
|
||||
for (int64 value : minor_to_major) {
|
||||
min2maj->Add(value);
|
||||
min2maj->push_back(value);
|
||||
}
|
||||
if (!shape.has_layout()) {
|
||||
return InvalidArgument("Shape has no layout.");
|
||||
@ -1618,10 +1618,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
if (LayoutUtil::HasLayout(shape)) {
|
||||
Layout* layout = shape.mutable_layout();
|
||||
layout->set_format(DENSE);
|
||||
for (size_t i = 0; i < layout->minor_to_major().size();) {
|
||||
for (int64 i = 0; i < layout->minor_to_major().size();) {
|
||||
if (layout->minor_to_major(i) == dim_to_delete) {
|
||||
layout->mutable_minor_to_major()->erase(
|
||||
layout->minor_to_major().begin() + i);
|
||||
layout->mutable_minor_to_major()->begin() + i);
|
||||
continue;
|
||||
}
|
||||
if (layout->minor_to_major(i) > dim_to_delete) {
|
||||
|
@ -133,7 +133,9 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
|
||||
// Reverse the minor-to-major order of the literal.
|
||||
Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
|
||||
ASSERT_EQ(2, literal_layout->minor_to_major_size());
|
||||
literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
|
||||
// Swap the first and second elements.
|
||||
*literal_layout->mutable_minor_to_major() = {
|
||||
literal_layout->minor_to_major(1), literal_layout->minor_to_major(0)};
|
||||
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
|
@ -399,7 +399,7 @@ message WaitForExecutionResponse {
|
||||
|
||||
message ComputeConstantGraphRequest {
|
||||
HloModuleProto computation = 1;
|
||||
Layout output_layout = 2;
|
||||
LayoutProto output_layout = 2;
|
||||
}
|
||||
|
||||
message ComputeConstantResponse {
|
||||
|
@ -100,6 +100,8 @@ message PaddingConfig {
|
||||
|
||||
// A format specifies the method used by a layout to store an array in memory.
|
||||
enum Format {
|
||||
// TODO(b/120869032): Rename this to FORMAT_NONE or something else which
|
||||
// better corresponds to its meaning.
|
||||
INVALID_FORMAT = 0;
|
||||
// The default layout, with exactly one storage location per element.
|
||||
DENSE = 1;
|
||||
@ -109,8 +111,9 @@ enum Format {
|
||||
}
|
||||
|
||||
// Describes a tile used in tiling-based layout. Refer to
|
||||
// g3doc/layout_with_tiling.md for details about tiling-based layout.
|
||||
message Tile {
|
||||
// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for
|
||||
// details about tiling-based layout.
|
||||
message TileProto {
|
||||
// Number of elements in each dimension of the tile. It's ordered from the
|
||||
// most major dimension of the tile to the most minor dimension of the tile.
|
||||
// The dimensions correspond to a suffix of the dimensions of the shape being
|
||||
@ -128,7 +131,7 @@ message Tile {
|
||||
// See the XLA documentation for more information on shapes and layouts.
|
||||
//
|
||||
// LINT.IfChange
|
||||
message Layout {
|
||||
message LayoutProto {
|
||||
// The method used to store the data in memory. The format determines which of
|
||||
// the other fields are used by the layout.
|
||||
Format format = 4;
|
||||
@ -153,7 +156,7 @@ message Layout {
|
||||
//
|
||||
// TODO(b/119839262): implement tiling in each backend or add Unimplemented
|
||||
// error.
|
||||
repeated Tile tiles = 6;
|
||||
repeated TileProto tiles = 6;
|
||||
|
||||
// Bit size of each element. If the size is bigger than what the element
|
||||
// type requires, the value is stored in the least significant
|
||||
@ -196,7 +199,7 @@ message ShapeProto {
|
||||
repeated ShapeProto tuple_shapes = 4;
|
||||
|
||||
// The layout used to back this shape.
|
||||
Layout layout = 5;
|
||||
LayoutProto layout = 5;
|
||||
|
||||
// Important: if any field is added, be sure to modify ShapeUtil::Equal(),
|
||||
// ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
|
||||
|
Loading…
x
Reference in New Issue
Block a user