Merge commit for internal changes
This commit is contained in:
commit
fc8b359214
@ -542,7 +542,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
|
||||
// (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
|
||||
// which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
|
||||
// of FunctionLibraryRuntime tells use that func_lib_def is not accessed by
|
||||
// of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
|
||||
// FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
|
||||
// This is quite subtle. Re-work things to make this better? (Would it make
|
||||
// sense for FunctionLibraryRuntime to ensure thread-safe access to
|
||||
|
@ -302,6 +302,7 @@ cc_library(
|
||||
":array4d",
|
||||
":shape_tree",
|
||||
":shape_util",
|
||||
":sparse_index_array",
|
||||
":status_macros",
|
||||
":types",
|
||||
":util",
|
||||
@ -628,6 +629,28 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sparse_index_array",
|
||||
srcs = ["sparse_index_array.cc"],
|
||||
hdrs = ["sparse_index_array.h"],
|
||||
deps = [
|
||||
":array2d",
|
||||
":shape_util",
|
||||
":xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "sparse_index_array_test",
|
||||
srcs = ["sparse_index_array_test.cc"],
|
||||
deps = [
|
||||
":sparse_index_array",
|
||||
":test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
@ -149,4 +149,33 @@ namespace xla {
|
||||
return stride;
|
||||
}
|
||||
|
||||
/* static */ bool IndexUtil::IndexInBounds(
|
||||
const Shape& shape, tensorflow::gtl::ArraySlice<int64> index) {
|
||||
int64 rank = ShapeUtil::Rank(shape);
|
||||
if (rank != index.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int64 d = 0; d < rank; ++d) {
|
||||
if (index[d] >= shape.dimensions(d)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/* static */ int IndexUtil::CompareIndices(
|
||||
tensorflow::gtl::ArraySlice<int64> lhs,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs) {
|
||||
int64 rank = lhs.size();
|
||||
CHECK_EQ(rhs.size(), rank);
|
||||
for (int64 dim = 0; dim < rank; ++dim) {
|
||||
if (lhs[dim] < rhs[dim]) {
|
||||
return -1;
|
||||
} else if (lhs[dim] > rhs[dim]) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -69,6 +69,18 @@ class IndexUtil {
|
||||
// sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10
|
||||
static int64 GetDimensionStride(const Shape& shape, int64 dimension);
|
||||
|
||||
// Returns true iff the given multi-index is contained in the bounds for the
|
||||
// shape.
|
||||
static bool IndexInBounds(const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<int64> index);
|
||||
|
||||
// Compares the given indices in lexicographic order. lhs[0] and rhs[0] are
|
||||
// compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger,
|
||||
// then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is
|
||||
// returned.
|
||||
static int CompareIndices(tensorflow::gtl::ArraySlice<int64> lhs,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs);
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil);
|
||||
};
|
||||
|
@ -64,6 +64,13 @@ void SetDefaultLayoutToContainer(
|
||||
return layout;
|
||||
}
|
||||
|
||||
/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
|
||||
Layout layout;
|
||||
layout.set_format(SPARSE);
|
||||
layout.set_max_sparse_elements(max_sparse_elements);
|
||||
return layout;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Internal helper that creates a default layout for an array of the given rank.
|
||||
@ -234,7 +241,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
LayoutUtil::ClearLayout(program_shape->mutable_result());
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::IsDense(const Shape& shape) {
|
||||
/* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
|
||||
return ShapeUtil::IsArray(shape) && shape.has_layout() &&
|
||||
IsDense(shape.layout());
|
||||
}
|
||||
@ -260,7 +267,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
shape.layout().padded_dimensions_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
CHECK(IsDense(shape));
|
||||
CHECK(IsDenseArray(shape));
|
||||
CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size());
|
||||
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
|
||||
if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) {
|
||||
@ -272,21 +279,35 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
|
||||
/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::PaddedDimensions(
|
||||
const Shape& shape) {
|
||||
CHECK(IsDense(shape));
|
||||
CHECK(IsDenseArray(shape));
|
||||
return AsInt64Slice(shape.layout().padded_dimensions());
|
||||
}
|
||||
|
||||
/* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape,
|
||||
int64 index) {
|
||||
CHECK(IsDense(shape));
|
||||
CHECK(IsDenseArray(shape));
|
||||
return shape.layout().padded_dimensions(index);
|
||||
}
|
||||
|
||||
/* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) {
|
||||
CHECK(IsDense(shape));
|
||||
CHECK(IsDenseArray(shape));
|
||||
return shape.layout().padding_value();
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
|
||||
return ShapeUtil::IsArray(shape) && shape.has_layout() &&
|
||||
IsSparse(shape.layout());
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::IsSparse(const Layout& layout) {
|
||||
return layout.format() == SPARSE;
|
||||
}
|
||||
|
||||
/* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) {
|
||||
CHECK(IsSparse(layout));
|
||||
return layout.max_sparse_elements();
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
// Tuple shape: all subshapes must have a layout.
|
||||
@ -313,7 +334,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
|
||||
/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::MinorToMajor(
|
||||
const Shape& shape) {
|
||||
CHECK(IsDense(shape));
|
||||
CHECK(IsDenseArray(shape));
|
||||
return AsInt64Slice(shape.layout().minor_to_major());
|
||||
}
|
||||
|
||||
@ -419,6 +440,7 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src,
|
||||
|
||||
/* static */ bool LayoutUtil::AreDimensionsConsecutive(
|
||||
const Layout& layout, tensorflow::gtl::ArraySlice<int64> dims) {
|
||||
CHECK(IsDense(layout));
|
||||
std::vector<int64> positions_in_layout;
|
||||
for (int64 dim : dims) {
|
||||
positions_in_layout.push_back(
|
||||
|
@ -36,6 +36,10 @@ class LayoutUtil {
|
||||
// convenience function for protobuf construction.)
|
||||
static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major);
|
||||
|
||||
// Creates a sparse layout with the given maximum number of elements. (This is
|
||||
// a convenience function for protobuf construction.)
|
||||
static Layout MakeSparseLayout(int64 max_sparse_elements);
|
||||
|
||||
// Returns default layout for the given shape.
|
||||
static Layout GetDefaultLayoutForShape(const Shape& shape);
|
||||
|
||||
@ -72,7 +76,7 @@ class LayoutUtil {
|
||||
static void ClearLayout(ProgramShape* program_shape);
|
||||
|
||||
// Returns whether the given Shape is an array and has a dense format layout.
|
||||
static bool IsDense(const Shape& shape);
|
||||
static bool IsDenseArray(const Shape& shape);
|
||||
|
||||
// Returns whether the given Layout has a dense format.
|
||||
static bool IsDense(const Layout& layout);
|
||||
@ -107,6 +111,17 @@ class LayoutUtil {
|
||||
// an array and has a dense layout.
|
||||
static PaddingValue GetPaddingValue(const Shape& shape);
|
||||
|
||||
// Returns whether the given Shape is an array (i.e. not a tuple) and has a
|
||||
// sparse format layout.
|
||||
static bool IsSparseArray(const Shape& shape);
|
||||
|
||||
// Returns whether the given Layout has a sparse format.
|
||||
static bool IsSparse(const Layout& layout);
|
||||
|
||||
// Returns the maximum number of elements that can be stored in a sparse
|
||||
// layout.
|
||||
static int64 MaxSparseElements(const Layout& layout);
|
||||
|
||||
// Returns whether the given shape has a layout. For tuple shapes, true is
|
||||
// returned only if all elements have layouts.
|
||||
static bool HasLayout(const Shape& shape);
|
||||
|
@ -30,6 +30,14 @@ class LayoutUtilTest : public ::testing::Test {
|
||||
*shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
|
||||
return shape;
|
||||
}
|
||||
|
||||
Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
int64 max_sparse_elements) {
|
||||
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
|
||||
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
|
||||
return shape;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(LayoutUtilTest, TupleLayoutComparison) {
|
||||
@ -81,6 +89,29 @@ TEST_F(LayoutUtilTest, CopyLayoutArray) {
|
||||
EXPECT_FALSE(dst.has_layout());
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutSparse) {
|
||||
Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2);
|
||||
Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
|
||||
|
||||
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
|
||||
// Should work if destination has no layout.
|
||||
dst.clear_layout();
|
||||
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
|
||||
// If source is cleared, then destination should be cleared.
|
||||
src.clear_layout();
|
||||
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_TRUE(dst.has_layout());
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_FALSE(dst.has_layout());
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutTuple) {
|
||||
Shape src = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
|
||||
@ -100,6 +131,25 @@ TEST_F(LayoutUtilTest, CopyLayoutTuple) {
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) {
|
||||
Shape src = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithSparseLayout(F32, {2, 3}, 4),
|
||||
MakeShapeWithSparseLayout(F32, {42, 123}, 4),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {}, {}),
|
||||
MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})});
|
||||
Shape dst = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
|
||||
MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {}, {}),
|
||||
MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
|
||||
|
||||
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
|
||||
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
|
||||
Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
|
||||
@ -107,6 +157,13 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) {
|
||||
Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6);
|
||||
Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
|
||||
ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
|
||||
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
|
||||
Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
|
||||
@ -116,6 +173,15 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
|
||||
::testing::ContainsRegex("cannot copy layout from shape"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) {
|
||||
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
|
||||
Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4);
|
||||
auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
::testing::ContainsRegex("cannot copy layout from shape"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) {
|
||||
Shape src =
|
||||
ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
|
||||
@ -221,5 +287,10 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
|
||||
ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, SparseLayoutMaxElements) {
|
||||
EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)),
|
||||
101);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -94,9 +94,15 @@ Literal::Literal(const Shape& shape, bool allocate_arrays)
|
||||
Piece& piece = pair.second;
|
||||
|
||||
piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
|
||||
if (ShapeUtil::IsArray(piece.subshape())) {
|
||||
const Shape& subshape = piece.subshape();
|
||||
if (ShapeUtil::IsArray(subshape)) {
|
||||
if (allocate_arrays) {
|
||||
piece.set_buffer(new char[piece.size_bytes()]);
|
||||
if (LayoutUtil::IsSparseArray(subshape)) {
|
||||
piece.set_sparse_indices(new SparseIndexArray(
|
||||
LayoutUtil::MaxSparseElements(subshape.layout()),
|
||||
ShapeUtil::Rank(subshape)));
|
||||
}
|
||||
} else {
|
||||
piece.set_buffer(nullptr);
|
||||
}
|
||||
@ -112,6 +118,7 @@ void Literal::DeallocateBuffers() {
|
||||
Piece& piece = pair.second;
|
||||
if (piece.buffer() != nullptr) {
|
||||
delete[] piece.buffer();
|
||||
delete piece.sparse_indices();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -164,6 +171,15 @@ std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
|
||||
return literal;
|
||||
}
|
||||
|
||||
const SparseIndexArray* Literal::sparse_indices(
|
||||
const ShapeIndex& shape_index) const {
|
||||
return piece(shape_index).sparse_indices();
|
||||
}
|
||||
|
||||
SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
|
||||
return piece(shape_index).sparse_indices();
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateFromDimensions(
|
||||
PrimitiveType primitive_type,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
@ -247,9 +263,12 @@ std::vector<Literal> Literal::DecomposeTuple() {
|
||||
}
|
||||
Piece& src_piece = piece(src_index);
|
||||
|
||||
// Move the respective buffer over to the element Literal.
|
||||
// Move the respective buffer and sparse indices over to the element
|
||||
// Literal.
|
||||
dest_piece.set_buffer(src_piece.buffer());
|
||||
src_piece.set_buffer(nullptr);
|
||||
dest_piece.set_sparse_indices(src_piece.sparse_indices());
|
||||
src_piece.set_sparse_indices(nullptr);
|
||||
}
|
||||
}
|
||||
// Set this literal to be nil-shaped.
|
||||
@ -406,6 +425,8 @@ Status Literal::MoveFrom(Literal&& src_literal,
|
||||
Piece& dest_piece = piece(dest_index);
|
||||
delete[] dest_piece.buffer();
|
||||
dest_piece.set_buffer(src_piece.buffer());
|
||||
delete dest_piece.sparse_indices();
|
||||
dest_piece.set_sparse_indices(src_piece.sparse_indices());
|
||||
}
|
||||
|
||||
src_literal.shape_ = ShapeUtil::MakeNil();
|
||||
@ -764,7 +785,7 @@ std::unique_ptr<Literal> Literal::Transpose(
|
||||
// dimension has within the transposed array, a layout is affine if
|
||||
// MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
|
||||
// vector of the affine layout.
|
||||
CHECK(LayoutUtil::IsDense(permuted_shape));
|
||||
CHECK(LayoutUtil::IsDenseArray(permuted_shape));
|
||||
Layout* layout = permuted_shape.mutable_layout();
|
||||
layout->clear_minor_to_major();
|
||||
for (auto index : LayoutUtil::MinorToMajor(shape())) {
|
||||
@ -1573,6 +1594,12 @@ LiteralProto Literal::ToProto() const {
|
||||
}
|
||||
piece.WriteToProto(proto_piece);
|
||||
}
|
||||
|
||||
if (LayoutUtil::IsSparseArray(shape())) {
|
||||
CopyToRepeatedField(proto.mutable_sparse_indices(),
|
||||
sparse_indices()->data());
|
||||
}
|
||||
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -1653,6 +1680,7 @@ LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
|
||||
}
|
||||
const Piece& src_piece = literal.piece(src_index);
|
||||
piece.set_buffer(src_piece.buffer());
|
||||
piece.set_sparse_indices(src_piece.sparse_indices());
|
||||
piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
|
||||
}
|
||||
}
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/sparse_index_array.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -103,6 +104,12 @@ class Literal {
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> data(
|
||||
const ShapeIndex& shape_index = {});
|
||||
|
||||
// Returns a pointer to the sparse index array. Returns nullptr if the literal
|
||||
// is not a sparse array.
|
||||
const SparseIndexArray* sparse_indices(
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
|
||||
|
||||
// Returns a pointer to (or size of) the underlying buffer holding the array
|
||||
// at the given shape index. CHECKs if the subshape of the literal at the
|
||||
// given ShapeIndex is not array.
|
||||
@ -160,6 +167,56 @@ class Literal {
|
||||
// array.
|
||||
string GetR1U8AsString() const;
|
||||
|
||||
// Creates a literal with a sparse layout and the given indices and values.
|
||||
// The shape is initialized from the given dimensions. The minor dimension of
|
||||
// the indices array must equal the rank of the shape (i.e. size of the
|
||||
// dimensions array). The major dimension of the indices array must equal the
|
||||
// number of elements in the values array. The maximum number of elements in
|
||||
// the array is taken from the max_indices() value of the index array.
|
||||
//
|
||||
// XLA assumes that sparse literals are in sorted order for all operations. If
|
||||
// the `sort` argument is true, then the indices and values will be sorted
|
||||
// while copying them into the literal. If you have ensured that the indices
|
||||
// and values are already sorted, then you may set the `sort` argument to
|
||||
// false to skip the sorting step.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// CreateSparse(
|
||||
// {12, 12, 12},
|
||||
// SparseIndexArray(10, 3,
|
||||
// Array2D{
|
||||
// {0, 1, 2},
|
||||
// {3, 4, 5},
|
||||
// {6, 7, 8},
|
||||
// {9, 10, 11},
|
||||
// }),
|
||||
// {1.0, 2.0 3.0, 4.0})
|
||||
//
|
||||
// This creates an array with shape F64[12,12,12]sparse{10}, that has the
|
||||
// following non-zero values:
|
||||
//
|
||||
// [0, 1, 2]: 1.0
|
||||
// [3, 4, 5]: 2.0
|
||||
// [6, 7, 8]: 3.0
|
||||
// [9, 10, 11]: 4.0
|
||||
//
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateSparse(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
|
||||
tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
|
||||
|
||||
// Populates a literal with a sparse layout with the given indices and values.
|
||||
// Each index in the indices array is CHECKed against the dimensions in the
|
||||
// literal's shape. If sort is true, then the indices and values will be
|
||||
// sorted. If sort is false, then the indices and values are assumed to
|
||||
// already be in sorted order. See CreateSparse for an example of how data
|
||||
// are populated.
|
||||
template <typename NativeT>
|
||||
void PopulateSparse(SparseIndexArray indices,
|
||||
tensorflow::gtl::ArraySlice<NativeT> values,
|
||||
bool sort = true);
|
||||
|
||||
// Creates a new Literal object with the shape specified as parameter.
|
||||
// The content of the literal values is the default value of the primitive
|
||||
// type of literal itself (0 for numeric types, and false for predicates).
|
||||
@ -358,7 +415,7 @@ class Literal {
|
||||
const ShapeIndex& shape_index, NativeT value);
|
||||
|
||||
// Overloads of Get and Set for array literals. CHECKs if the literal is not
|
||||
// array-shaped.
|
||||
// array-shaped and dense.
|
||||
template <typename NativeT>
|
||||
NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
|
||||
template <typename NativeT>
|
||||
@ -408,6 +465,8 @@ class Literal {
|
||||
// This function is useful if you want a polymorphic representation
|
||||
// of the tensor's elements (turning it to a string for something
|
||||
// like representation in a protobuf).
|
||||
//
|
||||
// This literal must have a dense layout.
|
||||
void EachCellAsString(
|
||||
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
|
||||
const string& value)>& per_cell) const;
|
||||
@ -447,6 +506,8 @@ class Literal {
|
||||
//
|
||||
// generator must be a callable of the type
|
||||
// NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
|
||||
//
|
||||
// This literal must have a dense layout.
|
||||
template <typename NativeT, typename FnType>
|
||||
Status Populate(const FnType& generator);
|
||||
|
||||
@ -485,10 +546,12 @@ class Literal {
|
||||
// admonishments about floating-point equality checks apply. We expect you to
|
||||
// use this to check for complex values that can be expressed precisely as
|
||||
// float pairs e.g. (-0.5, 1.0).
|
||||
//
|
||||
// This literal must have a dense layout.
|
||||
bool IsAllComplex(complex64 value) const;
|
||||
|
||||
// Returns whether this literal is zero at the specified index. This literal
|
||||
// must be an array.
|
||||
// must be an array with a dense layout.
|
||||
bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
|
||||
|
||||
// Return the count of the elements in the array at the given shape index in
|
||||
@ -563,6 +626,14 @@ class Literal {
|
||||
char* buffer() const { return buffer_; }
|
||||
void set_buffer(char* buffer) { buffer_ = buffer; }
|
||||
|
||||
// The array of multi-indices that provide the locations of non-zero
|
||||
// elements in a sparse array. Only used if
|
||||
// LayoutUtil::IsSparseArray(shape()) is true.
|
||||
SparseIndexArray* sparse_indices() const { return sparse_indices_; }
|
||||
void set_sparse_indices(SparseIndexArray* sparse_indices) {
|
||||
sparse_indices_ = sparse_indices;
|
||||
}
|
||||
|
||||
// Gets or sets the subshape of this piece. This reference points to a
|
||||
// subshape within the shape in the containing Literal (Literal::shape_).
|
||||
const Shape& subshape() const { return *subshape_; }
|
||||
@ -598,6 +669,9 @@ class Literal {
|
||||
// For array-shaped pieces, this is the buffer holding the literal data.
|
||||
char* buffer_ = nullptr;
|
||||
|
||||
// For sparse arrays, this is the array of indices.
|
||||
SparseIndexArray* sparse_indices_ = nullptr;
|
||||
|
||||
// The shape of piece. This points into the shape of the containing Literal
|
||||
// (Literal::shape_).
|
||||
const Shape* subshape_ = nullptr;
|
||||
@ -836,6 +910,21 @@ template <typename NativeT>
|
||||
return CreateR4FromArray4DWithLayout(tmp, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateSparse(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
|
||||
tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
|
||||
int64 num_elements = values.size();
|
||||
int64 rank = dimensions.size();
|
||||
CHECK_EQ(num_elements, indices.index_count());
|
||||
CHECK_EQ(rank, indices.rank());
|
||||
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
|
||||
indices.max_indices()));
|
||||
literal->PopulateSparse(indices, values, sort);
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateR4(
|
||||
std::initializer_list<std::initializer_list<
|
||||
@ -1044,11 +1133,35 @@ void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
|
||||
PopulateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateSparse(SparseIndexArray indices,
|
||||
tensorflow::gtl::ArraySlice<NativeT> values,
|
||||
bool sort) {
|
||||
CHECK(LayoutUtil::IsSparseArray(shape()));
|
||||
int rank = ShapeUtil::Rank(shape());
|
||||
CHECK_EQ(indices.rank(), rank);
|
||||
int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
|
||||
CHECK_LE(indices.max_indices(), max_elements);
|
||||
int64 num_elements = values.size();
|
||||
CHECK_LE(num_elements, max_elements);
|
||||
CHECK_EQ(num_elements, indices.index_count());
|
||||
auto root_data = root_piece().data<NativeT>();
|
||||
root_data.remove_suffix(max_elements - values.size());
|
||||
std::copy(values.begin(), values.end(), root_data.begin());
|
||||
*this->root_piece().sparse_indices() = std::move(indices);
|
||||
if (sort) {
|
||||
auto root_data = this->root_piece().data<NativeT>();
|
||||
root_data.remove_suffix(root_data.size() - num_elements);
|
||||
this->root_piece().sparse_indices()->SortWithValues(root_data);
|
||||
}
|
||||
DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
|
||||
}
|
||||
|
||||
template <typename NativeT, typename FnType>
|
||||
Status Literal::Populate(const FnType& generator) {
|
||||
const Shape& this_shape = shape();
|
||||
const int64 rank = ShapeUtil::Rank(this_shape);
|
||||
TF_RET_CHECK(ShapeUtil::IsArray(this_shape));
|
||||
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
|
||||
TF_RET_CHECK(this_shape.element_type() ==
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
|
||||
|
@ -193,6 +193,34 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
|
||||
ASSERT_EQ(expected, result);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, CreateSparse) {
|
||||
std::vector<int64> dimensions = {8, 8, 8};
|
||||
Array2D<int64> indices = {
|
||||
{3, 4, 5},
|
||||
{1, 2, 3},
|
||||
{2, 3, 4},
|
||||
{3, 5, 6},
|
||||
};
|
||||
std::vector<int64> values = {7, 8, 9, 10};
|
||||
auto literal = Literal::CreateSparse<int64>(
|
||||
dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
|
||||
|
||||
Array2D<int64> expected_indices = {
|
||||
{1, 2, 3},
|
||||
{2, 3, 4},
|
||||
{3, 4, 5},
|
||||
{3, 5, 6},
|
||||
};
|
||||
std::vector<int64> expected_values = {8, 9, 7, 10};
|
||||
|
||||
EXPECT_EQ(literal->sparse_indices()->data(),
|
||||
tensorflow::gtl::ArraySlice<int64>(
|
||||
expected_indices.data(), expected_indices.num_elements()));
|
||||
EXPECT_EQ(tensorflow::gtl::ArraySlice<int64>(literal->data<int64>().data(),
|
||||
expected_values.size()),
|
||||
tensorflow::gtl::ArraySlice<int64>(expected_values));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
|
||||
// clang-format off
|
||||
auto literal = Literal::CreateR4Projected<float>({
|
||||
|
@ -60,6 +60,12 @@ bool ContainsKey(const Collection& collection, const Key& key) {
|
||||
return collection.find(key) != collection.end();
|
||||
}
|
||||
|
||||
// Inserts `value` into `set`. Dies if it was already present.
|
||||
template <class Set>
|
||||
void InsertOrDie(Set* const set, const typename Set::value_type& value) {
|
||||
CHECK(set->insert(value).second) << "duplicate value: " << value;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
|
||||
|
@ -1101,6 +1101,8 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_evaluator",
|
||||
":hlo_pass",
|
||||
":tuple_util",
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -2255,6 +2257,78 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tuple_util",
|
||||
srcs = ["tuple_util.cc"],
|
||||
hdrs = ["tuple_util.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tuple_util_test",
|
||||
srcs = ["tuple_util_test.cc"],
|
||||
deps = [
|
||||
":tuple_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_util",
|
||||
srcs = ["while_util.cc"],
|
||||
hdrs = ["while_util.h"],
|
||||
deps = [
|
||||
":call_inliner",
|
||||
":hlo",
|
||||
":tuple_util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "while_util_test",
|
||||
srcs = ["while_util_test.cc"],
|
||||
deps = [
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop_invariant_code_motion",
|
||||
srcs = ["while_loop_invariant_code_motion.cc"],
|
||||
hdrs = ["while_loop_invariant_code_motion.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":tuple_util",
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "while_loop_invariant_code_motion_test",
|
||||
srcs = ["while_loop_invariant_code_motion_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":while_loop_invariant_code_motion",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
@ -82,6 +82,10 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||
return outer_->ReplaceInstruction(call_, new_root);
|
||||
}
|
||||
|
||||
CallInliner::InlinedInstructionMap ConsumeInstructionMap() {
|
||||
return std::move(subcomputation_hlo_to_new_hlo_);
|
||||
}
|
||||
|
||||
private:
|
||||
// Resolves the callee subcomputation_hlo to the new (inline) HLO in the
|
||||
// caller computation, or returns a NotFound error if that subcomputation HLO
|
||||
@ -112,13 +116,13 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
HloInstruction* call_;
|
||||
HloComputation* outer_;
|
||||
std::unordered_map<HloInstruction*, HloInstruction*>
|
||||
subcomputation_hlo_to_new_hlo_;
|
||||
CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ Status CallInliner::Inline(HloInstruction* call) {
|
||||
/* static */ StatusOr<CallInliner::InlinedInstructionMap> CallInliner::Inline(
|
||||
HloInstruction* call) {
|
||||
TF_RET_CHECK(call->opcode() == HloOpcode::kCall)
|
||||
<< "Instruction was not a call op: " << call->opcode();
|
||||
const auto& callees = call->called_computations();
|
||||
@ -126,7 +130,8 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||
HloComputation* callee = callees[0];
|
||||
// We visit the callee, cloning its body into its caller.
|
||||
SubcomputationInsertionVisitor visitor(call);
|
||||
return callee->Accept(&visitor);
|
||||
TF_RETURN_IF_ERROR(callee->Accept(&visitor));
|
||||
return visitor.ConsumeInstructionMap();
|
||||
}
|
||||
|
||||
StatusOr<bool> CallInliner::Run(HloModule* module) {
|
||||
@ -140,7 +145,7 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
|
||||
VLOG(1) << "Visiting callsite: " << callsite.ToString();
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
|
||||
HloInstruction* call = callsite.instruction();
|
||||
TF_RETURN_IF_ERROR(Inline(call));
|
||||
TF_RETURN_IF_ERROR(Inline(call).status());
|
||||
did_mutate = true;
|
||||
}
|
||||
}
|
||||
|
@ -27,8 +27,12 @@ namespace xla {
|
||||
// called function, and proceed recursively.
|
||||
class CallInliner : public HloPassInterface {
|
||||
public:
|
||||
// Inlines one call instruction.
|
||||
static Status Inline(HloInstruction* call);
|
||||
using InlinedInstructionMap =
|
||||
std::unordered_map<HloInstruction*, HloInstruction*>;
|
||||
|
||||
// Inlines one call instruction. Returns a mapping from the original
|
||||
// instructions to their inlined versions.
|
||||
static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call);
|
||||
|
||||
~CallInliner() override = default;
|
||||
tensorflow::StringPiece name() const override { return "CallInliner"; }
|
||||
|
@ -135,7 +135,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) {
|
||||
HloInstruction::CreateCall(pred, {}, false_computation));
|
||||
auto computation = module->AddEntryComputation(call_false_builder.Build());
|
||||
|
||||
TF_ASSERT_OK(CallInliner::Inline(call));
|
||||
TF_ASSERT_OK(CallInliner::Inline(call).status());
|
||||
EXPECT_THAT(computation->root_instruction(), op::Constant());
|
||||
EXPECT_THAT(computation->root_instruction()->control_successors(),
|
||||
ElementsAre(op::Constant()));
|
||||
|
@ -81,6 +81,7 @@ cc_library(
|
||||
":conv_canonicalization",
|
||||
":cpu_copy_insertion",
|
||||
":cpu_executable",
|
||||
":cpu_hlo_support_checker",
|
||||
":cpu_instruction_fusion",
|
||||
":cpu_layout_assignment",
|
||||
":cpu_options",
|
||||
@ -126,6 +127,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:transpose_folding",
|
||||
"//tensorflow/compiler/xla/service:tuple_simplifier",
|
||||
"//tensorflow/compiler/xla/service:while_loop_invariant_code_motion",
|
||||
"//tensorflow/compiler/xla/service:while_loop_simplifier",
|
||||
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep
|
||||
@ -873,6 +875,32 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_hlo_support_checker",
|
||||
srcs = ["cpu_hlo_support_checker.cc"],
|
||||
hdrs = ["cpu_hlo_support_checker.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cpu_hlo_support_checker_test",
|
||||
srcs = ["cpu_hlo_support_checker_test.cc"],
|
||||
deps = [
|
||||
":cpu_hlo_support_checker",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/Object/ObjectFile.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
@ -50,6 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
@ -85,6 +87,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/transpose_folding.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -258,6 +261,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
|
||||
// Optimization pipeline.
|
||||
HloPassPipeline pipeline("CPU");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(ShapeSizeBytesFunction());
|
||||
pipeline.AddPass<CpuHloSupportChecker>();
|
||||
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&pipeline, module->config().debug_options(),
|
||||
@ -291,6 +295,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
|
||||
// elimination has to come after that pass.
|
||||
pipeline.AddPass<ZeroSizedHloElimination>();
|
||||
|
||||
pass.AddPass<WhileLoopInvariantCodeMotion>();
|
||||
pass.AddPass<TupleSimplifier>();
|
||||
pass.AddPass<WhileLoopSimplifier>();
|
||||
pass.AddPass<HloDCE>();
|
||||
@ -439,6 +444,21 @@ Status InitializeModuleHooks(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyLlvmModule(const llvm::Module& llvm_module) {
|
||||
XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
|
||||
|
||||
std::string err;
|
||||
llvm::raw_string_ostream err_stream(err);
|
||||
|
||||
// verifyModule() returns true if the module is broken.
|
||||
TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
|
||||
<< "Invalid LLVM IR before optimizations:\n"
|
||||
<< err_stream.str()
|
||||
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
|
||||
"Rerun with --xla_dump_ir_to to get the IR. ";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
||||
@ -627,6 +647,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
if (embed_ir_in_executable) {
|
||||
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
|
||||
|
||||
// JIT compile the LLVM IR module to in-memory machine code.
|
||||
jit->AddModule(std::move(llvm_module));
|
||||
@ -704,6 +725,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
if (embed_ir_in_executable) {
|
||||
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
|
||||
|
||||
XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module));
|
||||
|
||||
@ -875,6 +897,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
&module_sequence.at(computation)));
|
||||
|
||||
CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name));
|
||||
TF_RETURN_IF_ERROR(VerifyLlvmModule(llvm_module));
|
||||
|
||||
ModuleHook pre_optimization_ir_dump_hook;
|
||||
ModuleHook post_optimization_ir_dump_hook;
|
||||
|
@ -0,0 +1,48 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> CpuHloSupportChecker::Run(HloModule* module) {
|
||||
for (auto* computation : module->computations()) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
instruction->shape(),
|
||||
[&instruction](const Shape& subshape, const ShapeIndex&) {
|
||||
if (LayoutUtil::IsSparseArray(subshape)) {
|
||||
return xla::Unimplemented(
|
||||
"CPU backend does not support HLO instruction %s with shape "
|
||||
"containing a sparse layout: %s",
|
||||
instruction->ToString().c_str(),
|
||||
ShapeUtil::HumanStringWithLayout(instruction->shape())
|
||||
.c_str());
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// This pass should run early in the HLO pipeline and checks for HLO constructs
|
||||
// which are not supported by the CPU backend and cannot be removed via HLO
|
||||
// transformations (eg, sparse layouts).
|
||||
class CpuHloSupportChecker : public HloPassInterface {
|
||||
public:
|
||||
CpuHloSupportChecker() = default;
|
||||
~CpuHloSupportChecker() override = default;
|
||||
|
||||
tensorflow::StringPiece name() const override {
|
||||
return "cpu_hlo_support_checker";
|
||||
}
|
||||
|
||||
// Note: always returns false (no instructions are ever modified by this
|
||||
// pass).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
|
@ -0,0 +1,72 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class CpuHloSupportCheckerTest : public HloTestBase {
|
||||
protected:
|
||||
CpuHloSupportChecker& checker() { return checker_; }
|
||||
|
||||
private:
|
||||
CpuHloSupportChecker checker_;
|
||||
};
|
||||
|
||||
TEST_F(CpuHloSupportCheckerTest, Add) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, param0, param1));
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK(checker().Run(module.get()).status());
|
||||
}
|
||||
|
||||
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2);
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, sparse_shape, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
sparse_shape, HloOpcode::kAdd, param0, param1));
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module.get()).status();
|
||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("CPU backend does not support"));
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -437,6 +437,7 @@ cc_library(
|
||||
":fusion_merger",
|
||||
":gpu_copy_insertion",
|
||||
":gpu_executable",
|
||||
":gpu_hlo_support_checker",
|
||||
":gpu_layout_assignment",
|
||||
":hlo_schedule",
|
||||
":instruction_fusion",
|
||||
@ -610,6 +611,32 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_hlo_support_checker",
|
||||
srcs = ["gpu_hlo_support_checker.cc"],
|
||||
hdrs = ["gpu_hlo_support_checker.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gpu_hlo_support_checker_test",
|
||||
srcs = ["gpu_hlo_support_checker_test.cc"],
|
||||
deps = [
|
||||
":gpu_hlo_support_checker",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "llvm/IR/DiagnosticPrinter.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
@ -39,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
@ -137,6 +139,7 @@ tensorflow::Status OptimizeHloModule(
|
||||
{
|
||||
HloPassPipeline pipeline("optimization");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(shape_size_function);
|
||||
pipeline.AddPass<GpuHloSupportChecker>();
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&pipeline, hlo_module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
|
||||
@ -476,6 +479,20 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
entry_computation->root_instruction()->Accept(&ir_emitter));
|
||||
}
|
||||
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
|
||||
|
||||
std::string err;
|
||||
llvm::raw_string_ostream err_stream(err);
|
||||
|
||||
// verifyModule() returns true if the module is broken.
|
||||
TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
|
||||
<< "Invalid LLVM IR before optimizations:\n"
|
||||
<< err_stream.str()
|
||||
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
|
||||
"Rerun with --xla_dump_ir_to to get the IR. ";
|
||||
}
|
||||
|
||||
if (user_pre_optimization_hook_) {
|
||||
TF_CHECK_OK(user_pre_optimization_hook_(llvm_module));
|
||||
}
|
||||
|
@ -0,0 +1,48 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> GpuHloSupportChecker::Run(HloModule* module) {
|
||||
for (auto* computation : module->computations()) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
instruction->shape(),
|
||||
[&instruction](const Shape& subshape, const ShapeIndex&) {
|
||||
if (LayoutUtil::IsSparseArray(subshape)) {
|
||||
return xla::Unimplemented(
|
||||
"GPU backend does not support HLO instruction %s with shape "
|
||||
"containing a sparse layout: %s",
|
||||
instruction->ToString().c_str(),
|
||||
ShapeUtil::HumanStringWithLayout(instruction->shape())
|
||||
.c_str());
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// his pass should run early in the HLO pipeline and checks for HLO constructs
|
||||
// which are not supported by the GPU backend and cannot be removed via HLO
|
||||
// transformations (eg, sparse layouts).
|
||||
class GpuHloSupportChecker : public HloPassInterface {
|
||||
public:
|
||||
GpuHloSupportChecker() = default;
|
||||
~GpuHloSupportChecker() override = default;
|
||||
|
||||
tensorflow::StringPiece name() const override {
|
||||
return "gpu_hlo_support_checker";
|
||||
}
|
||||
|
||||
// Note: always returns false (no instructions are ever modified by this
|
||||
// pass).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
|
@ -0,0 +1,72 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class GpuHloSupportCheckerTest : public HloTestBase {
|
||||
protected:
|
||||
GpuHloSupportChecker& checker() { return checker_; }
|
||||
|
||||
private:
|
||||
GpuHloSupportChecker checker_;
|
||||
};
|
||||
|
||||
TEST_F(GpuHloSupportCheckerTest, Add) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, param0, param1));
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK(checker().Run(module.get()).status());
|
||||
}
|
||||
|
||||
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2);
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, sparse_shape, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
sparse_shape, HloOpcode::kAdd, param0, param1));
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module.get()).status();
|
||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("GPU backend does not support"));
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
61
tensorflow/compiler/xla/service/tuple_util.cc
Normal file
61
tensorflow/compiler/xla/service/tuple_util.cc
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
/*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple,
|
||||
int64 elements) {
|
||||
CHECK(ShapeUtil::IsTuple(input_tuple->shape()));
|
||||
|
||||
HloComputation* computation = input_tuple->parent();
|
||||
const Shape& input_shape = input_tuple->shape();
|
||||
|
||||
std::vector<HloInstruction*> tuple_elements;
|
||||
tuple_elements.reserve(elements);
|
||||
for (int i = 0; i < elements; i++) {
|
||||
tuple_elements.push_back(
|
||||
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
input_shape.tuple_shapes(i), input_tuple, i)));
|
||||
}
|
||||
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateTuple(tuple_elements));
|
||||
}
|
||||
|
||||
/*static*/ HloInstruction* TupleUtil::AppendSuffix(
|
||||
HloInstruction* input_tuple,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values) {
|
||||
CHECK(ShapeUtil::IsTuple(input_tuple->shape()));
|
||||
|
||||
HloComputation* computation = input_tuple->parent();
|
||||
const Shape& input_shape = input_tuple->shape();
|
||||
std::vector<HloInstruction*> tuple_elements;
|
||||
tuple_elements.reserve(input_shape.tuple_shapes_size());
|
||||
for (int i = 0; i < input_shape.tuple_shapes_size(); i++) {
|
||||
tuple_elements.push_back(
|
||||
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
input_shape.tuple_shapes(i), input_tuple, i)));
|
||||
}
|
||||
tuple_elements.insert(tuple_elements.end(), trailing_values.begin(),
|
||||
trailing_values.end());
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateTuple(tuple_elements));
|
||||
}
|
||||
|
||||
} // namespace xla
|
45
tensorflow/compiler/xla/service/tuple_util.h
Normal file
45
tensorflow/compiler/xla/service/tuple_util.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
||||
namespace xla {
|
||||
class TupleUtil {
|
||||
public:
|
||||
// Generates HLO instructions to get a prefix tuple from `input_tuple` (which
|
||||
// must be of tuple shape) of length `elements`. Returns the root of the
|
||||
// graph of instructions generated.
|
||||
//
|
||||
// The instructions are generated into the computation containing
|
||||
// `input_tuple`.
|
||||
static HloInstruction* ExtractPrefix(HloInstruction* input_tuple,
|
||||
int64 elements);
|
||||
|
||||
// Generates HLO instructions to create a tuple that consists of the values in
|
||||
// `trailing_values` appended to `input_tuple` (which must be of tuple shape).
|
||||
// Returns the root of the graph of instructions generated.
|
||||
//
|
||||
// The instructions are generated into the computation containing
|
||||
// `input_tuple`.
|
||||
static HloInstruction* AppendSuffix(
|
||||
HloInstruction* input_tuple,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values);
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_
|
81
tensorflow/compiler/xla/service/tuple_util_test.cc
Normal file
81
tensorflow/compiler/xla/service/tuple_util_test.cc
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
|
||||
HloComputation** entry_computation, HloInstruction** param0,
|
||||
HloInstruction** param1) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY entry {
|
||||
p0 = (f32[32,32]{1,0},f32[32,32]{1,0},f32[32,32]{1,0}) parameter(0)
|
||||
ROOT p1 = f32[32,32]{1,0} parameter(1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
tools::Parse(hlo_string));
|
||||
|
||||
*entry_computation = module->entry_computation();
|
||||
*param0 = (*entry_computation)->parameter_instruction(0);
|
||||
*param1 = (*entry_computation)->parameter_instruction(1);
|
||||
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
TEST(TupleUtilTest, ExtractPrefix) {
|
||||
HloInstruction *param0, *param1;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
GetParsedModule(&entry_computation, ¶m0, ¶m1));
|
||||
|
||||
HloInstruction* prefix = TupleUtil::ExtractPrefix(param0, 2);
|
||||
|
||||
EXPECT_THAT(prefix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
|
||||
op::GetTupleElement(op::Parameter(0), 1)));
|
||||
}
|
||||
|
||||
TEST(TupleUtilTest, AppendSuffix) {
|
||||
HloInstruction *param0, *param1;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
GetParsedModule(&entry_computation, ¶m0, ¶m1));
|
||||
|
||||
HloInstruction* with_suffix =
|
||||
TupleUtil::AppendSuffix(param0, {param1, param1});
|
||||
|
||||
EXPECT_THAT(with_suffix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
|
||||
op::GetTupleElement(op::Parameter(0), 1),
|
||||
op::GetTupleElement(op::Parameter(0), 2),
|
||||
op::Parameter(1), op::Parameter(1)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -0,0 +1,296 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
using tensorflow::gtl::FlatMap;
|
||||
using tensorflow::gtl::FlatSet;
|
||||
using tensorflow::gtl::InlinedVector;
|
||||
|
||||
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
|
||||
// operands as needed. All of its transitive operands are expected to be either
|
||||
// in `hoisted_instructions` or `unhoisted_invariant_instructions`. This
|
||||
// function hoists the operands in `unhoisted_invariant_instructions` and moves
|
||||
// them into `hoisted_instructions`.
|
||||
static void CreateLoopInvariantCopy(
|
||||
FlatMap<HloInstruction*, HloInstruction*>* hoisted_instructions,
|
||||
FlatSet<HloInstruction*>* unhoisted_invariant_instructions,
|
||||
HloInstruction* while_instr, HloInstruction* to_hoist) {
|
||||
HloComputation* parent_of_while = while_instr->parent();
|
||||
HloComputation* while_body = while_instr->while_body();
|
||||
|
||||
struct DFSFrame {
|
||||
HloInstruction* instruction;
|
||||
int64 operand_index;
|
||||
};
|
||||
|
||||
InlinedVector<DFSFrame, 8> dfs_stack;
|
||||
dfs_stack.push_back({to_hoist, 0});
|
||||
|
||||
HloInstruction* while_body_param = while_body->parameter_instruction(0);
|
||||
HloInstruction* while_operand = while_instr->mutable_operand(0);
|
||||
|
||||
do {
|
||||
DFSFrame* frame = &dfs_stack.back();
|
||||
if (frame->operand_index == frame->instruction->operand_count()) {
|
||||
HloInstruction* old_instruction = frame->instruction;
|
||||
|
||||
// All of the operands for old_instruction have been cloned, so it is
|
||||
// time to clone old_instruction itself.
|
||||
|
||||
auto get_new_operand = [&](HloInstruction* old_operand) {
|
||||
return old_operand == while_body_param
|
||||
? while_operand
|
||||
: FindOrDie(*hoisted_instructions, old_operand);
|
||||
};
|
||||
|
||||
InlinedVector<HloInstruction*, 4> new_operands;
|
||||
c_transform(old_instruction->operands(), std::back_inserter(new_operands),
|
||||
get_new_operand);
|
||||
|
||||
HloInstruction* new_instruction =
|
||||
parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands(
|
||||
old_instruction->shape(), new_operands));
|
||||
|
||||
InsertOrDie(hoisted_instructions, old_instruction, new_instruction);
|
||||
|
||||
// Approximately half of the instructions that would normally be present
|
||||
// in unhoisted_invariant_instructions are constants. We save a bit of
|
||||
// compile time by not putting these in the hashtable.
|
||||
CHECK_EQ(unhoisted_invariant_instructions->erase(old_instruction),
|
||||
to_hoist != old_instruction &&
|
||||
old_instruction->opcode() != HloOpcode::kConstant);
|
||||
dfs_stack.pop_back();
|
||||
continue;
|
||||
}
|
||||
|
||||
HloInstruction* next_operand =
|
||||
frame->instruction->mutable_operand(frame->operand_index++);
|
||||
if (hoisted_instructions->count(next_operand) ||
|
||||
next_operand == while_body_param) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dfs_stack.push_back({next_operand, 0});
|
||||
} while (!dfs_stack.empty());
|
||||
}
|
||||
|
||||
// Returns true if `instruction` is worth hoisting only if it lets us hoist some
|
||||
// instruction using it. The rationale is that hoisting these instructions will
|
||||
// prevent simplification and fusion in the while body.
|
||||
static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
|
||||
switch (instruction.opcode()) {
|
||||
default:
|
||||
return false;
|
||||
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kReverse:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kTuple:
|
||||
return true;
|
||||
|
||||
case HloOpcode::kTranspose:
|
||||
return ShapeUtil::TransposeIsBitcast(
|
||||
/*input_shape=*/instruction.operand(0)->shape(),
|
||||
/*output_shape=*/instruction.shape(), instruction.dimensions());
|
||||
|
||||
case HloOpcode::kReshape:
|
||||
return ShapeUtil::ReshapeIsBitcast(
|
||||
/*input_shape=*/instruction.operand(0)->shape(),
|
||||
/*output_shape=*/instruction.shape());
|
||||
}
|
||||
}
|
||||
|
||||
// Populates `gte_set` with the GetTupleElement instructions in `while_body`
|
||||
// that access elements in the parameter tuple that don't change across
|
||||
// iterations. Assumes `while_body` is the body computation of the while loop
|
||||
// in question.
|
||||
static void GatherInvariantGTEs(HloComputation* while_body,
|
||||
FlatSet<HloInstruction*>* gte_set) {
|
||||
const HloInstruction::InstructionVector root_operands =
|
||||
while_body->root_instruction()->operands();
|
||||
for (int i = 0; i < root_operands.size(); i++) {
|
||||
HloInstruction* instr = root_operands[i];
|
||||
if (instr->opcode() == HloOpcode::kGetTupleElement &&
|
||||
instr->tuple_index() == i &&
|
||||
instr->operand(0) == while_body->parameter_instruction(0) &&
|
||||
ShapeUtil::IsArray(instr->shape())) {
|
||||
InsertOrDie(gte_set, instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
|
||||
HloInstruction* while_instr) {
|
||||
auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false);
|
||||
|
||||
if (!ShapeUtil::IsTuple(while_instr->shape())) {
|
||||
// This restriction leaves one interesting pattern on the table:
|
||||
//
|
||||
// while_body(f32[1024, 1024] %param) {
|
||||
// %value = expensive_op(%param)
|
||||
// outfeed(%value)
|
||||
// ROOT = %param
|
||||
// }
|
||||
//
|
||||
// If we see that pattern in the while, instead of generalizing this
|
||||
// algorithm to work with non-tuples, we should instead add a pass that
|
||||
// canonicalizes while loops like the above to use a tuple state.
|
||||
return false;
|
||||
}
|
||||
|
||||
string while_instr_name = while_instr->ToString(print_no_metadata);
|
||||
VLOG(2) << "Trying to hoist from " << while_instr_name;
|
||||
|
||||
HloComputation* while_body = while_instr->while_body();
|
||||
|
||||
// Maps instructions in the while body to instructions hoisted outside the
|
||||
// while that compute the same value.
|
||||
FlatMap<HloInstruction*, HloInstruction*> hoisted_instructions;
|
||||
|
||||
// Contains instructions that can be legally hoisted, but were deemed to be
|
||||
// unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we
|
||||
// hoist an instruction in this set, we move it from
|
||||
// unhoisted_invariant_instructions to hoisted_instructions.
|
||||
FlatSet<HloInstruction*> unhoisted_invariant_instructions;
|
||||
|
||||
// Invariant GTE's axiomatically satisfy the constraints for
|
||||
// unhoisted_invariant_instructions -- they can be legally hoisted, but there
|
||||
// is no benefit to hoisting them unless something that uses it is also
|
||||
// hoisted.
|
||||
GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions);
|
||||
|
||||
if (unhoisted_invariant_instructions.empty()) {
|
||||
// There are no obviously loop invariant elements in the state being
|
||||
// threaded through the while loop so give up. In theory this precondition
|
||||
// is too strong -- we could have code that e.g. permutes the elements in
|
||||
// the while state but uses a select to pick the same value on every
|
||||
// iteration.
|
||||
return false;
|
||||
}
|
||||
|
||||
// instructions_to_replace[i] is hoisted into a loop invariant instruction
|
||||
// replacement_instructions[i].
|
||||
std::vector<HloInstruction*> instructions_to_replace;
|
||||
std::vector<HloInstruction*> replacement_instructions;
|
||||
|
||||
for (auto* instruction : while_body->MakeInstructionPostOrder()) {
|
||||
if (instruction->HasSideEffect() ||
|
||||
instruction->opcode() == HloOpcode::kParameter ||
|
||||
!instruction->control_predecessors().empty() ||
|
||||
!instruction->control_successors().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto is_invariant = [&](HloInstruction* op) {
|
||||
return hoisted_instructions.find(op) != hoisted_instructions.end() ||
|
||||
unhoisted_invariant_instructions.count(op) ||
|
||||
op->opcode() == HloOpcode::kConstant;
|
||||
};
|
||||
|
||||
if (!c_all_of(instruction->operands(), is_invariant)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (NotWorthHoistingIndividually(*instruction)) {
|
||||
VLOG(2) << "Adding " << instruction->ToString(print_no_metadata)
|
||||
<< " to unhoisted invariant set.";
|
||||
// Approximately half of the instructions that reach this point are
|
||||
// constants. We save a bit of compile time by not putting these in the
|
||||
// hashtable.
|
||||
if (instruction->opcode() != HloOpcode::kConstant) {
|
||||
InsertOrDie(&unhoisted_invariant_instructions, instruction);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata);
|
||||
|
||||
CreateLoopInvariantCopy(&hoisted_instructions,
|
||||
&unhoisted_invariant_instructions, while_instr,
|
||||
instruction);
|
||||
|
||||
instructions_to_replace.push_back(instruction);
|
||||
replacement_instructions.push_back(
|
||||
FindOrDie(hoisted_instructions, instruction));
|
||||
}
|
||||
|
||||
if (instructions_to_replace.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result,
|
||||
WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions));
|
||||
|
||||
HloComputation* new_while_body =
|
||||
live_in_instructions_result.new_while_instr->while_body();
|
||||
|
||||
for (int i = 0; i < instructions_to_replace.size(); i++) {
|
||||
HloInstruction* instruction_to_replace_in_new_while =
|
||||
FindOrDie(live_in_instructions_result.while_body_instruction_map,
|
||||
instructions_to_replace[i]);
|
||||
TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction(
|
||||
instruction_to_replace_in_new_while,
|
||||
live_in_instructions_result.while_body_live_in_values[i]));
|
||||
}
|
||||
|
||||
VLOG(1) << "Hoisted " << instructions_to_replace.size()
|
||||
<< " instructions from " << while_instr_name;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
std::vector<HloInstruction*> while_instrs;
|
||||
for (auto* comp : module->computations()) {
|
||||
c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
|
||||
[](const HloInstruction* instr) {
|
||||
return instr->opcode() == HloOpcode::kWhile;
|
||||
});
|
||||
}
|
||||
|
||||
for (HloInstruction* while_instr : while_instrs) {
|
||||
// Right now we only hoist computations from the while body, but
|
||||
// TryHoistingInvariantInstructionsFromWhileBody can be generalized to
|
||||
// optimize the condition computation too, if needed.
|
||||
//
|
||||
// The transform we do here is a pessmization for while loops that execute
|
||||
// zero times*, but at this time we expect those to be rare. If this
|
||||
// becomes a problem we can consider using the conditional HLO to avoid
|
||||
// doing extra work for while loops with zero trip count.
|
||||
//
|
||||
// * We delete while loops that have a zero trip count, so this would have
|
||||
// to be a while loop with a somewhat opaque condition expression.
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
TryHoistingInvariantInstructionsFromWhileBody(while_instr));
|
||||
changed |= result;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace xla
|
@ -0,0 +1,39 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass that rewrites while loops to hoist loop invariant instructions in
|
||||
// the while body into the computation that contains the while instruction.
|
||||
|
||||
class WhileLoopInvariantCodeMotion : public HloPassInterface {
|
||||
public:
|
||||
~WhileLoopInvariantCodeMotion() override = default;
|
||||
|
||||
tensorflow::StringPiece name() const override {
|
||||
return "while-loop-invariant-code-motion";
|
||||
}
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_
|
@ -0,0 +1,442 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
|
||||
public:
|
||||
// Makes a computation which has one parameter, of the given shape, and always
|
||||
// returns PRED[]{true}. This is useful as a dummy loop condition.
|
||||
HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
|
||||
HloModule* module);
|
||||
};
|
||||
|
||||
static void FindOnlyWhileInstruction(HloComputation* computation,
|
||||
HloInstruction** while_instruction) {
|
||||
*while_instruction = nullptr;
|
||||
for (auto* instr : computation->instructions()) {
|
||||
if (instr->opcode() == HloOpcode::kWhile) {
|
||||
ASSERT_EQ(*while_instruction, nullptr);
|
||||
*while_instruction = instr;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_NE(*while_instruction, nullptr);
|
||||
}
|
||||
|
||||
HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation(
|
||||
const Shape& param_shape, HloModule* module) {
|
||||
HloComputation::Builder builder(TestName() + ".always_true");
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, param_shape, "param"));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
|
||||
return module->AddEmbeddedComputation(builder.Build());
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) {
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
Shape while_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
|
||||
|
||||
HloComputation* while_body = [&]() {
|
||||
HloComputation::Builder builder(TestName() + ".while_body");
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "param"));
|
||||
HloInstruction* gte_0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
|
||||
HloInstruction* gte_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
|
||||
HloInstruction* add_result =
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
|
||||
|
||||
return module().AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "init_value"));
|
||||
builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
|
||||
while_body, init_value));
|
||||
HloComputation* entry_computation =
|
||||
module().AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopInvariantCodeMotion{}.Run(&module()));
|
||||
EXPECT_TRUE(simplified_loop);
|
||||
|
||||
HloInstruction* transformed_while;
|
||||
FindOnlyWhileInstruction(entry_computation, &transformed_while);
|
||||
|
||||
EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(op::Add())));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) {
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
Shape while_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
|
||||
|
||||
HloComputation* while_body = [&]() {
|
||||
HloComputation::Builder builder(TestName() + ".while_body");
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "param"));
|
||||
HloInstruction* gte_0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
|
||||
HloInstruction* gte_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
|
||||
HloInstruction* gte_2_loop_variant = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 2));
|
||||
|
||||
HloInstruction* add_result =
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
|
||||
HloInstruction* mul_result =
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kMultiply, add_result, gte_1));
|
||||
HloInstruction* negate_result =
|
||||
builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
scalar_s32, HloOpcode::kNegate, mul_result));
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int32>(4)));
|
||||
HloInstruction* sub_result =
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kSubtract, negate_result, constant));
|
||||
HloInstruction* divide_result =
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({gte_0, gte_1, divide_result}));
|
||||
|
||||
return module().AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "init_value"));
|
||||
builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
|
||||
while_body, init_value));
|
||||
HloComputation* entry_computation =
|
||||
module().AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopInvariantCodeMotion{}.Run(&module()));
|
||||
EXPECT_TRUE(simplified_loop);
|
||||
|
||||
HloInstruction* transformed_while;
|
||||
FindOnlyWhileInstruction(entry_computation, &transformed_while);
|
||||
|
||||
EXPECT_THAT(entry_computation->instructions(),
|
||||
AllOf(Contains(op::Add()), Contains(op::Multiply()),
|
||||
Contains(op::Negate()), Contains(op::Subtract()),
|
||||
Contains(op::Constant()),
|
||||
|
||||
// The division had a loop varying operand so that better
|
||||
// not be hoisted.
|
||||
Not(Contains(op::Divide()))));
|
||||
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(),
|
||||
op::Subtract(), op::Constant()))));
|
||||
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Contains(op::Divide()));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopInvariantCodeMotionTest,
|
||||
DontHoistTriviallyLoopVaryingComputation) {
|
||||
// Basic negative test: the add expression is not loop invariant.
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
|
||||
|
||||
HloComputation* while_body = [&]() {
|
||||
HloComputation::Builder builder(TestName() + ".while_body");
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "param"));
|
||||
HloInstruction* gte_0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
|
||||
HloInstruction* gte_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
|
||||
HloInstruction* add_result =
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result}));
|
||||
|
||||
return module().AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "init_value"));
|
||||
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
|
||||
while_body, init_value));
|
||||
|
||||
module().AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopInvariantCodeMotion{}.Run(&module()));
|
||||
EXPECT_FALSE(simplified_loop);
|
||||
|
||||
EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopInvariantCodeMotionTest,
|
||||
DontHoistLoopVaryingComputationWithAlternatingTuples) {
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
Shape while_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
|
||||
|
||||
HloComputation* while_body = [&]() {
|
||||
HloComputation::Builder builder(TestName() + ".while_body");
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "param"));
|
||||
HloInstruction* gte_0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
|
||||
HloInstruction* gte_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
|
||||
HloInstruction* add_result =
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({gte_1, gte_0, add_result}));
|
||||
|
||||
return module().AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "init_value"));
|
||||
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
|
||||
while_body, init_value));
|
||||
|
||||
module().AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopInvariantCodeMotion{}.Run(&module()));
|
||||
EXPECT_FALSE(simplified_loop);
|
||||
|
||||
EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add()));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
|
||||
|
||||
HloComputation* while_body = [&]() {
|
||||
HloComputation::Builder builder(TestName() + ".while_body");
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "param"));
|
||||
HloInstruction* gte_0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
|
||||
HloInstruction* gte_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateOutfeed(scalar_s32, gte_0, ""));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
|
||||
|
||||
return module().AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "init_value"));
|
||||
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
|
||||
while_body, init_value));
|
||||
|
||||
module().AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopInvariantCodeMotion{}.Run(&module()));
|
||||
EXPECT_FALSE(simplified_loop);
|
||||
|
||||
EXPECT_THAT(while_inst->while_body()->instructions(),
|
||||
Contains(op::Outfeed()));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
|
||||
// The bitcast's user, an outfeed, can't be hoisted, so don't hoist the
|
||||
// bitcast either.
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
auto scalar_f32 = ShapeUtil::MakeShape(F32, {});
|
||||
Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
|
||||
|
||||
HloComputation* while_body = [&]() {
|
||||
HloComputation::Builder builder(TestName() + ".while_body");
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "param"));
|
||||
HloInstruction* gte_0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
|
||||
HloInstruction* gte_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
|
||||
HloInstruction* bitcast_inst = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, ""));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
|
||||
|
||||
return module().AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "init_value"));
|
||||
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
|
||||
while_body, init_value));
|
||||
|
||||
module().AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopInvariantCodeMotion{}.Run(&module()));
|
||||
EXPECT_FALSE(simplified_loop);
|
||||
|
||||
EXPECT_THAT(while_inst->while_body()->instructions(),
|
||||
Contains(op::Outfeed()));
|
||||
EXPECT_THAT(while_inst->while_body()->instructions(),
|
||||
Contains(op::Bitcast()));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) {
|
||||
// The bitcast's user can be hoisted, so hoist the bitcast too.
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
auto scalar_f32 = ShapeUtil::MakeShape(F32, {});
|
||||
Shape while_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32});
|
||||
|
||||
HloComputation* while_body = [&]() {
|
||||
HloComputation::Builder builder(TestName() + ".while_body");
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "param"));
|
||||
HloInstruction* gte_0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
|
||||
HloInstruction* gte_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_f32, param, 1));
|
||||
HloInstruction* bitcast_inst = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0));
|
||||
HloInstruction* add_inst =
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({gte_0, gte_1, add_inst}));
|
||||
|
||||
return module().AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "init_value"));
|
||||
builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
|
||||
while_body, init_value));
|
||||
|
||||
HloComputation* entry_computation =
|
||||
module().AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopInvariantCodeMotion{}.Run(&module()));
|
||||
EXPECT_TRUE(simplified_loop);
|
||||
|
||||
HloInstruction* transformed_while;
|
||||
FindOnlyWhileInstruction(entry_computation, &transformed_while);
|
||||
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(op::Add())));
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(op::Bitcast())));
|
||||
EXPECT_THAT(entry_computation->instructions(), Contains(op::Add()));
|
||||
EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast()));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) {
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
Shape while_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32});
|
||||
|
||||
HloComputation* while_body;
|
||||
{
|
||||
HloComputation::Builder builder(TestName() + ".while_body");
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "param"));
|
||||
HloInstruction* gte_0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
|
||||
HloInstruction* gte_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
|
||||
HloInstruction* add_result =
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kAdd, gte_0, gte_1));
|
||||
TF_ASSERT_OK(param->AddControlDependencyTo(add_result));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({gte_0, gte_1, add_result}));
|
||||
|
||||
while_body = module().AddEmbeddedComputation(builder.Build());
|
||||
}
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "init_value"));
|
||||
builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
|
||||
while_body, init_value));
|
||||
module().AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopInvariantCodeMotion{}.Run(&module()));
|
||||
EXPECT_FALSE(simplified_loop);
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) {
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
|
||||
|
||||
HloComputation* while_body = [&]() {
|
||||
HloComputation::Builder builder(TestName() + ".passthrough");
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "param"));
|
||||
HloComputation* result = module().AddEmbeddedComputation(builder.Build());
|
||||
|
||||
result->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
|
||||
return result;
|
||||
}();
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* init_value = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, while_shape, "init_value"));
|
||||
builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
|
||||
while_body, init_value));
|
||||
module().AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopInvariantCodeMotion{}.Run(&module()));
|
||||
EXPECT_FALSE(simplified_loop);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -595,7 +595,9 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
|
||||
auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
|
||||
while_op->shape(), while_op->operands(), while_op->while_body()));
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
|
||||
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op));
|
||||
TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
|
||||
CallInliner::Inline(call_op));
|
||||
(void)inlined_instructions_map;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
140
tensorflow/compiler/xla/service/while_util.cc
Normal file
140
tensorflow/compiler/xla/service/while_util.cc
Normal file
@ -0,0 +1,140 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
static StatusOr<HloComputation*> WidenWhileCondition(
|
||||
HloComputation* narrow_condition, const Shape& wide_shape) {
|
||||
const Shape& narrow_shape =
|
||||
narrow_condition->parameter_instruction(0)->shape();
|
||||
|
||||
HloComputation* wide_while_cond = [&]() {
|
||||
HloComputation::Builder builder(
|
||||
tensorflow::strings::StrCat("wide.", narrow_condition->name()));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
|
||||
|
||||
// This is needed so that the root instruction is shaped as a PRED[] -- we
|
||||
// need to get this right to begin with since we can't mutate the type of
|
||||
// the root instruction later. We later change the root instruction to
|
||||
// something more appropriate.
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloInstruction* truncated_parameter =
|
||||
TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
|
||||
narrow_shape.tuple_shapes_size());
|
||||
HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
|
||||
HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
|
||||
{truncated_parameter}, narrow_condition));
|
||||
|
||||
wide_while_cond->set_root_instruction(call_narrow_cond);
|
||||
|
||||
TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
|
||||
return wide_while_cond;
|
||||
}
|
||||
|
||||
static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
|
||||
WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
|
||||
const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
|
||||
|
||||
HloComputation* wide_while_body = [&]() {
|
||||
HloComputation::Builder builder(
|
||||
tensorflow::strings::StrCat("wide.", narrow_body->name()));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
|
||||
return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
|
||||
HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
|
||||
wide_parameter, narrow_shape.tuple_shapes_size());
|
||||
HloInstruction* call_narrow_body =
|
||||
wide_while_body->AddInstruction(HloInstruction::CreateCall(
|
||||
narrow_shape, {truncated_parameter}, narrow_body));
|
||||
|
||||
std::vector<HloInstruction*> live_through_values;
|
||||
for (int i = narrow_shape.tuple_shapes_size();
|
||||
i < wide_shape.tuple_shapes_size(); i++) {
|
||||
live_through_values.push_back(
|
||||
wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
wide_shape.tuple_shapes(i), wide_parameter, i)));
|
||||
}
|
||||
|
||||
wide_while_body->set_root_instruction(
|
||||
TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
|
||||
CallInliner::Inline(call_narrow_body));
|
||||
return {{wide_while_body, std::move(inlined_instructions_map)}};
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
|
||||
WhileUtil::MakeInstructionsLiveIn(
|
||||
HloInstruction* while_instr,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
|
||||
CHECK(ShapeUtil::IsTuple(while_instr->shape()));
|
||||
|
||||
int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
|
||||
Shape new_while_shape = while_instr->shape();
|
||||
for (auto* instruction : instructions) {
|
||||
*new_while_shape.add_tuple_shapes() = instruction->shape();
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloComputation * new_while_condition,
|
||||
WidenWhileCondition(while_instr->while_condition(), new_while_shape));
|
||||
|
||||
HloComputation* new_while_body;
|
||||
CallInliner::InlinedInstructionMap inlined_instructions_map;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::tie(new_while_body, inlined_instructions_map),
|
||||
WidenWhileBody(while_instr->while_body(), new_while_shape));
|
||||
|
||||
HloInstruction* new_while_init =
|
||||
TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
|
||||
HloComputation* containing_computation = while_instr->parent();
|
||||
HloInstruction* new_while = containing_computation->AddInstruction(
|
||||
HloInstruction::CreateWhile(new_while_shape, new_while_condition,
|
||||
new_while_body, new_while_init));
|
||||
TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction(
|
||||
while_instr, TupleUtil::ExtractPrefix(
|
||||
new_while, while_instr->shape().tuple_shapes_size())));
|
||||
|
||||
HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
|
||||
std::vector<HloInstruction*> live_in_instructions;
|
||||
for (int64 i = elements_in_old_while_shape;
|
||||
i < new_while_shape.tuple_shapes_size(); i++) {
|
||||
live_in_instructions.push_back(
|
||||
new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
instructions[i - elements_in_old_while_shape]->shape(),
|
||||
while_body_param, i)));
|
||||
}
|
||||
|
||||
WhileUtil::MakeInstructionsLiveInResult result;
|
||||
|
||||
result.new_while_instr = new_while;
|
||||
result.while_body_live_in_values = std::move(live_in_instructions);
|
||||
result.while_body_instruction_map = std::move(inlined_instructions_map);
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
} // namespace xla
|
58
tensorflow/compiler/xla/service/while_util.h
Normal file
58
tensorflow/compiler/xla/service/while_util.h
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
||||
namespace xla {
|
||||
class WhileUtil {
|
||||
public:
|
||||
// Holds a return value from MakeInstructionsLiveIn.
|
||||
struct MakeInstructionsLiveInResult {
|
||||
// The new while operation that has the requested values live in.
|
||||
HloInstruction* new_while_instr;
|
||||
|
||||
// The i'th element of `while_body_live_in_values` is an instruction in the
|
||||
// while body that holds the i'th *newly added* live in value at runtime.
|
||||
std::vector<HloInstruction*> while_body_live_in_values;
|
||||
|
||||
// `while_body_instruction_map` maps instructions in the original while body
|
||||
// to the corresponding instructions in the body for the newly created while
|
||||
// operation.
|
||||
CallInliner::InlinedInstructionMap while_body_instruction_map;
|
||||
};
|
||||
|
||||
// Replaces `while_instr` with a new while instruction that is equivalent to
|
||||
// `while_instr`, except that it has all of the HLO instructions in
|
||||
// `instructions` as live-in, loop invariant values. These new live in values
|
||||
// are represented as new elements appended to the parameter of the while
|
||||
// loop, which must be of tuple shape. GetTupleElement instructions computing
|
||||
// each new live in value is returned in the `while_body_live_in_values`
|
||||
// vector.
|
||||
//
|
||||
// Precondition: `while_instr` must have a tuple shaped state.
|
||||
//
|
||||
// Every instruction in `instructions` must be contained in the computation
|
||||
// that contains `while_instr`.
|
||||
static StatusOr<MakeInstructionsLiveInResult> MakeInstructionsLiveIn(
|
||||
HloInstruction* while_instr,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_
|
130
tensorflow/compiler/xla/service/while_util_test.cc
Normal file
130
tensorflow/compiler/xla/service/while_util_test.cc
Normal file
@ -0,0 +1,130 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
|
||||
HloComputation** entry_computation, HloInstruction** param0,
|
||||
HloInstruction** param1, HloInstruction** param2) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule ModuleWithWhile
|
||||
|
||||
while_body {
|
||||
ROOT p_body = (f32[32,32]{1,0}, f32[32,32]{1,0}) parameter(0)
|
||||
}
|
||||
|
||||
while_condition {
|
||||
p_cond = f32[32,32]{1,0} parameter(0)
|
||||
ROOT result = pred[] constant(true)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p_entry_0 = f32[32,32]{1,0} parameter(0)
|
||||
p_entry_1 = s32[32,32]{1,0} parameter(1)
|
||||
p_entry_2 = s64[32,32]{1,0} parameter(2)
|
||||
while_init = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p_entry_0, p_entry_0)
|
||||
ROOT while = (f32[32,32]{1,0}, f32[32,32]{1,0}) while(while_init), condition=while_condition, body=while_body
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
tools::Parse(hlo_string));
|
||||
|
||||
*entry_computation = module->entry_computation();
|
||||
*param0 = (*entry_computation)->parameter_instruction(0);
|
||||
*param1 = (*entry_computation)->parameter_instruction(1);
|
||||
*param2 = (*entry_computation)->parameter_instruction(2);
|
||||
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
TEST(WhileUtil, MakeZeroInstructionsLiveOp) {
|
||||
HloInstruction *param0, *param1, *param2;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2));
|
||||
|
||||
HloInstruction* while_instr = entry_computation->root_instruction();
|
||||
ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
|
||||
WhileUtil::MakeInstructionsLiveIn(while_instr, /*instructions=*/{}));
|
||||
|
||||
HloInstruction* new_while_instr = make_live_in_result.new_while_instr;
|
||||
|
||||
EXPECT_THAT(
|
||||
entry_computation->root_instruction(),
|
||||
op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0),
|
||||
op::GetTupleElement(::testing::Eq(new_while_instr), 1)));
|
||||
|
||||
auto param_reconstructed =
|
||||
op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
|
||||
op::GetTupleElement(op::Parameter(0), 1));
|
||||
|
||||
EXPECT_THAT(new_while_instr->while_body()->root_instruction(),
|
||||
op::Tuple(op::GetTupleElement(param_reconstructed, 0),
|
||||
op::GetTupleElement(param_reconstructed, 1)));
|
||||
}
|
||||
|
||||
TEST(WhileUtilTest, MakeTwoInstructionsLive) {
|
||||
HloInstruction *param0, *param1, *param2;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2));
|
||||
|
||||
HloInstruction* while_instr = entry_computation->root_instruction();
|
||||
ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
|
||||
WhileUtil::MakeInstructionsLiveIn(while_instr,
|
||||
/*instructions=*/{param0, param1}));
|
||||
|
||||
HloInstruction* new_while_instr = make_live_in_result.new_while_instr;
|
||||
|
||||
XLA_VLOG_LINES(3, module->ToString());
|
||||
|
||||
EXPECT_THAT(
|
||||
entry_computation->root_instruction(),
|
||||
op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0),
|
||||
op::GetTupleElement(::testing::Eq(new_while_instr), 1)));
|
||||
|
||||
auto first_half_param_reconstructed =
|
||||
op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
|
||||
op::GetTupleElement(op::Parameter(0), 1));
|
||||
|
||||
EXPECT_THAT(new_while_instr->while_body()->root_instruction(),
|
||||
op::Tuple(op::GetTupleElement(first_half_param_reconstructed, 0),
|
||||
op::GetTupleElement(first_half_param_reconstructed, 1),
|
||||
op::GetTupleElement(op::Parameter(0), 2),
|
||||
op::GetTupleElement(op::Parameter(0), 3)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -84,7 +84,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
||||
if (lhs.layout().format() != rhs.layout().format()) {
|
||||
return false;
|
||||
}
|
||||
if (LayoutUtil::IsDense(lhs)) {
|
||||
if (LayoutUtil::IsDenseArray(lhs)) {
|
||||
if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs),
|
||||
LayoutUtil::MinorToMajor(rhs))) {
|
||||
VLOG(3) << "CompareShapes: lhs layout != rhs layout";
|
||||
@ -202,6 +202,17 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
return MakeShapeWithLayout(element_type, dimensions, layout);
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
|
||||
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
int64 max_sparse_elements) {
|
||||
DCHECK_NE(TUPLE, element_type);
|
||||
DCHECK_NE(OPAQUE, element_type);
|
||||
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
|
||||
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
|
||||
return shape;
|
||||
}
|
||||
|
||||
/* static */ Shape
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
const Shape& shape) {
|
||||
@ -249,7 +260,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
}
|
||||
|
||||
/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
|
||||
CHECK(LayoutUtil::IsDense(*shape));
|
||||
CHECK(LayoutUtil::IsDenseArray(*shape));
|
||||
shape->mutable_layout()->add_minor_to_major(Rank(*shape));
|
||||
shape->add_dimensions(bound);
|
||||
TF_DCHECK_OK(ValidateShape(*shape));
|
||||
@ -658,23 +669,55 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
DCHECK_NE(OPAQUE, shape.element_type());
|
||||
if (shape.element_type() == TUPLE) {
|
||||
CHECK_GT(pointer_size, 0);
|
||||
return pointer_size * shape.tuple_shapes_size();
|
||||
return ByteSizeOfTupleIndexTable(shape, pointer_size);
|
||||
}
|
||||
int64 byte_size = ByteSizeOfElements(shape);
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
byte_size += ByteSizeOfSparseIndices(shape);
|
||||
}
|
||||
return byte_size;
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
|
||||
int64 pointer_size) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
DCHECK_EQ(TUPLE, shape.element_type());
|
||||
CHECK_GT(pointer_size, 0);
|
||||
return pointer_size * shape.tuple_shapes_size();
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
DCHECK(ShapeUtil::IsArray(shape));
|
||||
int64 allocated_element_count;
|
||||
if (shape.layout().padded_dimensions_size() > 0) {
|
||||
CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size());
|
||||
allocated_element_count = 1;
|
||||
for (int64 dimension_size : shape.layout().padded_dimensions()) {
|
||||
allocated_element_count *= dimension_size;
|
||||
}
|
||||
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
|
||||
} else {
|
||||
allocated_element_count = ElementsIn(shape);
|
||||
CHECK(LayoutUtil::IsDenseArray(shape));
|
||||
tensorflow::gtl::ArraySlice<int64> padded_dimensions =
|
||||
LayoutUtil::PaddedDimensions(shape);
|
||||
if (!padded_dimensions.empty()) {
|
||||
CHECK_EQ(Rank(shape), padded_dimensions.size());
|
||||
allocated_element_count = 1;
|
||||
for (int64 dimension_size : padded_dimensions) {
|
||||
allocated_element_count *= dimension_size;
|
||||
}
|
||||
} else {
|
||||
allocated_element_count = ElementsIn(shape);
|
||||
}
|
||||
}
|
||||
return allocated_element_count *
|
||||
ByteSizeOfPrimitiveType(shape.element_type());
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
DCHECK(LayoutUtil::IsSparseArray(shape));
|
||||
return LayoutUtil::MaxSparseElements(shape.layout()) *
|
||||
ShapeUtil::Rank(shape) * sizeof(int64);
|
||||
}
|
||||
|
||||
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
|
||||
const Shape& shape) {
|
||||
if (shape.element_type() == TUPLE) {
|
||||
@ -900,7 +943,7 @@ Status ForEachMutableSubshapeHelper(
|
||||
new_shape.add_dimensions(dim);
|
||||
}
|
||||
if (shape.has_layout()) {
|
||||
CHECK(LayoutUtil::IsDense(shape));
|
||||
CHECK(LayoutUtil::IsDenseArray(shape));
|
||||
Layout* new_layout = new_shape.mutable_layout();
|
||||
new_layout->set_format(DENSE);
|
||||
new_layout->clear_minor_to_major();
|
||||
|
@ -143,7 +143,10 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
|
||||
class ShapeUtil {
|
||||
public:
|
||||
// Returns the number of elements are contained within the provided shape;
|
||||
// e.g. for rank 0 (scalars) the result is always 1.
|
||||
// e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes
|
||||
// may not actually be able to store this number of elements. See
|
||||
// LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
|
||||
// elements that can be stored in a sparse shape.
|
||||
// Precondition: !IsTuple(shape)
|
||||
static int64 ElementsIn(const Shape& shape);
|
||||
|
||||
@ -164,6 +167,27 @@ class ShapeUtil {
|
||||
// Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
|
||||
static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
|
||||
|
||||
// Returns the number of bytes required to store the tuple member pointers for
|
||||
// a allocation of shape. The `shape` must be a TUPLE shape, and
|
||||
// `pointer_size` must be larger than zero.
|
||||
static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
|
||||
int64 pointer_size);
|
||||
|
||||
// Returns the number of bytes required for the elements in an allocation of
|
||||
// `shape`, which must be an array shape. The return value does not include
|
||||
// the bytes needed to store sparse indices. Dense shapes use a separate
|
||||
// memory location for each element, and so for these shapes,
|
||||
// `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this
|
||||
// size also includes padding if present in the layout. For sparse shapes,
|
||||
// `ByteSizeOf(shape) == ByteSizeOfElements(shape) +
|
||||
// ByteSizeOfSparseindices(shape)`.
|
||||
static int64 ByteSizeOfElements(const Shape& shape);
|
||||
|
||||
// Returns the number of bytes required for the sparse indices in an
|
||||
// allocation of shape. The shape must be an array shape. The return value
|
||||
// does not include the bytes needed to store sparse indices.
|
||||
static int64 ByteSizeOfSparseIndices(const Shape& shape);
|
||||
|
||||
// Returns a human-readable string that represents the given shape, with or
|
||||
// without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
|
||||
static string HumanString(const Shape& shape);
|
||||
@ -269,6 +293,10 @@ class ShapeUtil {
|
||||
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major);
|
||||
|
||||
static Shape MakeShapeWithSparseLayout(
|
||||
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
int64 max_sparse_elements);
|
||||
|
||||
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
|
||||
static Shape MakeShapeWithDescendingLayout(
|
||||
PrimitiveType element_type,
|
||||
|
110
tensorflow/compiler/xla/sparse_index_array.cc
Normal file
110
tensorflow/compiler/xla/sparse_index_array.cc
Normal file
@ -0,0 +1,110 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/sparse_index_array.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
SparseIndexArray::SparseIndexArray() : rank_(0), max_indices_(0) {}
|
||||
|
||||
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
|
||||
std::vector<int64> indices)
|
||||
: indices_(std::move(indices)), rank_(rank), max_indices_(max_indices) {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_EQ(indices_.size() % rank_, 0)
|
||||
<< "indices_.size(): " << indices_.size() << ", rank_: " << rank_;
|
||||
CHECK_LT(index_count(), max_indices_);
|
||||
}
|
||||
|
||||
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
|
||||
tensorflow::gtl::ArraySlice<int64> indices)
|
||||
: SparseIndexArray(max_indices, rank,
|
||||
std::vector<int64>(indices.begin(), indices.end())) {}
|
||||
|
||||
SparseIndexArray::SparseIndexArray(int64 max_indices,
|
||||
const Array2D<int64>& indices)
|
||||
: SparseIndexArray(max_indices, indices.n2(),
|
||||
std::vector<int64>(indices.begin(), indices.end())) {}
|
||||
|
||||
int64 SparseIndexArray::index_count() const {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_EQ(indices_.size() % rank_, 0);
|
||||
return indices_.size() / rank_;
|
||||
}
|
||||
|
||||
tensorflow::gtl::ArraySlice<int64> SparseIndexArray::At(
|
||||
int64 sparse_index_number) const {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_GE(sparse_index_number, 0);
|
||||
CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size());
|
||||
return tensorflow::gtl::ArraySlice<int64>(
|
||||
indices_.data() + rank_ * sparse_index_number, rank_);
|
||||
}
|
||||
|
||||
tensorflow::gtl::MutableArraySlice<int64> SparseIndexArray::At(
|
||||
int64 sparse_index_number) {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_GE(sparse_index_number, 0);
|
||||
CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size());
|
||||
return tensorflow::gtl::MutableArraySlice<int64>(
|
||||
indices_.data() + rank_ * sparse_index_number, rank_);
|
||||
}
|
||||
|
||||
void SparseIndexArray::Append(tensorflow::gtl::ArraySlice<int64> index) {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_EQ(index.size(), rank_);
|
||||
indices_.insert(indices_.end(), index.begin(), index.end());
|
||||
}
|
||||
|
||||
void SparseIndexArray::Clear() { indices_.clear(); }
|
||||
|
||||
void SparseIndexArray::Resize(int64 num_indices) {
|
||||
CHECK_GT(rank_, 0);
|
||||
indices_.resize(rank_ * num_indices);
|
||||
}
|
||||
|
||||
bool SparseIndexArray::Validate(const Shape& shape) const {
|
||||
if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) {
|
||||
return false;
|
||||
}
|
||||
int64 num_indices = index_count();
|
||||
if (num_indices > LayoutUtil::MaxSparseElements(shape.layout())) {
|
||||
return false;
|
||||
}
|
||||
if (num_indices < 2) {
|
||||
return true;
|
||||
}
|
||||
tensorflow::gtl::ArraySlice<int64> last = At(0);
|
||||
if (!IndexUtil::IndexInBounds(shape, last)) {
|
||||
return false;
|
||||
}
|
||||
for (int64 n = 1; n < num_indices; ++n) {
|
||||
tensorflow::gtl::ArraySlice<int64> next = At(n);
|
||||
if (!IndexUtil::IndexInBounds(shape, next)) {
|
||||
return false;
|
||||
}
|
||||
if (IndexUtil::CompareIndices(last, next) >= 0) {
|
||||
return false;
|
||||
}
|
||||
last = next;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace xla
|
176
tensorflow/compiler/xla/sparse_index_array.h
Normal file
176
tensorflow/compiler/xla/sparse_index_array.h
Normal file
@ -0,0 +1,176 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Utility class for managing sparse array indices.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Encapsulates the array of indices for a sparse array. A SparseIndexArray
|
||||
// contain indices for up to `max_indices` elements of a sparse array. Each
|
||||
// sparse index is an array of `rank` int64 value that gives the location of a
|
||||
// value within a sparse array. Note that the dimensions of the array are not
|
||||
// checked (except for the rank). To avoid confusion, we refer to the position
|
||||
// of an index within a SparseIndexArray as a sparse index number.
|
||||
class SparseIndexArray {
|
||||
public:
|
||||
SparseIndexArray();
|
||||
SparseIndexArray(const SparseIndexArray&) = default;
|
||||
SparseIndexArray(SparseIndexArray&&) = default;
|
||||
SparseIndexArray& operator=(const SparseIndexArray&) = default;
|
||||
SparseIndexArray& operator=(SparseIndexArray&&) = default;
|
||||
|
||||
// Constructs a SparseIndexArray that can hold up to `max_indices` sparse
|
||||
// indices, with an initial contents obtained from the given array. The rank
|
||||
// is taken from the minor dimension of the array. The major dimension of the
|
||||
// array must not exceed `max_indices`.
|
||||
SparseIndexArray(int64 max_indices, const Array2D<int64>& indices);
|
||||
|
||||
// Like above, but the array is flattened. For example, the following are
|
||||
// equivalent:
|
||||
//
|
||||
// SparseIndexArray(10, 3,
|
||||
// Array2D{
|
||||
// {0, 1, 2},
|
||||
// {3, 4, 5},
|
||||
// {6, 7, 8},
|
||||
// {9, 10, 11},
|
||||
// })
|
||||
//
|
||||
// SparseIndexArray(10, 3,
|
||||
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})
|
||||
//
|
||||
SparseIndexArray(int64 max_indices, int64 rank,
|
||||
std::vector<int64> indices = {});
|
||||
SparseIndexArray(int64 max_indices, int64 rank,
|
||||
tensorflow::gtl::ArraySlice<int64> indices);
|
||||
|
||||
// Returns the number of elements represented by the indices stored in the
|
||||
// array.
|
||||
int64 index_count() const;
|
||||
|
||||
// Returns a slice that refers to the given sparse index number. The argument
|
||||
// must be in the range [0, element_count()).
|
||||
tensorflow::gtl::ArraySlice<int64> At(int64 sparse_index_number) const;
|
||||
tensorflow::gtl::MutableArraySlice<int64> At(int64 sparse_index_number);
|
||||
|
||||
// Adds the given index at the end of the array. The new size of the
|
||||
// SparseIndexArray must not exceed `max_indices`.
|
||||
void Append(tensorflow::gtl::ArraySlice<int64> index);
|
||||
|
||||
// Removes all indices from the array.
|
||||
void Clear();
|
||||
|
||||
// Resizes the array to contain the given number of sparse indices. The new
|
||||
// size must be smaller than `max_indices`. If the new size is larger than
|
||||
// the old size, the value of the new indices is not specified.
|
||||
void Resize(int64 num_indices);
|
||||
|
||||
// Returns true iff all indices are unique and occur in sorted order, and are
|
||||
// valid for the given shape.
|
||||
bool Validate(const Shape& shape) const;
|
||||
|
||||
int64 rank() const { return rank_; }
|
||||
int64 max_indices() const { return max_indices_; }
|
||||
|
||||
// Returns a pointer to the int64 array that holds the sparse indices.
|
||||
tensorflow::gtl::MutableArraySlice<int64> mutable_data() { return &indices_; }
|
||||
tensorflow::gtl::ArraySlice<int64> data() const { return indices_; }
|
||||
|
||||
// Sorts this sparse index array along with the set of corresponding values.
|
||||
// The indices and values are sorted in the lexicographic order of the
|
||||
// indices, from smallest to largest.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// std::vector<float> v{10.0, 11.0, 12.0};
|
||||
// SparseIndexArray a(10, 3,
|
||||
// {{3, 4, 5},
|
||||
// {1, 2, 3},
|
||||
// {2, 3, 4}});
|
||||
// a.SortWithValues(&v);
|
||||
// // Prints "11.0, 12.0, 10.0":
|
||||
// std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
|
||||
//
|
||||
template <typename NativeT>
|
||||
void SortWithValues(tensorflow::gtl::MutableArraySlice<NativeT> values);
|
||||
|
||||
private:
|
||||
std::vector<int64> indices_;
|
||||
int64 rank_;
|
||||
int64 max_indices_;
|
||||
};
|
||||
|
||||
template <typename NativeT>
|
||||
void SparseIndexArray::SortWithValues(
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> values) {
|
||||
int64 num_elements = index_count();
|
||||
CHECK_EQ(values.size(), num_elements);
|
||||
std::vector<int64> sort_order;
|
||||
sort_order.reserve(num_elements);
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
sort_order.push_back(i);
|
||||
}
|
||||
auto sort_order_less = [this](int64 lhs, int64 rhs) {
|
||||
return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
|
||||
};
|
||||
std::sort(sort_order.begin(), sort_order.end(), sort_order_less);
|
||||
|
||||
// Reorder the array elements according to sort_order. Work through the array
|
||||
// and follow cycles so we can do the reorder in-place.
|
||||
tensorflow::gtl::InlinedVector<int64, 8> saved_index(rank());
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
// sort_order[i] == -1 indicates the element has already been copied.
|
||||
if (sort_order[i] < 0) {
|
||||
continue;
|
||||
} else if (i == sort_order[i]) {
|
||||
// The element is already in sorted order.
|
||||
sort_order[i] = -1;
|
||||
continue;
|
||||
}
|
||||
|
||||
std::copy_n(At(i).begin(), rank(), saved_index.begin());
|
||||
NativeT saved_value = values[i];
|
||||
int64 j = i;
|
||||
for (;;) {
|
||||
if (sort_order[j] == i) {
|
||||
std::copy_n(saved_index.begin(), rank(), At(j).begin());
|
||||
values[j] = saved_value;
|
||||
sort_order[j] = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin());
|
||||
values[j] = values[sort_order[j]];
|
||||
|
||||
int64 k = sort_order[j];
|
||||
sort_order[j] = -1;
|
||||
j = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
|
43
tensorflow/compiler/xla/sparse_index_array_test.cc
Normal file
43
tensorflow/compiler/xla/sparse_index_array_test.cc
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/sparse_index_array.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(SparseIndexArrayTest, Sort) {
|
||||
SparseIndexArray a(10, 3);
|
||||
a.Append({2, 3, 4});
|
||||
a.Append({3, 4, 5});
|
||||
a.Append({1, 2, 3});
|
||||
a.Append({5, 6, 7});
|
||||
a.Append({4, 5, 6});
|
||||
a.Append({6, 7, 8});
|
||||
std::vector<double> values = {
|
||||
12.0, 13.0, 11.0, 15.0, 14.0, 16.0,
|
||||
};
|
||||
a.SortWithValues<double>(&values);
|
||||
ASSERT_EQ(a.data(), std::vector<int64>({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5,
|
||||
6, 7, 6, 7, 8}));
|
||||
ASSERT_EQ(values, std::vector<double>({11.0, 12.0, 13.0, 14.0, 15.0, 16.0}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -1160,6 +1160,50 @@ TEST_F(WhileTest, WhileWithCallInsideCondition) {
|
||||
ComputeAndCompareR0<int32>(&builder, 5, {});
|
||||
}
|
||||
|
||||
TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
|
||||
auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
auto while_shape = ShapeUtil::MakeTupleShape(
|
||||
{scalar_s32, matrix_shape, matrix_shape, matrix_shape});
|
||||
|
||||
// Create a computation for the condition: repeat for 5 iterations.
|
||||
Computation condition;
|
||||
{
|
||||
ComputationBuilder builder(client_, "condition");
|
||||
auto state = builder.Parameter(0, while_shape, "state");
|
||||
builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
|
||||
}
|
||||
|
||||
Computation body;
|
||||
{
|
||||
ComputationBuilder builder(client_, "body");
|
||||
auto state = builder.Parameter(0, while_shape, "state");
|
||||
auto indvar = builder.GetTupleElement(state, 0);
|
||||
auto input_0 = builder.GetTupleElement(state, 1);
|
||||
auto input_1 = builder.GetTupleElement(state, 2);
|
||||
auto output = builder.Tanh(builder.Dot(input_0, input_1));
|
||||
auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1));
|
||||
auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output});
|
||||
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
|
||||
}
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto matrix_input = builder.Parameter(0, matrix_shape, "matrix");
|
||||
auto init = builder.Tuple(
|
||||
{builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input});
|
||||
auto while_instruction = builder.While(condition, body, init);
|
||||
builder.GetTupleElement(while_instruction, 3);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto param_value,
|
||||
client_->TransferToServer(*Literal::CreateR2<float>(
|
||||
{{1.0, 2.0}, {-1.0, -2.0}})));
|
||||
|
||||
ComputeAndCompareR2<float>(
|
||||
&builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},
|
||||
{param_value.get()}, ErrorSpec(4e-5));
|
||||
}
|
||||
|
||||
void BM_WhileLoop(int num_iters) {
|
||||
// Benchmark a simple kernel to measure while loop overheads.
|
||||
tensorflow::testing::StopTiming();
|
||||
|
@ -1515,7 +1515,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return TokenError(StrCat("unsupported premitive type ",
|
||||
return TokenError(StrCat("unsupported primitive type ",
|
||||
PrimitiveType_Name(shape.element_type())));
|
||||
}
|
||||
break;
|
||||
@ -1851,7 +1851,7 @@ bool HloParser::ParseWindow(Window* window) {
|
||||
if (field_name == "rhs_reversal") {
|
||||
return ParseDxD("rhs_reversal", &rhs_reversal);
|
||||
}
|
||||
return Error(loc, StrCat("unexpected attribute name: ", field_name));
|
||||
return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
|
||||
}();
|
||||
if (!ok) {
|
||||
return false;
|
||||
|
@ -398,6 +398,31 @@ std::vector<std::pair<int64, int64>> CommonFactors(
|
||||
// Removes illegal characters from filenames.
|
||||
string SanitizeFileName(string file_name);
|
||||
|
||||
// Simple wrapper around std::all_of.
|
||||
template <typename Container, typename Predicate>
|
||||
bool c_all_of(Container container, Predicate predicate) {
|
||||
return std::all_of(std::begin(container), std::end(container), predicate);
|
||||
}
|
||||
|
||||
// Simple wrapper around std::transform.
|
||||
template <typename InputContainer, typename OutputIterator,
|
||||
typename UnaryOperation>
|
||||
OutputIterator c_transform(InputContainer input_container,
|
||||
OutputIterator output_iterator,
|
||||
UnaryOperation unary_op) {
|
||||
return std::transform(std::begin(input_container), std::end(input_container),
|
||||
output_iterator, unary_op);
|
||||
}
|
||||
|
||||
// Simple wrapper around std::copy_if.
|
||||
template <class InputContainer, class OutputIterator, class UnaryPredicate>
|
||||
OutputIterator c_copy_if(InputContainer input_container,
|
||||
OutputIterator output_iterator,
|
||||
UnaryPredicate predicate) {
|
||||
return std::copy_if(std::begin(input_container), std::end(input_container),
|
||||
output_iterator, predicate);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#define XLA_LOG_LINES(SEV, STRING) \
|
||||
|
@ -120,6 +120,9 @@ enum Format {
|
||||
// The default layout, with exactly one storage location per element (ignoring
|
||||
// padding).
|
||||
DENSE = 1;
|
||||
// A sparsely encoded layout, providing only the index/value pairs of non-zero
|
||||
// elements.
|
||||
SPARSE = 2;
|
||||
}
|
||||
|
||||
// A layout describes how the array is placed in (1D) memory space. This
|
||||
@ -151,6 +154,11 @@ message Layout {
|
||||
// field must be unset unless the format is DENSE.
|
||||
PaddingValue padding_value = 3;
|
||||
|
||||
// The maximum number of elements that can be stored for SPARSE formats. This
|
||||
// can be used to determine the maximum size in bytes of arrays stored in
|
||||
// memory. This field must be unset unless the format is SPARSE.
|
||||
int64 max_sparse_elements = 5;
|
||||
|
||||
// Important: if any field is added, be sure to modify ShapeUtil::Equal()
|
||||
// appropriately to account for the new field.
|
||||
}
|
||||
@ -333,7 +341,8 @@ message LiteralProto {
|
||||
// The F16s and BF16s are encoded in little endian byte order
|
||||
bytes f16s = 11;
|
||||
bytes bf16s = 13;
|
||||
// Next = 14
|
||||
repeated int64 sparse_indices = 14;
|
||||
// Next = 15
|
||||
}
|
||||
|
||||
message WindowDimension {
|
||||
|
@ -210,6 +210,7 @@ std::unique_ptr<TaskType> Batch<TaskType>::RemoveTask() {
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<TaskType> task = std::move(tasks_.back());
|
||||
size_ -= task->size();
|
||||
tasks_.pop_back();
|
||||
return task;
|
||||
}
|
||||
|
@ -74,7 +74,9 @@ TEST(BatchTest, Basic) {
|
||||
EXPECT_EQ(task1->size(), batch.task(1).size());
|
||||
|
||||
EXPECT_EQ(7, batch.RemoveTask()->size());
|
||||
EXPECT_EQ(3, batch.size());
|
||||
EXPECT_EQ(3, batch.RemoveTask()->size());
|
||||
EXPECT_EQ(0, batch.size());
|
||||
EXPECT_TRUE(batch.empty());
|
||||
}
|
||||
|
||||
|
@ -524,7 +524,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
else:
|
||||
self.write_coordination_events[expected_element].set()
|
||||
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
|
||||
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
|
||||
actual_element = sess.run(self.next_element)
|
||||
if not done_first_event:
|
||||
done_first_event = True
|
||||
@ -611,7 +611,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
else:
|
||||
self.write_coordination_events[expected_element].set()
|
||||
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
|
||||
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
|
||||
actual_element = sess.run(self.next_element)
|
||||
if not done_first_event:
|
||||
done_first_event = True
|
||||
|
@ -7,7 +7,11 @@ exports_files([
|
||||
"generic_tree_model_proto.swig",
|
||||
])
|
||||
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_proto_library",
|
||||
"tf_pyclif_proto_library",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
@ -34,3 +38,10 @@ tf_proto_library(
|
||||
protodeps = [":generic_tree_model"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_pyclif_proto_library(
|
||||
name = "generic_tree_model_pyclif",
|
||||
proto_lib = ":generic_tree_model",
|
||||
proto_srcfile = "generic_tree_model.proto",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ py_library(
|
||||
"python/ops/__init__.py",
|
||||
"python/ops/alpha_dropout.py",
|
||||
"python/ops/cross_entropy.py",
|
||||
"python/ops/fwd_gradients.py",
|
||||
"python/ops/sampling_ops.py",
|
||||
"python/ops/scaled_softplus.py",
|
||||
],
|
||||
@ -28,6 +29,7 @@ py_library(
|
||||
"//tensorflow/python:embedding_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:nn_ops",
|
||||
@ -55,6 +57,19 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "fwd_gradients_test",
|
||||
size = "small",
|
||||
srcs = ["python/ops/fwd_gradients_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":nn_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sampling_ops_test",
|
||||
size = "small",
|
||||
|
76
tensorflow/contrib/nn/python/ops/fwd_gradients.py
Normal file
76
tensorflow/contrib/nn/python/ops/fwd_gradients.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Forward-mode derivatives."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops.gradients_impl import gradients
|
||||
|
||||
|
||||
def fwd_gradients(ys, xs, grad_xs=None, assert_unused=False):
|
||||
"""Computes forward-mode derivatives.
|
||||
|
||||
This is accomplished in pure-python using tensorflow's existing (reverse-mode)
|
||||
gradients. There is additional overhead on graph construction, but runtime
|
||||
performance should be equal to a manual implementation [citation needed].
|
||||
|
||||
See https://j-towns.github.io/2017/06/12/A-new-trick.html and
|
||||
https://github.com/HIPS/autograd/pull/175 for the original discussion of this
|
||||
method, and https://github.com/renmengye/tensorflow-forward-ad for a "direct"
|
||||
implementation.
|
||||
|
||||
Args:
|
||||
ys: A list of tensors.
|
||||
xs: A list of tensors.
|
||||
grad_xs: An optional list of tensors. If provided, must have the same length
|
||||
and shapes compatible with xs.
|
||||
assert_unused: Add assertions that intermediate values are not computed.
|
||||
Returns:
|
||||
A list of tensors of the same shapes as ys. The directional derivatives of
|
||||
ys with respect to xs in the direction grad_xs. Leaving grad_xs unspecified
|
||||
is equivalent to passing in 1s for each x in xs.
|
||||
"""
|
||||
# This version of forward-mode autodiff is based on code by Tim Cooijmans
|
||||
# and handles list arguments and certain special cases such as when the
|
||||
# ys doesn't depend on one or more of the xs, and when tf.IndexedSlices are
|
||||
# generated by the first tf.gradients call.
|
||||
|
||||
us = [array_ops.zeros_like(y) + float('nan') for y in ys]
|
||||
|
||||
dydxs = gradients(ys, xs, grad_ys=us)
|
||||
|
||||
# deal with strange types that tf.gradients returns but can't deal with
|
||||
dydxs = [ops.convert_to_tensor(dydx) if isinstance(dydx, ops.IndexedSlices)
|
||||
else dydx for dydx in dydxs]
|
||||
|
||||
if assert_unused:
|
||||
with ops.control_dependencies(dydxs):
|
||||
assert_unused = control_flow_ops.Assert(False, [1], name='fwd_gradients')
|
||||
with ops.control_dependencies([assert_unused]):
|
||||
dydxs = array_ops.identity_n(dydxs)
|
||||
|
||||
dydxs = [array_ops.zeros_like(x) if dydx is None else dydx
|
||||
for x, dydx in zip(xs, dydxs)]
|
||||
for x, dydx in zip(xs, dydxs):
|
||||
dydx.set_shape(x.shape)
|
||||
|
||||
dysdx = gradients(dydxs, us, grad_ys=grad_xs)
|
||||
|
||||
return dysdx
|
52
tensorflow/contrib/nn/python/ops/fwd_gradients_test.py
Normal file
52
tensorflow/contrib/nn/python/ops/fwd_gradients_test.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for forward_ad.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.nn.python.ops import fwd_gradients
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ForwardAdTest(test.TestCase):
|
||||
|
||||
def testSquare(self):
|
||||
x = constant_op.constant(1.)
|
||||
y = math_ops.square(x)
|
||||
grad_x = 3.
|
||||
|
||||
dydx_tf = fwd_gradients.fwd_gradients([y], [x], [grad_x])[0]
|
||||
dydx_py = 2. * grad_x
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.assertAllClose(sess.run(dydx_tf), dydx_py, 1e-6)
|
||||
|
||||
def testGather(self):
|
||||
x = constant_op.constant([1., 2., 3.])
|
||||
y = array_ops.gather(x, [0, 1])
|
||||
y.set_shape([2])
|
||||
dydx = fwd_gradients.fwd_gradients([y], [x], assert_unused=True)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(dydx)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -226,6 +226,11 @@ bool IsConstantFoldable(
|
||||
if (consider && !consider(n)) {
|
||||
return false;
|
||||
}
|
||||
// PlaceholderWithDefault shouldn't be constant folded because its output can
|
||||
// be fed non-constant values.
|
||||
if (n->type_string() == "PlaceholderWithDefault") {
|
||||
return false;
|
||||
}
|
||||
if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -338,6 +338,40 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
|
||||
EXPECT_FALSE(was_mutated);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, Placeholders) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto placeholder = ops::Placeholder(s, DT_DOUBLE);
|
||||
auto add = ops::Add(s, placeholder, 2.0);
|
||||
auto send =
|
||||
ops::_Send(s.WithOpName("send"), add, "add", "sender", 0, "receiver");
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
}
|
||||
bool was_mutated;
|
||||
Status s = ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
|
||||
nullptr, &g, &was_mutated);
|
||||
EXPECT_FALSE(was_mutated);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(s.error_message().find(
|
||||
"You must feed a value for placeholder "
|
||||
"tensor 'Placeholder' with dtype double") != string::npos);
|
||||
|
||||
Graph g2(OpRegistry::Global());
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto placeholder = ops::PlaceholderWithDefault(s, {1.0}, {1});
|
||||
auto add = ops::Add(s, placeholder, 2.0);
|
||||
auto send =
|
||||
ops::_Send(s.WithOpName("send"), add, "add", "sender", 0, "receiver");
|
||||
TF_ASSERT_OK(s.ToGraph(&g2));
|
||||
}
|
||||
// TODO(skyewm): should this have the same behavior as Placeholder?
|
||||
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
|
||||
nullptr, &g2, &was_mutated));
|
||||
EXPECT_FALSE(was_mutated);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, ControlDependencies) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
|
@ -558,6 +558,13 @@ Status ShapeRefiner::ExtractConstantSubgraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (target_node->type_string() == "PlaceholderWithDefault") {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(skyewm): more of the filtering applied in input nodes below should be
|
||||
// applied to target_node here
|
||||
|
||||
struct NodeAndRecursed {
|
||||
Node* new_node = nullptr;
|
||||
bool recursed = false;
|
||||
@ -608,6 +615,14 @@ Status ShapeRefiner::ExtractConstantSubgraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Placeholders should never be constant folded because their outputs are
|
||||
// fed by the user. Note that "Placeholder" nodes have no inputs so are
|
||||
// handled below.
|
||||
if (current_node->type_string() == "PlaceholderWithDefault") {
|
||||
*is_constant_graph = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If there is nothing more to recurse down, see if
|
||||
// the generator node is a constant.
|
||||
if (current_node->num_inputs() == 0) {
|
||||
|
@ -724,6 +724,25 @@ TEST_F(ShapeRefinerTest, PropagateRange) {
|
||||
EXPECT_EQ("[1,4,7,10]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
// Make sure PlaceholderWithDefaults aren't treated as constants.
|
||||
TEST_F(ShapeRefinerTest, NoPropagatePlaceholderWithDefault) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto constant = ops::Const<int>(root, 2);
|
||||
auto placeholder =
|
||||
ops::PlaceholderWithDefault(root, constant, PartialTensorShape());
|
||||
Node* shape_data;
|
||||
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
|
||||
.Input(placeholder.node())
|
||||
.Finalize(root.graph(), &shape_data));
|
||||
|
||||
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(constant.node()));
|
||||
TF_ASSERT_OK(m.AddNode(placeholder.node()));
|
||||
TF_ASSERT_OK(m.AddNode(shape_data));
|
||||
shape_inference::InferenceContext* ic = m.GetContext(shape_data);
|
||||
EXPECT_EQ(ic->DebugString(ic->output(0)), "?");
|
||||
}
|
||||
|
||||
TEST_F(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
// This node is used as two inputs to 'range'.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/costs/graph_memory.h"
|
||||
#include <list>
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
@ -163,6 +164,8 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) {
|
||||
|
||||
NodeMap node_map(&item_.graph);
|
||||
for (const auto& dev_stats : timeline.dev_stats()) {
|
||||
const string& device_name = dev_stats.device();
|
||||
const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:"));
|
||||
std::list<LiveTensor>& device_tensors =
|
||||
live_tensors_per_device[dev_stats.device()];
|
||||
for (const auto& node_stats : dev_stats.node_stats()) {
|
||||
@ -194,7 +197,24 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) {
|
||||
// graph (e.g _Send/_Recv nodes).
|
||||
continue;
|
||||
}
|
||||
for (const string& input : node->input()) {
|
||||
std::unordered_set<int> swapped_inputs;
|
||||
if (is_gpu) {
|
||||
auto it = node->attr().find("_swap_to_host");
|
||||
if (it != node->attr().end()) {
|
||||
const AttrValue& val = it->second;
|
||||
for (int port_id : val.list().i()) {
|
||||
swapped_inputs.insert(port_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
if (swapped_inputs.find(i) != swapped_inputs.end()) {
|
||||
// The memory of swapped inputs will be released as early as possible:
|
||||
// therefore ignore this input when determining the deallocation time
|
||||
// of the tensor.
|
||||
continue;
|
||||
}
|
||||
const string& input = node->input(i);
|
||||
int position;
|
||||
string input_node = ParseNodeName(input, &position);
|
||||
if (position < 0) {
|
||||
|
@ -134,6 +134,62 @@ TEST_F(GraphMemoryTest, MultiDevice) {
|
||||
EXPECT_EQ(gpu_expected, gpu_tensors);
|
||||
}
|
||||
|
||||
TEST_F(GraphMemoryTest, GpuSwapping) {
|
||||
TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false, {"/GPU:0"});
|
||||
GrapplerItem item;
|
||||
CHECK(fake_input.NextItem(&item));
|
||||
item.feed.clear();
|
||||
|
||||
{
|
||||
// Estimate the max memory usage for the graph.
|
||||
GraphMemory memory(item);
|
||||
Status s = memory.InferStatically(devices_);
|
||||
TF_CHECK_OK(s);
|
||||
|
||||
const GraphMemory::MemoryUsage& gpu_mem =
|
||||
memory.GetPeakMemoryUsage("/GPU:0");
|
||||
EXPECT_EQ(20971520, gpu_mem.used_memory);
|
||||
std::set<string> gpu_tensors;
|
||||
for (const auto& t : gpu_mem.live_tensors) {
|
||||
gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
||||
}
|
||||
std::set<string> gpu_expected;
|
||||
gpu_expected.insert("Square:0");
|
||||
gpu_expected.insert("Square_1:0");
|
||||
gpu_expected.insert("AddN:0");
|
||||
gpu_expected.insert("AddN_1:0");
|
||||
gpu_expected.insert("AddN_2:0");
|
||||
EXPECT_EQ(gpu_expected, gpu_tensors);
|
||||
}
|
||||
|
||||
{
|
||||
// Swap the first input to node AddN_1: its fanin (the square nodes) should
|
||||
// not appear in the max cut anymore.
|
||||
for (auto& node : *item.graph.mutable_node()) {
|
||||
if (node.name() == "AddN_1") {
|
||||
(*node.mutable_attr())["_swap_to_host"].mutable_list()->add_i(0);
|
||||
}
|
||||
}
|
||||
GraphMemory memory(item);
|
||||
Status s = memory.InferStatically(devices_);
|
||||
TF_CHECK_OK(s);
|
||||
const GraphMemory::MemoryUsage& new_gpu_mem =
|
||||
memory.GetPeakMemoryUsage("/GPU:0");
|
||||
EXPECT_EQ(20971520, new_gpu_mem.used_memory);
|
||||
std::set<string> new_gpu_tensors;
|
||||
for (const auto& t : new_gpu_mem.live_tensors) {
|
||||
new_gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
||||
}
|
||||
std::set<string> new_gpu_expected;
|
||||
new_gpu_expected.insert("AddN:0");
|
||||
new_gpu_expected.insert("AddN_1:0");
|
||||
new_gpu_expected.insert("AddN_2:0");
|
||||
new_gpu_expected.insert("AddN_3:0");
|
||||
new_gpu_expected.insert("AddN_4:0");
|
||||
EXPECT_EQ(new_gpu_expected, new_gpu_tensors);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphMemoryTest, CtrlDependencies) {
|
||||
// Build a simple graph with a control dependency.
|
||||
Scope s = Scope::NewRootScope();
|
||||
|
@ -31,8 +31,6 @@ namespace {
|
||||
GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
|
||||
bool use_multiple_devices, bool insert_queue,
|
||||
const std::vector<string>& device_names) {
|
||||
CHECK_GE(device_names.size(), width);
|
||||
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
@ -49,13 +47,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
|
||||
std::vector<Output> this_stage;
|
||||
for (int j = 0; j < width; j++) {
|
||||
if (last_stage.size() == 1) {
|
||||
Output unary_op =
|
||||
Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
|
||||
last_stage[0]);
|
||||
Output unary_op = Square(
|
||||
s.WithDevice(
|
||||
device_names[use_multiple_devices ? j % device_names.size()
|
||||
: 0]),
|
||||
last_stage[0]);
|
||||
this_stage.push_back(unary_op);
|
||||
} else {
|
||||
Output combine =
|
||||
AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
|
||||
AddN(s.WithDevice(
|
||||
device_names[use_multiple_devices ? j % device_names.size()
|
||||
: 0]),
|
||||
last_stage);
|
||||
this_stage.push_back(combine);
|
||||
}
|
||||
|
@ -433,13 +433,42 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
|
||||
id = --min_id;
|
||||
}
|
||||
}
|
||||
|
||||
// Beware: the reduction dimensions computed by the BCast class are valid iff
|
||||
// we assume that two distinct symbolic dimensions can't be equal and a
|
||||
// symbolic dimension can't be equal to 1. This is often but not always true,
|
||||
// so to make this optimization safe we filter out these cases.
|
||||
const int common_dims = std::min(shape1.size(), shape2.size());
|
||||
for (int i = 0; i < common_dims; ++i) {
|
||||
if (shape1[i] >= 0 && shape2[i] >= 0) {
|
||||
continue;
|
||||
}
|
||||
if (shape1[i] != shape2[i]) {
|
||||
// We're either dealing with 2 different symbolic dimensions or a symbolic
|
||||
// and a know dimensions. We can't be sure whether both are equal or not,
|
||||
// so we can't be sure whether we'll be broadcasting or not.
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
// These extra dims could be equal to 1, in which case there is no
|
||||
// broadcasting. It could also be greater than 1, in which case there would
|
||||
// be broadcasting. Since we don't know, we'll just punt.
|
||||
for (int i = common_dims; i < shape1.size(); ++i) {
|
||||
if (shape1[i] < 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
for (int i = common_dims; i < shape2.size(); ++i) {
|
||||
if (shape2[i] < 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
BCast bcast(shape1, shape2);
|
||||
if (!bcast.IsValid()) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Beware: the reduction dimensions are valid iff we assume that two distinct
|
||||
// symbolic dimensions can't be equal. This is often but not always true, so
|
||||
// this optimization isn't safe.
|
||||
|
||||
BCast::Vec reduce_dims[2];
|
||||
reduce_dims[0] = bcast.grad_x_reduce_idx();
|
||||
reduce_dims[1] = bcast.grad_y_reduce_idx();
|
||||
@ -447,26 +476,27 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
|
||||
const DataType type = node.attr().at("T").type();
|
||||
NodeDef* out[2];
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
if (!reduce_dims[j].empty()) {
|
||||
// This is the case when a tensor dimension of 1 is matched against an
|
||||
// unknown dimension. The unknown dimension could also be equal to 1, in
|
||||
// which case there would be no reduction.
|
||||
out[j] = nullptr;
|
||||
} else {
|
||||
string const_name = OptimizedNodeName(node, strings::StrCat("-", j));
|
||||
out[j] = node_map_->GetNode(const_name);
|
||||
if (out[j] == nullptr) {
|
||||
out[j] = graph_->add_node();
|
||||
Tensor value(type, TensorShape({0}));
|
||||
*out[j] = CreateNodeDef(const_name, TensorValue(&value));
|
||||
out[j]->set_device(node.device());
|
||||
node_map_->AddNode(const_name, out[j]);
|
||||
string ctrl_dep =
|
||||
AddControlDependency(node.name(), graph_, node_map_.get());
|
||||
*out[j]->add_input() = ctrl_dep;
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
|
||||
int reduction_indices = reduce_dims[j].size();
|
||||
Tensor value(type, TensorShape({reduction_indices}));
|
||||
for (int i = 0; i < reduction_indices; ++i) {
|
||||
if (type == DT_INT32) {
|
||||
value.vec<int32>()(i) = reduce_dims[j][i];
|
||||
} else {
|
||||
value.vec<int64>()(i) = reduce_dims[j][i];
|
||||
}
|
||||
}
|
||||
string const_name = OptimizedNodeName(node, strings::StrCat("-", j));
|
||||
out[j] = node_map_->GetNode(const_name);
|
||||
if (out[j] == nullptr) {
|
||||
out[j] = graph_->add_node();
|
||||
*out[j] = CreateNodeDef(const_name, TensorValue(&value));
|
||||
out[j]->set_device(node.device());
|
||||
node_map_->AddNode(const_name, out[j]);
|
||||
string ctrl_dep =
|
||||
AddControlDependency(node.name(), graph_, node_map_.get());
|
||||
*out[j]->add_input() = ctrl_dep;
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
|
||||
}
|
||||
}
|
||||
|
||||
const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
|
||||
@ -584,12 +614,11 @@ Status ConstantFolding::MaterializeReductionIndices(
|
||||
|
||||
Status ConstantFolding::MaterializeConstants(
|
||||
const GraphProperties& properties) {
|
||||
const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
|
||||
const int node_count = graph_->node_size();
|
||||
for (int i = 0; i < node_count; ++i) {
|
||||
NodeDef& node = *graph_->mutable_node(i);
|
||||
const string& op = node.op();
|
||||
if (is_aggressive && op == "BroadcastGradientArgs") {
|
||||
if (op == "BroadcastGradientArgs") {
|
||||
TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
|
||||
} else if (IsReduction(node)) {
|
||||
TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
|
||||
|
@ -1373,21 +1373,14 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
|
||||
} else if (node.name() == "p1") {
|
||||
++found;
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("ConstantFolding/i-0", node.input(0));
|
||||
EXPECT_EQ("i", node.input(0));
|
||||
} else if (node.name() == "p2") {
|
||||
++found;
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("i:1", node.input(0));
|
||||
} else if (node.name() == "ConstantFolding/i-0") {
|
||||
++found;
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("^i", node.input(0));
|
||||
EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
|
||||
.num_elements());
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(7, found);
|
||||
EXPECT_EQ(6, found);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
|
||||
|
@ -148,7 +148,10 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
|
||||
first_control_input = i;
|
||||
break;
|
||||
}
|
||||
if (actual.input(i) != expected.input(i)) {
|
||||
// Special case for inputs: "tensor" is equivalent to "tensor:0"
|
||||
if (actual.input(i) != expected.input(i) &&
|
||||
actual.input(i) != strings::StrCat(expected.input(i), ":0") &&
|
||||
strings::StrCat(actual.input(i), ":0") != expected.input(i)) {
|
||||
if (diff != nullptr) {
|
||||
*diff = strings::StrCat("Node named '", actual.name(), "' has input ",
|
||||
i, " '", actual.input(i),
|
||||
|
@ -5922,6 +5922,7 @@ func RandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types [
|
||||
//
|
||||
// This op is hidden from public in Python. It is used by TensorFlow Debugger to
|
||||
// register gradient tensors for gradient debugging.
|
||||
// This op operates on non-reference-type tensors.
|
||||
func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
|
@ -344,7 +344,7 @@ def implicit_val_and_grad(f):
|
||||
|
||||
def grad_fn(*args):
|
||||
"""Computes the gradient of the wrapped function."""
|
||||
tape.push_new_tape()
|
||||
this_tape = tape.push_new_tape()
|
||||
try:
|
||||
end_node = f(*args)
|
||||
if end_node is None:
|
||||
@ -352,10 +352,10 @@ def implicit_val_and_grad(f):
|
||||
"did you forget to return a value from {}?".format(
|
||||
f.__name__))
|
||||
finally:
|
||||
popped_tape = tape.pop_tape()
|
||||
tape.pop_tape(this_tape)
|
||||
# Sorting variables by id, which is monotonically increasing in construction
|
||||
# order. This ensures unique order across executions.
|
||||
variables = list(sorted(popped_tape.watched_variables(),
|
||||
variables = list(sorted(this_tape.watched_variables(),
|
||||
key=lambda v: v.handle._id)) # pylint: disable=protected-access
|
||||
sources = [x.handle for x in variables]
|
||||
|
||||
@ -363,7 +363,7 @@ def implicit_val_and_grad(f):
|
||||
raise ValueError("No trainable variables were accessed while the "
|
||||
"function was being computed.")
|
||||
grad = imperative_grad.imperative_grad(_default_vspace,
|
||||
popped_tape,
|
||||
this_tape,
|
||||
nest.flatten(end_node),
|
||||
sources)
|
||||
return end_node, list(zip(grad, variables))
|
||||
@ -652,7 +652,7 @@ def make_vjp(f, params=None):
|
||||
"""Computes the value and gradient of the decorated function."""
|
||||
parameter_positions = _get_arg_spec(f, params, args)
|
||||
assert not kwds, "The gradient function can't take keyword arguments."
|
||||
tape.push_new_tape()
|
||||
this_tape = tape.push_new_tape()
|
||||
try:
|
||||
sources = []
|
||||
args = [
|
||||
@ -673,12 +673,12 @@ def make_vjp(f, params=None):
|
||||
flat_result = [gen_array_ops.identity(x) for x in flat_result]
|
||||
result = nest.pack_sequence_as(result, flat_result)
|
||||
finally:
|
||||
t = tape.pop_tape()
|
||||
tape.pop_tape(this_tape)
|
||||
def vjp(dy=None):
|
||||
if dy is not None:
|
||||
dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
|
||||
return imperative_grad.imperative_grad(
|
||||
_default_vspace, t, nest.flatten(result), sources,
|
||||
_default_vspace, this_tape, nest.flatten(result), sources,
|
||||
output_gradients=dy)
|
||||
return result, vjp
|
||||
|
||||
@ -835,11 +835,11 @@ class GradientTape(object):
|
||||
self._persistent = persistent
|
||||
|
||||
def __enter__(self):
|
||||
tape.push_new_tape(persistent=self._persistent)
|
||||
self._tape = tape.push_new_tape(persistent=self._persistent)
|
||||
return self
|
||||
|
||||
def __exit__(self, typ, value, traceback):
|
||||
self._tape = tape.pop_tape()
|
||||
tape.pop_tape(self._tape)
|
||||
|
||||
def watch(self, tensor):
|
||||
"""Ensures that `tensor` is being traced by this tape.
|
||||
|
@ -24,9 +24,12 @@ import copy
|
||||
import random
|
||||
import threading
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
GRAPH_MODE = 0
|
||||
@ -398,6 +401,36 @@ class Context(object):
|
||||
"""Get the list of post-execution callbacks added to the context."""
|
||||
return self._post_execution_callbacks
|
||||
|
||||
def enable_run_metadata(self):
|
||||
"""Enables tracing of op execution via RunMetadata.
|
||||
|
||||
To retrieve the accumulated metadata call context.export_run_metadata()
|
||||
and to stop tracing call context.disable_run_metadata().
|
||||
"""
|
||||
pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle)
|
||||
|
||||
def disable_run_metadata(self):
|
||||
"""Disables tracing of op execution via RunMetadata."""
|
||||
pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
|
||||
|
||||
def export_run_metadata(self):
|
||||
"""Returns a RunMetadata proto with accumulated information.
|
||||
|
||||
The returned protocol buffer contains information since the most recent call
|
||||
to either enable_run_metadata or export_run_metadata.
|
||||
|
||||
Returns:
|
||||
A RunMetadata protocol buffer.
|
||||
"""
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.TFE_ContextExportRunMetadata(
|
||||
self._context_handle, buffer_, status)
|
||||
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
run_metadata.ParseFromString(compat.as_bytes(proto_data))
|
||||
return run_metadata
|
||||
|
||||
_context = None
|
||||
_context_lock = threading.Lock()
|
||||
|
||||
@ -516,3 +549,29 @@ def num_gpus():
|
||||
The number of available GPU devices.
|
||||
"""
|
||||
return context().num_gpus()
|
||||
|
||||
|
||||
def enable_run_metadata():
|
||||
"""Enables tracing of op execution via RunMetadata.
|
||||
|
||||
To retrieve the accumulated metadata call context.export_run_metadata()
|
||||
and to stop tracing call context.disable_run_metadata().
|
||||
"""
|
||||
context().enable_run_metadata()
|
||||
|
||||
|
||||
def disable_run_metadata():
|
||||
"""Disables tracing of op execution via RunMetadata."""
|
||||
context().disable_run_metadata()
|
||||
|
||||
|
||||
def export_run_metadata():
|
||||
"""Returns a RunMetadata proto with accumulated information.
|
||||
|
||||
The returned protocol buffer contains information since the most recent call
|
||||
to either enable_run_metadata or export_run_metadata.
|
||||
|
||||
Returns:
|
||||
A RunMetadata protocol buffer.
|
||||
"""
|
||||
return context().export_run_metadata()
|
||||
|
@ -84,6 +84,20 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(has_cpu_device)
|
||||
del ctx
|
||||
|
||||
def testRunMetadata(self):
|
||||
context.enable_run_metadata()
|
||||
t = constant_op.constant(1.0)
|
||||
_ = t + t # Runs an operation which will be in the RunMetadata
|
||||
run_metadata = context.export_run_metadata()
|
||||
context.disable_run_metadata()
|
||||
step_stats = run_metadata.step_stats
|
||||
self.assertGreater(len(step_stats.dev_stats), 0)
|
||||
cpu_stats = step_stats.dev_stats[0]
|
||||
self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
|
||||
cpu_stats.device)
|
||||
self.assertEqual(len(cpu_stats.node_stats), 1)
|
||||
self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add')
|
||||
|
||||
def testContextStackContainsEagerMode(self):
|
||||
# Eager execution has been enabled, and no other context
|
||||
# switch has occurred, so `context_stack` should contain
|
||||
|
@ -544,11 +544,12 @@ def _defun_internal(name, func, args, kwds):
|
||||
func_inputs = _get_defun_inputs(args)
|
||||
|
||||
with capture_tensors(captures):
|
||||
tape.push_new_tape()
|
||||
this_tape = tape.push_new_tape()
|
||||
try:
|
||||
func_outputs = func(*func_inputs, **kwds)
|
||||
finally:
|
||||
variables = tape.pop_tape().watched_variables()
|
||||
tape.pop_tape(this_tape)
|
||||
variables = this_tape.watched_variables()
|
||||
|
||||
# Returning a closed-over tensor as an output does not trigger a
|
||||
# call to convert_to_tensor, so we manually capture all such tensors.
|
||||
|
@ -332,7 +332,7 @@ void EagerTensor_dealloc(EagerTensor* self) {
|
||||
tensorflow::ClearDecrefCache();
|
||||
auto id = self->id;
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
TFE_Py_TapeStackDeleteTrace(id);
|
||||
TFE_Py_TapeSetDeleteTrace(id);
|
||||
}
|
||||
|
||||
// Getter for `_id`.
|
||||
|
@ -87,22 +87,25 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
|
||||
// newly created type, or nullptr on error.
|
||||
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
|
||||
|
||||
// Pushes a new tape into the thread-local stack.
|
||||
// `persistent` must be a PyBool_Type, i.e either Py_True or Py_False
|
||||
void TFE_Py_TapeStackPushNew(PyObject* persistent);
|
||||
// Creates a new tape and adds it to the active set. `persistent` must be a
|
||||
// PyBool_Type, i.e either Py_True or Py_False
|
||||
PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
|
||||
|
||||
// Pops the tape from the top of the stack and returns it.
|
||||
PyObject* TFE_Py_TapeStackPop();
|
||||
|
||||
// Pushes an existing tape onto the stack.
|
||||
void TFE_Py_TapeStackPush(PyObject* tape);
|
||||
// Removes the passed tape from the set of active tapes.
|
||||
void TFE_Py_TapeSetRemove(PyObject* tape);
|
||||
|
||||
// Returns true if the tape stack is empty.
|
||||
PyObject* TFE_Py_TapeStackIsEmpty();
|
||||
PyObject* TFE_Py_TapeSetIsEmpty();
|
||||
|
||||
PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors);
|
||||
void TFE_Py_TapeStackWatch(PyObject* tensor);
|
||||
void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id);
|
||||
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
|
||||
void TFE_Py_TapeSetWatch(PyObject* tensor);
|
||||
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
|
||||
|
||||
// Stops any gradient recording on the current thread.
|
||||
void TFE_Py_TapeSetStopOnThread();
|
||||
|
||||
// Restarts gradient recording on the current thread.
|
||||
void TFE_Py_TapeSetRestartOnThread();
|
||||
|
||||
// Records an operation in the gradient tape stack.type is a string for the
|
||||
// operation type, used in the backprop code. output_tensors should be a list of
|
||||
@ -111,13 +114,12 @@ void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id);
|
||||
// operation. backward_function should be the function to be called during
|
||||
// backprop to, given the gradients of the output tensors, produce the gradients
|
||||
// of the input tensors.
|
||||
void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
|
||||
PyObject* output_tensors,
|
||||
PyObject* input_tensor_ids,
|
||||
PyObject* backward_function);
|
||||
void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
|
||||
PyObject* input_tensor_ids,
|
||||
PyObject* backward_function);
|
||||
|
||||
// Watches the given variable object on the given tape.
|
||||
void TFE_Py_TapeStackWatchVariable(PyObject* variable);
|
||||
void TFE_Py_TapeSetWatchVariable(PyObject* variable);
|
||||
|
||||
// Computes a gradient based on information recorded on the tape.`tape` must
|
||||
// have been produced by TFE_Py_NewTape. `vspace` must be a
|
||||
|
@ -538,62 +538,67 @@ static PyTypeObject TFE_Py_Tape_Type = {
|
||||
"TFE_Py_Tape objects", /* tp_doc */
|
||||
};
|
||||
|
||||
// Note: in the current design no mutex is needed here because of the python
|
||||
// GIL, which is always held when any TFE_Py_* methods are called. We should
|
||||
// revisit this if/when decide to not hold the GIL while manipulating the tape
|
||||
// stack.
|
||||
static std::unordered_set<TFE_Py_Tape*>* tape_set = nullptr;
|
||||
std::unordered_set<TFE_Py_Tape*>* GetTapeSet() {
|
||||
if (tape_set == nullptr) {
|
||||
tape_set = new std::unordered_set<TFE_Py_Tape*>;
|
||||
}
|
||||
return tape_set;
|
||||
}
|
||||
|
||||
// xcode 7 doesn't define thread_local, so for compatibility we implement our
|
||||
// own. TODO(apassos) remove once we can deprecate xcode 7.
|
||||
#ifndef __APPLE__
|
||||
std::vector<TFE_Py_Tape*>* GetTapeStack() {
|
||||
thread_local std::vector<TFE_Py_Tape*> tape_stack;
|
||||
return &tape_stack;
|
||||
bool* ThreadTapeIsStopped() {
|
||||
thread_local bool thread_tape_is_stopped{false};
|
||||
return &thread_tape_is_stopped;
|
||||
}
|
||||
#else
|
||||
static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED);
|
||||
static std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>*
|
||||
tape_stack GUARDED_BY(stack_mu) = nullptr;
|
||||
std::vector<TFE_Py_Tape*>* GetTapeStack() {
|
||||
tensorflow::mutex_lock ml(stack_mu);
|
||||
if (tape_stack == nullptr) {
|
||||
tape_stack =
|
||||
new std::unordered_map<std::thread::id, std::vector<TFE_Py_Tape*>*>;
|
||||
static std::unordered_map<std::thread::id, bool>* tape_is_stopped = nullptr;
|
||||
bool* ThreadTapeIsStopped() {
|
||||
if (tape_is_stopped == nullptr) {
|
||||
tape_is_stopped = new std::unordered_map<std::thread::id, bool>;
|
||||
}
|
||||
auto it = tape_stack->find(std::this_thread::get_id());
|
||||
if (it != tape_stack->end()) {
|
||||
return it->second;
|
||||
auto it = tape_is_stopped->find(std::this_thread::get_id());
|
||||
if (it != tape_is_stopped->end()) {
|
||||
return &(it->second);
|
||||
}
|
||||
return tape_stack
|
||||
->emplace(std::this_thread::get_id(), new std::vector<TFE_Py_Tape*>)
|
||||
.first->second;
|
||||
return &(tape_is_stopped->emplace(std::this_thread::get_id(), false)
|
||||
.first->second);
|
||||
}
|
||||
#endif
|
||||
|
||||
void TFE_Py_TapeStackPushNew(PyObject* persistent) {
|
||||
void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
|
||||
|
||||
void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
|
||||
|
||||
PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
|
||||
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return;
|
||||
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
|
||||
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
|
||||
tape->tape = new GradientTape(persistent == Py_True);
|
||||
GetTapeStack()->push_back(tape);
|
||||
}
|
||||
|
||||
void TFE_Py_TapeStackPush(PyObject* tape) {
|
||||
Py_INCREF(tape);
|
||||
GetTapeStack()->push_back(reinterpret_cast<TFE_Py_Tape*>(tape));
|
||||
GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
|
||||
return reinterpret_cast<PyObject*>(tape);
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TapeStackIsEmpty() {
|
||||
if (GetTapeStack()->empty()) {
|
||||
PyObject* TFE_Py_TapeSetIsEmpty() {
|
||||
if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TapeStackPop() {
|
||||
auto* stack = GetTapeStack();
|
||||
if (stack->empty()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "tape stack is empty.");
|
||||
return nullptr;
|
||||
}
|
||||
TFE_Py_Tape* top = stack->back();
|
||||
stack->pop_back();
|
||||
return reinterpret_cast<PyObject*>(top);
|
||||
void TFE_Py_TapeSetRemove(PyObject* tape) {
|
||||
auto* stack = GetTapeSet();
|
||||
stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
|
||||
// We kept a reference to the tape in the set to ensure it wouldn't get
|
||||
// deleted under us; cleaning it up here.
|
||||
Py_DECREF(tape);
|
||||
}
|
||||
|
||||
static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
|
||||
@ -620,12 +625,15 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
|
||||
return tensor_ids;
|
||||
}
|
||||
|
||||
PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
|
||||
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
|
||||
if (tensors == Py_None) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
auto* stack = GetTapeStack();
|
||||
if (stack->empty()) {
|
||||
if (*ThreadTapeIsStopped()) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
auto* tape_set = GetTapeSet();
|
||||
if (tape_set->empty()) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
||||
@ -642,7 +650,7 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
|
||||
tensor_ids.push_back(FastTensorId(item));
|
||||
}
|
||||
Py_DECREF(seq);
|
||||
for (TFE_Py_Tape* tape : *stack) {
|
||||
for (TFE_Py_Tape* tape : *tape_set) {
|
||||
if (tape->tape->ShouldRecord(tensor_ids)) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
@ -650,12 +658,12 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
void TFE_Py_TapeStackWatch(PyObject* tensor) {
|
||||
void TFE_Py_TapeSetWatch(PyObject* tensor) {
|
||||
tensorflow::int64 tensor_id = FastTensorId(tensor);
|
||||
if (PyErr_Occurred()) {
|
||||
return;
|
||||
}
|
||||
for (TFE_Py_Tape* tape : *GetTapeStack()) {
|
||||
for (TFE_Py_Tape* tape : *GetTapeSet()) {
|
||||
tape->tape->Watch(tensor_id);
|
||||
}
|
||||
}
|
||||
@ -720,8 +728,8 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
|
||||
return list;
|
||||
}
|
||||
|
||||
void TFE_Py_TapeStackWatchVariable(PyObject* variable) {
|
||||
for (TFE_Py_Tape* tape : *GetTapeStack()) {
|
||||
void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
|
||||
for (TFE_Py_Tape* tape : *GetTapeSet()) {
|
||||
tape->tape->WatchVariable(variable);
|
||||
}
|
||||
}
|
||||
@ -736,12 +744,11 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
|
||||
PyObject* output_tensors,
|
||||
PyObject* input_tensors,
|
||||
PyObject* backward_function) {
|
||||
auto* stack = GetTapeStack();
|
||||
if (stack->empty()) {
|
||||
void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
|
||||
PyObject* input_tensors,
|
||||
PyObject* backward_function) {
|
||||
auto* set = GetTapeSet();
|
||||
if (set->empty()) {
|
||||
return;
|
||||
}
|
||||
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
|
||||
@ -776,7 +783,7 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
|
||||
return;
|
||||
}
|
||||
|
||||
for (TFE_Py_Tape* tape : *stack) {
|
||||
for (TFE_Py_Tape* tape : *set) {
|
||||
Py_INCREF(backward_function);
|
||||
tape->tape->RecordOperation(
|
||||
op_type_str, output_info, input_ids, backward_function,
|
||||
@ -784,8 +791,8 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type,
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id) {
|
||||
for (TFE_Py_Tape* tape : *GetTapeStack()) {
|
||||
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
|
||||
for (TFE_Py_Tape* tape : *GetTapeSet()) {
|
||||
tape->tape->DeleteTrace(tensor_id);
|
||||
}
|
||||
}
|
||||
|
@ -35,7 +35,8 @@ class Tape(object):
|
||||
|
||||
def push_new_tape(persistent=False):
|
||||
"""Pushes a new tape onto the tape stack."""
|
||||
pywrap_tensorflow.TFE_Py_TapeStackPushNew(persistent)
|
||||
tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent)
|
||||
return Tape(tape)
|
||||
|
||||
|
||||
def watch(tensor):
|
||||
@ -44,7 +45,7 @@ def watch(tensor):
|
||||
Args:
|
||||
tensor: tensor to be watched.
|
||||
"""
|
||||
pywrap_tensorflow.TFE_Py_TapeStackWatch(tensor)
|
||||
pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor)
|
||||
|
||||
|
||||
def watch_variable(variable):
|
||||
@ -53,42 +54,39 @@ def watch_variable(variable):
|
||||
Args:
|
||||
variable: variable to be watched.
|
||||
"""
|
||||
pywrap_tensorflow.TFE_Py_TapeStackWatchVariable(variable)
|
||||
pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable)
|
||||
|
||||
|
||||
def pop_tape():
|
||||
def pop_tape(tape):
|
||||
"""Pops the top tape in the stack, if any."""
|
||||
return Tape(pywrap_tensorflow.TFE_Py_TapeStackPop())
|
||||
pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def stop_recording():
|
||||
stack = []
|
||||
while not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty():
|
||||
stack.append(pop_tape()._tape) # pylint: disable=protected-access
|
||||
try:
|
||||
pywrap_tensorflow.TFE_Py_TapeSetStopOnThread()
|
||||
yield
|
||||
finally:
|
||||
for tape in reversed(stack):
|
||||
pywrap_tensorflow.TFE_Py_TapeStackPush(tape)
|
||||
pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread()
|
||||
|
||||
|
||||
def should_record(tensors):
|
||||
"""Returns true if any tape in the stack watches any of these tensors."""
|
||||
return pywrap_tensorflow.TFE_Py_TapeStackShouldRecord(tensors)
|
||||
return pywrap_tensorflow.TFE_Py_TapeSetShouldRecord(tensors)
|
||||
|
||||
|
||||
def record_operation(op_type, output_tensors, input_tensors, backward_function):
|
||||
"""Records the operation on all tapes in the stack."""
|
||||
pywrap_tensorflow.TFE_Py_TapeStackRecordOperation(
|
||||
pywrap_tensorflow.TFE_Py_TapeSetRecordOperation(
|
||||
op_type, output_tensors, input_tensors, backward_function)
|
||||
|
||||
|
||||
def delete_trace(tensor_id):
|
||||
"""Deletes traces for this Tensor from all tapes in the stack."""
|
||||
pywrap_tensorflow.TFE_Py_TapeStackDeleteTrace(tensor_id)
|
||||
pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id)
|
||||
|
||||
|
||||
def could_possibly_record():
|
||||
"""Returns True if any tape is active."""
|
||||
return not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty()
|
||||
return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty()
|
||||
|
@ -614,7 +614,7 @@ class Estimator(object):
|
||||
if isinstance(result, (list, tuple)):
|
||||
if len(result) != 2:
|
||||
raise ValueError(
|
||||
'input_fn should return (feautures, labels) as a len 2 tuple.')
|
||||
'input_fn should return (features, labels) as a len 2 tuple.')
|
||||
return result[0], result[1], input_hooks
|
||||
return result, None, input_hooks
|
||||
|
||||
|
@ -121,7 +121,10 @@ class _WarmStartSettings(
|
||||
# where ws could be defined as:
|
||||
|
||||
# Warm-start all weights in the model (input layer and hidden weights).
|
||||
# Either the directory or a specific checkpoint can be provided (in the case
|
||||
# of the former, the latest checkpoint will be used).
|
||||
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp")
|
||||
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
|
||||
|
||||
# Warm-start only the embeddings (input layer).
|
||||
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp",
|
||||
@ -348,7 +351,7 @@ def _warmstart_var_with_vocab(var,
|
||||
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
|
||||
# remapping too.
|
||||
init = checkpoint_ops._load_and_remap_matrix_initializer(
|
||||
ckpt_path=saver.latest_checkpoint(prev_ckpt),
|
||||
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
|
||||
old_tensor_name=prev_tensor_name,
|
||||
new_row_vocab_size=current_vocab_size,
|
||||
new_col_vocab_size=v_shape[1],
|
||||
|
@ -50,9 +50,7 @@ class WarmStartingUtilTest(test.TestCase):
|
||||
sess.run(variables.global_variables_initializer())
|
||||
saver = saver_lib.Saver()
|
||||
ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
|
||||
ckpt_state_name = "checkpoint"
|
||||
saver.save(
|
||||
sess, ckpt_prefix, global_step=0, latest_filename=ckpt_state_name)
|
||||
saver.save(sess, ckpt_prefix, global_step=0)
|
||||
|
||||
def _create_prev_run_var(self,
|
||||
var_name,
|
||||
@ -408,6 +406,44 @@ class WarmStartingUtilTest(test.TestCase):
|
||||
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
|
||||
sess)
|
||||
|
||||
def testWarmStart_ExplicitCheckpointFile(self):
|
||||
# Create vocab for sparse column "sc_vocab".
|
||||
vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
||||
"vocab")
|
||||
# Create feature column.
|
||||
sc_vocab = fc.categorical_column_with_vocabulary_file(
|
||||
"sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
|
||||
|
||||
# Save checkpoint from which to warm-start.
|
||||
_, prev_vocab_val = self._create_prev_run_var(
|
||||
"linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
|
||||
|
||||
partitioner = lambda shape, dtype: [1] * len(shape)
|
||||
# New graph, new session WITHOUT warmstarting.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Without warmstarting, the weights should be initialized using default
|
||||
# initializer (which is init_ops.zeros_initializer).
|
||||
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
|
||||
sess)
|
||||
|
||||
# New graph, new session with warmstarting.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
|
||||
# Since old vocab is not explicitly set in WarmStartSettings, the old
|
||||
# vocab is assumed to be same as new vocab.
|
||||
ws_util._warmstart(ws_util._WarmStartSettings(
|
||||
# Explicitly provide the file prefix instead of just the dir.
|
||||
os.path.join(self.get_temp_dir(), "model-0"),
|
||||
vars_to_warmstart=".*sc_vocab.*"))
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Verify weights were correctly warmstarted.
|
||||
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
|
||||
sess)
|
||||
|
||||
def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
|
||||
# Create old vocabulary, and use a size smaller than the total number of
|
||||
# entries.
|
||||
|
@ -59,6 +59,7 @@ def _TestDir(test_name):
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class SimpleMetaGraphTest(test.TestCase):
|
||||
|
||||
def testNoVariables(self):
|
||||
@ -103,7 +104,8 @@ class SimpleMetaGraphTest(test.TestCase):
|
||||
# Re-exports the current graph state for comparison to the original.
|
||||
new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename +
|
||||
"_new")
|
||||
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
|
||||
test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
|
||||
new_meta_graph_def)
|
||||
|
||||
# Ensures that we can still get a reference to our graph collections.
|
||||
new_input_tensor = ops.get_collection("input_tensor")[0]
|
||||
@ -226,7 +228,7 @@ class SimpleMetaGraphTest(test.TestCase):
|
||||
double_nested_complex_node_def = None
|
||||
for function_def in meta_graph_def.graph_def.library.function:
|
||||
for node_def in function_def.node_def:
|
||||
if node_def.name == "double_nested_complex":
|
||||
if node_def.name.startswith("double_nested_complex"):
|
||||
double_nested_complex_node_def = node_def
|
||||
break
|
||||
if double_nested_complex_node_def:
|
||||
@ -258,6 +260,7 @@ class SimpleMetaGraphTest(test.TestCase):
|
||||
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class ScopedMetaGraphTest(test.TestCase):
|
||||
|
||||
def _testScopedExport(self, test_dir, exported_filenames):
|
||||
@ -435,10 +438,13 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
]
|
||||
orig_meta_graphs = self._testScopedExport(test_dir, filenames)
|
||||
new_meta_graphs = self._testScopedImport(test_dir, filenames)
|
||||
# Delete the unbound_inputs to allow directly calling ProtoEqual.
|
||||
del orig_meta_graphs[0].collection_def["unbound_inputs"]
|
||||
del new_meta_graphs[0].collection_def["unbound_inputs"]
|
||||
for a, b in zip(orig_meta_graphs, new_meta_graphs):
|
||||
# The unbound input strings are slightly different with the C API enabled
|
||||
# ("images" vs "images:0") due to the original import_graph_def code
|
||||
# vs. ImportGraphDef in C++.
|
||||
# TODO(skyewm): update the pbtxts once _USE_C_API is removed.
|
||||
del a.collection_def["unbound_inputs"]
|
||||
del b.collection_def["unbound_inputs"]
|
||||
test_util.assert_meta_graph_protos_equal(self, a, b)
|
||||
|
||||
def testScopedImportUnderNameScope(self):
|
||||
@ -572,7 +578,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
"exported_queue1.pbtxt")
|
||||
new_meta_graph = self._testScopedImportWithQueue(
|
||||
test_dir, "exported_queue1.pbtxt", "exported_new_queue1.pbtxt")
|
||||
self.assertProtoEquals(orig_meta_graph, new_meta_graph)
|
||||
test_util.assert_meta_graph_protos_equal(self, orig_meta_graph,
|
||||
new_meta_graph)
|
||||
|
||||
# Verifies that we can export a subgraph in a nested name scope containing a
|
||||
# "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new
|
||||
@ -718,6 +725,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class MetaGraphWithVariableScopeTest(test.TestCase):
|
||||
|
||||
def testMetricsCollection(self):
|
||||
@ -775,6 +783,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
|
||||
initializer = variables.local_variables_initializer()
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class ExportImportAcrossScopesTest(test.TestCase):
|
||||
|
||||
def testPartionedVariables(self):
|
||||
@ -845,7 +854,7 @@ class ExportImportAcrossScopesTest(test.TestCase):
|
||||
if shared_name_value.s:
|
||||
node.attr[shared_name_attr].s = b""
|
||||
|
||||
self.assertProtoEquals(expected, result)
|
||||
test_util.assert_meta_graph_protos_equal(self, expected, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -162,6 +162,16 @@ def assert_meta_graph_protos_equal(tester, a, b):
|
||||
# proto comparison below.
|
||||
a.ClearField("collection_def")
|
||||
b.ClearField("collection_def")
|
||||
|
||||
# Check the graph_defs.
|
||||
assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
|
||||
# Check graph_def versions (ignored by assert_equal_graph_def).
|
||||
tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
|
||||
# Compared the fields directly, remove their raw values from the
|
||||
# proto comparison below.
|
||||
a.ClearField("graph_def")
|
||||
b.ClearField("graph_def")
|
||||
|
||||
tester.assertProtoEquals(a, b)
|
||||
|
||||
|
||||
@ -178,7 +188,7 @@ def _strip_checkpoint_v2_randomized(graph_def):
|
||||
if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
|
||||
attr_tensor_string_value = attr_tensor_value.string_val[0]
|
||||
if (attr_tensor_string_value and
|
||||
re.match(_SHARDED_SAVE_OP_PATTERN, attr_tensor_string_value)):
|
||||
re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))):
|
||||
delete_keys.append(attr_key)
|
||||
for attr_key in delete_keys:
|
||||
del node.attr[attr_key]
|
||||
|
@ -25,6 +25,7 @@ import numpy as np
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -55,6 +56,7 @@ def _logit(x):
|
||||
return np.log(x) - np.log1p(-x)
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class AssertCloseTest(test.TestCase):
|
||||
|
||||
def testAssertCloseIntegerDtype(self):
|
||||
@ -145,6 +147,7 @@ class AssertCloseTest(test.TestCase):
|
||||
array_ops.identity(w).eval(feed_dict=feed_dict)
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class GetLogitsAndProbsTest(test.TestCase):
|
||||
|
||||
def testImproperArguments(self):
|
||||
@ -298,6 +301,7 @@ class GetLogitsAndProbsTest(test.TestCase):
|
||||
logit.eval(feed_dict={l: np.ones([int(2**11+1)])})
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class EmbedCheckCategoricalEventShapeTest(test.TestCase):
|
||||
|
||||
def testTooSmall(self):
|
||||
@ -335,6 +339,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
|
||||
du.embed_check_categorical_event_shape(param)
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class EmbedCheckIntegerCastingClosedTest(test.TestCase):
|
||||
|
||||
def testCorrectlyAssertsNonnegative(self):
|
||||
@ -370,6 +375,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
|
||||
x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.int32)})
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class LogCombinationsTest(test.TestCase):
|
||||
|
||||
def testLogCombinationsBinomial(self):
|
||||
@ -400,6 +406,7 @@ class LogCombinationsTest(test.TestCase):
|
||||
self.assertEqual([2, 2], log_binom.get_shape())
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class DynamicShapeTest(test.TestCase):
|
||||
|
||||
def testSameDynamicShape(self):
|
||||
@ -504,6 +511,7 @@ class DynamicShapeTest(test.TestCase):
|
||||
}))
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class RotateTransposeTest(test.TestCase):
|
||||
|
||||
def _np_rotate_transpose(self, x, shift):
|
||||
@ -537,6 +545,7 @@ class RotateTransposeTest(test.TestCase):
|
||||
shift: shift_value}))
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class PickVectorTest(test.TestCase):
|
||||
|
||||
def testCorrectlyPicksVector(self):
|
||||
@ -557,6 +566,7 @@ class PickVectorTest(test.TestCase):
|
||||
constant_op.constant(False), x, y)) # No eval.
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class PreferStaticRankTest(test.TestCase):
|
||||
|
||||
def testNonEmptyConstantTensor(self):
|
||||
@ -596,6 +606,7 @@ class PreferStaticRankTest(test.TestCase):
|
||||
self.assertAllEqual(0, rank.eval(feed_dict={x: 1}))
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class PreferStaticShapeTest(test.TestCase):
|
||||
|
||||
def testNonEmptyConstantTensor(self):
|
||||
@ -635,6 +646,7 @@ class PreferStaticShapeTest(test.TestCase):
|
||||
self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1}))
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class PreferStaticValueTest(test.TestCase):
|
||||
|
||||
def testNonEmptyConstantTensor(self):
|
||||
@ -675,6 +687,7 @@ class PreferStaticValueTest(test.TestCase):
|
||||
self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1}))
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class FillTriangularTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -769,6 +782,7 @@ class FillTriangularTest(test.TestCase):
|
||||
self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True)
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class ReduceWeightedLogSumExp(test.TestCase):
|
||||
|
||||
def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False):
|
||||
@ -865,6 +879,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
|
||||
du.reduce_weighted_logsumexp(x, w, axis=[0, 1]).eval())
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class GenNewSeedTest(test.TestCase):
|
||||
|
||||
def testOnlyNoneReturnsNone(self):
|
||||
@ -875,6 +890,7 @@ class GenNewSeedTest(test.TestCase):
|
||||
# TODO(jvdillon): Merge this test back into:
|
||||
# tensorflow/python/kernel_tests/softplus_op_test.py
|
||||
# once TF core is accepting new ops.
|
||||
@test_util.with_c_api
|
||||
class SoftplusTest(test.TestCase):
|
||||
|
||||
def _npSoftplus(self, np_features):
|
||||
|
@ -20,21 +20,25 @@ limitations under the License.
|
||||
%rename("%s") TFE_ContextListDevices;
|
||||
%rename("%s") TFE_ContextAddFunction;
|
||||
%rename("%s") TFE_ContextAddFunctionDef;
|
||||
%rename("%s") TFE_ContextEnableRunMetadata;
|
||||
%rename("%s") TFE_ContextDisableRunMetadata;
|
||||
%rename("%s") TFE_ContextExportRunMetadata;
|
||||
%rename("%s") TFE_ContextClearCaches;
|
||||
%rename("%s") TFE_OpNameGetAttrType;
|
||||
%rename("%s") TFE_Py_InitEagerTensor;
|
||||
%rename("%s") TFE_Py_RegisterExceptionClass;
|
||||
%rename("%s") TFE_Py_Execute;
|
||||
%rename("%s") TFE_Py_UID;
|
||||
%rename("%s") TFE_Py_TapeStackPushNew;
|
||||
%rename("%s") TFE_Py_TapeStackPush;
|
||||
%rename("%s") TFE_Py_TapeStackPop;
|
||||
%rename("%s") TFE_Py_TapeStackIsEmpty;
|
||||
%rename("%s") TFE_Py_TapeStackShouldRecord;
|
||||
%rename("%s") TFE_Py_TapeStackWatch;
|
||||
%rename("%s") TFE_Py_TapeStackDeleteTrace;
|
||||
%rename("%s") TFE_Py_TapeStackRecordOperation;
|
||||
%rename("%s") TFE_Py_TapeStackWatchVariable;
|
||||
%rename("%s") TFE_Py_TapeSetNew;
|
||||
%rename("%s") TFE_Py_TapeSetRemove;
|
||||
%rename("%s") TFE_Py_TapeSetStopOnThread;
|
||||
%rename("%s") TFE_Py_TapeSetRestartOnThread;
|
||||
%rename("%s") TFE_Py_TapeSetIsEmpty;
|
||||
%rename("%s") TFE_Py_TapeSetShouldRecord;
|
||||
%rename("%s") TFE_Py_TapeSetWatch;
|
||||
%rename("%s") TFE_Py_TapeSetDeleteTrace;
|
||||
%rename("%s") TFE_Py_TapeSetRecordOperation;
|
||||
%rename("%s") TFE_Py_TapeSetWatchVariable;
|
||||
%rename("%s") TFE_Py_TapeGradient;
|
||||
%rename("%s") TFE_Py_TapeWatchedVariables;
|
||||
%rename("%s") TFE_NewContextOptions;
|
||||
|
@ -455,11 +455,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||
tf_http_archive(
|
||||
name = "llvm",
|
||||
urls = [
|
||||
"https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/cb27e1d0da7f30562ea6c1c4f01393afbf112620.tar.gz",
|
||||
"https://github.com/llvm-mirror/llvm/archive/cb27e1d0da7f30562ea6c1c4f01393afbf112620.tar.gz",
|
||||
"https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/81a623d6d61cf87847f839e80b047c267020ab0e.tar.gz",
|
||||
"https://github.com/llvm-mirror/llvm/archive/81a623d6d61cf87847f839e80b047c267020ab0e.tar.gz",
|
||||
],
|
||||
sha256 = "d4e4d17040a786bab13bb1b73ec2dc358f0c07214f847076e0ded8de15800782",
|
||||
strip_prefix = "llvm-cb27e1d0da7f30562ea6c1c4f01393afbf112620",
|
||||
sha256 = "be0259c0bd5349200df346c92ba7708341e18ef313fcf7398682b5cff2469137",
|
||||
strip_prefix = "llvm-81a623d6d61cf87847f839e80b047c267020ab0e",
|
||||
build_file = str(Label("//third_party/llvm:llvm.BUILD")),
|
||||
)
|
||||
|
||||
|
6
third_party/gpus/cuda_configure.bzl
vendored
6
third_party/gpus/cuda_configure.bzl
vendored
@ -110,11 +110,7 @@ def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
|
||||
lang = "c++"
|
||||
else:
|
||||
lang = "c"
|
||||
# TODO: We pass -no-canonical-prefixes here to match the compiler flags,
|
||||
# but in cuda_clang CROSSTOOL file that is a `feature` and we should
|
||||
# handle the case when it's disabled and no flag is passed
|
||||
result = repository_ctx.execute([cc, "-no-canonical-prefixes",
|
||||
"-E", "-x" + lang, "-", "-v"])
|
||||
result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"])
|
||||
index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
|
||||
if index1 == -1:
|
||||
return []
|
||||
|
Loading…
Reference in New Issue
Block a user