Merge commit for internal changes

This commit is contained in:
Frank Chen 2018-01-09 11:14:49 -08:00
commit fc8b359214
79 changed files with 3319 additions and 209 deletions

View File

@ -542,7 +542,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
// (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
// which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
// of FunctionLibraryRuntime tells use that func_lib_def is not accessed by
// of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
// FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
// This is quite subtle. Re-work things to make this better? (Would it make
// sense for FunctionLibraryRuntime to ensure thread-safe access to

View File

@ -302,6 +302,7 @@ cc_library(
":array4d",
":shape_tree",
":shape_util",
":sparse_index_array",
":status_macros",
":types",
":util",
@ -628,6 +629,28 @@ tf_cc_test(
],
)
cc_library(
name = "sparse_index_array",
srcs = ["sparse_index_array.cc"],
hdrs = ["sparse_index_array.h"],
deps = [
":array2d",
":shape_util",
":xla_data_proto",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "sparse_index_array_test",
srcs = ["sparse_index_array_test.cc"],
deps = [
":sparse_index_array",
":test",
"//tensorflow/core:test_main",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -149,4 +149,33 @@ namespace xla {
return stride;
}
/* static */ bool IndexUtil::IndexInBounds(
const Shape& shape, tensorflow::gtl::ArraySlice<int64> index) {
int64 rank = ShapeUtil::Rank(shape);
if (rank != index.size()) {
return false;
}
for (int64 d = 0; d < rank; ++d) {
if (index[d] >= shape.dimensions(d)) {
return false;
}
}
return true;
}
/* static */ int IndexUtil::CompareIndices(
tensorflow::gtl::ArraySlice<int64> lhs,
tensorflow::gtl::ArraySlice<int64> rhs) {
int64 rank = lhs.size();
CHECK_EQ(rhs.size(), rank);
for (int64 dim = 0; dim < rank; ++dim) {
if (lhs[dim] < rhs[dim]) {
return -1;
} else if (lhs[dim] > rhs[dim]) {
return 1;
}
}
return 0;
}
} // namespace xla

View File

@ -69,6 +69,18 @@ class IndexUtil {
// sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10
static int64 GetDimensionStride(const Shape& shape, int64 dimension);
// Returns true iff the given multi-index is contained in the bounds for the
// shape.
static bool IndexInBounds(const Shape& shape,
tensorflow::gtl::ArraySlice<int64> index);
// Compares the given indices in lexicographic order. lhs[0] and rhs[0] are
// compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger,
// then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is
// returned.
static int CompareIndices(tensorflow::gtl::ArraySlice<int64> lhs,
tensorflow::gtl::ArraySlice<int64> rhs);
private:
TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil);
};

View File

@ -64,6 +64,13 @@ void SetDefaultLayoutToContainer(
return layout;
}
/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
Layout layout;
layout.set_format(SPARSE);
layout.set_max_sparse_elements(max_sparse_elements);
return layout;
}
namespace {
// Internal helper that creates a default layout for an array of the given rank.
@ -234,7 +241,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
LayoutUtil::ClearLayout(program_shape->mutable_result());
}
/* static */ bool LayoutUtil::IsDense(const Shape& shape) {
/* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
return ShapeUtil::IsArray(shape) && shape.has_layout() &&
IsDense(shape.layout());
}
@ -260,7 +267,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
shape.layout().padded_dimensions_size() == 0) {
return false;
}
CHECK(IsDense(shape));
CHECK(IsDenseArray(shape));
CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size());
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) {
@ -272,21 +279,35 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::PaddedDimensions(
const Shape& shape) {
CHECK(IsDense(shape));
CHECK(IsDenseArray(shape));
return AsInt64Slice(shape.layout().padded_dimensions());
}
/* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape,
int64 index) {
CHECK(IsDense(shape));
CHECK(IsDenseArray(shape));
return shape.layout().padded_dimensions(index);
}
/* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) {
CHECK(IsDense(shape));
CHECK(IsDenseArray(shape));
return shape.layout().padding_value();
}
/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
return ShapeUtil::IsArray(shape) && shape.has_layout() &&
IsSparse(shape.layout());
}
/* static */ bool LayoutUtil::IsSparse(const Layout& layout) {
return layout.format() == SPARSE;
}
/* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) {
CHECK(IsSparse(layout));
return layout.max_sparse_elements();
}
/* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
if (ShapeUtil::IsTuple(shape)) {
// Tuple shape: all subshapes must have a layout.
@ -313,7 +334,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::MinorToMajor(
const Shape& shape) {
CHECK(IsDense(shape));
CHECK(IsDenseArray(shape));
return AsInt64Slice(shape.layout().minor_to_major());
}
@ -419,6 +440,7 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src,
/* static */ bool LayoutUtil::AreDimensionsConsecutive(
const Layout& layout, tensorflow::gtl::ArraySlice<int64> dims) {
CHECK(IsDense(layout));
std::vector<int64> positions_in_layout;
for (int64 dim : dims) {
positions_in_layout.push_back(

View File

@ -36,6 +36,10 @@ class LayoutUtil {
// convenience function for protobuf construction.)
static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major);
// Creates a sparse layout with the given maximum number of elements. (This is
// a convenience function for protobuf construction.)
static Layout MakeSparseLayout(int64 max_sparse_elements);
// Returns default layout for the given shape.
static Layout GetDefaultLayoutForShape(const Shape& shape);
@ -72,7 +76,7 @@ class LayoutUtil {
static void ClearLayout(ProgramShape* program_shape);
// Returns whether the given Shape is an array and has a dense format layout.
static bool IsDense(const Shape& shape);
static bool IsDenseArray(const Shape& shape);
// Returns whether the given Layout has a dense format.
static bool IsDense(const Layout& layout);
@ -107,6 +111,17 @@ class LayoutUtil {
// an array and has a dense layout.
static PaddingValue GetPaddingValue(const Shape& shape);
// Returns whether the given Shape is an array (i.e. not a tuple) and has a
// sparse format layout.
static bool IsSparseArray(const Shape& shape);
// Returns whether the given Layout has a sparse format.
static bool IsSparse(const Layout& layout);
// Returns the maximum number of elements that can be stored in a sparse
// layout.
static int64 MaxSparseElements(const Layout& layout);
// Returns whether the given shape has a layout. For tuple shapes, true is
// returned only if all elements have layouts.
static bool HasLayout(const Shape& shape);

View File

@ -30,6 +30,14 @@ class LayoutUtilTest : public ::testing::Test {
*shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
return shape;
}
Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
tensorflow::gtl::ArraySlice<int64> dimensions,
int64 max_sparse_elements) {
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
return shape;
}
};
TEST_F(LayoutUtilTest, TupleLayoutComparison) {
@ -81,6 +89,29 @@ TEST_F(LayoutUtilTest, CopyLayoutArray) {
EXPECT_FALSE(dst.has_layout());
}
TEST_F(LayoutUtilTest, CopyLayoutSparse) {
Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2);
Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
// Should work if destination has no layout.
dst.clear_layout();
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
// If source is cleared, then destination should be cleared.
src.clear_layout();
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_TRUE(dst.has_layout());
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_FALSE(dst.has_layout());
}
TEST_F(LayoutUtilTest, CopyLayoutTuple) {
Shape src = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
@ -100,6 +131,25 @@ TEST_F(LayoutUtilTest, CopyLayoutTuple) {
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) {
Shape src = ShapeUtil::MakeTupleShape(
{MakeShapeWithSparseLayout(F32, {2, 3}, 4),
MakeShapeWithSparseLayout(F32, {42, 123}, 4),
ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})});
Shape dst = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
@ -107,6 +157,13 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) {
Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6);
Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
@ -116,6 +173,15 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
::testing::ContainsRegex("cannot copy layout from shape"));
}
TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) {
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4);
auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
EXPECT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
::testing::ContainsRegex("cannot copy layout from shape"));
}
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) {
Shape src =
ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
@ -221,5 +287,10 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
}
TEST_F(LayoutUtilTest, SparseLayoutMaxElements) {
EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)),
101);
}
} // namespace
} // namespace xla

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -94,9 +94,15 @@ Literal::Literal(const Shape& shape, bool allocate_arrays)
Piece& piece = pair.second;
piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
if (ShapeUtil::IsArray(piece.subshape())) {
const Shape& subshape = piece.subshape();
if (ShapeUtil::IsArray(subshape)) {
if (allocate_arrays) {
piece.set_buffer(new char[piece.size_bytes()]);
if (LayoutUtil::IsSparseArray(subshape)) {
piece.set_sparse_indices(new SparseIndexArray(
LayoutUtil::MaxSparseElements(subshape.layout()),
ShapeUtil::Rank(subshape)));
}
} else {
piece.set_buffer(nullptr);
}
@ -112,6 +118,7 @@ void Literal::DeallocateBuffers() {
Piece& piece = pair.second;
if (piece.buffer() != nullptr) {
delete[] piece.buffer();
delete piece.sparse_indices();
}
}
}
@ -164,6 +171,15 @@ std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
return literal;
}
const SparseIndexArray* Literal::sparse_indices(
const ShapeIndex& shape_index) const {
return piece(shape_index).sparse_indices();
}
SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
return piece(shape_index).sparse_indices();
}
/* static */ std::unique_ptr<Literal> Literal::CreateFromDimensions(
PrimitiveType primitive_type,
tensorflow::gtl::ArraySlice<int64> dimensions) {
@ -247,9 +263,12 @@ std::vector<Literal> Literal::DecomposeTuple() {
}
Piece& src_piece = piece(src_index);
// Move the respective buffer over to the element Literal.
// Move the respective buffer and sparse indices over to the element
// Literal.
dest_piece.set_buffer(src_piece.buffer());
src_piece.set_buffer(nullptr);
dest_piece.set_sparse_indices(src_piece.sparse_indices());
src_piece.set_sparse_indices(nullptr);
}
}
// Set this literal to be nil-shaped.
@ -406,6 +425,8 @@ Status Literal::MoveFrom(Literal&& src_literal,
Piece& dest_piece = piece(dest_index);
delete[] dest_piece.buffer();
dest_piece.set_buffer(src_piece.buffer());
delete dest_piece.sparse_indices();
dest_piece.set_sparse_indices(src_piece.sparse_indices());
}
src_literal.shape_ = ShapeUtil::MakeNil();
@ -764,7 +785,7 @@ std::unique_ptr<Literal> Literal::Transpose(
// dimension has within the transposed array, a layout is affine if
// MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
// vector of the affine layout.
CHECK(LayoutUtil::IsDense(permuted_shape));
CHECK(LayoutUtil::IsDenseArray(permuted_shape));
Layout* layout = permuted_shape.mutable_layout();
layout->clear_minor_to_major();
for (auto index : LayoutUtil::MinorToMajor(shape())) {
@ -1573,6 +1594,12 @@ LiteralProto Literal::ToProto() const {
}
piece.WriteToProto(proto_piece);
}
if (LayoutUtil::IsSparseArray(shape())) {
CopyToRepeatedField(proto.mutable_sparse_indices(),
sparse_indices()->data());
}
return proto;
}
@ -1653,6 +1680,7 @@ LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
}
const Piece& src_piece = literal.piece(src_index);
piece.set_buffer(src_piece.buffer());
piece.set_sparse_indices(src_piece.sparse_indices());
piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
}
}

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@ -103,6 +104,12 @@ class Literal {
tensorflow::gtl::MutableArraySlice<NativeT> data(
const ShapeIndex& shape_index = {});
// Returns a pointer to the sparse index array. Returns nullptr if the literal
// is not a sparse array.
const SparseIndexArray* sparse_indices(
const ShapeIndex& shape_index = {}) const;
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
// Returns a pointer to (or size of) the underlying buffer holding the array
// at the given shape index. CHECKs if the subshape of the literal at the
// given ShapeIndex is not array.
@ -160,6 +167,56 @@ class Literal {
// array.
string GetR1U8AsString() const;
// Creates a literal with a sparse layout and the given indices and values.
// The shape is initialized from the given dimensions. The minor dimension of
// the indices array must equal the rank of the shape (i.e. size of the
// dimensions array). The major dimension of the indices array must equal the
// number of elements in the values array. The maximum number of elements in
// the array is taken from the max_indices() value of the index array.
//
// XLA assumes that sparse literals are in sorted order for all operations. If
// the `sort` argument is true, then the indices and values will be sorted
// while copying them into the literal. If you have ensured that the indices
// and values are already sorted, then you may set the `sort` argument to
// false to skip the sorting step.
//
// For example:
//
// CreateSparse(
// {12, 12, 12},
// SparseIndexArray(10, 3,
// Array2D{
// {0, 1, 2},
// {3, 4, 5},
// {6, 7, 8},
// {9, 10, 11},
// }),
// {1.0, 2.0 3.0, 4.0})
//
// This creates an array with shape F64[12,12,12]sparse{10}, that has the
// following non-zero values:
//
// [0, 1, 2]: 1.0
// [3, 4, 5]: 2.0
// [6, 7, 8]: 3.0
// [9, 10, 11]: 4.0
//
template <typename NativeT>
static std::unique_ptr<Literal> CreateSparse(
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
// Populates a literal with a sparse layout with the given indices and values.
// Each index in the indices array is CHECKed against the dimensions in the
// literal's shape. If sort is true, then the indices and values will be
// sorted. If sort is false, then the indices and values are assumed to
// already be in sorted order. See CreateSparse for an example of how data
// are populated.
template <typename NativeT>
void PopulateSparse(SparseIndexArray indices,
tensorflow::gtl::ArraySlice<NativeT> values,
bool sort = true);
// Creates a new Literal object with the shape specified as parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
@ -358,7 +415,7 @@ class Literal {
const ShapeIndex& shape_index, NativeT value);
// Overloads of Get and Set for array literals. CHECKs if the literal is not
// array-shaped.
// array-shaped and dense.
template <typename NativeT>
NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
template <typename NativeT>
@ -408,6 +465,8 @@ class Literal {
// This function is useful if you want a polymorphic representation
// of the tensor's elements (turning it to a string for something
// like representation in a protobuf).
//
// This literal must have a dense layout.
void EachCellAsString(
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
const string& value)>& per_cell) const;
@ -447,6 +506,8 @@ class Literal {
//
// generator must be a callable of the type
// NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
//
// This literal must have a dense layout.
template <typename NativeT, typename FnType>
Status Populate(const FnType& generator);
@ -485,10 +546,12 @@ class Literal {
// admonishments about floating-point equality checks apply. We expect you to
// use this to check for complex values that can be expressed precisely as
// float pairs e.g. (-0.5, 1.0).
//
// This literal must have a dense layout.
bool IsAllComplex(complex64 value) const;
// Returns whether this literal is zero at the specified index. This literal
// must be an array.
// must be an array with a dense layout.
bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
// Return the count of the elements in the array at the given shape index in
@ -563,6 +626,14 @@ class Literal {
char* buffer() const { return buffer_; }
void set_buffer(char* buffer) { buffer_ = buffer; }
// The array of multi-indices that provide the locations of non-zero
// elements in a sparse array. Only used if
// LayoutUtil::IsSparseArray(shape()) is true.
SparseIndexArray* sparse_indices() const { return sparse_indices_; }
void set_sparse_indices(SparseIndexArray* sparse_indices) {
sparse_indices_ = sparse_indices;
}
// Gets or sets the subshape of this piece. This reference points to a
// subshape within the shape in the containing Literal (Literal::shape_).
const Shape& subshape() const { return *subshape_; }
@ -598,6 +669,9 @@ class Literal {
// For array-shaped pieces, this is the buffer holding the literal data.
char* buffer_ = nullptr;
// For sparse arrays, this is the array of indices.
SparseIndexArray* sparse_indices_ = nullptr;
// The shape of piece. This points into the shape of the containing Literal
// (Literal::shape_).
const Shape* subshape_ = nullptr;
@ -836,6 +910,21 @@ template <typename NativeT>
return CreateR4FromArray4DWithLayout(tmp, layout);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateSparse(
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
int64 num_elements = values.size();
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
CHECK_EQ(rank, indices.rank());
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
indices.max_indices()));
literal->PopulateSparse(indices, values, sort);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR4(
std::initializer_list<std::initializer_list<
@ -1044,11 +1133,35 @@ void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateSparse(SparseIndexArray indices,
tensorflow::gtl::ArraySlice<NativeT> values,
bool sort) {
CHECK(LayoutUtil::IsSparseArray(shape()));
int rank = ShapeUtil::Rank(shape());
CHECK_EQ(indices.rank(), rank);
int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
CHECK_LE(indices.max_indices(), max_elements);
int64 num_elements = values.size();
CHECK_LE(num_elements, max_elements);
CHECK_EQ(num_elements, indices.index_count());
auto root_data = root_piece().data<NativeT>();
root_data.remove_suffix(max_elements - values.size());
std::copy(values.begin(), values.end(), root_data.begin());
*this->root_piece().sparse_indices() = std::move(indices);
if (sort) {
auto root_data = this->root_piece().data<NativeT>();
root_data.remove_suffix(root_data.size() - num_elements);
this->root_piece().sparse_indices()->SortWithValues(root_data);
}
DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
}
template <typename NativeT, typename FnType>
Status Literal::Populate(const FnType& generator) {
const Shape& this_shape = shape();
const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(ShapeUtil::IsArray(this_shape));
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
TF_RET_CHECK(this_shape.element_type() ==
primitive_util::NativeToPrimitiveType<NativeT>());
tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();

View File

@ -193,6 +193,34 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
ASSERT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, CreateSparse) {
std::vector<int64> dimensions = {8, 8, 8};
Array2D<int64> indices = {
{3, 4, 5},
{1, 2, 3},
{2, 3, 4},
{3, 5, 6},
};
std::vector<int64> values = {7, 8, 9, 10};
auto literal = Literal::CreateSparse<int64>(
dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
Array2D<int64> expected_indices = {
{1, 2, 3},
{2, 3, 4},
{3, 4, 5},
{3, 5, 6},
};
std::vector<int64> expected_values = {8, 9, 7, 10};
EXPECT_EQ(literal->sparse_indices()->data(),
tensorflow::gtl::ArraySlice<int64>(
expected_indices.data(), expected_indices.num_elements()));
EXPECT_EQ(tensorflow::gtl::ArraySlice<int64>(literal->data<int64>().data(),
expected_values.size()),
tensorflow::gtl::ArraySlice<int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
// clang-format off
auto literal = Literal::CreateR4Projected<float>({

View File

@ -60,6 +60,12 @@ bool ContainsKey(const Collection& collection, const Key& key) {
return collection.find(key) != collection.end();
}
// Inserts `value` into `set`. Dies if it was already present.
template <class Set>
void InsertOrDie(Set* const set, const typename Set::value_type& value) {
CHECK(set->insert(value).second) << "duplicate value: " << value;
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_

View File

@ -1101,6 +1101,8 @@ cc_library(
":hlo",
":hlo_evaluator",
":hlo_pass",
":tuple_util",
":while_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
],
@ -2255,6 +2257,78 @@ cc_library(
],
)
cc_library(
name = "tuple_util",
srcs = ["tuple_util.cc"],
hdrs = ["tuple_util.h"],
deps = [
":hlo",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "tuple_util_test",
srcs = ["tuple_util_test.cc"],
deps = [
":tuple_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
cc_library(
name = "while_util",
srcs = ["while_util.cc"],
hdrs = ["while_util.h"],
deps = [
":call_inliner",
":hlo",
":tuple_util",
],
)
tf_cc_test(
name = "while_util_test",
srcs = ["while_util_test.cc"],
deps = [
":while_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
cc_library(
name = "while_loop_invariant_code_motion",
srcs = ["while_loop_invariant_code_motion.cc"],
hdrs = ["while_loop_invariant_code_motion.h"],
deps = [
":hlo",
":hlo_pass",
":tuple_util",
":while_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "while_loop_invariant_code_motion_test",
srcs = ["while_loop_invariant_code_motion_test.cc"],
deps = [
":hlo_matchers",
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -82,6 +82,10 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
return outer_->ReplaceInstruction(call_, new_root);
}
CallInliner::InlinedInstructionMap ConsumeInstructionMap() {
return std::move(subcomputation_hlo_to_new_hlo_);
}
private:
// Resolves the callee subcomputation_hlo to the new (inline) HLO in the
// caller computation, or returns a NotFound error if that subcomputation HLO
@ -112,13 +116,13 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
HloInstruction* call_;
HloComputation* outer_;
std::unordered_map<HloInstruction*, HloInstruction*>
subcomputation_hlo_to_new_hlo_;
CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_;
};
} // namespace
/* static */ Status CallInliner::Inline(HloInstruction* call) {
/* static */ StatusOr<CallInliner::InlinedInstructionMap> CallInliner::Inline(
HloInstruction* call) {
TF_RET_CHECK(call->opcode() == HloOpcode::kCall)
<< "Instruction was not a call op: " << call->opcode();
const auto& callees = call->called_computations();
@ -126,7 +130,8 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
HloComputation* callee = callees[0];
// We visit the callee, cloning its body into its caller.
SubcomputationInsertionVisitor visitor(call);
return callee->Accept(&visitor);
TF_RETURN_IF_ERROR(callee->Accept(&visitor));
return visitor.ConsumeInstructionMap();
}
StatusOr<bool> CallInliner::Run(HloModule* module) {
@ -140,7 +145,7 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
VLOG(1) << "Visiting callsite: " << callsite.ToString();
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
HloInstruction* call = callsite.instruction();
TF_RETURN_IF_ERROR(Inline(call));
TF_RETURN_IF_ERROR(Inline(call).status());
did_mutate = true;
}
}

View File

@ -27,8 +27,12 @@ namespace xla {
// called function, and proceed recursively.
class CallInliner : public HloPassInterface {
public:
// Inlines one call instruction.
static Status Inline(HloInstruction* call);
using InlinedInstructionMap =
std::unordered_map<HloInstruction*, HloInstruction*>;
// Inlines one call instruction. Returns a mapping from the original
// instructions to their inlined versions.
static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call);
~CallInliner() override = default;
tensorflow::StringPiece name() const override { return "CallInliner"; }

View File

@ -135,7 +135,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) {
HloInstruction::CreateCall(pred, {}, false_computation));
auto computation = module->AddEntryComputation(call_false_builder.Build());
TF_ASSERT_OK(CallInliner::Inline(call));
TF_ASSERT_OK(CallInliner::Inline(call).status());
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_THAT(computation->root_instruction()->control_successors(),
ElementsAre(op::Constant()));

View File

@ -81,6 +81,7 @@ cc_library(
":conv_canonicalization",
":cpu_copy_insertion",
":cpu_executable",
":cpu_hlo_support_checker",
":cpu_instruction_fusion",
":cpu_layout_assignment",
":cpu_options",
@ -126,6 +127,7 @@ cc_library(
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service:while_loop_invariant_code_motion",
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep
@ -873,6 +875,32 @@ tf_cc_test(
],
)
cc_library(
name = "cpu_hlo_support_checker",
srcs = ["cpu_hlo_support_checker.cc"],
hdrs = ["cpu_hlo_support_checker.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "cpu_hlo_support_checker_test",
srcs = ["cpu_hlo_support_checker_test.cc"],
deps = [
":cpu_hlo_support_checker",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "llvm/IR/Function.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Object/ObjectFile.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/TargetRegistry.h"
@ -50,6 +51,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
@ -85,6 +87,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -258,6 +261,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
// Optimization pipeline.
HloPassPipeline pipeline("CPU");
pipeline.AddInvariantChecker<HloVerifier>(ShapeSizeBytesFunction());
pipeline.AddPass<CpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
@ -291,6 +295,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
// elimination has to come after that pass.
pipeline.AddPass<ZeroSizedHloElimination>();
pass.AddPass<WhileLoopInvariantCodeMotion>();
pass.AddPass<TupleSimplifier>();
pass.AddPass<WhileLoopSimplifier>();
pass.AddPass<HloDCE>();
@ -439,6 +444,21 @@ Status InitializeModuleHooks(
return Status::OK();
}
Status VerifyLlvmModule(const llvm::Module& llvm_module) {
XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
std::string err;
llvm::raw_string_ostream err_stream(err);
// verifyModule() returns true if the module is broken.
TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
<< "Invalid LLVM IR before optimizations:\n"
<< err_stream.str()
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
"Rerun with --xla_dump_ir_to to get the IR. ";
return Status::OK();
}
} // namespace
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
@ -627,6 +647,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
if (embed_ir_in_executable) {
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
}
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
// JIT compile the LLVM IR module to in-memory machine code.
jit->AddModule(std::move(llvm_module));
@ -704,6 +725,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
if (embed_ir_in_executable) {
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
}
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module));
@ -875,6 +897,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
&module_sequence.at(computation)));
CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name));
TF_RETURN_IF_ERROR(VerifyLlvmModule(llvm_module));
ModuleHook pre_optimization_ir_dump_hook;
ModuleHook post_optimization_ir_dump_hook;

View File

@ -0,0 +1,48 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
StatusOr<bool> CpuHloSupportChecker::Run(HloModule* module) {
for (auto* computation : module->computations()) {
for (const auto& instruction : computation->instructions()) {
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
instruction->shape(),
[&instruction](const Shape& subshape, const ShapeIndex&) {
if (LayoutUtil::IsSparseArray(subshape)) {
return xla::Unimplemented(
"CPU backend does not support HLO instruction %s with shape "
"containing a sparse layout: %s",
instruction->ToString().c_str(),
ShapeUtil::HumanStringWithLayout(instruction->shape())
.c_str());
}
return Status::OK();
}));
}
}
return false;
}
} // namespace xla

View File

@ -0,0 +1,42 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// This pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the CPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
class CpuHloSupportChecker : public HloPassInterface {
public:
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;
tensorflow::StringPiece name() const override {
return "cpu_hlo_support_checker";
}
// Note: always returns false (no instructions are ever modified by this
// pass).
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_

View File

@ -0,0 +1,72 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
using ::testing::HasSubstr;
class CpuHloSupportCheckerTest : public HloTestBase {
protected:
CpuHloSupportChecker& checker() { return checker_; }
private:
CpuHloSupportChecker checker_;
};
TEST_F(CpuHloSupportCheckerTest, Add) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, param0, param1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK(checker().Run(module.get()).status());
}
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
HloComputation::Builder builder(TestName());
const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2);
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, sparse_shape, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
sparse_shape, HloOpcode::kAdd, param0, param1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
Status status = checker().Run(module.get()).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("CPU backend does not support"));
EXPECT_THAT(status.error_message(),
HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape)));
}
} // namespace
} // namespace xla

View File

@ -437,6 +437,7 @@ cc_library(
":fusion_merger",
":gpu_copy_insertion",
":gpu_executable",
":gpu_hlo_support_checker",
":gpu_layout_assignment",
":hlo_schedule",
":instruction_fusion",
@ -610,6 +611,32 @@ tf_cc_test(
],
)
cc_library(
name = "gpu_hlo_support_checker",
srcs = ["gpu_hlo_support_checker.cc"],
hdrs = ["gpu_hlo_support_checker.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "gpu_hlo_support_checker_test",
srcs = ["gpu_hlo_support_checker_test.cc"],
deps = [
":gpu_hlo_support_checker",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
@ -39,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
@ -137,6 +139,7 @@ tensorflow::Status OptimizeHloModule(
{
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>(shape_size_function);
pipeline.AddPass<GpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
@ -476,6 +479,20 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
entry_computation->root_instruction()->Accept(&ir_emitter));
}
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
std::string err;
llvm::raw_string_ostream err_stream(err);
// verifyModule() returns true if the module is broken.
TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
<< "Invalid LLVM IR before optimizations:\n"
<< err_stream.str()
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
"Rerun with --xla_dump_ir_to to get the IR. ";
}
if (user_pre_optimization_hook_) {
TF_CHECK_OK(user_pre_optimization_hook_(llvm_module));
}

View File

@ -0,0 +1,48 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
StatusOr<bool> GpuHloSupportChecker::Run(HloModule* module) {
for (auto* computation : module->computations()) {
for (const auto& instruction : computation->instructions()) {
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
instruction->shape(),
[&instruction](const Shape& subshape, const ShapeIndex&) {
if (LayoutUtil::IsSparseArray(subshape)) {
return xla::Unimplemented(
"GPU backend does not support HLO instruction %s with shape "
"containing a sparse layout: %s",
instruction->ToString().c_str(),
ShapeUtil::HumanStringWithLayout(instruction->shape())
.c_str());
}
return Status::OK();
}));
}
}
return false;
}
} // namespace xla

View File

@ -0,0 +1,42 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// his pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the GPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
class GpuHloSupportChecker : public HloPassInterface {
public:
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;
tensorflow::StringPiece name() const override {
return "gpu_hlo_support_checker";
}
// Note: always returns false (no instructions are ever modified by this
// pass).
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_

View File

@ -0,0 +1,72 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
using ::testing::HasSubstr;
class GpuHloSupportCheckerTest : public HloTestBase {
protected:
GpuHloSupportChecker& checker() { return checker_; }
private:
GpuHloSupportChecker checker_;
};
TEST_F(GpuHloSupportCheckerTest, Add) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, param0, param1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK(checker().Run(module.get()).status());
}
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
HloComputation::Builder builder(TestName());
const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2);
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, sparse_shape, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
sparse_shape, HloOpcode::kAdd, param0, param1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
Status status = checker().Run(module.get()).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("GPU backend does not support"));
EXPECT_THAT(status.error_message(),
HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape)));
}
} // namespace
} // namespace xla

View File

@ -0,0 +1,61 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
/*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple,
int64 elements) {
CHECK(ShapeUtil::IsTuple(input_tuple->shape()));
HloComputation* computation = input_tuple->parent();
const Shape& input_shape = input_tuple->shape();
std::vector<HloInstruction*> tuple_elements;
tuple_elements.reserve(elements);
for (int i = 0; i < elements; i++) {
tuple_elements.push_back(
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
input_shape.tuple_shapes(i), input_tuple, i)));
}
return computation->AddInstruction(
HloInstruction::CreateTuple(tuple_elements));
}
/*static*/ HloInstruction* TupleUtil::AppendSuffix(
HloInstruction* input_tuple,
tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values) {
CHECK(ShapeUtil::IsTuple(input_tuple->shape()));
HloComputation* computation = input_tuple->parent();
const Shape& input_shape = input_tuple->shape();
std::vector<HloInstruction*> tuple_elements;
tuple_elements.reserve(input_shape.tuple_shapes_size());
for (int i = 0; i < input_shape.tuple_shapes_size(); i++) {
tuple_elements.push_back(
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
input_shape.tuple_shapes(i), input_tuple, i)));
}
tuple_elements.insert(tuple_elements.end(), trailing_values.begin(),
trailing_values.end());
return computation->AddInstruction(
HloInstruction::CreateTuple(tuple_elements));
}
} // namespace xla

View File

@ -0,0 +1,45 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
class TupleUtil {
public:
// Generates HLO instructions to get a prefix tuple from `input_tuple` (which
// must be of tuple shape) of length `elements`. Returns the root of the
// graph of instructions generated.
//
// The instructions are generated into the computation containing
// `input_tuple`.
static HloInstruction* ExtractPrefix(HloInstruction* input_tuple,
int64 elements);
// Generates HLO instructions to create a tuple that consists of the values in
// `trailing_values` appended to `input_tuple` (which must be of tuple shape).
// Returns the root of the graph of instructions generated.
//
// The instructions are generated into the computation containing
// `input_tuple`.
static HloInstruction* AppendSuffix(
HloInstruction* input_tuple,
tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_

View File

@ -0,0 +1,81 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
namespace op = ::xla::testing::opcode_matchers;
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
HloComputation** entry_computation, HloInstruction** param0,
HloInstruction** param1) {
const char* const hlo_string = R"(
HloModule Module
ENTRY entry {
p0 = (f32[32,32]{1,0},f32[32,32]{1,0},f32[32,32]{1,0}) parameter(0)
ROOT p1 = f32[32,32]{1,0} parameter(1)
}
)";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
tools::Parse(hlo_string));
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);
*param1 = (*entry_computation)->parameter_instruction(1);
return std::move(module);
}
TEST(TupleUtilTest, ExtractPrefix) {
HloInstruction *param0, *param1;
HloComputation* entry_computation;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
GetParsedModule(&entry_computation, &param0, &param1));
HloInstruction* prefix = TupleUtil::ExtractPrefix(param0, 2);
EXPECT_THAT(prefix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
op::GetTupleElement(op::Parameter(0), 1)));
}
TEST(TupleUtilTest, AppendSuffix) {
HloInstruction *param0, *param1;
HloComputation* entry_computation;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
GetParsedModule(&entry_computation, &param0, &param1));
HloInstruction* with_suffix =
TupleUtil::AppendSuffix(param0, {param1, param1});
EXPECT_THAT(with_suffix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
op::GetTupleElement(op::Parameter(0), 1),
op::GetTupleElement(op::Parameter(0), 2),
op::Parameter(1), op::Parameter(1)));
}
} // namespace
} // namespace xla

View File

@ -0,0 +1,296 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace xla {
using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
using tensorflow::gtl::InlinedVector;
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
// operands as needed. All of its transitive operands are expected to be either
// in `hoisted_instructions` or `unhoisted_invariant_instructions`. This
// function hoists the operands in `unhoisted_invariant_instructions` and moves
// them into `hoisted_instructions`.
static void CreateLoopInvariantCopy(
FlatMap<HloInstruction*, HloInstruction*>* hoisted_instructions,
FlatSet<HloInstruction*>* unhoisted_invariant_instructions,
HloInstruction* while_instr, HloInstruction* to_hoist) {
HloComputation* parent_of_while = while_instr->parent();
HloComputation* while_body = while_instr->while_body();
struct DFSFrame {
HloInstruction* instruction;
int64 operand_index;
};
InlinedVector<DFSFrame, 8> dfs_stack;
dfs_stack.push_back({to_hoist, 0});
HloInstruction* while_body_param = while_body->parameter_instruction(0);
HloInstruction* while_operand = while_instr->mutable_operand(0);
do {
DFSFrame* frame = &dfs_stack.back();
if (frame->operand_index == frame->instruction->operand_count()) {
HloInstruction* old_instruction = frame->instruction;
// All of the operands for old_instruction have been cloned, so it is
// time to clone old_instruction itself.
auto get_new_operand = [&](HloInstruction* old_operand) {
return old_operand == while_body_param
? while_operand
: FindOrDie(*hoisted_instructions, old_operand);
};
InlinedVector<HloInstruction*, 4> new_operands;
c_transform(old_instruction->operands(), std::back_inserter(new_operands),
get_new_operand);
HloInstruction* new_instruction =
parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands(
old_instruction->shape(), new_operands));
InsertOrDie(hoisted_instructions, old_instruction, new_instruction);
// Approximately half of the instructions that would normally be present
// in unhoisted_invariant_instructions are constants. We save a bit of
// compile time by not putting these in the hashtable.
CHECK_EQ(unhoisted_invariant_instructions->erase(old_instruction),
to_hoist != old_instruction &&
old_instruction->opcode() != HloOpcode::kConstant);
dfs_stack.pop_back();
continue;
}
HloInstruction* next_operand =
frame->instruction->mutable_operand(frame->operand_index++);
if (hoisted_instructions->count(next_operand) ||
next_operand == while_body_param) {
continue;
}
dfs_stack.push_back({next_operand, 0});
} while (!dfs_stack.empty());
}
// Returns true if `instruction` is worth hoisting only if it lets us hoist some
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification and fusion in the while body.
static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
switch (instruction.opcode()) {
default:
return false;
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
case HloOpcode::kConstant:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
case HloOpcode::kTuple:
return true;
case HloOpcode::kTranspose:
return ShapeUtil::TransposeIsBitcast(
/*input_shape=*/instruction.operand(0)->shape(),
/*output_shape=*/instruction.shape(), instruction.dimensions());
case HloOpcode::kReshape:
return ShapeUtil::ReshapeIsBitcast(
/*input_shape=*/instruction.operand(0)->shape(),
/*output_shape=*/instruction.shape());
}
}
// Populates `gte_set` with the GetTupleElement instructions in `while_body`
// that access elements in the parameter tuple that don't change across
// iterations. Assumes `while_body` is the body computation of the while loop
// in question.
static void GatherInvariantGTEs(HloComputation* while_body,
FlatSet<HloInstruction*>* gte_set) {
const HloInstruction::InstructionVector root_operands =
while_body->root_instruction()->operands();
for (int i = 0; i < root_operands.size(); i++) {
HloInstruction* instr = root_operands[i];
if (instr->opcode() == HloOpcode::kGetTupleElement &&
instr->tuple_index() == i &&
instr->operand(0) == while_body->parameter_instruction(0) &&
ShapeUtil::IsArray(instr->shape())) {
InsertOrDie(gte_set, instr);
}
}
}
static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
HloInstruction* while_instr) {
auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false);
if (!ShapeUtil::IsTuple(while_instr->shape())) {
// This restriction leaves one interesting pattern on the table:
//
// while_body(f32[1024, 1024] %param) {
// %value = expensive_op(%param)
// outfeed(%value)
// ROOT = %param
// }
//
// If we see that pattern in the while, instead of generalizing this
// algorithm to work with non-tuples, we should instead add a pass that
// canonicalizes while loops like the above to use a tuple state.
return false;
}
string while_instr_name = while_instr->ToString(print_no_metadata);
VLOG(2) << "Trying to hoist from " << while_instr_name;
HloComputation* while_body = while_instr->while_body();
// Maps instructions in the while body to instructions hoisted outside the
// while that compute the same value.
FlatMap<HloInstruction*, HloInstruction*> hoisted_instructions;
// Contains instructions that can be legally hoisted, but were deemed to be
// unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we
// hoist an instruction in this set, we move it from
// unhoisted_invariant_instructions to hoisted_instructions.
FlatSet<HloInstruction*> unhoisted_invariant_instructions;
// Invariant GTE's axiomatically satisfy the constraints for
// unhoisted_invariant_instructions -- they can be legally hoisted, but there
// is no benefit to hoisting them unless something that uses it is also
// hoisted.
GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions);
if (unhoisted_invariant_instructions.empty()) {
// There are no obviously loop invariant elements in the state being
// threaded through the while loop so give up. In theory this precondition
// is too strong -- we could have code that e.g. permutes the elements in
// the while state but uses a select to pick the same value on every
// iteration.
return false;
}
// instructions_to_replace[i] is hoisted into a loop invariant instruction
// replacement_instructions[i].
std::vector<HloInstruction*> instructions_to_replace;
std::vector<HloInstruction*> replacement_instructions;
for (auto* instruction : while_body->MakeInstructionPostOrder()) {
if (instruction->HasSideEffect() ||
instruction->opcode() == HloOpcode::kParameter ||
!instruction->control_predecessors().empty() ||
!instruction->control_successors().empty()) {
continue;
}
auto is_invariant = [&](HloInstruction* op) {
return hoisted_instructions.find(op) != hoisted_instructions.end() ||
unhoisted_invariant_instructions.count(op) ||
op->opcode() == HloOpcode::kConstant;
};
if (!c_all_of(instruction->operands(), is_invariant)) {
continue;
}
if (NotWorthHoistingIndividually(*instruction)) {
VLOG(2) << "Adding " << instruction->ToString(print_no_metadata)
<< " to unhoisted invariant set.";
// Approximately half of the instructions that reach this point are
// constants. We save a bit of compile time by not putting these in the
// hashtable.
if (instruction->opcode() != HloOpcode::kConstant) {
InsertOrDie(&unhoisted_invariant_instructions, instruction);
}
continue;
}
VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata);
CreateLoopInvariantCopy(&hoisted_instructions,
&unhoisted_invariant_instructions, while_instr,
instruction);
instructions_to_replace.push_back(instruction);
replacement_instructions.push_back(
FindOrDie(hoisted_instructions, instruction));
}
if (instructions_to_replace.empty()) {
return false;
}
TF_ASSIGN_OR_RETURN(
WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result,
WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions));
HloComputation* new_while_body =
live_in_instructions_result.new_while_instr->while_body();
for (int i = 0; i < instructions_to_replace.size(); i++) {
HloInstruction* instruction_to_replace_in_new_while =
FindOrDie(live_in_instructions_result.while_body_instruction_map,
instructions_to_replace[i]);
TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction(
instruction_to_replace_in_new_while,
live_in_instructions_result.while_body_live_in_values[i]));
}
VLOG(1) << "Hoisted " << instructions_to_replace.size()
<< " instructions from " << while_instr_name;
return true;
}
StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
bool changed = false;
std::vector<HloInstruction*> while_instrs;
for (auto* comp : module->computations()) {
c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
[](const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kWhile;
});
}
for (HloInstruction* while_instr : while_instrs) {
// Right now we only hoist computations from the while body, but
// TryHoistingInvariantInstructionsFromWhileBody can be generalized to
// optimize the condition computation too, if needed.
//
// The transform we do here is a pessmization for while loops that execute
// zero times*, but at this time we expect those to be rare. If this
// becomes a problem we can consider using the conditional HLO to avoid
// doing extra work for while loops with zero trip count.
//
// * We delete while loops that have a zero trip count, so this would have
// to be a while loop with a somewhat opaque condition expression.
TF_ASSIGN_OR_RETURN(
bool result,
TryHoistingInvariantInstructionsFromWhileBody(while_instr));
changed |= result;
}
return changed;
}
} // namespace xla

View File

@ -0,0 +1,39 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// HLO pass that rewrites while loops to hoist loop invariant instructions in
// the while body into the computation that contains the while instruction.
class WhileLoopInvariantCodeMotion : public HloPassInterface {
public:
~WhileLoopInvariantCodeMotion() override = default;
tensorflow::StringPiece name() const override {
return "while-loop-invariant-code-motion";
}
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_

View File

@ -0,0 +1,442 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
namespace op = xla::testing::opcode_matchers;
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
public:
// Makes a computation which has one parameter, of the given shape, and always
// returns PRED[]{true}. This is useful as a dummy loop condition.
HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
HloModule* module);
};
static void FindOnlyWhileInstruction(HloComputation* computation,
HloInstruction** while_instruction) {
*while_instruction = nullptr;
for (auto* instr : computation->instructions()) {
if (instr->opcode() == HloOpcode::kWhile) {
ASSERT_EQ(*while_instruction, nullptr);
*while_instruction = instr;
}
}
ASSERT_NE(*while_instruction, nullptr);
}
HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation(
const Shape& param_shape, HloModule* module) {
HloComputation::Builder builder(TestName() + ".always_true");
builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "param"));
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
return module->AddEmbeddedComputation(builder.Build());
}
TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
Shape while_shape =
ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "param"));
HloInstruction* gte_0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
HloInstruction* add_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
builder.AddInstruction(
HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "init_value"));
builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
HloComputation* entry_computation =
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
EXPECT_TRUE(simplified_loop);
HloInstruction* transformed_while;
FindOnlyWhileInstruction(entry_computation, &transformed_while);
EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(op::Add())));
}
TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
Shape while_shape =
ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "param"));
HloInstruction* gte_0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
HloInstruction* gte_2_loop_variant = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 2));
HloInstruction* add_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
HloInstruction* mul_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kMultiply, add_result, gte_1));
HloInstruction* negate_result =
builder.AddInstruction(HloInstruction::CreateUnary(
scalar_s32, HloOpcode::kNegate, mul_result));
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<int32>(4)));
HloInstruction* sub_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kSubtract, negate_result, constant));
HloInstruction* divide_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant));
builder.AddInstruction(
HloInstruction::CreateTuple({gte_0, gte_1, divide_result}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "init_value"));
builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
HloComputation* entry_computation =
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
EXPECT_TRUE(simplified_loop);
HloInstruction* transformed_while;
FindOnlyWhileInstruction(entry_computation, &transformed_while);
EXPECT_THAT(entry_computation->instructions(),
AllOf(Contains(op::Add()), Contains(op::Multiply()),
Contains(op::Negate()), Contains(op::Subtract()),
Contains(op::Constant()),
// The division had a loop varying operand so that better
// not be hoisted.
Not(Contains(op::Divide()))));
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(),
op::Subtract(), op::Constant()))));
EXPECT_THAT(transformed_while->while_body()->instructions(),
Contains(op::Divide()));
}
TEST_F(WhileLoopInvariantCodeMotionTest,
DontHoistTriviallyLoopVaryingComputation) {
// Basic negative test: the add expression is not loop invariant.
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "param"));
HloInstruction* gte_0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
HloInstruction* add_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "init_value"));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
EXPECT_FALSE(simplified_loop);
EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
}
TEST_F(WhileLoopInvariantCodeMotionTest,
DontHoistLoopVaryingComputationWithAlternatingTuples) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
Shape while_shape =
ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "param"));
HloInstruction* gte_0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
HloInstruction* add_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
builder.AddInstruction(
HloInstruction::CreateTuple({gte_1, gte_0, add_result}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "init_value"));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
EXPECT_FALSE(simplified_loop);
EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
}
TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "param"));
HloInstruction* gte_0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
builder.AddInstruction(
HloInstruction::CreateOutfeed(scalar_s32, gte_0, ""));
builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "init_value"));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
EXPECT_FALSE(simplified_loop);
EXPECT_THAT(while_inst->while_body()->instructions(),
Contains(op::Outfeed()));
}
TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
// The bitcast's user, an outfeed, can't be hoisted, so don't hoist the
// bitcast either.
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
auto scalar_f32 = ShapeUtil::MakeShape(F32, {});
Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "param"));
HloInstruction* gte_0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
HloInstruction* bitcast_inst = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0));
builder.AddInstruction(
HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, ""));
builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "init_value"));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
EXPECT_FALSE(simplified_loop);
EXPECT_THAT(while_inst->while_body()->instructions(),
Contains(op::Outfeed()));
EXPECT_THAT(while_inst->while_body()->instructions(),
Contains(op::Bitcast()));
}
TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) {
// The bitcast's user can be hoisted, so hoist the bitcast too.
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
auto scalar_f32 = ShapeUtil::MakeShape(F32, {});
Shape while_shape =
ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "param"));
HloInstruction* gte_0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_f32, param, 1));
HloInstruction* bitcast_inst = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0));
HloInstruction* add_inst =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1));
builder.AddInstruction(
HloInstruction::CreateTuple({gte_0, gte_1, add_inst}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "init_value"));
builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
HloComputation* entry_computation =
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
EXPECT_TRUE(simplified_loop);
HloInstruction* transformed_while;
FindOnlyWhileInstruction(entry_computation, &transformed_while);
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(op::Add())));
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(op::Bitcast())));
EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast()));
}
TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
Shape while_shape =
ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
HloComputation* while_body;
{
HloComputation::Builder builder(TestName() + ".while_body");
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "param"));
HloInstruction* gte_0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
HloInstruction* add_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
TF_ASSERT_OK(param->AddControlDependencyTo(add_result));
builder.AddInstruction(
HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
while_body = module().AddEmbeddedComputation(builder.Build());
}
HloComputation::Builder builder(TestName());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "init_value"));
builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
EXPECT_FALSE(simplified_loop);
}
TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".passthrough");
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "param"));
HloComputation* result = module().AddEmbeddedComputation(builder.Build());
result->AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
return result;
}();
HloComputation::Builder builder(TestName());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateParameter(0, while_shape, "init_value"));
builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
EXPECT_FALSE(simplified_loop);
}
} // namespace
} // namespace xla

View File

@ -595,7 +595,9 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
while_op->shape(), while_op->operands(), while_op->while_body()));
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op));
TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
CallInliner::Inline(call_op));
(void)inlined_instructions_map;
return true;
}
return false;

View File

@ -0,0 +1,140 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
namespace xla {
static StatusOr<HloComputation*> WidenWhileCondition(
HloComputation* narrow_condition, const Shape& wide_shape) {
const Shape& narrow_shape =
narrow_condition->parameter_instruction(0)->shape();
HloComputation* wide_while_cond = [&]() {
HloComputation::Builder builder(
tensorflow::strings::StrCat("wide.", narrow_condition->name()));
builder.AddInstruction(
HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
// This is needed so that the root instruction is shaped as a PRED[] -- we
// need to get this right to begin with since we can't mutate the type of
// the root instruction later. We later change the root instruction to
// something more appropriate.
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
}();
HloInstruction* truncated_parameter =
TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
narrow_shape.tuple_shapes_size());
HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
{truncated_parameter}, narrow_condition));
wide_while_cond->set_root_instruction(call_narrow_cond);
TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
return wide_while_cond;
}
static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
HloComputation* wide_while_body = [&]() {
HloComputation::Builder builder(
tensorflow::strings::StrCat("wide.", narrow_body->name()));
builder.AddInstruction(
HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
}();
HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
wide_parameter, narrow_shape.tuple_shapes_size());
HloInstruction* call_narrow_body =
wide_while_body->AddInstruction(HloInstruction::CreateCall(
narrow_shape, {truncated_parameter}, narrow_body));
std::vector<HloInstruction*> live_through_values;
for (int i = narrow_shape.tuple_shapes_size();
i < wide_shape.tuple_shapes_size(); i++) {
live_through_values.push_back(
wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
wide_shape.tuple_shapes(i), wide_parameter, i)));
}
wide_while_body->set_root_instruction(
TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
CallInliner::Inline(call_narrow_body));
return {{wide_while_body, std::move(inlined_instructions_map)}};
}
/*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
WhileUtil::MakeInstructionsLiveIn(
HloInstruction* while_instr,
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
CHECK(ShapeUtil::IsTuple(while_instr->shape()));
int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
Shape new_while_shape = while_instr->shape();
for (auto* instruction : instructions) {
*new_while_shape.add_tuple_shapes() = instruction->shape();
}
TF_ASSIGN_OR_RETURN(
HloComputation * new_while_condition,
WidenWhileCondition(while_instr->while_condition(), new_while_shape));
HloComputation* new_while_body;
CallInliner::InlinedInstructionMap inlined_instructions_map;
TF_ASSIGN_OR_RETURN(
std::tie(new_while_body, inlined_instructions_map),
WidenWhileBody(while_instr->while_body(), new_while_shape));
HloInstruction* new_while_init =
TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
HloComputation* containing_computation = while_instr->parent();
HloInstruction* new_while = containing_computation->AddInstruction(
HloInstruction::CreateWhile(new_while_shape, new_while_condition,
new_while_body, new_while_init));
TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction(
while_instr, TupleUtil::ExtractPrefix(
new_while, while_instr->shape().tuple_shapes_size())));
HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
std::vector<HloInstruction*> live_in_instructions;
for (int64 i = elements_in_old_while_shape;
i < new_while_shape.tuple_shapes_size(); i++) {
live_in_instructions.push_back(
new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
instructions[i - elements_in_old_while_shape]->shape(),
while_body_param, i)));
}
WhileUtil::MakeInstructionsLiveInResult result;
result.new_while_instr = new_while;
result.while_body_live_in_values = std::move(live_in_instructions);
result.while_body_instruction_map = std::move(inlined_instructions_map);
return std::move(result);
}
} // namespace xla

View File

@ -0,0 +1,58 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
class WhileUtil {
public:
// Holds a return value from MakeInstructionsLiveIn.
struct MakeInstructionsLiveInResult {
// The new while operation that has the requested values live in.
HloInstruction* new_while_instr;
// The i'th element of `while_body_live_in_values` is an instruction in the
// while body that holds the i'th *newly added* live in value at runtime.
std::vector<HloInstruction*> while_body_live_in_values;
// `while_body_instruction_map` maps instructions in the original while body
// to the corresponding instructions in the body for the newly created while
// operation.
CallInliner::InlinedInstructionMap while_body_instruction_map;
};
// Replaces `while_instr` with a new while instruction that is equivalent to
// `while_instr`, except that it has all of the HLO instructions in
// `instructions` as live-in, loop invariant values. These new live in values
// are represented as new elements appended to the parameter of the while
// loop, which must be of tuple shape. GetTupleElement instructions computing
// each new live in value is returned in the `while_body_live_in_values`
// vector.
//
// Precondition: `while_instr` must have a tuple shaped state.
//
// Every instruction in `instructions` must be contained in the computation
// that contains `while_instr`.
static StatusOr<MakeInstructionsLiveInResult> MakeInstructionsLiveIn(
HloInstruction* while_instr,
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_

View File

@ -0,0 +1,130 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
namespace op = ::xla::testing::opcode_matchers;
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
HloComputation** entry_computation, HloInstruction** param0,
HloInstruction** param1, HloInstruction** param2) {
const char* const hlo_string = R"(
HloModule ModuleWithWhile
while_body {
ROOT p_body = (f32[32,32]{1,0}, f32[32,32]{1,0}) parameter(0)
}
while_condition {
p_cond = f32[32,32]{1,0} parameter(0)
ROOT result = pred[] constant(true)
}
ENTRY entry {
p_entry_0 = f32[32,32]{1,0} parameter(0)
p_entry_1 = s32[32,32]{1,0} parameter(1)
p_entry_2 = s64[32,32]{1,0} parameter(2)
while_init = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p_entry_0, p_entry_0)
ROOT while = (f32[32,32]{1,0}, f32[32,32]{1,0}) while(while_init), condition=while_condition, body=while_body
}
)";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
tools::Parse(hlo_string));
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);
*param1 = (*entry_computation)->parameter_instruction(1);
*param2 = (*entry_computation)->parameter_instruction(2);
return std::move(module);
}
TEST(WhileUtil, MakeZeroInstructionsLiveOp) {
HloInstruction *param0, *param1, *param2;
HloComputation* entry_computation;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
GetParsedModule(&entry_computation, &param0, &param1, &param2));
HloInstruction* while_instr = entry_computation->root_instruction();
ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile);
TF_ASSERT_OK_AND_ASSIGN(
WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
WhileUtil::MakeInstructionsLiveIn(while_instr, /*instructions=*/{}));
HloInstruction* new_while_instr = make_live_in_result.new_while_instr;
EXPECT_THAT(
entry_computation->root_instruction(),
op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0),
op::GetTupleElement(::testing::Eq(new_while_instr), 1)));
auto param_reconstructed =
op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
op::GetTupleElement(op::Parameter(0), 1));
EXPECT_THAT(new_while_instr->while_body()->root_instruction(),
op::Tuple(op::GetTupleElement(param_reconstructed, 0),
op::GetTupleElement(param_reconstructed, 1)));
}
TEST(WhileUtilTest, MakeTwoInstructionsLive) {
HloInstruction *param0, *param1, *param2;
HloComputation* entry_computation;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
GetParsedModule(&entry_computation, &param0, &param1, &param2));
HloInstruction* while_instr = entry_computation->root_instruction();
ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile);
TF_ASSERT_OK_AND_ASSIGN(
WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
WhileUtil::MakeInstructionsLiveIn(while_instr,
/*instructions=*/{param0, param1}));
HloInstruction* new_while_instr = make_live_in_result.new_while_instr;
XLA_VLOG_LINES(3, module->ToString());
EXPECT_THAT(
entry_computation->root_instruction(),
op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0),
op::GetTupleElement(::testing::Eq(new_while_instr), 1)));
auto first_half_param_reconstructed =
op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
op::GetTupleElement(op::Parameter(0), 1));
EXPECT_THAT(new_while_instr->while_body()->root_instruction(),
op::Tuple(op::GetTupleElement(first_half_param_reconstructed, 0),
op::GetTupleElement(first_half_param_reconstructed, 1),
op::GetTupleElement(op::Parameter(0), 2),
op::GetTupleElement(op::Parameter(0), 3)));
}
} // namespace
} // namespace xla

View File

@ -84,7 +84,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
if (lhs.layout().format() != rhs.layout().format()) {
return false;
}
if (LayoutUtil::IsDense(lhs)) {
if (LayoutUtil::IsDenseArray(lhs)) {
if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs),
LayoutUtil::MinorToMajor(rhs))) {
VLOG(3) << "CompareShapes: lhs layout != rhs layout";
@ -202,6 +202,17 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return MakeShapeWithLayout(element_type, dimensions, layout);
}
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
int64 max_sparse_elements) {
DCHECK_NE(TUPLE, element_type);
DCHECK_NE(OPAQUE, element_type);
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
return shape;
}
/* static */ Shape
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
const Shape& shape) {
@ -249,7 +260,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
CHECK(LayoutUtil::IsDense(*shape));
CHECK(LayoutUtil::IsDenseArray(*shape));
shape->mutable_layout()->add_minor_to_major(Rank(*shape));
shape->add_dimensions(bound);
TF_DCHECK_OK(ValidateShape(*shape));
@ -658,23 +669,55 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
TF_DCHECK_OK(ValidateShape(shape));
DCHECK_NE(OPAQUE, shape.element_type());
if (shape.element_type() == TUPLE) {
CHECK_GT(pointer_size, 0);
return pointer_size * shape.tuple_shapes_size();
return ByteSizeOfTupleIndexTable(shape, pointer_size);
}
int64 byte_size = ByteSizeOfElements(shape);
if (LayoutUtil::IsSparseArray(shape)) {
byte_size += ByteSizeOfSparseIndices(shape);
}
return byte_size;
}
/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
int64 pointer_size) {
TF_DCHECK_OK(ValidateShape(shape));
DCHECK_EQ(TUPLE, shape.element_type());
CHECK_GT(pointer_size, 0);
return pointer_size * shape.tuple_shapes_size();
}
/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
DCHECK(ShapeUtil::IsArray(shape));
int64 allocated_element_count;
if (shape.layout().padded_dimensions_size() > 0) {
CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size());
allocated_element_count = 1;
for (int64 dimension_size : shape.layout().padded_dimensions()) {
allocated_element_count *= dimension_size;
}
if (LayoutUtil::IsSparseArray(shape)) {
allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
} else {
allocated_element_count = ElementsIn(shape);
CHECK(LayoutUtil::IsDenseArray(shape));
tensorflow::gtl::ArraySlice<int64> padded_dimensions =
LayoutUtil::PaddedDimensions(shape);
if (!padded_dimensions.empty()) {
CHECK_EQ(Rank(shape), padded_dimensions.size());
allocated_element_count = 1;
for (int64 dimension_size : padded_dimensions) {
allocated_element_count *= dimension_size;
}
} else {
allocated_element_count = ElementsIn(shape);
}
}
return allocated_element_count *
ByteSizeOfPrimitiveType(shape.element_type());
}
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
DCHECK(LayoutUtil::IsSparseArray(shape));
return LayoutUtil::MaxSparseElements(shape.layout()) *
ShapeUtil::Rank(shape) * sizeof(int64);
}
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) {
if (shape.element_type() == TUPLE) {
@ -900,7 +943,7 @@ Status ForEachMutableSubshapeHelper(
new_shape.add_dimensions(dim);
}
if (shape.has_layout()) {
CHECK(LayoutUtil::IsDense(shape));
CHECK(LayoutUtil::IsDenseArray(shape));
Layout* new_layout = new_shape.mutable_layout();
new_layout->set_format(DENSE);
new_layout->clear_minor_to_major();

View File

@ -143,7 +143,10 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
class ShapeUtil {
public:
// Returns the number of elements are contained within the provided shape;
// e.g. for rank 0 (scalars) the result is always 1.
// e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes
// may not actually be able to store this number of elements. See
// LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
// elements that can be stored in a sparse shape.
// Precondition: !IsTuple(shape)
static int64 ElementsIn(const Shape& shape);
@ -164,6 +167,27 @@ class ShapeUtil {
// Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
// Returns the number of bytes required to store the tuple member pointers for
// a allocation of shape. The `shape` must be a TUPLE shape, and
// `pointer_size` must be larger than zero.
static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
int64 pointer_size);
// Returns the number of bytes required for the elements in an allocation of
// `shape`, which must be an array shape. The return value does not include
// the bytes needed to store sparse indices. Dense shapes use a separate
// memory location for each element, and so for these shapes,
// `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this
// size also includes padding if present in the layout. For sparse shapes,
// `ByteSizeOf(shape) == ByteSizeOfElements(shape) +
// ByteSizeOfSparseindices(shape)`.
static int64 ByteSizeOfElements(const Shape& shape);
// Returns the number of bytes required for the sparse indices in an
// allocation of shape. The shape must be an array shape. The return value
// does not include the bytes needed to store sparse indices.
static int64 ByteSizeOfSparseIndices(const Shape& shape);
// Returns a human-readable string that represents the given shape, with or
// without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
static string HumanString(const Shape& shape);
@ -269,6 +293,10 @@ class ShapeUtil {
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major);
static Shape MakeShapeWithSparseLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
int64 max_sparse_elements);
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
static Shape MakeShapeWithDescendingLayout(
PrimitiveType element_type,

View File

@ -0,0 +1,110 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
SparseIndexArray::SparseIndexArray() : rank_(0), max_indices_(0) {}
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
std::vector<int64> indices)
: indices_(std::move(indices)), rank_(rank), max_indices_(max_indices) {
CHECK_GT(rank_, 0);
CHECK_EQ(indices_.size() % rank_, 0)
<< "indices_.size(): " << indices_.size() << ", rank_: " << rank_;
CHECK_LT(index_count(), max_indices_);
}
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
tensorflow::gtl::ArraySlice<int64> indices)
: SparseIndexArray(max_indices, rank,
std::vector<int64>(indices.begin(), indices.end())) {}
SparseIndexArray::SparseIndexArray(int64 max_indices,
const Array2D<int64>& indices)
: SparseIndexArray(max_indices, indices.n2(),
std::vector<int64>(indices.begin(), indices.end())) {}
int64 SparseIndexArray::index_count() const {
CHECK_GT(rank_, 0);
CHECK_EQ(indices_.size() % rank_, 0);
return indices_.size() / rank_;
}
tensorflow::gtl::ArraySlice<int64> SparseIndexArray::At(
int64 sparse_index_number) const {
CHECK_GT(rank_, 0);
CHECK_GE(sparse_index_number, 0);
CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size());
return tensorflow::gtl::ArraySlice<int64>(
indices_.data() + rank_ * sparse_index_number, rank_);
}
tensorflow::gtl::MutableArraySlice<int64> SparseIndexArray::At(
int64 sparse_index_number) {
CHECK_GT(rank_, 0);
CHECK_GE(sparse_index_number, 0);
CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size());
return tensorflow::gtl::MutableArraySlice<int64>(
indices_.data() + rank_ * sparse_index_number, rank_);
}
void SparseIndexArray::Append(tensorflow::gtl::ArraySlice<int64> index) {
CHECK_GT(rank_, 0);
CHECK_EQ(index.size(), rank_);
indices_.insert(indices_.end(), index.begin(), index.end());
}
void SparseIndexArray::Clear() { indices_.clear(); }
void SparseIndexArray::Resize(int64 num_indices) {
CHECK_GT(rank_, 0);
indices_.resize(rank_ * num_indices);
}
bool SparseIndexArray::Validate(const Shape& shape) const {
if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) {
return false;
}
int64 num_indices = index_count();
if (num_indices > LayoutUtil::MaxSparseElements(shape.layout())) {
return false;
}
if (num_indices < 2) {
return true;
}
tensorflow::gtl::ArraySlice<int64> last = At(0);
if (!IndexUtil::IndexInBounds(shape, last)) {
return false;
}
for (int64 n = 1; n < num_indices; ++n) {
tensorflow::gtl::ArraySlice<int64> next = At(n);
if (!IndexUtil::IndexInBounds(shape, next)) {
return false;
}
if (IndexUtil::CompareIndices(last, next) >= 0) {
return false;
}
last = next;
}
return true;
}
} // namespace xla

View File

@ -0,0 +1,176 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Utility class for managing sparse array indices.
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
// Encapsulates the array of indices for a sparse array. A SparseIndexArray
// contain indices for up to `max_indices` elements of a sparse array. Each
// sparse index is an array of `rank` int64 value that gives the location of a
// value within a sparse array. Note that the dimensions of the array are not
// checked (except for the rank). To avoid confusion, we refer to the position
// of an index within a SparseIndexArray as a sparse index number.
class SparseIndexArray {
public:
SparseIndexArray();
SparseIndexArray(const SparseIndexArray&) = default;
SparseIndexArray(SparseIndexArray&&) = default;
SparseIndexArray& operator=(const SparseIndexArray&) = default;
SparseIndexArray& operator=(SparseIndexArray&&) = default;
// Constructs a SparseIndexArray that can hold up to `max_indices` sparse
// indices, with an initial contents obtained from the given array. The rank
// is taken from the minor dimension of the array. The major dimension of the
// array must not exceed `max_indices`.
SparseIndexArray(int64 max_indices, const Array2D<int64>& indices);
// Like above, but the array is flattened. For example, the following are
// equivalent:
//
// SparseIndexArray(10, 3,
// Array2D{
// {0, 1, 2},
// {3, 4, 5},
// {6, 7, 8},
// {9, 10, 11},
// })
//
// SparseIndexArray(10, 3,
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})
//
SparseIndexArray(int64 max_indices, int64 rank,
std::vector<int64> indices = {});
SparseIndexArray(int64 max_indices, int64 rank,
tensorflow::gtl::ArraySlice<int64> indices);
// Returns the number of elements represented by the indices stored in the
// array.
int64 index_count() const;
// Returns a slice that refers to the given sparse index number. The argument
// must be in the range [0, element_count()).
tensorflow::gtl::ArraySlice<int64> At(int64 sparse_index_number) const;
tensorflow::gtl::MutableArraySlice<int64> At(int64 sparse_index_number);
// Adds the given index at the end of the array. The new size of the
// SparseIndexArray must not exceed `max_indices`.
void Append(tensorflow::gtl::ArraySlice<int64> index);
// Removes all indices from the array.
void Clear();
// Resizes the array to contain the given number of sparse indices. The new
// size must be smaller than `max_indices`. If the new size is larger than
// the old size, the value of the new indices is not specified.
void Resize(int64 num_indices);
// Returns true iff all indices are unique and occur in sorted order, and are
// valid for the given shape.
bool Validate(const Shape& shape) const;
int64 rank() const { return rank_; }
int64 max_indices() const { return max_indices_; }
// Returns a pointer to the int64 array that holds the sparse indices.
tensorflow::gtl::MutableArraySlice<int64> mutable_data() { return &indices_; }
tensorflow::gtl::ArraySlice<int64> data() const { return indices_; }
// Sorts this sparse index array along with the set of corresponding values.
// The indices and values are sorted in the lexicographic order of the
// indices, from smallest to largest.
//
// For example:
//
// std::vector<float> v{10.0, 11.0, 12.0};
// SparseIndexArray a(10, 3,
// {{3, 4, 5},
// {1, 2, 3},
// {2, 3, 4}});
// a.SortWithValues(&v);
// // Prints "11.0, 12.0, 10.0":
// std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
//
template <typename NativeT>
void SortWithValues(tensorflow::gtl::MutableArraySlice<NativeT> values);
private:
std::vector<int64> indices_;
int64 rank_;
int64 max_indices_;
};
template <typename NativeT>
void SparseIndexArray::SortWithValues(
tensorflow::gtl::MutableArraySlice<NativeT> values) {
int64 num_elements = index_count();
CHECK_EQ(values.size(), num_elements);
std::vector<int64> sort_order;
sort_order.reserve(num_elements);
for (int64 i = 0; i < num_elements; ++i) {
sort_order.push_back(i);
}
auto sort_order_less = [this](int64 lhs, int64 rhs) {
return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
};
std::sort(sort_order.begin(), sort_order.end(), sort_order_less);
// Reorder the array elements according to sort_order. Work through the array
// and follow cycles so we can do the reorder in-place.
tensorflow::gtl::InlinedVector<int64, 8> saved_index(rank());
for (int64 i = 0; i < num_elements; ++i) {
// sort_order[i] == -1 indicates the element has already been copied.
if (sort_order[i] < 0) {
continue;
} else if (i == sort_order[i]) {
// The element is already in sorted order.
sort_order[i] = -1;
continue;
}
std::copy_n(At(i).begin(), rank(), saved_index.begin());
NativeT saved_value = values[i];
int64 j = i;
for (;;) {
if (sort_order[j] == i) {
std::copy_n(saved_index.begin(), rank(), At(j).begin());
values[j] = saved_value;
sort_order[j] = -1;
break;
}
std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin());
values[j] = values[sort_order[j]];
int64 k = sort_order[j];
sort_order[j] = -1;
j = k;
}
}
}
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_

View File

@ -0,0 +1,43 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include <vector>
#include "tensorflow/compiler/xla/test.h"
namespace xla {
namespace {
TEST(SparseIndexArrayTest, Sort) {
SparseIndexArray a(10, 3);
a.Append({2, 3, 4});
a.Append({3, 4, 5});
a.Append({1, 2, 3});
a.Append({5, 6, 7});
a.Append({4, 5, 6});
a.Append({6, 7, 8});
std::vector<double> values = {
12.0, 13.0, 11.0, 15.0, 14.0, 16.0,
};
a.SortWithValues<double>(&values);
ASSERT_EQ(a.data(), std::vector<int64>({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5,
6, 7, 6, 7, 8}));
ASSERT_EQ(values, std::vector<double>({11.0, 12.0, 13.0, 14.0, 15.0, 16.0}));
}
} // namespace
} // namespace xla

View File

@ -1160,6 +1160,50 @@ TEST_F(WhileTest, WhileWithCallInsideCondition) {
ComputeAndCompareR0<int32>(&builder, 5, {});
}
TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
auto while_shape = ShapeUtil::MakeTupleShape(
{scalar_s32, matrix_shape, matrix_shape, matrix_shape});
// Create a computation for the condition: repeat for 5 iterations.
Computation condition;
{
ComputationBuilder builder(client_, "condition");
auto state = builder.Parameter(0, while_shape, "state");
builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
Computation body;
{
ComputationBuilder builder(client_, "body");
auto state = builder.Parameter(0, while_shape, "state");
auto indvar = builder.GetTupleElement(state, 0);
auto input_0 = builder.GetTupleElement(state, 1);
auto input_1 = builder.GetTupleElement(state, 2);
auto output = builder.Tanh(builder.Dot(input_0, input_1));
auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1));
auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
ComputationBuilder builder(client_, TestName());
auto matrix_input = builder.Parameter(0, matrix_shape, "matrix");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input});
auto while_instruction = builder.While(condition, body, init);
builder.GetTupleElement(while_instruction, 3);
TF_ASSERT_OK_AND_ASSIGN(auto param_value,
client_->TransferToServer(*Literal::CreateR2<float>(
{{1.0, 2.0}, {-1.0, -2.0}})));
ComputeAndCompareR2<float>(
&builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},
{param_value.get()}, ErrorSpec(4e-5));
}
void BM_WhileLoop(int num_iters) {
// Benchmark a simple kernel to measure while loop overheads.
tensorflow::testing::StopTiming();

View File

@ -1515,7 +1515,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
return false;
}
} else {
return TokenError(StrCat("unsupported premitive type ",
return TokenError(StrCat("unsupported primitive type ",
PrimitiveType_Name(shape.element_type())));
}
break;
@ -1851,7 +1851,7 @@ bool HloParser::ParseWindow(Window* window) {
if (field_name == "rhs_reversal") {
return ParseDxD("rhs_reversal", &rhs_reversal);
}
return Error(loc, StrCat("unexpected attribute name: ", field_name));
return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
}();
if (!ok) {
return false;

View File

@ -398,6 +398,31 @@ std::vector<std::pair<int64, int64>> CommonFactors(
// Removes illegal characters from filenames.
string SanitizeFileName(string file_name);
// Simple wrapper around std::all_of.
template <typename Container, typename Predicate>
bool c_all_of(Container container, Predicate predicate) {
return std::all_of(std::begin(container), std::end(container), predicate);
}
// Simple wrapper around std::transform.
template <typename InputContainer, typename OutputIterator,
typename UnaryOperation>
OutputIterator c_transform(InputContainer input_container,
OutputIterator output_iterator,
UnaryOperation unary_op) {
return std::transform(std::begin(input_container), std::end(input_container),
output_iterator, unary_op);
}
// Simple wrapper around std::copy_if.
template <class InputContainer, class OutputIterator, class UnaryPredicate>
OutputIterator c_copy_if(InputContainer input_container,
OutputIterator output_iterator,
UnaryPredicate predicate) {
return std::copy_if(std::begin(input_container), std::end(input_container),
output_iterator, predicate);
}
} // namespace xla
#define XLA_LOG_LINES(SEV, STRING) \

View File

@ -120,6 +120,9 @@ enum Format {
// The default layout, with exactly one storage location per element (ignoring
// padding).
DENSE = 1;
// A sparsely encoded layout, providing only the index/value pairs of non-zero
// elements.
SPARSE = 2;
}
// A layout describes how the array is placed in (1D) memory space. This
@ -151,6 +154,11 @@ message Layout {
// field must be unset unless the format is DENSE.
PaddingValue padding_value = 3;
// The maximum number of elements that can be stored for SPARSE formats. This
// can be used to determine the maximum size in bytes of arrays stored in
// memory. This field must be unset unless the format is SPARSE.
int64 max_sparse_elements = 5;
// Important: if any field is added, be sure to modify ShapeUtil::Equal()
// appropriately to account for the new field.
}
@ -333,7 +341,8 @@ message LiteralProto {
// The F16s and BF16s are encoded in little endian byte order
bytes f16s = 11;
bytes bf16s = 13;
// Next = 14
repeated int64 sparse_indices = 14;
// Next = 15
}
message WindowDimension {

View File

@ -210,6 +210,7 @@ std::unique_ptr<TaskType> Batch<TaskType>::RemoveTask() {
return nullptr;
}
std::unique_ptr<TaskType> task = std::move(tasks_.back());
size_ -= task->size();
tasks_.pop_back();
return task;
}

View File

@ -74,7 +74,9 @@ TEST(BatchTest, Basic) {
EXPECT_EQ(task1->size(), batch.task(1).size());
EXPECT_EQ(7, batch.RemoveTask()->size());
EXPECT_EQ(3, batch.size());
EXPECT_EQ(3, batch.RemoveTask()->size());
EXPECT_EQ(0, batch.size());
EXPECT_TRUE(batch.empty());
}

View File

@ -524,7 +524,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.read_coordination_events[expected_element].acquire()
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
if not done_first_event:
done_first_event = True
@ -611,7 +611,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.read_coordination_events[expected_element].acquire()
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
if not done_first_event:
done_first_event = True

View File

@ -7,7 +7,11 @@ exports_files([
"generic_tree_model_proto.swig",
])
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library",
"tf_pyclif_proto_library",
)
filegroup(
name = "all_files",
@ -34,3 +38,10 @@ tf_proto_library(
protodeps = [":generic_tree_model"],
visibility = ["//visibility:public"],
)
tf_pyclif_proto_library(
name = "generic_tree_model_pyclif",
proto_lib = ":generic_tree_model",
proto_srcfile = "generic_tree_model.proto",
visibility = ["//visibility:public"],
)

View File

@ -17,6 +17,7 @@ py_library(
"python/ops/__init__.py",
"python/ops/alpha_dropout.py",
"python/ops/cross_entropy.py",
"python/ops/fwd_gradients.py",
"python/ops/sampling_ops.py",
"python/ops/scaled_softplus.py",
],
@ -28,6 +29,7 @@ py_library(
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
@ -55,6 +57,19 @@ py_test(
],
)
py_test(
name = "fwd_gradients_test",
size = "small",
srcs = ["python/ops/fwd_gradients_test.py"],
srcs_version = "PY2AND3",
deps = [
":nn_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:math_ops",
],
)
py_test(
name = "sampling_ops_test",
size = "small",

View File

@ -0,0 +1,76 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Forward-mode derivatives."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops.gradients_impl import gradients
def fwd_gradients(ys, xs, grad_xs=None, assert_unused=False):
"""Computes forward-mode derivatives.
This is accomplished in pure-python using tensorflow's existing (reverse-mode)
gradients. There is additional overhead on graph construction, but runtime
performance should be equal to a manual implementation [citation needed].
See https://j-towns.github.io/2017/06/12/A-new-trick.html and
https://github.com/HIPS/autograd/pull/175 for the original discussion of this
method, and https://github.com/renmengye/tensorflow-forward-ad for a "direct"
implementation.
Args:
ys: A list of tensors.
xs: A list of tensors.
grad_xs: An optional list of tensors. If provided, must have the same length
and shapes compatible with xs.
assert_unused: Add assertions that intermediate values are not computed.
Returns:
A list of tensors of the same shapes as ys. The directional derivatives of
ys with respect to xs in the direction grad_xs. Leaving grad_xs unspecified
is equivalent to passing in 1s for each x in xs.
"""
# This version of forward-mode autodiff is based on code by Tim Cooijmans
# and handles list arguments and certain special cases such as when the
# ys doesn't depend on one or more of the xs, and when tf.IndexedSlices are
# generated by the first tf.gradients call.
us = [array_ops.zeros_like(y) + float('nan') for y in ys]
dydxs = gradients(ys, xs, grad_ys=us)
# deal with strange types that tf.gradients returns but can't deal with
dydxs = [ops.convert_to_tensor(dydx) if isinstance(dydx, ops.IndexedSlices)
else dydx for dydx in dydxs]
if assert_unused:
with ops.control_dependencies(dydxs):
assert_unused = control_flow_ops.Assert(False, [1], name='fwd_gradients')
with ops.control_dependencies([assert_unused]):
dydxs = array_ops.identity_n(dydxs)
dydxs = [array_ops.zeros_like(x) if dydx is None else dydx
for x, dydx in zip(xs, dydxs)]
for x, dydx in zip(xs, dydxs):
dydx.set_shape(x.shape)
dysdx = gradients(dydxs, us, grad_ys=grad_xs)
return dysdx

View File

@ -0,0 +1,52 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for forward_ad.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.nn.python.ops import fwd_gradients
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class ForwardAdTest(test.TestCase):
def testSquare(self):
x = constant_op.constant(1.)
y = math_ops.square(x)
grad_x = 3.
dydx_tf = fwd_gradients.fwd_gradients([y], [x], [grad_x])[0]
dydx_py = 2. * grad_x
with self.test_session() as sess:
self.assertAllClose(sess.run(dydx_tf), dydx_py, 1e-6)
def testGather(self):
x = constant_op.constant([1., 2., 3.])
y = array_ops.gather(x, [0, 1])
y.set_shape([2])
dydx = fwd_gradients.fwd_gradients([y], [x], assert_unused=True)
with self.test_session() as sess:
sess.run(dydx)
if __name__ == '__main__':
test.main()

View File

@ -226,6 +226,11 @@ bool IsConstantFoldable(
if (consider && !consider(n)) {
return false;
}
// PlaceholderWithDefault shouldn't be constant folded because its output can
// be fed non-constant values.
if (n->type_string() == "PlaceholderWithDefault") {
return false;
}
if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
return false;
}

View File

@ -338,6 +338,40 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
EXPECT_FALSE(was_mutated);
}
TEST_F(ConstantFoldingTest, Placeholders) {
Graph g(OpRegistry::Global());
{
Scope s = Scope::NewRootScope();
auto placeholder = ops::Placeholder(s, DT_DOUBLE);
auto add = ops::Add(s, placeholder, 2.0);
auto send =
ops::_Send(s.WithOpName("send"), add, "add", "sender", 0, "receiver");
TF_ASSERT_OK(s.ToGraph(&g));
}
bool was_mutated;
Status s = ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
nullptr, &g, &was_mutated);
EXPECT_FALSE(was_mutated);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(s.error_message().find(
"You must feed a value for placeholder "
"tensor 'Placeholder' with dtype double") != string::npos);
Graph g2(OpRegistry::Global());
{
Scope s = Scope::NewRootScope();
auto placeholder = ops::PlaceholderWithDefault(s, {1.0}, {1});
auto add = ops::Add(s, placeholder, 2.0);
auto send =
ops::_Send(s.WithOpName("send"), add, "add", "sender", 0, "receiver");
TF_ASSERT_OK(s.ToGraph(&g2));
}
// TODO(skyewm): should this have the same behavior as Placeholder?
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
nullptr, &g2, &was_mutated));
EXPECT_FALSE(was_mutated);
}
TEST_F(ConstantFoldingTest, ControlDependencies) {
Graph g(OpRegistry::Global());
{

View File

@ -558,6 +558,13 @@ Status ShapeRefiner::ExtractConstantSubgraph(
return Status::OK();
}
if (target_node->type_string() == "PlaceholderWithDefault") {
return Status::OK();
}
// TODO(skyewm): more of the filtering applied in input nodes below should be
// applied to target_node here
struct NodeAndRecursed {
Node* new_node = nullptr;
bool recursed = false;
@ -608,6 +615,14 @@ Status ShapeRefiner::ExtractConstantSubgraph(
return Status::OK();
}
// Placeholders should never be constant folded because their outputs are
// fed by the user. Note that "Placeholder" nodes have no inputs so are
// handled below.
if (current_node->type_string() == "PlaceholderWithDefault") {
*is_constant_graph = false;
return Status::OK();
}
// If there is nothing more to recurse down, see if
// the generator node is a constant.
if (current_node->num_inputs() == 0) {

View File

@ -724,6 +724,25 @@ TEST_F(ShapeRefinerTest, PropagateRange) {
EXPECT_EQ("[1,4,7,10]", ctx->DebugString(ctx->output(0)));
}
// Make sure PlaceholderWithDefaults aren't treated as constants.
TEST_F(ShapeRefinerTest, NoPropagatePlaceholderWithDefault) {
Scope root = Scope::NewRootScope();
auto constant = ops::Const<int>(root, 2);
auto placeholder =
ops::PlaceholderWithDefault(root, constant, PartialTensorShape());
Node* shape_data;
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
.Input(placeholder.node())
.Finalize(root.graph(), &shape_data));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(constant.node()));
TF_ASSERT_OK(m.AddNode(placeholder.node()));
TF_ASSERT_OK(m.AddNode(shape_data));
shape_inference::InferenceContext* ic = m.GetContext(shape_data);
EXPECT_EQ(ic->DebugString(ic->output(0)), "?");
}
TEST_F(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) {
Scope root = Scope::NewRootScope();
// This node is used as two inputs to 'range'.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include <list>
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
@ -163,6 +164,8 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) {
NodeMap node_map(&item_.graph);
for (const auto& dev_stats : timeline.dev_stats()) {
const string& device_name = dev_stats.device();
const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:"));
std::list<LiveTensor>& device_tensors =
live_tensors_per_device[dev_stats.device()];
for (const auto& node_stats : dev_stats.node_stats()) {
@ -194,7 +197,24 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) {
// graph (e.g _Send/_Recv nodes).
continue;
}
for (const string& input : node->input()) {
std::unordered_set<int> swapped_inputs;
if (is_gpu) {
auto it = node->attr().find("_swap_to_host");
if (it != node->attr().end()) {
const AttrValue& val = it->second;
for (int port_id : val.list().i()) {
swapped_inputs.insert(port_id);
}
}
}
for (int i = 0; i < node->input_size(); ++i) {
if (swapped_inputs.find(i) != swapped_inputs.end()) {
// The memory of swapped inputs will be released as early as possible:
// therefore ignore this input when determining the deallocation time
// of the tensor.
continue;
}
const string& input = node->input(i);
int position;
string input_node = ParseNodeName(input, &position);
if (position < 0) {

View File

@ -134,6 +134,62 @@ TEST_F(GraphMemoryTest, MultiDevice) {
EXPECT_EQ(gpu_expected, gpu_tensors);
}
TEST_F(GraphMemoryTest, GpuSwapping) {
TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false, {"/GPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
item.feed.clear();
{
// Estimate the max memory usage for the graph.
GraphMemory memory(item);
Status s = memory.InferStatically(devices_);
TF_CHECK_OK(s);
const GraphMemory::MemoryUsage& gpu_mem =
memory.GetPeakMemoryUsage("/GPU:0");
EXPECT_EQ(20971520, gpu_mem.used_memory);
std::set<string> gpu_tensors;
for (const auto& t : gpu_mem.live_tensors) {
gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
}
std::set<string> gpu_expected;
gpu_expected.insert("Square:0");
gpu_expected.insert("Square_1:0");
gpu_expected.insert("AddN:0");
gpu_expected.insert("AddN_1:0");
gpu_expected.insert("AddN_2:0");
EXPECT_EQ(gpu_expected, gpu_tensors);
}
{
// Swap the first input to node AddN_1: its fanin (the square nodes) should
// not appear in the max cut anymore.
for (auto& node : *item.graph.mutable_node()) {
if (node.name() == "AddN_1") {
(*node.mutable_attr())["_swap_to_host"].mutable_list()->add_i(0);
}
}
GraphMemory memory(item);
Status s = memory.InferStatically(devices_);
TF_CHECK_OK(s);
const GraphMemory::MemoryUsage& new_gpu_mem =
memory.GetPeakMemoryUsage("/GPU:0");
EXPECT_EQ(20971520, new_gpu_mem.used_memory);
std::set<string> new_gpu_tensors;
for (const auto& t : new_gpu_mem.live_tensors) {
new_gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
}
std::set<string> new_gpu_expected;
new_gpu_expected.insert("AddN:0");
new_gpu_expected.insert("AddN_1:0");
new_gpu_expected.insert("AddN_2:0");
new_gpu_expected.insert("AddN_3:0");
new_gpu_expected.insert("AddN_4:0");
EXPECT_EQ(new_gpu_expected, new_gpu_tensors);
}
}
TEST_F(GraphMemoryTest, CtrlDependencies) {
// Build a simple graph with a control dependency.
Scope s = Scope::NewRootScope();

View File

@ -31,8 +31,6 @@ namespace {
GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
bool use_multiple_devices, bool insert_queue,
const std::vector<string>& device_names) {
CHECK_GE(device_names.size(), width);
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@ -49,13 +47,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
std::vector<Output> this_stage;
for (int j = 0; j < width; j++) {
if (last_stage.size() == 1) {
Output unary_op =
Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
last_stage[0]);
Output unary_op = Square(
s.WithDevice(
device_names[use_multiple_devices ? j % device_names.size()
: 0]),
last_stage[0]);
this_stage.push_back(unary_op);
} else {
Output combine =
AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
AddN(s.WithDevice(
device_names[use_multiple_devices ? j % device_names.size()
: 0]),
last_stage);
this_stage.push_back(combine);
}

View File

@ -433,13 +433,42 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
id = --min_id;
}
}
// Beware: the reduction dimensions computed by the BCast class are valid iff
// we assume that two distinct symbolic dimensions can't be equal and a
// symbolic dimension can't be equal to 1. This is often but not always true,
// so to make this optimization safe we filter out these cases.
const int common_dims = std::min(shape1.size(), shape2.size());
for (int i = 0; i < common_dims; ++i) {
if (shape1[i] >= 0 && shape2[i] >= 0) {
continue;
}
if (shape1[i] != shape2[i]) {
// We're either dealing with 2 different symbolic dimensions or a symbolic
// and a know dimensions. We can't be sure whether both are equal or not,
// so we can't be sure whether we'll be broadcasting or not.
return Status::OK();
}
}
// These extra dims could be equal to 1, in which case there is no
// broadcasting. It could also be greater than 1, in which case there would
// be broadcasting. Since we don't know, we'll just punt.
for (int i = common_dims; i < shape1.size(); ++i) {
if (shape1[i] < 0) {
return Status::OK();
}
}
for (int i = common_dims; i < shape2.size(); ++i) {
if (shape2[i] < 0) {
return Status::OK();
}
}
BCast bcast(shape1, shape2);
if (!bcast.IsValid()) {
return Status::OK();
}
// Beware: the reduction dimensions are valid iff we assume that two distinct
// symbolic dimensions can't be equal. This is often but not always true, so
// this optimization isn't safe.
BCast::Vec reduce_dims[2];
reduce_dims[0] = bcast.grad_x_reduce_idx();
reduce_dims[1] = bcast.grad_y_reduce_idx();
@ -447,26 +476,27 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
const DataType type = node.attr().at("T").type();
NodeDef* out[2];
for (int j = 0; j < 2; ++j) {
if (!reduce_dims[j].empty()) {
// This is the case when a tensor dimension of 1 is matched against an
// unknown dimension. The unknown dimension could also be equal to 1, in
// which case there would be no reduction.
out[j] = nullptr;
} else {
string const_name = OptimizedNodeName(node, strings::StrCat("-", j));
out[j] = node_map_->GetNode(const_name);
if (out[j] == nullptr) {
out[j] = graph_->add_node();
Tensor value(type, TensorShape({0}));
*out[j] = CreateNodeDef(const_name, TensorValue(&value));
out[j]->set_device(node.device());
node_map_->AddNode(const_name, out[j]);
string ctrl_dep =
AddControlDependency(node.name(), graph_, node_map_.get());
*out[j]->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
int reduction_indices = reduce_dims[j].size();
Tensor value(type, TensorShape({reduction_indices}));
for (int i = 0; i < reduction_indices; ++i) {
if (type == DT_INT32) {
value.vec<int32>()(i) = reduce_dims[j][i];
} else {
value.vec<int64>()(i) = reduce_dims[j][i];
}
}
string const_name = OptimizedNodeName(node, strings::StrCat("-", j));
out[j] = node_map_->GetNode(const_name);
if (out[j] == nullptr) {
out[j] = graph_->add_node();
*out[j] = CreateNodeDef(const_name, TensorValue(&value));
out[j]->set_device(node.device());
node_map_->AddNode(const_name, out[j]);
string ctrl_dep =
AddControlDependency(node.name(), graph_, node_map_.get());
*out[j]->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
}
}
const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
@ -584,12 +614,11 @@ Status ConstantFolding::MaterializeReductionIndices(
Status ConstantFolding::MaterializeConstants(
const GraphProperties& properties) {
const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
const int node_count = graph_->node_size();
for (int i = 0; i < node_count; ++i) {
NodeDef& node = *graph_->mutable_node(i);
const string& op = node.op();
if (is_aggressive && op == "BroadcastGradientArgs") {
if (op == "BroadcastGradientArgs") {
TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
} else if (IsReduction(node)) {
TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));

View File

@ -1373,21 +1373,14 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
} else if (node.name() == "p1") {
++found;
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("ConstantFolding/i-0", node.input(0));
EXPECT_EQ("i", node.input(0));
} else if (node.name() == "p2") {
++found;
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("i:1", node.input(0));
} else if (node.name() == "ConstantFolding/i-0") {
++found;
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^i", node.input(0));
EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
.num_elements());
}
}
EXPECT_EQ(7, found);
EXPECT_EQ(6, found);
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {

View File

@ -148,7 +148,10 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
first_control_input = i;
break;
}
if (actual.input(i) != expected.input(i)) {
// Special case for inputs: "tensor" is equivalent to "tensor:0"
if (actual.input(i) != expected.input(i) &&
actual.input(i) != strings::StrCat(expected.input(i), ":0") &&
strings::StrCat(actual.input(i), ":0") != expected.input(i)) {
if (diff != nullptr) {
*diff = strings::StrCat("Node named '", actual.name(), "' has input ",
i, " '", actual.input(i),

View File

@ -5922,6 +5922,7 @@ func RandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types [
//
// This op is hidden from public in Python. It is used by TensorFlow Debugger to
// register gradient tensors for gradient debugging.
// This op operates on non-reference-type tensors.
func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) {
if scope.Err() != nil {
return

View File

@ -344,7 +344,7 @@ def implicit_val_and_grad(f):
def grad_fn(*args):
"""Computes the gradient of the wrapped function."""
tape.push_new_tape()
this_tape = tape.push_new_tape()
try:
end_node = f(*args)
if end_node is None:
@ -352,10 +352,10 @@ def implicit_val_and_grad(f):
"did you forget to return a value from {}?".format(
f.__name__))
finally:
popped_tape = tape.pop_tape()
tape.pop_tape(this_tape)
# Sorting variables by id, which is monotonically increasing in construction
# order. This ensures unique order across executions.
variables = list(sorted(popped_tape.watched_variables(),
variables = list(sorted(this_tape.watched_variables(),
key=lambda v: v.handle._id)) # pylint: disable=protected-access
sources = [x.handle for x in variables]
@ -363,7 +363,7 @@ def implicit_val_and_grad(f):
raise ValueError("No trainable variables were accessed while the "
"function was being computed.")
grad = imperative_grad.imperative_grad(_default_vspace,
popped_tape,
this_tape,
nest.flatten(end_node),
sources)
return end_node, list(zip(grad, variables))
@ -652,7 +652,7 @@ def make_vjp(f, params=None):
"""Computes the value and gradient of the decorated function."""
parameter_positions = _get_arg_spec(f, params, args)
assert not kwds, "The gradient function can't take keyword arguments."
tape.push_new_tape()
this_tape = tape.push_new_tape()
try:
sources = []
args = [
@ -673,12 +673,12 @@ def make_vjp(f, params=None):
flat_result = [gen_array_ops.identity(x) for x in flat_result]
result = nest.pack_sequence_as(result, flat_result)
finally:
t = tape.pop_tape()
tape.pop_tape(this_tape)
def vjp(dy=None):
if dy is not None:
dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
return imperative_grad.imperative_grad(
_default_vspace, t, nest.flatten(result), sources,
_default_vspace, this_tape, nest.flatten(result), sources,
output_gradients=dy)
return result, vjp
@ -835,11 +835,11 @@ class GradientTape(object):
self._persistent = persistent
def __enter__(self):
tape.push_new_tape(persistent=self._persistent)
self._tape = tape.push_new_tape(persistent=self._persistent)
return self
def __exit__(self, typ, value, traceback):
self._tape = tape.pop_tape()
tape.pop_tape(self._tape)
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.

View File

@ -24,9 +24,12 @@ import copy
import random
import threading
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
GRAPH_MODE = 0
@ -398,6 +401,36 @@ class Context(object):
"""Get the list of post-execution callbacks added to the context."""
return self._post_execution_callbacks
def enable_run_metadata(self):
"""Enables tracing of op execution via RunMetadata.
To retrieve the accumulated metadata call context.export_run_metadata()
and to stop tracing call context.disable_run_metadata().
"""
pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle)
def disable_run_metadata(self):
"""Disables tracing of op execution via RunMetadata."""
pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
def export_run_metadata(self):
"""Returns a RunMetadata proto with accumulated information.
The returned protocol buffer contains information since the most recent call
to either enable_run_metadata or export_run_metadata.
Returns:
A RunMetadata protocol buffer.
"""
with c_api_util.tf_buffer() as buffer_:
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TFE_ContextExportRunMetadata(
self._context_handle, buffer_, status)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
run_metadata = config_pb2.RunMetadata()
run_metadata.ParseFromString(compat.as_bytes(proto_data))
return run_metadata
_context = None
_context_lock = threading.Lock()
@ -516,3 +549,29 @@ def num_gpus():
The number of available GPU devices.
"""
return context().num_gpus()
def enable_run_metadata():
"""Enables tracing of op execution via RunMetadata.
To retrieve the accumulated metadata call context.export_run_metadata()
and to stop tracing call context.disable_run_metadata().
"""
context().enable_run_metadata()
def disable_run_metadata():
"""Disables tracing of op execution via RunMetadata."""
context().disable_run_metadata()
def export_run_metadata():
"""Returns a RunMetadata proto with accumulated information.
The returned protocol buffer contains information since the most recent call
to either enable_run_metadata or export_run_metadata.
Returns:
A RunMetadata protocol buffer.
"""
return context().export_run_metadata()

View File

@ -84,6 +84,20 @@ class TFETest(test_util.TensorFlowTestCase):
self.assertTrue(has_cpu_device)
del ctx
def testRunMetadata(self):
context.enable_run_metadata()
t = constant_op.constant(1.0)
_ = t + t # Runs an operation which will be in the RunMetadata
run_metadata = context.export_run_metadata()
context.disable_run_metadata()
step_stats = run_metadata.step_stats
self.assertGreater(len(step_stats.dev_stats), 0)
cpu_stats = step_stats.dev_stats[0]
self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
cpu_stats.device)
self.assertEqual(len(cpu_stats.node_stats), 1)
self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add')
def testContextStackContainsEagerMode(self):
# Eager execution has been enabled, and no other context
# switch has occurred, so `context_stack` should contain

View File

@ -544,11 +544,12 @@ def _defun_internal(name, func, args, kwds):
func_inputs = _get_defun_inputs(args)
with capture_tensors(captures):
tape.push_new_tape()
this_tape = tape.push_new_tape()
try:
func_outputs = func(*func_inputs, **kwds)
finally:
variables = tape.pop_tape().watched_variables()
tape.pop_tape(this_tape)
variables = this_tape.watched_variables()
# Returning a closed-over tensor as an output does not trigger a
# call to convert_to_tensor, so we manually capture all such tensors.

View File

@ -332,7 +332,7 @@ void EagerTensor_dealloc(EagerTensor* self) {
tensorflow::ClearDecrefCache();
auto id = self->id;
Py_TYPE(self)->tp_free(self);
TFE_Py_TapeStackDeleteTrace(id);
TFE_Py_TapeSetDeleteTrace(id);
}
// Getter for `_id`.

View File

@ -87,22 +87,25 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
// newly created type, or nullptr on error.
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
// Pushes a new tape into the thread-local stack.
// `persistent` must be a PyBool_Type, i.e either Py_True or Py_False
void TFE_Py_TapeStackPushNew(PyObject* persistent);
// Creates a new tape and adds it to the active set. `persistent` must be a
// PyBool_Type, i.e either Py_True or Py_False
PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
// Pops the tape from the top of the stack and returns it.
PyObject* TFE_Py_TapeStackPop();
// Pushes an existing tape onto the stack.
void TFE_Py_TapeStackPush(PyObject* tape);
// Removes the passed tape from the set of active tapes.
void TFE_Py_TapeSetRemove(PyObject* tape);
// Returns true if the tape stack is empty.
PyObject* TFE_Py_TapeStackIsEmpty();
PyObject* TFE_Py_TapeSetIsEmpty();
PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors);
void TFE_Py_TapeStackWatch(PyObject* tensor);
void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id);
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
void TFE_Py_TapeSetWatch(PyObject* tensor);
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
// Stops any gradient recording on the current thread.
void TFE_Py_TapeSetStopOnThread();
// Restarts gradient recording on the current thread.
void TFE_Py_TapeSetRestartOnThread();
// Records an operation in the gradient tape stack.type is a string for the
// operation type, used in the backprop code. output_tensors should be a list of
@ -111,13 +114,12 @@ void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id);
// operation. backward_function should be the function to be called during
// backprop to, given the gradients of the output tensors, produce the gradients
// of the input tensors.
void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
PyObject* output_tensors,
PyObject* input_tensor_ids,
PyObject* backward_function);
void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
PyObject* input_tensor_ids,
PyObject* backward_function);
// Watches the given variable object on the given tape.
void TFE_Py_TapeStackWatchVariable(PyObject* variable);
void TFE_Py_TapeSetWatchVariable(PyObject* variable);
// Computes a gradient based on information recorded on the tape.`tape` must
// have been produced by TFE_Py_NewTape. `vspace` must be a

View File

@ -538,62 +538,67 @@ static PyTypeObject TFE_Py_Tape_Type = {
"TFE_Py_Tape objects", /* tp_doc */
};
// Note: in the current design no mutex is needed here because of the python
// GIL, which is always held when any TFE_Py_* methods are called. We should
// revisit this if/when decide to not hold the GIL while manipulating the tape
// stack.
static std::unordered_set<TFE_Py_Tape*>* tape_set = nullptr;
std::unordered_set<TFE_Py_Tape*>* GetTapeSet() {
if (tape_set == nullptr) {
tape_set = new std::unordered_set<TFE_Py_Tape*>;
}
return tape_set;
}
// xcode 7 doesn't define thread_local, so for compatibility we implement our
// own. TODO(apassos) remove once we can deprecate xcode 7.
#ifndef __APPLE__
std::vector<TFE_Py_Tape*>* GetTapeStack() {
thread_local std::vector<TFE_Py_Tape*> tape_stack;
return &tape_stack;
bool* ThreadTapeIsStopped() {
thread_local bool thread_tape_is_stopped{false};
return &thread_tape_is_stopped;
}
#else
static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>*
tape_stack GUARDED_BY(stack_mu) = nullptr;
std::vector<TFE_Py_Tape*>* GetTapeStack() {
tensorflow::mutex_lock ml(stack_mu);
if (tape_stack == nullptr) {
tape_stack =
new std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>;
static std::unordered_map<std::thread::id, bool>* tape_is_stopped = nullptr;
bool* ThreadTapeIsStopped() {
if (tape_is_stopped == nullptr) {
tape_is_stopped = new std::unordered_map<std::thread::id, bool>;
}
auto it = tape_stack->find(std::this_thread::get_id());
if (it != tape_stack->end()) {
return it->second;
auto it = tape_is_stopped->find(std::this_thread::get_id());
if (it != tape_is_stopped->end()) {
return &(it->second);
}
return tape_stack
->emplace(std::this_thread::get_id(), new std::vector<TFE_Py_Tape*>)
.first->second;
return &(tape_is_stopped->emplace(std::this_thread::get_id(), false)
.first->second);
}
#endif
void TFE_Py_TapeStackPushNew(PyObject* persistent) {
void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return;
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
tape->tape = new GradientTape(persistent == Py_True);
GetTapeStack()->push_back(tape);
}
void TFE_Py_TapeStackPush(PyObject* tape) {
Py_INCREF(tape);
GetTapeStack()->push_back(reinterpret_cast<TFE_Py_Tape*>(tape));
GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
return reinterpret_cast<PyObject*>(tape);
}
PyObject* TFE_Py_TapeStackIsEmpty() {
if (GetTapeStack()->empty()) {
PyObject* TFE_Py_TapeSetIsEmpty() {
if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
PyObject* TFE_Py_TapeStackPop() {
auto* stack = GetTapeStack();
if (stack->empty()) {
PyErr_SetString(PyExc_RuntimeError, "tape stack is empty.");
return nullptr;
}
TFE_Py_Tape* top = stack->back();
stack->pop_back();
return reinterpret_cast<PyObject*>(top);
void TFE_Py_TapeSetRemove(PyObject* tape) {
auto* stack = GetTapeSet();
stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
// We kept a reference to the tape in the set to ensure it wouldn't get
// deleted under us; cleaning it up here.
Py_DECREF(tape);
}
static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
@ -620,12 +625,15 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
return tensor_ids;
}
PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
if (tensors == Py_None) {
Py_RETURN_FALSE;
}
auto* stack = GetTapeStack();
if (stack->empty()) {
if (*ThreadTapeIsStopped()) {
Py_RETURN_FALSE;
}
auto* tape_set = GetTapeSet();
if (tape_set->empty()) {
Py_RETURN_FALSE;
}
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@ -642,7 +650,7 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
tensor_ids.push_back(FastTensorId(item));
}
Py_DECREF(seq);
for (TFE_Py_Tape* tape : *stack) {
for (TFE_Py_Tape* tape : *tape_set) {
if (tape->tape->ShouldRecord(tensor_ids)) {
Py_RETURN_TRUE;
}
@ -650,12 +658,12 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
Py_RETURN_FALSE;
}
void TFE_Py_TapeStackWatch(PyObject* tensor) {
void TFE_Py_TapeSetWatch(PyObject* tensor) {
tensorflow::int64 tensor_id = FastTensorId(tensor);
if (PyErr_Occurred()) {
return;
}
for (TFE_Py_Tape* tape : *GetTapeStack()) {
for (TFE_Py_Tape* tape : *GetTapeSet()) {
tape->tape->Watch(tensor_id);
}
}
@ -720,8 +728,8 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
return list;
}
void TFE_Py_TapeStackWatchVariable(PyObject* variable) {
for (TFE_Py_Tape* tape : *GetTapeStack()) {
void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
for (TFE_Py_Tape* tape : *GetTapeSet()) {
tape->tape->WatchVariable(variable);
}
}
@ -736,12 +744,11 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
return result;
}
void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
PyObject* output_tensors,
PyObject* input_tensors,
PyObject* backward_function) {
auto* stack = GetTapeStack();
if (stack->empty()) {
void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
PyObject* input_tensors,
PyObject* backward_function) {
auto* set = GetTapeSet();
if (set->empty()) {
return;
}
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
@ -776,7 +783,7 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
return;
}
for (TFE_Py_Tape* tape : *stack) {
for (TFE_Py_Tape* tape : *set) {
Py_INCREF(backward_function);
tape->tape->RecordOperation(
op_type_str, output_info, input_ids, backward_function,
@ -784,8 +791,8 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
}
}
void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id) {
for (TFE_Py_Tape* tape : *GetTapeStack()) {
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
for (TFE_Py_Tape* tape : *GetTapeSet()) {
tape->tape->DeleteTrace(tensor_id);
}
}

View File

@ -35,7 +35,8 @@ class Tape(object):
def push_new_tape(persistent=False):
"""Pushes a new tape onto the tape stack."""
pywrap_tensorflow.TFE_Py_TapeStackPushNew(persistent)
tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent)
return Tape(tape)
def watch(tensor):
@ -44,7 +45,7 @@ def watch(tensor):
Args:
tensor: tensor to be watched.
"""
pywrap_tensorflow.TFE_Py_TapeStackWatch(tensor)
pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor)
def watch_variable(variable):
@ -53,42 +54,39 @@ def watch_variable(variable):
Args:
variable: variable to be watched.
"""
pywrap_tensorflow.TFE_Py_TapeStackWatchVariable(variable)
pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable)
def pop_tape():
def pop_tape(tape):
"""Pops the top tape in the stack, if any."""
return Tape(pywrap_tensorflow.TFE_Py_TapeStackPop())
pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
@contextlib.contextmanager
def stop_recording():
stack = []
while not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty():
stack.append(pop_tape()._tape) # pylint: disable=protected-access
try:
pywrap_tensorflow.TFE_Py_TapeSetStopOnThread()
yield
finally:
for tape in reversed(stack):
pywrap_tensorflow.TFE_Py_TapeStackPush(tape)
pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread()
def should_record(tensors):
"""Returns true if any tape in the stack watches any of these tensors."""
return pywrap_tensorflow.TFE_Py_TapeStackShouldRecord(tensors)
return pywrap_tensorflow.TFE_Py_TapeSetShouldRecord(tensors)
def record_operation(op_type, output_tensors, input_tensors, backward_function):
"""Records the operation on all tapes in the stack."""
pywrap_tensorflow.TFE_Py_TapeStackRecordOperation(
pywrap_tensorflow.TFE_Py_TapeSetRecordOperation(
op_type, output_tensors, input_tensors, backward_function)
def delete_trace(tensor_id):
"""Deletes traces for this Tensor from all tapes in the stack."""
pywrap_tensorflow.TFE_Py_TapeStackDeleteTrace(tensor_id)
pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id)
def could_possibly_record():
"""Returns True if any tape is active."""
return not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty()
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()

View File

@ -614,7 +614,7 @@ class Estimator(object):
if isinstance(result, (list, tuple)):
if len(result) != 2:
raise ValueError(
'input_fn should return (feautures, labels) as a len 2 tuple.')
'input_fn should return (features, labels) as a len 2 tuple.')
return result[0], result[1], input_hooks
return result, None, input_hooks

View File

@ -121,7 +121,10 @@ class _WarmStartSettings(
# where ws could be defined as:
# Warm-start all weights in the model (input layer and hidden weights).
# Either the directory or a specific checkpoint can be provided (in the case
# of the former, the latest checkpoint will be used).
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp")
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
# Warm-start only the embeddings (input layer).
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp",
@ -348,7 +351,7 @@ def _warmstart_var_with_vocab(var,
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
# remapping too.
init = checkpoint_ops._load_and_remap_matrix_initializer(
ckpt_path=saver.latest_checkpoint(prev_ckpt),
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
old_tensor_name=prev_tensor_name,
new_row_vocab_size=current_vocab_size,
new_col_vocab_size=v_shape[1],

View File

@ -50,9 +50,7 @@ class WarmStartingUtilTest(test.TestCase):
sess.run(variables.global_variables_initializer())
saver = saver_lib.Saver()
ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
ckpt_state_name = "checkpoint"
saver.save(
sess, ckpt_prefix, global_step=0, latest_filename=ckpt_state_name)
saver.save(sess, ckpt_prefix, global_step=0)
def _create_prev_run_var(self,
var_name,
@ -408,6 +406,44 @@ class WarmStartingUtilTest(test.TestCase):
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
sess)
def testWarmStart_ExplicitCheckpointFile(self):
# Create vocab for sparse column "sc_vocab".
vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"vocab")
# Create feature column.
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
# Save checkpoint from which to warm-start.
_, prev_vocab_val = self._create_prev_run_var(
"linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warmstarting.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
sess.run(variables.global_variables_initializer())
# Without warmstarting, the weights should be initialized using default
# initializer (which is init_ops.zeros_initializer).
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
sess)
# New graph, new session with warmstarting.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
# Since old vocab is not explicitly set in WarmStartSettings, the old
# vocab is assumed to be same as new vocab.
ws_util._warmstart(ws_util._WarmStartSettings(
# Explicitly provide the file prefix instead of just the dir.
os.path.join(self.get_temp_dir(), "model-0"),
vars_to_warmstart=".*sc_vocab.*"))
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warmstarted.
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
sess)
def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
# Create old vocabulary, and use a size smaller than the total number of
# entries.

View File

@ -59,6 +59,7 @@ def _TestDir(test_name):
# pylint: enable=invalid-name
@test_util.with_c_api
class SimpleMetaGraphTest(test.TestCase):
def testNoVariables(self):
@ -103,7 +104,8 @@ class SimpleMetaGraphTest(test.TestCase):
# Re-exports the current graph state for comparison to the original.
new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename +
"_new")
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
new_meta_graph_def)
# Ensures that we can still get a reference to our graph collections.
new_input_tensor = ops.get_collection("input_tensor")[0]
@ -226,7 +228,7 @@ class SimpleMetaGraphTest(test.TestCase):
double_nested_complex_node_def = None
for function_def in meta_graph_def.graph_def.library.function:
for node_def in function_def.node_def:
if node_def.name == "double_nested_complex":
if node_def.name.startswith("double_nested_complex"):
double_nested_complex_node_def = node_def
break
if double_nested_complex_node_def:
@ -258,6 +260,7 @@ class SimpleMetaGraphTest(test.TestCase):
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
@test_util.with_c_api
class ScopedMetaGraphTest(test.TestCase):
def _testScopedExport(self, test_dir, exported_filenames):
@ -435,10 +438,13 @@ class ScopedMetaGraphTest(test.TestCase):
]
orig_meta_graphs = self._testScopedExport(test_dir, filenames)
new_meta_graphs = self._testScopedImport(test_dir, filenames)
# Delete the unbound_inputs to allow directly calling ProtoEqual.
del orig_meta_graphs[0].collection_def["unbound_inputs"]
del new_meta_graphs[0].collection_def["unbound_inputs"]
for a, b in zip(orig_meta_graphs, new_meta_graphs):
# The unbound input strings are slightly different with the C API enabled
# ("images" vs "images:0") due to the original import_graph_def code
# vs. ImportGraphDef in C++.
# TODO(skyewm): update the pbtxts once _USE_C_API is removed.
del a.collection_def["unbound_inputs"]
del b.collection_def["unbound_inputs"]
test_util.assert_meta_graph_protos_equal(self, a, b)
def testScopedImportUnderNameScope(self):
@ -572,7 +578,8 @@ class ScopedMetaGraphTest(test.TestCase):
"exported_queue1.pbtxt")
new_meta_graph = self._testScopedImportWithQueue(
test_dir, "exported_queue1.pbtxt", "exported_new_queue1.pbtxt")
self.assertProtoEquals(orig_meta_graph, new_meta_graph)
test_util.assert_meta_graph_protos_equal(self, orig_meta_graph,
new_meta_graph)
# Verifies that we can export a subgraph in a nested name scope containing a
# "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new
@ -718,6 +725,7 @@ class ScopedMetaGraphTest(test.TestCase):
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
@test_util.with_c_api
class MetaGraphWithVariableScopeTest(test.TestCase):
def testMetricsCollection(self):
@ -775,6 +783,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
initializer = variables.local_variables_initializer()
@test_util.with_c_api
class ExportImportAcrossScopesTest(test.TestCase):
def testPartionedVariables(self):
@ -845,7 +854,7 @@ class ExportImportAcrossScopesTest(test.TestCase):
if shared_name_value.s:
node.attr[shared_name_attr].s = b""
self.assertProtoEquals(expected, result)
test_util.assert_meta_graph_protos_equal(self, expected, result)
if __name__ == "__main__":

View File

@ -162,6 +162,16 @@ def assert_meta_graph_protos_equal(tester, a, b):
# proto comparison below.
a.ClearField("collection_def")
b.ClearField("collection_def")
# Check the graph_defs.
assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
# Check graph_def versions (ignored by assert_equal_graph_def).
tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
# Compared the fields directly, remove their raw values from the
# proto comparison below.
a.ClearField("graph_def")
b.ClearField("graph_def")
tester.assertProtoEquals(a, b)
@ -178,7 +188,7 @@ def _strip_checkpoint_v2_randomized(graph_def):
if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
attr_tensor_string_value = attr_tensor_value.string_val[0]
if (attr_tensor_string_value and
re.match(_SHARDED_SAVE_OP_PATTERN, attr_tensor_string_value)):
re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))):
delete_keys.append(attr_key)
for attr_key in delete_keys:
del node.attr[attr_key]

View File

@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
@ -55,6 +56,7 @@ def _logit(x):
return np.log(x) - np.log1p(-x)
@test_util.with_c_api
class AssertCloseTest(test.TestCase):
def testAssertCloseIntegerDtype(self):
@ -145,6 +147,7 @@ class AssertCloseTest(test.TestCase):
array_ops.identity(w).eval(feed_dict=feed_dict)
@test_util.with_c_api
class GetLogitsAndProbsTest(test.TestCase):
def testImproperArguments(self):
@ -298,6 +301,7 @@ class GetLogitsAndProbsTest(test.TestCase):
logit.eval(feed_dict={l: np.ones([int(2**11+1)])})
@test_util.with_c_api
class EmbedCheckCategoricalEventShapeTest(test.TestCase):
def testTooSmall(self):
@ -335,6 +339,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
du.embed_check_categorical_event_shape(param)
@test_util.with_c_api
class EmbedCheckIntegerCastingClosedTest(test.TestCase):
def testCorrectlyAssertsNonnegative(self):
@ -370,6 +375,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.int32)})
@test_util.with_c_api
class LogCombinationsTest(test.TestCase):
def testLogCombinationsBinomial(self):
@ -400,6 +406,7 @@ class LogCombinationsTest(test.TestCase):
self.assertEqual([2, 2], log_binom.get_shape())
@test_util.with_c_api
class DynamicShapeTest(test.TestCase):
def testSameDynamicShape(self):
@ -504,6 +511,7 @@ class DynamicShapeTest(test.TestCase):
}))
@test_util.with_c_api
class RotateTransposeTest(test.TestCase):
def _np_rotate_transpose(self, x, shift):
@ -537,6 +545,7 @@ class RotateTransposeTest(test.TestCase):
shift: shift_value}))
@test_util.with_c_api
class PickVectorTest(test.TestCase):
def testCorrectlyPicksVector(self):
@ -557,6 +566,7 @@ class PickVectorTest(test.TestCase):
constant_op.constant(False), x, y)) # No eval.
@test_util.with_c_api
class PreferStaticRankTest(test.TestCase):
def testNonEmptyConstantTensor(self):
@ -596,6 +606,7 @@ class PreferStaticRankTest(test.TestCase):
self.assertAllEqual(0, rank.eval(feed_dict={x: 1}))
@test_util.with_c_api
class PreferStaticShapeTest(test.TestCase):
def testNonEmptyConstantTensor(self):
@ -635,6 +646,7 @@ class PreferStaticShapeTest(test.TestCase):
self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1}))
@test_util.with_c_api
class PreferStaticValueTest(test.TestCase):
def testNonEmptyConstantTensor(self):
@ -675,6 +687,7 @@ class PreferStaticValueTest(test.TestCase):
self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1}))
@test_util.with_c_api
class FillTriangularTest(test.TestCase):
def setUp(self):
@ -769,6 +782,7 @@ class FillTriangularTest(test.TestCase):
self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True)
@test_util.with_c_api
class ReduceWeightedLogSumExp(test.TestCase):
def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False):
@ -865,6 +879,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
du.reduce_weighted_logsumexp(x, w, axis=[0, 1]).eval())
@test_util.with_c_api
class GenNewSeedTest(test.TestCase):
def testOnlyNoneReturnsNone(self):
@ -875,6 +890,7 @@ class GenNewSeedTest(test.TestCase):
# TODO(jvdillon): Merge this test back into:
# tensorflow/python/kernel_tests/softplus_op_test.py
# once TF core is accepting new ops.
@test_util.with_c_api
class SoftplusTest(test.TestCase):
def _npSoftplus(self, np_features):

View File

@ -20,21 +20,25 @@ limitations under the License.
%rename("%s") TFE_ContextListDevices;
%rename("%s") TFE_ContextAddFunction;
%rename("%s") TFE_ContextAddFunctionDef;
%rename("%s") TFE_ContextEnableRunMetadata;
%rename("%s") TFE_ContextDisableRunMetadata;
%rename("%s") TFE_ContextExportRunMetadata;
%rename("%s") TFE_ContextClearCaches;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_Py_UID;
%rename("%s") TFE_Py_TapeStackPushNew;
%rename("%s") TFE_Py_TapeStackPush;
%rename("%s") TFE_Py_TapeStackPop;
%rename("%s") TFE_Py_TapeStackIsEmpty;
%rename("%s") TFE_Py_TapeStackShouldRecord;
%rename("%s") TFE_Py_TapeStackWatch;
%rename("%s") TFE_Py_TapeStackDeleteTrace;
%rename("%s") TFE_Py_TapeStackRecordOperation;
%rename("%s") TFE_Py_TapeStackWatchVariable;
%rename("%s") TFE_Py_TapeSetNew;
%rename("%s") TFE_Py_TapeSetRemove;
%rename("%s") TFE_Py_TapeSetStopOnThread;
%rename("%s") TFE_Py_TapeSetRestartOnThread;
%rename("%s") TFE_Py_TapeSetIsEmpty;
%rename("%s") TFE_Py_TapeSetShouldRecord;
%rename("%s") TFE_Py_TapeSetWatch;
%rename("%s") TFE_Py_TapeSetDeleteTrace;
%rename("%s") TFE_Py_TapeSetRecordOperation;
%rename("%s") TFE_Py_TapeSetWatchVariable;
%rename("%s") TFE_Py_TapeGradient;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_NewContextOptions;

View File

@ -455,11 +455,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
"https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/cb27e1d0da7f30562ea6c1c4f01393afbf112620.tar.gz",
"https://github.com/llvm-mirror/llvm/archive/cb27e1d0da7f30562ea6c1c4f01393afbf112620.tar.gz",
"https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/81a623d6d61cf87847f839e80b047c267020ab0e.tar.gz",
"https://github.com/llvm-mirror/llvm/archive/81a623d6d61cf87847f839e80b047c267020ab0e.tar.gz",
],
sha256 = "d4e4d17040a786bab13bb1b73ec2dc358f0c07214f847076e0ded8de15800782",
strip_prefix = "llvm-cb27e1d0da7f30562ea6c1c4f01393afbf112620",
sha256 = "be0259c0bd5349200df346c92ba7708341e18ef313fcf7398682b5cff2469137",
strip_prefix = "llvm-81a623d6d61cf87847f839e80b047c267020ab0e",
build_file = str(Label("//third_party/llvm:llvm.BUILD")),
)

View File

@ -110,11 +110,7 @@ def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
lang = "c++"
else:
lang = "c"
# TODO: We pass -no-canonical-prefixes here to match the compiler flags,
# but in cuda_clang CROSSTOOL file that is a `feature` and we should
# handle the case when it's disabled and no flag is passed
result = repository_ctx.execute([cc, "-no-canonical-prefixes",
"-E", "-x" + lang, "-", "-v"])
result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"])
index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
if index1 == -1:
return []