diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5a5e5fe0d71..7ccfc45a5ff 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -542,7 +542,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // WARNING: kernel->Run utilizes the FunctionLibraryRuntime // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation - // of FunctionLibraryRuntime tells use that func_lib_def is not accessed by + // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. // This is quite subtle. Re-work things to make this better? (Would it make // sense for FunctionLibraryRuntime to ensure thread-safe access to diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 88de17a5ff0..dcbe1fe9e5f 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -302,6 +302,7 @@ cc_library( ":array4d", ":shape_tree", ":shape_util", + ":sparse_index_array", ":status_macros", ":types", ":util", @@ -628,6 +629,28 @@ tf_cc_test( ], ) +cc_library( + name = "sparse_index_array", + srcs = ["sparse_index_array.cc"], + hdrs = ["sparse_index_array.h"], + deps = [ + ":array2d", + ":shape_util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "sparse_index_array_test", + srcs = ["sparse_index_array_test.cc"], + deps = [ + ":sparse_index_array", + ":test", + "//tensorflow/core:test_main", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 2ee23927d86..ffd1fb79e98 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -149,4 +149,33 @@ namespace xla { return stride; } +/* static */ bool IndexUtil::IndexInBounds( + const Shape& shape, tensorflow::gtl::ArraySlice index) { + int64 rank = ShapeUtil::Rank(shape); + if (rank != index.size()) { + return false; + } + for (int64 d = 0; d < rank; ++d) { + if (index[d] >= shape.dimensions(d)) { + return false; + } + } + return true; +} + +/* static */ int IndexUtil::CompareIndices( + tensorflow::gtl::ArraySlice lhs, + tensorflow::gtl::ArraySlice rhs) { + int64 rank = lhs.size(); + CHECK_EQ(rhs.size(), rank); + for (int64 dim = 0; dim < rank; ++dim) { + if (lhs[dim] < rhs[dim]) { + return -1; + } else if (lhs[dim] > rhs[dim]) { + return 1; + } + } + return 0; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index c9838966a5b..0b9188e8524 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -69,6 +69,18 @@ class IndexUtil { // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 static int64 GetDimensionStride(const Shape& shape, int64 dimension); + // Returns true iff the given multi-index is contained in the bounds for the + // shape. + static bool IndexInBounds(const Shape& shape, + tensorflow::gtl::ArraySlice index); + + // Compares the given indices in lexicographic order. lhs[0] and rhs[0] are + // compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger, + // then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is + // returned. + static int CompareIndices(tensorflow::gtl::ArraySlice lhs, + tensorflow::gtl::ArraySlice rhs); + private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); }; diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 6435226fbe6..ddf091e19ff 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -64,6 +64,13 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { + Layout layout; + layout.set_format(SPARSE); + layout.set_max_sparse_elements(max_sparse_elements); + return layout; +} + namespace { // Internal helper that creates a default layout for an array of the given rank. @@ -234,7 +241,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::ClearLayout(program_shape->mutable_result()); } -/* static */ bool LayoutUtil::IsDense(const Shape& shape) { +/* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && shape.has_layout() && IsDense(shape.layout()); } @@ -260,7 +267,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape.layout().padded_dimensions_size() == 0) { return false; } - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); for (int64 i = 0; i < shape.dimensions_size(); ++i) { if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { @@ -272,21 +279,35 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ tensorflow::gtl::ArraySlice LayoutUtil::PaddedDimensions( const Shape& shape) { - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().padded_dimensions()); } /* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape, int64 index) { - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); return shape.layout().padded_dimensions(index); } /* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) { - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); return shape.layout().padding_value(); } +/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && shape.has_layout() && + IsSparse(shape.layout()); +} + +/* static */ bool LayoutUtil::IsSparse(const Layout& layout) { + return layout.format() == SPARSE; +} + +/* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) { + CHECK(IsSparse(layout)); + return layout.max_sparse_elements(); +} + /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape: all subshapes must have a layout. @@ -313,7 +334,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( const Shape& shape) { - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().minor_to_major()); } @@ -419,6 +440,7 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, /* static */ bool LayoutUtil::AreDimensionsConsecutive( const Layout& layout, tensorflow::gtl::ArraySlice dims) { + CHECK(IsDense(layout)); std::vector positions_in_layout; for (int64 dim : dims) { positions_in_layout.push_back( diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index f73cc957649..7c1ba4b022e 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + // Creates a sparse layout with the given maximum number of elements. (This is + // a convenience function for protobuf construction.) + static Layout MakeSparseLayout(int64 max_sparse_elements); + // Returns default layout for the given shape. static Layout GetDefaultLayoutForShape(const Shape& shape); @@ -72,7 +76,7 @@ class LayoutUtil { static void ClearLayout(ProgramShape* program_shape); // Returns whether the given Shape is an array and has a dense format layout. - static bool IsDense(const Shape& shape); + static bool IsDenseArray(const Shape& shape); // Returns whether the given Layout has a dense format. static bool IsDense(const Layout& layout); @@ -107,6 +111,17 @@ class LayoutUtil { // an array and has a dense layout. static PaddingValue GetPaddingValue(const Shape& shape); + // Returns whether the given Shape is an array (i.e. not a tuple) and has a + // sparse format layout. + static bool IsSparseArray(const Shape& shape); + + // Returns whether the given Layout has a sparse format. + static bool IsSparse(const Layout& layout); + + // Returns the maximum number of elements that can be stored in a sparse + // layout. + static int64 MaxSparseElements(const Layout& layout); + // Returns whether the given shape has a layout. For tuple shapes, true is // returned only if all elements have layouts. static bool HasLayout(const Shape& shape); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 331bb9afa94..daf4dc10ac7 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -30,6 +30,14 @@ class LayoutUtilTest : public ::testing::Test { *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); return shape; } + + Shape MakeShapeWithSparseLayout(PrimitiveType element_type, + tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements) { + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); + return shape; + } }; TEST_F(LayoutUtilTest, TupleLayoutComparison) { @@ -81,6 +89,29 @@ TEST_F(LayoutUtilTest, CopyLayoutArray) { EXPECT_FALSE(dst.has_layout()); } +TEST_F(LayoutUtilTest, CopyLayoutSparse) { + Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2); + Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // Should work if destination has no layout. + dst.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // If source is cleared, then destination should be cleared. + src.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_TRUE(dst.has_layout()); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_FALSE(dst.has_layout()); +} + TEST_F(LayoutUtilTest, CopyLayoutTuple) { Shape src = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), @@ -100,6 +131,25 @@ TEST_F(LayoutUtilTest, CopyLayoutTuple) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } +TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithSparseLayout(F32, {2, 3}, 4), + MakeShapeWithSparseLayout(F32, {42, 123}, 4), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); @@ -107,6 +157,13 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } +TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) { + Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6); + Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); + ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -116,6 +173,15 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { ::testing::ContainsRegex("cannot copy layout from shape")); } +TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) { + Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); + Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4); + auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("cannot copy layout from shape")); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { Shape src = ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}), @@ -221,5 +287,10 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } +TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index cc1735e6f2c..dff5c1381ab 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -94,9 +94,15 @@ Literal::Literal(const Shape& shape, bool allocate_arrays) Piece& piece = pair.second; piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - if (ShapeUtil::IsArray(piece.subshape())) { + const Shape& subshape = piece.subshape(); + if (ShapeUtil::IsArray(subshape)) { if (allocate_arrays) { piece.set_buffer(new char[piece.size_bytes()]); + if (LayoutUtil::IsSparseArray(subshape)) { + piece.set_sparse_indices(new SparseIndexArray( + LayoutUtil::MaxSparseElements(subshape.layout()), + ShapeUtil::Rank(subshape))); + } } else { piece.set_buffer(nullptr); } @@ -112,6 +118,7 @@ void Literal::DeallocateBuffers() { Piece& piece = pair.second; if (piece.buffer() != nullptr) { delete[] piece.buffer(); + delete piece.sparse_indices(); } } } @@ -164,6 +171,15 @@ std::unique_ptr Literal::CreateFromShape(const Shape& shape) { return literal; } +const SparseIndexArray* Literal::sparse_indices( + const ShapeIndex& shape_index) const { + return piece(shape_index).sparse_indices(); +} + +SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { + return piece(shape_index).sparse_indices(); +} + /* static */ std::unique_ptr Literal::CreateFromDimensions( PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions) { @@ -247,9 +263,12 @@ std::vector Literal::DecomposeTuple() { } Piece& src_piece = piece(src_index); - // Move the respective buffer over to the element Literal. + // Move the respective buffer and sparse indices over to the element + // Literal. dest_piece.set_buffer(src_piece.buffer()); src_piece.set_buffer(nullptr); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); } } // Set this literal to be nil-shaped. @@ -406,6 +425,8 @@ Status Literal::MoveFrom(Literal&& src_literal, Piece& dest_piece = piece(dest_index); delete[] dest_piece.buffer(); dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); } src_literal.shape_ = ShapeUtil::MakeNil(); @@ -764,7 +785,7 @@ std::unique_ptr Literal::Transpose( // dimension has within the transposed array, a layout is affine if // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major // vector of the affine layout. - CHECK(LayoutUtil::IsDense(permuted_shape)); + CHECK(LayoutUtil::IsDenseArray(permuted_shape)); Layout* layout = permuted_shape.mutable_layout(); layout->clear_minor_to_major(); for (auto index : LayoutUtil::MinorToMajor(shape())) { @@ -1573,6 +1594,12 @@ LiteralProto Literal::ToProto() const { } piece.WriteToProto(proto_piece); } + + if (LayoutUtil::IsSparseArray(shape())) { + CopyToRepeatedField(proto.mutable_sparse_indices(), + sparse_indices()->data()); + } + return proto; } @@ -1653,6 +1680,7 @@ LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { } const Piece& src_piece = literal.piece(src_index); piece.set_buffer(src_piece.buffer()); + piece.set_sparse_indices(src_piece.sparse_indices()); piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); } } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index dc29c6359c6..50e25bbdd0d 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -103,6 +104,12 @@ class Literal { tensorflow::gtl::MutableArraySlice data( const ShapeIndex& shape_index = {}); + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // Returns a pointer to (or size of) the underlying buffer holding the array // at the given shape index. CHECKs if the subshape of the literal at the // given ShapeIndex is not array. @@ -160,6 +167,56 @@ class Literal { // array. string GetR1U8AsString() const; + // Creates a literal with a sparse layout and the given indices and values. + // The shape is initialized from the given dimensions. The minor dimension of + // the indices array must equal the rank of the shape (i.e. size of the + // dimensions array). The major dimension of the indices array must equal the + // number of elements in the values array. The maximum number of elements in + // the array is taken from the max_indices() value of the index array. + // + // XLA assumes that sparse literals are in sorted order for all operations. If + // the `sort` argument is true, then the indices and values will be sorted + // while copying them into the literal. If you have ensured that the indices + // and values are already sorted, then you may set the `sort` argument to + // false to skip the sorting step. + // + // For example: + // + // CreateSparse( + // {12, 12, 12}, + // SparseIndexArray(10, 3, + // Array2D{ + // {0, 1, 2}, + // {3, 4, 5}, + // {6, 7, 8}, + // {9, 10, 11}, + // }), + // {1.0, 2.0 3.0, 4.0}) + // + // This creates an array with shape F64[12,12,12]sparse{10}, that has the + // following non-zero values: + // + // [0, 1, 2]: 1.0 + // [3, 4, 5]: 2.0 + // [6, 7, 8]: 3.0 + // [9, 10, 11]: 4.0 + // + template + static std::unique_ptr CreateSparse( + tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, bool sort = true); + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). @@ -358,7 +415,7 @@ class Literal { const ShapeIndex& shape_index, NativeT value); // Overloads of Get and Set for array literals. CHECKs if the literal is not - // array-shaped. + // array-shaped and dense. template NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; template @@ -408,6 +465,8 @@ class Literal { // This function is useful if you want a polymorphic representation // of the tensor's elements (turning it to a string for something // like representation in a protobuf). + // + // This literal must have a dense layout. void EachCellAsString( const std::function indices, const string& value)>& per_cell) const; @@ -447,6 +506,8 @@ class Literal { // // generator must be a callable of the type // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. template Status Populate(const FnType& generator); @@ -485,10 +546,12 @@ class Literal { // admonishments about floating-point equality checks apply. We expect you to // use this to check for complex values that can be expressed precisely as // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. bool IsAllComplex(complex64 value) const; // Returns whether this literal is zero at the specified index. This literal - // must be an array. + // must be an array with a dense layout. bool IsZero(tensorflow::gtl::ArraySlice indices) const; // Return the count of the elements in the array at the given shape index in @@ -563,6 +626,14 @@ class Literal { char* buffer() const { return buffer_; } void set_buffer(char* buffer) { buffer_ = buffer; } + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + // Gets or sets the subshape of this piece. This reference points to a // subshape within the shape in the containing Literal (Literal::shape_). const Shape& subshape() const { return *subshape_; } @@ -598,6 +669,9 @@ class Literal { // For array-shaped pieces, this is the buffer holding the literal data. char* buffer_ = nullptr; + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + // The shape of piece. This points into the shape of the containing Literal // (Literal::shape_). const Shape* subshape_ = nullptr; @@ -836,6 +910,21 @@ template return CreateR4FromArray4DWithLayout(tmp, layout); } +template +/* static */ std::unique_ptr Literal::CreateSparse( + tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, bool sort) { + int64 num_elements = values.size(); + int64 rank = dimensions.size(); + CHECK_EQ(num_elements, indices.index_count()); + CHECK_EQ(rank, indices.rank()); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); + literal->PopulateSparse(indices, values, sort); + return literal; +} + template /* static */ std::unique_ptr Literal::CreateR4( std::initializer_list& values) { PopulateFromArray(values); } +template +void Literal::PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort) { + CHECK(LayoutUtil::IsSparseArray(shape())); + int rank = ShapeUtil::Rank(shape()); + CHECK_EQ(indices.rank(), rank); + int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); + CHECK_LE(indices.max_indices(), max_elements); + int64 num_elements = values.size(); + CHECK_LE(num_elements, max_elements); + CHECK_EQ(num_elements, indices.index_count()); + auto root_data = root_piece().data(); + root_data.remove_suffix(max_elements - values.size()); + std::copy(values.begin(), values.end(), root_data.begin()); + *this->root_piece().sparse_indices() = std::move(indices); + if (sort) { + auto root_data = this->root_piece().data(); + root_data.remove_suffix(root_data.size() - num_elements); + this->root_piece().sparse_indices()->SortWithValues(root_data); + } + DCHECK(this->root_piece().sparse_indices()->Validate(shape())); +} + template Status Literal::Populate(const FnType& generator) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); - TF_RET_CHECK(ShapeUtil::IsArray(this_shape)); + TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); tensorflow::gtl::MutableArraySlice literal_data = data(); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 4974ead048d..29efb4312f2 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -193,6 +193,34 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { ASSERT_EQ(expected, result); } +TEST_F(LiteralUtilTest, CreateSparse) { + std::vector dimensions = {8, 8, 8}; + Array2D indices = { + {3, 4, 5}, + {1, 2, 3}, + {2, 3, 4}, + {3, 5, 6}, + }; + std::vector values = {7, 8, 9, 10}; + auto literal = Literal::CreateSparse( + dimensions, SparseIndexArray(indices.n1() + 3, indices), values); + + Array2D expected_indices = { + {1, 2, 3}, + {2, 3, 4}, + {3, 4, 5}, + {3, 5, 6}, + }; + std::vector expected_values = {8, 9, 7, 10}; + + EXPECT_EQ(literal->sparse_indices()->data(), + tensorflow::gtl::ArraySlice( + expected_indices.data(), expected_indices.num_elements())); + EXPECT_EQ(tensorflow::gtl::ArraySlice(literal->data().data(), + expected_values.size()), + tensorflow::gtl::ArraySlice(expected_values)); +} + TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off auto literal = Literal::CreateR4Projected({ diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 51d0d5f86f0..50659c12405 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -60,6 +60,12 @@ bool ContainsKey(const Collection& collection, const Key& key) { return collection.find(key) != collection.end(); } +// Inserts `value` into `set`. Dies if it was already present. +template +void InsertOrDie(Set* const set, const typename Set::value_type& value) { + CHECK(set->insert(value).second) << "duplicate value: " << value; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7652b492d5c..7637ff57e00 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1101,6 +1101,8 @@ cc_library( ":hlo", ":hlo_evaluator", ":hlo_pass", + ":tuple_util", + ":while_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", ], @@ -2255,6 +2257,78 @@ cc_library( ], ) +cc_library( + name = "tuple_util", + srcs = ["tuple_util.cc"], + hdrs = ["tuple_util.h"], + deps = [ + ":hlo", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "tuple_util_test", + srcs = ["tuple_util_test.cc"], + deps = [ + ":tuple_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "while_util", + srcs = ["while_util.cc"], + hdrs = ["while_util.h"], + deps = [ + ":call_inliner", + ":hlo", + ":tuple_util", + ], +) + +tf_cc_test( + name = "while_util_test", + srcs = ["while_util_test.cc"], + deps = [ + ":while_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "while_loop_invariant_code_motion", + srcs = ["while_loop_invariant_code_motion.cc"], + hdrs = ["while_loop_invariant_code_motion.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":tuple_util", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_invariant_code_motion_test", + srcs = ["while_loop_invariant_code_motion_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_invariant_code_motion", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 3aa7f5c4d58..482ccc5b671 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -82,6 +82,10 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { return outer_->ReplaceInstruction(call_, new_root); } + CallInliner::InlinedInstructionMap ConsumeInstructionMap() { + return std::move(subcomputation_hlo_to_new_hlo_); + } + private: // Resolves the callee subcomputation_hlo to the new (inline) HLO in the // caller computation, or returns a NotFound error if that subcomputation HLO @@ -112,13 +116,13 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { HloInstruction* call_; HloComputation* outer_; - std::unordered_map - subcomputation_hlo_to_new_hlo_; + CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_; }; } // namespace -/* static */ Status CallInliner::Inline(HloInstruction* call) { +/* static */ StatusOr CallInliner::Inline( + HloInstruction* call) { TF_RET_CHECK(call->opcode() == HloOpcode::kCall) << "Instruction was not a call op: " << call->opcode(); const auto& callees = call->called_computations(); @@ -126,7 +130,8 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { HloComputation* callee = callees[0]; // We visit the callee, cloning its body into its caller. SubcomputationInsertionVisitor visitor(call); - return callee->Accept(&visitor); + TF_RETURN_IF_ERROR(callee->Accept(&visitor)); + return visitor.ConsumeInstructionMap(); } StatusOr CallInliner::Run(HloModule* module) { @@ -140,7 +145,7 @@ StatusOr CallInliner::Run(HloModule* module) { VLOG(1) << "Visiting callsite: " << callsite.ToString(); if (callsite.instruction()->opcode() == HloOpcode::kCall) { HloInstruction* call = callsite.instruction(); - TF_RETURN_IF_ERROR(Inline(call)); + TF_RETURN_IF_ERROR(Inline(call).status()); did_mutate = true; } } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index 2dbd38bf1ac..a8345a394d4 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -27,8 +27,12 @@ namespace xla { // called function, and proceed recursively. class CallInliner : public HloPassInterface { public: - // Inlines one call instruction. - static Status Inline(HloInstruction* call); + using InlinedInstructionMap = + std::unordered_map; + + // Inlines one call instruction. Returns a mapping from the original + // instructions to their inlined versions. + static StatusOr Inline(HloInstruction* call); ~CallInliner() override = default; tensorflow::StringPiece name() const override { return "CallInliner"; } diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 865ed993da1..738d00881dd 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -135,7 +135,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { HloInstruction::CreateCall(pred, {}, false_computation)); auto computation = module->AddEntryComputation(call_false_builder.Build()); - TF_ASSERT_OK(CallInliner::Inline(call)); + TF_ASSERT_OK(CallInliner::Inline(call).status()); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_THAT(computation->root_instruction()->control_successors(), ElementsAre(op::Constant())); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 14c86e2e720..2f025916312 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -81,6 +81,7 @@ cc_library( ":conv_canonicalization", ":cpu_copy_insertion", ":cpu_executable", + ":cpu_hlo_support_checker", ":cpu_instruction_fusion", ":cpu_layout_assignment", ":cpu_options", @@ -126,6 +127,7 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep @@ -873,6 +875,32 @@ tf_cc_test( ], ) +cc_library( + name = "cpu_hlo_support_checker", + srcs = ["cpu_hlo_support_checker.cc"], + hdrs = ["cpu_hlo_support_checker.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "cpu_hlo_support_checker_test", + srcs = ["cpu_hlo_support_checker_test.cc"], + deps = [ + ":cpu_hlo_support_checker", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 8e6562c237e..705bcb2e9bd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/TargetRegistry.h" @@ -50,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" @@ -85,6 +87,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -258,6 +261,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(ShapeSizeBytesFunction()); + pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), @@ -291,6 +295,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // elimination has to come after that pass. pipeline.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -439,6 +444,21 @@ Status InitializeModuleHooks( return Status::OK(); } +Status VerifyLlvmModule(const llvm::Module& llvm_module) { + XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " + "Rerun with --xla_dump_ir_to to get the IR. "; + return Status::OK(); +} + } // namespace StatusOr> CpuCompiler::RunHloPasses( @@ -627,6 +647,7 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); @@ -704,6 +725,7 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); @@ -875,6 +897,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, &module_sequence.at(computation))); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); + TF_RETURN_IF_ERROR(VerifyLlvmModule(llvm_module)); ModuleHook pre_optimization_ir_dump_hook; ModuleHook post_optimization_ir_dump_hook; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc new file mode 100644 index 00000000000..7bd4741a04b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr CpuHloSupportChecker::Run(HloModule* module) { + for (auto* computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instruction->shape(), + [&instruction](const Shape& subshape, const ShapeIndex&) { + if (LayoutUtil::IsSparseArray(subshape)) { + return xla::Unimplemented( + "CPU backend does not support HLO instruction %s with shape " + "containing a sparse layout: %s", + instruction->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(instruction->shape()) + .c_str()); + } + return Status::OK(); + })); + } + } + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h new file mode 100644 index 00000000000..2271af7b247 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass should run early in the HLO pipeline and checks for HLO constructs +// which are not supported by the CPU backend and cannot be removed via HLO +// transformations (eg, sparse layouts). +class CpuHloSupportChecker : public HloPassInterface { + public: + CpuHloSupportChecker() = default; + ~CpuHloSupportChecker() override = default; + + tensorflow::StringPiece name() const override { + return "cpu_hlo_support_checker"; + } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc new file mode 100644 index 00000000000..0f463e6de62 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +class CpuHloSupportCheckerTest : public HloTestBase { + protected: + CpuHloSupportChecker& checker() { return checker_; } + + private: + CpuHloSupportChecker checker_; +}; + +TEST_F(CpuHloSupportCheckerTest, Add) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(checker().Run(module.get()).status()); +} + +TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { + HloComputation::Builder builder(TestName()); + const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, sparse_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, sparse_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + sparse_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Status status = checker().Run(module.get()).status(); + ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_THAT(status.error_message(), + HasSubstr("CPU backend does not support")); + EXPECT_THAT(status.error_message(), + HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a86d3583a6b..69ccc7179f9 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -437,6 +437,7 @@ cc_library( ":fusion_merger", ":gpu_copy_insertion", ":gpu_executable", + ":gpu_hlo_support_checker", ":gpu_layout_assignment", ":hlo_schedule", ":instruction_fusion", @@ -610,6 +611,32 @@ tf_cc_test( ], ) +cc_library( + name = "gpu_hlo_support_checker", + srcs = ["gpu_hlo_support_checker.cc"], + hdrs = ["gpu_hlo_support_checker.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_hlo_support_checker_test", + srcs = ["gpu_hlo_support_checker_test.cc"], + deps = [ + ":gpu_hlo_support_checker", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 9f34866ff51..9321429bdcc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" @@ -137,6 +139,7 @@ tensorflow::Status OptimizeHloModule( { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(shape_size_function); + pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); @@ -476,6 +479,20 @@ StatusOr> GpuCompiler::RunBackend( entry_computation->root_instruction()->Accept(&ir_emitter)); } + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " + "Rerun with --xla_dump_ir_to to get the IR. "; + } + if (user_pre_optimization_hook_) { TF_CHECK_OK(user_pre_optimization_hook_(llvm_module)); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc new file mode 100644 index 00000000000..4944c41f7d8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr GpuHloSupportChecker::Run(HloModule* module) { + for (auto* computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instruction->shape(), + [&instruction](const Shape& subshape, const ShapeIndex&) { + if (LayoutUtil::IsSparseArray(subshape)) { + return xla::Unimplemented( + "GPU backend does not support HLO instruction %s with shape " + "containing a sparse layout: %s", + instruction->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(instruction->shape()) + .c_str()); + } + return Status::OK(); + })); + } + } + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h new file mode 100644 index 00000000000..d9550f81b59 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// his pass should run early in the HLO pipeline and checks for HLO constructs +// which are not supported by the GPU backend and cannot be removed via HLO +// transformations (eg, sparse layouts). +class GpuHloSupportChecker : public HloPassInterface { + public: + GpuHloSupportChecker() = default; + ~GpuHloSupportChecker() override = default; + + tensorflow::StringPiece name() const override { + return "gpu_hlo_support_checker"; + } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc new file mode 100644 index 00000000000..0a4089df4c9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +class GpuHloSupportCheckerTest : public HloTestBase { + protected: + GpuHloSupportChecker& checker() { return checker_; } + + private: + GpuHloSupportChecker checker_; +}; + +TEST_F(GpuHloSupportCheckerTest, Add) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(checker().Run(module.get()).status()); +} + +TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { + HloComputation::Builder builder(TestName()); + const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, sparse_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, sparse_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + sparse_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Status status = checker().Run(module.get()).status(); + ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_THAT(status.error_message(), + HasSubstr("GPU backend does not support")); + EXPECT_THAT(status.error_message(), + HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc new file mode 100644 index 00000000000..4a530bb0b20 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +/*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple, + int64 elements) { + CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + + HloComputation* computation = input_tuple->parent(); + const Shape& input_shape = input_tuple->shape(); + + std::vector tuple_elements; + tuple_elements.reserve(elements); + for (int i = 0; i < elements; i++) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + input_shape.tuple_shapes(i), input_tuple, i))); + } + + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + +/*static*/ HloInstruction* TupleUtil::AppendSuffix( + HloInstruction* input_tuple, + tensorflow::gtl::ArraySlice trailing_values) { + CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + + HloComputation* computation = input_tuple->parent(); + const Shape& input_shape = input_tuple->shape(); + std::vector tuple_elements; + tuple_elements.reserve(input_shape.tuple_shapes_size()); + for (int i = 0; i < input_shape.tuple_shapes_size(); i++) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + input_shape.tuple_shapes(i), input_tuple, i))); + } + tuple_elements.insert(tuple_elements.end(), trailing_values.begin(), + trailing_values.end()); + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h new file mode 100644 index 00000000000..e5ff9aaa835 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +class TupleUtil { + public: + // Generates HLO instructions to get a prefix tuple from `input_tuple` (which + // must be of tuple shape) of length `elements`. Returns the root of the + // graph of instructions generated. + // + // The instructions are generated into the computation containing + // `input_tuple`. + static HloInstruction* ExtractPrefix(HloInstruction* input_tuple, + int64 elements); + + // Generates HLO instructions to create a tuple that consists of the values in + // `trailing_values` appended to `input_tuple` (which must be of tuple shape). + // Returns the root of the graph of instructions generated. + // + // The instructions are generated into the computation containing + // `input_tuple`. + static HloInstruction* AppendSuffix( + HloInstruction* input_tuple, + tensorflow::gtl::ArraySlice trailing_values); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc new file mode 100644 index 00000000000..754fd8ef169 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_util.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +StatusOr> GetParsedModule( + HloComputation** entry_computation, HloInstruction** param0, + HloInstruction** param1) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[32,32]{1,0},f32[32,32]{1,0},f32[32,32]{1,0}) parameter(0) + ROOT p1 = f32[32,32]{1,0} parameter(1) +} +)"; + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + tools::Parse(hlo_string)); + + *entry_computation = module->entry_computation(); + *param0 = (*entry_computation)->parameter_instruction(0); + *param1 = (*entry_computation)->parameter_instruction(1); + + return std::move(module); +} + +TEST(TupleUtilTest, ExtractPrefix) { + HloInstruction *param0, *param1; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1)); + + HloInstruction* prefix = TupleUtil::ExtractPrefix(param0, 2); + + EXPECT_THAT(prefix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1))); +} + +TEST(TupleUtilTest, AppendSuffix) { + HloInstruction *param0, *param1; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1)); + + HloInstruction* with_suffix = + TupleUtil::AppendSuffix(param0, {param1, param1}); + + EXPECT_THAT(with_suffix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1), + op::GetTupleElement(op::Parameter(0), 2), + op::Parameter(1), op::Parameter(1))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc new file mode 100644 index 00000000000..a5f9b01f011 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -0,0 +1,296 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +using tensorflow::gtl::FlatMap; +using tensorflow::gtl::FlatSet; +using tensorflow::gtl::InlinedVector; + +// Copies `to_hoist` to the computation containing `while_instr`, hoisting its +// operands as needed. All of its transitive operands are expected to be either +// in `hoisted_instructions` or `unhoisted_invariant_instructions`. This +// function hoists the operands in `unhoisted_invariant_instructions` and moves +// them into `hoisted_instructions`. +static void CreateLoopInvariantCopy( + FlatMap* hoisted_instructions, + FlatSet* unhoisted_invariant_instructions, + HloInstruction* while_instr, HloInstruction* to_hoist) { + HloComputation* parent_of_while = while_instr->parent(); + HloComputation* while_body = while_instr->while_body(); + + struct DFSFrame { + HloInstruction* instruction; + int64 operand_index; + }; + + InlinedVector dfs_stack; + dfs_stack.push_back({to_hoist, 0}); + + HloInstruction* while_body_param = while_body->parameter_instruction(0); + HloInstruction* while_operand = while_instr->mutable_operand(0); + + do { + DFSFrame* frame = &dfs_stack.back(); + if (frame->operand_index == frame->instruction->operand_count()) { + HloInstruction* old_instruction = frame->instruction; + + // All of the operands for old_instruction have been cloned, so it is + // time to clone old_instruction itself. + + auto get_new_operand = [&](HloInstruction* old_operand) { + return old_operand == while_body_param + ? while_operand + : FindOrDie(*hoisted_instructions, old_operand); + }; + + InlinedVector new_operands; + c_transform(old_instruction->operands(), std::back_inserter(new_operands), + get_new_operand); + + HloInstruction* new_instruction = + parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( + old_instruction->shape(), new_operands)); + + InsertOrDie(hoisted_instructions, old_instruction, new_instruction); + + // Approximately half of the instructions that would normally be present + // in unhoisted_invariant_instructions are constants. We save a bit of + // compile time by not putting these in the hashtable. + CHECK_EQ(unhoisted_invariant_instructions->erase(old_instruction), + to_hoist != old_instruction && + old_instruction->opcode() != HloOpcode::kConstant); + dfs_stack.pop_back(); + continue; + } + + HloInstruction* next_operand = + frame->instruction->mutable_operand(frame->operand_index++); + if (hoisted_instructions->count(next_operand) || + next_operand == while_body_param) { + continue; + } + + dfs_stack.push_back({next_operand, 0}); + } while (!dfs_stack.empty()); +} + +// Returns true if `instruction` is worth hoisting only if it lets us hoist some +// instruction using it. The rationale is that hoisting these instructions will +// prevent simplification and fusion in the while body. +static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { + switch (instruction.opcode()) { + default: + return false; + + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kConstant: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kTuple: + return true; + + case HloOpcode::kTranspose: + return ShapeUtil::TransposeIsBitcast( + /*input_shape=*/instruction.operand(0)->shape(), + /*output_shape=*/instruction.shape(), instruction.dimensions()); + + case HloOpcode::kReshape: + return ShapeUtil::ReshapeIsBitcast( + /*input_shape=*/instruction.operand(0)->shape(), + /*output_shape=*/instruction.shape()); + } +} + +// Populates `gte_set` with the GetTupleElement instructions in `while_body` +// that access elements in the parameter tuple that don't change across +// iterations. Assumes `while_body` is the body computation of the while loop +// in question. +static void GatherInvariantGTEs(HloComputation* while_body, + FlatSet* gte_set) { + const HloInstruction::InstructionVector root_operands = + while_body->root_instruction()->operands(); + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && + instr->operand(0) == while_body->parameter_instruction(0) && + ShapeUtil::IsArray(instr->shape())) { + InsertOrDie(gte_set, instr); + } + } +} + +static StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr) { + auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); + + if (!ShapeUtil::IsTuple(while_instr->shape())) { + // This restriction leaves one interesting pattern on the table: + // + // while_body(f32[1024, 1024] %param) { + // %value = expensive_op(%param) + // outfeed(%value) + // ROOT = %param + // } + // + // If we see that pattern in the while, instead of generalizing this + // algorithm to work with non-tuples, we should instead add a pass that + // canonicalizes while loops like the above to use a tuple state. + return false; + } + + string while_instr_name = while_instr->ToString(print_no_metadata); + VLOG(2) << "Trying to hoist from " << while_instr_name; + + HloComputation* while_body = while_instr->while_body(); + + // Maps instructions in the while body to instructions hoisted outside the + // while that compute the same value. + FlatMap hoisted_instructions; + + // Contains instructions that can be legally hoisted, but were deemed to be + // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we + // hoist an instruction in this set, we move it from + // unhoisted_invariant_instructions to hoisted_instructions. + FlatSet unhoisted_invariant_instructions; + + // Invariant GTE's axiomatically satisfy the constraints for + // unhoisted_invariant_instructions -- they can be legally hoisted, but there + // is no benefit to hoisting them unless something that uses it is also + // hoisted. + GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions); + + if (unhoisted_invariant_instructions.empty()) { + // There are no obviously loop invariant elements in the state being + // threaded through the while loop so give up. In theory this precondition + // is too strong -- we could have code that e.g. permutes the elements in + // the while state but uses a select to pick the same value on every + // iteration. + return false; + } + + // instructions_to_replace[i] is hoisted into a loop invariant instruction + // replacement_instructions[i]. + std::vector instructions_to_replace; + std::vector replacement_instructions; + + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kParameter || + !instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + continue; + } + + auto is_invariant = [&](HloInstruction* op) { + return hoisted_instructions.find(op) != hoisted_instructions.end() || + unhoisted_invariant_instructions.count(op) || + op->opcode() == HloOpcode::kConstant; + }; + + if (!c_all_of(instruction->operands(), is_invariant)) { + continue; + } + + if (NotWorthHoistingIndividually(*instruction)) { + VLOG(2) << "Adding " << instruction->ToString(print_no_metadata) + << " to unhoisted invariant set."; + // Approximately half of the instructions that reach this point are + // constants. We save a bit of compile time by not putting these in the + // hashtable. + if (instruction->opcode() != HloOpcode::kConstant) { + InsertOrDie(&unhoisted_invariant_instructions, instruction); + } + continue; + } + + VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata); + + CreateLoopInvariantCopy(&hoisted_instructions, + &unhoisted_invariant_instructions, while_instr, + instruction); + + instructions_to_replace.push_back(instruction); + replacement_instructions.push_back( + FindOrDie(hoisted_instructions, instruction)); + } + + if (instructions_to_replace.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN( + WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions)); + + HloComputation* new_while_body = + live_in_instructions_result.new_while_instr->while_body(); + + for (int i = 0; i < instructions_to_replace.size(); i++) { + HloInstruction* instruction_to_replace_in_new_while = + FindOrDie(live_in_instructions_result.while_body_instruction_map, + instructions_to_replace[i]); + TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction( + instruction_to_replace_in_new_while, + live_in_instructions_result.while_body_live_in_values[i])); + } + + VLOG(1) << "Hoisted " << instructions_to_replace.size() + << " instructions from " << while_instr_name; + + return true; +} + +StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->computations()) { + c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + } + + for (HloInstruction* while_instr : while_instrs) { + // Right now we only hoist computations from the while body, but + // TryHoistingInvariantInstructionsFromWhileBody can be generalized to + // optimize the condition computation too, if needed. + // + // The transform we do here is a pessmization for while loops that execute + // zero times*, but at this time we expect those to be rare. If this + // becomes a problem we can consider using the conditional HLO to avoid + // doing extra work for while loops with zero trip count. + // + // * We delete while loops that have a zero trip count, so this would have + // to be a while loop with a somewhat opaque condition expression. + + TF_ASSIGN_OR_RETURN( + bool result, + TryHoistingInvariantInstructionsFromWhileBody(while_instr)); + changed |= result; + } + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h new file mode 100644 index 00000000000..8c4b765b000 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that rewrites while loops to hoist loop invariant instructions in +// the while body into the computation that contains the while instruction. + +class WhileLoopInvariantCodeMotion : public HloPassInterface { + public: + ~WhileLoopInvariantCodeMotion() override = default; + + tensorflow::StringPiece name() const override { + return "while-loop-invariant-code-motion"; + } + StatusOr Run(HloModule* module) override; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc new file mode 100644 index 00000000000..799340fda90 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -0,0 +1,442 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { + public: + // Makes a computation which has one parameter, of the given shape, and always + // returns PRED[]{true}. This is useful as a dummy loop condition. + HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, + HloModule* module); +}; + +static void FindOnlyWhileInstruction(HloComputation* computation, + HloInstruction** while_instruction) { + *while_instruction = nullptr; + for (auto* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + ASSERT_EQ(*while_instruction, nullptr); + *while_instruction = instr; + } + } + + ASSERT_NE(*while_instruction, nullptr); +} + +HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( + const Shape& param_shape, HloModule* module) { + HloComputation::Builder builder(TestName() + ".always_true"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "param")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + return module->AddEmbeddedComputation(builder.Build()); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Add()))); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* gte_2_loop_variant = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 2)); + + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + HloInstruction* mul_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kMultiply, add_result, gte_1)); + HloInstruction* negate_result = + builder.AddInstruction(HloInstruction::CreateUnary( + scalar_s32, HloOpcode::kNegate, mul_result)); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction* sub_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kSubtract, negate_result, constant)); + HloInstruction* divide_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, divide_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(entry_computation->instructions(), + AllOf(Contains(op::Add()), Contains(op::Multiply()), + Contains(op::Negate()), Contains(op::Subtract()), + Contains(op::Constant()), + + // The division had a loop varying operand so that better + // not be hoisted. + Not(Contains(op::Divide())))); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(), + op::Subtract(), op::Constant())))); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Contains(op::Divide())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, + DontHoistTriviallyLoopVaryingComputation) { + // Basic negative test: the add expression is not loop invariant. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, + DontHoistLoopVaryingComputationWithAlternatingTuples) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_1, gte_0, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Outfeed())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { + // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the + // bitcast either. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* bitcast_inst = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Outfeed())); + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Bitcast())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { + // The bitcast's user can be hoisted, so hoist the bitcast too. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_f32, param, 1)); + HloInstruction* bitcast_inst = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction* add_inst = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Add()))); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Bitcast()))); + EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); + EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body; + { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + TF_ASSERT_OK(param->AddControlDependencyTo(add_result)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_result})); + + while_body = module().AddEmbeddedComputation(builder.Build()); + } + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".passthrough"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + + result->AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + return result; + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 615a089d125..87a7f86f4ec 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -595,7 +595,9 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { auto call_op = computation->AddInstruction(HloInstruction::CreateCall( while_op->shape(), while_op->operands(), while_op->while_body())); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_op)); + TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_op)); + (void)inlined_instructions_map; return true; } return false; diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc new file mode 100644 index 00000000000..e20b25e4a08 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" + +namespace xla { + +static StatusOr WidenWhileCondition( + HloComputation* narrow_condition, const Shape& wide_shape) { + const Shape& narrow_shape = + narrow_condition->parameter_instruction(0)->shape(); + + HloComputation* wide_while_cond = [&]() { + HloComputation::Builder builder( + tensorflow::strings::StrCat("wide.", narrow_condition->name())); + builder.AddInstruction( + HloInstruction::CreateParameter(0, wide_shape, "wide_param")); + + // This is needed so that the root instruction is shaped as a PRED[] -- we + // need to get this right to begin with since we can't mutate the type of + // the root instruction later. We later change the root instruction to + // something more appropriate. + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + return narrow_condition->parent()->AddEmbeddedComputation(builder.Build()); + }(); + + HloInstruction* truncated_parameter = + TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0), + narrow_shape.tuple_shapes_size()); + HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction( + HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}), + {truncated_parameter}, narrow_condition)); + + wide_while_cond->set_root_instruction(call_narrow_cond); + + TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status()); + return wide_while_cond; +} + +static StatusOr> +WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { + const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape(); + + HloComputation* wide_while_body = [&]() { + HloComputation::Builder builder( + tensorflow::strings::StrCat("wide.", narrow_body->name())); + builder.AddInstruction( + HloInstruction::CreateParameter(0, wide_shape, "wide_param")); + return narrow_body->parent()->AddEmbeddedComputation(builder.Build()); + }(); + + HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0); + HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix( + wide_parameter, narrow_shape.tuple_shapes_size()); + HloInstruction* call_narrow_body = + wide_while_body->AddInstruction(HloInstruction::CreateCall( + narrow_shape, {truncated_parameter}, narrow_body)); + + std::vector live_through_values; + for (int i = narrow_shape.tuple_shapes_size(); + i < wide_shape.tuple_shapes_size(); i++) { + live_through_values.push_back( + wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + wide_shape.tuple_shapes(i), wide_parameter, i))); + } + + wide_while_body->set_root_instruction( + TupleUtil::AppendSuffix(call_narrow_body, live_through_values)); + + TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_narrow_body)); + return {{wide_while_body, std::move(inlined_instructions_map)}}; +} + +/*static*/ StatusOr +WhileUtil::MakeInstructionsLiveIn( + HloInstruction* while_instr, + tensorflow::gtl::ArraySlice instructions) { + CHECK(ShapeUtil::IsTuple(while_instr->shape())); + + int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); + Shape new_while_shape = while_instr->shape(); + for (auto* instruction : instructions) { + *new_while_shape.add_tuple_shapes() = instruction->shape(); + } + + TF_ASSIGN_OR_RETURN( + HloComputation * new_while_condition, + WidenWhileCondition(while_instr->while_condition(), new_while_shape)); + + HloComputation* new_while_body; + CallInliner::InlinedInstructionMap inlined_instructions_map; + TF_ASSIGN_OR_RETURN( + std::tie(new_while_body, inlined_instructions_map), + WidenWhileBody(while_instr->while_body(), new_while_shape)); + + HloInstruction* new_while_init = + TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions); + HloComputation* containing_computation = while_instr->parent(); + HloInstruction* new_while = containing_computation->AddInstruction( + HloInstruction::CreateWhile(new_while_shape, new_while_condition, + new_while_body, new_while_init)); + TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( + while_instr, TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + + HloInstruction* while_body_param = new_while_body->parameter_instruction(0); + std::vector live_in_instructions; + for (int64 i = elements_in_old_while_shape; + i < new_while_shape.tuple_shapes_size(); i++) { + live_in_instructions.push_back( + new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + instructions[i - elements_in_old_while_shape]->shape(), + while_body_param, i))); + } + + WhileUtil::MakeInstructionsLiveInResult result; + + result.new_while_instr = new_while; + result.while_body_live_in_values = std::move(live_in_instructions); + result.while_body_instruction_map = std::move(inlined_instructions_map); + + return std::move(result); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h new file mode 100644 index 00000000000..3600b5a80d2 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ + +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +class WhileUtil { + public: + // Holds a return value from MakeInstructionsLiveIn. + struct MakeInstructionsLiveInResult { + // The new while operation that has the requested values live in. + HloInstruction* new_while_instr; + + // The i'th element of `while_body_live_in_values` is an instruction in the + // while body that holds the i'th *newly added* live in value at runtime. + std::vector while_body_live_in_values; + + // `while_body_instruction_map` maps instructions in the original while body + // to the corresponding instructions in the body for the newly created while + // operation. + CallInliner::InlinedInstructionMap while_body_instruction_map; + }; + + // Replaces `while_instr` with a new while instruction that is equivalent to + // `while_instr`, except that it has all of the HLO instructions in + // `instructions` as live-in, loop invariant values. These new live in values + // are represented as new elements appended to the parameter of the while + // loop, which must be of tuple shape. GetTupleElement instructions computing + // each new live in value is returned in the `while_body_live_in_values` + // vector. + // + // Precondition: `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. + static StatusOr MakeInstructionsLiveIn( + HloInstruction* while_instr, + tensorflow::gtl::ArraySlice instructions); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc new file mode 100644 index 00000000000..cf0d0db99bd --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_util.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +StatusOr> GetParsedModule( + HloComputation** entry_computation, HloInstruction** param0, + HloInstruction** param1, HloInstruction** param2) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +while_body { + ROOT p_body = (f32[32,32]{1,0}, f32[32,32]{1,0}) parameter(0) +} + +while_condition { + p_cond = f32[32,32]{1,0} parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + p_entry_0 = f32[32,32]{1,0} parameter(0) + p_entry_1 = s32[32,32]{1,0} parameter(1) + p_entry_2 = s64[32,32]{1,0} parameter(2) + while_init = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p_entry_0, p_entry_0) + ROOT while = (f32[32,32]{1,0}, f32[32,32]{1,0}) while(while_init), condition=while_condition, body=while_body +} +)"; + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + tools::Parse(hlo_string)); + + *entry_computation = module->entry_computation(); + *param0 = (*entry_computation)->parameter_instruction(0); + *param1 = (*entry_computation)->parameter_instruction(1); + *param2 = (*entry_computation)->parameter_instruction(2); + + return std::move(module); +} + +TEST(WhileUtil, MakeZeroInstructionsLiveOp) { + HloInstruction *param0, *param1, *param2; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); + + HloInstruction* while_instr = entry_computation->root_instruction(); + ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, /*instructions=*/{})); + + HloInstruction* new_while_instr = make_live_in_result.new_while_instr; + + EXPECT_THAT( + entry_computation->root_instruction(), + op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), + op::GetTupleElement(::testing::Eq(new_while_instr), 1))); + + auto param_reconstructed = + op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1)); + + EXPECT_THAT(new_while_instr->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(param_reconstructed, 0), + op::GetTupleElement(param_reconstructed, 1))); +} + +TEST(WhileUtilTest, MakeTwoInstructionsLive) { + HloInstruction *param0, *param1, *param2; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); + + HloInstruction* while_instr = entry_computation->root_instruction(); + ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{param0, param1})); + + HloInstruction* new_while_instr = make_live_in_result.new_while_instr; + + XLA_VLOG_LINES(3, module->ToString()); + + EXPECT_THAT( + entry_computation->root_instruction(), + op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), + op::GetTupleElement(::testing::Eq(new_while_instr), 1))); + + auto first_half_param_reconstructed = + op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1)); + + EXPECT_THAT(new_while_instr->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(first_half_param_reconstructed, 0), + op::GetTupleElement(first_half_param_reconstructed, 1), + op::GetTupleElement(op::Parameter(0), 2), + op::GetTupleElement(op::Parameter(0), 3))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 3d4080e353e..290ea9b496a 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -84,7 +84,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (lhs.layout().format() != rhs.layout().format()) { return false; } - if (LayoutUtil::IsDense(lhs)) { + if (LayoutUtil::IsDenseArray(lhs)) { if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), LayoutUtil::MinorToMajor(rhs))) { VLOG(3) << "CompareShapes: lhs layout != rhs layout"; @@ -202,6 +202,17 @@ StatusOr MakeShapeWithLayoutInternal( return MakeShapeWithLayout(element_type, dimensions, layout); } +/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements) { + DCHECK_NE(TUPLE, element_type); + DCHECK_NE(OPAQUE, element_type); + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); + TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); + return shape; +} + /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape) { @@ -249,7 +260,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { - CHECK(LayoutUtil::IsDense(*shape)); + CHECK(LayoutUtil::IsDenseArray(*shape)); shape->mutable_layout()->add_minor_to_major(Rank(*shape)); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); @@ -658,23 +669,55 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { TF_DCHECK_OK(ValidateShape(shape)); DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { - CHECK_GT(pointer_size, 0); - return pointer_size * shape.tuple_shapes_size(); + return ByteSizeOfTupleIndexTable(shape, pointer_size); } + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; +} + +/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, + int64 pointer_size) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_GT(pointer_size, 0); + return pointer_size * shape.tuple_shapes_size(); +} + +/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; - if (shape.layout().padded_dimensions_size() > 0) { - CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size()); - allocated_element_count = 1; - for (int64 dimension_size : shape.layout().padded_dimensions()) { - allocated_element_count *= dimension_size; - } + + if (LayoutUtil::IsSparseArray(shape)) { + allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { - allocated_element_count = ElementsIn(shape); + CHECK(LayoutUtil::IsDenseArray(shape)); + tensorflow::gtl::ArraySlice padded_dimensions = + LayoutUtil::PaddedDimensions(shape); + if (!padded_dimensions.empty()) { + CHECK_EQ(Rank(shape), padded_dimensions.size()); + allocated_element_count = 1; + for (int64 dimension_size : padded_dimensions) { + allocated_element_count *= dimension_size; + } + } else { + allocated_element_count = ElementsIn(shape); + } } return allocated_element_count * ByteSizeOfPrimitiveType(shape.element_type()); } +/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(LayoutUtil::IsSparseArray(shape)); + return LayoutUtil::MaxSparseElements(shape.layout()) * + ShapeUtil::Rank(shape) * sizeof(int64); +} + /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == TUPLE) { @@ -900,7 +943,7 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { - CHECK(LayoutUtil::IsDense(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)); Layout* new_layout = new_shape.mutable_layout(); new_layout->set_format(DENSE); new_layout->clear_minor_to_major(); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 59bdffee5a8..453d4ec0472 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -143,7 +143,10 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); class ShapeUtil { public: // Returns the number of elements are contained within the provided shape; - // e.g. for rank 0 (scalars) the result is always 1. + // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes + // may not actually be able to store this number of elements. See + // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of + // elements that can be stored in a sparse shape. // Precondition: !IsTuple(shape) static int64 ElementsIn(const Shape& shape); @@ -164,6 +167,27 @@ class ShapeUtil { // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); + // Returns the number of bytes required to store the tuple member pointers for + // a allocation of shape. The `shape` must be a TUPLE shape, and + // `pointer_size` must be larger than zero. + static int64 ByteSizeOfTupleIndexTable(const Shape& shape, + int64 pointer_size); + + // Returns the number of bytes required for the elements in an allocation of + // `shape`, which must be an array shape. The return value does not include + // the bytes needed to store sparse indices. Dense shapes use a separate + // memory location for each element, and so for these shapes, + // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this + // size also includes padding if present in the layout. For sparse shapes, + // `ByteSizeOf(shape) == ByteSizeOfElements(shape) + + // ByteSizeOfSparseindices(shape)`. + static int64 ByteSizeOfElements(const Shape& shape); + + // Returns the number of bytes required for the sparse indices in an + // allocation of shape. The shape must be an array shape. The return value + // does not include the bytes needed to store sparse indices. + static int64 ByteSizeOfSparseIndices(const Shape& shape); + // Returns a human-readable string that represents the given shape, with or // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". static string HumanString(const Shape& shape); @@ -269,6 +293,10 @@ class ShapeUtil { PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major); + static Shape MakeShapeWithSparseLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements); + // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( PrimitiveType element_type, diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc new file mode 100644 index 00000000000..e7738e67903 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/sparse_index_array.h" + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +SparseIndexArray::SparseIndexArray() : rank_(0), max_indices_(0) {} + +SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, + std::vector indices) + : indices_(std::move(indices)), rank_(rank), max_indices_(max_indices) { + CHECK_GT(rank_, 0); + CHECK_EQ(indices_.size() % rank_, 0) + << "indices_.size(): " << indices_.size() << ", rank_: " << rank_; + CHECK_LT(index_count(), max_indices_); +} + +SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, + tensorflow::gtl::ArraySlice indices) + : SparseIndexArray(max_indices, rank, + std::vector(indices.begin(), indices.end())) {} + +SparseIndexArray::SparseIndexArray(int64 max_indices, + const Array2D& indices) + : SparseIndexArray(max_indices, indices.n2(), + std::vector(indices.begin(), indices.end())) {} + +int64 SparseIndexArray::index_count() const { + CHECK_GT(rank_, 0); + CHECK_EQ(indices_.size() % rank_, 0); + return indices_.size() / rank_; +} + +tensorflow::gtl::ArraySlice SparseIndexArray::At( + int64 sparse_index_number) const { + CHECK_GT(rank_, 0); + CHECK_GE(sparse_index_number, 0); + CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size()); + return tensorflow::gtl::ArraySlice( + indices_.data() + rank_ * sparse_index_number, rank_); +} + +tensorflow::gtl::MutableArraySlice SparseIndexArray::At( + int64 sparse_index_number) { + CHECK_GT(rank_, 0); + CHECK_GE(sparse_index_number, 0); + CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size()); + return tensorflow::gtl::MutableArraySlice( + indices_.data() + rank_ * sparse_index_number, rank_); +} + +void SparseIndexArray::Append(tensorflow::gtl::ArraySlice index) { + CHECK_GT(rank_, 0); + CHECK_EQ(index.size(), rank_); + indices_.insert(indices_.end(), index.begin(), index.end()); +} + +void SparseIndexArray::Clear() { indices_.clear(); } + +void SparseIndexArray::Resize(int64 num_indices) { + CHECK_GT(rank_, 0); + indices_.resize(rank_ * num_indices); +} + +bool SparseIndexArray::Validate(const Shape& shape) const { + if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) { + return false; + } + int64 num_indices = index_count(); + if (num_indices > LayoutUtil::MaxSparseElements(shape.layout())) { + return false; + } + if (num_indices < 2) { + return true; + } + tensorflow::gtl::ArraySlice last = At(0); + if (!IndexUtil::IndexInBounds(shape, last)) { + return false; + } + for (int64 n = 1; n < num_indices; ++n) { + tensorflow::gtl::ArraySlice next = At(n); + if (!IndexUtil::IndexInBounds(shape, next)) { + return false; + } + if (IndexUtil::CompareIndices(last, next) >= 0) { + return false; + } + last = next; + } + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h new file mode 100644 index 00000000000..f67f34760e0 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility class for managing sparse array indices. + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ + +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// Encapsulates the array of indices for a sparse array. A SparseIndexArray +// contain indices for up to `max_indices` elements of a sparse array. Each +// sparse index is an array of `rank` int64 value that gives the location of a +// value within a sparse array. Note that the dimensions of the array are not +// checked (except for the rank). To avoid confusion, we refer to the position +// of an index within a SparseIndexArray as a sparse index number. +class SparseIndexArray { + public: + SparseIndexArray(); + SparseIndexArray(const SparseIndexArray&) = default; + SparseIndexArray(SparseIndexArray&&) = default; + SparseIndexArray& operator=(const SparseIndexArray&) = default; + SparseIndexArray& operator=(SparseIndexArray&&) = default; + + // Constructs a SparseIndexArray that can hold up to `max_indices` sparse + // indices, with an initial contents obtained from the given array. The rank + // is taken from the minor dimension of the array. The major dimension of the + // array must not exceed `max_indices`. + SparseIndexArray(int64 max_indices, const Array2D& indices); + + // Like above, but the array is flattened. For example, the following are + // equivalent: + // + // SparseIndexArray(10, 3, + // Array2D{ + // {0, 1, 2}, + // {3, 4, 5}, + // {6, 7, 8}, + // {9, 10, 11}, + // }) + // + // SparseIndexArray(10, 3, + // {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) + // + SparseIndexArray(int64 max_indices, int64 rank, + std::vector indices = {}); + SparseIndexArray(int64 max_indices, int64 rank, + tensorflow::gtl::ArraySlice indices); + + // Returns the number of elements represented by the indices stored in the + // array. + int64 index_count() const; + + // Returns a slice that refers to the given sparse index number. The argument + // must be in the range [0, element_count()). + tensorflow::gtl::ArraySlice At(int64 sparse_index_number) const; + tensorflow::gtl::MutableArraySlice At(int64 sparse_index_number); + + // Adds the given index at the end of the array. The new size of the + // SparseIndexArray must not exceed `max_indices`. + void Append(tensorflow::gtl::ArraySlice index); + + // Removes all indices from the array. + void Clear(); + + // Resizes the array to contain the given number of sparse indices. The new + // size must be smaller than `max_indices`. If the new size is larger than + // the old size, the value of the new indices is not specified. + void Resize(int64 num_indices); + + // Returns true iff all indices are unique and occur in sorted order, and are + // valid for the given shape. + bool Validate(const Shape& shape) const; + + int64 rank() const { return rank_; } + int64 max_indices() const { return max_indices_; } + + // Returns a pointer to the int64 array that holds the sparse indices. + tensorflow::gtl::MutableArraySlice mutable_data() { return &indices_; } + tensorflow::gtl::ArraySlice data() const { return indices_; } + + // Sorts this sparse index array along with the set of corresponding values. + // The indices and values are sorted in the lexicographic order of the + // indices, from smallest to largest. + // + // For example: + // + // std::vector v{10.0, 11.0, 12.0}; + // SparseIndexArray a(10, 3, + // {{3, 4, 5}, + // {1, 2, 3}, + // {2, 3, 4}}); + // a.SortWithValues(&v); + // // Prints "11.0, 12.0, 10.0": + // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl; + // + template + void SortWithValues(tensorflow::gtl::MutableArraySlice values); + + private: + std::vector indices_; + int64 rank_; + int64 max_indices_; +}; + +template +void SparseIndexArray::SortWithValues( + tensorflow::gtl::MutableArraySlice values) { + int64 num_elements = index_count(); + CHECK_EQ(values.size(), num_elements); + std::vector sort_order; + sort_order.reserve(num_elements); + for (int64 i = 0; i < num_elements; ++i) { + sort_order.push_back(i); + } + auto sort_order_less = [this](int64 lhs, int64 rhs) { + return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; + }; + std::sort(sort_order.begin(), sort_order.end(), sort_order_less); + + // Reorder the array elements according to sort_order. Work through the array + // and follow cycles so we can do the reorder in-place. + tensorflow::gtl::InlinedVector saved_index(rank()); + for (int64 i = 0; i < num_elements; ++i) { + // sort_order[i] == -1 indicates the element has already been copied. + if (sort_order[i] < 0) { + continue; + } else if (i == sort_order[i]) { + // The element is already in sorted order. + sort_order[i] = -1; + continue; + } + + std::copy_n(At(i).begin(), rank(), saved_index.begin()); + NativeT saved_value = values[i]; + int64 j = i; + for (;;) { + if (sort_order[j] == i) { + std::copy_n(saved_index.begin(), rank(), At(j).begin()); + values[j] = saved_value; + sort_order[j] = -1; + break; + } + + std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin()); + values[j] = values[sort_order[j]]; + + int64 k = sort_order[j]; + sort_order[j] = -1; + j = k; + } + } +} + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc new file mode 100644 index 00000000000..7377f88958d --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array_test.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/sparse_index_array.h" + +#include + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(SparseIndexArrayTest, Sort) { + SparseIndexArray a(10, 3); + a.Append({2, 3, 4}); + a.Append({3, 4, 5}); + a.Append({1, 2, 3}); + a.Append({5, 6, 7}); + a.Append({4, 5, 6}); + a.Append({6, 7, 8}); + std::vector values = { + 12.0, 13.0, 11.0, 15.0, 14.0, 16.0, + }; + a.SortWithValues(&values); + ASSERT_EQ(a.data(), std::vector({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, + 6, 7, 6, 7, 8})); + ASSERT_EQ(values, std::vector({11.0, 12.0, 13.0, 14.0, 15.0, 16.0})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 0b3430ee1ee..7e7f6b14862 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -1160,6 +1160,50 @@ TEST_F(WhileTest, WhileWithCallInsideCondition) { ComputeAndCompareR0(&builder, 5, {}); } +TEST_F(WhileTest, WhileWithLoopInvariantOperation) { + auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto while_shape = ShapeUtil::MakeTupleShape( + {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto state = builder.Parameter(0, while_shape, "state"); + builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto state = builder.Parameter(0, while_shape, "state"); + auto indvar = builder.GetTupleElement(state, 0); + auto input_0 = builder.GetTupleElement(state, 1); + auto input_1 = builder.GetTupleElement(state, 2); + auto output = builder.Tanh(builder.Dot(input_0, input_1)); + auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); + auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + ComputationBuilder builder(client_, TestName()); + auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); + auto init = builder.Tuple( + {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); + auto while_instruction = builder.While(condition, body, init); + builder.GetTupleElement(while_instruction, 3); + + TF_ASSERT_OK_AND_ASSIGN(auto param_value, + client_->TransferToServer(*Literal::CreateR2( + {{1.0, 2.0}, {-1.0, -2.0}}))); + + ComputeAndCompareR2( + &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, + {param_value.get()}, ErrorSpec(4e-5)); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 2fc369dc0e6..75bedfabe27 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -1515,7 +1515,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, return false; } } else { - return TokenError(StrCat("unsupported premitive type ", + return TokenError(StrCat("unsupported primitive type ", PrimitiveType_Name(shape.element_type()))); } break; @@ -1851,7 +1851,7 @@ bool HloParser::ParseWindow(Window* window) { if (field_name == "rhs_reversal") { return ParseDxD("rhs_reversal", &rhs_reversal); } - return Error(loc, StrCat("unexpected attribute name: ", field_name)); + return Error(attr_loc, StrCat("unexpected attribute name: ", field_name)); }(); if (!ok) { return false; diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 277cc5ec86f..bb2db2010c5 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -398,6 +398,31 @@ std::vector> CommonFactors( // Removes illegal characters from filenames. string SanitizeFileName(string file_name); +// Simple wrapper around std::all_of. +template +bool c_all_of(Container container, Predicate predicate) { + return std::all_of(std::begin(container), std::end(container), predicate); +} + +// Simple wrapper around std::transform. +template +OutputIterator c_transform(InputContainer input_container, + OutputIterator output_iterator, + UnaryOperation unary_op) { + return std::transform(std::begin(input_container), std::end(input_container), + output_iterator, unary_op); +} + +// Simple wrapper around std::copy_if. +template +OutputIterator c_copy_if(InputContainer input_container, + OutputIterator output_iterator, + UnaryPredicate predicate) { + return std::copy_if(std::begin(input_container), std::end(input_container), + output_iterator, predicate); +} + } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index e34f138b6ed..3aea0217539 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -120,6 +120,9 @@ enum Format { // The default layout, with exactly one storage location per element (ignoring // padding). DENSE = 1; + // A sparsely encoded layout, providing only the index/value pairs of non-zero + // elements. + SPARSE = 2; } // A layout describes how the array is placed in (1D) memory space. This @@ -151,6 +154,11 @@ message Layout { // field must be unset unless the format is DENSE. PaddingValue padding_value = 3; + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be unset unless the format is SPARSE. + int64 max_sparse_elements = 5; + // Important: if any field is added, be sure to modify ShapeUtil::Equal() // appropriately to account for the new field. } @@ -333,7 +341,8 @@ message LiteralProto { // The F16s and BF16s are encoded in little endian byte order bytes f16s = 11; bytes bf16s = 13; - // Next = 14 + repeated int64 sparse_indices = 14; + // Next = 15 } message WindowDimension { diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index e18cf6c3505..aa8891ab4ef 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -210,6 +210,7 @@ std::unique_ptr Batch::RemoveTask() { return nullptr; } std::unique_ptr task = std::move(tasks_.back()); + size_ -= task->size(); tasks_.pop_back(); return task; } diff --git a/tensorflow/contrib/batching/batch_scheduler_test.cc b/tensorflow/contrib/batching/batch_scheduler_test.cc index f15d8cc8e57..b627fee972a 100644 --- a/tensorflow/contrib/batching/batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/batch_scheduler_test.cc @@ -74,7 +74,9 @@ TEST(BatchTest, Basic) { EXPECT_EQ(task1->size(), batch.task(1).size()); EXPECT_EQ(7, batch.RemoveTask()->size()); + EXPECT_EQ(3, batch.size()); EXPECT_EQ(3, batch.RemoveTask()->size()); + EXPECT_EQ(0, batch.size()); EXPECT_TRUE(batch.empty()); } diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index e13c60c9a71..b1937c08f34 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -524,7 +524,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.1) # Sleep to consistently "avoid" the race condition. + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True @@ -611,7 +611,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.1) # Sleep to consistently "avoid" the race condition. + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index 87c80740a8f..f6de5998d73 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -7,7 +7,11 @@ exports_files([ "generic_tree_model_proto.swig", ]) -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", + "tf_pyclif_proto_library", +) filegroup( name = "all_files", @@ -34,3 +38,10 @@ tf_proto_library( protodeps = [":generic_tree_model"], visibility = ["//visibility:public"], ) + +tf_pyclif_proto_library( + name = "generic_tree_model_pyclif", + proto_lib = ":generic_tree_model", + proto_srcfile = "generic_tree_model.proto", + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD index 56a24ac77f0..5543eb6c6e3 100644 --- a/tensorflow/contrib/nn/BUILD +++ b/tensorflow/contrib/nn/BUILD @@ -17,6 +17,7 @@ py_library( "python/ops/__init__.py", "python/ops/alpha_dropout.py", "python/ops/cross_entropy.py", + "python/ops/fwd_gradients.py", "python/ops/sampling_ops.py", "python/ops/scaled_softplus.py", ], @@ -28,6 +29,7 @@ py_library( "//tensorflow/python:embedding_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:function", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", @@ -55,6 +57,19 @@ py_test( ], ) +py_test( + name = "fwd_gradients_test", + size = "small", + srcs = ["python/ops/fwd_gradients_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":nn_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:math_ops", + ], +) + py_test( name = "sampling_ops_test", size = "small", diff --git a/tensorflow/contrib/nn/python/ops/fwd_gradients.py b/tensorflow/contrib/nn/python/ops/fwd_gradients.py new file mode 100644 index 00000000000..922497779b1 --- /dev/null +++ b/tensorflow/contrib/nn/python/ops/fwd_gradients.py @@ -0,0 +1,76 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Forward-mode derivatives.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.gradients_impl import gradients + + +def fwd_gradients(ys, xs, grad_xs=None, assert_unused=False): + """Computes forward-mode derivatives. + + This is accomplished in pure-python using tensorflow's existing (reverse-mode) + gradients. There is additional overhead on graph construction, but runtime + performance should be equal to a manual implementation [citation needed]. + + See https://j-towns.github.io/2017/06/12/A-new-trick.html and + https://github.com/HIPS/autograd/pull/175 for the original discussion of this + method, and https://github.com/renmengye/tensorflow-forward-ad for a "direct" + implementation. + + Args: + ys: A list of tensors. + xs: A list of tensors. + grad_xs: An optional list of tensors. If provided, must have the same length + and shapes compatible with xs. + assert_unused: Add assertions that intermediate values are not computed. + Returns: + A list of tensors of the same shapes as ys. The directional derivatives of + ys with respect to xs in the direction grad_xs. Leaving grad_xs unspecified + is equivalent to passing in 1s for each x in xs. + """ + # This version of forward-mode autodiff is based on code by Tim Cooijmans + # and handles list arguments and certain special cases such as when the + # ys doesn't depend on one or more of the xs, and when tf.IndexedSlices are + # generated by the first tf.gradients call. + + us = [array_ops.zeros_like(y) + float('nan') for y in ys] + + dydxs = gradients(ys, xs, grad_ys=us) + + # deal with strange types that tf.gradients returns but can't deal with + dydxs = [ops.convert_to_tensor(dydx) if isinstance(dydx, ops.IndexedSlices) + else dydx for dydx in dydxs] + + if assert_unused: + with ops.control_dependencies(dydxs): + assert_unused = control_flow_ops.Assert(False, [1], name='fwd_gradients') + with ops.control_dependencies([assert_unused]): + dydxs = array_ops.identity_n(dydxs) + + dydxs = [array_ops.zeros_like(x) if dydx is None else dydx + for x, dydx in zip(xs, dydxs)] + for x, dydx in zip(xs, dydxs): + dydx.set_shape(x.shape) + + dysdx = gradients(dydxs, us, grad_ys=grad_xs) + + return dysdx diff --git a/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py b/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py new file mode 100644 index 00000000000..56062c3cab3 --- /dev/null +++ b/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for forward_ad.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.nn.python.ops import fwd_gradients +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ForwardAdTest(test.TestCase): + + def testSquare(self): + x = constant_op.constant(1.) + y = math_ops.square(x) + grad_x = 3. + + dydx_tf = fwd_gradients.fwd_gradients([y], [x], [grad_x])[0] + dydx_py = 2. * grad_x + + with self.test_session() as sess: + self.assertAllClose(sess.run(dydx_tf), dydx_py, 1e-6) + + def testGather(self): + x = constant_op.constant([1., 2., 3.]) + y = array_ops.gather(x, [0, 1]) + y.set_shape([2]) + dydx = fwd_gradients.fwd_gradients([y], [x], assert_unused=True) + + with self.test_session() as sess: + sess.run(dydx) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 0398c2a60d1..5235e520568 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -226,6 +226,11 @@ bool IsConstantFoldable( if (consider && !consider(n)) { return false; } + // PlaceholderWithDefault shouldn't be constant folded because its output can + // be fed non-constant values. + if (n->type_string() == "PlaceholderWithDefault") { + return false; + } if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) { return false; } diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 923a4d92493..31f41e133b9 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -338,6 +338,40 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) { EXPECT_FALSE(was_mutated); } +TEST_F(ConstantFoldingTest, Placeholders) { + Graph g(OpRegistry::Global()); + { + Scope s = Scope::NewRootScope(); + auto placeholder = ops::Placeholder(s, DT_DOUBLE); + auto add = ops::Add(s, placeholder, 2.0); + auto send = + ops::_Send(s.WithOpName("send"), add, "add", "sender", 0, "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + } + bool was_mutated; + Status s = ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(), + nullptr, &g, &was_mutated); + EXPECT_FALSE(was_mutated); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find( + "You must feed a value for placeholder " + "tensor 'Placeholder' with dtype double") != string::npos); + + Graph g2(OpRegistry::Global()); + { + Scope s = Scope::NewRootScope(); + auto placeholder = ops::PlaceholderWithDefault(s, {1.0}, {1}); + auto add = ops::Add(s, placeholder, 2.0); + auto send = + ops::_Send(s.WithOpName("send"), add, "add", "sender", 0, "receiver"); + TF_ASSERT_OK(s.ToGraph(&g2)); + } + // TODO(skyewm): should this have the same behavior as Placeholder? + TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(), + nullptr, &g2, &was_mutated)); + EXPECT_FALSE(was_mutated); +} + TEST_F(ConstantFoldingTest, ControlDependencies) { Graph g(OpRegistry::Global()); { diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 3ae52f414fa..45cdab98e06 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -558,6 +558,13 @@ Status ShapeRefiner::ExtractConstantSubgraph( return Status::OK(); } + if (target_node->type_string() == "PlaceholderWithDefault") { + return Status::OK(); + } + + // TODO(skyewm): more of the filtering applied in input nodes below should be + // applied to target_node here + struct NodeAndRecursed { Node* new_node = nullptr; bool recursed = false; @@ -608,6 +615,14 @@ Status ShapeRefiner::ExtractConstantSubgraph( return Status::OK(); } + // Placeholders should never be constant folded because their outputs are + // fed by the user. Note that "Placeholder" nodes have no inputs so are + // handled below. + if (current_node->type_string() == "PlaceholderWithDefault") { + *is_constant_graph = false; + return Status::OK(); + } + // If there is nothing more to recurse down, see if // the generator node is a constant. if (current_node->num_inputs() == 0) { diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index e4eef1dbe28..adf5a9afff2 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -724,6 +724,25 @@ TEST_F(ShapeRefinerTest, PropagateRange) { EXPECT_EQ("[1,4,7,10]", ctx->DebugString(ctx->output(0))); } +// Make sure PlaceholderWithDefaults aren't treated as constants. +TEST_F(ShapeRefinerTest, NoPropagatePlaceholderWithDefault) { + Scope root = Scope::NewRootScope(); + auto constant = ops::Const(root, 2); + auto placeholder = + ops::PlaceholderWithDefault(root, constant, PartialTensorShape()); + Node* shape_data; + TF_ASSERT_OK(NodeBuilder("Test", "ShapeData") + .Input(placeholder.node()) + .Finalize(root.graph(), &shape_data)); + + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + TF_ASSERT_OK(m.AddNode(constant.node())); + TF_ASSERT_OK(m.AddNode(placeholder.node())); + TF_ASSERT_OK(m.AddNode(shape_data)); + shape_inference::InferenceContext* ic = m.GetContext(shape_data); + EXPECT_EQ(ic->DebugString(ic->output(0)), "?"); +} + TEST_F(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) { Scope root = Scope::NewRootScope(); // This node is used as two inputs to 'range'. diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc index 3168758c8bd..3604de392f8 100644 --- a/tensorflow/core/grappler/costs/graph_memory.cc +++ b/tensorflow/core/grappler/costs/graph_memory.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_memory.h" #include #include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor_description.pb.h" @@ -163,6 +164,8 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) { NodeMap node_map(&item_.graph); for (const auto& dev_stats : timeline.dev_stats()) { + const string& device_name = dev_stats.device(); + const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:")); std::list& device_tensors = live_tensors_per_device[dev_stats.device()]; for (const auto& node_stats : dev_stats.node_stats()) { @@ -194,7 +197,24 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) { // graph (e.g _Send/_Recv nodes). continue; } - for (const string& input : node->input()) { + std::unordered_set swapped_inputs; + if (is_gpu) { + auto it = node->attr().find("_swap_to_host"); + if (it != node->attr().end()) { + const AttrValue& val = it->second; + for (int port_id : val.list().i()) { + swapped_inputs.insert(port_id); + } + } + } + for (int i = 0; i < node->input_size(); ++i) { + if (swapped_inputs.find(i) != swapped_inputs.end()) { + // The memory of swapped inputs will be released as early as possible: + // therefore ignore this input when determining the deallocation time + // of the tensor. + continue; + } + const string& input = node->input(i); int position; string input_node = ParseNodeName(input, &position); if (position < 0) { diff --git a/tensorflow/core/grappler/costs/graph_memory_test.cc b/tensorflow/core/grappler/costs/graph_memory_test.cc index 6f3522b068b..95170ba49b7 100644 --- a/tensorflow/core/grappler/costs/graph_memory_test.cc +++ b/tensorflow/core/grappler/costs/graph_memory_test.cc @@ -134,6 +134,62 @@ TEST_F(GraphMemoryTest, MultiDevice) { EXPECT_EQ(gpu_expected, gpu_tensors); } +TEST_F(GraphMemoryTest, GpuSwapping) { + TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false, {"/GPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + item.feed.clear(); + + { + // Estimate the max memory usage for the graph. + GraphMemory memory(item); + Status s = memory.InferStatically(devices_); + TF_CHECK_OK(s); + + const GraphMemory::MemoryUsage& gpu_mem = + memory.GetPeakMemoryUsage("/GPU:0"); + EXPECT_EQ(20971520, gpu_mem.used_memory); + std::set gpu_tensors; + for (const auto& t : gpu_mem.live_tensors) { + gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); + } + std::set gpu_expected; + gpu_expected.insert("Square:0"); + gpu_expected.insert("Square_1:0"); + gpu_expected.insert("AddN:0"); + gpu_expected.insert("AddN_1:0"); + gpu_expected.insert("AddN_2:0"); + EXPECT_EQ(gpu_expected, gpu_tensors); + } + + { + // Swap the first input to node AddN_1: its fanin (the square nodes) should + // not appear in the max cut anymore. + for (auto& node : *item.graph.mutable_node()) { + if (node.name() == "AddN_1") { + (*node.mutable_attr())["_swap_to_host"].mutable_list()->add_i(0); + } + } + GraphMemory memory(item); + Status s = memory.InferStatically(devices_); + TF_CHECK_OK(s); + const GraphMemory::MemoryUsage& new_gpu_mem = + memory.GetPeakMemoryUsage("/GPU:0"); + EXPECT_EQ(20971520, new_gpu_mem.used_memory); + std::set new_gpu_tensors; + for (const auto& t : new_gpu_mem.live_tensors) { + new_gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); + } + std::set new_gpu_expected; + new_gpu_expected.insert("AddN:0"); + new_gpu_expected.insert("AddN_1:0"); + new_gpu_expected.insert("AddN_2:0"); + new_gpu_expected.insert("AddN_3:0"); + new_gpu_expected.insert("AddN_4:0"); + EXPECT_EQ(new_gpu_expected, new_gpu_tensors); + } +} + TEST_F(GraphMemoryTest, CtrlDependencies) { // Build a simple graph with a control dependency. Scope s = Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc index 6d25556770d..ec54bd5c759 100644 --- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc +++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc @@ -31,8 +31,6 @@ namespace { GraphDef CreateGraphDef(int num_stages, int width, int tensor_size, bool use_multiple_devices, bool insert_queue, const std::vector& device_names) { - CHECK_GE(device_names.size(), width); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -49,13 +47,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size, std::vector this_stage; for (int j = 0; j < width; j++) { if (last_stage.size() == 1) { - Output unary_op = - Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]), - last_stage[0]); + Output unary_op = Square( + s.WithDevice( + device_names[use_multiple_devices ? j % device_names.size() + : 0]), + last_stage[0]); this_stage.push_back(unary_op); } else { Output combine = - AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]), + AddN(s.WithDevice( + device_names[use_multiple_devices ? j % device_names.size() + : 0]), last_stage); this_stage.push_back(combine); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 9f24f1c7683..68feedbcbb0 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -433,13 +433,42 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( id = --min_id; } } + + // Beware: the reduction dimensions computed by the BCast class are valid iff + // we assume that two distinct symbolic dimensions can't be equal and a + // symbolic dimension can't be equal to 1. This is often but not always true, + // so to make this optimization safe we filter out these cases. + const int common_dims = std::min(shape1.size(), shape2.size()); + for (int i = 0; i < common_dims; ++i) { + if (shape1[i] >= 0 && shape2[i] >= 0) { + continue; + } + if (shape1[i] != shape2[i]) { + // We're either dealing with 2 different symbolic dimensions or a symbolic + // and a know dimensions. We can't be sure whether both are equal or not, + // so we can't be sure whether we'll be broadcasting or not. + return Status::OK(); + } + } + // These extra dims could be equal to 1, in which case there is no + // broadcasting. It could also be greater than 1, in which case there would + // be broadcasting. Since we don't know, we'll just punt. + for (int i = common_dims; i < shape1.size(); ++i) { + if (shape1[i] < 0) { + return Status::OK(); + } + } + for (int i = common_dims; i < shape2.size(); ++i) { + if (shape2[i] < 0) { + return Status::OK(); + } + } + BCast bcast(shape1, shape2); if (!bcast.IsValid()) { return Status::OK(); } - // Beware: the reduction dimensions are valid iff we assume that two distinct - // symbolic dimensions can't be equal. This is often but not always true, so - // this optimization isn't safe. + BCast::Vec reduce_dims[2]; reduce_dims[0] = bcast.grad_x_reduce_idx(); reduce_dims[1] = bcast.grad_y_reduce_idx(); @@ -447,26 +476,27 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( const DataType type = node.attr().at("T").type(); NodeDef* out[2]; for (int j = 0; j < 2; ++j) { - if (!reduce_dims[j].empty()) { - // This is the case when a tensor dimension of 1 is matched against an - // unknown dimension. The unknown dimension could also be equal to 1, in - // which case there would be no reduction. - out[j] = nullptr; - } else { - string const_name = OptimizedNodeName(node, strings::StrCat("-", j)); - out[j] = node_map_->GetNode(const_name); - if (out[j] == nullptr) { - out[j] = graph_->add_node(); - Tensor value(type, TensorShape({0})); - *out[j] = CreateNodeDef(const_name, TensorValue(&value)); - out[j]->set_device(node.device()); - node_map_->AddNode(const_name, out[j]); - string ctrl_dep = - AddControlDependency(node.name(), graph_, node_map_.get()); - *out[j]->add_input() = ctrl_dep; - node_map_->AddOutput(NodeName(ctrl_dep), const_name); + int reduction_indices = reduce_dims[j].size(); + Tensor value(type, TensorShape({reduction_indices})); + for (int i = 0; i < reduction_indices; ++i) { + if (type == DT_INT32) { + value.vec()(i) = reduce_dims[j][i]; + } else { + value.vec()(i) = reduce_dims[j][i]; } } + string const_name = OptimizedNodeName(node, strings::StrCat("-", j)); + out[j] = node_map_->GetNode(const_name); + if (out[j] == nullptr) { + out[j] = graph_->add_node(); + *out[j] = CreateNodeDef(const_name, TensorValue(&value)); + out[j]->set_device(node.device()); + node_map_->AddNode(const_name, out[j]); + string ctrl_dep = + AddControlDependency(node.name(), graph_, node_map_.get()); + *out[j]->add_input() = ctrl_dep; + node_map_->AddOutput(NodeName(ctrl_dep), const_name); + } } const std::set outputs = node_map_->GetOutputs(node.name()); @@ -584,12 +614,11 @@ Status ConstantFolding::MaterializeReductionIndices( Status ConstantFolding::MaterializeConstants( const GraphProperties& properties) { - const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; const int node_count = graph_->node_size(); for (int i = 0; i < node_count; ++i) { NodeDef& node = *graph_->mutable_node(i); const string& op = node.op(); - if (is_aggressive && op == "BroadcastGradientArgs") { + if (op == "BroadcastGradientArgs") { TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties)); } else if (IsReduction(node)) { TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties)); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index a3b3e522eb8..c53678f727f 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -1373,21 +1373,14 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) { } else if (node.name() == "p1") { ++found; EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("ConstantFolding/i-0", node.input(0)); + EXPECT_EQ("i", node.input(0)); } else if (node.name() == "p2") { ++found; EXPECT_EQ(1, node.input_size()); EXPECT_EQ("i:1", node.input(0)); - } else if (node.name() == "ConstantFolding/i-0") { - ++found; - EXPECT_EQ("Const", node.op()); - EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("^i", node.input(0)); - EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape()) - .num_elements()); } } - EXPECT_EQ(7, found); + EXPECT_EQ(6, found); } TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { diff --git a/tensorflow/core/util/equal_graph_def.cc b/tensorflow/core/util/equal_graph_def.cc index a3b7db98cc0..f1ec497a677 100644 --- a/tensorflow/core/util/equal_graph_def.cc +++ b/tensorflow/core/util/equal_graph_def.cc @@ -148,7 +148,10 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff, first_control_input = i; break; } - if (actual.input(i) != expected.input(i)) { + // Special case for inputs: "tensor" is equivalent to "tensor:0" + if (actual.input(i) != expected.input(i) && + actual.input(i) != strings::StrCat(expected.input(i), ":0") && + strings::StrCat(actual.input(i), ":0") != expected.input(i)) { if (diff != nullptr) { *diff = strings::StrCat("Node named '", actual.name(), "' has input ", i, " '", actual.input(i), diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index e495857afe5..42c4f81b82e 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -5922,6 +5922,7 @@ func RandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types [ // // This op is hidden from public in Python. It is used by TensorFlow Debugger to // register gradient tensors for gradient debugging. +// This op operates on non-reference-type tensors. func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 048dc92d049..ec31bc9bb71 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -344,7 +344,7 @@ def implicit_val_and_grad(f): def grad_fn(*args): """Computes the gradient of the wrapped function.""" - tape.push_new_tape() + this_tape = tape.push_new_tape() try: end_node = f(*args) if end_node is None: @@ -352,10 +352,10 @@ def implicit_val_and_grad(f): "did you forget to return a value from {}?".format( f.__name__)) finally: - popped_tape = tape.pop_tape() + tape.pop_tape(this_tape) # Sorting variables by id, which is monotonically increasing in construction # order. This ensures unique order across executions. - variables = list(sorted(popped_tape.watched_variables(), + variables = list(sorted(this_tape.watched_variables(), key=lambda v: v.handle._id)) # pylint: disable=protected-access sources = [x.handle for x in variables] @@ -363,7 +363,7 @@ def implicit_val_and_grad(f): raise ValueError("No trainable variables were accessed while the " "function was being computed.") grad = imperative_grad.imperative_grad(_default_vspace, - popped_tape, + this_tape, nest.flatten(end_node), sources) return end_node, list(zip(grad, variables)) @@ -652,7 +652,7 @@ def make_vjp(f, params=None): """Computes the value and gradient of the decorated function.""" parameter_positions = _get_arg_spec(f, params, args) assert not kwds, "The gradient function can't take keyword arguments." - tape.push_new_tape() + this_tape = tape.push_new_tape() try: sources = [] args = [ @@ -673,12 +673,12 @@ def make_vjp(f, params=None): flat_result = [gen_array_ops.identity(x) for x in flat_result] result = nest.pack_sequence_as(result, flat_result) finally: - t = tape.pop_tape() + tape.pop_tape(this_tape) def vjp(dy=None): if dy is not None: dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)] return imperative_grad.imperative_grad( - _default_vspace, t, nest.flatten(result), sources, + _default_vspace, this_tape, nest.flatten(result), sources, output_gradients=dy) return result, vjp @@ -835,11 +835,11 @@ class GradientTape(object): self._persistent = persistent def __enter__(self): - tape.push_new_tape(persistent=self._persistent) + self._tape = tape.push_new_tape(persistent=self._persistent) return self def __exit__(self, typ, value, traceback): - self._tape = tape.pop_tape() + tape.pop_tape(self._tape) def watch(self, tensor): """Ensures that `tensor` is being traced by this tape. diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 3173afc4240..e1ab1e7bc64 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -24,9 +24,12 @@ import copy import random import threading +from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors +from tensorflow.python.util import compat from tensorflow.python.util import tf_contextlib GRAPH_MODE = 0 @@ -398,6 +401,36 @@ class Context(object): """Get the list of post-execution callbacks added to the context.""" return self._post_execution_callbacks + def enable_run_metadata(self): + """Enables tracing of op execution via RunMetadata. + + To retrieve the accumulated metadata call context.export_run_metadata() + and to stop tracing call context.disable_run_metadata(). + """ + pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle) + + def disable_run_metadata(self): + """Disables tracing of op execution via RunMetadata.""" + pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle) + + def export_run_metadata(self): + """Returns a RunMetadata proto with accumulated information. + + The returned protocol buffer contains information since the most recent call + to either enable_run_metadata or export_run_metadata. + + Returns: + A RunMetadata protocol buffer. + """ + with c_api_util.tf_buffer() as buffer_: + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TFE_ContextExportRunMetadata( + self._context_handle, buffer_, status) + proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) + run_metadata = config_pb2.RunMetadata() + run_metadata.ParseFromString(compat.as_bytes(proto_data)) + return run_metadata + _context = None _context_lock = threading.Lock() @@ -516,3 +549,29 @@ def num_gpus(): The number of available GPU devices. """ return context().num_gpus() + + +def enable_run_metadata(): + """Enables tracing of op execution via RunMetadata. + + To retrieve the accumulated metadata call context.export_run_metadata() + and to stop tracing call context.disable_run_metadata(). + """ + context().enable_run_metadata() + + +def disable_run_metadata(): + """Disables tracing of op execution via RunMetadata.""" + context().disable_run_metadata() + + +def export_run_metadata(): + """Returns a RunMetadata proto with accumulated information. + + The returned protocol buffer contains information since the most recent call + to either enable_run_metadata or export_run_metadata. + + Returns: + A RunMetadata protocol buffer. + """ + return context().export_run_metadata() diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 02694b34fe9..a70fa728048 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -84,6 +84,20 @@ class TFETest(test_util.TensorFlowTestCase): self.assertTrue(has_cpu_device) del ctx + def testRunMetadata(self): + context.enable_run_metadata() + t = constant_op.constant(1.0) + _ = t + t # Runs an operation which will be in the RunMetadata + run_metadata = context.export_run_metadata() + context.disable_run_metadata() + step_stats = run_metadata.step_stats + self.assertGreater(len(step_stats.dev_stats), 0) + cpu_stats = step_stats.dev_stats[0] + self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', + cpu_stats.device) + self.assertEqual(len(cpu_stats.node_stats), 1) + self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add') + def testContextStackContainsEagerMode(self): # Eager execution has been enabled, and no other context # switch has occurred, so `context_stack` should contain diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 969e321dd12..f755434ad78 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -544,11 +544,12 @@ def _defun_internal(name, func, args, kwds): func_inputs = _get_defun_inputs(args) with capture_tensors(captures): - tape.push_new_tape() + this_tape = tape.push_new_tape() try: func_outputs = func(*func_inputs, **kwds) finally: - variables = tape.pop_tape().watched_variables() + tape.pop_tape(this_tape) + variables = this_tape.watched_variables() # Returning a closed-over tensor as an output does not trigger a # call to convert_to_tensor, so we manually capture all such tensors. diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 91192fea62d..6fa076507d1 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -332,7 +332,7 @@ void EagerTensor_dealloc(EagerTensor* self) { tensorflow::ClearDecrefCache(); auto id = self->id; Py_TYPE(self)->tp_free(self); - TFE_Py_TapeStackDeleteTrace(id); + TFE_Py_TapeSetDeleteTrace(id); } // Getter for `_id`. diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index a33b17ada6f..cecef426032 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -87,22 +87,25 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); // newly created type, or nullptr on error. PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); -// Pushes a new tape into the thread-local stack. -// `persistent` must be a PyBool_Type, i.e either Py_True or Py_False -void TFE_Py_TapeStackPushNew(PyObject* persistent); +// Creates a new tape and adds it to the active set. `persistent` must be a +// PyBool_Type, i.e either Py_True or Py_False +PyObject* TFE_Py_TapeSetNew(PyObject* persistent); -// Pops the tape from the top of the stack and returns it. -PyObject* TFE_Py_TapeStackPop(); - -// Pushes an existing tape onto the stack. -void TFE_Py_TapeStackPush(PyObject* tape); +// Removes the passed tape from the set of active tapes. +void TFE_Py_TapeSetRemove(PyObject* tape); // Returns true if the tape stack is empty. -PyObject* TFE_Py_TapeStackIsEmpty(); +PyObject* TFE_Py_TapeSetIsEmpty(); -PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors); -void TFE_Py_TapeStackWatch(PyObject* tensor); -void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id); +PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors); +void TFE_Py_TapeSetWatch(PyObject* tensor); +void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id); + +// Stops any gradient recording on the current thread. +void TFE_Py_TapeSetStopOnThread(); + +// Restarts gradient recording on the current thread. +void TFE_Py_TapeSetRestartOnThread(); // Records an operation in the gradient tape stack.type is a string for the // operation type, used in the backprop code. output_tensors should be a list of @@ -111,13 +114,12 @@ void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id); // operation. backward_function should be the function to be called during // backprop to, given the gradients of the output tensors, produce the gradients // of the input tensors. -void TFE_Py_TapeStackRecordOperation(PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensor_ids, - PyObject* backward_function); +void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, + PyObject* input_tensor_ids, + PyObject* backward_function); // Watches the given variable object on the given tape. -void TFE_Py_TapeStackWatchVariable(PyObject* variable); +void TFE_Py_TapeSetWatchVariable(PyObject* variable); // Computes a gradient based on information recorded on the tape.`tape` must // have been produced by TFE_Py_NewTape. `vspace` must be a diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 3ba81fb3d04..bdaeccf2860 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -538,62 +538,67 @@ static PyTypeObject TFE_Py_Tape_Type = { "TFE_Py_Tape objects", /* tp_doc */ }; +// Note: in the current design no mutex is needed here because of the python +// GIL, which is always held when any TFE_Py_* methods are called. We should +// revisit this if/when decide to not hold the GIL while manipulating the tape +// stack. +static std::unordered_set* tape_set = nullptr; +std::unordered_set* GetTapeSet() { + if (tape_set == nullptr) { + tape_set = new std::unordered_set; + } + return tape_set; +} + // xcode 7 doesn't define thread_local, so for compatibility we implement our // own. TODO(apassos) remove once we can deprecate xcode 7. #ifndef __APPLE__ -std::vector* GetTapeStack() { - thread_local std::vector tape_stack; - return &tape_stack; +bool* ThreadTapeIsStopped() { + thread_local bool thread_tape_is_stopped{false}; + return &thread_tape_is_stopped; } #else -static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED); -static std::unordered_map*>* - tape_stack GUARDED_BY(stack_mu) = nullptr; -std::vector* GetTapeStack() { - tensorflow::mutex_lock ml(stack_mu); - if (tape_stack == nullptr) { - tape_stack = - new std::unordered_map*>; +static std::unordered_map* tape_is_stopped = nullptr; +bool* ThreadTapeIsStopped() { + if (tape_is_stopped == nullptr) { + tape_is_stopped = new std::unordered_map; } - auto it = tape_stack->find(std::this_thread::get_id()); - if (it != tape_stack->end()) { - return it->second; + auto it = tape_is_stopped->find(std::this_thread::get_id()); + if (it != tape_is_stopped->end()) { + return &(it->second); } - return tape_stack - ->emplace(std::this_thread::get_id(), new std::vector) - .first->second; + return &(tape_is_stopped->emplace(std::this_thread::get_id(), false) + .first->second); } #endif -void TFE_Py_TapeStackPushNew(PyObject* persistent) { +void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } + +void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } + +PyObject* TFE_Py_TapeSetNew(PyObject* persistent) { TFE_Py_Tape_Type.tp_new = PyType_GenericNew; - if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return; + if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); tape->tape = new GradientTape(persistent == Py_True); - GetTapeStack()->push_back(tape); -} - -void TFE_Py_TapeStackPush(PyObject* tape) { Py_INCREF(tape); - GetTapeStack()->push_back(reinterpret_cast(tape)); + GetTapeSet()->insert(reinterpret_cast(tape)); + return reinterpret_cast(tape); } -PyObject* TFE_Py_TapeStackIsEmpty() { - if (GetTapeStack()->empty()) { +PyObject* TFE_Py_TapeSetIsEmpty() { + if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } -PyObject* TFE_Py_TapeStackPop() { - auto* stack = GetTapeStack(); - if (stack->empty()) { - PyErr_SetString(PyExc_RuntimeError, "tape stack is empty."); - return nullptr; - } - TFE_Py_Tape* top = stack->back(); - stack->pop_back(); - return reinterpret_cast(top); +void TFE_Py_TapeSetRemove(PyObject* tape) { + auto* stack = GetTapeSet(); + stack->erase(reinterpret_cast(tape)); + // We kept a reference to the tape in the set to ensure it wouldn't get + // deleted under us; cleaning it up here. + Py_DECREF(tape); } static std::vector MakeIntList(PyObject* list) { @@ -620,12 +625,15 @@ static std::vector MakeIntList(PyObject* list) { return tensor_ids; } -PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) { +PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { if (tensors == Py_None) { Py_RETURN_FALSE; } - auto* stack = GetTapeStack(); - if (stack->empty()) { + if (*ThreadTapeIsStopped()) { + Py_RETURN_FALSE; + } + auto* tape_set = GetTapeSet(); + if (tape_set->empty()) { Py_RETURN_FALSE; } PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); @@ -642,7 +650,7 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) { tensor_ids.push_back(FastTensorId(item)); } Py_DECREF(seq); - for (TFE_Py_Tape* tape : *stack) { + for (TFE_Py_Tape* tape : *tape_set) { if (tape->tape->ShouldRecord(tensor_ids)) { Py_RETURN_TRUE; } @@ -650,12 +658,12 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) { Py_RETURN_FALSE; } -void TFE_Py_TapeStackWatch(PyObject* tensor) { +void TFE_Py_TapeSetWatch(PyObject* tensor) { tensorflow::int64 tensor_id = FastTensorId(tensor); if (PyErr_Occurred()) { return; } - for (TFE_Py_Tape* tape : *GetTapeStack()) { + for (TFE_Py_Tape* tape : *GetTapeSet()) { tape->tape->Watch(tensor_id); } } @@ -720,8 +728,8 @@ std::vector MakeTensorIDList(PyObject* tensors) { return list; } -void TFE_Py_TapeStackWatchVariable(PyObject* variable) { - for (TFE_Py_Tape* tape : *GetTapeStack()) { +void TFE_Py_TapeSetWatchVariable(PyObject* variable) { + for (TFE_Py_Tape* tape : *GetTapeSet()) { tape->tape->WatchVariable(variable); } } @@ -736,12 +744,11 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { return result; } -void TFE_Py_TapeStackRecordOperation(PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensors, - PyObject* backward_function) { - auto* stack = GetTapeStack(); - if (stack->empty()) { +void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, + PyObject* input_tensors, + PyObject* backward_function) { + auto* set = GetTapeSet(); + if (set->empty()) { return; } std::vector input_ids = MakeTensorIDList(input_tensors); @@ -776,7 +783,7 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type, return; } - for (TFE_Py_Tape* tape : *stack) { + for (TFE_Py_Tape* tape : *set) { Py_INCREF(backward_function); tape->tape->RecordOperation( op_type_str, output_info, input_ids, backward_function, @@ -784,8 +791,8 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type, } } -void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id) { - for (TFE_Py_Tape* tape : *GetTapeStack()) { +void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { + for (TFE_Py_Tape* tape : *GetTapeSet()) { tape->tape->DeleteTrace(tensor_id); } } diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 14b5238f740..ad82266beca 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -35,7 +35,8 @@ class Tape(object): def push_new_tape(persistent=False): """Pushes a new tape onto the tape stack.""" - pywrap_tensorflow.TFE_Py_TapeStackPushNew(persistent) + tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent) + return Tape(tape) def watch(tensor): @@ -44,7 +45,7 @@ def watch(tensor): Args: tensor: tensor to be watched. """ - pywrap_tensorflow.TFE_Py_TapeStackWatch(tensor) + pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor) def watch_variable(variable): @@ -53,42 +54,39 @@ def watch_variable(variable): Args: variable: variable to be watched. """ - pywrap_tensorflow.TFE_Py_TapeStackWatchVariable(variable) + pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable) -def pop_tape(): +def pop_tape(tape): """Pops the top tape in the stack, if any.""" - return Tape(pywrap_tensorflow.TFE_Py_TapeStackPop()) + pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access @contextlib.contextmanager def stop_recording(): - stack = [] - while not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty(): - stack.append(pop_tape()._tape) # pylint: disable=protected-access try: + pywrap_tensorflow.TFE_Py_TapeSetStopOnThread() yield finally: - for tape in reversed(stack): - pywrap_tensorflow.TFE_Py_TapeStackPush(tape) + pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread() def should_record(tensors): """Returns true if any tape in the stack watches any of these tensors.""" - return pywrap_tensorflow.TFE_Py_TapeStackShouldRecord(tensors) + return pywrap_tensorflow.TFE_Py_TapeSetShouldRecord(tensors) def record_operation(op_type, output_tensors, input_tensors, backward_function): """Records the operation on all tapes in the stack.""" - pywrap_tensorflow.TFE_Py_TapeStackRecordOperation( + pywrap_tensorflow.TFE_Py_TapeSetRecordOperation( op_type, output_tensors, input_tensors, backward_function) def delete_trace(tensor_id): """Deletes traces for this Tensor from all tapes in the stack.""" - pywrap_tensorflow.TFE_Py_TapeStackDeleteTrace(tensor_id) + pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id) def could_possibly_record(): """Returns True if any tape is active.""" - return not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty() + return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty() diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index bf175cbe01e..d0f40bd68e0 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -614,7 +614,7 @@ class Estimator(object): if isinstance(result, (list, tuple)): if len(result) != 2: raise ValueError( - 'input_fn should return (feautures, labels) as a len 2 tuple.') + 'input_fn should return (features, labels) as a len 2 tuple.') return result[0], result[1], input_hooks return result, None, input_hooks diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py index 476776daa8f..37ac8515cb8 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/estimator/warm_starting_util.py @@ -121,7 +121,10 @@ class _WarmStartSettings( # where ws could be defined as: # Warm-start all weights in the model (input layer and hidden weights). + # Either the directory or a specific checkpoint can be provided (in the case + # of the former, the latest checkpoint will be used). ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp") + ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000") # Warm-start only the embeddings (input layer). ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp", @@ -348,7 +351,7 @@ def _warmstart_var_with_vocab(var, # TODO(vihanjain): Support _WarmstartSettings where class vocabularies need # remapping too. init = checkpoint_ops._load_and_remap_matrix_initializer( - ckpt_path=saver.latest_checkpoint(prev_ckpt), + ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), old_tensor_name=prev_tensor_name, new_row_vocab_size=current_vocab_size, new_col_vocab_size=v_shape[1], diff --git a/tensorflow/python/estimator/warm_starting_util_test.py b/tensorflow/python/estimator/warm_starting_util_test.py index cf502dd60de..cc0c4efc756 100644 --- a/tensorflow/python/estimator/warm_starting_util_test.py +++ b/tensorflow/python/estimator/warm_starting_util_test.py @@ -50,9 +50,7 @@ class WarmStartingUtilTest(test.TestCase): sess.run(variables.global_variables_initializer()) saver = saver_lib.Saver() ckpt_prefix = os.path.join(self.get_temp_dir(), "model") - ckpt_state_name = "checkpoint" - saver.save( - sess, ckpt_prefix, global_step=0, latest_filename=ckpt_state_name) + saver.save(sess, ckpt_prefix, global_step=0) def _create_prev_run_var(self, var_name, @@ -408,6 +406,44 @@ class WarmStartingUtilTest(test.TestCase): self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, sess) + def testWarmStart_ExplicitCheckpointFile(self): + # Create vocab for sparse column "sc_vocab". + vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], + "vocab") + # Create feature column. + sc_vocab = fc.categorical_column_with_vocabulary_file( + "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4) + + # Save checkpoint from which to warm-start. + _, prev_vocab_val = self._create_prev_run_var( + "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones()) + + partitioner = lambda shape, dtype: [1] * len(shape) + # New graph, new session WITHOUT warmstarting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + cols_to_vars = self._create_linear_model([sc_vocab], partitioner) + sess.run(variables.global_variables_initializer()) + # Without warmstarting, the weights should be initialized using default + # initializer (which is init_ops.zeros_initializer). + self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]}, + sess) + + # New graph, new session with warmstarting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + cols_to_vars = self._create_linear_model([sc_vocab], partitioner) + # Since old vocab is not explicitly set in WarmStartSettings, the old + # vocab is assumed to be same as new vocab. + ws_util._warmstart(ws_util._WarmStartSettings( + # Explicitly provide the file prefix instead of just the dir. + os.path.join(self.get_temp_dir(), "model-0"), + vars_to_warmstart=".*sc_vocab.*")) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warmstarted. + self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, + sess) + def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self): # Create old vocabulary, and use a size smaller than the total number of # entries. diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 7bb799a8b06..b5ed1352843 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -59,6 +59,7 @@ def _TestDir(test_name): # pylint: enable=invalid-name +@test_util.with_c_api class SimpleMetaGraphTest(test.TestCase): def testNoVariables(self): @@ -103,7 +104,8 @@ class SimpleMetaGraphTest(test.TestCase): # Re-exports the current graph state for comparison to the original. new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename + "_new") - self.assertProtoEquals(meta_graph_def, new_meta_graph_def) + test_util.assert_meta_graph_protos_equal(self, meta_graph_def, + new_meta_graph_def) # Ensures that we can still get a reference to our graph collections. new_input_tensor = ops.get_collection("input_tensor")[0] @@ -226,7 +228,7 @@ class SimpleMetaGraphTest(test.TestCase): double_nested_complex_node_def = None for function_def in meta_graph_def.graph_def.library.function: for node_def in function_def.node_def: - if node_def.name == "double_nested_complex": + if node_def.name.startswith("double_nested_complex"): double_nested_complex_node_def = node_def break if double_nested_complex_node_def: @@ -258,6 +260,7 @@ class SimpleMetaGraphTest(test.TestCase): self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) +@test_util.with_c_api class ScopedMetaGraphTest(test.TestCase): def _testScopedExport(self, test_dir, exported_filenames): @@ -435,10 +438,13 @@ class ScopedMetaGraphTest(test.TestCase): ] orig_meta_graphs = self._testScopedExport(test_dir, filenames) new_meta_graphs = self._testScopedImport(test_dir, filenames) - # Delete the unbound_inputs to allow directly calling ProtoEqual. - del orig_meta_graphs[0].collection_def["unbound_inputs"] - del new_meta_graphs[0].collection_def["unbound_inputs"] for a, b in zip(orig_meta_graphs, new_meta_graphs): + # The unbound input strings are slightly different with the C API enabled + # ("images" vs "images:0") due to the original import_graph_def code + # vs. ImportGraphDef in C++. + # TODO(skyewm): update the pbtxts once _USE_C_API is removed. + del a.collection_def["unbound_inputs"] + del b.collection_def["unbound_inputs"] test_util.assert_meta_graph_protos_equal(self, a, b) def testScopedImportUnderNameScope(self): @@ -572,7 +578,8 @@ class ScopedMetaGraphTest(test.TestCase): "exported_queue1.pbtxt") new_meta_graph = self._testScopedImportWithQueue( test_dir, "exported_queue1.pbtxt", "exported_new_queue1.pbtxt") - self.assertProtoEquals(orig_meta_graph, new_meta_graph) + test_util.assert_meta_graph_protos_equal(self, orig_meta_graph, + new_meta_graph) # Verifies that we can export a subgraph in a nested name scope containing a # "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new @@ -718,6 +725,7 @@ class ScopedMetaGraphTest(test.TestCase): self.assertEqual("", str(graph2.as_graph_element("matmul").device)) +@test_util.with_c_api class MetaGraphWithVariableScopeTest(test.TestCase): def testMetricsCollection(self): @@ -775,6 +783,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase): initializer = variables.local_variables_initializer() +@test_util.with_c_api class ExportImportAcrossScopesTest(test.TestCase): def testPartionedVariables(self): @@ -845,7 +854,7 @@ class ExportImportAcrossScopesTest(test.TestCase): if shared_name_value.s: node.attr[shared_name_attr].s = b"" - self.assertProtoEquals(expected, result) + test_util.assert_meta_graph_protos_equal(self, expected, result) if __name__ == "__main__": diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 5ac30537499..b2fb63dbbac 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -162,6 +162,16 @@ def assert_meta_graph_protos_equal(tester, a, b): # proto comparison below. a.ClearField("collection_def") b.ClearField("collection_def") + + # Check the graph_defs. + assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True) + # Check graph_def versions (ignored by assert_equal_graph_def). + tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions) + # Compared the fields directly, remove their raw values from the + # proto comparison below. + a.ClearField("graph_def") + b.ClearField("graph_def") + tester.assertProtoEquals(a, b) @@ -178,7 +188,7 @@ def _strip_checkpoint_v2_randomized(graph_def): if attr_tensor_value and len(attr_tensor_value.string_val) == 1: attr_tensor_string_value = attr_tensor_value.string_val[0] if (attr_tensor_string_value and - re.match(_SHARDED_SAVE_OP_PATTERN, attr_tensor_string_value)): + re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))): delete_keys.append(attr_key) for attr_key in delete_keys: del node.attr[attr_key] diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 00781d01505..f54f146e0ac 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl @@ -55,6 +56,7 @@ def _logit(x): return np.log(x) - np.log1p(-x) +@test_util.with_c_api class AssertCloseTest(test.TestCase): def testAssertCloseIntegerDtype(self): @@ -145,6 +147,7 @@ class AssertCloseTest(test.TestCase): array_ops.identity(w).eval(feed_dict=feed_dict) +@test_util.with_c_api class GetLogitsAndProbsTest(test.TestCase): def testImproperArguments(self): @@ -298,6 +301,7 @@ class GetLogitsAndProbsTest(test.TestCase): logit.eval(feed_dict={l: np.ones([int(2**11+1)])}) +@test_util.with_c_api class EmbedCheckCategoricalEventShapeTest(test.TestCase): def testTooSmall(self): @@ -335,6 +339,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): du.embed_check_categorical_event_shape(param) +@test_util.with_c_api class EmbedCheckIntegerCastingClosedTest(test.TestCase): def testCorrectlyAssertsNonnegative(self): @@ -370,6 +375,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase): x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.int32)}) +@test_util.with_c_api class LogCombinationsTest(test.TestCase): def testLogCombinationsBinomial(self): @@ -400,6 +406,7 @@ class LogCombinationsTest(test.TestCase): self.assertEqual([2, 2], log_binom.get_shape()) +@test_util.with_c_api class DynamicShapeTest(test.TestCase): def testSameDynamicShape(self): @@ -504,6 +511,7 @@ class DynamicShapeTest(test.TestCase): })) +@test_util.with_c_api class RotateTransposeTest(test.TestCase): def _np_rotate_transpose(self, x, shift): @@ -537,6 +545,7 @@ class RotateTransposeTest(test.TestCase): shift: shift_value})) +@test_util.with_c_api class PickVectorTest(test.TestCase): def testCorrectlyPicksVector(self): @@ -557,6 +566,7 @@ class PickVectorTest(test.TestCase): constant_op.constant(False), x, y)) # No eval. +@test_util.with_c_api class PreferStaticRankTest(test.TestCase): def testNonEmptyConstantTensor(self): @@ -596,6 +606,7 @@ class PreferStaticRankTest(test.TestCase): self.assertAllEqual(0, rank.eval(feed_dict={x: 1})) +@test_util.with_c_api class PreferStaticShapeTest(test.TestCase): def testNonEmptyConstantTensor(self): @@ -635,6 +646,7 @@ class PreferStaticShapeTest(test.TestCase): self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1})) +@test_util.with_c_api class PreferStaticValueTest(test.TestCase): def testNonEmptyConstantTensor(self): @@ -675,6 +687,7 @@ class PreferStaticValueTest(test.TestCase): self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1})) +@test_util.with_c_api class FillTriangularTest(test.TestCase): def setUp(self): @@ -769,6 +782,7 @@ class FillTriangularTest(test.TestCase): self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True) +@test_util.with_c_api class ReduceWeightedLogSumExp(test.TestCase): def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False): @@ -865,6 +879,7 @@ class ReduceWeightedLogSumExp(test.TestCase): du.reduce_weighted_logsumexp(x, w, axis=[0, 1]).eval()) +@test_util.with_c_api class GenNewSeedTest(test.TestCase): def testOnlyNoneReturnsNone(self): @@ -875,6 +890,7 @@ class GenNewSeedTest(test.TestCase): # TODO(jvdillon): Merge this test back into: # tensorflow/python/kernel_tests/softplus_op_test.py # once TF core is accepting new ops. +@test_util.with_c_api class SoftplusTest(test.TestCase): def _npSoftplus(self, np_features): diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index d97823c17f8..083931aa836 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -20,21 +20,25 @@ limitations under the License. %rename("%s") TFE_ContextListDevices; %rename("%s") TFE_ContextAddFunction; %rename("%s") TFE_ContextAddFunctionDef; +%rename("%s") TFE_ContextEnableRunMetadata; +%rename("%s") TFE_ContextDisableRunMetadata; +%rename("%s") TFE_ContextExportRunMetadata; %rename("%s") TFE_ContextClearCaches; %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; %rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_UID; -%rename("%s") TFE_Py_TapeStackPushNew; -%rename("%s") TFE_Py_TapeStackPush; -%rename("%s") TFE_Py_TapeStackPop; -%rename("%s") TFE_Py_TapeStackIsEmpty; -%rename("%s") TFE_Py_TapeStackShouldRecord; -%rename("%s") TFE_Py_TapeStackWatch; -%rename("%s") TFE_Py_TapeStackDeleteTrace; -%rename("%s") TFE_Py_TapeStackRecordOperation; -%rename("%s") TFE_Py_TapeStackWatchVariable; +%rename("%s") TFE_Py_TapeSetNew; +%rename("%s") TFE_Py_TapeSetRemove; +%rename("%s") TFE_Py_TapeSetStopOnThread; +%rename("%s") TFE_Py_TapeSetRestartOnThread; +%rename("%s") TFE_Py_TapeSetIsEmpty; +%rename("%s") TFE_Py_TapeSetShouldRecord; +%rename("%s") TFE_Py_TapeSetWatch; +%rename("%s") TFE_Py_TapeSetDeleteTrace; +%rename("%s") TFE_Py_TapeSetRecordOperation; +%rename("%s") TFE_Py_TapeSetWatchVariable; %rename("%s") TFE_Py_TapeGradient; %rename("%s") TFE_Py_TapeWatchedVariables; %rename("%s") TFE_NewContextOptions; diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index cebff5cf309..ed9d11d8561 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -455,11 +455,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/cb27e1d0da7f30562ea6c1c4f01393afbf112620.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/cb27e1d0da7f30562ea6c1c4f01393afbf112620.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/81a623d6d61cf87847f839e80b047c267020ab0e.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/81a623d6d61cf87847f839e80b047c267020ab0e.tar.gz", ], - sha256 = "d4e4d17040a786bab13bb1b73ec2dc358f0c07214f847076e0ded8de15800782", - strip_prefix = "llvm-cb27e1d0da7f30562ea6c1c4f01393afbf112620", + sha256 = "be0259c0bd5349200df346c92ba7708341e18ef313fcf7398682b5cff2469137", + strip_prefix = "llvm-81a623d6d61cf87847f839e80b047c267020ab0e", build_file = str(Label("//third_party/llvm:llvm.BUILD")), ) diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 5f1c42dbe48..c09bb222e8c 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -110,11 +110,7 @@ def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): lang = "c++" else: lang = "c" - # TODO: We pass -no-canonical-prefixes here to match the compiler flags, - # but in cuda_clang CROSSTOOL file that is a `feature` and we should - # handle the case when it's disabled and no flag is passed - result = repository_ctx.execute([cc, "-no-canonical-prefixes", - "-E", "-x" + lang, "-", "-v"]) + result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"]) index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN) if index1 == -1: return []