From e513d1f516046abc4a5831e1347720922118e81b Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 8 Jan 2018 17:21:43 -0800 Subject: [PATCH 01/23] Materialize BroadcastGradientArgs by default instead of just doing so in aggressive mode. This ensures that we optimize gradient computations in the presence of variable batch sizes. PiperOrigin-RevId: 181242749 --- .../grappler/optimizers/constant_folding.cc | 75 +++++++++++++------ .../optimizers/constant_folding_test.cc | 11 +-- 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 9f24f1c7683..68feedbcbb0 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -433,13 +433,42 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( id = --min_id; } } + + // Beware: the reduction dimensions computed by the BCast class are valid iff + // we assume that two distinct symbolic dimensions can't be equal and a + // symbolic dimension can't be equal to 1. This is often but not always true, + // so to make this optimization safe we filter out these cases. + const int common_dims = std::min(shape1.size(), shape2.size()); + for (int i = 0; i < common_dims; ++i) { + if (shape1[i] >= 0 && shape2[i] >= 0) { + continue; + } + if (shape1[i] != shape2[i]) { + // We're either dealing with 2 different symbolic dimensions or a symbolic + // and a know dimensions. We can't be sure whether both are equal or not, + // so we can't be sure whether we'll be broadcasting or not. + return Status::OK(); + } + } + // These extra dims could be equal to 1, in which case there is no + // broadcasting. It could also be greater than 1, in which case there would + // be broadcasting. Since we don't know, we'll just punt. + for (int i = common_dims; i < shape1.size(); ++i) { + if (shape1[i] < 0) { + return Status::OK(); + } + } + for (int i = common_dims; i < shape2.size(); ++i) { + if (shape2[i] < 0) { + return Status::OK(); + } + } + BCast bcast(shape1, shape2); if (!bcast.IsValid()) { return Status::OK(); } - // Beware: the reduction dimensions are valid iff we assume that two distinct - // symbolic dimensions can't be equal. This is often but not always true, so - // this optimization isn't safe. + BCast::Vec reduce_dims[2]; reduce_dims[0] = bcast.grad_x_reduce_idx(); reduce_dims[1] = bcast.grad_y_reduce_idx(); @@ -447,26 +476,27 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( const DataType type = node.attr().at("T").type(); NodeDef* out[2]; for (int j = 0; j < 2; ++j) { - if (!reduce_dims[j].empty()) { - // This is the case when a tensor dimension of 1 is matched against an - // unknown dimension. The unknown dimension could also be equal to 1, in - // which case there would be no reduction. - out[j] = nullptr; - } else { - string const_name = OptimizedNodeName(node, strings::StrCat("-", j)); - out[j] = node_map_->GetNode(const_name); - if (out[j] == nullptr) { - out[j] = graph_->add_node(); - Tensor value(type, TensorShape({0})); - *out[j] = CreateNodeDef(const_name, TensorValue(&value)); - out[j]->set_device(node.device()); - node_map_->AddNode(const_name, out[j]); - string ctrl_dep = - AddControlDependency(node.name(), graph_, node_map_.get()); - *out[j]->add_input() = ctrl_dep; - node_map_->AddOutput(NodeName(ctrl_dep), const_name); + int reduction_indices = reduce_dims[j].size(); + Tensor value(type, TensorShape({reduction_indices})); + for (int i = 0; i < reduction_indices; ++i) { + if (type == DT_INT32) { + value.vec()(i) = reduce_dims[j][i]; + } else { + value.vec()(i) = reduce_dims[j][i]; } } + string const_name = OptimizedNodeName(node, strings::StrCat("-", j)); + out[j] = node_map_->GetNode(const_name); + if (out[j] == nullptr) { + out[j] = graph_->add_node(); + *out[j] = CreateNodeDef(const_name, TensorValue(&value)); + out[j]->set_device(node.device()); + node_map_->AddNode(const_name, out[j]); + string ctrl_dep = + AddControlDependency(node.name(), graph_, node_map_.get()); + *out[j]->add_input() = ctrl_dep; + node_map_->AddOutput(NodeName(ctrl_dep), const_name); + } } const std::set outputs = node_map_->GetOutputs(node.name()); @@ -584,12 +614,11 @@ Status ConstantFolding::MaterializeReductionIndices( Status ConstantFolding::MaterializeConstants( const GraphProperties& properties) { - const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; const int node_count = graph_->node_size(); for (int i = 0; i < node_count; ++i) { NodeDef& node = *graph_->mutable_node(i); const string& op = node.op(); - if (is_aggressive && op == "BroadcastGradientArgs") { + if (op == "BroadcastGradientArgs") { TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties)); } else if (IsReduction(node)) { TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties)); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index a3b3e522eb8..c53678f727f 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -1373,21 +1373,14 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) { } else if (node.name() == "p1") { ++found; EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("ConstantFolding/i-0", node.input(0)); + EXPECT_EQ("i", node.input(0)); } else if (node.name() == "p2") { ++found; EXPECT_EQ(1, node.input_size()); EXPECT_EQ("i:1", node.input(0)); - } else if (node.name() == "ConstantFolding/i-0") { - ++found; - EXPECT_EQ("Const", node.op()); - EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("^i", node.input(0)); - EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape()) - .num_elements()); } } - EXPECT_EQ(7, found); + EXPECT_EQ(6, found); } TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { From 91332633c2703727f3e776efbb4eba567cef6de1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Jan 2018 17:23:57 -0800 Subject: [PATCH 02/23] [TF:XLA] Pass CompileOptions into XlaCompilationCache::Compile. PiperOrigin-RevId: 181243048 --- .../compiler/jit/kernels/xla_launch_op.cc | 4 ++- .../compiler/jit/xla_compilation_cache.cc | 9 ++++--- .../compiler/jit/xla_compilation_cache.h | 3 ++- .../compiler/tf2xla/kernels/while_op.cc | 10 ++++--- tensorflow/compiler/tf2xla/xla_compiler.cc | 27 +++++++++++++++---- 5 files changed, 38 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 4f3f17df9c6..4842877d9af 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -257,8 +257,10 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, - variables, ctx, &kernel, &executable)); + variables, ctx, &kernel, &executable, + /*compile_options=*/nullptr)); VLOG(1) << "Executing XLA Computation..."; diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 3717c2cc242..bfff52c55a7 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -238,7 +238,8 @@ Status XlaCompilationCache::Compile( int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable) { + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -297,9 +298,9 @@ Status XlaCompilationCache::Compile( XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = - compiler.CompileFunction(XlaCompiler::CompileOptions(), function, args, - &entry->compilation_result); + entry->compilation_status = compiler.CompileFunction( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + function, args, &entry->compilation_result); } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index c3a8f68a157..0858020716f 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -66,7 +66,8 @@ class XlaCompilationCache : public ResourceBase { const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable); + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ee466520dd1..5aea25dc7df 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -201,10 +201,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); - xla::Shape body_input_shape = - xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes); - xla::Shape cond_input_shape = - xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes); + OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape body_input_shape = body.xla_input_shapes[0]; + OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape cond_input_shape = cond.xla_input_shapes[0]; VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index c55719be552..310cf20ec16 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -316,15 +316,22 @@ Status BuildArguments(const Graph& graph, return Status::OK(); } - input_shapes->resize(parameters.size()); + std::vector arg_shapes; + arg_shapes.reserve(parameters.size()); input_mapping->resize(parameters.size()); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; // Computes the shapes of non-constant arguments. - (*input_shapes)[i] = arg.shape; + arg_shapes.push_back(arg.shape); (*input_mapping)[i] = parameters[i]; } + if (use_tuple_arg) { + input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes)); + } else { + *input_shapes = arg_shapes; + } + // Use the _Arg nodes in the graph to resolve core assignments. for (const Node* n : graph.nodes()) { if (StringPiece(n->type_string()) != "_Arg") continue; @@ -348,9 +355,19 @@ Status BuildArguments(const Graph& graph, // Build parameter handles for non-constant arguments. std::vector arg_handles(parameters.size()); if (use_tuple_arg) { - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes); + xla::OpSharding tuple_sharding; + tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); + for (int64 parameter : parameters) { + const int core = (*arg_cores)[parameter]; + const int root_device = 0; + *tuple_sharding.add_tuple_shardings() = + core == -1 ? xla::sharding_builder::AssignDevice(root_device) + : xla::sharding_builder::AssignDevice(core); + } + xla::ScopedShardingAssignment assign_tuple_sharding(builder, + tuple_sharding); xla::ComputationDataHandle tuple = - builder->Parameter(0, tuple_shape, "arg_tuple"); + builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const int core = (*arg_cores)[parameters[i]]; xla::ScopedShardingAssignment assign_sharding( @@ -374,7 +391,7 @@ Status BuildArguments(const Graph& graph, for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; VLOG(2) << " XLA arg " << i - << " shape: " << xla::ShapeUtil::HumanString((*input_shapes)[i]) + << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) << " name: " << arg.name << " TF arg " << parameters[i]; XlaExpression& arg_expression = (*arg_expressions)[parameters[i]]; switch (arg.kind) { From 6fda97ab7540fb003d46f7c9810d6aab6dbc6c19 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Jan 2018 17:36:56 -0800 Subject: [PATCH 03/23] [XLA] Initial sparse layout support Adds SparseIndexArray and support methods to Literal. SparseIndexArray manages the array of sparse indices and is exposed by sparse Literals. Also adds HloSupportChecker classes for CPU and GPU. This will be run as the first HloPass during compilation, and verifies that the graph is supported by the backend. Currently only verifies shapes, and that the layout is not sparse since no backend supports sparse layouts yet. PiperOrigin-RevId: 181244401 --- tensorflow/compiler/xla/BUILD | 23 +++ tensorflow/compiler/xla/index_util.cc | 29 +++ tensorflow/compiler/xla/index_util.h | 12 ++ tensorflow/compiler/xla/layout_util.cc | 34 +++- tensorflow/compiler/xla/layout_util.h | 17 +- tensorflow/compiler/xla/layout_util_test.cc | 71 +++++++ tensorflow/compiler/xla/literal_util.cc | 36 +++- tensorflow/compiler/xla/literal_util.h | 119 +++++++++++- tensorflow/compiler/xla/literal_util_test.cc | 28 +++ tensorflow/compiler/xla/service/cpu/BUILD | 27 +++ .../compiler/xla/service/cpu/cpu_compiler.cc | 2 + .../service/cpu/cpu_hlo_support_checker.cc | 48 +++++ .../xla/service/cpu/cpu_hlo_support_checker.h | 42 +++++ .../cpu/cpu_hlo_support_checker_test.cc | 72 +++++++ tensorflow/compiler/xla/service/gpu/BUILD | 27 +++ .../compiler/xla/service/gpu/gpu_compiler.cc | 2 + .../service/gpu/gpu_hlo_support_checker.cc | 48 +++++ .../xla/service/gpu/gpu_hlo_support_checker.h | 42 +++++ .../gpu/gpu_hlo_support_checker_test.cc | 72 +++++++ tensorflow/compiler/xla/shape_util.cc | 67 +++++-- tensorflow/compiler/xla/shape_util.h | 30 ++- tensorflow/compiler/xla/sparse_index_array.cc | 110 +++++++++++ tensorflow/compiler/xla/sparse_index_array.h | 176 ++++++++++++++++++ .../compiler/xla/sparse_index_array_test.cc | 43 +++++ tensorflow/compiler/xla/xla_data.proto | 11 +- 25 files changed, 1160 insertions(+), 28 deletions(-) create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h create mode 100644 tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc create mode 100644 tensorflow/compiler/xla/sparse_index_array.cc create mode 100644 tensorflow/compiler/xla/sparse_index_array.h create mode 100644 tensorflow/compiler/xla/sparse_index_array_test.cc diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 88de17a5ff0..dcbe1fe9e5f 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -302,6 +302,7 @@ cc_library( ":array4d", ":shape_tree", ":shape_util", + ":sparse_index_array", ":status_macros", ":types", ":util", @@ -628,6 +629,28 @@ tf_cc_test( ], ) +cc_library( + name = "sparse_index_array", + srcs = ["sparse_index_array.cc"], + hdrs = ["sparse_index_array.h"], + deps = [ + ":array2d", + ":shape_util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "sparse_index_array_test", + srcs = ["sparse_index_array_test.cc"], + deps = [ + ":sparse_index_array", + ":test", + "//tensorflow/core:test_main", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 2ee23927d86..ffd1fb79e98 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -149,4 +149,33 @@ namespace xla { return stride; } +/* static */ bool IndexUtil::IndexInBounds( + const Shape& shape, tensorflow::gtl::ArraySlice index) { + int64 rank = ShapeUtil::Rank(shape); + if (rank != index.size()) { + return false; + } + for (int64 d = 0; d < rank; ++d) { + if (index[d] >= shape.dimensions(d)) { + return false; + } + } + return true; +} + +/* static */ int IndexUtil::CompareIndices( + tensorflow::gtl::ArraySlice lhs, + tensorflow::gtl::ArraySlice rhs) { + int64 rank = lhs.size(); + CHECK_EQ(rhs.size(), rank); + for (int64 dim = 0; dim < rank; ++dim) { + if (lhs[dim] < rhs[dim]) { + return -1; + } else if (lhs[dim] > rhs[dim]) { + return 1; + } + } + return 0; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index c9838966a5b..0b9188e8524 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -69,6 +69,18 @@ class IndexUtil { // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 static int64 GetDimensionStride(const Shape& shape, int64 dimension); + // Returns true iff the given multi-index is contained in the bounds for the + // shape. + static bool IndexInBounds(const Shape& shape, + tensorflow::gtl::ArraySlice index); + + // Compares the given indices in lexicographic order. lhs[0] and rhs[0] are + // compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger, + // then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is + // returned. + static int CompareIndices(tensorflow::gtl::ArraySlice lhs, + tensorflow::gtl::ArraySlice rhs); + private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); }; diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 6435226fbe6..ddf091e19ff 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -64,6 +64,13 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { + Layout layout; + layout.set_format(SPARSE); + layout.set_max_sparse_elements(max_sparse_elements); + return layout; +} + namespace { // Internal helper that creates a default layout for an array of the given rank. @@ -234,7 +241,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::ClearLayout(program_shape->mutable_result()); } -/* static */ bool LayoutUtil::IsDense(const Shape& shape) { +/* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && shape.has_layout() && IsDense(shape.layout()); } @@ -260,7 +267,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape.layout().padded_dimensions_size() == 0) { return false; } - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); for (int64 i = 0; i < shape.dimensions_size(); ++i) { if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { @@ -272,21 +279,35 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ tensorflow::gtl::ArraySlice LayoutUtil::PaddedDimensions( const Shape& shape) { - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().padded_dimensions()); } /* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape, int64 index) { - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); return shape.layout().padded_dimensions(index); } /* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) { - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); return shape.layout().padding_value(); } +/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && shape.has_layout() && + IsSparse(shape.layout()); +} + +/* static */ bool LayoutUtil::IsSparse(const Layout& layout) { + return layout.format() == SPARSE; +} + +/* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) { + CHECK(IsSparse(layout)); + return layout.max_sparse_elements(); +} + /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape: all subshapes must have a layout. @@ -313,7 +334,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( const Shape& shape) { - CHECK(IsDense(shape)); + CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().minor_to_major()); } @@ -419,6 +440,7 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, /* static */ bool LayoutUtil::AreDimensionsConsecutive( const Layout& layout, tensorflow::gtl::ArraySlice dims) { + CHECK(IsDense(layout)); std::vector positions_in_layout; for (int64 dim : dims) { positions_in_layout.push_back( diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index f73cc957649..7c1ba4b022e 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + // Creates a sparse layout with the given maximum number of elements. (This is + // a convenience function for protobuf construction.) + static Layout MakeSparseLayout(int64 max_sparse_elements); + // Returns default layout for the given shape. static Layout GetDefaultLayoutForShape(const Shape& shape); @@ -72,7 +76,7 @@ class LayoutUtil { static void ClearLayout(ProgramShape* program_shape); // Returns whether the given Shape is an array and has a dense format layout. - static bool IsDense(const Shape& shape); + static bool IsDenseArray(const Shape& shape); // Returns whether the given Layout has a dense format. static bool IsDense(const Layout& layout); @@ -107,6 +111,17 @@ class LayoutUtil { // an array and has a dense layout. static PaddingValue GetPaddingValue(const Shape& shape); + // Returns whether the given Shape is an array (i.e. not a tuple) and has a + // sparse format layout. + static bool IsSparseArray(const Shape& shape); + + // Returns whether the given Layout has a sparse format. + static bool IsSparse(const Layout& layout); + + // Returns the maximum number of elements that can be stored in a sparse + // layout. + static int64 MaxSparseElements(const Layout& layout); + // Returns whether the given shape has a layout. For tuple shapes, true is // returned only if all elements have layouts. static bool HasLayout(const Shape& shape); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 331bb9afa94..daf4dc10ac7 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -30,6 +30,14 @@ class LayoutUtilTest : public ::testing::Test { *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); return shape; } + + Shape MakeShapeWithSparseLayout(PrimitiveType element_type, + tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements) { + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); + return shape; + } }; TEST_F(LayoutUtilTest, TupleLayoutComparison) { @@ -81,6 +89,29 @@ TEST_F(LayoutUtilTest, CopyLayoutArray) { EXPECT_FALSE(dst.has_layout()); } +TEST_F(LayoutUtilTest, CopyLayoutSparse) { + Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2); + Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // Should work if destination has no layout. + dst.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // If source is cleared, then destination should be cleared. + src.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_TRUE(dst.has_layout()); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_FALSE(dst.has_layout()); +} + TEST_F(LayoutUtilTest, CopyLayoutTuple) { Shape src = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), @@ -100,6 +131,25 @@ TEST_F(LayoutUtilTest, CopyLayoutTuple) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } +TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithSparseLayout(F32, {2, 3}, 4), + MakeShapeWithSparseLayout(F32, {42, 123}, 4), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); @@ -107,6 +157,13 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } +TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) { + Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6); + Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); + ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -116,6 +173,15 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { ::testing::ContainsRegex("cannot copy layout from shape")); } +TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) { + Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); + Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4); + auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("cannot copy layout from shape")); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { Shape src = ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}), @@ -221,5 +287,10 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } +TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index cc1735e6f2c..dff5c1381ab 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -94,9 +94,15 @@ Literal::Literal(const Shape& shape, bool allocate_arrays) Piece& piece = pair.second; piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - if (ShapeUtil::IsArray(piece.subshape())) { + const Shape& subshape = piece.subshape(); + if (ShapeUtil::IsArray(subshape)) { if (allocate_arrays) { piece.set_buffer(new char[piece.size_bytes()]); + if (LayoutUtil::IsSparseArray(subshape)) { + piece.set_sparse_indices(new SparseIndexArray( + LayoutUtil::MaxSparseElements(subshape.layout()), + ShapeUtil::Rank(subshape))); + } } else { piece.set_buffer(nullptr); } @@ -112,6 +118,7 @@ void Literal::DeallocateBuffers() { Piece& piece = pair.second; if (piece.buffer() != nullptr) { delete[] piece.buffer(); + delete piece.sparse_indices(); } } } @@ -164,6 +171,15 @@ std::unique_ptr Literal::CreateFromShape(const Shape& shape) { return literal; } +const SparseIndexArray* Literal::sparse_indices( + const ShapeIndex& shape_index) const { + return piece(shape_index).sparse_indices(); +} + +SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { + return piece(shape_index).sparse_indices(); +} + /* static */ std::unique_ptr Literal::CreateFromDimensions( PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions) { @@ -247,9 +263,12 @@ std::vector Literal::DecomposeTuple() { } Piece& src_piece = piece(src_index); - // Move the respective buffer over to the element Literal. + // Move the respective buffer and sparse indices over to the element + // Literal. dest_piece.set_buffer(src_piece.buffer()); src_piece.set_buffer(nullptr); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); } } // Set this literal to be nil-shaped. @@ -406,6 +425,8 @@ Status Literal::MoveFrom(Literal&& src_literal, Piece& dest_piece = piece(dest_index); delete[] dest_piece.buffer(); dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); } src_literal.shape_ = ShapeUtil::MakeNil(); @@ -764,7 +785,7 @@ std::unique_ptr Literal::Transpose( // dimension has within the transposed array, a layout is affine if // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major // vector of the affine layout. - CHECK(LayoutUtil::IsDense(permuted_shape)); + CHECK(LayoutUtil::IsDenseArray(permuted_shape)); Layout* layout = permuted_shape.mutable_layout(); layout->clear_minor_to_major(); for (auto index : LayoutUtil::MinorToMajor(shape())) { @@ -1573,6 +1594,12 @@ LiteralProto Literal::ToProto() const { } piece.WriteToProto(proto_piece); } + + if (LayoutUtil::IsSparseArray(shape())) { + CopyToRepeatedField(proto.mutable_sparse_indices(), + sparse_indices()->data()); + } + return proto; } @@ -1653,6 +1680,7 @@ LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { } const Piece& src_piece = literal.piece(src_index); piece.set_buffer(src_piece.buffer()); + piece.set_sparse_indices(src_piece.sparse_indices()); piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); } } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index dc29c6359c6..50e25bbdd0d 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -103,6 +104,12 @@ class Literal { tensorflow::gtl::MutableArraySlice data( const ShapeIndex& shape_index = {}); + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // Returns a pointer to (or size of) the underlying buffer holding the array // at the given shape index. CHECKs if the subshape of the literal at the // given ShapeIndex is not array. @@ -160,6 +167,56 @@ class Literal { // array. string GetR1U8AsString() const; + // Creates a literal with a sparse layout and the given indices and values. + // The shape is initialized from the given dimensions. The minor dimension of + // the indices array must equal the rank of the shape (i.e. size of the + // dimensions array). The major dimension of the indices array must equal the + // number of elements in the values array. The maximum number of elements in + // the array is taken from the max_indices() value of the index array. + // + // XLA assumes that sparse literals are in sorted order for all operations. If + // the `sort` argument is true, then the indices and values will be sorted + // while copying them into the literal. If you have ensured that the indices + // and values are already sorted, then you may set the `sort` argument to + // false to skip the sorting step. + // + // For example: + // + // CreateSparse( + // {12, 12, 12}, + // SparseIndexArray(10, 3, + // Array2D{ + // {0, 1, 2}, + // {3, 4, 5}, + // {6, 7, 8}, + // {9, 10, 11}, + // }), + // {1.0, 2.0 3.0, 4.0}) + // + // This creates an array with shape F64[12,12,12]sparse{10}, that has the + // following non-zero values: + // + // [0, 1, 2]: 1.0 + // [3, 4, 5]: 2.0 + // [6, 7, 8]: 3.0 + // [9, 10, 11]: 4.0 + // + template + static std::unique_ptr CreateSparse( + tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, bool sort = true); + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). @@ -358,7 +415,7 @@ class Literal { const ShapeIndex& shape_index, NativeT value); // Overloads of Get and Set for array literals. CHECKs if the literal is not - // array-shaped. + // array-shaped and dense. template NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; template @@ -408,6 +465,8 @@ class Literal { // This function is useful if you want a polymorphic representation // of the tensor's elements (turning it to a string for something // like representation in a protobuf). + // + // This literal must have a dense layout. void EachCellAsString( const std::function indices, const string& value)>& per_cell) const; @@ -447,6 +506,8 @@ class Literal { // // generator must be a callable of the type // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. template Status Populate(const FnType& generator); @@ -485,10 +546,12 @@ class Literal { // admonishments about floating-point equality checks apply. We expect you to // use this to check for complex values that can be expressed precisely as // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. bool IsAllComplex(complex64 value) const; // Returns whether this literal is zero at the specified index. This literal - // must be an array. + // must be an array with a dense layout. bool IsZero(tensorflow::gtl::ArraySlice indices) const; // Return the count of the elements in the array at the given shape index in @@ -563,6 +626,14 @@ class Literal { char* buffer() const { return buffer_; } void set_buffer(char* buffer) { buffer_ = buffer; } + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + // Gets or sets the subshape of this piece. This reference points to a // subshape within the shape in the containing Literal (Literal::shape_). const Shape& subshape() const { return *subshape_; } @@ -598,6 +669,9 @@ class Literal { // For array-shaped pieces, this is the buffer holding the literal data. char* buffer_ = nullptr; + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + // The shape of piece. This points into the shape of the containing Literal // (Literal::shape_). const Shape* subshape_ = nullptr; @@ -836,6 +910,21 @@ template return CreateR4FromArray4DWithLayout(tmp, layout); } +template +/* static */ std::unique_ptr Literal::CreateSparse( + tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, bool sort) { + int64 num_elements = values.size(); + int64 rank = dimensions.size(); + CHECK_EQ(num_elements, indices.index_count()); + CHECK_EQ(rank, indices.rank()); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); + literal->PopulateSparse(indices, values, sort); + return literal; +} + template /* static */ std::unique_ptr Literal::CreateR4( std::initializer_list& values) { PopulateFromArray(values); } +template +void Literal::PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort) { + CHECK(LayoutUtil::IsSparseArray(shape())); + int rank = ShapeUtil::Rank(shape()); + CHECK_EQ(indices.rank(), rank); + int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); + CHECK_LE(indices.max_indices(), max_elements); + int64 num_elements = values.size(); + CHECK_LE(num_elements, max_elements); + CHECK_EQ(num_elements, indices.index_count()); + auto root_data = root_piece().data(); + root_data.remove_suffix(max_elements - values.size()); + std::copy(values.begin(), values.end(), root_data.begin()); + *this->root_piece().sparse_indices() = std::move(indices); + if (sort) { + auto root_data = this->root_piece().data(); + root_data.remove_suffix(root_data.size() - num_elements); + this->root_piece().sparse_indices()->SortWithValues(root_data); + } + DCHECK(this->root_piece().sparse_indices()->Validate(shape())); +} + template Status Literal::Populate(const FnType& generator) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); - TF_RET_CHECK(ShapeUtil::IsArray(this_shape)); + TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); tensorflow::gtl::MutableArraySlice literal_data = data(); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 4974ead048d..29efb4312f2 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -193,6 +193,34 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { ASSERT_EQ(expected, result); } +TEST_F(LiteralUtilTest, CreateSparse) { + std::vector dimensions = {8, 8, 8}; + Array2D indices = { + {3, 4, 5}, + {1, 2, 3}, + {2, 3, 4}, + {3, 5, 6}, + }; + std::vector values = {7, 8, 9, 10}; + auto literal = Literal::CreateSparse( + dimensions, SparseIndexArray(indices.n1() + 3, indices), values); + + Array2D expected_indices = { + {1, 2, 3}, + {2, 3, 4}, + {3, 4, 5}, + {3, 5, 6}, + }; + std::vector expected_values = {8, 9, 7, 10}; + + EXPECT_EQ(literal->sparse_indices()->data(), + tensorflow::gtl::ArraySlice( + expected_indices.data(), expected_indices.num_elements())); + EXPECT_EQ(tensorflow::gtl::ArraySlice(literal->data().data(), + expected_values.size()), + tensorflow::gtl::ArraySlice(expected_values)); +} + TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off auto literal = Literal::CreateR4Projected({ diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 14c86e2e720..a6b3c0c7c42 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -81,6 +81,7 @@ cc_library( ":conv_canonicalization", ":cpu_copy_insertion", ":cpu_executable", + ":cpu_hlo_support_checker", ":cpu_instruction_fusion", ":cpu_layout_assignment", ":cpu_options", @@ -873,6 +874,32 @@ tf_cc_test( ], ) +cc_library( + name = "cpu_hlo_support_checker", + srcs = ["cpu_hlo_support_checker.cc"], + hdrs = ["cpu_hlo_support_checker.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "cpu_hlo_support_checker_test", + srcs = ["cpu_hlo_support_checker_test.cc"], + deps = [ + ":cpu_hlo_support_checker", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 8e6562c237e..42fd4f100bf 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" @@ -258,6 +259,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(ShapeSizeBytesFunction()); + pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc new file mode 100644 index 00000000000..7bd4741a04b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr CpuHloSupportChecker::Run(HloModule* module) { + for (auto* computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instruction->shape(), + [&instruction](const Shape& subshape, const ShapeIndex&) { + if (LayoutUtil::IsSparseArray(subshape)) { + return xla::Unimplemented( + "CPU backend does not support HLO instruction %s with shape " + "containing a sparse layout: %s", + instruction->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(instruction->shape()) + .c_str()); + } + return Status::OK(); + })); + } + } + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h new file mode 100644 index 00000000000..2271af7b247 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass should run early in the HLO pipeline and checks for HLO constructs +// which are not supported by the CPU backend and cannot be removed via HLO +// transformations (eg, sparse layouts). +class CpuHloSupportChecker : public HloPassInterface { + public: + CpuHloSupportChecker() = default; + ~CpuHloSupportChecker() override = default; + + tensorflow::StringPiece name() const override { + return "cpu_hlo_support_checker"; + } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc new file mode 100644 index 00000000000..0f463e6de62 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +class CpuHloSupportCheckerTest : public HloTestBase { + protected: + CpuHloSupportChecker& checker() { return checker_; } + + private: + CpuHloSupportChecker checker_; +}; + +TEST_F(CpuHloSupportCheckerTest, Add) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(checker().Run(module.get()).status()); +} + +TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { + HloComputation::Builder builder(TestName()); + const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, sparse_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, sparse_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + sparse_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Status status = checker().Run(module.get()).status(); + ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_THAT(status.error_message(), + HasSubstr("CPU backend does not support")); + EXPECT_THAT(status.error_message(), + HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a86d3583a6b..69ccc7179f9 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -437,6 +437,7 @@ cc_library( ":fusion_merger", ":gpu_copy_insertion", ":gpu_executable", + ":gpu_hlo_support_checker", ":gpu_layout_assignment", ":hlo_schedule", ":instruction_fusion", @@ -610,6 +611,32 @@ tf_cc_test( ], ) +cc_library( + name = "gpu_hlo_support_checker", + srcs = ["gpu_hlo_support_checker.cc"], + hdrs = ["gpu_hlo_support_checker.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_hlo_support_checker_test", + srcs = ["gpu_hlo_support_checker_test.cc"], + deps = [ + ":gpu_hlo_support_checker", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 9f34866ff51..93db9ebbee2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" @@ -137,6 +138,7 @@ tensorflow::Status OptimizeHloModule( { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(shape_size_function); + pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc new file mode 100644 index 00000000000..4944c41f7d8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr GpuHloSupportChecker::Run(HloModule* module) { + for (auto* computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instruction->shape(), + [&instruction](const Shape& subshape, const ShapeIndex&) { + if (LayoutUtil::IsSparseArray(subshape)) { + return xla::Unimplemented( + "GPU backend does not support HLO instruction %s with shape " + "containing a sparse layout: %s", + instruction->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(instruction->shape()) + .c_str()); + } + return Status::OK(); + })); + } + } + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h new file mode 100644 index 00000000000..d9550f81b59 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// his pass should run early in the HLO pipeline and checks for HLO constructs +// which are not supported by the GPU backend and cannot be removed via HLO +// transformations (eg, sparse layouts). +class GpuHloSupportChecker : public HloPassInterface { + public: + GpuHloSupportChecker() = default; + ~GpuHloSupportChecker() override = default; + + tensorflow::StringPiece name() const override { + return "gpu_hlo_support_checker"; + } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc new file mode 100644 index 00000000000..0a4089df4c9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +class GpuHloSupportCheckerTest : public HloTestBase { + protected: + GpuHloSupportChecker& checker() { return checker_; } + + private: + GpuHloSupportChecker checker_; +}; + +TEST_F(GpuHloSupportCheckerTest, Add) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(checker().Run(module.get()).status()); +} + +TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { + HloComputation::Builder builder(TestName()); + const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, sparse_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, sparse_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + sparse_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Status status = checker().Run(module.get()).status(); + ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_THAT(status.error_message(), + HasSubstr("GPU backend does not support")); + EXPECT_THAT(status.error_message(), + HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 3d4080e353e..290ea9b496a 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -84,7 +84,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (lhs.layout().format() != rhs.layout().format()) { return false; } - if (LayoutUtil::IsDense(lhs)) { + if (LayoutUtil::IsDenseArray(lhs)) { if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), LayoutUtil::MinorToMajor(rhs))) { VLOG(3) << "CompareShapes: lhs layout != rhs layout"; @@ -202,6 +202,17 @@ StatusOr MakeShapeWithLayoutInternal( return MakeShapeWithLayout(element_type, dimensions, layout); } +/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements) { + DCHECK_NE(TUPLE, element_type); + DCHECK_NE(OPAQUE, element_type); + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); + TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); + return shape; +} + /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape) { @@ -249,7 +260,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { - CHECK(LayoutUtil::IsDense(*shape)); + CHECK(LayoutUtil::IsDenseArray(*shape)); shape->mutable_layout()->add_minor_to_major(Rank(*shape)); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); @@ -658,23 +669,55 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { TF_DCHECK_OK(ValidateShape(shape)); DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { - CHECK_GT(pointer_size, 0); - return pointer_size * shape.tuple_shapes_size(); + return ByteSizeOfTupleIndexTable(shape, pointer_size); } + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; +} + +/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, + int64 pointer_size) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_GT(pointer_size, 0); + return pointer_size * shape.tuple_shapes_size(); +} + +/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; - if (shape.layout().padded_dimensions_size() > 0) { - CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size()); - allocated_element_count = 1; - for (int64 dimension_size : shape.layout().padded_dimensions()) { - allocated_element_count *= dimension_size; - } + + if (LayoutUtil::IsSparseArray(shape)) { + allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { - allocated_element_count = ElementsIn(shape); + CHECK(LayoutUtil::IsDenseArray(shape)); + tensorflow::gtl::ArraySlice padded_dimensions = + LayoutUtil::PaddedDimensions(shape); + if (!padded_dimensions.empty()) { + CHECK_EQ(Rank(shape), padded_dimensions.size()); + allocated_element_count = 1; + for (int64 dimension_size : padded_dimensions) { + allocated_element_count *= dimension_size; + } + } else { + allocated_element_count = ElementsIn(shape); + } } return allocated_element_count * ByteSizeOfPrimitiveType(shape.element_type()); } +/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(LayoutUtil::IsSparseArray(shape)); + return LayoutUtil::MaxSparseElements(shape.layout()) * + ShapeUtil::Rank(shape) * sizeof(int64); +} + /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == TUPLE) { @@ -900,7 +943,7 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { - CHECK(LayoutUtil::IsDense(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)); Layout* new_layout = new_shape.mutable_layout(); new_layout->set_format(DENSE); new_layout->clear_minor_to_major(); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 59bdffee5a8..453d4ec0472 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -143,7 +143,10 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); class ShapeUtil { public: // Returns the number of elements are contained within the provided shape; - // e.g. for rank 0 (scalars) the result is always 1. + // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes + // may not actually be able to store this number of elements. See + // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of + // elements that can be stored in a sparse shape. // Precondition: !IsTuple(shape) static int64 ElementsIn(const Shape& shape); @@ -164,6 +167,27 @@ class ShapeUtil { // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); + // Returns the number of bytes required to store the tuple member pointers for + // a allocation of shape. The `shape` must be a TUPLE shape, and + // `pointer_size` must be larger than zero. + static int64 ByteSizeOfTupleIndexTable(const Shape& shape, + int64 pointer_size); + + // Returns the number of bytes required for the elements in an allocation of + // `shape`, which must be an array shape. The return value does not include + // the bytes needed to store sparse indices. Dense shapes use a separate + // memory location for each element, and so for these shapes, + // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this + // size also includes padding if present in the layout. For sparse shapes, + // `ByteSizeOf(shape) == ByteSizeOfElements(shape) + + // ByteSizeOfSparseindices(shape)`. + static int64 ByteSizeOfElements(const Shape& shape); + + // Returns the number of bytes required for the sparse indices in an + // allocation of shape. The shape must be an array shape. The return value + // does not include the bytes needed to store sparse indices. + static int64 ByteSizeOfSparseIndices(const Shape& shape); + // Returns a human-readable string that represents the given shape, with or // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". static string HumanString(const Shape& shape); @@ -269,6 +293,10 @@ class ShapeUtil { PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major); + static Shape MakeShapeWithSparseLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements); + // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( PrimitiveType element_type, diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc new file mode 100644 index 00000000000..e7738e67903 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/sparse_index_array.h" + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +SparseIndexArray::SparseIndexArray() : rank_(0), max_indices_(0) {} + +SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, + std::vector indices) + : indices_(std::move(indices)), rank_(rank), max_indices_(max_indices) { + CHECK_GT(rank_, 0); + CHECK_EQ(indices_.size() % rank_, 0) + << "indices_.size(): " << indices_.size() << ", rank_: " << rank_; + CHECK_LT(index_count(), max_indices_); +} + +SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, + tensorflow::gtl::ArraySlice indices) + : SparseIndexArray(max_indices, rank, + std::vector(indices.begin(), indices.end())) {} + +SparseIndexArray::SparseIndexArray(int64 max_indices, + const Array2D& indices) + : SparseIndexArray(max_indices, indices.n2(), + std::vector(indices.begin(), indices.end())) {} + +int64 SparseIndexArray::index_count() const { + CHECK_GT(rank_, 0); + CHECK_EQ(indices_.size() % rank_, 0); + return indices_.size() / rank_; +} + +tensorflow::gtl::ArraySlice SparseIndexArray::At( + int64 sparse_index_number) const { + CHECK_GT(rank_, 0); + CHECK_GE(sparse_index_number, 0); + CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size()); + return tensorflow::gtl::ArraySlice( + indices_.data() + rank_ * sparse_index_number, rank_); +} + +tensorflow::gtl::MutableArraySlice SparseIndexArray::At( + int64 sparse_index_number) { + CHECK_GT(rank_, 0); + CHECK_GE(sparse_index_number, 0); + CHECK_LE(rank_ * sparse_index_number + rank_, indices_.size()); + return tensorflow::gtl::MutableArraySlice( + indices_.data() + rank_ * sparse_index_number, rank_); +} + +void SparseIndexArray::Append(tensorflow::gtl::ArraySlice index) { + CHECK_GT(rank_, 0); + CHECK_EQ(index.size(), rank_); + indices_.insert(indices_.end(), index.begin(), index.end()); +} + +void SparseIndexArray::Clear() { indices_.clear(); } + +void SparseIndexArray::Resize(int64 num_indices) { + CHECK_GT(rank_, 0); + indices_.resize(rank_ * num_indices); +} + +bool SparseIndexArray::Validate(const Shape& shape) const { + if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) { + return false; + } + int64 num_indices = index_count(); + if (num_indices > LayoutUtil::MaxSparseElements(shape.layout())) { + return false; + } + if (num_indices < 2) { + return true; + } + tensorflow::gtl::ArraySlice last = At(0); + if (!IndexUtil::IndexInBounds(shape, last)) { + return false; + } + for (int64 n = 1; n < num_indices; ++n) { + tensorflow::gtl::ArraySlice next = At(n); + if (!IndexUtil::IndexInBounds(shape, next)) { + return false; + } + if (IndexUtil::CompareIndices(last, next) >= 0) { + return false; + } + last = next; + } + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h new file mode 100644 index 00000000000..f67f34760e0 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility class for managing sparse array indices. + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ + +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// Encapsulates the array of indices for a sparse array. A SparseIndexArray +// contain indices for up to `max_indices` elements of a sparse array. Each +// sparse index is an array of `rank` int64 value that gives the location of a +// value within a sparse array. Note that the dimensions of the array are not +// checked (except for the rank). To avoid confusion, we refer to the position +// of an index within a SparseIndexArray as a sparse index number. +class SparseIndexArray { + public: + SparseIndexArray(); + SparseIndexArray(const SparseIndexArray&) = default; + SparseIndexArray(SparseIndexArray&&) = default; + SparseIndexArray& operator=(const SparseIndexArray&) = default; + SparseIndexArray& operator=(SparseIndexArray&&) = default; + + // Constructs a SparseIndexArray that can hold up to `max_indices` sparse + // indices, with an initial contents obtained from the given array. The rank + // is taken from the minor dimension of the array. The major dimension of the + // array must not exceed `max_indices`. + SparseIndexArray(int64 max_indices, const Array2D& indices); + + // Like above, but the array is flattened. For example, the following are + // equivalent: + // + // SparseIndexArray(10, 3, + // Array2D{ + // {0, 1, 2}, + // {3, 4, 5}, + // {6, 7, 8}, + // {9, 10, 11}, + // }) + // + // SparseIndexArray(10, 3, + // {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) + // + SparseIndexArray(int64 max_indices, int64 rank, + std::vector indices = {}); + SparseIndexArray(int64 max_indices, int64 rank, + tensorflow::gtl::ArraySlice indices); + + // Returns the number of elements represented by the indices stored in the + // array. + int64 index_count() const; + + // Returns a slice that refers to the given sparse index number. The argument + // must be in the range [0, element_count()). + tensorflow::gtl::ArraySlice At(int64 sparse_index_number) const; + tensorflow::gtl::MutableArraySlice At(int64 sparse_index_number); + + // Adds the given index at the end of the array. The new size of the + // SparseIndexArray must not exceed `max_indices`. + void Append(tensorflow::gtl::ArraySlice index); + + // Removes all indices from the array. + void Clear(); + + // Resizes the array to contain the given number of sparse indices. The new + // size must be smaller than `max_indices`. If the new size is larger than + // the old size, the value of the new indices is not specified. + void Resize(int64 num_indices); + + // Returns true iff all indices are unique and occur in sorted order, and are + // valid for the given shape. + bool Validate(const Shape& shape) const; + + int64 rank() const { return rank_; } + int64 max_indices() const { return max_indices_; } + + // Returns a pointer to the int64 array that holds the sparse indices. + tensorflow::gtl::MutableArraySlice mutable_data() { return &indices_; } + tensorflow::gtl::ArraySlice data() const { return indices_; } + + // Sorts this sparse index array along with the set of corresponding values. + // The indices and values are sorted in the lexicographic order of the + // indices, from smallest to largest. + // + // For example: + // + // std::vector v{10.0, 11.0, 12.0}; + // SparseIndexArray a(10, 3, + // {{3, 4, 5}, + // {1, 2, 3}, + // {2, 3, 4}}); + // a.SortWithValues(&v); + // // Prints "11.0, 12.0, 10.0": + // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl; + // + template + void SortWithValues(tensorflow::gtl::MutableArraySlice values); + + private: + std::vector indices_; + int64 rank_; + int64 max_indices_; +}; + +template +void SparseIndexArray::SortWithValues( + tensorflow::gtl::MutableArraySlice values) { + int64 num_elements = index_count(); + CHECK_EQ(values.size(), num_elements); + std::vector sort_order; + sort_order.reserve(num_elements); + for (int64 i = 0; i < num_elements; ++i) { + sort_order.push_back(i); + } + auto sort_order_less = [this](int64 lhs, int64 rhs) { + return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; + }; + std::sort(sort_order.begin(), sort_order.end(), sort_order_less); + + // Reorder the array elements according to sort_order. Work through the array + // and follow cycles so we can do the reorder in-place. + tensorflow::gtl::InlinedVector saved_index(rank()); + for (int64 i = 0; i < num_elements; ++i) { + // sort_order[i] == -1 indicates the element has already been copied. + if (sort_order[i] < 0) { + continue; + } else if (i == sort_order[i]) { + // The element is already in sorted order. + sort_order[i] = -1; + continue; + } + + std::copy_n(At(i).begin(), rank(), saved_index.begin()); + NativeT saved_value = values[i]; + int64 j = i; + for (;;) { + if (sort_order[j] == i) { + std::copy_n(saved_index.begin(), rank(), At(j).begin()); + values[j] = saved_value; + sort_order[j] = -1; + break; + } + + std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin()); + values[j] = values[sort_order[j]]; + + int64 k = sort_order[j]; + sort_order[j] = -1; + j = k; + } + } +} + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc new file mode 100644 index 00000000000..7377f88958d --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array_test.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/sparse_index_array.h" + +#include + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(SparseIndexArrayTest, Sort) { + SparseIndexArray a(10, 3); + a.Append({2, 3, 4}); + a.Append({3, 4, 5}); + a.Append({1, 2, 3}); + a.Append({5, 6, 7}); + a.Append({4, 5, 6}); + a.Append({6, 7, 8}); + std::vector values = { + 12.0, 13.0, 11.0, 15.0, 14.0, 16.0, + }; + a.SortWithValues(&values); + ASSERT_EQ(a.data(), std::vector({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, + 6, 7, 6, 7, 8})); + ASSERT_EQ(values, std::vector({11.0, 12.0, 13.0, 14.0, 15.0, 16.0})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index e34f138b6ed..3aea0217539 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -120,6 +120,9 @@ enum Format { // The default layout, with exactly one storage location per element (ignoring // padding). DENSE = 1; + // A sparsely encoded layout, providing only the index/value pairs of non-zero + // elements. + SPARSE = 2; } // A layout describes how the array is placed in (1D) memory space. This @@ -151,6 +154,11 @@ message Layout { // field must be unset unless the format is DENSE. PaddingValue padding_value = 3; + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be unset unless the format is SPARSE. + int64 max_sparse_elements = 5; + // Important: if any field is added, be sure to modify ShapeUtil::Equal() // appropriately to account for the new field. } @@ -333,7 +341,8 @@ message LiteralProto { // The F16s and BF16s are encoded in little endian byte order bytes f16s = 11; bytes bf16s = 13; - // Next = 14 + repeated int64 sparse_indices = 14; + // Next = 15 } message WindowDimension { From caa2a1b856d1080cfec26b6ab5756aa49114597e Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 8 Jan 2018 17:59:58 -0800 Subject: [PATCH 04/23] Fix the threading model of gradient tapes. The set of tapes needs to be global to enable multithreaded programming (when it's natural for tensors to cross threads during reduction operations) but each thread still needs to be able to locally pause recording while it does gradient-related bookkeeping (like custom gradients or initialization). Also removes a mutex from the thread-local structure since it's unnecessary as we're always holding the GIL while calling across the python-c boundary unless we explicitly release it. PiperOrigin-RevId: 181246570 --- tensorflow/python/eager/backprop.py | 18 ++-- tensorflow/python/eager/function.py | 5 +- tensorflow/python/eager/pywrap_tensor.cc | 2 +- tensorflow/python/eager/pywrap_tfe.h | 36 +++---- tensorflow/python/eager/pywrap_tfe_src.cc | 111 ++++++++++++---------- tensorflow/python/eager/tape.py | 26 +++-- tensorflow/python/pywrap_tfe.i | 19 ++-- 7 files changed, 113 insertions(+), 104 deletions(-) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index a06feb16695..56a49301a22 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -344,7 +344,7 @@ def implicit_val_and_grad(f): def grad_fn(*args): """Computes the gradient of the wrapped function.""" - tape.push_new_tape() + this_tape = tape.push_new_tape() try: end_node = f(*args) if end_node is None: @@ -352,10 +352,10 @@ def implicit_val_and_grad(f): "did you forget to return a value from {}?".format( f.__name__)) finally: - popped_tape = tape.pop_tape() + tape.pop_tape(this_tape) # Sorting variables by id, which is monotonically increasing in construction # order. This ensures unique order across executions. - variables = list(sorted(popped_tape.watched_variables(), + variables = list(sorted(this_tape.watched_variables(), key=lambda v: v.handle._id)) # pylint: disable=protected-access sources = [x.handle for x in variables] @@ -363,7 +363,7 @@ def implicit_val_and_grad(f): raise ValueError("No trainable variables were accessed while the " "function was being computed.") grad = imperative_grad.imperative_grad(_default_vspace, - popped_tape, + this_tape, nest.flatten(end_node), sources) return end_node, list(zip(grad, variables)) @@ -652,7 +652,7 @@ def make_vjp(f, params=None): """Computes the value and gradient of the decorated function.""" parameter_positions = _get_arg_spec(f, params, args) assert not kwds, "The gradient function can't take keyword arguments." - tape.push_new_tape() + this_tape = tape.push_new_tape() try: sources = [] args = [ @@ -673,12 +673,12 @@ def make_vjp(f, params=None): flat_result = [gen_array_ops.identity(x) for x in flat_result] result = nest.pack_sequence_as(result, flat_result) finally: - t = tape.pop_tape() + tape.pop_tape(this_tape) def vjp(dy=None): if dy is not None: dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)] return imperative_grad.imperative_grad( - _default_vspace, t, nest.flatten(result), sources, + _default_vspace, this_tape, nest.flatten(result), sources, output_gradients=dy) return result, vjp @@ -835,11 +835,11 @@ class GradientTape(object): self._persistent = persistent def __enter__(self): - tape.push_new_tape(persistent=self._persistent) + self._tape = tape.push_new_tape(persistent=self._persistent) return self def __exit__(self, typ, value, traceback): - self._tape = tape.pop_tape() + tape.pop_tape(self._tape) def watch(self, tensor): """Ensures that `tensor` is being traced by this tape. diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 969e321dd12..f755434ad78 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -544,11 +544,12 @@ def _defun_internal(name, func, args, kwds): func_inputs = _get_defun_inputs(args) with capture_tensors(captures): - tape.push_new_tape() + this_tape = tape.push_new_tape() try: func_outputs = func(*func_inputs, **kwds) finally: - variables = tape.pop_tape().watched_variables() + tape.pop_tape(this_tape) + variables = this_tape.watched_variables() # Returning a closed-over tensor as an output does not trigger a # call to convert_to_tensor, so we manually capture all such tensors. diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 91192fea62d..6fa076507d1 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -332,7 +332,7 @@ void EagerTensor_dealloc(EagerTensor* self) { tensorflow::ClearDecrefCache(); auto id = self->id; Py_TYPE(self)->tp_free(self); - TFE_Py_TapeStackDeleteTrace(id); + TFE_Py_TapeSetDeleteTrace(id); } // Getter for `_id`. diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index a33b17ada6f..cecef426032 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -87,22 +87,25 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); // newly created type, or nullptr on error. PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); -// Pushes a new tape into the thread-local stack. -// `persistent` must be a PyBool_Type, i.e either Py_True or Py_False -void TFE_Py_TapeStackPushNew(PyObject* persistent); +// Creates a new tape and adds it to the active set. `persistent` must be a +// PyBool_Type, i.e either Py_True or Py_False +PyObject* TFE_Py_TapeSetNew(PyObject* persistent); -// Pops the tape from the top of the stack and returns it. -PyObject* TFE_Py_TapeStackPop(); - -// Pushes an existing tape onto the stack. -void TFE_Py_TapeStackPush(PyObject* tape); +// Removes the passed tape from the set of active tapes. +void TFE_Py_TapeSetRemove(PyObject* tape); // Returns true if the tape stack is empty. -PyObject* TFE_Py_TapeStackIsEmpty(); +PyObject* TFE_Py_TapeSetIsEmpty(); -PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors); -void TFE_Py_TapeStackWatch(PyObject* tensor); -void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id); +PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors); +void TFE_Py_TapeSetWatch(PyObject* tensor); +void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id); + +// Stops any gradient recording on the current thread. +void TFE_Py_TapeSetStopOnThread(); + +// Restarts gradient recording on the current thread. +void TFE_Py_TapeSetRestartOnThread(); // Records an operation in the gradient tape stack.type is a string for the // operation type, used in the backprop code. output_tensors should be a list of @@ -111,13 +114,12 @@ void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id); // operation. backward_function should be the function to be called during // backprop to, given the gradients of the output tensors, produce the gradients // of the input tensors. -void TFE_Py_TapeStackRecordOperation(PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensor_ids, - PyObject* backward_function); +void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, + PyObject* input_tensor_ids, + PyObject* backward_function); // Watches the given variable object on the given tape. -void TFE_Py_TapeStackWatchVariable(PyObject* variable); +void TFE_Py_TapeSetWatchVariable(PyObject* variable); // Computes a gradient based on information recorded on the tape.`tape` must // have been produced by TFE_Py_NewTape. `vspace` must be a diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 3ba81fb3d04..bdaeccf2860 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -538,62 +538,67 @@ static PyTypeObject TFE_Py_Tape_Type = { "TFE_Py_Tape objects", /* tp_doc */ }; +// Note: in the current design no mutex is needed here because of the python +// GIL, which is always held when any TFE_Py_* methods are called. We should +// revisit this if/when decide to not hold the GIL while manipulating the tape +// stack. +static std::unordered_set* tape_set = nullptr; +std::unordered_set* GetTapeSet() { + if (tape_set == nullptr) { + tape_set = new std::unordered_set; + } + return tape_set; +} + // xcode 7 doesn't define thread_local, so for compatibility we implement our // own. TODO(apassos) remove once we can deprecate xcode 7. #ifndef __APPLE__ -std::vector* GetTapeStack() { - thread_local std::vector tape_stack; - return &tape_stack; +bool* ThreadTapeIsStopped() { + thread_local bool thread_tape_is_stopped{false}; + return &thread_tape_is_stopped; } #else -static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED); -static std::unordered_map*>* - tape_stack GUARDED_BY(stack_mu) = nullptr; -std::vector* GetTapeStack() { - tensorflow::mutex_lock ml(stack_mu); - if (tape_stack == nullptr) { - tape_stack = - new std::unordered_map*>; +static std::unordered_map* tape_is_stopped = nullptr; +bool* ThreadTapeIsStopped() { + if (tape_is_stopped == nullptr) { + tape_is_stopped = new std::unordered_map; } - auto it = tape_stack->find(std::this_thread::get_id()); - if (it != tape_stack->end()) { - return it->second; + auto it = tape_is_stopped->find(std::this_thread::get_id()); + if (it != tape_is_stopped->end()) { + return &(it->second); } - return tape_stack - ->emplace(std::this_thread::get_id(), new std::vector) - .first->second; + return &(tape_is_stopped->emplace(std::this_thread::get_id(), false) + .first->second); } #endif -void TFE_Py_TapeStackPushNew(PyObject* persistent) { +void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } + +void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } + +PyObject* TFE_Py_TapeSetNew(PyObject* persistent) { TFE_Py_Tape_Type.tp_new = PyType_GenericNew; - if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return; + if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); tape->tape = new GradientTape(persistent == Py_True); - GetTapeStack()->push_back(tape); -} - -void TFE_Py_TapeStackPush(PyObject* tape) { Py_INCREF(tape); - GetTapeStack()->push_back(reinterpret_cast(tape)); + GetTapeSet()->insert(reinterpret_cast(tape)); + return reinterpret_cast(tape); } -PyObject* TFE_Py_TapeStackIsEmpty() { - if (GetTapeStack()->empty()) { +PyObject* TFE_Py_TapeSetIsEmpty() { + if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } -PyObject* TFE_Py_TapeStackPop() { - auto* stack = GetTapeStack(); - if (stack->empty()) { - PyErr_SetString(PyExc_RuntimeError, "tape stack is empty."); - return nullptr; - } - TFE_Py_Tape* top = stack->back(); - stack->pop_back(); - return reinterpret_cast(top); +void TFE_Py_TapeSetRemove(PyObject* tape) { + auto* stack = GetTapeSet(); + stack->erase(reinterpret_cast(tape)); + // We kept a reference to the tape in the set to ensure it wouldn't get + // deleted under us; cleaning it up here. + Py_DECREF(tape); } static std::vector MakeIntList(PyObject* list) { @@ -620,12 +625,15 @@ static std::vector MakeIntList(PyObject* list) { return tensor_ids; } -PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) { +PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { if (tensors == Py_None) { Py_RETURN_FALSE; } - auto* stack = GetTapeStack(); - if (stack->empty()) { + if (*ThreadTapeIsStopped()) { + Py_RETURN_FALSE; + } + auto* tape_set = GetTapeSet(); + if (tape_set->empty()) { Py_RETURN_FALSE; } PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); @@ -642,7 +650,7 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) { tensor_ids.push_back(FastTensorId(item)); } Py_DECREF(seq); - for (TFE_Py_Tape* tape : *stack) { + for (TFE_Py_Tape* tape : *tape_set) { if (tape->tape->ShouldRecord(tensor_ids)) { Py_RETURN_TRUE; } @@ -650,12 +658,12 @@ PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) { Py_RETURN_FALSE; } -void TFE_Py_TapeStackWatch(PyObject* tensor) { +void TFE_Py_TapeSetWatch(PyObject* tensor) { tensorflow::int64 tensor_id = FastTensorId(tensor); if (PyErr_Occurred()) { return; } - for (TFE_Py_Tape* tape : *GetTapeStack()) { + for (TFE_Py_Tape* tape : *GetTapeSet()) { tape->tape->Watch(tensor_id); } } @@ -720,8 +728,8 @@ std::vector MakeTensorIDList(PyObject* tensors) { return list; } -void TFE_Py_TapeStackWatchVariable(PyObject* variable) { - for (TFE_Py_Tape* tape : *GetTapeStack()) { +void TFE_Py_TapeSetWatchVariable(PyObject* variable) { + for (TFE_Py_Tape* tape : *GetTapeSet()) { tape->tape->WatchVariable(variable); } } @@ -736,12 +744,11 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { return result; } -void TFE_Py_TapeStackRecordOperation(PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensors, - PyObject* backward_function) { - auto* stack = GetTapeStack(); - if (stack->empty()) { +void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, + PyObject* input_tensors, + PyObject* backward_function) { + auto* set = GetTapeSet(); + if (set->empty()) { return; } std::vector input_ids = MakeTensorIDList(input_tensors); @@ -776,7 +783,7 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type, return; } - for (TFE_Py_Tape* tape : *stack) { + for (TFE_Py_Tape* tape : *set) { Py_INCREF(backward_function); tape->tape->RecordOperation( op_type_str, output_info, input_ids, backward_function, @@ -784,8 +791,8 @@ void TFE_Py_TapeStackRecordOperation(PyObject* op_type, } } -void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id) { - for (TFE_Py_Tape* tape : *GetTapeStack()) { +void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { + for (TFE_Py_Tape* tape : *GetTapeSet()) { tape->tape->DeleteTrace(tensor_id); } } diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 14b5238f740..ad82266beca 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -35,7 +35,8 @@ class Tape(object): def push_new_tape(persistent=False): """Pushes a new tape onto the tape stack.""" - pywrap_tensorflow.TFE_Py_TapeStackPushNew(persistent) + tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent) + return Tape(tape) def watch(tensor): @@ -44,7 +45,7 @@ def watch(tensor): Args: tensor: tensor to be watched. """ - pywrap_tensorflow.TFE_Py_TapeStackWatch(tensor) + pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor) def watch_variable(variable): @@ -53,42 +54,39 @@ def watch_variable(variable): Args: variable: variable to be watched. """ - pywrap_tensorflow.TFE_Py_TapeStackWatchVariable(variable) + pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable) -def pop_tape(): +def pop_tape(tape): """Pops the top tape in the stack, if any.""" - return Tape(pywrap_tensorflow.TFE_Py_TapeStackPop()) + pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access @contextlib.contextmanager def stop_recording(): - stack = [] - while not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty(): - stack.append(pop_tape()._tape) # pylint: disable=protected-access try: + pywrap_tensorflow.TFE_Py_TapeSetStopOnThread() yield finally: - for tape in reversed(stack): - pywrap_tensorflow.TFE_Py_TapeStackPush(tape) + pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread() def should_record(tensors): """Returns true if any tape in the stack watches any of these tensors.""" - return pywrap_tensorflow.TFE_Py_TapeStackShouldRecord(tensors) + return pywrap_tensorflow.TFE_Py_TapeSetShouldRecord(tensors) def record_operation(op_type, output_tensors, input_tensors, backward_function): """Records the operation on all tapes in the stack.""" - pywrap_tensorflow.TFE_Py_TapeStackRecordOperation( + pywrap_tensorflow.TFE_Py_TapeSetRecordOperation( op_type, output_tensors, input_tensors, backward_function) def delete_trace(tensor_id): """Deletes traces for this Tensor from all tapes in the stack.""" - pywrap_tensorflow.TFE_Py_TapeStackDeleteTrace(tensor_id) + pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id) def could_possibly_record(): """Returns True if any tape is active.""" - return not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty() + return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty() diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index d97823c17f8..42e4773df3e 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -26,15 +26,16 @@ limitations under the License. %rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_UID; -%rename("%s") TFE_Py_TapeStackPushNew; -%rename("%s") TFE_Py_TapeStackPush; -%rename("%s") TFE_Py_TapeStackPop; -%rename("%s") TFE_Py_TapeStackIsEmpty; -%rename("%s") TFE_Py_TapeStackShouldRecord; -%rename("%s") TFE_Py_TapeStackWatch; -%rename("%s") TFE_Py_TapeStackDeleteTrace; -%rename("%s") TFE_Py_TapeStackRecordOperation; -%rename("%s") TFE_Py_TapeStackWatchVariable; +%rename("%s") TFE_Py_TapeSetNew; +%rename("%s") TFE_Py_TapeSetRemove; +%rename("%s") TFE_Py_TapeSetStopOnThread; +%rename("%s") TFE_Py_TapeSetRestartOnThread; +%rename("%s") TFE_Py_TapeSetIsEmpty; +%rename("%s") TFE_Py_TapeSetShouldRecord; +%rename("%s") TFE_Py_TapeSetWatch; +%rename("%s") TFE_Py_TapeSetDeleteTrace; +%rename("%s") TFE_Py_TapeSetRecordOperation; +%rename("%s") TFE_Py_TapeSetWatchVariable; %rename("%s") TFE_Py_TapeGradient; %rename("%s") TFE_Py_TapeWatchedVariables; %rename("%s") TFE_NewContextOptions; From 2cd288baa4a1c18c14b5572ef54fa29bc18dfce1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Jan 2018 18:08:47 -0800 Subject: [PATCH 05/23] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 181247809 --- tensorflow/go/op/wrappers.go | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index e495857afe5..42c4f81b82e 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -5922,6 +5922,7 @@ func RandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types [ // // This op is hidden from public in Python. It is used by TensorFlow Debugger to // register gradient tensors for gradient debugging. +// This op operates on non-reference-type tensors. func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return From 20db88eec824259764b2eafba377f93ea11776b0 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 8 Jan 2018 18:15:51 -0800 Subject: [PATCH 06/23] Ignore nodes that are going to be swapped when computing max memory usage PiperOrigin-RevId: 181248577 --- .../core/grappler/costs/graph_memory.cc | 22 +++++++- .../core/grappler/costs/graph_memory_test.cc | 56 +++++++++++++++++++ .../trivial_test_graph_input_yielder.cc | 14 +++-- 3 files changed, 85 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc index 3168758c8bd..3604de392f8 100644 --- a/tensorflow/core/grappler/costs/graph_memory.cc +++ b/tensorflow/core/grappler/costs/graph_memory.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_memory.h" #include #include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor_description.pb.h" @@ -163,6 +164,8 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) { NodeMap node_map(&item_.graph); for (const auto& dev_stats : timeline.dev_stats()) { + const string& device_name = dev_stats.device(); + const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:")); std::list& device_tensors = live_tensors_per_device[dev_stats.device()]; for (const auto& node_stats : dev_stats.node_stats()) { @@ -194,7 +197,24 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) { // graph (e.g _Send/_Recv nodes). continue; } - for (const string& input : node->input()) { + std::unordered_set swapped_inputs; + if (is_gpu) { + auto it = node->attr().find("_swap_to_host"); + if (it != node->attr().end()) { + const AttrValue& val = it->second; + for (int port_id : val.list().i()) { + swapped_inputs.insert(port_id); + } + } + } + for (int i = 0; i < node->input_size(); ++i) { + if (swapped_inputs.find(i) != swapped_inputs.end()) { + // The memory of swapped inputs will be released as early as possible: + // therefore ignore this input when determining the deallocation time + // of the tensor. + continue; + } + const string& input = node->input(i); int position; string input_node = ParseNodeName(input, &position); if (position < 0) { diff --git a/tensorflow/core/grappler/costs/graph_memory_test.cc b/tensorflow/core/grappler/costs/graph_memory_test.cc index 6f3522b068b..95170ba49b7 100644 --- a/tensorflow/core/grappler/costs/graph_memory_test.cc +++ b/tensorflow/core/grappler/costs/graph_memory_test.cc @@ -134,6 +134,62 @@ TEST_F(GraphMemoryTest, MultiDevice) { EXPECT_EQ(gpu_expected, gpu_tensors); } +TEST_F(GraphMemoryTest, GpuSwapping) { + TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false, {"/GPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + item.feed.clear(); + + { + // Estimate the max memory usage for the graph. + GraphMemory memory(item); + Status s = memory.InferStatically(devices_); + TF_CHECK_OK(s); + + const GraphMemory::MemoryUsage& gpu_mem = + memory.GetPeakMemoryUsage("/GPU:0"); + EXPECT_EQ(20971520, gpu_mem.used_memory); + std::set gpu_tensors; + for (const auto& t : gpu_mem.live_tensors) { + gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); + } + std::set gpu_expected; + gpu_expected.insert("Square:0"); + gpu_expected.insert("Square_1:0"); + gpu_expected.insert("AddN:0"); + gpu_expected.insert("AddN_1:0"); + gpu_expected.insert("AddN_2:0"); + EXPECT_EQ(gpu_expected, gpu_tensors); + } + + { + // Swap the first input to node AddN_1: its fanin (the square nodes) should + // not appear in the max cut anymore. + for (auto& node : *item.graph.mutable_node()) { + if (node.name() == "AddN_1") { + (*node.mutable_attr())["_swap_to_host"].mutable_list()->add_i(0); + } + } + GraphMemory memory(item); + Status s = memory.InferStatically(devices_); + TF_CHECK_OK(s); + const GraphMemory::MemoryUsage& new_gpu_mem = + memory.GetPeakMemoryUsage("/GPU:0"); + EXPECT_EQ(20971520, new_gpu_mem.used_memory); + std::set new_gpu_tensors; + for (const auto& t : new_gpu_mem.live_tensors) { + new_gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); + } + std::set new_gpu_expected; + new_gpu_expected.insert("AddN:0"); + new_gpu_expected.insert("AddN_1:0"); + new_gpu_expected.insert("AddN_2:0"); + new_gpu_expected.insert("AddN_3:0"); + new_gpu_expected.insert("AddN_4:0"); + EXPECT_EQ(new_gpu_expected, new_gpu_tensors); + } +} + TEST_F(GraphMemoryTest, CtrlDependencies) { // Build a simple graph with a control dependency. Scope s = Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc index 6d25556770d..ec54bd5c759 100644 --- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc +++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc @@ -31,8 +31,6 @@ namespace { GraphDef CreateGraphDef(int num_stages, int width, int tensor_size, bool use_multiple_devices, bool insert_queue, const std::vector& device_names) { - CHECK_GE(device_names.size(), width); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -49,13 +47,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size, std::vector this_stage; for (int j = 0; j < width; j++) { if (last_stage.size() == 1) { - Output unary_op = - Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]), - last_stage[0]); + Output unary_op = Square( + s.WithDevice( + device_names[use_multiple_devices ? j % device_names.size() + : 0]), + last_stage[0]); this_stage.push_back(unary_op); } else { Output combine = - AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]), + AddN(s.WithDevice( + device_names[use_multiple_devices ? j % device_names.size() + : 0]), last_stage); this_stage.push_back(combine); } From 3cc8e02a29e51e5ff1f02f678712f0f04260730a Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 8 Jan 2018 18:57:44 -0800 Subject: [PATCH 07/23] Make meta_graph_test.py work with C API enabled. This mainly means changing test_util.assert_meta_graph_protos_equal to call assert_equal_graph_def instead of comparing the GraphDef protos using assertEquals, since the NodeDefs are in a different order with the C API enabled. It also makes assert_equal_graph_def treat inputs 'tensor' and 'tensor:0' as equal (they're both valid). PiperOrigin-RevId: 181252488 --- tensorflow/core/util/equal_graph_def.cc | 5 +++- .../python/framework/meta_graph_test.py | 23 +++++++++++++------ tensorflow/python/framework/test_util.py | 12 +++++++++- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/util/equal_graph_def.cc b/tensorflow/core/util/equal_graph_def.cc index a3b7db98cc0..f1ec497a677 100644 --- a/tensorflow/core/util/equal_graph_def.cc +++ b/tensorflow/core/util/equal_graph_def.cc @@ -148,7 +148,10 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff, first_control_input = i; break; } - if (actual.input(i) != expected.input(i)) { + // Special case for inputs: "tensor" is equivalent to "tensor:0" + if (actual.input(i) != expected.input(i) && + actual.input(i) != strings::StrCat(expected.input(i), ":0") && + strings::StrCat(actual.input(i), ":0") != expected.input(i)) { if (diff != nullptr) { *diff = strings::StrCat("Node named '", actual.name(), "' has input ", i, " '", actual.input(i), diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 7bb799a8b06..b5ed1352843 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -59,6 +59,7 @@ def _TestDir(test_name): # pylint: enable=invalid-name +@test_util.with_c_api class SimpleMetaGraphTest(test.TestCase): def testNoVariables(self): @@ -103,7 +104,8 @@ class SimpleMetaGraphTest(test.TestCase): # Re-exports the current graph state for comparison to the original. new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename + "_new") - self.assertProtoEquals(meta_graph_def, new_meta_graph_def) + test_util.assert_meta_graph_protos_equal(self, meta_graph_def, + new_meta_graph_def) # Ensures that we can still get a reference to our graph collections. new_input_tensor = ops.get_collection("input_tensor")[0] @@ -226,7 +228,7 @@ class SimpleMetaGraphTest(test.TestCase): double_nested_complex_node_def = None for function_def in meta_graph_def.graph_def.library.function: for node_def in function_def.node_def: - if node_def.name == "double_nested_complex": + if node_def.name.startswith("double_nested_complex"): double_nested_complex_node_def = node_def break if double_nested_complex_node_def: @@ -258,6 +260,7 @@ class SimpleMetaGraphTest(test.TestCase): self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) +@test_util.with_c_api class ScopedMetaGraphTest(test.TestCase): def _testScopedExport(self, test_dir, exported_filenames): @@ -435,10 +438,13 @@ class ScopedMetaGraphTest(test.TestCase): ] orig_meta_graphs = self._testScopedExport(test_dir, filenames) new_meta_graphs = self._testScopedImport(test_dir, filenames) - # Delete the unbound_inputs to allow directly calling ProtoEqual. - del orig_meta_graphs[0].collection_def["unbound_inputs"] - del new_meta_graphs[0].collection_def["unbound_inputs"] for a, b in zip(orig_meta_graphs, new_meta_graphs): + # The unbound input strings are slightly different with the C API enabled + # ("images" vs "images:0") due to the original import_graph_def code + # vs. ImportGraphDef in C++. + # TODO(skyewm): update the pbtxts once _USE_C_API is removed. + del a.collection_def["unbound_inputs"] + del b.collection_def["unbound_inputs"] test_util.assert_meta_graph_protos_equal(self, a, b) def testScopedImportUnderNameScope(self): @@ -572,7 +578,8 @@ class ScopedMetaGraphTest(test.TestCase): "exported_queue1.pbtxt") new_meta_graph = self._testScopedImportWithQueue( test_dir, "exported_queue1.pbtxt", "exported_new_queue1.pbtxt") - self.assertProtoEquals(orig_meta_graph, new_meta_graph) + test_util.assert_meta_graph_protos_equal(self, orig_meta_graph, + new_meta_graph) # Verifies that we can export a subgraph in a nested name scope containing a # "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new @@ -718,6 +725,7 @@ class ScopedMetaGraphTest(test.TestCase): self.assertEqual("", str(graph2.as_graph_element("matmul").device)) +@test_util.with_c_api class MetaGraphWithVariableScopeTest(test.TestCase): def testMetricsCollection(self): @@ -775,6 +783,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase): initializer = variables.local_variables_initializer() +@test_util.with_c_api class ExportImportAcrossScopesTest(test.TestCase): def testPartionedVariables(self): @@ -845,7 +854,7 @@ class ExportImportAcrossScopesTest(test.TestCase): if shared_name_value.s: node.attr[shared_name_attr].s = b"" - self.assertProtoEquals(expected, result) + test_util.assert_meta_graph_protos_equal(self, expected, result) if __name__ == "__main__": diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 5ac30537499..b2fb63dbbac 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -162,6 +162,16 @@ def assert_meta_graph_protos_equal(tester, a, b): # proto comparison below. a.ClearField("collection_def") b.ClearField("collection_def") + + # Check the graph_defs. + assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True) + # Check graph_def versions (ignored by assert_equal_graph_def). + tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions) + # Compared the fields directly, remove their raw values from the + # proto comparison below. + a.ClearField("graph_def") + b.ClearField("graph_def") + tester.assertProtoEquals(a, b) @@ -178,7 +188,7 @@ def _strip_checkpoint_v2_randomized(graph_def): if attr_tensor_value and len(attr_tensor_value.string_val) == 1: attr_tensor_string_value = attr_tensor_value.string_val[0] if (attr_tensor_string_value and - re.match(_SHARDED_SAVE_OP_PATTERN, attr_tensor_string_value)): + re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))): delete_keys.append(attr_key) for attr_key in delete_keys: del node.attr[attr_key] From 33eeb44e70cfd6e5f9f5c8a398338573c6f6148a Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 8 Jan 2018 19:09:47 -0800 Subject: [PATCH 08/23] Don't treat PlaceholderWithDefaults as constants. Without this, the shape refiner will use the shape of the default, which may be more precise than the specified shape. This would cause util_test to fail with the C API enabled. PiperOrigin-RevId: 181253665 --- .../core/common_runtime/constant_folding.cc | 5 +++ .../common_runtime/constant_folding_test.cc | 34 +++++++++++++++++++ .../core/common_runtime/shape_refiner.cc | 15 ++++++++ .../core/common_runtime/shape_refiner_test.cc | 19 +++++++++++ .../kernel_tests/distributions/util_test.py | 16 +++++++++ 5 files changed, 89 insertions(+) diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 0398c2a60d1..5235e520568 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -226,6 +226,11 @@ bool IsConstantFoldable( if (consider && !consider(n)) { return false; } + // PlaceholderWithDefault shouldn't be constant folded because its output can + // be fed non-constant values. + if (n->type_string() == "PlaceholderWithDefault") { + return false; + } if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) { return false; } diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 923a4d92493..31f41e133b9 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -338,6 +338,40 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) { EXPECT_FALSE(was_mutated); } +TEST_F(ConstantFoldingTest, Placeholders) { + Graph g(OpRegistry::Global()); + { + Scope s = Scope::NewRootScope(); + auto placeholder = ops::Placeholder(s, DT_DOUBLE); + auto add = ops::Add(s, placeholder, 2.0); + auto send = + ops::_Send(s.WithOpName("send"), add, "add", "sender", 0, "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + } + bool was_mutated; + Status s = ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(), + nullptr, &g, &was_mutated); + EXPECT_FALSE(was_mutated); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find( + "You must feed a value for placeholder " + "tensor 'Placeholder' with dtype double") != string::npos); + + Graph g2(OpRegistry::Global()); + { + Scope s = Scope::NewRootScope(); + auto placeholder = ops::PlaceholderWithDefault(s, {1.0}, {1}); + auto add = ops::Add(s, placeholder, 2.0); + auto send = + ops::_Send(s.WithOpName("send"), add, "add", "sender", 0, "receiver"); + TF_ASSERT_OK(s.ToGraph(&g2)); + } + // TODO(skyewm): should this have the same behavior as Placeholder? + TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(), + nullptr, &g2, &was_mutated)); + EXPECT_FALSE(was_mutated); +} + TEST_F(ConstantFoldingTest, ControlDependencies) { Graph g(OpRegistry::Global()); { diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 3ae52f414fa..45cdab98e06 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -558,6 +558,13 @@ Status ShapeRefiner::ExtractConstantSubgraph( return Status::OK(); } + if (target_node->type_string() == "PlaceholderWithDefault") { + return Status::OK(); + } + + // TODO(skyewm): more of the filtering applied in input nodes below should be + // applied to target_node here + struct NodeAndRecursed { Node* new_node = nullptr; bool recursed = false; @@ -608,6 +615,14 @@ Status ShapeRefiner::ExtractConstantSubgraph( return Status::OK(); } + // Placeholders should never be constant folded because their outputs are + // fed by the user. Note that "Placeholder" nodes have no inputs so are + // handled below. + if (current_node->type_string() == "PlaceholderWithDefault") { + *is_constant_graph = false; + return Status::OK(); + } + // If there is nothing more to recurse down, see if // the generator node is a constant. if (current_node->num_inputs() == 0) { diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index e4eef1dbe28..adf5a9afff2 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -724,6 +724,25 @@ TEST_F(ShapeRefinerTest, PropagateRange) { EXPECT_EQ("[1,4,7,10]", ctx->DebugString(ctx->output(0))); } +// Make sure PlaceholderWithDefaults aren't treated as constants. +TEST_F(ShapeRefinerTest, NoPropagatePlaceholderWithDefault) { + Scope root = Scope::NewRootScope(); + auto constant = ops::Const(root, 2); + auto placeholder = + ops::PlaceholderWithDefault(root, constant, PartialTensorShape()); + Node* shape_data; + TF_ASSERT_OK(NodeBuilder("Test", "ShapeData") + .Input(placeholder.node()) + .Finalize(root.graph(), &shape_data)); + + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + TF_ASSERT_OK(m.AddNode(constant.node())); + TF_ASSERT_OK(m.AddNode(placeholder.node())); + TF_ASSERT_OK(m.AddNode(shape_data)); + shape_inference::InferenceContext* ic = m.GetContext(shape_data); + EXPECT_EQ(ic->DebugString(ic->output(0)), "?"); +} + TEST_F(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) { Scope root = Scope::NewRootScope(); // This node is used as two inputs to 'range'. diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 00781d01505..f54f146e0ac 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl @@ -55,6 +56,7 @@ def _logit(x): return np.log(x) - np.log1p(-x) +@test_util.with_c_api class AssertCloseTest(test.TestCase): def testAssertCloseIntegerDtype(self): @@ -145,6 +147,7 @@ class AssertCloseTest(test.TestCase): array_ops.identity(w).eval(feed_dict=feed_dict) +@test_util.with_c_api class GetLogitsAndProbsTest(test.TestCase): def testImproperArguments(self): @@ -298,6 +301,7 @@ class GetLogitsAndProbsTest(test.TestCase): logit.eval(feed_dict={l: np.ones([int(2**11+1)])}) +@test_util.with_c_api class EmbedCheckCategoricalEventShapeTest(test.TestCase): def testTooSmall(self): @@ -335,6 +339,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): du.embed_check_categorical_event_shape(param) +@test_util.with_c_api class EmbedCheckIntegerCastingClosedTest(test.TestCase): def testCorrectlyAssertsNonnegative(self): @@ -370,6 +375,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase): x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.int32)}) +@test_util.with_c_api class LogCombinationsTest(test.TestCase): def testLogCombinationsBinomial(self): @@ -400,6 +406,7 @@ class LogCombinationsTest(test.TestCase): self.assertEqual([2, 2], log_binom.get_shape()) +@test_util.with_c_api class DynamicShapeTest(test.TestCase): def testSameDynamicShape(self): @@ -504,6 +511,7 @@ class DynamicShapeTest(test.TestCase): })) +@test_util.with_c_api class RotateTransposeTest(test.TestCase): def _np_rotate_transpose(self, x, shift): @@ -537,6 +545,7 @@ class RotateTransposeTest(test.TestCase): shift: shift_value})) +@test_util.with_c_api class PickVectorTest(test.TestCase): def testCorrectlyPicksVector(self): @@ -557,6 +566,7 @@ class PickVectorTest(test.TestCase): constant_op.constant(False), x, y)) # No eval. +@test_util.with_c_api class PreferStaticRankTest(test.TestCase): def testNonEmptyConstantTensor(self): @@ -596,6 +606,7 @@ class PreferStaticRankTest(test.TestCase): self.assertAllEqual(0, rank.eval(feed_dict={x: 1})) +@test_util.with_c_api class PreferStaticShapeTest(test.TestCase): def testNonEmptyConstantTensor(self): @@ -635,6 +646,7 @@ class PreferStaticShapeTest(test.TestCase): self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1})) +@test_util.with_c_api class PreferStaticValueTest(test.TestCase): def testNonEmptyConstantTensor(self): @@ -675,6 +687,7 @@ class PreferStaticValueTest(test.TestCase): self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1})) +@test_util.with_c_api class FillTriangularTest(test.TestCase): def setUp(self): @@ -769,6 +782,7 @@ class FillTriangularTest(test.TestCase): self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True) +@test_util.with_c_api class ReduceWeightedLogSumExp(test.TestCase): def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False): @@ -865,6 +879,7 @@ class ReduceWeightedLogSumExp(test.TestCase): du.reduce_weighted_logsumexp(x, w, axis=[0, 1]).eval()) +@test_util.with_c_api class GenNewSeedTest(test.TestCase): def testOnlyNoneReturnsNone(self): @@ -875,6 +890,7 @@ class GenNewSeedTest(test.TestCase): # TODO(jvdillon): Merge this test back into: # tensorflow/python/kernel_tests/softplus_op_test.py # once TF core is accepting new ops. +@test_util.with_c_api class SoftplusTest(test.TestCase): def _npSoftplus(self, np_features): From bd72021bd1223ac42b1010e6599b2528a18dc33c Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 8 Jan 2018 19:39:41 -0800 Subject: [PATCH 09/23] [XLA] Fix spelling in error message; NFC PiperOrigin-RevId: 181256065 --- tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 2fc369dc0e6..ba2b6c0dad2 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -1515,7 +1515,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, return false; } } else { - return TokenError(StrCat("unsupported premitive type ", + return TokenError(StrCat("unsupported primitive type ", PrimitiveType_Name(shape.element_type()))); } break; From 51895fe67434b6e9f5419872f69c7e6092ed69e9 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 8 Jan 2018 19:41:05 -0800 Subject: [PATCH 10/23] [TF:XLA] Add while loop invariant code motion This new pass extracts out loop invariant computations out of while loops into their parent computations. Right now this is enabled only for the CPU backend. PiperOrigin-RevId: 181256166 --- tensorflow/compiler/xla/map_util.h | 6 + tensorflow/compiler/xla/service/BUILD | 74 +++ .../compiler/xla/service/call_inliner.cc | 15 +- .../compiler/xla/service/call_inliner.h | 8 +- .../compiler/xla/service/call_inliner_test.cc | 2 +- tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 2 + tensorflow/compiler/xla/service/tuple_util.cc | 61 +++ tensorflow/compiler/xla/service/tuple_util.h | 45 ++ .../compiler/xla/service/tuple_util_test.cc | 81 ++++ .../while_loop_invariant_code_motion.cc | 296 ++++++++++++ .../while_loop_invariant_code_motion.h | 39 ++ .../while_loop_invariant_code_motion_test.cc | 442 ++++++++++++++++++ .../xla/service/while_loop_simplifier.cc | 4 +- tensorflow/compiler/xla/service/while_util.cc | 140 ++++++ tensorflow/compiler/xla/service/while_util.h | 58 +++ .../compiler/xla/service/while_util_test.cc | 130 ++++++ tensorflow/compiler/xla/tests/while_test.cc | 44 ++ tensorflow/compiler/xla/util.h | 25 + 19 files changed, 1464 insertions(+), 9 deletions(-) create mode 100644 tensorflow/compiler/xla/service/tuple_util.cc create mode 100644 tensorflow/compiler/xla/service/tuple_util.h create mode 100644 tensorflow/compiler/xla/service/tuple_util_test.cc create mode 100644 tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc create mode 100644 tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h create mode 100644 tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc create mode 100644 tensorflow/compiler/xla/service/while_util.cc create mode 100644 tensorflow/compiler/xla/service/while_util.h create mode 100644 tensorflow/compiler/xla/service/while_util_test.cc diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 51d0d5f86f0..50659c12405 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -60,6 +60,12 @@ bool ContainsKey(const Collection& collection, const Key& key) { return collection.find(key) != collection.end(); } +// Inserts `value` into `set`. Dies if it was already present. +template +void InsertOrDie(Set* const set, const typename Set::value_type& value) { + CHECK(set->insert(value).second) << "duplicate value: " << value; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index c8afb7d6505..f26dc64fee1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1103,6 +1103,8 @@ cc_library( ":hlo", ":hlo_evaluator", ":hlo_pass", + ":tuple_util", + ":while_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", ], @@ -2257,6 +2259,78 @@ cc_library( ], ) +cc_library( + name = "tuple_util", + srcs = ["tuple_util.cc"], + hdrs = ["tuple_util.h"], + deps = [ + ":hlo", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "tuple_util_test", + srcs = ["tuple_util_test.cc"], + deps = [ + ":tuple_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "while_util", + srcs = ["while_util.cc"], + hdrs = ["while_util.h"], + deps = [ + ":call_inliner", + ":hlo", + ":tuple_util", + ], +) + +tf_cc_test( + name = "while_util_test", + srcs = ["while_util_test.cc"], + deps = [ + ":while_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "while_loop_invariant_code_motion", + srcs = ["while_loop_invariant_code_motion.cc"], + hdrs = ["while_loop_invariant_code_motion.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":tuple_util", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_invariant_code_motion_test", + srcs = ["while_loop_invariant_code_motion_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_invariant_code_motion", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 3aa7f5c4d58..482ccc5b671 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -82,6 +82,10 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { return outer_->ReplaceInstruction(call_, new_root); } + CallInliner::InlinedInstructionMap ConsumeInstructionMap() { + return std::move(subcomputation_hlo_to_new_hlo_); + } + private: // Resolves the callee subcomputation_hlo to the new (inline) HLO in the // caller computation, or returns a NotFound error if that subcomputation HLO @@ -112,13 +116,13 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { HloInstruction* call_; HloComputation* outer_; - std::unordered_map - subcomputation_hlo_to_new_hlo_; + CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_; }; } // namespace -/* static */ Status CallInliner::Inline(HloInstruction* call) { +/* static */ StatusOr CallInliner::Inline( + HloInstruction* call) { TF_RET_CHECK(call->opcode() == HloOpcode::kCall) << "Instruction was not a call op: " << call->opcode(); const auto& callees = call->called_computations(); @@ -126,7 +130,8 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { HloComputation* callee = callees[0]; // We visit the callee, cloning its body into its caller. SubcomputationInsertionVisitor visitor(call); - return callee->Accept(&visitor); + TF_RETURN_IF_ERROR(callee->Accept(&visitor)); + return visitor.ConsumeInstructionMap(); } StatusOr CallInliner::Run(HloModule* module) { @@ -140,7 +145,7 @@ StatusOr CallInliner::Run(HloModule* module) { VLOG(1) << "Visiting callsite: " << callsite.ToString(); if (callsite.instruction()->opcode() == HloOpcode::kCall) { HloInstruction* call = callsite.instruction(); - TF_RETURN_IF_ERROR(Inline(call)); + TF_RETURN_IF_ERROR(Inline(call).status()); did_mutate = true; } } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index 2dbd38bf1ac..a8345a394d4 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -27,8 +27,12 @@ namespace xla { // called function, and proceed recursively. class CallInliner : public HloPassInterface { public: - // Inlines one call instruction. - static Status Inline(HloInstruction* call); + using InlinedInstructionMap = + std::unordered_map; + + // Inlines one call instruction. Returns a mapping from the original + // instructions to their inlined versions. + static StatusOr Inline(HloInstruction* call); ~CallInliner() override = default; tensorflow::StringPiece name() const override { return "CallInliner"; } diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 865ed993da1..738d00881dd 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -135,7 +135,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { HloInstruction::CreateCall(pred, {}, false_computation)); auto computation = module->AddEntryComputation(call_false_builder.Build()); - TF_ASSERT_OK(CallInliner::Inline(call)); + TF_ASSERT_OK(CallInliner::Inline(call).status()); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_THAT(computation->root_instruction()->control_successors(), ElementsAre(op::Constant())); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index a6b3c0c7c42..2f025916312 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -127,6 +127,7 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 42fd4f100bf..27bdde41af0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -86,6 +86,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -293,6 +294,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // elimination has to come after that pass. pipeline.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc new file mode 100644 index 00000000000..4a530bb0b20 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +/*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple, + int64 elements) { + CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + + HloComputation* computation = input_tuple->parent(); + const Shape& input_shape = input_tuple->shape(); + + std::vector tuple_elements; + tuple_elements.reserve(elements); + for (int i = 0; i < elements; i++) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + input_shape.tuple_shapes(i), input_tuple, i))); + } + + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + +/*static*/ HloInstruction* TupleUtil::AppendSuffix( + HloInstruction* input_tuple, + tensorflow::gtl::ArraySlice trailing_values) { + CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + + HloComputation* computation = input_tuple->parent(); + const Shape& input_shape = input_tuple->shape(); + std::vector tuple_elements; + tuple_elements.reserve(input_shape.tuple_shapes_size()); + for (int i = 0; i < input_shape.tuple_shapes_size(); i++) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + input_shape.tuple_shapes(i), input_tuple, i))); + } + tuple_elements.insert(tuple_elements.end(), trailing_values.begin(), + trailing_values.end()); + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h new file mode 100644 index 00000000000..e5ff9aaa835 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +class TupleUtil { + public: + // Generates HLO instructions to get a prefix tuple from `input_tuple` (which + // must be of tuple shape) of length `elements`. Returns the root of the + // graph of instructions generated. + // + // The instructions are generated into the computation containing + // `input_tuple`. + static HloInstruction* ExtractPrefix(HloInstruction* input_tuple, + int64 elements); + + // Generates HLO instructions to create a tuple that consists of the values in + // `trailing_values` appended to `input_tuple` (which must be of tuple shape). + // Returns the root of the graph of instructions generated. + // + // The instructions are generated into the computation containing + // `input_tuple`. + static HloInstruction* AppendSuffix( + HloInstruction* input_tuple, + tensorflow::gtl::ArraySlice trailing_values); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc new file mode 100644 index 00000000000..754fd8ef169 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_util.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +StatusOr> GetParsedModule( + HloComputation** entry_computation, HloInstruction** param0, + HloInstruction** param1) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[32,32]{1,0},f32[32,32]{1,0},f32[32,32]{1,0}) parameter(0) + ROOT p1 = f32[32,32]{1,0} parameter(1) +} +)"; + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + tools::Parse(hlo_string)); + + *entry_computation = module->entry_computation(); + *param0 = (*entry_computation)->parameter_instruction(0); + *param1 = (*entry_computation)->parameter_instruction(1); + + return std::move(module); +} + +TEST(TupleUtilTest, ExtractPrefix) { + HloInstruction *param0, *param1; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1)); + + HloInstruction* prefix = TupleUtil::ExtractPrefix(param0, 2); + + EXPECT_THAT(prefix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1))); +} + +TEST(TupleUtilTest, AppendSuffix) { + HloInstruction *param0, *param1; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1)); + + HloInstruction* with_suffix = + TupleUtil::AppendSuffix(param0, {param1, param1}); + + EXPECT_THAT(with_suffix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1), + op::GetTupleElement(op::Parameter(0), 2), + op::Parameter(1), op::Parameter(1))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc new file mode 100644 index 00000000000..a5f9b01f011 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -0,0 +1,296 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +using tensorflow::gtl::FlatMap; +using tensorflow::gtl::FlatSet; +using tensorflow::gtl::InlinedVector; + +// Copies `to_hoist` to the computation containing `while_instr`, hoisting its +// operands as needed. All of its transitive operands are expected to be either +// in `hoisted_instructions` or `unhoisted_invariant_instructions`. This +// function hoists the operands in `unhoisted_invariant_instructions` and moves +// them into `hoisted_instructions`. +static void CreateLoopInvariantCopy( + FlatMap* hoisted_instructions, + FlatSet* unhoisted_invariant_instructions, + HloInstruction* while_instr, HloInstruction* to_hoist) { + HloComputation* parent_of_while = while_instr->parent(); + HloComputation* while_body = while_instr->while_body(); + + struct DFSFrame { + HloInstruction* instruction; + int64 operand_index; + }; + + InlinedVector dfs_stack; + dfs_stack.push_back({to_hoist, 0}); + + HloInstruction* while_body_param = while_body->parameter_instruction(0); + HloInstruction* while_operand = while_instr->mutable_operand(0); + + do { + DFSFrame* frame = &dfs_stack.back(); + if (frame->operand_index == frame->instruction->operand_count()) { + HloInstruction* old_instruction = frame->instruction; + + // All of the operands for old_instruction have been cloned, so it is + // time to clone old_instruction itself. + + auto get_new_operand = [&](HloInstruction* old_operand) { + return old_operand == while_body_param + ? while_operand + : FindOrDie(*hoisted_instructions, old_operand); + }; + + InlinedVector new_operands; + c_transform(old_instruction->operands(), std::back_inserter(new_operands), + get_new_operand); + + HloInstruction* new_instruction = + parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( + old_instruction->shape(), new_operands)); + + InsertOrDie(hoisted_instructions, old_instruction, new_instruction); + + // Approximately half of the instructions that would normally be present + // in unhoisted_invariant_instructions are constants. We save a bit of + // compile time by not putting these in the hashtable. + CHECK_EQ(unhoisted_invariant_instructions->erase(old_instruction), + to_hoist != old_instruction && + old_instruction->opcode() != HloOpcode::kConstant); + dfs_stack.pop_back(); + continue; + } + + HloInstruction* next_operand = + frame->instruction->mutable_operand(frame->operand_index++); + if (hoisted_instructions->count(next_operand) || + next_operand == while_body_param) { + continue; + } + + dfs_stack.push_back({next_operand, 0}); + } while (!dfs_stack.empty()); +} + +// Returns true if `instruction` is worth hoisting only if it lets us hoist some +// instruction using it. The rationale is that hoisting these instructions will +// prevent simplification and fusion in the while body. +static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { + switch (instruction.opcode()) { + default: + return false; + + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kConstant: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kTuple: + return true; + + case HloOpcode::kTranspose: + return ShapeUtil::TransposeIsBitcast( + /*input_shape=*/instruction.operand(0)->shape(), + /*output_shape=*/instruction.shape(), instruction.dimensions()); + + case HloOpcode::kReshape: + return ShapeUtil::ReshapeIsBitcast( + /*input_shape=*/instruction.operand(0)->shape(), + /*output_shape=*/instruction.shape()); + } +} + +// Populates `gte_set` with the GetTupleElement instructions in `while_body` +// that access elements in the parameter tuple that don't change across +// iterations. Assumes `while_body` is the body computation of the while loop +// in question. +static void GatherInvariantGTEs(HloComputation* while_body, + FlatSet* gte_set) { + const HloInstruction::InstructionVector root_operands = + while_body->root_instruction()->operands(); + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && + instr->operand(0) == while_body->parameter_instruction(0) && + ShapeUtil::IsArray(instr->shape())) { + InsertOrDie(gte_set, instr); + } + } +} + +static StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr) { + auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); + + if (!ShapeUtil::IsTuple(while_instr->shape())) { + // This restriction leaves one interesting pattern on the table: + // + // while_body(f32[1024, 1024] %param) { + // %value = expensive_op(%param) + // outfeed(%value) + // ROOT = %param + // } + // + // If we see that pattern in the while, instead of generalizing this + // algorithm to work with non-tuples, we should instead add a pass that + // canonicalizes while loops like the above to use a tuple state. + return false; + } + + string while_instr_name = while_instr->ToString(print_no_metadata); + VLOG(2) << "Trying to hoist from " << while_instr_name; + + HloComputation* while_body = while_instr->while_body(); + + // Maps instructions in the while body to instructions hoisted outside the + // while that compute the same value. + FlatMap hoisted_instructions; + + // Contains instructions that can be legally hoisted, but were deemed to be + // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we + // hoist an instruction in this set, we move it from + // unhoisted_invariant_instructions to hoisted_instructions. + FlatSet unhoisted_invariant_instructions; + + // Invariant GTE's axiomatically satisfy the constraints for + // unhoisted_invariant_instructions -- they can be legally hoisted, but there + // is no benefit to hoisting them unless something that uses it is also + // hoisted. + GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions); + + if (unhoisted_invariant_instructions.empty()) { + // There are no obviously loop invariant elements in the state being + // threaded through the while loop so give up. In theory this precondition + // is too strong -- we could have code that e.g. permutes the elements in + // the while state but uses a select to pick the same value on every + // iteration. + return false; + } + + // instructions_to_replace[i] is hoisted into a loop invariant instruction + // replacement_instructions[i]. + std::vector instructions_to_replace; + std::vector replacement_instructions; + + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kParameter || + !instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + continue; + } + + auto is_invariant = [&](HloInstruction* op) { + return hoisted_instructions.find(op) != hoisted_instructions.end() || + unhoisted_invariant_instructions.count(op) || + op->opcode() == HloOpcode::kConstant; + }; + + if (!c_all_of(instruction->operands(), is_invariant)) { + continue; + } + + if (NotWorthHoistingIndividually(*instruction)) { + VLOG(2) << "Adding " << instruction->ToString(print_no_metadata) + << " to unhoisted invariant set."; + // Approximately half of the instructions that reach this point are + // constants. We save a bit of compile time by not putting these in the + // hashtable. + if (instruction->opcode() != HloOpcode::kConstant) { + InsertOrDie(&unhoisted_invariant_instructions, instruction); + } + continue; + } + + VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata); + + CreateLoopInvariantCopy(&hoisted_instructions, + &unhoisted_invariant_instructions, while_instr, + instruction); + + instructions_to_replace.push_back(instruction); + replacement_instructions.push_back( + FindOrDie(hoisted_instructions, instruction)); + } + + if (instructions_to_replace.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN( + WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions)); + + HloComputation* new_while_body = + live_in_instructions_result.new_while_instr->while_body(); + + for (int i = 0; i < instructions_to_replace.size(); i++) { + HloInstruction* instruction_to_replace_in_new_while = + FindOrDie(live_in_instructions_result.while_body_instruction_map, + instructions_to_replace[i]); + TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction( + instruction_to_replace_in_new_while, + live_in_instructions_result.while_body_live_in_values[i])); + } + + VLOG(1) << "Hoisted " << instructions_to_replace.size() + << " instructions from " << while_instr_name; + + return true; +} + +StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->computations()) { + c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + } + + for (HloInstruction* while_instr : while_instrs) { + // Right now we only hoist computations from the while body, but + // TryHoistingInvariantInstructionsFromWhileBody can be generalized to + // optimize the condition computation too, if needed. + // + // The transform we do here is a pessmization for while loops that execute + // zero times*, but at this time we expect those to be rare. If this + // becomes a problem we can consider using the conditional HLO to avoid + // doing extra work for while loops with zero trip count. + // + // * We delete while loops that have a zero trip count, so this would have + // to be a while loop with a somewhat opaque condition expression. + + TF_ASSIGN_OR_RETURN( + bool result, + TryHoistingInvariantInstructionsFromWhileBody(while_instr)); + changed |= result; + } + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h new file mode 100644 index 00000000000..8c4b765b000 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that rewrites while loops to hoist loop invariant instructions in +// the while body into the computation that contains the while instruction. + +class WhileLoopInvariantCodeMotion : public HloPassInterface { + public: + ~WhileLoopInvariantCodeMotion() override = default; + + tensorflow::StringPiece name() const override { + return "while-loop-invariant-code-motion"; + } + StatusOr Run(HloModule* module) override; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc new file mode 100644 index 00000000000..799340fda90 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -0,0 +1,442 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { + public: + // Makes a computation which has one parameter, of the given shape, and always + // returns PRED[]{true}. This is useful as a dummy loop condition. + HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, + HloModule* module); +}; + +static void FindOnlyWhileInstruction(HloComputation* computation, + HloInstruction** while_instruction) { + *while_instruction = nullptr; + for (auto* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + ASSERT_EQ(*while_instruction, nullptr); + *while_instruction = instr; + } + } + + ASSERT_NE(*while_instruction, nullptr); +} + +HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( + const Shape& param_shape, HloModule* module) { + HloComputation::Builder builder(TestName() + ".always_true"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "param")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + return module->AddEmbeddedComputation(builder.Build()); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Add()))); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* gte_2_loop_variant = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 2)); + + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + HloInstruction* mul_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kMultiply, add_result, gte_1)); + HloInstruction* negate_result = + builder.AddInstruction(HloInstruction::CreateUnary( + scalar_s32, HloOpcode::kNegate, mul_result)); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction* sub_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kSubtract, negate_result, constant)); + HloInstruction* divide_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, divide_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(entry_computation->instructions(), + AllOf(Contains(op::Add()), Contains(op::Multiply()), + Contains(op::Negate()), Contains(op::Subtract()), + Contains(op::Constant()), + + // The division had a loop varying operand so that better + // not be hoisted. + Not(Contains(op::Divide())))); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(), + op::Subtract(), op::Constant())))); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Contains(op::Divide())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, + DontHoistTriviallyLoopVaryingComputation) { + // Basic negative test: the add expression is not loop invariant. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, + DontHoistLoopVaryingComputationWithAlternatingTuples) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_1, gte_0, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Outfeed())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { + // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the + // bitcast either. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* bitcast_inst = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Outfeed())); + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Bitcast())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { + // The bitcast's user can be hoisted, so hoist the bitcast too. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_f32, param, 1)); + HloInstruction* bitcast_inst = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction* add_inst = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Add()))); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Bitcast()))); + EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); + EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body; + { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + TF_ASSERT_OK(param->AddControlDependencyTo(add_result)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_result})); + + while_body = module().AddEmbeddedComputation(builder.Build()); + } + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".passthrough"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + + result->AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + return result; + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 615a089d125..87a7f86f4ec 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -595,7 +595,9 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { auto call_op = computation->AddInstruction(HloInstruction::CreateCall( while_op->shape(), while_op->operands(), while_op->while_body())); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_op)); + TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_op)); + (void)inlined_instructions_map; return true; } return false; diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc new file mode 100644 index 00000000000..e20b25e4a08 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" + +namespace xla { + +static StatusOr WidenWhileCondition( + HloComputation* narrow_condition, const Shape& wide_shape) { + const Shape& narrow_shape = + narrow_condition->parameter_instruction(0)->shape(); + + HloComputation* wide_while_cond = [&]() { + HloComputation::Builder builder( + tensorflow::strings::StrCat("wide.", narrow_condition->name())); + builder.AddInstruction( + HloInstruction::CreateParameter(0, wide_shape, "wide_param")); + + // This is needed so that the root instruction is shaped as a PRED[] -- we + // need to get this right to begin with since we can't mutate the type of + // the root instruction later. We later change the root instruction to + // something more appropriate. + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + return narrow_condition->parent()->AddEmbeddedComputation(builder.Build()); + }(); + + HloInstruction* truncated_parameter = + TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0), + narrow_shape.tuple_shapes_size()); + HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction( + HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}), + {truncated_parameter}, narrow_condition)); + + wide_while_cond->set_root_instruction(call_narrow_cond); + + TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status()); + return wide_while_cond; +} + +static StatusOr> +WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { + const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape(); + + HloComputation* wide_while_body = [&]() { + HloComputation::Builder builder( + tensorflow::strings::StrCat("wide.", narrow_body->name())); + builder.AddInstruction( + HloInstruction::CreateParameter(0, wide_shape, "wide_param")); + return narrow_body->parent()->AddEmbeddedComputation(builder.Build()); + }(); + + HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0); + HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix( + wide_parameter, narrow_shape.tuple_shapes_size()); + HloInstruction* call_narrow_body = + wide_while_body->AddInstruction(HloInstruction::CreateCall( + narrow_shape, {truncated_parameter}, narrow_body)); + + std::vector live_through_values; + for (int i = narrow_shape.tuple_shapes_size(); + i < wide_shape.tuple_shapes_size(); i++) { + live_through_values.push_back( + wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + wide_shape.tuple_shapes(i), wide_parameter, i))); + } + + wide_while_body->set_root_instruction( + TupleUtil::AppendSuffix(call_narrow_body, live_through_values)); + + TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_narrow_body)); + return {{wide_while_body, std::move(inlined_instructions_map)}}; +} + +/*static*/ StatusOr +WhileUtil::MakeInstructionsLiveIn( + HloInstruction* while_instr, + tensorflow::gtl::ArraySlice instructions) { + CHECK(ShapeUtil::IsTuple(while_instr->shape())); + + int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); + Shape new_while_shape = while_instr->shape(); + for (auto* instruction : instructions) { + *new_while_shape.add_tuple_shapes() = instruction->shape(); + } + + TF_ASSIGN_OR_RETURN( + HloComputation * new_while_condition, + WidenWhileCondition(while_instr->while_condition(), new_while_shape)); + + HloComputation* new_while_body; + CallInliner::InlinedInstructionMap inlined_instructions_map; + TF_ASSIGN_OR_RETURN( + std::tie(new_while_body, inlined_instructions_map), + WidenWhileBody(while_instr->while_body(), new_while_shape)); + + HloInstruction* new_while_init = + TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions); + HloComputation* containing_computation = while_instr->parent(); + HloInstruction* new_while = containing_computation->AddInstruction( + HloInstruction::CreateWhile(new_while_shape, new_while_condition, + new_while_body, new_while_init)); + TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( + while_instr, TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + + HloInstruction* while_body_param = new_while_body->parameter_instruction(0); + std::vector live_in_instructions; + for (int64 i = elements_in_old_while_shape; + i < new_while_shape.tuple_shapes_size(); i++) { + live_in_instructions.push_back( + new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + instructions[i - elements_in_old_while_shape]->shape(), + while_body_param, i))); + } + + WhileUtil::MakeInstructionsLiveInResult result; + + result.new_while_instr = new_while; + result.while_body_live_in_values = std::move(live_in_instructions); + result.while_body_instruction_map = std::move(inlined_instructions_map); + + return std::move(result); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h new file mode 100644 index 00000000000..3600b5a80d2 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ + +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +class WhileUtil { + public: + // Holds a return value from MakeInstructionsLiveIn. + struct MakeInstructionsLiveInResult { + // The new while operation that has the requested values live in. + HloInstruction* new_while_instr; + + // The i'th element of `while_body_live_in_values` is an instruction in the + // while body that holds the i'th *newly added* live in value at runtime. + std::vector while_body_live_in_values; + + // `while_body_instruction_map` maps instructions in the original while body + // to the corresponding instructions in the body for the newly created while + // operation. + CallInliner::InlinedInstructionMap while_body_instruction_map; + }; + + // Replaces `while_instr` with a new while instruction that is equivalent to + // `while_instr`, except that it has all of the HLO instructions in + // `instructions` as live-in, loop invariant values. These new live in values + // are represented as new elements appended to the parameter of the while + // loop, which must be of tuple shape. GetTupleElement instructions computing + // each new live in value is returned in the `while_body_live_in_values` + // vector. + // + // Precondition: `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. + static StatusOr MakeInstructionsLiveIn( + HloInstruction* while_instr, + tensorflow::gtl::ArraySlice instructions); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc new file mode 100644 index 00000000000..cf0d0db99bd --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_util.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +StatusOr> GetParsedModule( + HloComputation** entry_computation, HloInstruction** param0, + HloInstruction** param1, HloInstruction** param2) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +while_body { + ROOT p_body = (f32[32,32]{1,0}, f32[32,32]{1,0}) parameter(0) +} + +while_condition { + p_cond = f32[32,32]{1,0} parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + p_entry_0 = f32[32,32]{1,0} parameter(0) + p_entry_1 = s32[32,32]{1,0} parameter(1) + p_entry_2 = s64[32,32]{1,0} parameter(2) + while_init = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p_entry_0, p_entry_0) + ROOT while = (f32[32,32]{1,0}, f32[32,32]{1,0}) while(while_init), condition=while_condition, body=while_body +} +)"; + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + tools::Parse(hlo_string)); + + *entry_computation = module->entry_computation(); + *param0 = (*entry_computation)->parameter_instruction(0); + *param1 = (*entry_computation)->parameter_instruction(1); + *param2 = (*entry_computation)->parameter_instruction(2); + + return std::move(module); +} + +TEST(WhileUtil, MakeZeroInstructionsLiveOp) { + HloInstruction *param0, *param1, *param2; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); + + HloInstruction* while_instr = entry_computation->root_instruction(); + ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, /*instructions=*/{})); + + HloInstruction* new_while_instr = make_live_in_result.new_while_instr; + + EXPECT_THAT( + entry_computation->root_instruction(), + op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), + op::GetTupleElement(::testing::Eq(new_while_instr), 1))); + + auto param_reconstructed = + op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1)); + + EXPECT_THAT(new_while_instr->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(param_reconstructed, 0), + op::GetTupleElement(param_reconstructed, 1))); +} + +TEST(WhileUtilTest, MakeTwoInstructionsLive) { + HloInstruction *param0, *param1, *param2; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); + + HloInstruction* while_instr = entry_computation->root_instruction(); + ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{param0, param1})); + + HloInstruction* new_while_instr = make_live_in_result.new_while_instr; + + XLA_VLOG_LINES(3, module->ToString()); + + EXPECT_THAT( + entry_computation->root_instruction(), + op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), + op::GetTupleElement(::testing::Eq(new_while_instr), 1))); + + auto first_half_param_reconstructed = + op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1)); + + EXPECT_THAT(new_while_instr->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(first_half_param_reconstructed, 0), + op::GetTupleElement(first_half_param_reconstructed, 1), + op::GetTupleElement(op::Parameter(0), 2), + op::GetTupleElement(op::Parameter(0), 3))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 0b3430ee1ee..7e7f6b14862 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -1160,6 +1160,50 @@ TEST_F(WhileTest, WhileWithCallInsideCondition) { ComputeAndCompareR0(&builder, 5, {}); } +TEST_F(WhileTest, WhileWithLoopInvariantOperation) { + auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto while_shape = ShapeUtil::MakeTupleShape( + {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto state = builder.Parameter(0, while_shape, "state"); + builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto state = builder.Parameter(0, while_shape, "state"); + auto indvar = builder.GetTupleElement(state, 0); + auto input_0 = builder.GetTupleElement(state, 1); + auto input_1 = builder.GetTupleElement(state, 2); + auto output = builder.Tanh(builder.Dot(input_0, input_1)); + auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); + auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + ComputationBuilder builder(client_, TestName()); + auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); + auto init = builder.Tuple( + {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); + auto while_instruction = builder.While(condition, body, init); + builder.GetTupleElement(while_instruction, 3); + + TF_ASSERT_OK_AND_ASSIGN(auto param_value, + client_->TransferToServer(*Literal::CreateR2( + {{1.0, 2.0}, {-1.0, -2.0}}))); + + ComputeAndCompareR2( + &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, + {param_value.get()}, ErrorSpec(4e-5)); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 277cc5ec86f..bb2db2010c5 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -398,6 +398,31 @@ std::vector> CommonFactors( // Removes illegal characters from filenames. string SanitizeFileName(string file_name); +// Simple wrapper around std::all_of. +template +bool c_all_of(Container container, Predicate predicate) { + return std::all_of(std::begin(container), std::end(container), predicate); +} + +// Simple wrapper around std::transform. +template +OutputIterator c_transform(InputContainer input_container, + OutputIterator output_iterator, + UnaryOperation unary_op) { + return std::transform(std::begin(input_container), std::end(input_container), + output_iterator, unary_op); +} + +// Simple wrapper around std::copy_if. +template +OutputIterator c_copy_if(InputContainer input_container, + OutputIterator output_iterator, + UnaryPredicate predicate) { + return std::copy_if(std::begin(input_container), std::end(input_container), + output_iterator, predicate); +} + } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ From e5652f4027fa8b0fd90c9c958fbf8e50cdf33675 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 8 Jan 2018 20:15:33 -0800 Subject: [PATCH 11/23] [TF:XLA] Bump open source llvm revision to r321863 PiperOrigin-RevId: 181259191 --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 3c6209a809e..5d26c3468d6 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -455,11 +455,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/cb27e1d0da7f30562ea6c1c4f01393afbf112620.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/cb27e1d0da7f30562ea6c1c4f01393afbf112620.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/81a623d6d61cf87847f839e80b047c267020ab0e.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/81a623d6d61cf87847f839e80b047c267020ab0e.tar.gz", ], - sha256 = "d4e4d17040a786bab13bb1b73ec2dc358f0c07214f847076e0ded8de15800782", - strip_prefix = "llvm-cb27e1d0da7f30562ea6c1c4f01393afbf112620", + sha256 = "be0259c0bd5349200df346c92ba7708341e18ef313fcf7398682b5cff2469137", + strip_prefix = "llvm-81a623d6d61cf87847f839e80b047c267020ab0e", build_file = str(Label("//third_party/llvm:llvm.BUILD")), ) From ac3560e16e58c8f5ea94f736b99bb60ed46fa885 Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Mon, 8 Jan 2018 20:35:21 -0800 Subject: [PATCH 12/23] Fix typo in comment TFE_Execute PiperOrigin-RevId: 181260751 --- tensorflow/c/eager/c_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5a5e5fe0d71..7ccfc45a5ff 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -542,7 +542,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // WARNING: kernel->Run utilizes the FunctionLibraryRuntime // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation - // of FunctionLibraryRuntime tells use that func_lib_def is not accessed by + // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. // This is quite subtle. Re-work things to make this better? (Would it make // sense for FunctionLibraryRuntime to ensure thread-safe access to From ff2ba6bd17f71048c2c4ae44f5297a1e03644d09 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Jan 2018 20:36:13 -0800 Subject: [PATCH 13/23] Automated g4 rollback of changelist 181243048 PiperOrigin-RevId: 181260801 --- .../compiler/jit/kernels/xla_launch_op.cc | 4 +-- .../compiler/jit/xla_compilation_cache.cc | 9 +++---- .../compiler/jit/xla_compilation_cache.h | 3 +-- .../compiler/tf2xla/kernels/while_op.cc | 10 +++---- tensorflow/compiler/tf2xla/xla_compiler.cc | 27 ++++--------------- 5 files changed, 15 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 4842877d9af..4f3f17df9c6 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -257,10 +257,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, - variables, ctx, &kernel, &executable, - /*compile_options=*/nullptr)); + variables, ctx, &kernel, &executable)); VLOG(1) << "Executing XLA Computation..."; diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index bfff52c55a7..3717c2cc242 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -238,8 +238,7 @@ Status XlaCompilationCache::Compile( int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + xla::LocalExecutable** executable) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -298,9 +297,9 @@ Status XlaCompilationCache::Compile( XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + entry->compilation_status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), function, args, + &entry->compilation_result); } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 0858020716f..c3a8f68a157 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -66,8 +66,7 @@ class XlaCompilationCache : public ResourceBase { const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + xla::LocalExecutable** executable); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 5aea25dc7df..ee466520dd1 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -201,12 +201,10 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); - OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, - errors::FailedPrecondition("Expected one input shape")); - xla::Shape body_input_shape = body.xla_input_shapes[0]; - OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, - errors::FailedPrecondition("Expected one input shape")); - xla::Shape cond_input_shape = cond.xla_input_shapes[0]; + xla::Shape body_input_shape = + xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes); + xla::Shape cond_input_shape = + xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes); VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 310cf20ec16..c55719be552 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -316,22 +316,15 @@ Status BuildArguments(const Graph& graph, return Status::OK(); } - std::vector arg_shapes; - arg_shapes.reserve(parameters.size()); + input_shapes->resize(parameters.size()); input_mapping->resize(parameters.size()); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; // Computes the shapes of non-constant arguments. - arg_shapes.push_back(arg.shape); + (*input_shapes)[i] = arg.shape; (*input_mapping)[i] = parameters[i]; } - if (use_tuple_arg) { - input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes)); - } else { - *input_shapes = arg_shapes; - } - // Use the _Arg nodes in the graph to resolve core assignments. for (const Node* n : graph.nodes()) { if (StringPiece(n->type_string()) != "_Arg") continue; @@ -355,19 +348,9 @@ Status BuildArguments(const Graph& graph, // Build parameter handles for non-constant arguments. std::vector arg_handles(parameters.size()); if (use_tuple_arg) { - xla::OpSharding tuple_sharding; - tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); - for (int64 parameter : parameters) { - const int core = (*arg_cores)[parameter]; - const int root_device = 0; - *tuple_sharding.add_tuple_shardings() = - core == -1 ? xla::sharding_builder::AssignDevice(root_device) - : xla::sharding_builder::AssignDevice(core); - } - xla::ScopedShardingAssignment assign_tuple_sharding(builder, - tuple_sharding); + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes); xla::ComputationDataHandle tuple = - builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + builder->Parameter(0, tuple_shape, "arg_tuple"); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const int core = (*arg_cores)[parameters[i]]; xla::ScopedShardingAssignment assign_sharding( @@ -391,7 +374,7 @@ Status BuildArguments(const Graph& graph, for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; VLOG(2) << " XLA arg " << i - << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) + << " shape: " << xla::ShapeUtil::HumanString((*input_shapes)[i]) << " name: " << arg.name << " TF arg " << parameters[i]; XlaExpression& arg_expression = (*arg_expressions)[parameters[i]]; switch (arg.kind) { From 11a19750bcb5802d44d51bb4baf74b3fc9ac52bc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Jan 2018 20:53:48 -0800 Subject: [PATCH 14/23] subtract the task's size from the batch's when removing it. PiperOrigin-RevId: 181262266 --- tensorflow/contrib/batching/batch_scheduler.h | 1 + tensorflow/contrib/batching/batch_scheduler_test.cc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index e18cf6c3505..aa8891ab4ef 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -210,6 +210,7 @@ std::unique_ptr Batch::RemoveTask() { return nullptr; } std::unique_ptr task = std::move(tasks_.back()); + size_ -= task->size(); tasks_.pop_back(); return task; } diff --git a/tensorflow/contrib/batching/batch_scheduler_test.cc b/tensorflow/contrib/batching/batch_scheduler_test.cc index f15d8cc8e57..b627fee972a 100644 --- a/tensorflow/contrib/batching/batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/batch_scheduler_test.cc @@ -74,7 +74,9 @@ TEST(BatchTest, Basic) { EXPECT_EQ(task1->size(), batch.task(1).size()); EXPECT_EQ(7, batch.RemoveTask()->size()); + EXPECT_EQ(3, batch.size()); EXPECT_EQ(3, batch.RemoveTask()->size()); + EXPECT_EQ(0, batch.size()); EXPECT_TRUE(batch.empty()); } From c17abb0aafab284f3a8470e57e3479db5008fb52 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Mon, 8 Jan 2018 21:11:20 -0800 Subject: [PATCH 15/23] Avoid deadlock by sleeping for a longer duration. PiperOrigin-RevId: 181263637 --- .../data/python/kernel_tests/interleave_dataset_op_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index e13c60c9a71..b1937c08f34 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -524,7 +524,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.1) # Sleep to consistently "avoid" the race condition. + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True @@ -611,7 +611,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.1) # Sleep to consistently "avoid" the race condition. + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True From e6e31d0c2a118348c76306bcaba50b943f239c9a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Jan 2018 21:45:02 -0800 Subject: [PATCH 16/23] [XLA:Tool] Fix an error message in the Hlo parser: should report unexpected subattribute's location correctly. PiperOrigin-RevId: 181265992 --- tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index ba2b6c0dad2..75bedfabe27 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -1851,7 +1851,7 @@ bool HloParser::ParseWindow(Window* window) { if (field_name == "rhs_reversal") { return ParseDxD("rhs_reversal", &rhs_reversal); } - return Error(loc, StrCat("unexpected attribute name: ", field_name)); + return Error(attr_loc, StrCat("unexpected attribute name: ", field_name)); }(); if (!ok) { return false; From 3a0f98cf806c612e6895dd26f706c7c18efbac1b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Jan 2018 03:49:52 -0800 Subject: [PATCH 17/23] Forward-mode automatic differentiation for Tensorflow. PiperOrigin-RevId: 181297995 --- tensorflow/contrib/nn/BUILD | 15 ++++ .../contrib/nn/python/ops/fwd_gradients.py | 76 +++++++++++++++++++ .../nn/python/ops/fwd_gradients_test.py | 52 +++++++++++++ 3 files changed, 143 insertions(+) create mode 100644 tensorflow/contrib/nn/python/ops/fwd_gradients.py create mode 100644 tensorflow/contrib/nn/python/ops/fwd_gradients_test.py diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD index 56a24ac77f0..5543eb6c6e3 100644 --- a/tensorflow/contrib/nn/BUILD +++ b/tensorflow/contrib/nn/BUILD @@ -17,6 +17,7 @@ py_library( "python/ops/__init__.py", "python/ops/alpha_dropout.py", "python/ops/cross_entropy.py", + "python/ops/fwd_gradients.py", "python/ops/sampling_ops.py", "python/ops/scaled_softplus.py", ], @@ -28,6 +29,7 @@ py_library( "//tensorflow/python:embedding_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:function", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", @@ -55,6 +57,19 @@ py_test( ], ) +py_test( + name = "fwd_gradients_test", + size = "small", + srcs = ["python/ops/fwd_gradients_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":nn_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:math_ops", + ], +) + py_test( name = "sampling_ops_test", size = "small", diff --git a/tensorflow/contrib/nn/python/ops/fwd_gradients.py b/tensorflow/contrib/nn/python/ops/fwd_gradients.py new file mode 100644 index 00000000000..922497779b1 --- /dev/null +++ b/tensorflow/contrib/nn/python/ops/fwd_gradients.py @@ -0,0 +1,76 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Forward-mode derivatives.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.gradients_impl import gradients + + +def fwd_gradients(ys, xs, grad_xs=None, assert_unused=False): + """Computes forward-mode derivatives. + + This is accomplished in pure-python using tensorflow's existing (reverse-mode) + gradients. There is additional overhead on graph construction, but runtime + performance should be equal to a manual implementation [citation needed]. + + See https://j-towns.github.io/2017/06/12/A-new-trick.html and + https://github.com/HIPS/autograd/pull/175 for the original discussion of this + method, and https://github.com/renmengye/tensorflow-forward-ad for a "direct" + implementation. + + Args: + ys: A list of tensors. + xs: A list of tensors. + grad_xs: An optional list of tensors. If provided, must have the same length + and shapes compatible with xs. + assert_unused: Add assertions that intermediate values are not computed. + Returns: + A list of tensors of the same shapes as ys. The directional derivatives of + ys with respect to xs in the direction grad_xs. Leaving grad_xs unspecified + is equivalent to passing in 1s for each x in xs. + """ + # This version of forward-mode autodiff is based on code by Tim Cooijmans + # and handles list arguments and certain special cases such as when the + # ys doesn't depend on one or more of the xs, and when tf.IndexedSlices are + # generated by the first tf.gradients call. + + us = [array_ops.zeros_like(y) + float('nan') for y in ys] + + dydxs = gradients(ys, xs, grad_ys=us) + + # deal with strange types that tf.gradients returns but can't deal with + dydxs = [ops.convert_to_tensor(dydx) if isinstance(dydx, ops.IndexedSlices) + else dydx for dydx in dydxs] + + if assert_unused: + with ops.control_dependencies(dydxs): + assert_unused = control_flow_ops.Assert(False, [1], name='fwd_gradients') + with ops.control_dependencies([assert_unused]): + dydxs = array_ops.identity_n(dydxs) + + dydxs = [array_ops.zeros_like(x) if dydx is None else dydx + for x, dydx in zip(xs, dydxs)] + for x, dydx in zip(xs, dydxs): + dydx.set_shape(x.shape) + + dysdx = gradients(dydxs, us, grad_ys=grad_xs) + + return dysdx diff --git a/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py b/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py new file mode 100644 index 00000000000..56062c3cab3 --- /dev/null +++ b/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for forward_ad.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.nn.python.ops import fwd_gradients +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ForwardAdTest(test.TestCase): + + def testSquare(self): + x = constant_op.constant(1.) + y = math_ops.square(x) + grad_x = 3. + + dydx_tf = fwd_gradients.fwd_gradients([y], [x], [grad_x])[0] + dydx_py = 2. * grad_x + + with self.test_session() as sess: + self.assertAllClose(sess.run(dydx_tf), dydx_py, 1e-6) + + def testGather(self): + x = constant_op.constant([1., 2., 3.]) + y = array_ops.gather(x, [0, 1]) + y.set_shape([2]) + dydx = fwd_gradients.fwd_gradients([y], [x], assert_unused=True) + + with self.test_session() as sess: + sess.run(dydx) + + +if __name__ == '__main__': + test.main() From 14fa431da0b7fd69ccf7bf4a60172a5745c1773c Mon Sep 17 00:00:00 2001 From: Ilya Biryukov Date: Tue, 9 Jan 2018 05:27:43 -0800 Subject: [PATCH 18/23] Don't pass '-no-canonical-prefixes' when collecting builtin includes. It matches the way bazel's autoconf works and seems to be the right thing to do. This change should fix #14380. PiperOrigin-RevId: 181305871 --- third_party/gpus/cuda_configure.bzl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 5f1c42dbe48..c09bb222e8c 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -110,11 +110,7 @@ def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): lang = "c++" else: lang = "c" - # TODO: We pass -no-canonical-prefixes here to match the compiler flags, - # but in cuda_clang CROSSTOOL file that is a `feature` and we should - # handle the case when it's disabled and no flag is passed - result = repository_ctx.execute([cc, "-no-canonical-prefixes", - "-E", "-x" + lang, "-", "-v"]) + result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"]) index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN) if index1 == -1: return [] From c5e2b0e5a3039fe98b0f22154c567c2eb425fb22 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Tue, 9 Jan 2018 07:37:55 -0800 Subject: [PATCH 19/23] Exposes runmetadata from tfe in python. PiperOrigin-RevId: 181317960 --- tensorflow/python/eager/context.py | 59 ++++++++++++++++++++++++++++ tensorflow/python/eager/core_test.py | 14 +++++++ tensorflow/python/pywrap_tfe.i | 3 ++ 3 files changed, 76 insertions(+) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 3173afc4240..e1ab1e7bc64 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -24,9 +24,12 @@ import copy import random import threading +from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors +from tensorflow.python.util import compat from tensorflow.python.util import tf_contextlib GRAPH_MODE = 0 @@ -398,6 +401,36 @@ class Context(object): """Get the list of post-execution callbacks added to the context.""" return self._post_execution_callbacks + def enable_run_metadata(self): + """Enables tracing of op execution via RunMetadata. + + To retrieve the accumulated metadata call context.export_run_metadata() + and to stop tracing call context.disable_run_metadata(). + """ + pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle) + + def disable_run_metadata(self): + """Disables tracing of op execution via RunMetadata.""" + pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle) + + def export_run_metadata(self): + """Returns a RunMetadata proto with accumulated information. + + The returned protocol buffer contains information since the most recent call + to either enable_run_metadata or export_run_metadata. + + Returns: + A RunMetadata protocol buffer. + """ + with c_api_util.tf_buffer() as buffer_: + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TFE_ContextExportRunMetadata( + self._context_handle, buffer_, status) + proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) + run_metadata = config_pb2.RunMetadata() + run_metadata.ParseFromString(compat.as_bytes(proto_data)) + return run_metadata + _context = None _context_lock = threading.Lock() @@ -516,3 +549,29 @@ def num_gpus(): The number of available GPU devices. """ return context().num_gpus() + + +def enable_run_metadata(): + """Enables tracing of op execution via RunMetadata. + + To retrieve the accumulated metadata call context.export_run_metadata() + and to stop tracing call context.disable_run_metadata(). + """ + context().enable_run_metadata() + + +def disable_run_metadata(): + """Disables tracing of op execution via RunMetadata.""" + context().disable_run_metadata() + + +def export_run_metadata(): + """Returns a RunMetadata proto with accumulated information. + + The returned protocol buffer contains information since the most recent call + to either enable_run_metadata or export_run_metadata. + + Returns: + A RunMetadata protocol buffer. + """ + return context().export_run_metadata() diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 02694b34fe9..a70fa728048 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -84,6 +84,20 @@ class TFETest(test_util.TensorFlowTestCase): self.assertTrue(has_cpu_device) del ctx + def testRunMetadata(self): + context.enable_run_metadata() + t = constant_op.constant(1.0) + _ = t + t # Runs an operation which will be in the RunMetadata + run_metadata = context.export_run_metadata() + context.disable_run_metadata() + step_stats = run_metadata.step_stats + self.assertGreater(len(step_stats.dev_stats), 0) + cpu_stats = step_stats.dev_stats[0] + self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', + cpu_stats.device) + self.assertEqual(len(cpu_stats.node_stats), 1) + self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add') + def testContextStackContainsEagerMode(self): # Eager execution has been enabled, and no other context # switch has occurred, so `context_stack` should contain diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 42e4773df3e..083931aa836 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -20,6 +20,9 @@ limitations under the License. %rename("%s") TFE_ContextListDevices; %rename("%s") TFE_ContextAddFunction; %rename("%s") TFE_ContextAddFunctionDef; +%rename("%s") TFE_ContextEnableRunMetadata; +%rename("%s") TFE_ContextDisableRunMetadata; +%rename("%s") TFE_ContextExportRunMetadata; %rename("%s") TFE_ContextClearCaches; %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; From 4b7d9fec149b3bd96e3eea72c0c3475b293feeb0 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 9 Jan 2018 10:06:58 -0800 Subject: [PATCH 20/23] [XLA] Run the LLVM verifier after lowering HLO -> LLVM IR. This way if we generate bad IR, we emit a nice error message, instead of (probably) crashing somewhere in LLVM. PiperOrigin-RevId: 181334588 --- .../compiler/xla/service/cpu/cpu_compiler.cc | 19 +++++++++++++++++++ .../compiler/xla/service/gpu/gpu_compiler.cc | 15 +++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 27bdde41af0..705bcb2e9bd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/TargetRegistry.h" @@ -443,6 +444,21 @@ Status InitializeModuleHooks( return Status::OK(); } +Status VerifyLlvmModule(const llvm::Module& llvm_module) { + XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " + "Rerun with --xla_dump_ir_to to get the IR. "; + return Status::OK(); +} + } // namespace StatusOr> CpuCompiler::RunHloPasses( @@ -631,6 +647,7 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); @@ -708,6 +725,7 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); @@ -879,6 +897,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, &module_sequence.at(computation))); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); + TF_RETURN_IF_ERROR(VerifyLlvmModule(llvm_module)); ModuleHook pre_optimization_ir_dump_hook; ModuleHook post_optimization_ir_dump_hook; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 93db9ebbee2..9321429bdcc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -478,6 +479,20 @@ StatusOr> GpuCompiler::RunBackend( entry_computation->root_instruction()->Accept(&ir_emitter)); } + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " + "Rerun with --xla_dump_ir_to to get the IR. "; + } + if (user_pre_optimization_hook_) { TF_CHECK_OK(user_pre_optimization_hook_(llvm_module)); } From 9fcaafd4a9fe6f5868c875749b85c4641c217350 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Jan 2018 10:41:06 -0800 Subject: [PATCH 21/23] Fix typo in ValueError string. PiperOrigin-RevId: 181340112 --- tensorflow/python/estimator/estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index bf175cbe01e..d0f40bd68e0 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -614,7 +614,7 @@ class Estimator(object): if isinstance(result, (list, tuple)): if len(result) != 2: raise ValueError( - 'input_fn should return (feautures, labels) as a len 2 tuple.') + 'input_fn should return (features, labels) as a len 2 tuple.') return result[0], result[1], input_hooks return result, None, input_hooks From 611f18f179e8f0a3b5df59477b4bc1d92fd3a7a4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Jan 2018 10:49:16 -0800 Subject: [PATCH 22/23] Fix _warmstart_var_with_vocab to truly accept either a checkpoint dir or a direct path to checkpoint files. PiperOrigin-RevId: 181341437 --- .../python/estimator/warm_starting_util.py | 5 ++- .../estimator/warm_starting_util_test.py | 42 +++++++++++++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py index 476776daa8f..37ac8515cb8 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/estimator/warm_starting_util.py @@ -121,7 +121,10 @@ class _WarmStartSettings( # where ws could be defined as: # Warm-start all weights in the model (input layer and hidden weights). + # Either the directory or a specific checkpoint can be provided (in the case + # of the former, the latest checkpoint will be used). ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp") + ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000") # Warm-start only the embeddings (input layer). ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp", @@ -348,7 +351,7 @@ def _warmstart_var_with_vocab(var, # TODO(vihanjain): Support _WarmstartSettings where class vocabularies need # remapping too. init = checkpoint_ops._load_and_remap_matrix_initializer( - ckpt_path=saver.latest_checkpoint(prev_ckpt), + ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), old_tensor_name=prev_tensor_name, new_row_vocab_size=current_vocab_size, new_col_vocab_size=v_shape[1], diff --git a/tensorflow/python/estimator/warm_starting_util_test.py b/tensorflow/python/estimator/warm_starting_util_test.py index cf502dd60de..cc0c4efc756 100644 --- a/tensorflow/python/estimator/warm_starting_util_test.py +++ b/tensorflow/python/estimator/warm_starting_util_test.py @@ -50,9 +50,7 @@ class WarmStartingUtilTest(test.TestCase): sess.run(variables.global_variables_initializer()) saver = saver_lib.Saver() ckpt_prefix = os.path.join(self.get_temp_dir(), "model") - ckpt_state_name = "checkpoint" - saver.save( - sess, ckpt_prefix, global_step=0, latest_filename=ckpt_state_name) + saver.save(sess, ckpt_prefix, global_step=0) def _create_prev_run_var(self, var_name, @@ -408,6 +406,44 @@ class WarmStartingUtilTest(test.TestCase): self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, sess) + def testWarmStart_ExplicitCheckpointFile(self): + # Create vocab for sparse column "sc_vocab". + vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], + "vocab") + # Create feature column. + sc_vocab = fc.categorical_column_with_vocabulary_file( + "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4) + + # Save checkpoint from which to warm-start. + _, prev_vocab_val = self._create_prev_run_var( + "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones()) + + partitioner = lambda shape, dtype: [1] * len(shape) + # New graph, new session WITHOUT warmstarting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + cols_to_vars = self._create_linear_model([sc_vocab], partitioner) + sess.run(variables.global_variables_initializer()) + # Without warmstarting, the weights should be initialized using default + # initializer (which is init_ops.zeros_initializer). + self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]}, + sess) + + # New graph, new session with warmstarting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + cols_to_vars = self._create_linear_model([sc_vocab], partitioner) + # Since old vocab is not explicitly set in WarmStartSettings, the old + # vocab is assumed to be same as new vocab. + ws_util._warmstart(ws_util._WarmStartSettings( + # Explicitly provide the file prefix instead of just the dir. + os.path.join(self.get_temp_dir(), "model-0"), + vars_to_warmstart=".*sc_vocab.*")) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warmstarted. + self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, + sess) + def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self): # Create old vocabulary, and use a size smaller than the total number of # entries. From c4889755e0ea8dca3d8b14d7dfce7560e545fc2f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Jan 2018 10:51:31 -0800 Subject: [PATCH 23/23] Adding a CLIF library for generic_tree_model.proto PiperOrigin-RevId: 181341793 --- tensorflow/contrib/decision_trees/proto/BUILD | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index 87c80740a8f..f6de5998d73 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -7,7 +7,11 @@ exports_files([ "generic_tree_model_proto.swig", ]) -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", + "tf_pyclif_proto_library", +) filegroup( name = "all_files", @@ -34,3 +38,10 @@ tf_proto_library( protodeps = [":generic_tree_model"], visibility = ["//visibility:public"], ) + +tf_pyclif_proto_library( + name = "generic_tree_model_pyclif", + proto_lib = ":generic_tree_model", + proto_srcfile = "generic_tree_model.proto", + visibility = ["//visibility:public"], +)