diff --git a/RELEASE.md b/RELEASE.md index 4280d5dd156..4698e65fd87 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -41,6 +41,15 @@ be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post- processing of the rnn. For RNN decoding, this functionality has been replaced with an alternative API in `tf.contrib.seq2seq`. +* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of + optimized deep learning primitives: In addition to matrix multiplication and + convolution, these building blocks include: + Direct batched convolution + Pooling: maximum, minimum, average + Normalization: LRN, batch normalization + Activation: rectified linear unit (ReLU) + Data manipulation: multi-dimensional transposition (conversion), split, + concat, sum and scale. * TensorForest Estimator now supports SavedModel export for serving. * Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters. * TensorFlow C library now available for Windows. diff --git a/WORKSPACE b/WORKSPACE index edf655f6a7b..c9d7b458a90 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "4be8a887f6f38f883236e77bb25c2da10d506f2bf1a8e5d785c0f35574c74ca4", - strip_prefix = "rules_closure-aac19edc557aec9b603cd7ffe359401264ceff0d", + sha256 = "edc91f556b762fc5212d1050d00b12e40dd0b0b1c1d5d96886b59e9a30a6cae4", + strip_prefix = "rules_closure-3f07fb6a58870afbb36051bd5d54da4479561cc6", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", # 2017-05-10 - "https://github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", + "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz", # 2017-05-31 + "https://github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz", ], ) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index bf27c360184..abbfe8d54af 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -393,6 +393,9 @@ filegroup( "//tensorflow/tensorboard/demo:all_files", "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", "//tensorflow/tensorboard/plugins:all_files", + "//tensorflow/tensorboard/plugins/audio:all_files", + "//tensorflow/tensorboard/plugins/distributions:all_files", + "//tensorflow/tensorboard/plugins/graphs:all_files", "//tensorflow/tensorboard/plugins/histograms:all_files", "//tensorflow/tensorboard/plugins/images:all_files", "//tensorflow/tensorboard/plugins/projector:all_files", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index a9644a5555d..77faa475ed4 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -805,6 +805,7 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, } std::vector dim_vec; + dim_vec.reserve(num_dims); for (int i = 0; i < num_dims; ++i) { dim_vec.push_back(ic->MakeDim(dims[i])); } diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index 2879445441d..ba056a8f3a8 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -113,10 +113,12 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, feeds.emplace_back(feed.first.name(), feed.second.tensor); } std::vector output_tensor_names; + output_tensor_names.reserve(fetch_outputs.size()); for (auto const& output : fetch_outputs) { output_tensor_names.push_back(output.name()); } std::vector target_node_names; + target_node_names.reserve(run_outputs.size()); for (auto const& output : run_outputs) { target_node_names.push_back(output.node()->name()); } diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index 8f20ff1457b..b8e5411bf71 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -44,6 +44,7 @@ Status ComputeTheoreticalJacobianTranspose( size_t x_num = x_shapes.size(); // Call AddSymbolicGradients to get 'dxs' (we will feed 'dys'). OutputList dys; + dys.reserve(y_shapes.size()); for (const auto& y_shape : y_shapes) { // TODO(suharshs): This currently assumes that all x's are the same type. dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type())); diff --git a/tensorflow/cc/framework/testutil.cc b/tensorflow/cc/framework/testutil.cc index b0746913a16..ca78f31db51 100644 --- a/tensorflow/cc/framework/testutil.cc +++ b/tensorflow/cc/framework/testutil.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/cc/framework/testutil.h" +#include + #include "tensorflow/cc/client/client_session.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/default_device.h" @@ -30,7 +32,7 @@ void GetTensors(const Scope& scope, OutputList tensors, void GetTensor(const Scope& scope, Output tensor, Tensor* out) { std::vector outputs; - GetTensors(scope, {tensor}, &outputs); + GetTensors(scope, {std::move(tensor)}, &outputs); *out = outputs[0]; } diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index bb79fe81ab3..ca17c5ab690 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -350,6 +350,7 @@ Status CompileXla(xla::CompileOnlyClient* client, compile_result->program_shape = *pshape_or.ValueOrDie(); xla::ProgramShape* pshape = &compile_result->program_shape; std::vector arg_layouts; + arg_layouts.reserve(pshape->parameters_size()); for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 0fe0d8e89b2..277e3c99068 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -218,6 +218,7 @@ cc_library( deps = [ ":common", ":graph_to_functiondef", + ":union_find", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/kernels:xla_local_launch_op", @@ -237,6 +238,11 @@ cc_library( ], ) +cc_library( + name = "union_find", + hdrs = ["union_find.h"], +) + cc_test( name = "compilation_passes_test", size = "small", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index a8869c8e2a7..4a1dbaf05dc 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/cc/framework/ops.h" @@ -101,12 +103,12 @@ Node* Input(const GraphDefBuilder::Options& opts) { } Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { - return ops::UnaryOp("UnaryTest", a, opts); + return ops::UnaryOp("UnaryTest", std::move(a), opts); } Node* Binary(ops::NodeOut a, ops::NodeOut b, const GraphDefBuilder::Options& opts) { - return ops::BinaryOp("BinaryTest", a, b, opts); + return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts); } Node* AddNLike(const std::vector& inputs, @@ -127,7 +129,7 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval", opts.op_registry()); - node_builder.Input(a).Attr("index", index); + node_builder.Input(std::move(a)).Attr("index", index); return opts.FinalizeBuilder(&node_builder); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index ed9d9ad70e4..f1fef85f994 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -206,70 +207,12 @@ Status FindCompilationCandidates( return Status::OK(); } -// Union-Find data structure used to compute clusters. We use our own -// implementation because we want one key feature: when merging clusters, we -// need to know which value becomes the representative of the merged clusters. -// We use the representatives to name nodes in a cycle detection graph, and we -// need to control which node is named. -// TODO(phawkins): consider merging this code with union-find implementations -// in Tensorflow, e.g., in SimplePlacer. -class Cluster { - public: - Cluster(); - - int Size() { return FindRoot()->size_; } - - // Merges this cluster with 'other'. This cluster's representative becomes - // the representative of the merged cluster; the representative of 'other' - // is ignored. - void Merge(Cluster* other); - - // Each cluster has an associated integer 'representative', initialized to -1 - // by default. - int GetRepresentative() { return FindRoot()->representative_; } - void SetRepresentative(int representative) { - FindRoot()->representative_ = representative; - } - - private: - // Finds the root element of the cluster. Performs path compression. - Cluster* FindRoot(); - - int representative_; - int rank_; - int size_; // Size of the cluster. - Cluster* parent_; +struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; }; -Cluster::Cluster() - : representative_(-1), rank_(0), size_(1), parent_(nullptr) {} - -void Cluster::Merge(Cluster* other) { - Cluster* a = FindRoot(); - Cluster* b = other->FindRoot(); - if (a == b) return; - if (a->rank_ > b->rank_) { - b->parent_ = a; - a->size_ += b->size_; - return; - } - - a->parent_ = b; - if (a->rank_ == b->rank_) { - b->rank_++; - } - b->representative_ = a->representative_; - b->size_ += a->size_; -} - -Cluster* Cluster::FindRoot() { - if (!parent_) return this; - // Path compression: update intermediate nodes to point to the root of the - // equivalence class. - parent_ = parent_->FindRoot(); - return parent_; -} - } // anonymous namespace bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { @@ -432,10 +375,11 @@ Status MarkForCompilationPass::RunImpl( // Each compilation candidate belongs to a cluster. The cluster's // representative // names the node in the 'cycles' graph that represents the cluster. - std::vector clusters(graph->num_node_ids()); - std::deque worklist; + std::vector> clusters(graph->num_node_ids()); + std::deque*> worklist; for (Node* node : compilation_candidates) { - clusters[node->id()].SetRepresentative(node->id()); + Cluster& cluster = clusters[node->id()].Get(); + cluster.representative = node->id(); worklist.push_back(&clusters[node->id()]); } @@ -445,7 +389,7 @@ Status MarkForCompilationPass::RunImpl( // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. while (!worklist.empty()) { - int from = worklist.front()->GetRepresentative(); + int from = worklist.front()->Get().representative; worklist.pop_front(); Node* node_from = graph->FindNodeId(from); @@ -518,7 +462,7 @@ Status MarkForCompilationPass::RunImpl( // Count the number of elements in each cluster. std::vector cluster_sizes(graph->num_node_ids()); for (const Node* n : compilation_candidates) { - int cluster = clusters[n->id()].GetRepresentative(); + int cluster = clusters[n->id()].Get().representative; cluster_sizes[cluster]++; } @@ -532,7 +476,7 @@ Status MarkForCompilationPass::RunImpl( // if compilation is enabled, otherwise there will be no such candidates). const int min_cluster_size = flags->tf_xla_min_cluster_size; for (Node* n : compilation_candidates) { - int cluster = clusters[n->id()].GetRepresentative(); + int cluster = clusters[n->id()].Get().representative; // Compile if the user marked this node _XlaCompile=true bool compile_attr = false; diff --git a/tensorflow/compiler/jit/union_find.h b/tensorflow/compiler/jit/union_find.h new file mode 100644 index 00000000000..a1a7a6a4d0d --- /dev/null +++ b/tensorflow/compiler/jit/union_find.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ +#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ + +namespace tensorflow { + +// Union-Find data structure. +// Each cluster has an associated value; when merging clusters we can control +// which value becomes the representative of the merged clusters. Values must be +// copyable. +template +class UnionFind { + public: + UnionFind() : rank_(0), size_(1), parent_(nullptr) {} + + // Returns the number of elements in a cluster. + int Size() { return FindRoot()->size_; } + + // Merges this cluster with 'other'. This cluster's value becomes + // the value of the merged cluster; the value of 'other' is ignored. + void Merge(UnionFind* other); + + // Each cluster has an associated value. Retrieves the value associated + // with this cluster. + T& Get() { return FindRoot()->value_; } + + private: + // Finds the root element of the cluster. Performs path compression. + UnionFind* FindRoot(); + + int rank_; + int size_; // Size of the cluster. + UnionFind* parent_; + T value_; +}; + +template +void UnionFind::Merge(UnionFind* other) { + UnionFind* a = FindRoot(); + UnionFind* b = other->FindRoot(); + if (a == b) return; + if (a->rank_ > b->rank_) { + b->parent_ = a; + a->size_ += b->size_; + return; + } + + a->parent_ = b; + if (a->rank_ == b->rank_) { + b->rank_++; + } + b->value_ = a->value_; + b->size_ += a->size_; +} + +template +UnionFind* UnionFind::FindRoot() { + if (!parent_) return this; + // Path compression: update intermediate nodes to point to the root of the + // equivalence class. + parent_ = parent_->FindRoot(); + return parent_; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 1b5f94d4e5f..1e1d2a1b4b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -50,6 +50,7 @@ class FillOp : public XlaOpKernel { // Convert the dims literal into a vector that we can pass to // ComputationBuilder. std::vector broadcast; + broadcast.reserve(dims_literal.shape().dimensions(0)); for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { broadcast.push_back(xla::LiteralUtil::Get(dims_literal, {i})); } diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 87cd266708b..51c97d85d7f 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -50,6 +50,7 @@ class SliceOp : public XlaOpKernel { // slice will be an empty handle if the output has no elements. CHECK_EQ(begin.size(), size.size()); std::vector limits; + limits.reserve(begin.size()); for (int i = 0; i < begin.size(); ++i) { limits.push_back(begin[i] + size[i]); } diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index e8b2233853d..fe08e83c239 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 6b38f856442..454d0fbd965 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -58,14 +58,13 @@ StatusOr> Client::Transfer( "server provided response without a literal in " "TransferToClient request"); } - - return WrapUnique(response.release_literal()); + return MakeUnique(response.literal()); } StatusOr> Client::TransferToServer( const Literal& literal, const DeviceHandle* device_handle) { TransferToServerRequest request; - *request.mutable_literal() = literal; + *request.mutable_literal() = literal.ToProto(); if (device_handle) { *request.mutable_device_handle() = *device_handle; } @@ -93,7 +92,7 @@ StatusOr> Client::TransferToServer( Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, const DeviceHandle* device_handle) { TransferToInfeedRequest request; - *request.mutable_literal() = literal; + *request.mutable_literal() = literal.ToProto(); if (device_handle) { *request.mutable_device_handle() = *device_handle; } @@ -141,7 +140,8 @@ StatusOr> Client::TransferFromOutfeed( "TransferToClient request"); } - return WrapUnique(response.release_literal()); + Literal literal(response.literal()); + return MakeUnique(literal); } Status Client::ResetDevice() { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 50de730a52b..797835160fa 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 22a70681468..940d38c44e7 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -165,9 +165,10 @@ ComputationDataHandle ComputationBuilder::ConstantOp( } ConstantRequest request; - Literal* literal = request.mutable_literal(); - populate(literal); - VLOG(3) << "created constant: " << literal->ShortDebugString(); + Literal literal; + populate(&literal); + *request.mutable_literal() = literal.ToProto(); + VLOG(3) << "created constant: " << request.literal().ShortDebugString(); OpRequest op_request; *op_request.mutable_constant_request() = request; *op_request.mutable_computation() = computation_.handle(); diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index be706f7d232..40f59eaa68e 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include +#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -23,7 +24,7 @@ limitations under the License. namespace xla { GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle) - : handle_(handle), parent_(parent) {} + : handle_(std::move(handle)), parent_(parent) {} GlobalData::~GlobalData() { UnregisterRequest request; diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 6f2914b4718..96944a53b7e 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -222,8 +222,9 @@ tensorflow::Status LocalExecutable::RecordArguments( SessionModule* session_module) { session_module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - TF_RETURN_IF_ERROR( - LiteralFromShapedBuffer(*argument, session_module->add_arguments())); + Literal literal; + TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal)); + *session_module->add_arguments() = literal.ToProto(); } return tensorflow::Status::OK(); } @@ -231,9 +232,13 @@ tensorflow::Status LocalExecutable::RecordArguments( tensorflow::Status LocalExecutable::RecordResult( const ShapedBuffer* result, SessionModule* session_module) { session_module->clear_result(); - return LiteralFromShapedBuffer(*result, session_module->mutable_result()); + Literal literal(session_module->result()); + TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal)); + *session_module->mutable_result() = literal.ToProto(); + return tensorflow::Status::OK(); } +// TODO(dnovillo) Change signature to return StatusOr. tensorflow::Status LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer, Literal* literal) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index ec4012a7036..4648680dc53 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -36,7 +36,7 @@ limitations under the License. namespace xla { -LiteralUtil::StrideConfig::StrideConfig( +Literal::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, tensorflow::gtl::ArraySlice dimensions) : dimensions(dimensions), @@ -59,30 +59,28 @@ LiteralUtil::StrideConfig::StrideConfig( } } -/* static */ std::unique_ptr LiteralUtil::CreateFromShape( - const Shape& shape) { +std::unique_ptr Literal::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(); *literal->mutable_shape() = shape; - Reserve(ShapeUtil::ElementsIn(literal->shape()), literal.get()); + literal->Reserve(ShapeUtil::ElementsIn(literal->shape())); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( +/* static */ std::unique_ptr Literal::CreateFromDimensions( PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } template -/* static */ Status LiteralUtil::CopyRange( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +Status Literal::CopyRange(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { const Shape& src_shape = src_literal.shape(); - const Shape& dest_shape = dest_literal->shape(); - tensorflow::gtl::ArraySlice src_data = GetArraySlice(src_literal); - tensorflow::gtl::MutableArraySlice dest_data = - GetMutableArraySlice(dest_literal); + const Shape& dest_shape = shape(); + tensorflow::gtl::ArraySlice src_data = src_literal.GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = GetMutableArraySlice(); TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); @@ -90,8 +88,8 @@ template // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); - StridedCopy(dest_data, LinearIndex(*dest_literal, dest_base), 0, src_data, - LinearIndex(src_literal, src_base), 0, 1); + StridedCopy(dest_data, LinearIndex(dest_base), 0, src_data, + src_literal.LinearIndex(src_base), 0, 1); } else if (!ShapeUtil::HasZeroElements(dest_shape)) { TF_RET_CHECK(!ShapeUtil::HasZeroElements(src_shape)); TF_RET_CHECK(src_base.size() == dest_base.size()); @@ -113,8 +111,8 @@ template std::transform(indexes.begin(), indexes.end(), dest_base.begin(), dest_indexes.begin(), std::plus()); - int64 src_index = LinearIndex(src_literal, src_indexes); - int64 dest_index = LinearIndex(*dest_literal, dest_indexes); + int64 src_index = src_literal.LinearIndex(src_indexes); + int64 dest_index = LinearIndex(dest_indexes); StridedCopy(dest_data, dest_index, stride_config.dest_stride, src_data, src_index, stride_config.source_stride, @@ -129,37 +127,28 @@ template return Status::OK(); } -/* static */ Status LiteralUtil::Copy( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - TF_RET_CHECK( - ShapeUtil::SameElementType(src_literal.shape(), dest_literal->shape())); +Status Literal::Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); switch (src_literal.shape().element_type()) { case U32: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case U64: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case S32: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case S64: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case F16: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case F32: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case F64: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); case PRED: - return CopyRange(src_literal, src_base, dest_literal, dest_base, - copy_size); + return CopyRange(src_literal, src_base, dest_base, copy_size); default: break; } @@ -167,28 +156,28 @@ template src_literal.shape().element_type()); } -/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { +/* static */ Literal Literal::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case U32: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case U64: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case S8: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case S32: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case S64: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case F16: - return *LiteralUtil::CreateR0(static_cast(0.0f)); + return *Literal::CreateR0(static_cast(0.0f)); case F32: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case F64: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case PRED: - return *LiteralUtil::CreateR0(false); + return *Literal::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -201,31 +190,31 @@ template } } -/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { +/* static */ Literal Literal::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case U32: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case U64: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case S8: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case S32: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case S64: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case F32: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case F64: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case PRED: - return *LiteralUtil::CreateR0(true); + return *Literal::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *LiteralUtil::CreateR0(static_cast(1.0f)); + return *Literal::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -235,33 +224,32 @@ template } } -/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { +/* static */ Literal Literal::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case U32: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case U64: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case S8: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case S32: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case S64: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case F32: - return *LiteralUtil::CreateR0( - -std::numeric_limits::infinity()); + return *Literal::CreateR0(-std::numeric_limits::infinity()); case F64: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( -std::numeric_limits::infinity()); case PRED: - return *LiteralUtil::CreateR0(false); + return *Literal::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; @@ -272,33 +260,32 @@ template } } -/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { +/* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case U32: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case U64: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case S8: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case S32: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case S64: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case F32: - return *LiteralUtil::CreateR0( - std::numeric_limits::infinity()); + return *Literal::CreateR0(std::numeric_limits::infinity()); case F64: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( std::numeric_limits::infinity()); case PRED: - return *LiteralUtil::CreateR0(true); + return *Literal::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; @@ -309,166 +296,161 @@ template } } -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ std::unique_ptr Literal::CreateR1( const tensorflow::core::Bitmap& values) { auto literal = MakeUnique(); - PopulateR1(values, literal.get()); + literal->PopulateR1(values); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR1U8( +/* static */ std::unique_ptr Literal::CreateR1U8( tensorflow::StringPiece value) { auto literal = MakeUnique(); *literal->mutable_shape() = ShapeUtil::MakeShape(U8, {static_cast(value.size())}); - literal->set_u8s(value.ToString()); + literal->set_u8s(tensorflow::StringPiece(value.ToString())); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( - float from, float to, int64 rows, int64 cols) { +/* static */ std::unique_ptr Literal::CreateR2F32Linspace(float from, + float to, + int64 rows, + int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -/* static */ std::unique_ptr LiteralUtil::Relayout( - const Literal& original, const Layout& layout) { - std::unique_ptr result = CloneToUnique(original); +std::unique_ptr Literal::Relayout(const Layout& layout) const { + std::unique_ptr result = CloneToUnique(); *result->mutable_shape()->mutable_layout() = layout; - const Shape& shape = original.shape(); - DimensionVector base(ShapeUtil::Rank(shape), 0); - DimensionVector copy_size(shape.dimensions().begin(), - shape.dimensions().end()); + DimensionVector base(ShapeUtil::Rank(shape()), 0); + DimensionVector copy_size(shape().dimensions().begin(), + shape().dimensions().end()); - TF_CHECK_OK(Copy(original, base, result.get(), base, copy_size)); + TF_CHECK_OK(result->Copy(*this, base, base, copy_size)); return result; } -/* static */ StatusOr> LiteralUtil::Reshape( - const xla::Literal& input, tensorflow::gtl::ArraySlice dimensions) { - if (ShapeUtil::IsTuple(input.shape())) { +StatusOr> Literal::Reshape( + tensorflow::gtl::ArraySlice dimensions) const { + if (ShapeUtil::IsTuple(shape())) { return InvalidArgument("Reshape does not support tuples."); } std::unique_ptr output; - if (!LayoutUtil::IsMonotonicWithDim0Major(input.shape().layout())) { - std::vector minor_to_major(ShapeUtil::Rank(input.shape())); + if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { + std::vector minor_to_major(ShapeUtil::Rank(shape())); std::iota(minor_to_major.rbegin(), minor_to_major.rend(), static_cast(0)); - output = Relayout(input, LayoutUtil::MakeLayout(minor_to_major)); + output = Relayout(LayoutUtil::MakeLayout(minor_to_major)); } else { - output = CloneToUnique(input); + output = CloneToUnique(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. *output->mutable_shape() = - ShapeUtil::MakeShape(input.shape().element_type(), dimensions); + ShapeUtil::MakeShape(shape().element_type(), dimensions); - int64 elements_before = ShapeUtil::ElementsIn(input.shape()); + int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); if (elements_before != elements_after) { return InvalidArgument( - "Shapes before and after LiteralUtil::Reshape have different numbers " + "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", - ShapeUtil::HumanString(input.shape()).c_str(), + ShapeUtil::HumanString(shape()).c_str(), ShapeUtil::HumanString(output->shape()).c_str()); } return std::move(output); } -/* static */ std::unique_ptr LiteralUtil::Transpose( - const Literal& original, tensorflow::gtl::ArraySlice permutation) { - CHECK(!ShapeUtil::IsTuple(original.shape())) - << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, ShapeUtil::Rank(original.shape()))) +std::unique_ptr Literal::Transpose( + tensorflow::gtl::ArraySlice permutation) const { + CHECK(!ShapeUtil::IsTuple(shape())) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; // To transpose the array, we just permute the dimensions and layout, and // do a straight memory copy of the raw data set. // This is considerably faster than iterating over every array element using // the EachCell<>() and Set<>() APIs. std::vector inverse_permutation = InversePermutation(permutation); - Shape shape = - ShapeUtil::PermuteDimensions(inverse_permutation, original.shape()); - // Replace the layout with one affine to the original shape, such that a + Shape permuted_shape = + ShapeUtil::PermuteDimensions(inverse_permutation, shape()); + // Replace the layout with one affine to this shape, such that a // transpose operation can be performed by leaving the flat values // representation intact. // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. // The shape with affine layout resulting from that operation will be - // F32[8,11]{0,1}, since it leave the original most minor (the 8 sized), the + // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the // most minor. // Essentially, given MinMaj(Di) the position of the Di dimension within the // minor to major vector, and given T(Di) the index that the original Di // dimension has within the transposed array, a layout is affine if // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major // vector of the affine layout. - Layout* layout = shape.mutable_layout(); + Layout* layout = permuted_shape.mutable_layout(); layout->clear_minor_to_major(); - for (auto index : original.shape().layout().minor_to_major()) { + for (auto index : shape().layout().minor_to_major()) { layout->add_minor_to_major(inverse_permutation[index]); } - std::unique_ptr new_literal = CreateFromShape(shape); + std::unique_ptr new_literal = CreateFromShape(permuted_shape); DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), - ShapeUtil::ByteSizeOf(original.shape())); - std::memcpy(MutableInternalData(new_literal.get()), InternalData(original), - ShapeUtil::ByteSizeOf(original.shape())); + ShapeUtil::ByteSizeOf(shape())); + std::memcpy(new_literal->MutableInternalData(), InternalData(), + ShapeUtil::ByteSizeOf(shape())); return new_literal; } -/* static */ std::unique_ptr LiteralUtil::Slice( - const Literal& literal, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) { - CHECK(!ShapeUtil::IsTuple(literal.shape())) - << "tuple is not supported for reshape"; +std::unique_ptr Literal::Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const { + CHECK(!ShapeUtil::IsTuple(shape())) << "tuple is not supported for reshape"; DimensionVector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(literal.shape()); ++dnum) { + for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { CHECK_GE(start_indices[dnum], 0); - CHECK_LE(limit_indices[dnum], literal.shape().dimensions(dnum)); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)); int64 dimension = limit_indices[dnum] - start_indices[dnum]; CHECK_GT(dimension, 0); result_dimensions.push_back(dimension); } const auto result_shape = ShapeUtil::MakeShapeWithLayout( - literal.shape().element_type(), result_dimensions, - AsInt64Slice(literal.shape().layout().minor_to_major())); + shape().element_type(), result_dimensions, + AsInt64Slice(shape().layout().minor_to_major())); auto result_literal = MakeUnique(); *result_literal->mutable_shape() = result_shape; - Reserve(ShapeUtil::ElementsIn(result_shape), result_literal.get()); + result_literal->Reserve(ShapeUtil::ElementsIn(result_shape)); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: - LiteralUtil::EachCell( - *result_literal, + result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } - float value = LiteralUtil::Get(literal, new_indices); - LiteralUtil::Set(result_literal.get(), indices, value); + float value = Get(new_indices); + result_literal->Set(indices, value); }); return result_literal; case S32: - LiteralUtil::EachCell( - *result_literal, + result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } - int32 value = LiteralUtil::Get(literal, new_indices); - LiteralUtil::Set(result_literal.get(), indices, value); + int32 value = Get(new_indices); + result_literal->Set(indices, value); }); return result_literal; case U32: - LiteralUtil::EachCell( - *result_literal, + result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } - uint32 value = LiteralUtil::Get(literal, new_indices); - LiteralUtil::Set(result_literal.get(), indices, value); + uint32 value = Get(new_indices); + result_literal->Set(indices, value); }); return result_literal; default: @@ -477,98 +459,95 @@ template } } -/* static */ std::unique_ptr LiteralUtil::CloneToUnique( - const Literal& literal) { +std::unique_ptr Literal::CloneToUnique() const { auto unique = MakeUnique(); - *unique = literal; + *unique = *this; return unique; } -/* static */ string LiteralUtil::GetAsString( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - switch (literal.shape().element_type()) { +string Literal::GetAsString( + tensorflow::gtl::ArraySlice multi_index) const { + switch (shape().element_type()) { case PRED: - return Get(literal, multi_index) ? "true" : "false"; + return Get(multi_index) ? "true" : "false"; case U8: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case S32: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case S64: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case U32: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case U64: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case F32: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case F64: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case F16: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); default: return tensorflow::strings::StrCat( - "[", PrimitiveType_Name(literal.shape().element_type()), "]"); + "[", PrimitiveType_Name(shape().element_type()), "]"); } } -/* static */ int64 LiteralUtil::LinearIndex( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), - multi_index); +int64 Literal::LinearIndex( + tensorflow::gtl::ArraySlice multi_index) const { + return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); } -/* static */ string LiteralUtil::ToString(const Literal& literal) { - const Shape& shape = literal.shape(); +string Literal::ToString() const { std::vector pieces; auto element_to_string = - [&literal](tensorflow::gtl::ArraySlice indices) -> string { - PrimitiveType element_type = literal.shape().element_type(); + [this](tensorflow::gtl::ArraySlice indices) -> string { + PrimitiveType element_type = shape().element_type(); if (element_type == PRED) { // We display predicates in a densely packed form. - return Get(literal, indices) ? "1" : "0"; + return Get(indices) ? "1" : "0"; } return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - GetAsString(literal, indices); + GetAsString(indices); }; // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(shape)) { - pieces.push_back(ShapeUtil::HumanString(shape)); + if (ShapeUtil::IsTuple(shape())) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" (\n"); - for (const auto& element_literal : literal.tuple_literals()) { - pieces.push_back(ToString(element_literal)); + for (const auto& element_literal : tuple_literals()) { + pieces.push_back(element_literal.ToString()); pieces.push_back(",\n"); } pieces.push_back(")"); - } else if (ShapeUtil::Rank(shape) == 0) { - pieces.push_back(GetAsString(literal, {})); - } else if (ShapeUtil::Rank(shape) == 1) { + } else if (ShapeUtil::Rank(shape()) == 0) { + pieces.push_back(GetAsString({})); + } else if (ShapeUtil::Rank(shape()) == 1) { pieces.push_back("{"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(element_to_string({i0})); } pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape) == 2) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 2) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(" { "); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back(element_to_string({i0, i1})); } pieces.push_back(" "); pieces.push_back("},\n"); } pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape) == 3) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 3) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back(element_to_string({i0, i1, i2})); } pieces.push_back(" }"); @@ -576,17 +555,17 @@ template pieces.push_back(" }"); } pieces.push_back("\n}"); - } else if (ShapeUtil::Rank(shape) == 4) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 4) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( tensorflow::strings::Printf(" { // i1=%lld\n", i1)); - for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back(" {"); - for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) { + for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(element_to_string({i0, i1, i2, i3})); } pieces.push_back("},\n"); @@ -596,20 +575,20 @@ template pieces.push_back(" },\n"); } pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape) == 5) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 5) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( tensorflow::strings::Printf(" { // i1=%lld\n", i1)); - for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back( tensorflow::strings::Printf(" { // i2=%lld\n", i2)); - for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) { + for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(" {"); - for (int64 i4 = 0; i4 < shape.dimensions(4); ++i4) { + for (int64 i4 = 0; i4 < shape().dimensions(4); ++i4) { pieces.push_back(element_to_string({i0, i1, i2, i3, i4})); } pieces.push_back("},\n"); @@ -622,14 +601,14 @@ template } pieces.push_back("}"); } else { - pieces.push_back(ShapeUtil::HumanString(shape)); + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {...}"); } return tensorflow::str_util::Join(pieces, ""); } -/* static */ std::unique_ptr LiteralUtil::MakeTuple( +/* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { auto literal = MakeUnique(); std::vector shape; @@ -641,136 +620,137 @@ template return literal; } -/* static */ const void* LiteralUtil::InternalData(const Literal& literal) { - switch (literal.shape().element_type()) { +const void* Literal::InternalData() const { + return const_cast( + const_cast(this)->MutableInternalData()); +} + +void* Literal::MutableInternalData() { + // NOTE: We access the vectors directly to avoid the const reference + // created by the accessor functions. + switch (shape().element_type()) { case PRED: - return reinterpret_cast(literal.preds().data()); + return reinterpret_cast(preds_.data()); case U8: - return reinterpret_cast(literal.u8s().data()); + return reinterpret_cast(u8s_.data()); case S32: - return reinterpret_cast(literal.s32s().data()); + return reinterpret_cast(s32s_.data()); case S64: - return reinterpret_cast(literal.s64s().data()); + return reinterpret_cast(s64s_.data()); case U32: - return reinterpret_cast(literal.u32s().data()); + return reinterpret_cast(u32s_.data()); case U64: - return reinterpret_cast(literal.u64s().data()); + return reinterpret_cast(u64s_.data()); case F32: - return reinterpret_cast(literal.f32s().data()); + return reinterpret_cast(f32s_.data()); case F64: - return reinterpret_cast(literal.f64s().data()); + return reinterpret_cast(f64s_.data()); case F16: - return reinterpret_cast(literal.f16s().data()); + return reinterpret_cast(f16s_.data()); default: LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(literal.shape().element_type()); + << PrimitiveType_Name(shape().element_type()); } } -/* static */ void* LiteralUtil::MutableInternalData(Literal* literal) { - return const_cast(LiteralUtil::InternalData(*literal)); -} - -/* static */ void LiteralUtil::Reserve(int64 num_elements, Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - switch (literal->shape().element_type()) { +void Literal::Reserve(int64 num_elements) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + switch (shape().element_type()) { case PRED: - Resize(num_elements, false, literal); + Resize(num_elements, false); break; case S8: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case U8: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case S32: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case S64: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case U32: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case U64: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case F32: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case F64: - Resize(num_elements, 0, literal); + Resize(num_elements, 0); break; case F16: - Resize(num_elements, static_cast(0.0f), literal); + Resize(num_elements, static_cast(0.0f)); break; default: LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(literal->shape().element_type()); + << PrimitiveType_Name(shape().element_type()); } } -/* static */ tensorflow::Status LiteralUtil::ValidateLiteral( - const Literal& literal) { - TF_CHECK_OK(ShapeUtil::ValidateShape(literal.shape())); - int64 expected = ShapeUtil::ElementsIn(literal.shape()); +tensorflow::Status Literal::ValidateLiteral() const { + TF_CHECK_OK(ShapeUtil::ValidateShape(shape())); + int64 expected = ShapeUtil::ElementsIn(shape()); int64 actual = -1; - switch (literal.shape().element_type()) { + switch (shape().element_type()) { case PRED: - actual = literal.preds().size(); + actual = preds_size(); break; case U8: - actual = literal.u8s().size(); + actual = u8s_size(); break; case S32: - actual = literal.s32s_size(); + actual = s32s_size(); break; case U32: - actual = literal.u32s_size(); + actual = u32s_size(); break; case S64: - actual = literal.s64s_size(); + actual = s64s_size(); break; case U64: - actual = literal.u64s_size(); + actual = u64s_size(); break; case F32: - actual = literal.f32s_size(); + actual = f32s_size(); break; case F64: - actual = literal.f64s_size(); + actual = f64s_size(); break; case F16: - actual = literal.f16s().size() / sizeof(half); + actual = f16s().size() / sizeof(half); break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + - PrimitiveType_Name(literal.shape().element_type())); + PrimitiveType_Name(shape().element_type())); } if (expected != actual) { return tensorflow::errors::InvalidArgument(tensorflow::strings::Printf( "literal has bad number of elements for its shape %s: want %lld " "got %lld", - ShapeUtil::HumanString(literal.shape()).c_str(), expected, actual)); + ShapeUtil::HumanString(shape()).c_str(), expected, actual)); } return tensorflow::Status::OK(); } -/* static */ void LiteralUtil::EachCellAsString( - const Literal& literal, +void Literal::EachCellAsString( const std::function indices, - const string& value)>& per_cell) { - if (ShapeUtil::HasZeroElements(literal.shape())) { + const string& value)>& per_cell) const { + if (ShapeUtil::HasZeroElements(shape())) { return; } std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( - literal.shape(), /*linear_index=*/0); + shape(), /*linear_index=*/0); do { - per_cell(indices, GetAsString(literal, indices)); - } while (IndexUtil::BumpIndices(literal.shape(), &indices)); + per_cell(indices, GetAsString(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); } namespace { @@ -784,8 +764,8 @@ template bool EqualElements(const Literal& literal1, const Literal& literal2, int dimension, std::vector* multi_index) { if (dimension == ShapeUtil::Rank(literal1.shape())) { - return (LiteralUtil::Get(literal1, *multi_index) == - LiteralUtil::Get(literal2, *multi_index)); + return (literal1.Get(*multi_index) == + literal2.Get(*multi_index)); } for (int64 i = 0; i < literal1.shape().dimensions(dimension); ++i) { (*multi_index)[dimension] = i; @@ -799,219 +779,197 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, } // namespace -/* static */ bool LiteralUtil::Equal(const Literal& literal1, - const Literal& literal2) { - if (!ShapeUtil::Compatible(literal1.shape(), literal2.shape())) { +bool Literal::Equal(const Literal& literal2) const { + if (!ShapeUtil::Compatible(shape(), literal2.shape())) { return false; } - if (ShapeUtil::IsTuple(literal1.shape())) { + if (ShapeUtil::IsTuple(shape())) { // Because the shapes are compatible, they must have the same number of // tuple elements. - CHECK_EQ(literal1.tuple_literals_size(), literal2.tuple_literals_size()); - for (int i = 0; i < literal1.tuple_literals_size(); ++i) { - if (!Equal(literal1.tuple_literals(i), literal2.tuple_literals(i))) { + CHECK_EQ(tuple_literals_size(), literal2.tuple_literals_size()); + for (int i = 0; i < tuple_literals_size(); ++i) { + if (!tuple_literals(i).Equal(literal2.tuple_literals(i))) { return false; } } return true; } else { - std::vector multi_index(ShapeUtil::Rank(literal1.shape()), 0); - switch (literal1.shape().element_type()) { + std::vector multi_index(ShapeUtil::Rank(shape()), 0); + switch (shape().element_type()) { case PRED: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case U8: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case S32: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case S64: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case U32: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case U64: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case F32: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case F64: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case F16: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); default: - LOG(FATAL) << "Unimplemented: LiteralUtil::Equal for type " - << PrimitiveType_Name(literal1.shape().element_type()); + LOG(FATAL) << "Unimplemented: Literal::Equal for type " + << PrimitiveType_Name(shape().element_type()); } } } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_preds(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_preds(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // C++11 standard, basic_string 21.4.1.5, values should be stored // contiguously. From C++17 a mutable data() member will be provided. - auto values = literal->mutable_u8s(); + auto values = mutable_u8s(); return tensorflow::gtl::MutableArraySlice( reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // C++11 standard, basic_string 21.4.1.5, values should be stored // contiguously. From C++17 a mutable data() member will be provided. - auto values = literal->mutable_u8s(); + auto values = mutable_u8s(); return tensorflow::gtl::MutableArraySlice( reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_s32s(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_s32s(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_u32s(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_u32s(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { static_assert(sizeof(int64) == sizeof(tensorflow::protobuf_int64) && alignof(int64) == alignof(tensorflow::protobuf_int64), "The int64 and tensorflow::protobuf_int64 types are not " "compatible"); - auto values = literal->mutable_s64s(); + auto values = mutable_s64s(); // Because of the fact that tensorflow::protobuf_int64 is defined as int64_t // while tensorflow::int64 is defined as long long, a reinterpret_cast<> is // necessary from the raw data pointer returned by the mutable_data() API. return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->mutable_data()), values->size()); + reinterpret_cast(values->data()), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { static_assert(sizeof(uint64) == sizeof(tensorflow::protobuf_uint64) && alignof(uint64) == alignof(tensorflow::protobuf_uint64), "The uint64 and tensorflow::protobuf_uint64 types are not " "compatible"); - auto values = literal->mutable_u64s(); + auto values = mutable_u64s(); // Because of the fact that tensorflow::protobuf_uint64 is defined as uint64_t // while tensorflow::uint64 is defined as unsigned long long, a // reinterpret_cast<> is necessary from the raw data pointer returned by the // mutable_data() API. return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->mutable_data()), values->size()); + reinterpret_cast(values->data()), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_f32s(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_f32s(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { - auto values = literal->mutable_f64s(); - return tensorflow::gtl::MutableArraySlice(values->mutable_data(), +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_f64s(); + return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal) { +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // C++11 standard, basic_string 21.4.1.5, values should be stored // contiguously. From C++17 a mutable data() member will be provided. // TODO - there is an endianess problem here. fix it, or wait for uint16 // support in protobuf - auto values = literal->mutable_f16s(); + auto values = mutable_f16s(); return tensorflow::gtl::MutableArraySlice( reinterpret_cast(&(*values)[0]), values->size() / sizeof(half)); } template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), PRED); - return literal.preds(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), PRED); + return tensorflow::gtl::ArraySlice(preds().data(), preds().size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), U8); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), U8); return tensorflow::gtl::ArraySlice( - reinterpret_cast(literal.u8s().data()), - literal.u8s().size()); + reinterpret_cast(u8s().data()), u8s().size()); } template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), S8); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), S8); return tensorflow::gtl::ArraySlice( - reinterpret_cast(literal.u8s().data()), - literal.u8s().size()); + reinterpret_cast(u8s().data()), u8s().size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), U32); - return literal.u32s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), U32); + return u32s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), U64); - return AsUInt64Slice(literal.u64s()); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), U64); + return u64s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), S32); - return literal.s32s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), S32); + return s32s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), S64); - return AsInt64Slice(literal.s64s()); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), S64); + return s64s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), F64); - return literal.f64s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), F64); + return f64s(); } template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal) { - CHECK_EQ(literal.shape().element_type(), F16); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), F16); return tensorflow::gtl::ArraySlice( - reinterpret_cast(literal.f16s().data()), - literal.f16s().size() / sizeof(half)); + reinterpret_cast(f16s().data()), + f16s().size() / sizeof(half)); } template @@ -1019,48 +977,48 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { auto multi_index = IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - if (LiteralUtil::Get(literal, multi_index) != value) { + if (literal.Get(multi_index) != value) { return false; } } return true; } -/* static */ bool LiteralUtil::IsAll(const Literal& literal, int8 value) { - switch (literal.shape().element_type()) { +bool Literal::IsAll(int8 value) const { + switch (shape().element_type()) { case U8: if (value >= 0) { - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); } return false; case U32: if (value >= 0) { - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); } return false; case U64: if (value >= 0) { - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); } return false; case S8: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case S32: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case S64: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F32: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F64: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F16: - return AllElementsEqualValue(literal, static_cast(value)); + return AllElementsEqualValue(*this, static_cast(value)); case PRED: if (value == 0) { - return AllElementsEqualValue(literal, false); + return AllElementsEqualValue(*this, false); } if (value == 1) { - return AllElementsEqualValue(literal, true); + return AllElementsEqualValue(*this, true); } return false; default: @@ -1068,119 +1026,223 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { } } -/* static */ bool LiteralUtil::IsAllFloat(const Literal& literal, float value) { - switch (literal.shape().element_type()) { +bool Literal::IsAllFloat(float value) const { + switch (shape().element_type()) { case F32: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F64: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F16: - return AllElementsEqualValue(literal, static_cast(value)); + return AllElementsEqualValue(*this, static_cast(value)); default: return false; } } -/* static */ bool LiteralUtil::IsZero( - const Literal& literal, tensorflow::gtl::ArraySlice indices) { - switch (literal.shape().element_type()) { +bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { + switch (shape().element_type()) { case U8: - return Get(literal, indices) == 0; + return Get(indices) == 0; case U32: - return Get(literal, indices) == 0; + return Get(indices) == 0; case U64: - return Get(literal, indices) == 0; + return Get(indices) == 0; case S8: - return Get(literal, indices) == 0; + return Get(indices) == 0; case S32: - return Get(literal, indices) == 0; + return Get(indices) == 0; case S64: - return Get(literal, indices) == 0; + return Get(indices) == 0; case F32: - return Get(literal, indices) == 0.0f; + return Get(indices) == 0.0f; case F64: - return Get(literal, indices) == 0.0; + return Get(indices) == 0.0; case F16: - return Get(literal, indices) == static_cast(0.0f); + return Get(indices) == static_cast(0.0f); case PRED: - return Get(literal, indices) == false; + return Get(indices) == false; default: LOG(FATAL) << "Input literal must be an array."; } } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_preds()->Resize(num_elements, value); +/* static */ void Literal::Resize(int64 num_elements, bool value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_preds()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_u8s()->resize(num_elements, value); +void Literal::Resize(int64 num_elements, int8 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u8s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_u8s()->resize(num_elements, value); +void Literal::Resize(int64 num_elements, uint8 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u8s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_s32s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, int32 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_s32s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_u32s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, uint32 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u32s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_s64s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, int64 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_s64s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_u64s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, uint64 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u64s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, float value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_f32s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, float value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_f32s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, double value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_f64s()->Resize(num_elements, value); +void Literal::Resize(int64 num_elements, double value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_f64s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, half value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - literal->mutable_f16s()->resize(num_elements * sizeof(half)); - auto data = GetMutableArraySlice(literal); +void Literal::Resize(int64 num_elements, half value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_f16s()->resize(num_elements * sizeof(half)); + auto data = GetMutableArraySlice(); for (int i = 0; i < num_elements; i++) { data[i] = value; } } +template +static void CopyToRepeatedField(RepeatedFieldT* dest, + const std::vector& src) { + *dest = RepeatedFieldT(src.begin(), src.end()); +} + +template +static void CopyToRepeatedBoolField(RepeatedFieldT* dest, + const BoolVector& src) { + *dest = RepeatedFieldT(src.begin(), src.end()); +} + +LiteralProto Literal::ToProto() const { + LiteralProto proto; + proto.Clear(); + *proto.mutable_shape() = shape(); + switch (shape().element_type()) { + case PRED: + if (preds().begin()) { + CopyToRepeatedBoolField(proto.mutable_preds(), preds()); + } + break; + case U8: + *proto.mutable_u8s() = u8s_string(); + break; + case S32: + CopyToRepeatedField(proto.mutable_s32s(), s32s()); + break; + case S64: + CopyToRepeatedField(proto.mutable_s64s(), s64s()); + break; + case U32: + CopyToRepeatedField(proto.mutable_u32s(), u32s()); + break; + case U64: + CopyToRepeatedField(proto.mutable_u64s(), u64s()); + break; + case F16: + *proto.mutable_f16s() = + string(reinterpret_cast(f16s_.data()), + f16s_.size() / sizeof(half)); + break; + case F32: + CopyToRepeatedField(proto.mutable_f32s(), f32s()); + break; + case F64: + CopyToRepeatedField(proto.mutable_f64s(), f64s()); + break; + case TUPLE: + for (const auto& tuple : tuple_literals()) { + *proto.add_tuple_literals() = tuple.ToProto(); + } + break; + default: + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + } + + return proto; +} + +template +static void CopyFromRepeatedField(std::vector* dest, + const RepeatedFieldT& src) { + *dest = std::vector(src.begin(), src.end()); +} + +void Literal::CopyFromProto(const LiteralProto& literal_proto) { + if (!literal_proto.has_shape()) { + return; + } + + *mutable_shape() = literal_proto.shape(); + switch (shape().element_type()) { + case PRED: + *mutable_preds() = BoolVector(literal_proto.preds().begin(), + literal_proto.preds().end()); + break; + case U8: + set_u8s(literal_proto.u8s()); + break; + case S32: + CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s()); + break; + case S64: + CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s()); + break; + case U32: + CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s()); + break; + case U64: + CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s()); + break; + case F16: { + const string& s(literal_proto.f16s()); + CHECK_EQ(0, s.size() % sizeof(half)); + f16s_ = std::vector(s.size() / sizeof(half)); + memcpy(f16s_.data(), s.data(), s.size() / sizeof(half)); + break; + } + case F32: + CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s()); + break; + case F64: + CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); + break; + case TUPLE: + for (const auto& proto : literal_proto.tuple_literals()) { + mutable_tuple_literals()->push_back(Literal(proto)); + } + break; + default: + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 8e06f35b33d..8f6a70ffff9 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -48,15 +49,210 @@ limitations under the License. namespace xla { +// This class is a simple vector of boolean values. It's used to workaround some +// implementations of std::vector that use a bitset which does not have +// the semantics expected by Literal::preds(). +class BoolVector { + public: + typedef bool* iterator; + typedef const bool* const_iterator; + + BoolVector() : bits_(nullptr), size_(0), capacity_(0) {} + + BoolVector(const_iterator other_begin, const_iterator other_end) + : bits_(nullptr), size_(0), capacity_(0) { + if (other_begin && other_end) { + resize(other_end - other_begin); + memcpy(begin(), other_begin, size()); + } + } + + BoolVector(const BoolVector& other) { CopyFrom(other); } + + BoolVector& operator=(const BoolVector& other) { + CopyFrom(other); + return *this; + } + + void push_back(const bool& value) { + resize(size_ + 1); + bits_[size_ - 1] = value; + } + + bool* data() const { return bits_.get(); } + + size_t size() const { return size_; } + + size_t capacity() const { return capacity_; } + + void resize(size_t new_size, bool val = false) { + if (new_size == 0) { + bits_.reset(nullptr); + size_ = 0; + capacity_ = 0; + } else { + size_t old_size = size(); + if (new_size > old_size) { + grow(new_size); + } + if (old_size < new_size) { + memset(&bits_[old_size], val, new_size - old_size); + } + size_ = new_size; + } + } + + void clear() { + bits_.reset(nullptr); + size_ = 0; + capacity_ = 0; + } + + iterator begin() { return &bits_[0]; } + iterator end() { return &bits_[size()]; } + const_iterator begin() const { return &bits_[0]; } + const_iterator end() const { return &bits_[size()]; } + + private: + void grow(size_t n) { + if (capacity_ < n) { + capacity_ = 2 * n; + bool* new_bits = new bool[capacity_](); + if (size_ > 0) { + memcpy(new_bits, bits_.get(), size_); + } + bits_.reset(new_bits); + } + } + + void CopyFrom(const BoolVector& other) { + bits_ = MakeUnique(other.capacity()); + memcpy(begin(), other.begin(), other.size()); + size_ = other.size(); + capacity_ = other.capacity(); + } + + std::unique_ptr bits_; + size_t size_; + size_t capacity_; +}; + // Utility class for dealing with XLA literal values. Most methods are // templated by native (host) type which corresponds to a unique XLA // PrimitiveType. See ComputationBuilder for details. Not all primitive types // defined in xla_data.proto have a corresponding native type or even have a // storage location in the Literal proto yet (for example, primitive type F16). -class LiteralUtil { +class Literal { public: - // Create new literal of a given rank. To minimize ambiguity (for users and - // the compiler) these CreateR[0-2] methods should explicitly specify the + Literal() {} + + Literal(const Literal& other) = default; + + explicit Literal(const LiteralProto& other) { CopyFromProto(other); } + + Literal& operator=(const Literal& other) = default; + + LiteralProto ToProto() const; + + bool has_shape() const { + return shape_.element_type() != PRIMITIVE_TYPE_INVALID; + } + + // Basic accessor functions. Names mirror the original protobuf + // functions for convenience. + string DebugString() const { return ToProto().DebugString(); } + string ShortDebugString() const { return ToProto().ShortDebugString(); } + + void Clear() { + shape_.Clear(); + preds_.clear(); + u8s_.clear(); + s32s_.clear(); + s64s_.clear(); + u32s_.clear(); + u64s_.clear(); + f16s_.clear(); + f32s_.clear(); + f64s_.clear(); + tuple_literals_.clear(); + } + + int preds_size() const { return preds().size(); } + const BoolVector& preds() const { return preds_; } + BoolVector* mutable_preds() { return &preds_; } + + int s32s_size() const { return s32s().size(); } + int32 s32s(int i) const { return s32s_[i]; } + const std::vector& s32s() const { return s32s_; } + std::vector* mutable_s32s() { return &s32s_; } + + int s64s_size() const { return s64s().size(); } + void add_s64s(int64 value) { s64s_.push_back(value); } + const std::vector& s64s() const { return s64s_; } + std::vector* mutable_s64s() { return &s64s_; } + + int u32s_size() const { return u32s().size(); } + uint32 u32s(int i) const { return u32s_[i]; } + const std::vector& u32s() const { return u32s_; } + std::vector* mutable_u32s() { return &u32s_; } + + int u64s_size() const { return u64s().size(); } + const std::vector& u64s() const { return u64s_; } + std::vector* mutable_u64s() { return &u64s_; } + + int f16s_size() const { return f16s().size(); } + half f16s(int i) const { return f16s_[i]; } + const std::vector& f16s() const { return f16s_; } + std::vector* mutable_f16s() { return &f16s_; } + + int f32s_size() const { return f32s().size(); } + float f32s(int i) const { return f32s_[i]; } + void add_f32s(float value) { f32s_.push_back(value); } + const std::vector& f32s() const { return f32s_; } + std::vector& f32s() { return f32s_; } + std::vector* mutable_f32s() { return &f32s_; } + + int f64s_size() const { return f64s().size(); } + const std::vector& f64s() const { return f64s_; } + std::vector* mutable_f64s() { return &f64s_; } + + int tuple_literals_size() const { return tuple_literals().size(); } + const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } + Literal* add_tuple_literals() { + tuple_literals_.push_back(Literal()); + return &tuple_literals_.back(); + } + std::vector* mutable_tuple_literals() { return &tuple_literals_; } + const std::vector& tuple_literals() const { return tuple_literals_; } + + int u8s_size() const { return u8s().size(); } + const std::vector& u8s() const { return u8s_; } + void set_u8s(const std::vector& value) { u8s_ = value; } + void set_u8s(tensorflow::StringPiece value) { + u8s_ = std::vector(value.size()); + u8s_.clear(); + append_u8s(value); + } + + void append_u8s(tensorflow::StringPiece value) { + u8s_.insert(u8s_.end(), value.begin(), value.end()); + } + + string u8s_string() const { return string(u8s().begin(), u8s().end()); } + + std::vector* mutable_u8s() { return &u8s_; } + + const Shape& shape() const { return shape_; } + Shape* mutable_shape() { return &shape_; } + + void Swap(Literal* other) { + Literal temp = *this; + *this = *other; + *other = temp; + } + + // CreatesCreate new literal of a given rank. To minimize ambiguity (for users + // and the compiler) these CreateR[0-2] methods should explicitly specify the // native type. For example: // // CreateR1({1.0, 42.0}); @@ -101,12 +297,12 @@ class LiteralUtil { values, const Layout& layout); - // Create a new Literal object with the shape specified as parameter. + // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). static std::unique_ptr CreateFromShape(const Shape& shape); - // Create a new Literal object with its values havings the primitive_type + // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). @@ -115,86 +311,84 @@ class LiteralUtil { tensorflow::gtl::ArraySlice dimensions); // Copies the values from src_literal, starting at src_base shape indexes, - // to dest_literal, starting at dest_base, where the copy size in each + // to this literal, starting at dest_base, where the copy size in each // dimension is specified by copy_size. - // The src_literal and dest_literal must have the same primitive type, + // The src_literal and this literal must have the same primitive type, // src_base+copy_size must fit the source literal dimensions, as well as // dest_base+copy_size must fit the destination literal dimensions. - static Status Copy(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + Status Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); - // Creates a new value that has the equivalent value as literal, but conforms - // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major - // dimension layout can be re-laid-out as {1, 0} minor-to-major dimension - // layout and the value in the cell at any given logical index (i0, i1) will - // be the same. + // Creates a new value that has the equivalent value as this literal, but + // conforms to new_layout; e.g. a literal matrix that was in {0, 1} + // minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. // // Note: this is useful when the client wants to ensure that a value placed in // the XLA allocation tracker has a particular layout; for efficiency // purposes or avoiding unimplemented operation/layout combinations. - static std::unique_ptr Relayout(const Literal& literal, - const Layout& new_layout); + std::unique_ptr Relayout(const Layout& new_layout) const; - // Reshapes literal 'input' to have 'shape'. Both the original shape and - // 'shape' must contain the same number of elements. The implementation - // currently only supports monotonic dim0-major layouts. - static StatusOr> Reshape( - const xla::Literal& input, tensorflow::gtl::ArraySlice shape); + // Creates a new literal by reshaping this literal to have 'shape'. Both the + // original shape and 'shape' must contain the same number of elements. The + // implementation currently only supports monotonic dim0-major layouts. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice shape) const; - // Creates a new literal by reordering the dimensions of the original literal. + // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers // in the original literal, and it specifies the order of the new dimensions // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - static std::unique_ptr Transpose( - const Literal& literal, tensorflow::gtl::ArraySlice permutation); + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; - // Creates a sub-array from the given literal by extracting the indices + // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the // same rank and layout as for the given literal. The number of indices in // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. - static std::unique_ptr Slice( - const Literal& literal, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices); + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. template - static std::unique_ptr Replicate(const Literal& input, int64 times); + std::unique_ptr Replicate(int64 times) const; - // Create a literal by converting each element in an original literal to a new + // Creates a literal by converting each element in this literal to a new // type. template - static std::unique_ptr Convert(const Literal& literal); + std::unique_ptr Convert() const; - // Create a literal value zero of the given primitive type. + // Creates a literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Create a literal value one of the given primitive type. + // Creates a literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); // Creates a literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Create a literal value containing the maximum value of the given + // Creates a literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); - // Create a literal of the given shape where each element is `value`. + // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value); - // Create a new literal from an array. The variants not ending with WithLayout - // use the default XLA layout for the literal's linear representation in - // memory. + // Creates a new literal from an array. The variants not ending with + // WithLayout use the default XLA layout for the literal's linear + // representation in memory. template static std::unique_ptr CreateR2FromArray2D( const Array2D& values); @@ -236,39 +430,33 @@ class LiteralUtil { std::initializer_list> values, int64 projection_p, int64 projection_z); - // Clones literal into an owned unique_ptr version. - static std::unique_ptr CloneToUnique(const Literal& literal); + // Clones this literal into an owned unique_ptr version. + std::unique_ptr CloneToUnique() const; - // Returns the linear index of the given index within the literal's + // Returns the linear index of the given index within this literal's // element_type repeated field. - static int64 LinearIndex(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + int64 LinearIndex(tensorflow::gtl::ArraySlice multi_index) const; // Gets or sets an element in the literal at the given index. The index is // CHECKed against the dimension sizes. template - static NativeT Get(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; template - static void Set(Literal* literal, - tensorflow::gtl::ArraySlice multi_index, - NativeT value); + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); // Retrieves the mutable array slice interface which can be used to manipulate // pre-allocated literal values. template - static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( - Literal* literal); + tensorflow::gtl::MutableArraySlice GetMutableArraySlice(); // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. template - static NativeT GetFirstElement(const Literal& literal); + NativeT GetFirstElement() const; // As Get(), but determines the correct type and converts the value // into text. - static string GetAsString(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + string GetAsString(tensorflow::gtl::ArraySlice multi_index) const; // Returns an identity matrix (rank 2) with the given row and column count. template @@ -280,10 +468,10 @@ class LiteralUtil { // Validates that the data payload of the literal matches the literal shape; // if it does not, an appropriate status is returned. - static tensorflow::Status ValidateLiteral(const Literal& literal); + tensorflow::Status ValidateLiteral() const; // Returns a string representation of the literal value. - static string ToString(const Literal& literal); + string ToString() const; // Invokes the "per cell" callback for each element in the provided // literal with the element's indices and a string representation of @@ -292,112 +480,97 @@ class LiteralUtil { // This function is useful if you want a polymorphic representation // of the tensor's elements (turning it to a string for something // like representation in a protobuf). - static void EachCellAsString( - const Literal& literal, + void EachCellAsString( const std::function indices, - const string& value)>& per_cell); + const string& value)>& per_cell) const; template - static void EachCell( - const Literal& literal, - std::function indices, - NativeT value)> - per_cell); + void EachCell(std::function indices, + NativeT value)> + per_cell) const; - // Templated methods which populate the given repeated field in the Literal - // proto with the given value(s). The Shape field of the Literal proto is set + // Templated methods which populate the given repeated field in this literal + // with the given value(s). The Shape field of this literal is set // to match the array dimensions and type. Examples: // // // Populate with floats. // Array2D float_values = ... - // PopulateR2FromArray2D(values, literal); + // literal.PopulateR2FromArray2D(values); // // // Populate with int32s. - // PopulateR2({{1, 2}, {3, 4}}, literal); + // literal.PopulateR2({{1, 2}, {3, 4}}); // template - static void PopulateR0(NativeT values, Literal* literal); + void PopulateR0(NativeT values); template - static void PopulateR1(tensorflow::gtl::ArraySlice values, - Literal* literal); - static void PopulateR1(const tensorflow::core::Bitmap& values, - Literal* literal); + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); template - static void PopulateR2( + void PopulateR2(std::initializer_list> values); + template + void PopulateR2WithLayout( std::initializer_list> values, - Literal* literal); + const Layout& layout); template - static void PopulateR2WithLayout( - std::initializer_list> values, - const Layout& layout, Literal* literal); + void PopulateR2FromArray2D(const Array2D& values); template - static void PopulateR2FromArray2D(const Array2D& values, - Literal* literal); + void PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); template - static void PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout, - Literal* literal); + void PopulateR3FromArray3D(const Array3D& values); template - static void PopulateR3FromArray3D(const Array3D& values, - Literal* literal); + void PopulateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); template - static void PopulateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout, - Literal* literal); + void PopulateR4FromArray4D(const Array4D& values); template - static void PopulateR4FromArray4D(const Array4D& values, - Literal* literal); - template - static void PopulateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout, - Literal* literal); + void PopulateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); // Populates literal values by calling the generator function for every cell - // in the literal object. + // in this literal object. template - static Status Populate( - Literal* literal, + Status Populate( const std::function indexes)>& generator); // Creates a Literal of the given dimensions with all elements set to the // given value. template - static void PopulateWithValue(NativeT value, - tensorflow::gtl::ArraySlice dimensions, - Literal* literal); + void PopulateWithValue(NativeT value, + tensorflow::gtl::ArraySlice dimensions); - // Returns a pointer to the underlying buffer in the protobuf containing the - // array data. Use with care. - static const void* InternalData(const Literal& literal); - static void* MutableInternalData(Literal* literal); - - // Allocates space in the repeated_field of the literal sufficient to hold - // num_elements of the literal's primitive type. Values in the buffer are set - // to zero. num_elements must equal the number of elements in the literals + // Returns a pointer to the underlying vector corresponding to the Literal's // shape. - static void Reserve(int64 num_elements, Literal* literal); + const void* InternalData() const; + void* MutableInternalData(); - // Allocates space in the repeated_field of the literal sufficient to hold - // num_elements of the literal's primitive type and sets each element in the + // Allocates space in the underlying vector of this literal sufficient to hold + // num_elements of this literal's primitive type. Values in the vector are set + // to zero. num_elements must equal the number of elements in the literal's + // shape. + void Reserve(int64 num_elements); + + // Allocates space in the underlying vector of this literal sufficient to hold + // num_elements of this literal's primitive type and sets each element in this // literal to the given value. num_elements must equal the number of elements - // in the literals shape. + // in this literal's shape. template - static void Resize(int64 num_elements, NativeT value, Literal* literal); + void Resize(int64 num_elements, NativeT value); - // Returns true if the two given literals have the same shape and - // values. Layout is not considered in the comparison. - static bool Equal(const Literal& literal1, const Literal& literal2); + // Returns true if this literal has the same shape and value as the given + // literal. Layout is not considered in the comparison. + bool Equal(const Literal& literal2) const; - // Returns whether every element in the given literal is equal to value. + // Returns whether every element in this literal is equal to value. // // value is an int8 because we expect this to be called with small // compile-time constants (0, -1, etc.) and so that whatever value you pass // can be represented exactly by floating-point types as small as 16 bits. // - // If value doesn't fit in literal's type, returns false. Values of 1/0 are - // considered equal to true/false; other values are not considered equal to - // true. - static bool IsAll(const Literal& literal, int8 value); + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. + bool IsAll(int8 value) const; // Like IsAll(const Literal&, int8), except we check whether the literal is // equal to a particular floating-point number. @@ -408,34 +581,34 @@ class LiteralUtil { // admonishments about floating-point equality checks apply. We expect you to // use this to check for values that can be expressed precisely as a float, // e.g. -0.5. - static bool IsAllFloat(const Literal& literal, float value); + bool IsAllFloat(float value) const; - // Returns whether the literal is zero at the specified index. The literal + // Returns whether this literal is zero at the specified index. This literal // must be an array. - static bool IsZero(const Literal& literal, - tensorflow::gtl::ArraySlice indices); + bool IsZero(tensorflow::gtl::ArraySlice indices) const; private: - // Returns an ArraySlice view of the array for the given literal for the - // given NativeT (e.g., float). These - // functions map native type to XLA PrimitiveType via template - // specialization. The unspecialized forms below aborts to handle the error - // case where the given native type does not map to an XLA primitive type. + // Returns an ArraySlice view of the array for this literal for the given + // NativeT (e.g., float). These functions map native type to XLA PrimitiveType + // via template specialization. The unspecialized forms below aborts to handle + // the error case where the given native type does not map to an XLA primitive + // type. template - static tensorflow::gtl::ArraySlice GetArraySlice( - const Literal& literal) { + tensorflow::gtl::ArraySlice GetArraySlice() const { static_assert(!std::is_same::value, "Cannot map native type to primitive type."); } + // Copy from a LiteralProto instance. + void CopyFromProto(const LiteralProto& literal_proto); + // Internal template helper for the Copy() API, matching its arguments one by // one. template - static Status CopyRange(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + Status CopyRange(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); // Utility structure which is used to create the optimal configuration for // a ShapeUtil::ForEachIndex() scan across two literals. @@ -460,6 +633,549 @@ class LiteralUtil { int64 minor_loop_size = 1; }; + Shape shape_; + BoolVector preds_; + std::vector u8s_; + std::vector s32s_; + std::vector s64s_; + std::vector u32s_; + std::vector u64s_; + std::vector f16s_; + std::vector f32s_; + std::vector f64s_; + std::vector tuple_literals_; +}; + +// Utility class for dealing with XLA literal values. Most methods are +// templated by native (host) type which corresponds to a unique XLA +// PrimitiveType. See ComputationBuilder for details. Not all primitive types +// defined in xla_data.proto have a corresponding native type or even have a +// storage location in the Literal proto yet (for example, primitive type F16). +// +// TODO(dnovillo) - All functions in this class simply redirect to the +// corresponding function in class Literal. Remove this class after converting +// all user code to use Literal directly. +class LiteralUtil { + public: + // Creates new literal of a given rank. To minimize ambiguity (for users and + // the compiler) these CreateR[0-2] methods should explicitly specify the + // native type. For example: + // + // CreateR1({1.0, 42.0}); + // CreateR2({{1, 2}, {3, 4}}); + // + // The variants not ending with WithLayout use the default XLA layout for the + // literal's linear representation in memory. + template + static std::unique_ptr CreateR0(NativeT value) { + return Literal::CreateR0(value); + } + + template + static std::unique_ptr CreateR1( + tensorflow::gtl::ArraySlice values) { + return Literal::CreateR1(values); + } + + static std::unique_ptr CreateR1( + const tensorflow::core::Bitmap& values) { + return Literal::CreateR1(values); + } + + template + static std::unique_ptr CreateR2( + std::initializer_list> values) { + return Literal::CreateR2(values); + } + + template + static std::unique_ptr CreateR2WithLayout( + std::initializer_list> values, + const Layout& layout) { + return Literal::CreateR2WithLayout(values, layout); + } + + template + static std::unique_ptr CreateR3( + std::initializer_list< + std::initializer_list>> + values) { + return Literal::CreateR3(values); + } + + template + static std::unique_ptr CreateR3WithLayout( + std::initializer_list< + std::initializer_list>> + values, + const Layout& layout) { + return Literal::CreateR3WithLayout(values, layout); + } + + template + static std::unique_ptr CreateR4( + std::initializer_list>>> + values) { + return Literal::CreateR4(values); + } + + template + static std::unique_ptr CreateR4WithLayout( + std::initializer_list>>> + values, + const Layout& layout) { + return Literal::CreateR4WithLayout(values, layout); + } + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromShape(const Shape& shape) { + return Literal::CreateFromShape(shape); + } + + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions) { + return Literal::CreateFromDimensions(primitive_type, dimensions); + } + + // Copies the values from src_literal, starting at src_base shape indexes, + // to dest_literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // + // The src_literal and dest_literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + static Status Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + return dest_literal->Copy(src_literal, src_base, dest_base, copy_size); + } + + // Creates a new value that has the equivalent value as literal, but conforms + // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major + // dimension layout can be re-laid-out as {1, 0} minor-to-major dimension + // layout and the value in the cell at any given logical index (i0, i1) will + // be the same. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + static std::unique_ptr Relayout(const Literal& literal, + const Layout& new_layout) { + return literal.Relayout(new_layout); + } + + // Reshapes literal 'input' to have 'shape'. Both the original shape and + // 'shape' must contain the same number of elements. The implementation + // currently only supports monotonic dim0-major layouts. + static StatusOr> Reshape( + const xla::Literal& input, tensorflow::gtl::ArraySlice shape) { + return input.Reshape(shape); + } + + // Creates a new literal by reordering the dimensions of the original literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + static std::unique_ptr Transpose( + const Literal& literal, tensorflow::gtl::ArraySlice permutation) { + return literal.Transpose(permutation); + } + + // Creates a sub-array from the given literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + static std::unique_ptr Slice( + const Literal& literal, tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) { + return literal.Slice(start_indices, limit_indices); + } + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input + // literal replicated four times. + template + static std::unique_ptr Replicate(const Literal& input, int64 times) { + return input.Replicate(times); + } + + // Creates a literal by converting each element in an original literal to a + // new type. + template + static std::unique_ptr Convert(const Literal& literal) { + return literal.Convert(); + } + + // Creates a literal value zero of the given primitive type. + static Literal Zero(PrimitiveType primitive_type) { + return Literal::Zero(primitive_type); + } + + // Creates a literal value one of the given primitive type. + static Literal One(PrimitiveType primitive_type) { + return Literal::One(primitive_type); + } + + // Creates a literal value containing the minimum value of the given + // primitive type. For floating-point types, returns -inf. + static Literal MinValue(PrimitiveType primitive_type) { + return Literal::MinValue(primitive_type); + } + + // Creates a literal value containing the maximum value of the given + // primitive type. For floating-point types, returns inf. + static Literal MaxValue(PrimitiveType primitive_type) { + return Literal::MaxValue(primitive_type); + } + + // Creates a literal of the given shape where each element is `value`. + template + static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( + tensorflow::gtl::ArraySlice dimensions, NativeT value) { + return Literal::CreateFullWithMonotonicDim0MajorLayout(dimensions, value); + } + + // Creates a new literal from an array. The variants not ending with + // WithLayout use the default XLA layout for the literal's linear + // representation in memory. + template + static std::unique_ptr CreateR2FromArray2D( + const Array2D& values) { + return Literal::CreateR2FromArray2D(values); + } + + template + static std::unique_ptr CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return Literal::CreateR2FromArray2DWithLayout(values, layout); + } + + template + static std::unique_ptr CreateR3FromArray3D( + const Array3D& values) { + return Literal::CreateR3FromArray3D(values); + } + + template + static std::unique_ptr CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { + return Literal::CreateR3FromArray3DWithLayout(values, layout); + } + + template + static std::unique_ptr CreateR4FromArray4D( + const Array4D& values) { + return Literal::CreateR4FromArray4D(values); + } + + template + static std::unique_ptr CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { + return Literal::CreateR4FromArray4DWithLayout(values, layout); + } + + // Creates a new vector of U8s literal value from a string. + static std::unique_ptr CreateR1U8(tensorflow::StringPiece value) { + return Literal::CreateR1U8(value); + } + + // Creates a linspace-populated literal with the given number of rows and + // columns. + static std::unique_ptr CreateR2F32Linspace(float from, float to, + int64 rows, int64 cols) { + return Literal::CreateR2F32Linspace(from, to, rows, cols); + } + + // Creates a literal that projects the (x, y) dimensions given in values into + // the z dimension given by "projection". + template + static std::unique_ptr CreateR3Projected( + std::initializer_list> values, + int64 projection) { + return Literal::CreateR3Projected(values, projection); + } + + // Creates a literal that projects the (x, y) dimensions given in values into + // the z and p dimensions given. + template + static std::unique_ptr CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z) { + return Literal::CreateR4Projected(values, projection_p, projection_z); + } + + // Clones literal into an owned unique_ptr version. + static std::unique_ptr CloneToUnique(const Literal& literal) { + return literal.CloneToUnique(); + } + + // Returns the linear index of the given index within the literal's + // element_type repeated field. + static int64 LinearIndex(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index) { + return literal.LinearIndex(multi_index); + } + + // Gets or sets an element in the literal at the given index. The index is + // CHECKed against the dimension sizes. + template + static NativeT Get(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index) { + return literal.Get(multi_index); + } + + template + static void Set(Literal* literal, + tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + literal->Set(multi_index, value); + } + + // Retrieves the mutable array slice interface which can be used to manipulate + // pre-allocated literal values. + template + static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( + Literal* literal) { + return literal->GetMutableArraySlice(); + } + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + static NativeT GetFirstElement(const Literal& literal) { + return literal.GetFirstElement(); + } + + // As Get(), but determines the correct type and converts the value + // into text. + static string GetAsString(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index) { + return literal.GetAsString(multi_index); + } + + // Returns an identity matrix (rank 2) with the given row and column count. + template + static std::unique_ptr MakeIdentityR2(int64 size) { + return Literal::MakeIdentityR2(size); + } + + // Returns a tuple literal composed of given literals. + static std::unique_ptr MakeTuple( + tensorflow::gtl::ArraySlice elements) { + return Literal::MakeTuple(elements); + } + + // Validates that the data payload of the literal matches the literal shape; + // if it does not, an appropriate status is returned. + static tensorflow::Status ValidateLiteral(const Literal& literal) { + return literal.ValidateLiteral(); + } + + // Returns a string representation of the literal value. + static string ToString(const Literal& literal) { return literal.ToString(); } + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + static void EachCellAsString( + const Literal& literal, + const std::function indices, + const string& value)>& per_cell) { + literal.EachCellAsString(per_cell); + } + + template + static void EachCell( + const Literal& literal, + std::function indices, + NativeT value)> + per_cell) { + literal.EachCell(per_cell); + } + + // Templated methods which populate the given repeated field in the Literal + // proto with the given value(s). The Shape field of the Literal proto is set + // to match the array dimensions and type. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // PopulateR2FromArray2D(values, literal); + // + // // Populate with int32s. + // PopulateR2({{1, 2}, {3, 4}}, literal); + // + template + static void PopulateR0(NativeT values, Literal* literal) { + literal->PopulateR0(values); + } + + template + static void PopulateR1(tensorflow::gtl::ArraySlice values, + Literal* literal) { + literal->PopulateR1(values); + } + + static void PopulateR1(const tensorflow::core::Bitmap& values, + Literal* literal) { + literal->PopulateR1(values); + } + + template + static void PopulateR2( + std::initializer_list> values, + Literal* literal) { + literal->PopulateR2(values); + } + + template + static void PopulateR2WithLayout( + std::initializer_list> values, + const Layout& layout, Literal* literal) { + literal->PopulateR2WithLayout(values, layout); + } + + template + static void PopulateR2FromArray2D(const Array2D& values, + Literal* literal) { + literal->PopulateR2FromArray2D(values); + } + + template + static void PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout, + Literal* literal) { + literal->PopulateR2FromArray2DWithLayout(values, layout); + } + + template + static void PopulateR3FromArray3D(const Array3D& values, + Literal* literal) { + literal->PopulateR3FromArray3D(values); + } + + template + static void PopulateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout, + Literal* literal) { + literal->PopulateR3FromArray3DWithLayout(values, layout); + } + + template + static void PopulateR4FromArray4D(const Array4D& values, + Literal* literal) { + literal->PopulateR4FromArray4D(values); + } + + template + static void PopulateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout, + Literal* literal) { + literal->PopulateR4FromArray4DWithLayout(values, layout); + } + + // Populates literal values by calling the generator function for every cell + // in the literal object. + template + static Status Populate( + Literal* literal, + const std::function indexes)>& + generator) { + return literal->Populate(generator); + } + + // Creates a Literal of the given dimensions with all elements set to the + // given value. + template + static void PopulateWithValue(NativeT value, + tensorflow::gtl::ArraySlice dimensions, + Literal* literal) { + return literal->PopulateWithValue(value, dimensions); + } + + // Returns a pointer to the underlying vector containing the array data. Use + // with care. + static const void* InternalData(const Literal& literal) { + return literal.InternalData(); + } + + static void* MutableInternalData(Literal* literal) { + return literal->MutableInternalData(); + } + + // Allocates space in the underlying vector of the literal sufficient to hold + // num_elements of the literal's primitive type. Values in the vector are set + // to zero. num_elements must equal the number of elements in the literals + // shape. + static void Reserve(int64 num_elements, Literal* literal) { + literal->Reserve(num_elements); + } + + // Allocates space in the underlying vector of the literal sufficient to hold + // num_elements of the literal's primitive type and sets each element in the + // literal to the given value. num_elements must equal the number of elements + // in the literals shape. + template + static void Resize(int64 num_elements, NativeT value, Literal* literal) { + literal->Resize(num_elements, value); + } + + // Returns true if the two given literals have the same shape and + // values. Layout is not considered in the comparison. + static bool Equal(const Literal& literal1, const Literal& literal2) { + return literal1.Equal(literal2); + } + + // Returns whether every element in the given literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in literal's type, returns false. Values of 1/0 are + // considered equal to true/false; other values are not considered equal to + // true. + static bool IsAll(const Literal& literal, int8 value) { + return literal.IsAll(value); + } + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. + static bool IsAllFloat(const Literal& literal, float value) { + return literal.IsAllFloat(value); + } + + // Returns whether the literal is zero at the specified index. The literal + // must be an array. + static bool IsZero(const Literal& literal, + tensorflow::gtl::ArraySlice indices) { + return literal.IsZero(indices); + } + TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); }; @@ -467,160 +1183,131 @@ class LiteralUtil { // GetMutableArraySlice. The specializations map native type to XLA primitive // type. template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ inline tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - DCHECK(literal.shape().element_type() == F32); - return literal.f32s(); +inline tensorflow::gtl::ArraySlice Literal::GetArraySlice() + const { + DCHECK(shape().element_type() == F32); + return f32s(); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::MutableArraySlice -LiteralUtil::GetMutableArraySlice(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, - Literal* literal); +void Literal::Resize(int64 num_elements, bool value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, - Literal* literal); +void Literal::Resize(int64 num_elements, int8 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, - Literal* literal); +void Literal::Resize(int64 num_elements, uint8 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, - Literal* literal); +void Literal::Resize(int64 num_elements, int32 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, - Literal* literal); +void Literal::Resize(int64 num_elements, uint32 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal); +void Literal::Resize(int64 num_elements, int64 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal); +void Literal::Resize(int64 num_elements, uint64 value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, float value, - Literal* literal); +void Literal::Resize(int64 num_elements, float value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, double value, - Literal* literal); +void Literal::Resize(int64 num_elements, double value); template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, half value, - Literal* literal); +void Literal::Resize(int64 num_elements, half value); template -/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { +/* static */ std::unique_ptr Literal::CreateR0(NativeT value) { auto literal = MakeUnique(); - PopulateR0(value, literal.get()); + literal->PopulateR0(value); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ std::unique_ptr Literal::CreateR1( tensorflow::gtl::ArraySlice values) { auto literal = MakeUnique(); - PopulateR1(values, literal.get()); + literal->PopulateR1(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( +/* static */ std::unique_ptr Literal::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR2WithLayout(values, layout, literal.get()); + literal->PopulateR2WithLayout(values, layout); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2( +/* static */ std::unique_ptr Literal::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( +/* static */ std::unique_ptr Literal::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -645,14 +1332,14 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR3( +/* static */ std::unique_ptr Literal::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( +/* static */ std::unique_ptr Literal::CreateR4WithLayout( std::initializer_list>>> values, @@ -683,7 +1370,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4( +/* static */ std::unique_ptr Literal::CreateR4( std::initializer_list>>> values) { @@ -691,38 +1378,37 @@ template } template -/* static */ std::unique_ptr -LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR2FromArray2DWithLayout(values, layout, literal.get()); + literal->PopulateR2FromArray2DWithLayout(values, layout); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( +/* static */ std::unique_ptr Literal::CreateR2FromArray2D( const Array2D& values) { return CreateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } + template -/* static */ std::unique_ptr -LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR3FromArray3DWithLayout(values, layout, literal.get()); + literal->PopulateR3FromArray3DWithLayout(values, layout); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( +/* static */ std::unique_ptr Literal::CreateR3FromArray3D( const Array3D& values) { return CreateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( +/* static */ std::unique_ptr Literal::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -747,7 +1433,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( +/* static */ std::unique_ptr Literal::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -775,99 +1461,92 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( +/* static */ std::unique_ptr Literal::CreateR4FromArray4D( const Array4D& values) { return CreateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR4FromArray4DWithLayout(values, layout, literal.get()); + literal->PopulateR4FromArray4DWithLayout(values, layout); return literal; } template -/* static */ NativeT LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - int64 linear_index = LinearIndex(literal, multi_index); - return GetArraySlice(literal).at(linear_index); +NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index) const { + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice().at(linear_index); } template -/* static */ NativeT LiteralUtil::GetFirstElement(const Literal& literal) { - return GetArraySlice(literal).at(0); +NativeT Literal::GetFirstElement() const { + return GetArraySlice().at(0); } template <> -/* static */ inline uint8 LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - CHECK(literal.shape().element_type() == U8); - int64 linear_index = LinearIndex(literal, multi_index); - return literal.u8s()[linear_index]; +inline uint8 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == U8); + int64 linear_index = LinearIndex(multi_index); + return u8s()[linear_index]; } template <> -/* static */ inline int8 LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - CHECK(literal.shape().element_type() == S8); - int64 linear_index = LinearIndex(literal, multi_index); - return literal.u8s()[linear_index]; +inline int8 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == S8); + int64 linear_index = LinearIndex(multi_index); + return u8s()[linear_index]; } template <> -/* static */ inline half LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - CHECK(literal.shape().element_type() == F16); - int64 linear_index = LinearIndex(literal, multi_index); - return GetArraySlice(literal)[linear_index]; +inline half Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == F16); + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice()[linear_index]; } template -/* static */ void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - int64 linear_index = LinearIndex(*literal, multi_index); - GetMutableArraySlice(literal).at(linear_index) = value; +void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + int64 linear_index = LinearIndex(multi_index); + GetMutableArraySlice().at(linear_index) = value; } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - uint8 value) { - int64 linear_index = LinearIndex(*literal, multi_index); - (*literal->mutable_u8s())[linear_index] = value; +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + uint8 value) { + int64 linear_index = LinearIndex(multi_index); + (*mutable_u8s())[linear_index] = value; } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - int8 value) { - return Set(literal, multi_index, value); +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + int8 value) { + return Set(multi_index, value); } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - int64 value) { - int64 linear_index = LinearIndex(*literal, multi_index); - (*literal->mutable_s64s())[linear_index] = value; +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + int64 value) { + int64 linear_index = LinearIndex(multi_index); + (*mutable_s64s())[linear_index] = value; } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - uint64 value) { - int64 linear_index = LinearIndex(*literal, multi_index); - (*literal->mutable_u64s())[linear_index] = value; +/* static */ inline void Literal::Set( + tensorflow::gtl::ArraySlice multi_index, uint64 value) { + int64 linear_index = LinearIndex(multi_index); + (*mutable_u64s())[linear_index] = value; } // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { +/* static */ std::unique_ptr Literal::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -876,55 +1555,51 @@ template } template -/* static */ void LiteralUtil::EachCell( - const Literal& literal, +void Literal::EachCell( std::function indices, NativeT value)> - per_cell) { - if (ShapeUtil::HasZeroElements(literal.shape())) { + per_cell) const { + if (ShapeUtil::HasZeroElements(shape())) { return; } - std::vector indices(ShapeUtil::Rank(literal.shape()), 0); + std::vector indices(ShapeUtil::Rank(shape()), 0); do { - per_cell(indices, Get(literal, indices)); - } while (IndexUtil::BumpIndices(literal.shape(), &indices)); + per_cell(indices, Get(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); } template -/* static */ inline void LiteralUtil::PopulateR0(NativeT value, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( +inline void Literal::PopulateR0(NativeT value) { + *mutable_shape() = ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {}); - Resize(1, value, literal); + Resize(1, value); } template -/* static */ void LiteralUtil::PopulateR1( - tensorflow::gtl::ArraySlice values, Literal* literal) { - *literal->mutable_shape() = +inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { + *mutable_shape() = ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())}); - Reserve(values.size(), literal); + Reserve(values.size()); for (int64 i = 0; i < values.size(); ++i) { - Set(literal, {i}, values[i]); + Set({i}, values[i]); } } -/* static */ inline void LiteralUtil::PopulateR1( - const tensorflow::core::Bitmap& values, Literal* literal) { - *literal->mutable_shape() = +inline void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { + *mutable_shape() = ShapeUtil::MakeShape(PRED, {static_cast(values.bits())}); - Reserve(values.bits(), literal); + Reserve(values.bits()); for (int64 i = 0; i < static_cast(values.bits()); ++i) { - Set(literal, {i}, values.get(i)); + Set({i}, values.get(i)); } } template -/* static */ void LiteralUtil::PopulateR2WithLayout( +void Literal::PopulateR2WithLayout( std::initializer_list> values, - const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, @@ -932,17 +1607,17 @@ template const int64 dim0_size = values.size(); const int64 dim1_size = values.begin()->size(); - CHECK_EQ(dim0_size, literal->shape().dimensions(0)); - CHECK_EQ(dim1_size, literal->shape().dimensions(1)); + CHECK_EQ(dim0_size, shape().dimensions(0)); + CHECK_EQ(dim1_size, shape().dimensions(1)); const int64 num_elements = dim1_size * dim0_size; - Reserve(num_elements, literal); + Reserve(num_elements); int64 dim0 = 0; for (auto inner_list : values) { int64 dim1 = 0; for (auto value : inner_list) { - Set(literal, {dim0, dim1}, value); + Set({dim0, dim1}, value); ++dim1; } CHECK_EQ(dim1_size, dim1); @@ -951,84 +1626,79 @@ template } template -/* static */ void LiteralUtil::PopulateR2( - std::initializer_list> values, - Literal* literal) { - PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2(), literal); +void Literal::PopulateR2( + std::initializer_list> values) { + PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ void LiteralUtil::PopulateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( +void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); const int64 dim1_size = values.width(); const int64 dim0_size = values.height(); - CHECK_EQ(dim0_size, literal->shape().dimensions(0)); - CHECK_EQ(dim1_size, literal->shape().dimensions(1)); - Reserve(dim1_size * dim0_size, literal); + CHECK_EQ(dim0_size, shape().dimensions(0)); + CHECK_EQ(dim1_size, shape().dimensions(1)); + Reserve(dim1_size * dim0_size); for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { - Set(literal, {dim0, dim1}, values(dim0, dim1)); + Set({dim0, dim1}, values(dim0, dim1)); } } } template -/* static */ void LiteralUtil::PopulateR2FromArray2D( - const Array2D& values, Literal* literal) { - PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2(), - literal); +void Literal::PopulateR2FromArray2D(const Array2D& values) { + PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } + template -/* static */ void LiteralUtil::PopulateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( +void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {values.n1(), values.n2(), values.n3()}, AsInt64Slice(layout.minor_to_major())); - CHECK_EQ(values.n1(), literal->shape().dimensions(0)); - CHECK_EQ(values.n2(), literal->shape().dimensions(1)); - CHECK_EQ(values.n3(), literal->shape().dimensions(2)); - Reserve(values.n1() * values.n2() * values.n3(), literal); + CHECK_EQ(values.n1(), shape().dimensions(0)); + CHECK_EQ(values.n2(), shape().dimensions(1)); + CHECK_EQ(values.n3(), shape().dimensions(2)); + Reserve(values.n1() * values.n2() * values.n3()); for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - Set(literal, {dim0, dim1, dim2}, values(dim0, dim1, dim2)); + Set({dim0, dim1, dim2}, values(dim0, dim1, dim2)); } } } } template -/* static */ void LiteralUtil::PopulateR3FromArray3D( - const Array3D& values, Literal* literal) { - PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3(), - literal); +void Literal::PopulateR3FromArray3D(const Array3D& values) { + PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ void LiteralUtil::PopulateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( +void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {values.planes(), values.depth(), values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); - CHECK_EQ(values.n1(), literal->shape().dimensions(0)); - CHECK_EQ(values.n2(), literal->shape().dimensions(1)); - CHECK_EQ(values.n3(), literal->shape().dimensions(2)); - CHECK_EQ(values.n4(), literal->shape().dimensions(3)); - Reserve(values.n1() * values.n2() * values.n3() * values.n4(), literal); + CHECK_EQ(values.n1(), shape().dimensions(0)); + CHECK_EQ(values.n2(), shape().dimensions(1)); + CHECK_EQ(values.n3(), shape().dimensions(2)); + CHECK_EQ(values.n4(), shape().dimensions(3)); + Reserve(values.n1() * values.n2() * values.n3() * values.n4()); for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { - Set(literal, {dim0, dim1, dim2, dim3}, - values(dim0, dim1, dim2, dim3)); + Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3)); } } } @@ -1036,31 +1706,29 @@ template } template -/* static */ void LiteralUtil::PopulateR4FromArray4D( - const Array4D& values, Literal* literal) { - PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4(), - literal); +void Literal::PopulateR4FromArray4D(const Array4D& values) { + PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); } template -/* static */ Status LiteralUtil::Populate( - Literal* literal, +Status Literal::Populate( const std::function indexes)>& generator) { - const Shape& shape = literal->shape(); - int64 rank = ShapeUtil::Rank(shape); - TF_RET_CHECK(shape.element_type() == + const Shape& this_shape = shape(); + int64 rank = ShapeUtil::Rank(this_shape); + TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); tensorflow::gtl::MutableArraySlice data = - GetMutableArraySlice(literal); + GetMutableArraySlice(); if (rank > 0) { - StrideConfig stride_config(shape, shape, AsInt64Slice(shape.dimensions())); + StrideConfig stride_config(this_shape, this_shape, + AsInt64Slice(this_shape.dimensions())); DimensionVector minor_scan_indexes(rank, 0); int64 minor_dimension_size = - ShapeUtil::GetDimension(shape, stride_config.minor_dimension); + ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); auto init_function = [&](const std::vector& indexes) { - int64 index = LinearIndex(*literal, indexes); + int64 index = LinearIndex(indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); for (int64 i = 0; i < minor_dimension_size; ++i) { minor_scan_indexes[stride_config.minor_dimension] = i; @@ -1068,8 +1736,9 @@ template } return true; }; - ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions, - stride_config.step, init_function); + ShapeUtil::ForEachIndex(this_shape, stride_config.base, + stride_config.dimensions, stride_config.step, + init_function); } else { data.at(0) = generator({}); } @@ -1077,30 +1746,27 @@ template } template -/* static */ void LiteralUtil::PopulateWithValue( - NativeT value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( +void Literal::PopulateWithValue(NativeT value, + tensorflow::gtl::ArraySlice dimensions) { + *mutable_shape() = ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), dimensions); - Resize(ShapeUtil::ElementsIn(literal->shape()), value, literal); + Resize(ShapeUtil::ElementsIn(shape()), value); } template -/* static */ std::unique_ptr LiteralUtil::Convert( - const Literal& literal) { - const Shape& shape = literal.shape(); +std::unique_ptr Literal::Convert() const { + const Shape& this_shape = shape(); auto result_literal = MakeUnique(); Shape* result_shape = result_literal->mutable_shape(); - *result_shape = shape; + *result_shape = this_shape; result_shape->set_element_type( primitive_util::NativeToPrimitiveType()); - LiteralUtil::Reserve(ShapeUtil::ElementsIn(*result_shape), - result_literal.get()); + result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); tensorflow::gtl::ArraySlice src_data = - GetArraySlice(literal); + GetArraySlice(); tensorflow::gtl::MutableArraySlice dest_data = - GetMutableArraySlice(result_literal.get()); - int64 num_elements = ShapeUtil::ElementsIn(shape); + result_literal->GetMutableArraySlice(); + int64 num_elements = ShapeUtil::ElementsIn(this_shape); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = static_cast(src_data[i]); @@ -1110,36 +1776,35 @@ template template /* static */ std::unique_ptr -LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( +Literal::CreateFullWithMonotonicDim0MajorLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { - Shape shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + Shape this_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( primitive_util::NativeToPrimitiveType(), dimensions); auto literal = MakeUnique(); - *literal->mutable_shape() = shape; - Reserve(ShapeUtil::ElementsIn(shape), literal.get()); + *literal->mutable_shape() = this_shape; + literal->Reserve(ShapeUtil::ElementsIn(this_shape)); std::vector index(dimensions.size(), 0); do { - Set(literal.get(), index, value); - } while (IndexUtil::BumpIndices(shape, &index)); + literal->Set(index, value); + } while (IndexUtil::BumpIndices(this_shape, &index)); return literal; } template -/* static */ std::unique_ptr LiteralUtil::Replicate( - const Literal& input, int64 times) { +std::unique_ptr Literal::Replicate(int64 times) const { DimensionVector bounds = {times}; - bounds.reserve(input.shape().dimensions_size() + 1); - for (int64 bound : input.shape().dimensions()) { + bounds.reserve(shape().dimensions_size() + 1); + for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } auto literal = MakeUnique(); *literal->mutable_shape() = - ShapeUtil::MakeShape(input.shape().element_type(), bounds); + ShapeUtil::MakeShape(shape().element_type(), bounds); int64 elements = ShapeUtil::ElementsIn(literal->shape()); if (elements == 0) { return literal; } - Reserve(elements, literal.get()); + literal->Reserve(elements); DimensionVector output_indices(bounds.size(), 0); tensorflow::gtl::ArraySlice input_indices = output_indices; @@ -1147,8 +1812,8 @@ template bool done = false; while (!done) { - const auto element = Get(input, input_indices); - Set(literal.get(), output_indices, element); + const auto element = Get(input_indices); + literal->Set(output_indices, element); done = true; for (int n = 0; n < output_indices.size(); ++n) { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index b6f8863b513..aaab36dc8c5 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -856,5 +856,26 @@ TEST_F(LiteralUtilTest, ConvertR4) { EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted)); } +TEST_F(LiteralUtilTest, CopyFromProto_Bool) { + LiteralProto p; + p.mutable_shape()->set_element_type(PRED); + for (int len = 0; len < 25; ++len) { + p.mutable_shape()->clear_dimensions(); + p.mutable_shape()->add_dimensions(len); + p.clear_preds(); + for (int i = 0; i < len; ++i) { + p.add_preds((i % 2) == (len % 2)); + } + + Literal literal(p); + ASSERT_EQ(len, literal.preds_size()); + int i = 0; + for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) { + EXPECT_EQ((i % 2) == (len % 2), *it); + ++i; + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 21766a2a0c8..d488830a6cd 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -60,8 +60,8 @@ StatusOr> PackedLiteralReader::Read( int64 elements = ShapeUtil::ElementsIn(shape); LiteralUtil::Resize(elements, std::numeric_limits::quiet_NaN(), result.get()); - tensorflow::protobuf::RepeatedField* field = result->mutable_f32s(); - char* data = tensorflow::bit_cast(field->mutable_data()); + std::vector* field = result->mutable_f32s(); + char* data = tensorflow::bit_cast(field->data()); uint64 bytes = elements * sizeof(float); tensorflow::StringPiece sp; auto s = file_->Read(offset_, bytes, &sp, data); diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 563d978cf5d..45a9fe01278 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 32c3c3ae206..e8de559a5ef 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/reference_util.h" #include +#include #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -331,7 +332,8 @@ ReferenceUtil::ConvArray4DGeneralDimensions( std::pair kernel_stride, Padding padding, ConvolutionDimensionNumbers dimension_numbers) { return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, - {1, 1}, {1, 1}, dimension_numbers); + {1, 1}, {1, 1}, + std::move(dimension_numbers)); } /* static */ std::unique_ptr> diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 73b31e26f4c..8fa7d044d1b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -529,6 +529,7 @@ cc_library( srcs = ["transfer_manager.cc"], hdrs = ["transfer_manager.h"], deps = [ + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1680,10 +1681,8 @@ cc_library( deps = [ ":buffer_assignment", ":hlo", - ":hlo_ordering", ":hlo_proto", "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 83759a7a0c6..ad2fee2d39a 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -171,6 +171,7 @@ StatusOr> AllocationTracker::DeconstructTuple( executor, allocation->device_memory(), allocation->shape())); std::vector element_handles; + element_handles.reserve(element_bases.size()); for (int i = 0; i < element_bases.size(); ++i) { element_handles.push_back(RegisterInternal( allocation->backend(), allocation->device_ordinal(), element_bases[i], diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 3a1a9fe8709..b9c8589f731 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -229,25 +229,26 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( // Mapping from LogicalBuffer to index (used to detect non-distinct indices). FlatMap> buffer_to_source_indices; - TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices]( - const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& buffers) { - if (buffers.size() > 1) { - // Record ambiguous points-to set at 'index'. - if (!indices_to_copy_.element(index)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " at index: " << tensorflow::str_util::Join(index, ",") - << " with ambiguous points-to set."; - RecordIndex(index); - } - } - // For each 'buffer': record a mapping from 'buffer' to 'index'. - for (const LogicalBuffer* buffer : buffers) { - buffer_to_source_indices[buffer].push_back(index); - } - return Status::OK(); - })); + TF_RETURN_IF_ERROR(points_to.ForEachElement( + [this, &buffer_to_source_indices]( + const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& buffers) { + if (buffers.size() > 1) { + // Record ambiguous points-to set at 'index'. + if (!indices_to_copy_.element(index)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(index, ",") + << " with ambiguous points-to set."; + RecordIndex(index); + } + } + // For each 'buffer': record a mapping from 'buffer' to 'index'. + for (const LogicalBuffer* buffer : buffers) { + buffer_to_source_indices[buffer].push_back(index); + } + return Status::OK(); + })); // Record all non-distinct indices detected in 'buffer_to_source_indices'. for (const auto& buff_to_src : buffer_to_source_indices) { @@ -449,11 +450,15 @@ RevertReadOnlyIndicesForEntryParamsAndConstants( FlatMap* shared_copies) { const HloInstruction* init_hlo = while_hlo->operand(0); const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); + + // Mapping from LogicalBuffer to index (used to detect non-distinct indices). + FlatSet buffer_set; + ShapeTree copy_overrides(init_hlo->shape()); TF_RETURN_IF_ERROR(points_to.ForEachElement( - [init_hlo, read_only_indices, shared_copies, ©_overrides]( - const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& buffers) { + [init_hlo, read_only_indices, shared_copies, &buffer_set, + ©_overrides](const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& buffers) { // Look for read-only entry parameters. if (!read_only_indices->element(index)) { return Status::OK(); @@ -468,6 +473,7 @@ RevertReadOnlyIndicesForEntryParamsAndConstants( if (!is_entry_parameter && !is_constant) { continue; } + // We have found an entry parameter or constant that is read-only in // the while body. These buffers are managed by the caller, and cannot // be aliased with non-parameter buffers. Revert this read-only index, @@ -476,16 +482,17 @@ RevertReadOnlyIndicesForEntryParamsAndConstants( // Optimization to allow multiple while loops that share the same // read-only entry parameters (or constants) to share a single copy. - // Only unambiguous array-shaped buffers are allowed, to reduce code - // complexity. The shape of the entry parameter must be identical to - // the shape of the init_hlo at this index, to ensure there were no - // intervening bitcast or GTE instructions, which are also hard to - // handle. + // Only unambiguous and distinct array-shaped buffers are allowed, to + // reduce code complexity. The shape of the entry parameter must be + // identical to the shape of the init_hlo at this index, to ensure + // there were no intervening bitcast or GTE instructions, which are + // also hard to handle. const Shape& pointee_shape = pointee->shape(); const Shape& init_shape = ShapeUtil::GetSubshape(init_hlo->shape(), index); if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && - ShapeUtil::Equal(pointee_shape, init_shape)) { + ShapeUtil::Equal(pointee_shape, init_shape) && + buffer_set.count(buffer) < 1) { HloInstruction** copy = &(*shared_copies)[pointee]; if (*copy == nullptr) { *copy = @@ -496,6 +503,9 @@ RevertReadOnlyIndicesForEntryParamsAndConstants( *copy_overrides.mutable_element(index) = *copy; } + // Tracks whether this current buffer is distinct. + buffer_set.insert(buffer); + // We've already reverted the read-only index and handled the // single-copy optimization above, so there's nothing more to do. break; diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 661f682e38a..cb9682392ea 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -44,13 +44,20 @@ class CopyInsertionTest : public HloTestBase { EXPECT_IS_OK(copy_insertion.Run(module).status()); // Verify the points to set of the root of the computation after copy - // insertion contains no constants or parameters. + // insertion contains no constants or parameters, and is distinct and + // non-ambiguous. auto points_to_analysis = TuplePointsToAnalysis::Run(module).ConsumeValueOrDie(); + const auto& points_to = points_to_analysis->GetPointsToSet( + module->entry_computation()->root_instruction()); + EXPECT_TRUE(points_to.IsDistinct()); + EXPECT_TRUE(!points_to.IsAmbiguous()); + tensorflow::gtl::FlatSet maybe_live_out_buffers = points_to_analysis ->GetPointsToSet(module->entry_computation()->root_instruction()) .CreateFlattenedSet(); + for (const LogicalBuffer* buffer : maybe_live_out_buffers) { EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant); EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); @@ -390,6 +397,47 @@ class WhileCopyInsertionTest : public CopyInsertionTest { return builder.Build(); } + // Builds a While body computation with two output tuple elements dependent on + // both input tuple elements. + // + // EX: Body({in0, in1, in2}) + // out0 = Add(in0, 1) + // out1 = in1 + // out2 = in2 + // Tuple(out0, out1, out2) + std::unique_ptr BuildDependentBodyComputation2() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); + + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + + // Update the induction variable GTE(0). + auto induction_variable = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + induction_variable_shape_, loop_state, 0)); + auto inc = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + + // add0 = Add(in0, 1) + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); + // data1 = GTE(1). + HloInstruction* data1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + + // data2 = GTE(2). + HloInstruction* data2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2)); + + // Create output Tuple. + builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2})); + + return builder.Build(); + } + // Builds a While body computation with read-only tuple element 0. // EX: // Body({in0, in1}) @@ -408,6 +456,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // Update data GTE(1). auto data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + // Use 'induction_variable' in computation with no path to output tuple. auto update = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); @@ -431,6 +480,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // Create param instruction to access loop state. const Shape& loop_state_shape = nested ? nested_loop_state_shape_ : loop_state_shape_; + auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); // Update the induction variable GTE(0). @@ -972,7 +1022,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { op::Copy(old_init->operand(1)->operand(0))))); } -// Tests while init instruction buffer which interferes with while result buffer. +// Tests while init instruction buffer which interferes with while result +// buffer. // // init_data = Broadcast(...) // add_unrelated = Add(init_data) // takes a reference to cause interference @@ -989,5 +1040,81 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { op::Copy(old_init->operand(1)))); } +// Tests while init instruction buffer which has a non-distinct points-to set: +// +// init = Tuple(Parameter(S32, {}), Parameter(F32, {8}, +// Parameter(F32, {8}))) +// +// where the second and third parameters are identical *and* the tuple shared +// by another while instruction.. +// +// Verifies that the resulting point-to set is distinct in the resulting Tuple +// (non-identical Copys). In other words, verifies that copy sharing does not +// insert identical copies to the resulting tuple. +TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { + auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation()); + // Loop body that outputs tuple comprises two elements dependent on the init + // tuple. + auto body1 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2()); + auto body2 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2()); + + auto builder = HloComputation::Builder(TestName() + ".While"); + + auto iter_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "data")); + + // Loop init tuple contains two identical parameter buffers. + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({iter_param, data_param, data_param})); + + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); + + // Two while loops shares the same loop init tuple. + auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition1, body1, loop_init)); + auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition2, body2, loop_init)); + + module_.AddEntryComputation(builder.Build()); + + auto points_to_analysis = + TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie(); + + // Asserts that the init tuples before copy insertion is non-distinct. + ASSERT_FALSE( + points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct()); + ASSERT_FALSE( + points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct()); + + auto old_init1 = while_hlo1->operand(0); + auto old_init2 = while_hlo2->operand(0); + + InsertCopies(&module_); + + EXPECT_THAT(while_hlo1->operand(0), + op::Tuple(op::Copy(old_init1->operand(0)), + op::Copy(old_init1->operand(1)), + op::Copy(old_init1->operand(2)))); + + EXPECT_THAT(while_hlo2->operand(0), + op::Tuple(op::Copy(old_init2->operand(0)), + op::Copy(old_init2->operand(1)), + op::Copy(old_init2->operand(2)))); + + // Verifies the init tuples after copy insertion is distinct. + points_to_analysis = TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie(); + const auto& points_to1 = + points_to_analysis->GetPointsToSet(while_hlo1->operand(0)); + EXPECT_TRUE(points_to1.IsDistinct()); + + const auto& points_to2 = + points_to_analysis->GetPointsToSet(while_hlo2->operand(0)); + EXPECT_TRUE(points_to2.IsDistinct()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 49e9874cda2..78a398f8efa 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index c27710fbdb2..6557c3aa8e6 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 8d79d07f942..c225e62e3e1 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -31,7 +31,7 @@ AsyncExecution::AsyncExecution(Backend* backend, : backend_(CHECK_NOTNULL(backend)), streams_(std::move(streams)), profile_(profile), - result_(result) { + result_(std::move(result)) { for (const auto& stream : streams_) { CHECK(stream != nullptr); } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index c6749851dbb..dc421695cb1 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -254,6 +254,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { // d40 -- layer 4 HloComputation::Builder builder("entry_computation"); std::vector params; + params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 00471f72c99..9a09d2c02bb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1631,6 +1631,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // Compute the input buffer indices. std::vector io_buffers; + io_buffers.reserve(io_hlos.size()); for (const HloInstruction* io_hlo : io_hlos) { io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo))); } diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 28d47d2b0f8..56e3ff99fa9 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -86,6 +86,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { // d40 -- layer 4 HloComputation::Builder builder("entry_computation"); std::vector params; + params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 6583e509674..cfd1f0f53b7 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -46,7 +46,7 @@ message HloInstructionProto { xla.OpMetadata metadata = 7; // Literal, only present for kConstant. - xla.Literal literal = 8; + xla.LiteralProto literal = 8; // Parameter info, only present for kParameter. int64 parameter_number = 9; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 77279cdbc5c..24504b5ade7 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -311,7 +311,6 @@ void ComputeComputationPostOrder( visited->insert(computation); post_order->push_back(computation); - return; } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index deb355145a8..b02089206e9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -65,7 +65,7 @@ using ::tensorflow::strings::StrCat; WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); instruction->literal_.reset(new Literal); - *instruction->literal_->mutable_u8s() += tag; + instruction->literal_->append_u8s(tag); return instruction; } @@ -1484,6 +1484,7 @@ string HloInstruction::ToString(bool compact_operands, } if (!slice_starts_.empty() && !slice_limits_.empty()) { std::vector bounds; + bounds.reserve(slice_starts_.size()); for (int i = 0; i < slice_starts_.size(); ++i) { bounds.push_back( StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]")); @@ -1550,7 +1551,7 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; switch (opcode_) { case HloOpcode::kConstant: - *proto.mutable_literal() = *literal_; + *proto.mutable_literal() = literal_->ToProto(); break; case HloOpcode::kParameter: proto.set_parameter_number(parameter_number_); @@ -1647,10 +1648,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { trace_instruction_ = trace_instruction; } -const string& HloInstruction::tracing_tag() const { +string HloInstruction::TracingTag() const { CHECK_EQ(HloOpcode::kTrace, opcode()); CHECK(literal_ != nullptr); - return literal_->u8s(); + return literal_->u8s_string(); } bool HloInstruction::IsFused() const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3db185896da..3bf46341be2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -30,6 +30,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -535,7 +536,7 @@ class HloInstruction { // Returns a tag to be used in tracing. // // Precondition: opcode() == HloOpcode::kTrace - const string& tracing_tag() const; + string TracingTag() const; // Returns whether the instruction is a constant. bool IsConstant() const; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 5069215031b..721640cdbd8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -151,7 +151,26 @@ StatusOr InstructionFusion::Run(HloModule* module) { return true; }; - if (std::all_of(hlo->users().begin(), hlo->users().end(), + // An "effectively unary" operation is one that has one "large" + // input with the others being negligible in terms of memory usage. + // We use "has a smaller true rank than the output" as a heuristic + // for "negligible" memory usage. + auto effectively_unary = [](HloInstruction* hlo) { + if (hlo->operands().size() == 1) { + return true; + } + auto output_rank = ShapeUtil::TrueRank(hlo->shape()); + return std::count_if( + hlo->operands().begin(), hlo->operands().end(), + [output_rank](HloInstruction* operand) { + return ((operand->opcode() != HloOpcode::kBroadcast) && + ShapeUtil::TrueRank(operand->shape()) >= + output_rank); + }) <= 1; + }; + + if (effectively_unary(hlo) || + std::all_of(hlo->users().begin(), hlo->users().end(), user_fusable_into_hlo)) { all_consumers_fusable.insert(hlo); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 9a79e4c3824..d2df0b699ef 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -156,21 +156,67 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { HloComputation::Builder builder(TestName()); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {16, 16}), "0")); - HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0)); - builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); - HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1)); + auto shape = ShapeUtil::MakeShape(F32, {16, 16}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); + HloInstruction* binary1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + HloInstruction* unary = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(unary2, computation->root_instruction()); + EXPECT_EQ(unary, computation->root_instruction()); EXPECT_FALSE( InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) .ValueOrDie()); } +TEST_F(InstructionFusionTest, AllowUnaryDuplication) { + HloComputation::Builder builder(TestName()); + auto shape = ShapeUtil::MakeShape(F32, {16, 16}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); + HloInstruction* unary1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0)); + builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); + HloInstruction* unary2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(unary2, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { + auto shape = ShapeUtil::MakeShape(F32, {16, 16}); + auto small_shape = ShapeUtil::MakeShape(F32, {16}); + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, small_shape, "0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); + HloInstruction* binary1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + HloInstruction* unary = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(unary, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 4e94678ecf5..7b09c1f8314 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -27,6 +27,7 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Module.h" #include "external/llvm/include/llvm/IR/Value.h" #include "external/llvm/include/llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 0b94b37d376..2157604518d 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -77,8 +77,10 @@ tensorflow::Status RecordArguments( SessionModule* module) { module->clear_arguments(); for (const Allocation* allocation : arg_allocations) { - TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(), - module->add_arguments())); + Literal argument; + TF_RETURN_IF_ERROR( + LiteralFromAllocation(allocation, allocation->shape(), &argument)); + *module->add_arguments() = argument.ToProto(); } return tensorflow::Status::OK(); } @@ -87,8 +89,11 @@ tensorflow::Status RecordArguments( tensorflow::Status RecordResult(const Allocation* result_allocation, SessionModule* module) { module->clear_result(); - return LiteralFromAllocation(result_allocation, result_allocation->shape(), - module->mutable_result()); + Literal result; + TF_RETURN_IF_ERROR(LiteralFromAllocation( + result_allocation, result_allocation->shape(), &result)); + *module->mutable_result() = result.ToProto(); + return tensorflow::Status::OK(); } } // namespace @@ -649,6 +654,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), executor->device_ordinal())); std::vector arguments; + arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { arguments.push_back(allocation->device_memory()); } @@ -677,6 +683,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, BuildExecutables(versioned_handles, std::move(module_configs), execute_backend_.get(), executors)); std::vector executable_ptrs; + executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { executable_ptrs.push_back(executable.get()); } @@ -752,6 +759,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, << module_config->entry_computation_layout().ToString(); std::vector arguments; + arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { arguments.push_back(allocation->device_memory()); } @@ -820,6 +828,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, << module_config->entry_computation_layout().ToString(); std::vector arguments; + arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { arguments.push_back(allocation->device_memory()); } @@ -908,13 +917,15 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, literal_shape = &allocation->shape(); } - return LiteralFromAllocation(allocation, *literal_shape, - result->mutable_literal()); + Literal literal; + auto status = LiteralFromAllocation(allocation, *literal_shape, &literal); + *result->mutable_literal() = literal.ToProto(); + return status; } tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - const Literal& literal = arg->literal(); + Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { @@ -978,7 +989,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, } return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, arg->literal()); + executor, Literal(arg->literal())); } tensorflow::Status Service::TransferFromOutfeed( @@ -1001,8 +1012,12 @@ tensorflow::Status Service::TransferFromOutfeed( executor = execute_backend_->Replicas()[arg->replica_id()]; } - return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), result->mutable_literal()); + Literal literal; + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( + executor, arg->shape_with_layout(), &literal)); + *result->mutable_literal() = literal.ToProto(); + return tensorflow::Status::OK(); } tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto index 4902cb521c2..bb8d1cd2a10 100644 --- a/tensorflow/compiler/xla/service/session.proto +++ b/tensorflow/compiler/xla/service/session.proto @@ -75,10 +75,10 @@ message SessionModule { repeated SessionComputation embedded_computations = 2; // The arguments passed to the computation. - repeated Literal arguments = 3; + repeated LiteralProto arguments = 3; // The result of the computation. - Literal result = 4; + LiteralProto result = 4; // The name of the platform used to run the computation. string execution_platform = 5; diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index a417b988bfe..15f6b7bfb4a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc index 564111c4f2b..ca38601d919 100644 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/service/transfer_manager_test.cc @@ -121,7 +121,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) { const Shape shape = ShapeUtil::MakeShape(U8, {4}); TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( stream_exec_, memptr, shape, shape, &literal)); - CHECK_EQ("klmn", literal.u8s()); + CHECK_EQ("klmn", literal.u8s_string()); } TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) { diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index ac5f67418ed..b97823d2dc0 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -2275,7 +2275,7 @@ void ComputationLowerer::Visit( const ConstantRequest& constant_request = request.request().constant_request(); hlo_instruction = add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CloneToUnique(constant_request.literal()))); + LiteralUtil::CloneToUnique(Literal(constant_request.literal())))); break; } @@ -2467,6 +2467,7 @@ void ComputationLowerer::Visit( // to append dimensions on the left the broadcast_dimensions should just // be the n highest dimension numbers of the output shape where n is // the number of input dimensions. + broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { broadcast_dimensions.push_back(i + ShapeUtil::Rank(request.output_shape()) - diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index ddd13edeb86..ea691201263 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) { ConstantRequest constant_request; *constant_request.mutable_literal() = - *LiteralUtil::CreateR1({123.0f, 42.0f}); + LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle, computation.AddConstantInstruction(constant_request)); @@ -160,12 +160,13 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { UserComputation computation("TheComputation", handle); ConstantRequest a_request; - *a_request.mutable_literal() = *LiteralUtil::CreateR1({123.0f, 42.0f}); + *a_request.mutable_literal() = + LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, computation.AddConstantInstruction(a_request)); ConstantRequest b_request; - *b_request.mutable_literal() = *LiteralUtil::CreateR0(1.0f); + *b_request.mutable_literal() = LiteralUtil::CreateR0(1.0f)->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, computation.AddConstantInstruction(b_request)); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index aa4341d18e1..122d6ce4a98 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -44,6 +44,7 @@ struct ShapeTreeNode { // Children of this node. std::vector> children; + ShapeTreeNode() = default; explicit ShapeTreeNode(const T& data) : data(data) {} ShapeTreeNode(const ShapeTreeNode& other) @@ -85,8 +86,9 @@ class ShapeTree { public: // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} - // Create ShapeTree with the given shape, and default T values for all nodes. - explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {} + // Create ShapeTree with the given shape, and default-constructed T values for + // all nodes. + explicit ShapeTree(const Shape& shape); // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(const Shape& shape, const T& init_value); @@ -127,6 +129,19 @@ class ShapeTree { const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>; Status ForEachMutableElement(const MutableVisitorFunction& func); + // Copy the subtree of values from 'other' rooted at ShapeIndex + // 'source_base_index' into the subtree of value in this ShapeTree rooted at + // 'target_base_index'. + // + // Precondition: The subshape of other.shape() at index source_base_index must + // be compatible with the subshape of shape() at index target_base_index. + void CopySubtreeFrom(const ShapeTree& other, + const ShapeIndex& source_base_index, + const ShapeIndex& target_base_index); + + bool operator==(const ShapeTree& other) const; + bool operator!=(const ShapeTree& other) const { return !(*this == other); } + private: using Node = internal::ShapeTreeNode; @@ -134,6 +149,10 @@ class ShapeTree { // the given 'init_value'. void InitChildren(const Shape& shape, const T& init_value, Node* node); + // Initialize node->children based on 'shape'. All children have + // default-constructed data values. + void InitChildren(const Shape& shape, Node* node); + // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). @@ -165,6 +184,24 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, } } +template +void ShapeTree::InitChildren(const Shape& shape, Node* node) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + node->children.emplace_back(new Node()); + InitChildren(shape.tuple_shapes(i), node->children.back().get()); + } + } +} + +template +ShapeTree::ShapeTree(const Shape& shape) : root_(), shape_(shape) { + // The shape_ field is just used to hold the structure of the shape. + // It should not be relied upon to store layout information. + LayoutUtil::ClearLayout(&shape_); + InitChildren(shape_, &root_); +} + template ShapeTree::ShapeTree(const Shape& shape, const T& init_value) : root_(init_value), shape_(shape) { @@ -240,6 +277,48 @@ Status ShapeTree::ForEachMutableElement(const MutableVisitorFunction& func) { return ForEachMutableHelper(func, &root_, &index); } +template +void ShapeTree::CopySubtreeFrom(const ShapeTree& other, + const ShapeIndex& source_base_index, + const ShapeIndex& target_base_index) { + CHECK(ShapeUtil::Compatible( + ShapeUtil::GetSubshape(shape(), target_base_index), + ShapeUtil::GetSubshape(other.shape(), source_base_index))); + ForEachMutableElement( + [this, &other, &source_base_index, &target_base_index]( + const ShapeIndex& index, bool /*is_leaf*/, T* data) { + // Copy the data element only if index is in the + // subtree rooted at target_base_index. + for (int i = 0; i < target_base_index.size(); ++i) { + if (i >= index.size() || index[i] != target_base_index[i]) { + return Status::OK(); + } + } + // Construct source element index to copy from. + ShapeIndex source_index = source_base_index; + for (int i = target_base_index.size(); i < index.size(); ++i) { + source_index.push_back(index[i]); + } + *data = other.element(source_index); + return Status::OK(); + }) + .IgnoreError(); +} + +template +bool ShapeTree::operator==(const ShapeTree& other) const { + bool equal = true; + ForEachElement([this, &other, &equal](const ShapeIndex& index, + bool /*is_leaf*/, const T& data) { + if (data != other.element(index)) { + equal = false; + } + return Status::OK(); + }) + .IgnoreError(); + return equal; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index efb6f422e00..1b9e18023ef 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -245,5 +245,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { EXPECT_DEATH(shape_tree.element({0, 0}), ""); } +TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { + ShapeTree> shape_tree{tuple_shape_}; + EXPECT_EQ(shape_tree.element({2}).get(), nullptr); + *shape_tree.mutable_element({2}) = MakeUnique(42); + EXPECT_EQ(*shape_tree.element({2}), 42); +} + +TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) { + // Test CopySubtreeFrom method for a single value copied between array-shaped + // ShapeTrees. + ShapeTree source(array_shape_); + *source.mutable_element(/*index=*/{}) = 42; + ShapeTree destination(array_shape_, 123); + + EXPECT_EQ(destination.element(/*index=*/{}), 123); + destination.CopySubtreeFrom(source, /*source_base_index=*/{}, + /*target_base_index=*/{}); + EXPECT_EQ(destination.element(/*index=*/{}), 42); +} + +TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) { + // Test CopySubtreeFrom method for a copy of all elements from one + // tuple-shaped ShapeTree to another. + ShapeTree source(tuple_shape_); + *source.mutable_element(/*index=*/{}) = 10; + *source.mutable_element(/*index=*/{0}) = 11; + *source.mutable_element(/*index=*/{1}) = 12; + *source.mutable_element(/*index=*/{2}) = 13; + + ShapeTree destination(tuple_shape_, 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{}, + /*target_base_index=*/{}); + EXPECT_EQ(destination.element(/*index=*/{}), 10); + EXPECT_EQ(destination.element(/*index=*/{0}), 11); + EXPECT_EQ(destination.element(/*index=*/{1}), 12); + EXPECT_EQ(destination.element(/*index=*/{2}), 13); +} + +TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) { + // Test CopySubtreeFrom method for a copy of a single element from one + // tuple-shaped ShapeTree to another. + ShapeTree source(tuple_shape_); + *source.mutable_element(/*index=*/{}) = 10; + *source.mutable_element(/*index=*/{0}) = 11; + *source.mutable_element(/*index=*/{1}) = 12; + *source.mutable_element(/*index=*/{2}) = 13; + + ShapeTree destination(tuple_shape_, 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{0}, + /*target_base_index=*/{1}); + EXPECT_EQ(destination.element(/*index=*/{}), 0); + EXPECT_EQ(destination.element(/*index=*/{0}), 0); + EXPECT_EQ(destination.element(/*index=*/{1}), 11); + EXPECT_EQ(destination.element(/*index=*/{2}), 0); +} + +TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) { + // Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a + // nested-tuple-shaped ShapeTree. + ShapeTree source( + ShapeUtil::MakeTupleShape({array_shape_, array_shape_})); + *source.mutable_element(/*index=*/{}) = 10; + *source.mutable_element(/*index=*/{0}) = 11; + *source.mutable_element(/*index=*/{1}) = 12; + + ShapeTree destination(nested_tuple_shape_, 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{}, + /*target_base_index=*/{2, 0}); + + EXPECT_EQ(destination.element(/*index=*/{}), 0); + EXPECT_EQ(destination.element(/*index=*/{0}), 0); + EXPECT_EQ(destination.element(/*index=*/{1}), 0); + EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0); + EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0); + EXPECT_EQ(destination.element(/*index=*/{2}), 0); + EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10); + EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11); + EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12); + EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0); +} + +TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) { + // Test CopySubtreeFrom method for a copy from a nested-tuple-shape. + ShapeTree source(nested_tuple_shape_, 42); + *source.mutable_element(/*index=*/{1}) = 10; + *source.mutable_element(/*index=*/{1, 0}) = 11; + *source.mutable_element(/*index=*/{1, 1}) = 12; + + ShapeTree destination( + ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{1}, + /*target_base_index=*/{}); + + EXPECT_EQ(destination.element(/*index=*/{}), 10); + EXPECT_EQ(destination.element(/*index=*/{0}), 11); + EXPECT_EQ(destination.element(/*index=*/{1}), 12); +} + +TEST_F(ShapeTreeTest, OperatorEquals) { + { + ShapeTree a(array_shape_, 123); + ShapeTree b(array_shape_, 42); + ShapeTree c(array_shape_, 42); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + EXPECT_TRUE(b == c); + } + { + ShapeTree a(tuple_shape_); + *a.mutable_element(/*index=*/{}) = 10; + *a.mutable_element(/*index=*/{0}) = 11; + *a.mutable_element(/*index=*/{1}) = 12; + + ShapeTree b(tuple_shape_); + *b.mutable_element(/*index=*/{}) = 10; + *b.mutable_element(/*index=*/{0}) = 42; + *b.mutable_element(/*index=*/{1}) = 11; + + ShapeTree c(tuple_shape_); + *c.mutable_element(/*index=*/{}) = 10; + *c.mutable_element(/*index=*/{0}) = 42; + *c.mutable_element(/*index=*/{1}) = 11; + + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + EXPECT_TRUE(b == c); + EXPECT_FALSE(b != c); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ccc1dc63e78..8d04935a0bc 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -122,7 +122,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { for (const auto& shape : parameters) { *program_shape.add_parameters() = shape; } - *program_shape.mutable_result() = result; + *program_shape.mutable_result() = std::move(result); return program_shape; } diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index e2e6e25c06c..7a512166171 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -829,6 +829,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); std::vector values; + values.reserve(count); for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } @@ -836,6 +837,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); std::vector expected; + expected.reserve(values.size()); for (float value : values) { expected.push_back(value * value); } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 7bf1168dc39..08e3f81a283 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -179,7 +179,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal); VLOG(1) << "actual: " << LiteralUtil::ToString(*actual); - EXPECT_EQ(expected, actual->u8s()); + EXPECT_EQ(expected, actual->u8s_string()); } void ClientLibraryTestBase::ComputeAndCompareTuple( diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 63bfac441d3..fcdbe130d0b 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -442,6 +442,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { + ComputationBuilder builder(client_, TestName()); + + Array3D arr0(9, 17, 1); + arr0.Fill(1); + + Array3D arr1(9, 17, 256); + arr1.Fill(2); + + Array3D expected(9, 17, arr0.n3() + arr1.n3()); + for (int64 i = 0; i < expected.n1(); ++i) { + for (int64 j = 0; j < expected.n2(); ++j) { + int64 kk = 0; + for (const Array3D& arr : {arr0, arr1}) { + for (int64 k = 0; k < arr.n3(); ++k, ++kk) { + expected(i, j, kk) = arr(i, j, k); + } + } + } + } + + ComputationDataHandle h0; + auto p0 = CreateR3Parameter(arr0, /*parameter_number=*/0, "p0", + &builder, &h0); + ComputationDataHandle h1; + auto p1 = CreateR3Parameter(arr1, /*parameter_number=*/1, "p1", + &builder, &h1); + + auto concatenated = builder.ConcatInDim({h0, h1}, 2); + + ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); +} + // Describes a binary rank-2 concatenation test. struct R2BinarySpec { int64 lhs_dim0; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 23453db57bc..eb979ad189d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -262,7 +262,7 @@ class NearComparator { max_abs_err_ = 0.0; *miscompares_.mutable_shape() = ShapeUtil::ChangeElementType(actual.shape(), PRED); - miscompares_.mutable_preds()->Resize( + miscompares_.mutable_preds()->resize( ShapeUtil::ElementsIn(miscompares_.shape()), false); multi_index_.resize(expected.shape().dimensions_size(), 0); @@ -389,7 +389,7 @@ class NearComparator { tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(), now_usec, name.c_str())); TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), - filename, literal)); + filename, literal.ToProto())); LOG(ERROR) << "wrote to " << name << " file: " << filename; } diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index fdec11c0e98..a94f45f73b7 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -83,9 +83,10 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]"; EXPECT_EQ(3, results.size()); for (const string& result : results) { - Literal literal; + LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, - &literal)); + &literal_proto)); + Literal literal(literal_proto); if (result.find("expected") != string::npos) { EXPECT_EQ("2", LiteralUtil::ToString(literal)); } else if (result.find("actual") != string::npos) { diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index b520d89de3c..d3d1039e1bb 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -47,6 +47,7 @@ TEST_F(LogTest, LogTenValues) { builder.Log(x); std::vector expected; + expected.reserve(input.size()); for (float f : input) { expected.push_back(std::log(f)); } diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 2f05576ceeb..cd8f06efd82 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -246,6 +246,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { } std::vector param_data; + param_data.reserve(param_data_owner.size()); for (const std::unique_ptr& data : param_data_owner) { param_data.push_back(data.get()); } diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index d63582fb98a..82bdd6d35f0 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -37,6 +37,7 @@ class SliceTest : public ClientLibraryTestBase { template void RunSliceTenToTwo() { std::vector constant; + constant.reserve(10); for (int i = 0; i < 10; ++i) { constant.push_back(static_cast(i)); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 4ab4c84aa56..41bac6234da 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -64,6 +64,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) { for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) { ComputationBuilder builder(client_, TestName()); std::vector exponents; + exponents.reserve(count); for (int i = 0; i < count; ++i) { exponents.push_back(i / static_cast(count)); } @@ -71,6 +72,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) { auto exp = builder.Exp(x); std::vector expected; + expected.reserve(exponents.size()); for (float exponent : exponents) { expected.push_back(std::exp(exponent)); } diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index 3cfbb2c7fbf..e45e5291c9b 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 545bd22da91..7375493f430 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 4c242abc9b7..8d7f7fd1237 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -81,6 +81,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { client->GetComputationShape(computation).ConsumeValueOrDie(); std::vector layouts; + layouts.reserve(program_shape->parameters_size()); for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 152e0dcf56a..2a3a8803283 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -56,6 +56,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { client->GetComputationShape(computation).ConsumeValueOrDie(); std::vector layouts; + layouts.reserve(program_shape->parameters_size()); for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ffb2d5aefba..3a75bf64954 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -66,7 +66,8 @@ StatusOr> ReplayComputation( if (use_fake_data) { arguments = MakeFakeArgumentsOrDie(computation, client); } else { // use recorded data if available - for (const Literal& literal : module.arguments()) { + for (const auto& proto : module.arguments()) { + Literal literal(proto); TF_ASSIGN_OR_RETURN(std::unique_ptr data, client->TransferToServer(literal)); arguments.push_back(std::move(data)); @@ -74,6 +75,7 @@ StatusOr> ReplayComputation( } std::vector execute_arguments; + execute_arguments.reserve(arguments.size()); for (auto& argument : arguments) { execute_arguments.push_back(argument.get()); } @@ -100,7 +102,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { if (module.has_result()) { fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(module.result().shape()).c_str(), - LiteralUtil::ToString(module.result()).c_str()); + LiteralUtil::ToString(Literal(module.result())).c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index cf363913b15..b6538f5de07 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -37,9 +37,10 @@ int main(int argc, char **argv) { << " "; } - xla::Literal literal; + xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], - &literal)); - LOG(INFO) << "literal: " << literal.ShortDebugString(); + &literal_proto)); + xla::Literal literal(literal_proto); + LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str()); } diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 716eb424424..193ae49afee 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -92,11 +92,11 @@ message TransferToClientRequest { } message TransferToClientResponse { - Literal literal = 1; + LiteralProto literal = 1; } message TransferToServerRequest { - Literal literal = 1; + LiteralProto literal = 1; DeviceHandle device_handle = 2; } @@ -105,7 +105,7 @@ message TransferToServerResponse { } message TransferToInfeedRequest { - Literal literal = 1; + LiteralProto literal = 1; int64 replica_id = 2; DeviceHandle device_handle = 3; } @@ -123,7 +123,7 @@ message TransferFromOutfeedRequest { } message TransferFromOutfeedResponse { - Literal literal = 1; + LiteralProto literal = 1; } message ResetDeviceRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 1239816c50e..44a94e171fa 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -275,7 +275,7 @@ message ChannelHandle { // // Transfers to/from the client are encoded in literal form, and the structure // of the repeated fields is implied by the shape. -message Literal { +message LiteralProto { Shape shape = 1; repeated bool preds = 2; bytes u8s = 3; @@ -285,7 +285,7 @@ message Literal { repeated uint64 u64s = 7; repeated float f32s = 8; repeated double f64s = 9; - repeated Literal tuple_literals = 10; + repeated LiteralProto tuple_literals = 10; bytes f16s = 11; // Note: the F16s are encoded in little endian byte order } @@ -337,7 +337,7 @@ message Window { // field in OpRequest. message ConstantRequest { - Literal literal = 2; + LiteralProto literal = 2; } message GetTupleElementRequest { diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 8fc6c27e355..7286cce03c6 100755 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -85,6 +85,7 @@ cc_library( "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nccl:nccl_kernels", + "//tensorflow/contrib/seq2seq:beam_search_ops_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels", "//tensorflow/contrib/text:all_kernels", ], @@ -100,6 +101,7 @@ cc_library( "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", + "//tensorflow/contrib/seq2seq:beam_search_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", ], diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/contrib/batching/kernels/batch_kernels.cc index 1e0957298ba..3c06325651f 100644 --- a/tensorflow/contrib/batching/kernels/batch_kernels.cc +++ b/tensorflow/contrib/batching/kernels/batch_kernels.cc @@ -347,6 +347,7 @@ class BatchResource : public ResourceBase { // Concatenate the tasks ith input tensors into a big output tensor. std::vector to_concatenate; + to_concatenate.reserve(batch->num_tasks()); for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) { to_concatenate.push_back(batch->task(task_idx).inputs.at(i)); } diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc index 809958c737e..3e924ae5f13 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc @@ -139,6 +139,7 @@ TEST(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) { &callback_data](std::unique_ptr> batch) { ASSERT_TRUE(batch->IsClosed()); std::vector batch_data; + batch_data.reserve(batch->num_tasks()); for (int i = 0; i < batch->num_tasks(); ++i) { batch_data.push_back(batch->mutable_task(i)->size()); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc index 66e0995ecd0..f658532acb2 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc @@ -295,6 +295,7 @@ void ExpectVecsEquiv(const std::vector& vec1, std::vector GetWeightsByIndex(const std::vector& weights, const std::vector& indices) { std::vector res; + res.reserve(indices.size()); for (const int index : indices) { res.push_back(weights[index]); } diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 19e608f5abe..8287d6838a5 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -236,6 +236,9 @@ add_python_module("tensorflow/tensorboard") add_python_module("tensorflow/tensorboard/backend") add_python_module("tensorflow/tensorboard/backend/event_processing") add_python_module("tensorflow/tensorboard/plugins") +add_python_module("tensorflow/tensorboard/plugins/audio") +add_python_module("tensorflow/tensorboard/plugins/distributions") +add_python_module("tensorflow/tensorboard/plugins/graphs") add_python_module("tensorflow/tensorboard/plugins/histograms") add_python_module("tensorflow/tensorboard/plugins/images") add_python_module("tensorflow/tensorboard/plugins/projector") @@ -536,6 +539,7 @@ set(tf_python_op_gen_main_srcs "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" ) add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs}) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 482a6f62c42..55e9e311f92 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -209,10 +209,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Broken TensorBoard tests due to different paths in windows "${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py" "${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py" + "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py" + "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py" # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" - "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py" # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 2f7f8ebbae8..68cd3623c00 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -150,7 +150,8 @@ class MapDatasetTest(test.TestCase): results.append(sess.run(get_next)) except errors.OutOfRangeError: return - threads = [self.checkedThread(target=iterator_thread) for _ in range(8)] + threads = [self.checkedThread(target=iterator_thread) + for _ in range(64)] for t in threads: t.start() for t in threads: diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index 3a964311820..a2136c08bbc 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -375,8 +375,8 @@ class NearestNeighborsOp : public OpKernel { const Eigen::Ref& points_half_squared_norm, const Eigen::Ref& centers, const Eigen::Ref& centers_half_squared_norm, - Eigen::Ref nearest_center_indices, - Eigen::Ref nearest_center_distances) { + const Eigen::Ref& nearest_center_indices, + const Eigen::Ref& nearest_center_distances) { CHECK_LE(k, centers.rows()); if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) { FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers, diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index d3fa233a124..42815664adf 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -164,11 +164,12 @@ class KMeans(object): with ops.colocate_with(inp): # Computes Euclidean distance. Note the first and third terms are # broadcast additions. - squared_distance = (math_ops.reduce_sum( - math_ops.square(inp), 1, keep_dims=True) - 2 * math_ops.matmul( - inp, clusters, transpose_b=True) + array_ops.transpose( - math_ops.reduce_sum( - math_ops.square(clusters), 1, keep_dims=True))) + squared_distance = ( + math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) - + 2 * math_ops.matmul(inp, clusters, transpose_b=True) + + array_ops.transpose( + math_ops.reduce_sum( + math_ops.square(clusters), 1, keep_dims=True))) output.append(squared_distance) return output @@ -229,12 +230,12 @@ class KMeans(object): clusters = nn_impl.l2_normalize(clusters, dim=1) for inp, score in zip(inputs, scores): with ops.colocate_with(inp): - (indices, - distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1) + (indices, distances) = gen_clustering_ops.nearest_neighbors( + inp, clusters, 1) if self._distance_metric == COSINE_DISTANCE: distances *= 0.5 - output.append( - (score, array_ops.squeeze(distances), array_ops.squeeze(indices))) + output.append((score, array_ops.squeeze(distances), + array_ops.squeeze(indices))) return zip(*output) def _init_clusters_random(self): @@ -265,9 +266,7 @@ class KMeans(object): (not self._use_mini_batch or self._mini_batch_steps_per_iteration > 1)) - def _initialize_clusters(self, - cluster_centers, - cluster_centers_initialized, + def _initialize_clusters(self, cluster_centers, cluster_centers_initialized, cluster_centers_updated): """Returns an op to initialize the cluster centers.""" @@ -294,22 +293,20 @@ class KMeans(object): with ops.colocate_with(cluster_centers_initialized): initialized = control_flow_ops.with_dependencies( - [clusters_init], - array_ops.identity(cluster_centers_initialized)) + [clusters_init], array_ops.identity(cluster_centers_initialized)) with ops.colocate_with(cluster_centers): - assign_centers = state_ops.assign(cluster_centers, clusters_init, - validate_shape=False) + assign_centers = state_ops.assign( + cluster_centers, clusters_init, validate_shape=False) if cluster_centers_updated != cluster_centers: - assign_centers = control_flow_ops.group( - assign_centers, - state_ops.assign(cluster_centers_updated, clusters_init, - validate_shape=False)) - assign_centers = control_flow_ops.with_dependencies( - [assign_centers], - state_ops.assign(cluster_centers_initialized, True)) - return control_flow_ops.cond(initialized, - control_flow_ops.no_op, - lambda: assign_centers).op + assign_centers = control_flow_ops.group(assign_centers, + state_ops.assign( + cluster_centers_updated, + clusters_init, + validate_shape=False)) + assign_centers = control_flow_ops.with_dependencies( + [assign_centers], state_ops.assign(cluster_centers_initialized, True)) + return control_flow_ops.cond(initialized, control_flow_ops.no_op, + lambda: assign_centers).op def _create_variables(self): """Creates variables. @@ -327,19 +324,16 @@ class KMeans(object): cluster_centers_updated back to cluster_centers. """ init_value = array_ops.constant([], dtype=dtypes.float32) - cluster_centers = variable_scope.variable(init_value, - name='clusters', - validate_shape=False) - cluster_centers_initialized = variable_scope.variable(False, - dtype=dtypes.bool, - name='initialized') + cluster_centers = variable_scope.variable( + init_value, name='clusters', validate_shape=False) + cluster_centers_initialized = variable_scope.variable( + False, dtype=dtypes.bool, name='initialized') if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: # Copy of cluster centers actively updated each step according to # mini-batch update rule. - cluster_centers_updated = variable_scope.variable(init_value, - name='clusters_updated', - validate_shape=False) + cluster_centers_updated = variable_scope.variable( + init_value, name='clusters_updated', validate_shape=False) # How many steps till we copy the updated clusters to cluster_centers. update_in_steps = variable_scope.variable( self._mini_batch_steps_per_iteration, @@ -347,20 +341,15 @@ class KMeans(object): name='update_in_steps') # Count of points assigned to cluster_centers_updated. cluster_counts = variable_scope.variable( - array_ops.zeros([self._num_clusters], - dtype=dtypes.int64)) + array_ops.zeros([self._num_clusters], dtype=dtypes.int64)) else: cluster_centers_updated = cluster_centers update_in_steps = None - cluster_counts = (variable_scope.variable(array_ops.ones( - [self._num_clusters], - dtype=dtypes.int64)) + cluster_counts = (variable_scope.variable( + array_ops.ones([self._num_clusters], dtype=dtypes.int64)) if self._use_mini_batch else None) - return (cluster_centers, - cluster_centers_initialized, - cluster_counts, - cluster_centers_updated, - update_in_steps) + return (cluster_centers, cluster_centers_initialized, cluster_counts, + cluster_centers_updated, update_in_steps) @classmethod def _l2_normalize_data(cls, inputs): @@ -391,11 +380,8 @@ class KMeans(object): """ # Implementation of kmeans. inputs = self._inputs - (cluster_centers_var, - cluster_centers_initialized, - total_counts, - cluster_centers_updated, - update_in_steps) = self._create_variables() + (cluster_centers_var, cluster_centers_initialized, total_counts, + cluster_centers_updated, update_in_steps) = self._create_variables() init_op = self._initialize_clusters(cluster_centers_var, cluster_centers_initialized, cluster_centers_updated) @@ -409,8 +395,7 @@ class KMeans(object): all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers) if self._use_mini_batch: sync_updates_op = self._mini_batch_sync_updates_op( - update_in_steps, - cluster_centers_var, cluster_centers_updated, + update_in_steps, cluster_centers_var, cluster_centers_updated, total_counts) assert sync_updates_op is not None with ops.control_dependencies([sync_updates_op]): @@ -421,15 +406,15 @@ class KMeans(object): training_op = self._full_batch_training_op(inputs, cluster_idx, cluster_centers_var) - return (all_scores, cluster_idx, scores, - cluster_centers_initialized, init_op, training_op) + return (all_scores, cluster_idx, scores, cluster_centers_initialized, + init_op, training_op) - def _mini_batch_sync_updates_op(self, update_in_steps, - cluster_centers_var, cluster_centers_updated, - total_counts): + def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var, + cluster_centers_updated, total_counts): if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: assert update_in_steps is not None with ops.colocate_with(update_in_steps): + def _f(): # Note that there is a race condition here, so we do a best effort # updates here. We reset update_in_steps first so that other workers @@ -437,33 +422,36 @@ class KMeans(object): # before resetting total_counts to avoid large updates to # cluster_centers_updated based on partially updated # cluster_center_vars. - with ops.control_dependencies([state_ops.assign( - update_in_steps, - self._mini_batch_steps_per_iteration - 1)]): - with ops.colocate_with(cluster_centers_updated): + with ops.control_dependencies([ + state_ops.assign(update_in_steps, + self._mini_batch_steps_per_iteration - 1) + ]): + with ops.colocate_with( + cluster_centers_updated, ignore_existing=True): if self._distance_metric == COSINE_DISTANCE: - cluster_centers = nn_impl.l2_normalize(cluster_centers_updated, - dim=1) + cluster_centers = nn_impl.l2_normalize( + cluster_centers_updated, dim=1) else: cluster_centers = cluster_centers_updated with ops.colocate_with(cluster_centers_var): - with ops.control_dependencies([state_ops.assign( - cluster_centers_var, - cluster_centers)]): - with ops.colocate_with(cluster_centers_var): + with ops.control_dependencies( + [state_ops.assign(cluster_centers_var, cluster_centers)]): + with ops.colocate_with( + cluster_centers_var, ignore_existing=True): with ops.control_dependencies([ state_ops.assign(total_counts, - array_ops.zeros_like(total_counts))]): + array_ops.zeros_like(total_counts)) + ]): return array_ops.identity(update_in_steps) + return control_flow_ops.cond( - update_in_steps <= 0, - _f, + update_in_steps <= 0, _f, lambda: state_ops.assign_sub(update_in_steps, 1)) else: return control_flow_ops.no_op() - def _mini_batch_training_op(self, inputs, cluster_idx_list, - cluster_centers, total_counts): + def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers, + total_counts): """Creates an op for training for mini batch case. Args: @@ -487,17 +475,15 @@ class KMeans(object): unique_ids, unique_idx = array_ops.unique(cluster_idx) num_unique_cluster_idx = array_ops.size(unique_ids) # Fetch the old values of counts and cluster_centers. - with ops.colocate_with(total_counts): + with ops.colocate_with(total_counts, ignore_existing=True): old_counts = array_ops.gather(total_counts, unique_ids) # TODO(agarwal): This colocation seems to run into problems. Fix it. - # with ops.colocate_with(cluster_centers): - old_cluster_centers = array_ops.gather(cluster_centers, unique_ids) + with ops.colocate_with(cluster_centers, ignore_existing=True): + old_cluster_centers = array_ops.gather(cluster_centers, unique_ids) # Locally aggregate the increment to counts. count_updates = math_ops.unsorted_segment_sum( - array_ops.ones_like( - unique_idx, dtype=total_counts.dtype), - unique_idx, - num_unique_cluster_idx) + array_ops.ones_like(unique_idx, dtype=total_counts.dtype), + unique_idx, num_unique_cluster_idx) # Locally compute the sum of inputs mapped to each id. # For a cluster with old cluster value x, old count n, and with data # d_1,...d_k newly assigned to it, we recompute the new value as @@ -507,13 +493,12 @@ class KMeans(object): inp, unique_idx, num_unique_cluster_idx) # Shape to enable broadcasting count_updates and learning_rate to inp. # It extends the shape with 1's to match the rank of inp. - broadcast_shape = array_ops.concat( - [ - array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones( - array_ops.reshape(array_ops.rank(inp) - 1, [1]), - dtype=dtypes.int32) - ], - 0) + broadcast_shape = array_ops.concat([ + array_ops.reshape(num_unique_cluster_idx, [1]), + array_ops.ones( + array_ops.reshape(array_ops.rank(inp) - 1, [1]), + dtype=dtypes.int32) + ], 0) # Subtract k * x, see comment above. cluster_center_updates -= math_ops.cast( array_ops.reshape(count_updates, broadcast_shape), @@ -524,14 +509,10 @@ class KMeans(object): # scale by 1 / (n + k), see comment above. cluster_center_updates *= learning_rate # Apply the updates. - update_counts = state_ops.scatter_add( - total_counts, - unique_ids, - count_updates) + update_counts = state_ops.scatter_add(total_counts, unique_ids, + count_updates) update_cluster_centers = state_ops.scatter_add( - cluster_centers, - unique_ids, - cluster_center_updates) + cluster_centers, unique_ids, cluster_center_updates) update_ops.extend([update_counts, update_cluster_centers]) return control_flow_ops.group(*update_ops) @@ -552,7 +533,7 @@ class KMeans(object): cluster_counts = [] epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype) for inp, cluster_idx in zip(inputs, cluster_idx_list): - with ops.colocate_with(inp): + with ops.colocate_with(inp, ignore_existing=True): cluster_sums.append( math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters)) cluster_counts.append( @@ -561,7 +542,7 @@ class KMeans(object): array_ops.ones( array_ops.reshape(array_ops.shape(inp)[0], [-1])), [-1, 1]), cluster_idx, self._num_clusters)) - with ops.colocate_with(cluster_centers): + with ops.colocate_with(cluster_centers, ignore_existing=True): new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast( math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) if self._clusters_l2_normalized(): diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc index 2c6e278fec7..2871c146289 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc @@ -94,6 +94,7 @@ TEST(FfmpegLibTest, TestRoundTripGeneratedWav) { } std::vector sine_wave; + sine_wave.reserve(20000); for (int i = 0; i < 20000; ++i) { sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0)); } diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index 47a5b2a2077..219473153bd 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -494,6 +494,7 @@ class SparseFeatureCrossOp : public OpKernel { ExtractFeatureData(indices_list_in, batch_size, &feature_counts, &feature_start_indices); + columns.reserve(values_list_in.size()); for (int i = 0; i < values_list_in.size(); ++i) { columns.emplace_back(new SparseTensorColumn( values_list_in[i], std::move(feature_counts[i]), diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index a40cbc04490..bba479a00ee 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -308,6 +308,7 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_rea from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head from tensorflow.contrib.learn.python.learn.estimators.head import Head +from tensorflow.contrib.learn.python.learn.estimators.head import loss_only_head from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head from tensorflow.contrib.learn.python.learn.estimators.head import multi_head from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index d270d89c12b..6e15e7891e9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -429,6 +429,23 @@ def multi_label_head(n_classes, loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) +def loss_only_head(loss_fn, head_name=None): + """Creates a Head that contains only loss terms. + + Loss only head holds additional loss terms to be added to other heads and + usually represents additional regularization terms in the objective function. + + Args: + loss_fn: a function that takes no argument and returns a list of + scalar tensors. + head_name: a name for for the head. + + Returns: + An instance of `Head` to hold the additional losses. + """ + return _LossOnlyHead(loss_fn, head_name=head_name) + + def multi_head(heads, loss_weights=None): """Creates a MultiHead stemming from same logits/hidden layer. @@ -1406,6 +1423,80 @@ class _MultiLabelHead(_SingleHead): return metrics +class _LossOnlyHead(Head): + """`Head` implementation for additional loss terms. + + This class only holds loss terms unrelated to any other heads (labels), + e.g. regularization. + + Common usage: + This is oftem combine with other heads in a multi head setup. + ```python + head = multi_head([ + head1, head2, loss_only_head('regularizer', regularizer)]) + ``` + """ + + def __init__(self, loss_fn, head_name=None): + self._loss_fn = loss_fn + self.head_name = head_name or "loss_only_head" + + @property + def logits_dimension(self): + return 0 + + def create_model_fn_ops(self, + features, + mode, + labels=None, + train_op_fn=None, + logits=None, + logits_input=None, + scope=None): + """See `_Head.create_model_fn_ops`. + + Args: + features: Not been used. + mode: Estimator's `ModeKeys`. + labels: Labels `Tensor`, or `dict` of same. + train_op_fn: Function that takes a scalar loss and returns an op to + optimize with the loss. + logits: Not been used. + logits_input: Not been used. + scope: Optional scope for variable_scope. If provided, will be passed to + all heads. Most users will want to set this to `None`, so each head + constructs a separate variable_scope according to its `head_name`. + + Returns: + A `ModelFnOps` object. + + Raises: + ValueError: if `mode` is not recognition. + """ + _check_mode_valid(mode) + loss = None + train_op = None + if mode != model_fn.ModeKeys.INFER: + with variable_scope.variable_scope(scope, default_name=self.head_name): + loss = self._loss_fn() + if isinstance(loss, list): + loss = math_ops.add_n(loss) + logging_ops.scalar_summary( + _summary_key(self.head_name, mkey.LOSS), loss) + if mode == model_fn.ModeKeys.TRAIN: + if train_op_fn is None: + raise ValueError("train_op_fn can not be None in TRAIN mode") + with ops.name_scope(None, "train_op", (loss,)): + train_op = train_op_fn(loss) + + return model_fn.ModelFnOps( + mode=mode, + loss=loss, + train_op=train_op, + predictions={}, + eval_metric_ops={}) + + class _MultiHead(Head): """`Head` implementation for multi objective learning. @@ -1525,7 +1616,10 @@ class _MultiHead(Head): if isinstance(logits, dict): head_logits_pairs = [] for head in self._heads: - head_logits_pairs.append((head, logits[head.head_name])) + if isinstance(head, _LossOnlyHead): + head_logits_pairs.append((head, None)) + else: + head_logits_pairs.append((head, logits[head.head_name])) else: # Split logits for each head. head_logits_pairs = zip(self._heads, self._split_logits(logits)) @@ -1606,6 +1700,8 @@ class _MultiHead(Head): predictions = {} output_alternatives = {} for head, m in zip(self._heads, all_model_fn_ops): + if isinstance(head, _LossOnlyHead): + continue head_name = head.head_name output_alternatives[head_name] = m.output_alternatives[head_name] for k, v in m.predictions.items(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 012b919d631..25a66748587 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -1638,6 +1638,21 @@ class BinarySvmHeadTest(test.TestCase): }, model_fn_ops) +class LossOnlyHead(test.TestCase): + + def testNoPredictionsAndNoMetrics(self): + head = head_lib.loss_only_head(lambda: 1, head_name="const") + model_fn_ops = head.create_model_fn_ops( + features={}, + mode=model_fn.ModeKeys.TRAIN, + train_op_fn=head_lib.no_op_train_fn) + self.assertDictEqual(model_fn_ops.predictions, {}) + self.assertDictEqual(model_fn_ops.eval_metric_ops, {}) + self.assertIsNotNone(model_fn_ops.loss) + with session.Session() as sess: + self.assertEqual(1, sess.run(model_fn_ops.loss)) + + class MultiHeadTest(test.TestCase): def testInvalidHeads(self): @@ -1672,7 +1687,8 @@ class MultiHeadTest(test.TestCase): n_classes=3, label_name="label1", head_name="head1") head2 = head_lib.multi_class_head( n_classes=4, label_name="label2", head_name="head2") - head = head_lib.multi_head((head1, head2)) + head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const") + head = head_lib.multi_head((head1, head2, head3)) labels = { "label1": (1,), "label2": (1,) @@ -1691,7 +1707,7 @@ class MultiHeadTest(test.TestCase): self.assertIsNone(model_fn_ops.output_alternatives) with session.Session() as sess: - self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3) + self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3) def testTrain_withHeadWeights(self): head1 = head_lib.multi_class_head( diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 65474f03fa0..e49b62afa28 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None, ``` Args: - vocabulary_file: The vocabulary filename. + vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. num_oov_buckets: The number of out-of-vocabulary buckets. vocab_size: Number of the elements in the vocabulary, if known. default_value: The value to use for out-of-vocabulary feature values. @@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None, ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater than zero. """ - if not vocabulary_file: - raise ValueError("vocabulary_file must be specified.") + if vocabulary_file is None or ( + isinstance(vocabulary_file, str) and not vocabulary_file): + raise ValueError("vocabulary_file must be specified and must not be empty.") if num_oov_buckets < 0: raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." % num_oov_buckets) diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5ec169b6db4..180dfefe29d 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase): lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) + def test_string_index_table_from_file_tensor_filename(self): + vocabulary_file = self._createVocabFile("f2i_vocab1.txt") + with self.test_session(): + vocabulary_file = constant_op.constant(vocabulary_file) + table = lookup.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab2.txt", values=("42", "1", "-1000")) @@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase): 860), # 3 + fingerprint("toccata") mod 300. ids.eval()) - def test_index_table_from_file_with_only_oov_buckets(self): + def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self): + self.assertRaises( + ValueError, + lookup.index_table_from_file, + vocabulary_file="") + + def test_index_table_from_file_fails_with_empty_vocabulary(self): self.assertRaises( ValueError, lookup.index_table_from_file, diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index 85eecfc6cd7..4c16fb50407 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -23,6 +23,7 @@ See the @{$python/contrib.metrics} guide. @@streaming_precision @@streaming_precision_at_thresholds @@streaming_auc +@@streaming_curve_points @@streaming_recall_at_k @@streaming_mean_absolute_error @@streaming_mean_iou @@ -76,6 +77,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 727cdd9597a..c2211961dfb 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -733,6 +733,102 @@ def streaming_true_negatives_at_thresholds( return values['tn'], update_ops['tn'] +def streaming_curve_points(labels=None, + predictions=None, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + curve='ROC', + name=None): + """Computes curve (ROC or PR) values for a prespecified number of points. + + The `streaming_curve_points` function creates four local variables, + `true_positives`, `true_negatives`, `false_positives` and `false_negatives` + that are used to compute the curve values. To discretize the curve, a linearly + spaced set of thresholds is used to compute pairs of recall and precision + values. + + For best results, `predictions` should be distributed approximately uniformly + in the range [0, 1] and not peaked around 0 or 1. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: A `Tensor` whose shape matches `predictions`. Will be cast to + `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + num_thresholds: The number of thresholds to use when discretizing the roc + curve. + metrics_collections: An optional list of collections that `auc` should be + added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + curve: Specifies the name of the curve to be computed, 'ROC' [default] or + 'PR' for the Precision-Recall-curve. + name: An optional variable_scope name. + + Returns: + points: A `Tensor` with shape [num_thresholds, 2] that contains points of + the curve. + update_op: An operation that increments the `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` variables. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'curve_points', (labels, predictions, + weights)): + if curve != 'ROC' and curve != 'PR': + raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) + kepsilon = 1e-7 # to account for floating point imprecisions + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] + + values, update_ops = _streaming_confusion_matrix_at_thresholds( + labels=labels, + predictions=predictions, + thresholds=thresholds, + weights=weights) + + # Add epsilons to avoid dividing by 0. + epsilon = 1.0e-6 + + def compute_points(tp, fn, tn, fp): + """Computes the roc-auc or pr-auc based on confusion counts.""" + rec = math_ops.div(tp + epsilon, tp + fn + epsilon) + if curve == 'ROC': + fp_rate = math_ops.div(fp, fp + tn + epsilon) + return fp_rate, rec + else: # curve == 'PR'. + prec = math_ops.div(tp + epsilon, tp + fp + epsilon) + return rec, prec + + xs, ys = compute_points(values['tp'], values['fn'], values['tn'], + values['fp']) + points = array_ops.stack([xs, ys], axis=1) + update_op = control_flow_ops.group(*update_ops.values()) + + if metrics_collections: + ops.add_to_collections(metrics_collections, points) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return points, update_op + + def streaming_auc(predictions, labels, weights=None, num_thresholds=200, metrics_collections=None, updates_collections=None, curve='ROC', name=None): @@ -2372,6 +2468,7 @@ __all__ = [ 'sparse_recall_at_top_k', 'streaming_accuracy', 'streaming_auc', + 'streaming_curve_points', 'streaming_false_negatives', 'streaming_false_negatives_at_thresholds', 'streaming_false_positives', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index f97f03e30e1..f93b1945a69 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1327,6 +1327,99 @@ class StreamingRecallTest(test.TestCase): self.assertEqual(0, recall.eval()) +class StreamingCurvePointsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metric_ops.streaming_curve_points( + predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) + _assert_local_variables( + self, + ('curve_points/true_positives:0', 'curve_points/false_negatives:0', + 'curve_points/false_positives:0', 'curve_points/true_negatives:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + points, _ = metric_ops.streaming_curve_points( + labels=array_ops.ones((10, 1)), + predictions=array_ops.ones((10, 1)), + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [points]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metric_ops.streaming_curve_points( + labels=array_ops.ones((10, 1)), + predictions=array_ops.ones((10, 1)), + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def _testValueTensorIsIdempotent(self, curve): + predictions = constant_op.constant( + np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32) + labels = constant_op.constant( + np.random.uniform(high=2, size=(10, 3)), dtype=dtypes_lib.float32) + + points, update_op = metric_ops.streaming_curve_points( + labels, predictions=predictions, curve=curve) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + sess.run(update_op) + initial_points = points.eval() + + sess.run(update_op) + self.assertAllClose(initial_points, points.eval()) + + def testValueTensorIsIdempotentROC(self): + self._testValueTensorIsIdempotent(curve='ROC') + + def testValueTensorIsIdempotentPR(self): + self._testValueTensorIsIdempotent(curve='PR') + + def _testCase(self, labels, predictions, curve, expected_points): + with self.test_session() as sess: + predictions_tensor = constant_op.constant( + predictions, dtype=dtypes_lib.float32) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32) + points, update_op = metric_ops.streaming_curve_points( + labels=labels_tensor, + predictions=predictions_tensor, + num_thresholds=3, + curve=curve) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self.assertAllClose(expected_points, points.eval()) + + def testEdgeCasesROC(self): + self._testCase([[1]], [[1]], 'ROC', [[0, 1], [0, 1], [0, 0]]) + self._testCase([[0]], [[0]], 'ROC', [[1, 1], [0, 1], [0, 1]]) + self._testCase([[0]], [[1]], 'ROC', [[1, 1], [1, 1], [0, 1]]) + self._testCase([[1]], [[0]], 'ROC', [[0, 1], [0, 0], [0, 0]]) + + def testManyValuesROC(self): + self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]], + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'ROC', + [[1.0, 1.0], [0.0, 0.75], [0.0, 0.0]]) + + def testEdgeCasesPR(self): + self._testCase([[1]], [[1]], 'PR', [[1, 1], [1, 1], [0, 1]]) + self._testCase([[0]], [[0]], 'PR', [[1, 0], [1, 1], [1, 1]]) + self._testCase([[0]], [[1]], 'PR', [[1, 0], [1, 0], [1, 1]]) + self._testCase([[1]], [[0]], 'PR', [[1, 1], [0, 1], [0, 1]]) + + def testManyValuesPR(self): + self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]], + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'PR', + [[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]]) + + class StreamingAUCTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index aeafe7c3e59..3d0627467aa 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase): class BeamSearchDecoderTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, has_attention): - encoder_sequence_length = [3, 2, 3, 1, 1] - decoder_sequence_length = [2, 0, 1, 2, 3] + encoder_sequence_length = np.array([3, 2, 3, 1, 1]) + decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 @@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase): batch_size_tensor = constant_op.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) + initial_state = cell.zero_state(batch_size, dtypes.float32) if has_attention: inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, @@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase): num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) + initial_state = beam_search_decoder.tile_batch( + initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, @@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase): alignment_history=False) cell_state = cell.zero_state( dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) + if has_attention: + cell_state = cell_state.clone( + cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=embedding, diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index f1d0ab07711..1d1babda163 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -72,10 +72,30 @@ class FinalBeamSearchDecoderOutput( pass -def tile_batch(t, multiplier, name=None): - """Tile the batch dimension of tensor t. +def _tile_batch(t, multiplier): + """Core single-tensor implementation of tile_batch.""" + t = ops.convert_to_tensor(t, name="t") + shape_t = array_ops.shape(t) + if t.shape.ndims is None or t.shape.ndims < 1: + raise ValueError("t must have statically known rank") + tiling = [1] * (t.shape.ndims + 1) + tiling[1] = multiplier + tiled_static_batch_size = ( + t.shape[0].value * multiplier if t.shape[0].value is not None else None) + tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) + tiled = array_ops.reshape( + tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) + tiled.set_shape( + tensor_shape.TensorShape( + [tiled_static_batch_size]).concatenate(t.shape[1:])) + return tiled - This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of + +def tile_batch(t, multiplier, name=None): + """Tile the batch dimension of a (possibly nested structure of) tensor(s) t. + + For each tensor t in a (possibly nested structure) of tensors, + this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated @@ -87,27 +107,16 @@ def tile_batch(t, multiplier, name=None): name: Name scope for any created operations. Returns: - A `Tensor` shaped `[batch_size * multiplier, ...]`. + A (possibly nested structure of) `Tensor` shaped + `[batch_size * multiplier, ...]`. Raises: - ValueError: if `t` does not have a statically known rank or it's < 1. + ValueError: if tensor(s) `t` do not have a statically known rank or + the rank is < 1. """ - with ops.name_scope(name, "tile_batch", [t, multiplier]): - t = ops.convert_to_tensor(t, name="t") - shape_t = array_ops.shape(t) - if t.shape.ndims is None or t.shape.ndims < 1: - raise ValueError("t must have statically known rank") - tiling = [1] * (t.shape.ndims + 1) - tiling[1] = multiplier - tiled_static_batch_size = ( - t.shape[0].value * multiplier if t.shape[0].value is not None else None) - tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) - tiled = array_ops.reshape( - tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) - tiled.set_shape( - tensor_shape.TensorShape( - [tiled_static_batch_size]).concatenate(t.shape[1:])) - return tiled + flat_t = nest.flatten(t) + with ops.name_scope(name, "tile_batch", flat_t + [multiplier]): + return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) def _check_maybe(t): diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc index ad6264d5c8a..eb36d79e0f4 100644 --- a/tensorflow/contrib/session_bundle/session_bundle_test.cc +++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc @@ -270,7 +270,7 @@ class SessionBundleTest : public ::testing::Test { // MetaGraphDef. // Returns the path of the export. // ** Should only be called once per test ** - string SetupExport(MetaGraphDefTwiddler twiddler) { + string SetupExport(const MetaGraphDefTwiddler& twiddler) { return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename); } // SetupExport that allows for the variables and meta_graph_def filenames diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 76ed53dc33b..1cc712c2e1a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -62,6 +62,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "full_path", "if_android", "if_ios", "if_x86", diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc index aa8a2d989bf..8fc64fff69a 100644 --- a/tensorflow/core/common_runtime/device.cc +++ b/tensorflow/core/common_runtime/device.cc @@ -30,7 +30,11 @@ Device::Device(Env* env, const DeviceAttributes& device_attributes) rmgr_ = new ResourceMgr(parsed_name_.job); } -Device::~Device() { delete rmgr_; } +Device::~Device() { + if (rmgr_ != nullptr) { + DeleteResourceMgr(); + } +} // static DeviceAttributes Device::BuildDeviceAttributes( diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index c0e58f143e3..7312226f388 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -60,7 +60,9 @@ class Device : public DeviceBase { const string& name() const { return device_attributes_.name(); } // Parsed name of this device - const DeviceNameUtils::ParsedName parsed_name() const { return parsed_name_; } + const DeviceNameUtils::ParsedName& parsed_name() const { + return parsed_name_; + } // Describes what kind of device this is. This is intended to be // human-readable and not computer-parsed, except that two devices @@ -149,6 +151,12 @@ class Device : public DeviceBase { return BuildDeviceAttributes(name, device, memory_limit, locality, ""); } + protected: + void DeleteResourceMgr() { + delete rmgr_; + rmgr_ = nullptr; + } + private: const DeviceAttributes device_attributes_; DeviceNameUtils::ParsedName parsed_name_; diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc index 0ed9470655b..493349176ea 100644 --- a/tensorflow/core/common_runtime/device_set.cc +++ b/tensorflow/core/common_runtime/device_set.cc @@ -53,7 +53,7 @@ Device* DeviceSet::FindDeviceByName(const string& name) const { // static int DeviceSet::DeviceTypeOrder(const DeviceType& d) { - return DeviceFactory::DevicePriority(d.type()); + return DeviceFactory::DevicePriority(d.type_string()); } static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) { diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 93bd3a6adbe..6e0f312bc04 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -1231,7 +1231,7 @@ Status FunctionDefToBodyHelper( GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - Status s = ConvertGraphDefToGraph(opts, result.gdef, graph); + Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph); if (!s.ok()) { delete graph; } else { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index e27fc3898dc..dec6ca996aa 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -93,7 +93,7 @@ class FunctionTest : public ::testing::Test { GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g)); + TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g)); const int version = g->versions().producer(); LocalExecutorParams params; @@ -949,7 +949,7 @@ GraphDef Optimize(const std::function& pass, GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get())); + TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get())); pass(g.get()); std::unique_ptr g1(new Graph(OpRegistry::Global())); CopyGraph(*g, g1.get()); diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index 9bc86ef6ef8..1c4aaa5f748 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -324,6 +324,7 @@ static void BM_AllocationDelayed(int iters, int delay) { int size_index = 0; std::vector ptrs; + ptrs.reserve(delay); for (int i = 0; i < delay; i++) { ptrs.push_back(nullptr); } diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 4e14e6fe1a6..7b5cc1c5cba 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -123,10 +123,12 @@ void Benchmark::RunWithArgs( } // Gets inputs' and outputs' rendezvous keys. std::vector> in; + in.reserve(inputs.size()); for (const auto& p : inputs) { in.push_back({GetRendezvousKey(p.first), p.second}); } std::vector out; + out.reserve(outputs.size()); for (const auto& n : outputs) { out.push_back(GetRendezvousKey(n)); } diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc index 2e81811b7c2..dba7a9253e9 100644 --- a/tensorflow/core/common_runtime/session_factory.cc +++ b/tensorflow/core/common_runtime/session_factory.cc @@ -94,6 +94,7 @@ Status SessionFactory::GetFactory(const SessionOptions& options, // TODO(mrry): Consider providing a system-default fallback option // in this case. std::vector factory_types; + factory_types.reserve(candidate_factories.size()); for (const auto& candidate_factory : candidate_factories) { factory_types.push_back(candidate_factory.first); } diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 986b657e0e9..466b779e9b0 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -259,6 +259,7 @@ REGISTER_OP("ShapeData") } std::vector dims; + dims.reserve(shape_data->NumElements()); for (int i = 0; i < shape_data->NumElements(); ++i) { dims.emplace_back(c->MakeDim(shape_data->flat()(i))); } diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index 601ed174af2..6b7c47f8fe5 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -69,36 +69,6 @@ std::vector FilterSupportedDevices( return filtered_devices; } -// Returns the name of the colocation group of the node by inspecting -// the kColocationAttrName attribute of the NodeDef. -void ColocationGroups(const Node& node, - std::vector* colocation_groups) { - std::vector class_specs; - // TODO(vrv): We should consider adding a GetNodeAttr that returns a - // StringPiece, to avoid a copy. - if (!GetNodeAttrSimple(node.attrs(), kColocationAttrNameStringPiece, - &class_specs)) { - // No attribute value is equivalent to the empty colocation_group. - *colocation_groups = { - strings::StrCat(kColocationGroupPrefixStringPiece, node.name())}; - return; - } - - bool found_spec = false; - for (const string& class_spec : class_specs) { - StringPiece spec(class_spec); - if (spec.Consume(kColocationGroupPrefixStringPiece)) { - found_spec = true; - colocation_groups->emplace_back(class_spec); - } - } - - if (!found_spec) { - *colocation_groups = { - strings::StrCat(kColocationGroupPrefixStringPiece, node.name())}; - } -} - // This class maintains the connected components of a colocation // constraint graph, and uses this information to assign a satisfying // device placement to the nodes of the graph. @@ -130,51 +100,96 @@ void ColocationGroups(const Node& node, class ColocationGraph { public: ColocationGraph(Graph* graph, const DeviceSet* device_set, - const SessionOptions* options) - : device_set_(device_set), + bool allow_soft_placement) + : graph_(graph), + device_set_(device_set), device_types_(device_set->PrioritizedDeviceTypeList()), - options_(options) { - members_.reserve(graph->num_node_ids()); + allow_soft_placement_(allow_soft_placement) { + members_.resize(graph->num_node_ids()); } - // Adds the given node to this ColocationGraph as a singleton. + // Adds each node of the Graph to this ColocationGraph as a singleton. // // NOTE: The implementation assumes that the ids of nodes passed to // this method are dense and zero-based; the memory used will be linear in // the largest node ID. // NOTE: If this method returns an error, *this is left in an undefined // state. - Status AddNode(const Node& node) { - Member member; - TF_RETURN_IF_ERROR(InitializeMember(node, &member)); - CHECK_GE(member.parent, 0); - members_.resize(member.parent + 1); - members_[member.parent] = std::move(member); + Status ColocateAllNodes() { + // This maps from a colocation group identifier to the 'root' of that + // colocation group. Note that the keys in this map are StringPiece; the + // actual strings are stored under the NodeDef. The lifetime of this map + // is limited to this ColocateAllNodes() method, and no part of the + // NodeDef trees are changed during the lifetime of this method, so using + // StringPiece as a key is safe. + // + // Also, as a further optimization, we remove the "loc:@" prefix from + // "class" attribute values, when they are used as keys in this table. + // This allows us to use StringPiece values that refer to substrings of + // 'string' values stored in NodeDef attribute lists, as well as StringPiece + // values that refer to 'string' values from NodeDef::name(), without + // performing any string allocations. + std::unordered_map + colocation_group_root; - // When adding the node, identify whether it is part of a - // colocation group. - std::vector colocation_groups; - ColocationGroups(node, &colocation_groups); - Status s; - for (const string& colocation_group : colocation_groups) { - auto it = colocation_group_root_.find(colocation_group); - if (it == colocation_group_root_.end()) { - // This is the first node of the colocation group, so - // designate this node as the 'root' of that colocation group. - colocation_group_root_[colocation_group] = &node; - } else { - // Try to colocate the node with the root. If there is an - // error, return it. - s = ColocateNodes(node, *(it->second)); - if (!s.ok()) { - return s; + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + + // When adding the node, identify whether it is part of a + // colocation group. + + // This code is effectively the equivalent of GetNodeAttr() for a string + // array, but it avoids all internal allocations (the allocation of the + // backing store of the std::vector as well as the copies of the + // strings within it). Instead, we combine the query of the colocation + // attribute with the calls to ColocateNodeToGroup. + bool found_spec = false; + const AttrValue* attr_value = + AttrSlice(node->def()).Find(kColocationAttrNameStringPiece); + if (attr_value != nullptr && attr_value->has_list()) { + for (const string& class_spec : attr_value->list().s()) { + StringPiece spec(class_spec); + if (spec.Consume(kColocationGroupPrefixStringPiece)) { + found_spec = true; + TF_RETURN_IF_ERROR( + ColocateNodeToGroup(&colocation_group_root, node, spec)); + } } } + + if (!found_spec) { + // If the node does not specify a colocation group, then use the + // name of this node as the colocation group. + TF_RETURN_IF_ERROR( + ColocateNodeToGroup(&colocation_group_root, node, node->name())); + } } return Status::OK(); } + Status ColocateNodeToGroup( + std::unordered_map* + colocation_group_root, + Node* node, StringPiece colocation_group) { + const Node*& root_node = (*colocation_group_root)[colocation_group]; + if (root_node == nullptr) { + // This is the first node of the colocation group, so + // designate this node as the 'root' of that colocation group. + root_node = node; + } else { + // Try to colocate the node with the root. If there is an + // error, return it. + Status s = ColocateNodes(*node, *root_node); + if (!s.ok()) { + return AttachDef(s, node->def()); + } + } + return Status::OK(); + } + // Merge the (possibly disjoint) sets containing nodes "x" and // "y". Returns OK if the all nodes in the union of these sets can // be placed on the same device type. @@ -184,105 +199,104 @@ class ColocationGraph { Status ColocateNodes(const Node& x, const Node& y) { int x_root = FindRoot(x.id()); int y_root = FindRoot(y.id()); + return ColocateNodes(x, x_root, y, y_root); + } - Status s; - if (x_root != y_root) { - // Merge the sets by swinging the parent pointer of the smaller - // tree to point to the root of the larger tree. Together with - // path compression in ColocationGraph::FindRoot, this ensures - // that we do not experience pathological performance on graphs - // such as chains. - int new_root, old_root; - if (members_[x_root].rank < members_[y_root].rank) { - // The tree rooted at x_root is shallower, so connect it to - // y_root. The rank of y_root is unchanged because its new - // child has strictly less rank. - members_[x_root].parent = y_root; - new_root = y_root; - old_root = x_root; - } else if (members_[x_root].rank > members_[y_root].rank) { - // The tree rooted at y_root is shallower, so connect it to - // x_root. The rank of x_root is unchanged because its new - // child has strictly less rank. - members_[y_root].parent = x_root; - new_root = x_root; - old_root = y_root; - } else { - // Both trees have the same rank, so break the tie by choosing - // x_root as the new root. - members_[y_root].parent = x_root; - // Increment the rank of the tree rooted at x_root, because it - // is now strictly deeper than before. - ++members_[x_root].rank; - new_root = x_root; - old_root = y_root; - } - - // Merge the partial device specifications, and ensure that they are - // compatible. NULL options_ is treated as allowing soft placement. - // TODO(mrry): Consider enriching the error message by pointing - // out which nodes have the explicit partial device - // specifications that caused this conflict. - s = DeviceNameUtils::MergeDevNames( - &members_[new_root].device_name, members_[old_root].device_name, - options_ == nullptr || options_->config.allow_soft_placement()); - if (!s.ok()) { - return errors::InvalidArgument("Cannot colocate nodes '", x.name(), - "' and '", y.name(), ": ", - s.error_message()); - } - - // Transfer ids in the old group to the new one. - members_[new_root].ids_in_group.insert( - members_[old_root].ids_in_group.begin(), - members_[old_root].ids_in_group.end()); - members_[old_root].ids_in_group.clear(); - - // Ensure that the common root has at least one supported device - // type, by computing the intersection of - // members_[new_root].supported_device_types and - // members_[old_root].supported_device_types. - MergeSupportedDevices(&members_[new_root].supported_device_types, - members_[old_root].supported_device_types); - if (members_[new_root].supported_device_types.empty()) { - string debug_info; - AddDebugInfo(x_root, &debug_info); - AddDebugInfo(y_root, &debug_info); - return errors::InvalidArgument( - "Cannot colocate nodes '", x.name(), "' and '", y.name(), - "' because no device type supports both of those nodes and the " - "other nodes colocated with them.", - debug_info); - } + // This overload of ColocateNodes() allows a caller to provide the root node + // ids for the two nodes. For large graphs, this noticeably reduces the + // graph load time. + Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root) { + if (x_root == y_root) { + return Status::OK(); } + + DCHECK_EQ(x_root, FindRoot(x.id())); + DCHECK_EQ(y_root, FindRoot(y.id())); + + Member& x_root_member = members_[x_root]; + Member& y_root_member = members_[y_root]; + + // Merge the sets by swinging the parent pointer of the smaller + // tree to point to the root of the larger tree. Together with + // path compression in ColocationGraph::FindRoot, this ensures + // that we do not experience pathological performance on graphs + // such as chains. + int new_root, old_root; + if (x_root_member.rank < y_root_member.rank) { + // The tree rooted at x_root is shallower, so connect it to + // y_root. The rank of y_root is unchanged because its new + // child has strictly less rank. + x_root_member.parent = y_root; + new_root = y_root; + old_root = x_root; + } else if (x_root_member.rank > y_root_member.rank) { + // The tree rooted at y_root is shallower, so connect it to + // x_root. The rank of x_root is unchanged because its new + // child has strictly less rank. + y_root_member.parent = x_root; + new_root = x_root; + old_root = y_root; + } else { + // Both trees have the same rank, so break the tie by choosing + // x_root as the new root. + y_root_member.parent = x_root; + // Increment the rank of the tree rooted at x_root, because it + // is now strictly deeper than before. + ++x_root_member.rank; + new_root = x_root; + old_root = y_root; + } + + Member& new_root_member = members_[new_root]; + Member& old_root_member = members_[old_root]; + + // Merge the partial device specifications, and ensure that they are + // compatible. NULL options_ is treated as allowing soft placement. + // TODO(mrry): Consider enriching the error message by pointing + // out which nodes have the explicit partial device + // specifications that caused this conflict. + Status s = DeviceNameUtils::MergeDevNames(&new_root_member.device_name, + old_root_member.device_name, + allow_soft_placement_); + if (!s.ok()) { + return errors::InvalidArgument("Cannot colocate nodes '", x.name(), + "' and '", y.name(), ": ", + s.error_message()); + } + + // Ensure that the common root has at least one supported device + // type, by computing the intersection of + // new_root_member.supported_device_types and + // old_root_member.supported_device_types. + MergeSupportedDevices(&new_root_member.supported_device_types, + old_root_member.supported_device_types); + if (new_root_member.supported_device_types.empty()) { + return errors::InvalidArgument( + "Cannot colocate nodes '", x.name(), "' and '", y.name(), + "' because no device type supports both of those nodes and the " + "other nodes colocated with them.", + DebugInfo(x_root), DebugInfo(y_root)); + } + return Status::OK(); } - // Returns the device name associated with 'node'. - DeviceNameUtils::ParsedName DeviceForNode(const Node& node) { - int node_root = FindRoot(node.id()); - return members_[node_root].device_name; - } - - void SetDeviceForNode(Node* node, const DeviceNameUtils::ParsedName& device) { - int node_root = FindRoot(node->id()); - members_[node_root].device_name = device; - } - // For the given node, subject to the constraints previously given // to this ColocationGraph, set its assigned_device_name. Returns OK // if a satisfying device can be found, otherwise an error. - Status GetDevicesForNode(Node* node, std::vector* possible_devices) { - possible_devices->clear(); + // + // Note: This method returns a pointer to a field within members_. + // The caller must not use the returned pointer after there is any possibility + // that the members_[i].possible_devices field has been modified. + Status GetDevicesForNode(Node* node, + std::vector** possible_devices) { + *possible_devices = nullptr; const int node_root = FindRoot(node->id()); if (!members_[node_root].possible_devices.empty()) { - *possible_devices = members_[node_root].possible_devices; + *possible_devices = &members_[node_root].possible_devices; return Status::OK(); } - // String containing additional debugging info on failures. - string debug_info; - // We have not yet computed the possible devices for the // colocated node set containing 'node', so we do so now using the // constraints on the root node. @@ -304,10 +318,8 @@ class ColocationGraph { devices, members_[node_root].supported_device_types); } - // Perform soft placement if allow_soft_placement is set. options_ - // being NULL is treated as allowing soft placement. - if (devices.empty() && - (options_ == nullptr || options_->config.allow_soft_placement())) { + // Perform soft placement if allow_soft_placement_ is set. + if (devices.empty() && allow_soft_placement_) { // The soft_device_name is the same as the node's device name // without specifying the device type or ID. DeviceNameUtils::ParsedName soft_device_name = @@ -326,7 +338,7 @@ class ColocationGraph { // Return an error when a physical device that matches an explicit // device specification is not found. This ensures that we don't // assign a node to GPU when the user wanted to force it on CPU. - AddDebugInfo(node_root, &debug_info); + string debug_info = DebugInfo(node_root); DeviceNameUtils::ParsedName specified_device_name; if (DeviceNameUtils::ParseFullName(node->requested_device(), @@ -386,21 +398,32 @@ class ColocationGraph { device_set_->devices(), members_[node_root].supported_device_types); if (devices.empty()) { - AddDebugInfo(node_root, &debug_info); return errors::InvalidArgument( "Node had no OpKernel registered to support this operation: ", "Operation was ", node->type_string(), " and inputs were ", - DataTypeVectorString(node->input_types()), debug_info); + DataTypeVectorString(node->input_types()), DebugInfo(node_root)); } } // Cache the result of the possible devices for this node group. - members_[node_root].possible_devices = devices; - *possible_devices = members_[node_root].possible_devices; + members_[node_root].possible_devices = std::move(devices); + *possible_devices = &members_[node_root].possible_devices; + return Status::OK(); + } + + Status InitializeMembers() { + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + Status status = InitializeMember(*node, &members_[node->id()]); + if (!status.ok()) { + return AttachDef(status, node->def()); + } + } return Status::OK(); } - private: // Represents a node in the disjoint node set forest, and the // accumulated constraints on the device used by that node. struct Member { @@ -409,15 +432,6 @@ class ColocationGraph { // id if it is a root. parent <= 0 indicates that this member is invalid. int parent = -1; - // The set of ids that are part of the disjoint node set forest. - // - // This is only fully specified in the root of a disjoint - // node set forest. - std::set ids_in_group; - - // The type of the op for this node. - string op_type; - // A proxy for the depth of the tree that is used to prefer // connecting smaller trees to larger trees when merging disjoint // sets. @@ -438,49 +452,56 @@ class ColocationGraph { std::vector possible_devices; }; - // Adds debugging info to 'output' for the node referred to by - // 'node_root'. - void AddDebugInfo(const int node_root, string* output) { - if (members_[node_root].ids_in_group.size() > 1) { - strings::StrAppend(output, "\nColocation Debug Info:\n"); + // Returns debugging info for the node referred to by 'node_root'. + string DebugInfo(const int node_root) { + string text( + "\nColocation Debug Info:\n" + "Colocation group had the following types and devices: "); - // If this node is part of a colocation group, then we want to - // collect the mapping of ops to supported devices, so that - // the user can see why an unsatisfiable placement occurred. - strings::StrAppend( - output, "Colocation group had the following types and devices: "); + // If this node is part of a colocation group, then we want to + // collect the mapping of ops to supported devices, so that + // the user can see why an unsatisfiable placement occurred. - std::unordered_map type_to_devices; - for (const int id : members_[node_root].ids_in_group) { - const string& op_type = members_[id].op_type; - string devices_registered; - for (const auto& device_type : members_[id].supported_device_types) { - strings::StrAppend(&devices_registered, DeviceTypeString(device_type), - " "); - } + std::unordered_map type_to_devices; + int num_nodes_found = 0; - type_to_devices[op_type] = devices_registered; + for (const Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + int id = node->id(); + if (FindRoot(id) != node_root) { + continue; + } + ++num_nodes_found; + const string& op_type = node->type_string(); + string devices_registered; + for (const auto& device_type : members_[id].supported_device_types) { + strings::StrAppend(&devices_registered, DeviceTypeString(device_type), + " "); } - for (const auto& td : type_to_devices) { - strings::StrAppend(output, "\n", td.first, ": ", td.second); - } + type_to_devices[op_type] = std::move(devices_registered); } + + for (const auto& td : type_to_devices) { + strings::StrAppend(&text, "\n", td.first, ": ", td.second); + } + + if (num_nodes_found <= 1) { + text.clear(); + } + return text; } Status InitializeMember(const Node& node, Member* member) { const int id = node.id(); - member->ids_in_group.insert(id); - member->op_type = node.type_string(); - - if (id < 0) { - return errors::InvalidArgument("Node id was not positive: ", id); - } + DCHECK_GE(id, 0); member->parent = id; TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( device_types_, node.def(), &member->supported_device_types)); - if (!node.assigned_device_name().empty()) { + if (node.has_assigned_device_name()) { // This node has already been assigned to a device, so we // respect this placement, after sanity-checking it. The // device_name and supported_device_types for this node reflect @@ -490,17 +511,16 @@ class ColocationGraph { // NOTE: Since any assignment must have been performed by // the TensorFlow runtime, we consider errors in this branch to // be INTERNAL. - if (!DeviceNameUtils::ParseFullName(node.assigned_device_name(), + const string& assigned_device_name = node.assigned_device_name(); + if (!DeviceNameUtils::ParseFullName(assigned_device_name, &member->device_name)) { return errors::Internal("Malformed assigned device '", - node.assigned_device_name(), "'"); + assigned_device_name, "'"); } - std::vector devices; const Device* assigned_device = - device_set_->FindDeviceByName(node.assigned_device_name()); + device_set_->FindDeviceByName(assigned_device_name); if (assigned_device == nullptr) { - return errors::Internal("Assigned device '", - node.assigned_device_name(), + return errors::Internal("Assigned device '", assigned_device_name, "' does not match any device"); } @@ -510,7 +530,7 @@ class ColocationGraph { } } - return errors::Internal("Assigned device '", node.assigned_device_name(), + return errors::Internal("Assigned device '", assigned_device_name, "' does not have registered OpKernel support " "for ", node.type_string()); @@ -577,32 +597,32 @@ class ColocationGraph { // Returns the root node of the disjoint tree to which the node with the // given id is connected. int FindRoot(int node_id) { - DCHECK_GE(members_[node_id].parent, 0); - if (members_[node_id].parent != node_id) { + Member& member = members_[node_id]; + + int parent = member.parent; + DCHECK_GE(parent, 0); + + if (parent != node_id) { // NOTE: Compress paths from node_id to its root, so that future // calls to FindRoot and ColocateNodes are more efficient. - members_[node_id].parent = FindRoot(members_[node_id].parent); + int root = FindRoot(parent); + if (parent != root) { + parent = root; + member.parent = root; + } } - return members_[node_id].parent; + + DCHECK_GE(parent, 0); + return parent; } + Graph* const graph_; // Not owned. std::vector members_; const DeviceSet* device_set_; // Not owned. const std::vector device_types_; - const SessionOptions* options_; // Not owned; - - // Maps from a colocation group identifier to the 'root' of that - // colocation group. - std::unordered_map colocation_group_root_; + const bool allow_soft_placement_; }; -// Returns true if the node only depends on its input's metadata -// (shape). Not necessarily a complete list. -bool IsMetadataNode(const Node* node) { - const string& node_type = node->type_string(); - return (node_type == "Size" || node_type == "Shape" || node_type == "Rank"); -} - // Returns true if the node has no inputs and produces outputs // that are consumed by a single node. // @@ -618,12 +638,14 @@ bool IsGeneratorNode(const Node* node) { SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices, const SessionOptions* options) - : graph_(graph), devices_(devices), options_(options) {} + : graph_(graph), + devices_(devices), + options_(options), + log_device_placement_(options != nullptr && + options->config.log_device_placement()) {} SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices) - : graph_(graph), devices_(devices) { - options_ = nullptr; -} + : SimplePlacer(graph, devices, nullptr) {} SimplePlacer::~SimplePlacer() {} @@ -632,91 +654,93 @@ Status SimplePlacer::Run() { return errors::FailedPrecondition("No devices are registered"); } - ColocationGraph colocation_graph(graph_, devices_, options_); - Status status; + ColocationGraph colocation_graph( + graph_, devices_, + options_ == nullptr || options_->config.allow_soft_placement()); + + TF_RETURN_IF_ERROR(colocation_graph.InitializeMembers()); // 1. First add all of the nodes. Note that steps (1) and (2) // requires two passes over the nodes because the graph (and hence // the constraints) may not be acyclic. - for (Node* node : graph_->op_nodes()) { - status = colocation_graph.AddNode(*node); - if (!status.ok()) return AttachDef(status, *node); - } + TF_RETURN_IF_ERROR(colocation_graph.ColocateAllNodes()); // 2. Enumerate the constraint edges, and use them to update the disjoint // node set. - for (Node* node : graph_->op_nodes()) { - // If `node` has an input edge with reference type, add an - // edge from the source of that edge to `node`. - for (const auto& edge : node->in_edges()) { - if (!edge->IsControlEdge() && - (IsRefType(node->input_type(edge->dst_input())) || - node->input_type(edge->dst_input()) == DT_RESOURCE)) { - // If both the source node and this node have partially - // specified a device, then 'node's device should be - // cleared: the reference edge forces 'node' to be on the - // same device as the source node. - auto source_parsed_name = colocation_graph.DeviceForNode(*edge->src()); - auto dest_parsed_name = colocation_graph.DeviceForNode(*node); - if (DeviceNameUtils::HasSomeDetails(source_parsed_name) && - DeviceNameUtils::HasSomeDetails(dest_parsed_name)) { - // Add a log saying that we are ignoring a specified device - // for 'node' if the two names were incompatible. - if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name, - dest_parsed_name)) { - LOG(INFO) << "Ignoring device specification " - << DeviceNameUtils::ParsedNameToString( - colocation_graph.DeviceForNode(*node)) - << " for node '" << node->name() - << "' because the input edge from '" - << edge->src()->name() - << "' is a reference connection and already has a device " - "field set to " - << DeviceNameUtils::ParsedNameToString( - colocation_graph.DeviceForNode(*edge->src())); - // Make 'node' colocated with the source - colocation_graph.SetDeviceForNode(node, source_parsed_name); + // If `node` has an input edge with reference type, add an + // edge from the source of that edge to `node`. + for (const Edge* edge : graph_->edges()) { + if (edge->IsControlEdge()) { + continue; + } + Node* src = edge->src(); + Node* dst = edge->dst(); + DataType input_type = dst->input_type(edge->dst_input()); + if (input_type == DT_RESOURCE || IsRefType(input_type)) { + int src_root_id = colocation_graph.FindRoot(src->id()); + int dst_root_id = colocation_graph.FindRoot(dst->id()); + auto& src_root = colocation_graph.members_[src_root_id]; + auto& dst_root = colocation_graph.members_[dst_root_id]; + // If both the source node and this node have paritally + // specified a device, then 'node's device should be + // cleared: the reference edge forces 'node' to be on the + // same device as the source node. + const auto& source_parsed_name = src_root.device_name; + const auto& dest_parsed_name = dst_root.device_name; + if (DeviceNameUtils::HasSomeDetails(source_parsed_name) && + DeviceNameUtils::HasSomeDetails(dest_parsed_name)) { + // Add a log saying that we are ignoring a specified device + // for 'dst' if the two names were incompatible. + if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name, + dest_parsed_name)) { + LOG(INFO) << "Ignoring device specification " + << DeviceNameUtils::ParsedNameToString(dest_parsed_name) + << " for node '" << dst->name() + << "' because the input edge from '" << src->name() + << "' is a reference connection and already has a device " + "field set to " + << DeviceNameUtils::ParsedNameToString(source_parsed_name); + + // Make 'dst' colocated with the source + dst_root.device_name = source_parsed_name; + } else { + bool source_subset_of_dest = DeviceNameUtils::IsSpecification( + source_parsed_name, dest_parsed_name); + bool dest_subset_of_source = DeviceNameUtils::IsSpecification( + dest_parsed_name, source_parsed_name); + + if (source_subset_of_dest && !dest_subset_of_source) { + src_root.device_name = dest_parsed_name; } else { - bool source_subset_of_dest = DeviceNameUtils::IsSpecification( - source_parsed_name, dest_parsed_name); - bool dest_subset_of_source = DeviceNameUtils::IsSpecification( - dest_parsed_name, source_parsed_name); - - if (source_subset_of_dest && !dest_subset_of_source) { - colocation_graph.SetDeviceForNode(edge->src(), dest_parsed_name); - } else { - colocation_graph.SetDeviceForNode(node, source_parsed_name); - } + dst_root.device_name = source_parsed_name; } } + } - status = colocation_graph.ColocateNodes(*edge->src(), *node); - if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Nodes were connected by a " - "reference connection (requiring them to " - "be on the same device), but the two nodes " - "were assigned two different devices: ", - status.error_message()), - *node); - } + Status status = + colocation_graph.ColocateNodes(*src, src_root_id, *dst, dst_root_id); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument("Nodes were connected by a " + "reference connection (requiring them to " + "be on the same device), but the two nodes " + "were assigned two different devices: ", + status.error_message()), + dst->def()); } } } // 3. For each node, assign a device based on the constraints in the // disjoint node set. - std::vector devices; std::vector second_pass; for (Node* node : graph_->op_nodes()) { // The graph may have come pre-populated by the framework with assigned // devices (e.g., for stateful placements), so the placer should not try to // place nodes that are already placed. - if (!node->assigned_device_name().empty()) { - // Although the device is already assigned, we run this function to - // possibly log pre-assigned placements. - AssignAndLog(node->assigned_device_name(), node); + if (node->has_assigned_device_name()) { + LogDeviceAssignment(node); continue; } @@ -731,7 +755,8 @@ Status SimplePlacer::Run() { continue; } - status = colocation_graph.GetDevicesForNode(node, &devices); + std::vector* devices; + Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { return AttachDef( errors::InvalidArgument("Cannot assign a device for operation '", @@ -748,12 +773,12 @@ Status SimplePlacer::Run() { // given a choice of devices. Once we have a better idea of the // types of heuristics we want to use and the information needed // to perform good placement we can add an interface for this. - string assigned_device = devices[0]->name(); + int assigned_device = -1; // Heuristic B: If the node only operates on metadata, not data, // then it is desirable to place that metadata node with its // input. - if (IsMetadataNode(node)) { + if (IsMetadata(node)) { // Make sure that the input device type is in the list of supported // device types for this node. const Node* input = (*node->in_edges().begin())->src(); @@ -761,19 +786,24 @@ Status SimplePlacer::Run() { // node's assignment to the second pass, so that we handle the // case where a metadata node's input comes from a backedge // of a loop. - const string& input_device_name = input->assigned_device_name(); - if (CanAssignToDevice(input_device_name, devices)) { - assigned_device = input_device_name; + if (CanAssignToDevice(input->assigned_device_name(), *devices)) { + assigned_device = input->assigned_device_name_index(); } } + // Provide the default, if necessary. + if (assigned_device == -1) { + assigned_device = graph_->InternDeviceName((*devices)[0]->name()); + } + AssignAndLog(assigned_device, node); } // 4. Perform a second pass assignment for those nodes explicitly // skipped during the first pass. for (Node* node : second_pass) { - status = colocation_graph.GetDevicesForNode(node, &devices); + std::vector* devices; + Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { return AttachDef( errors::InvalidArgument("Cannot assign a device for operation '", @@ -781,25 +811,30 @@ Status SimplePlacer::Run() { *node); } - string assigned_device = devices[0]->name(); + int assigned_device = -1; // Heuristic A application. if (IsGeneratorNode(node)) { const Node* output = (*node->out_edges().begin())->dst(); - const string& output_device_name = output->assigned_device_name(); + int output_device_name = output->assigned_device_name_index(); const bool consumers_on_same_device = std::all_of( node->out_edges().begin(), node->out_edges().end(), [output_device_name](const Edge* e) { - return e->dst()->assigned_device_name() == output_device_name; + return e->dst()->assigned_device_name_index() == output_device_name; }); if (consumers_on_same_device && - CanAssignToDevice(output_device_name, devices)) { + CanAssignToDevice(output->assigned_device_name(), *devices)) { assigned_device = output_device_name; } } + // Provide the default, if necessary. + if (assigned_device == -1) { + assigned_device = graph_->InternDeviceName((*devices)[0]->name()); + } + AssignAndLog(assigned_device, node); } @@ -824,11 +859,14 @@ bool SimplePlacer::CanAssignToDevice( return false; } -void SimplePlacer::AssignAndLog(const string& assigned_device, - Node* node) const { - node->set_assigned_device_name(assigned_device); +void SimplePlacer::AssignAndLog(int assigned_device, Node* node) const { + node->set_assigned_device_name_index(assigned_device); + LogDeviceAssignment(node); +} + +void SimplePlacer::LogDeviceAssignment(const Node* node) const { // Log placement if log_device_placement is set. - if (options_ && options_->config.log_device_placement()) { + if (log_device_placement_) { printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(), node->assigned_device_name().c_str()); LOG(INFO) << node->name() << ": " diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/simple_placer.h index a041e968309..9c63cef40bb 100644 --- a/tensorflow/core/common_runtime/simple_placer.h +++ b/tensorflow/core/common_runtime/simple_placer.h @@ -86,11 +86,13 @@ class SimplePlacer { // Assigns 'node's devices to 'assigned_device', and logs the // placement if the SessionOptions entry in 'options_' requests it. - void AssignAndLog(const string& assigned_device, Node* node) const; + void AssignAndLog(int assigned_device, Node* node) const; + void LogDeviceAssignment(const Node* node) const; Graph* const graph_; // Not owned. const DeviceSet* const devices_; // Not owned. const SessionOptions* options_; // Not owned. + const bool log_device_placement_; TF_DISALLOW_COPY_AND_ASSIGN(SimplePlacer); }; diff --git a/tensorflow/core/common_runtime/stats_publisher_interface.cc b/tensorflow/core/common_runtime/stats_publisher_interface.cc index 408c901d170..f589140cd6f 100644 --- a/tensorflow/core/common_runtime/stats_publisher_interface.cc +++ b/tensorflow/core/common_runtime/stats_publisher_interface.cc @@ -15,29 +15,30 @@ limitations under the License. #include "tensorflow/core/common_runtime/stats_publisher_interface.h" -namespace tensorflow { +#include "tensorflow/core/framework/graph.pb.h" +namespace tensorflow { namespace { + // NoOpStatsPublisher provides an dummy/no-op implementation of // StatsPublisherInterface. class NoOpStatsPublisher : public StatsPublisherInterface { public: - NoOpStatsPublisher(){}; + NoOpStatsPublisher() = default; - void PublishStatsProto(const StepStats& step_stats) override { return; } + void PublishStatsProto(const StepStats& step_stats) override {} void PublishGraphProto( - const std::vector& graph_defs) override { - return; - } + const std::vector& graph_defs) override {} std::unique_ptr GetProfileHandler( uint64 step, int64 execution_count, const RunOptions& ropts) override { return nullptr; } - ~NoOpStatsPublisher() override {} + ~NoOpStatsPublisher() override = default; }; + } // namespace std::unique_ptr CreateNoOpStatsPublisher( diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index d5e6e293d6d..035bceb640f 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -746,6 +746,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, } // Build result of different unknown dims. std::vector dims; + dims.reserve(rank); for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim()); c->set_output(0, c->MakeShape(dims)); return Status::OK(); diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 186095201d1..9026075a2f0 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -140,7 +140,7 @@ class FunctionInstantiationHelper { FunctionInstantiationHelper(GetFunctionSignature get_function, InstantiationResult* result) : get_function_(std ::move(get_function)), result_(*result) { - result_.gdef.Clear(); + result_.nodes.clear(); } // Builds index for nodes that can be used as node's input arguments. @@ -151,15 +151,14 @@ class FunctionInstantiationHelper { TF_RETURN_IF_ERROR( ArgNumType(attr_values, arg_def, &is_type_list, &dtypes)); CHECK_GE(dtypes.size(), size_t{1}); - GraphDef* gdef = &result_.gdef; - int arg_index = gdef->node_size(); + int arg_index = result_.nodes.size(); TF_RETURN_IF_ERROR( AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes})); - // Creates dtypes.size() nodes in the gdef. + // Creates dtypes.size() nodes in the graph. for (size_t i = 0; i < dtypes.size(); ++i) { TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i), {true, arg_index, 0, false, {dtypes[i]}})); - DCHECK_EQ(arg_index, gdef->node_size()); + DCHECK_EQ(arg_index, result_.nodes.size()); string name = arg_def.name(); if (dtypes.size() > 1) { strings::StrAppend(&name, "_", i); @@ -332,13 +331,13 @@ class FunctionInstantiationHelper { // Adds the actual node inputs to the result graph by converting indexes to // the node names. void AddNodeInputs() { - for (int i = 0; i < result_.gdef.node_size(); i++) { + for (int i = 0; i < result_.nodes.size(); i++) { NodeInfo& node_info = nodes_[i]; for (const auto& p : node_info.data_inputs) { - result_.gdef.mutable_node(i)->add_input(Name(p.first, p.second)); + result_.nodes[i].add_input(Name(p.first, p.second)); } for (int index : node_info.control_inputs) { - result_.gdef.mutable_node(i)->add_input(Dep(index)); + result_.nodes[i].add_input(Dep(index)); } } } @@ -348,11 +347,10 @@ class FunctionInstantiationHelper { // node's input arguments. // // If is_func_arg is true, the name is a function's argument. In - // this case, the produced graph def has gdef.node[nid ... nid + - // dtype.size()). + // this case, the produced graph def has node[nid:nid + dtype.size()]. // // Otherwise, the name is a function body's node return value. In - // this case, the produced graph def has one node gdef.node[nid] and + // this case, the produced graph def has one node node[nid] and // the node's output index [idx ... idx + num) corresponds to the // named outputs. // @@ -398,10 +396,11 @@ class FunctionInstantiationHelper { } NodeDef* AddNode(const string& name) { - NodeDef* gnode = result_.gdef.add_node(); + result_.nodes.emplace_back(); + NodeDef* gnode = &result_.nodes.back(); gnode->set_name(name); nodes_.push_back({name, {}, {}}); - CHECK_EQ(result_.gdef.node_size(), nodes_.size()); + CHECK_EQ(result_.nodes.size(), nodes_.size()); return gnode; } @@ -429,7 +428,7 @@ class FunctionInstantiationHelper { // Control inputs (dependencies). std::vector control_inputs; }; - // nodes_[i] is the information about result_.gdef.node(i). + // nodes_[i] is the information about result_.nodes[i]. std::vector nodes_; }; @@ -545,17 +544,17 @@ string Print(const FunctionDef& fdef) { return out; } -string Print(const GraphDef& gdef) { +string Print(gtl::ArraySlice nodes) { std::vector arg; std::vector ret; std::vector body; - for (const NodeDef& n : gdef.node()) { - if (n.op() == "_Arg") { - arg.push_back(&n); - } else if (n.op() == "_Retval") { - ret.push_back(&n); + for (const NodeDef* n : nodes) { + if (n->op() == "_Arg") { + arg.push_back(n); + } else if (n->op() == "_Retval") { + ret.push_back(n); } else { - body.push_back(&n); + body.push_back(n); } } auto comp = [](const NodeDef* x, const NodeDef* y) { @@ -570,12 +569,11 @@ string Print(const GraphDef& gdef) { string out; strings::StrAppend(&out, "\n("); auto get_type = [](const NodeDef& n) { - for (auto a : n.attr()) { - if (a.first == "T") { - return DataTypeString(a.second.type()); - } + DataType dt; + if (!GetNodeAttr(n, "T", &dt).ok()) { + dt = DT_INVALID; } - return DataTypeString(DT_INVALID); + return DataTypeString(dt); }; for (size_t i = 0; i < arg.size(); ++i) { const NodeDef* n = arg[i]; @@ -663,13 +661,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, for (int i = 0; i < fdef.node_def_size(); ++i) { s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]), - result->gdef.node_size() + i); + result->nodes.size() + i); if (!s.ok()) { errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); return s; } } - // Emits one gdef.node for each fdef.node_def. + // Emits one node for each fdef.node_def. for (int i = 0; i < fdef.node_def_size(); ++i) { s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i])); if (!s.ok()) { @@ -697,7 +695,19 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, string DebugString(const FunctionDef& func_def) { return Print(func_def); } string DebugString(const GraphDef& instantiated_func_def) { - return Print(instantiated_func_def); + std::vector ptrs; + for (const NodeDef& n : instantiated_func_def.node()) { + ptrs.push_back(&n); + } + return Print(ptrs); +} + +string DebugString(gtl::ArraySlice instantiated_func_nodes) { + std::vector ptrs; + for (const NodeDef& n : instantiated_func_nodes) { + ptrs.push_back(&n); + } + return Print(ptrs); } string DebugStringWhole(const GraphDef& gdef) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 188c3855c6e..6c2da84790c 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -200,7 +200,7 @@ typedef std::function struct InstantiationResult { DataTypeVector arg_types; DataTypeVector ret_types; - GraphDef gdef; + std::vector nodes; }; Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, GetFunctionSignature get_function, @@ -216,6 +216,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, // etc.) string DebugString(const FunctionDef& func_def); string DebugString(const GraphDef& instantiated_func_def); +string DebugString(gtl::ArraySlice instantiated_func_nodes); // Returns a debug string for a top level graph (the main program and // its supporting functions defined in its library). diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index f3ad935c787..137093ab378 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -108,7 +108,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } TEST(TFunc, ControlDep) { @@ -154,7 +154,7 @@ ControlDep(x:int32) -> (y:int32) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } REGISTER_OP("HasDefaultType") @@ -198,7 +198,7 @@ BackCompat() -> (y:float) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector()); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } TEST(TFunc, NTimesT) { @@ -234,7 +234,7 @@ NTimesT(x:float, y:float) -> (z:float) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } // NOTE: This is the simplest Map op. It takes a f:T->U. @@ -299,7 +299,7 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } TEST(TFunc, ControlDeps) { @@ -344,7 +344,7 @@ ControlDeps(x:float) -> () { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } TEST(TFunc, XTimesTwo) { @@ -425,7 +425,7 @@ Test(i:float) -> (o:float) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } REGISTER_OP("Cond") @@ -493,7 +493,7 @@ MySelect(x:float) -> (z:float) { )P"; EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); - EXPECT_EQ(DebugString(result.gdef), e2); + EXPECT_EQ(DebugString(result.nodes), e2); } static void HasError(const Status& s, const string& substr) { @@ -1028,7 +1028,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) { *proto.add_gradient() = grad; FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); TF_EXPECT_OK(lib_def.AddLibrary(lib_def3)); -}; +} TEST(FunctionLibraryDefinitionTest, ToProto) { FunctionDefLibrary proto1; diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index 5ddac6b1982..fe333dc9ffa 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -48,7 +48,7 @@ OpRegistry::~OpRegistry() { for (const auto& e : registry_) delete e.second; } -void OpRegistry::Register(OpRegistrationDataFactory op_data_factory) { +void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) { mutex_lock lock(mu_); if (initialized_) { TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory)); @@ -181,7 +181,7 @@ Status OpRegistry::CallDeferred() const { } Status OpRegistry::RegisterAlreadyLocked( - OpRegistrationDataFactory op_data_factory) const { + const OpRegistrationDataFactory& op_data_factory) const { std::unique_ptr op_reg_data(new OpRegistrationData); Status s = op_data_factory(op_reg_data.get()); if (s.ok()) { diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index 892ed9b60b4..c5a0983a547 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -70,7 +70,7 @@ class OpRegistry : public OpRegistryInterface { OpRegistry(); ~OpRegistry() override; - void Register(OpRegistrationDataFactory op_data_factory); + void Register(const OpRegistrationDataFactory& op_data_factory); Status LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const override; @@ -138,8 +138,8 @@ class OpRegistry : public OpRegistryInterface { // Add 'def' to the registry with additional data 'data'. On failure, or if // there is already an OpDef with that name registered, returns a non-okay // status. - Status RegisterAlreadyLocked(OpRegistrationDataFactory op_data_factory) const - EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) + const EXCLUSIVE_LOCKS_REQUIRED(mu_); mutable mutex mu_; // Functions in deferred_ may only be called with mu_ held. diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index e8e931b52e4..f87b7178449 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -456,7 +456,8 @@ class OpKernelBuilderTest : public ::testing::Test { } } - string GetKernelClassName(const string& op_type, DeviceType device_type, + string GetKernelClassName(const string& op_type, + const DeviceType& device_type, const std::vector& attrs, DataTypeSlice input_types = {}) { NodeDef def = CreateNodeDef(op_type, attrs); diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index 7dc4a509257..6e578cdbab4 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -283,7 +283,6 @@ class LocalRendezvousImpl : public Rendezvous { } CHECK(table_.insert({key_hash, item}).second); mu_.unlock(); - return; } void StartAbort(const Status& status) override { diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 55860d92271..4365a861e52 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -124,6 +124,7 @@ string ResourceMgr::DebugString() const { } } std::vector text; + text.reserve(lines.size()); for (const Line& line : lines) { text.push_back(strings::Printf( "%-20s | %-40s | %-40s | %-s", line.container->c_str(), diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 59f6f6218e7..a4597080f01 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -569,6 +569,7 @@ Status InferenceContext::MakeShapeFromTensor(const Tensor* t, } const auto num_dims = Value(shape_dim); std::vector dims; + dims.reserve(num_dims); for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim()); return ReturnCreatedShape(dims, out); } diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 6f63937108c..a9c0303d4cb 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -783,6 +783,7 @@ TEST_F(ShapeInferenceTest, MakeShape) { std::vector dims; auto in0 = c.input(0); const int rank = c.Rank(in0); + dims.reserve(rank); for (int i = 0; i < rank; ++i) { dims.push_back(c.Dim(in0, rank - i - 1)); } diff --git a/tensorflow/core/framework/shape_inference_testutil_test.cc b/tensorflow/core/framework/shape_inference_testutil_test.cc index de14c071b46..20a6807064b 100644 --- a/tensorflow/core/framework/shape_inference_testutil_test.cc +++ b/tensorflow/core/framework/shape_inference_testutil_test.cc @@ -51,6 +51,7 @@ string RunInferShapes(const string& op_name, const string& ins, ShapeInferenceTestOp op(op_name); const int num_inputs = 1 + std::count(ins.begin(), ins.end(), ';'); std::vector src_list; + src_list.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) src_list.emplace_back("a", 0, DT_FLOAT); NodeDef node_def; TF_CHECK_OK(NodeDefBuilder("dummy", op_name) diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index fdab7daf76a..369f64e9e2d 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -820,16 +820,14 @@ namespace { // failures to allocate. class DummyCPUAllocator : public Allocator { public: - DummyCPUAllocator() {} + DummyCPUAllocator() = default; string Name() override { return "cpu"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override { return nullptr; } - void DeallocateRaw(void* ptr) override { return; } + void DeallocateRaw(void* ptr) override {} }; -} // namespace - TEST(Tensor, FailureToAllocate) { TensorShape shape({1}); DummyCPUAllocator allocator; @@ -1080,4 +1078,5 @@ static void BM_CreateAndMoveCtrWithBuf(int iters) { } BENCHMARK(BM_CreateAndMoveCtrWithBuf); +} // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 0a81b1cb9f3..f562880e7cf 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -57,6 +57,7 @@ class DeviceType { explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} const char* type() const { return type_.c_str(); } + const string& type_string() const { return type_; } bool operator<(const DeviceType& other) const; bool operator==(const DeviceType& other) const; diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc index 69247a4f621..f798af85e15 100644 --- a/tensorflow/core/graph/costmodel.cc +++ b/tensorflow/core/graph/costmodel.cc @@ -476,6 +476,14 @@ static void EstimateComputationCosts(const Graph& g, CostModel* cost_model) { } // namespace void CostModel::InitFromGraph(const Graph& g) { + const int num_node_ids = g.num_node_ids(); + slot_bytes_.reserve(num_node_ids); + count_.reserve(num_node_ids); + time_.reserve(num_node_ids); + max_mem_usage_.reserve(num_node_ids); + max_exec_time_.reserve(num_node_ids); + output_port_alloc_ids_.reserve(num_node_ids); + AddNodesToCostModel(g, this); AssignSizes(g, this); EstimateComputationCosts(g, this); diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 80161ceb56b..dcb8520cf73 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -56,6 +56,9 @@ const std::unordered_map& Node::kNodeClassTable = {"GetSessionHandleV2", NC_GET_SESSION_HANDLE}, {"GetSessionTensor", NC_GET_SESSION_TENSOR}, {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR}, + {"Size", NC_METADATA}, + {"Shape", NC_METADATA}, + {"Rank", NC_METADATA}, }); #undef REF_CLASS @@ -77,7 +80,7 @@ string Node::DebugString() const { strings::StrAppend(&ret, " sink}"); } else { strings::StrAppend(&ret, " op device:"); - strings::StrAppend(&ret, "{", assigned_device_name_, "}"); + strings::StrAppend(&ret, "{", assigned_device_name(), "}"); strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}"); } return ret; @@ -88,7 +91,7 @@ Node::Node() cost_id_(-1), class_(NC_UNINITIALIZED), props_(nullptr), - assigned_device_name_() {} + assigned_device_name_index_(0) {} Node::~Node() { if (props_) { @@ -124,7 +127,7 @@ void Node::Clear() { props_ = nullptr; } - assigned_device_name_.clear(); + assigned_device_name_index_ = 0; } gtl::iterator_range Node::out_nodes() const { @@ -241,6 +244,10 @@ Graph::Graph(const OpRegistryInterface* ops) versions_.set_producer(TF_GRAPH_DEF_VERSION); versions_.set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); + // Initialize the name interning table for assigned_device_name. + device_names_.push_back(""); + DCHECK_EQ(0, InternDeviceName("")); + // Source and sink have no endpoints, just control edges. NodeDef def; def.set_name("_SOURCE"); @@ -503,6 +510,7 @@ Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) { node = free_nodes_.back(); free_nodes_.pop_back(); } + node->graph_ = this; const int id = nodes_.size(); int cost_id = cost_node ? cost_node->cost_id() : id; node->Initialize(id, cost_id, props); @@ -519,4 +527,26 @@ void Graph::ReleaseNode(Node* node) { node->Clear(); } +// Ensures that 'device_name' is present in the device name table, and returns +// the index of that device name. The index is stable, and can be used in +// calls to Node::set_assigned_device_name_index(). +int Graph::InternDeviceName(const string& device_name) { + // Special case, very common. Also, this allows us to use a single map + // lookup below, instead of two. The 'if (index_cell > 0)' test below + // relies on this check. + if (device_name.empty()) { + return 0; + } + + int& index_cell = device_names_map_[device_name]; + if (index_cell > 0) { + return index_cell; + } + + const int index = device_names_map_.size(); + index_cell = index; + device_names_.push_back(device_name); + return index; +} + } // namespace tensorflow diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index e82580f204b..8cb270170e9 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -104,10 +104,13 @@ class Node { // fully specifies a device, and satisfies def().device(). // TODO(josh11b): Move assigned_device_name outside of Node into a // NodeId->DeviceName map. - string assigned_device_name() const { return assigned_device_name_; } - void set_assigned_device_name(const string& device_name) { - assigned_device_name_ = device_name; + const string& assigned_device_name() const; + void set_assigned_device_name(const string& device_name); + bool has_assigned_device_name() const { + return assigned_device_name_index_ > 0; } + int assigned_device_name_index() const { return assigned_device_name_index_; } + void set_assigned_device_name_index(int index); // Read only access to attributes AttrSlice attrs() const { return AttrSlice(def()); } @@ -155,6 +158,8 @@ class Node { bool IsHostSend() const { return class_ == NC_HOST_SEND; } bool IsHostRecv() const { return class_ == NC_HOST_RECV; } + bool IsMetadata() const { return class_ == NC_METADATA; } + template void AddAttr(const string& name, const T& val) { MaybeCopyOnWrite(); @@ -232,6 +237,7 @@ class Node { NC_GET_SESSION_HANDLE, NC_GET_SESSION_TENSOR, NC_DELETE_SESSION_TENSOR, + NC_METADATA, NC_OTHER // Not a special kind of node }; @@ -248,8 +254,16 @@ class Node { Properties* props_; - // Name of device assigned to perform this computation. - string assigned_device_name_; + // Index within Graph::device_names_ of the name of device assigned + // to perform this computation. + int assigned_device_name_index_; + + // A back-pointer to the Graph that owns this node. Currently, this exists + // solely to allow Node::[set_]assigned_device_name() to work. However, if all + // callers of Node::[set_]assigned_device_name() are modified to use the + // equivalent methods defined directly on Graph, then we can remove this + // field and reclaim that memory. + Graph* graph_; TF_DISALLOW_COPY_AND_ASSIGN(Node); }; @@ -478,6 +492,26 @@ class Graph { const OpRegistryInterface* op_registry() const { return &ops_; } const FunctionLibraryDefinition& flib_def() const { return ops_; } + void CheckDeviceNameIndex(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, static_cast(device_names_.size())); + } + + int InternDeviceName(const string& device_name); + + const string& get_assigned_device_name(const Node& node) const { + return device_names_[node.assigned_device_name_index()]; + } + + void set_assigned_device_name_index(Node* node, int device_name_index) { + CheckDeviceNameIndex(device_name_index); + node->assigned_device_name_index_ = device_name_index; + } + + void set_assigned_device_name(Node* node, const string& device_name) { + node->assigned_device_name_index_ = InternDeviceName(device_name); + } + // TODO(josh11b): uint64 hash() const; private: @@ -518,6 +552,30 @@ class Graph { // For generating unique names. int name_counter_ = 0; + // In most graphs, the number of unique values used for the + // Node::assigned_device_name() property is quite small. If the graph is + // large, then this duplication of values can consume a significant amount of + // memory. Instead, we represent the same information using an interning + // table, which consists of a vector of unique strings (device_names_), as + // well a map (device_names_map_) from unique strings to indices within the + // unique string table. + // + // The InternDeviceName() method handles adding a new entry into the table, + // or locating the index of an existing entry. + // + // The fact that Node::assigned_device_name() is implemented using an + // interning table is intentionally public. This allows algorithms that + // frequently access this field to do so efficiently, especially for the case + // where the assigned_device_name of one Node is copied directly from that + // of another Node. + + // A table of the unique assigned device names. Indices do NOT correspond + // to node IDs. Index 0 is always the empty string. + std::vector device_names_; + + // Maps unique device names to indices within device_names_[i]. + std::unordered_map device_names_map_; + TF_DISALLOW_COPY_AND_ASSIGN(Graph); }; @@ -550,6 +608,10 @@ inline bool IsIdentity(const Node* node) { return node->IsIdentity(); } // Returns true iff 'n' is a control flow node. inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); } +// Returns true if the node only depends on its input's metadata +// (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops. +inline bool IsMetadata(const Node* n) { return n->IsMetadata(); } + inline bool IsHostMemoryPreserving(const Node* node) { return IsIdentity(node) || IsControlFlow(node); } @@ -666,6 +728,19 @@ inline gtl::iterator_range Graph::op_nodes() const { return gtl::make_range(begin, end); } +inline void Node::set_assigned_device_name_index(int index) { + graph_->CheckDeviceNameIndex(index); + assigned_device_name_index_ = index; +} + +inline void Node::set_assigned_device_name(const string& device_name) { + graph_->set_assigned_device_name(this, device_name); +} + +inline const string& Node::assigned_device_name() const { + return graph_->get_assigned_device_name(*this); +} + } // namespace tensorflow #endif // TENSORFLOW_GRAPH_GRAPH_H_ diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 1d7eea2206f..28ebf7e8c3d 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -91,24 +91,36 @@ class GraphConstructor { bool importing; }; - static Status Construct(const Options& opts, const GraphDef* gdef, Graph* g, + typedef gtl::ArraySlice NodeDefSlice; + + // versions and library may be nullptr + static Status Construct(const Options& opts, NodeDefSlice node_defs, + const VersionDef* versions, + const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, std::vector>* return_tensors) { - TF_RETURN_IF_ERROR(CheckVersions(gdef->versions(), TF_GRAPH_DEF_VERSION, - TF_GRAPH_DEF_VERSION_MIN_PRODUCER, - "GraphDef", "graph")); - GraphConstructor c(opts, gdef, g, refiner, return_tensors); + if (versions) { + TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION, + TF_GRAPH_DEF_VERSION_MIN_PRODUCER, + "GraphDef", "graph")); + } + GraphConstructor c(opts, node_defs, versions, library, g, refiner, + return_tensors); const Status s = c.TryImport(); if (!s.ok()) c.Undo(); return s; } private: - GraphConstructor(const Options& opts, const GraphDef* gdef, Graph* g, + GraphConstructor(const Options& opts, NodeDefSlice node_defs, + const VersionDef* versions, + const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, std::vector>* return_tensors) : opts_(opts), - gdef_(gdef), + node_defs_(node_defs), + versions_(versions), + library_(library), g_(g), original_versions_(g->versions()), refiner_(refiner), @@ -159,7 +171,9 @@ class GraphConstructor { // From constructor const Options opts_; - const GraphDef* gdef_; + const NodeDefSlice node_defs_; + const VersionDef* versions_; + const FunctionDefLibrary* library_; Graph* g_; const VersionDef original_versions_; @@ -168,7 +182,7 @@ class GraphConstructor { // May be null. Not owned. std::vector>* return_tensors_; - // Mapping from node name to the index within gdef_ + // Mapping from node name to the index within node_defs_ struct NodeInfo { explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {} // std::unordered_map<> requires that we have a default constructor. @@ -183,18 +197,18 @@ class GraphConstructor { // Mapping from node name to the existing node in g_ std::unordered_map existing_nodes_; - // Index of NodeDefs in gdef_ with all inputs already converted. + // Index of NodeDefs in node_defs_ with all inputs already converted. std::vector ready_; - // Mapping between index within gdef_ and the number of inputs that + // Mapping between index within node_defs_ and the number of inputs that // still need to be converted. std::vector pending_count_; - // Mapping between index within gdef_ and the index within gdef_ of + // Mapping between index within node_defs_ and the index within node_defs_ of // all nodes it outputs to. std::vector> outputs_; - // Used in the conversion from gdef_ to g_ to represent the ith input + // Used in the conversion from node_defs_ to g_ to represent the ith input // of a node. struct InputInfo { explicit InputInfo(const string& node_name, Node* n, int i) @@ -205,7 +219,7 @@ class GraphConstructor { int index; }; - // Used in the conversion from gdef_ to g_ to represent an edge from + // Used in the conversion from node_defs_ to g_ to represent an edge from // the node named 'name' to node 'n'. struct EdgeInfo { explicit EdgeInfo(const string& name, int i1, Node* n, int i2) @@ -254,8 +268,8 @@ Status GraphConstructor::EnsureNoNameCollisions() { } } if (opts_.prefix.empty() && opts_.importing) { - for (int n = 0; n < gdef_->node_size(); ++n) { - const string& name = gdef_->node(n).name(); + for (const NodeDef* n : node_defs_) { + const string& name = n->name(); if (existing_nodes_.find(name) != existing_nodes_.end()) { return errors::InvalidArgument("Node '", name, "' already exists in the Graph"); @@ -312,8 +326,8 @@ Status GraphConstructor::ValidateInputMapAndControlDependencies() { Status GraphConstructor::BuildNodeIndex() { // Validate the node names and add them to gdef_nodes_. - for (int n = 0; n < gdef_->node_size(); ++n) { - const NodeDef& node_def(gdef_->node(n)); + for (int n = 0; n < node_defs_.size(); ++n) { + const NodeDef& node_def = *node_defs_[n]; if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) { return errors::InvalidArgument( "Node '", node_def.name(), @@ -351,13 +365,13 @@ Status GraphConstructor::BuildNodeIndex() { } Status GraphConstructor::InitFromEdges() { - const int num_nodes = gdef_->node_size(); + const int num_nodes = node_defs_.size(); pending_count_.reserve(num_nodes); outputs_.resize(num_nodes); // Parse the inputs for each node. for (int n = 0; n < num_nodes; ++n) { - const NodeDef& node_def(gdef_->node(n)); + const NodeDef& node_def = *node_defs_[n]; if (IsMerge(node_def)) { // for merge only wait for one non-control input. int32 num_control_edges = 0; @@ -489,13 +503,16 @@ Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); AddDefaultsToNodeDef(*op_def, node_def); TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def)); - TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, gdef_->versions().producer())); + if (versions_) { + TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer())); + } return Status::OK(); } void RemoveInputs(NodeDef* node_def, const std::vector& inputs_to_remove) { // TODO(skyewm): is there a better way to do this? std::vector inputs; + inputs.reserve(node_def->input_size()); for (int i = 0; i < node_def->input_size(); ++i) { inputs.push_back(node_def->input(i)); } @@ -607,10 +624,15 @@ void GraphConstructor::AddPrefixToNodeDef( Status GraphConstructor::Convert() { // Import functions before adding nodes, since imported nodes may refer to // functions - TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(gdef_->library())); + if (library_) { + TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_)); + } std::vector inputs; int processed = 0; + + std::vector input_already_exists; + // Process the NodeDefs in topological order. // (InitFromEdges() sets this up by filling in ready_ with nodes that have no // inputs, pending_counts_ with the number of inputs for each node and @@ -622,16 +644,16 @@ Status GraphConstructor::Convert() { inputs.clear(); bool has_data_back_edge = false; - const NodeDef& original_node_def = gdef_->node(o); + const NodeDef& original_node_def = *node_defs_[o]; NodeDef imported_node_def; const NodeDef* node_def; // input_already_exists[i] is true iff the i-th input of the node we're // importing refers to a preexisting node in g_ (i.e. input[i] existed prior - // to importing gdef_). Conversely, input_already_exists[i] is false iff - // the input refers to a node in gdef_. - std::vector input_already_exists(original_node_def.input_size(), - false); + // to importing node_defs_). Conversely, input_already_exists[i] is false + // iff the input refers to a node in node_defs_. + input_already_exists.clear(); + input_already_exists.resize(original_node_def.input_size(), false); if (opts_.importing) { // TODO(ashankar): The line below means an additional copy of the NodeDef, @@ -727,8 +749,8 @@ Status GraphConstructor::Convert() { } } - if (processed < gdef_->node_size()) { - return errors::InvalidArgument(gdef_->node_size() - processed, + if (processed < node_defs_.size()) { + return errors::InvalidArgument(node_defs_.size() - processed, " nodes in a cycle"); } return Status::OK(); @@ -752,20 +774,21 @@ Status GraphConstructor::AddBackEdges() { } Status GraphConstructor::UpdateVersionDef() { + if (versions_ == nullptr) return Status::OK(); + if (!opts_.importing) { - g_->set_versions(gdef_->versions()); + g_->set_versions(*versions_); return Status::OK(); } VersionDef versions = g_->versions(); - versions.set_producer( - std::min(versions.producer(), gdef_->versions().producer())); + versions.set_producer(std::min(versions.producer(), versions_->producer())); versions.set_min_consumer( - std::max(versions.min_consumer(), gdef_->versions().min_consumer())); - if (gdef_->versions().bad_consumers_size() > 0) { + std::max(versions.min_consumer(), versions_->min_consumer())); + if (versions_->bad_consumers_size() > 0) { std::set bad(versions.bad_consumers().begin(), versions.bad_consumers().end()); - bad.insert(gdef_->versions().bad_consumers().begin(), - gdef_->versions().bad_consumers().end()); + bad.insert(versions_->bad_consumers().begin(), + versions_->bad_consumers().end()); versions.clear_bad_consumers(); for (int v : bad) { versions.add_bad_consumers(v); @@ -833,7 +856,20 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, const GraphDef& gdef, Graph* g) { ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); - return GraphConstructor::Construct(opts, &gdef, g, &refiner, nullptr); + return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), + &gdef.library(), g, &refiner, nullptr); +} + +Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, + gtl::ArraySlice nodes, Graph* g) { + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry()); + // TODO(irving): Copy will go away once NodeInfo exists + std::vector node_defs; + for (const auto& n : nodes) { + node_defs.push_back(&n); + } + return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g, + &refiner, nullptr); } Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, @@ -882,7 +918,9 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, refiner->set_graph_def_version( std::min(refiner->graph_def_version(), gdef.versions().producer())); - return GraphConstructor::Construct(opts, &gdef, g, refiner, return_tensors); + return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), + &gdef.library(), g, refiner, + return_tensors); } void CopyGraph(const Graph& src, Graph* dest) { diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index 54d38cac65c..7c34dd536cc 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -46,6 +46,12 @@ struct GraphConstructorOptions { extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, const GraphDef& gdef, Graph* g); +// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function +// instantiation. +// TODO(irving): This will turn into std::vector soon. +extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, + gtl::ArraySlice nodes, Graph* g); + // Add the graph in GraphDef gdef into an existing Graph *g. // // On error, returns non-OK and leaves *g unmodified. diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc index ec1c1b6cea2..33d2021f381 100644 --- a/tensorflow/core/graph/graph_def_builder.cc +++ b/tensorflow/core/graph/graph_def_builder.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" +#include + #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" @@ -119,7 +121,7 @@ Node* UnaryOp(const string& op_name, NodeOut input, if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, opts.op_registry()); - node_builder.Input(input); + node_builder.Input(std::move(input)); return opts.FinalizeBuilder(&node_builder); } @@ -128,7 +130,7 @@ Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, opts.op_registry()); - node_builder.Input(a).Input(b); + node_builder.Input(std::move(a)).Input(std::move(b)); return opts.FinalizeBuilder(&node_builder); } diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 60363175594..f8c6895dfa1 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/core/framework/memory_types.h" @@ -392,7 +393,8 @@ Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g, Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2, const string& device_name, const GraphDefBuilder::Options& bopts) { - Node* res_node = ops::BinaryOp("Switch", input1, input2, bopts); + Node* res_node = + ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts); if (bopts.HaveError()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; @@ -401,7 +403,7 @@ Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2, // A next_iteration node for control flow. Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name, const GraphDefBuilder::Options& bopts) { - Node* res_node = ops::UnaryOp("NextIteration", input, bopts); + Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts); if (bopts.HaveError()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index ee545dbfbfa..ca49ea0ac49 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_partition.h" #include +#include #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" @@ -159,7 +160,7 @@ Output BoolInput(const Scope& scope) { } Output Combine(const Scope& scope, Input a, Input b) { - return ConstructOp(scope, "Combine", {a, b}); + return ConstructOp(scope, "Combine", {std::move(a), std::move(b)}); } class GraphPartitionTest : public ::testing::Test { diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 89784c631f0..68848ae8c84 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -110,6 +110,7 @@ class GraphTest : public ::testing::Test { // are readable. static std::vector Stringify(const std::vector& nodes) { std::vector result; + result.reserve(nodes.size()); for (Node* n : nodes) { result.push_back(n->DebugString()); } diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc index 94250240eb7..21a63662cf2 100644 --- a/tensorflow/core/graph/optimizer_cse_test.cc +++ b/tensorflow/core/graph/optimizer_cse_test.cc @@ -86,7 +86,7 @@ class OptimizerCSETest : public ::testing::Test { str_util::Join(edges, ";")); } - string DoCSE(std::function consider_fn = nullptr) { + string DoCSE(const std::function& consider_fn = nullptr) { string before = CanonicalGraphString(&graph_); LOG(ERROR) << "Before rewrites: " << before; diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index 1431641a8fb..fd2f2b32492 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -1,5 +1,7 @@ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "tf_cuda_library") + filegroup( name = "all_files", srcs = glob( @@ -20,7 +22,7 @@ config_setting( }, ) -cc_library( +tf_cuda_library( name = "utils", srcs = ["utils.cc"], hdrs = [ diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 5118a2530b2..d40e66cd168 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -1,5 +1,7 @@ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "tf_cuda_library") + filegroup( name = "all_files", srcs = glob( @@ -108,25 +110,21 @@ cc_test( ], ) -cc_library( +tf_cuda_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], - defines = if_cuda(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ ":op_performance_data_cc", - "//third_party/eigen3", - "//tensorflow/core/grappler/clusters:utils", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", - ] + if_cuda([ - "//tensorflow/core:cuda", - "@local_config_cuda//cuda:cuda_headers", - ]), + "//tensorflow/core/grappler/clusters:utils", + "//third_party/eigen3", + ], ) cc_library( @@ -183,6 +181,28 @@ cc_library( ], ) +cc_test( + name = "virtual_scheduler_test", + srcs = ["virtual_scheduler_test.cc"], + deps = [ + ":graph_properties", + ":utils", + ":virtual_placer", + ":virtual_scheduler", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:utils", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:cost_estimator", + ], +) + cc_library( name = "measuring_cost_estimator", srcs = ["measuring_cost_estimator.cc"], @@ -247,3 +267,18 @@ cc_library( "//tensorflow/core/grappler:grappler_item", ], ) + +cc_test( + name = "analytical_cost_estimator_test", + srcs = ["analytical_cost_estimator_test.cc"], + deps = [ + ":analytical_cost_estimator", + ":virtual_scheduler", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler/clusters:virtual_cluster", + ], +) diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index 7a1e7fcacef..651c77ad9a1 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -97,7 +97,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, node_costs.compute_time.asMicroSeconds().count()); cost_node->set_memory_time( node_costs.memory_time.asMicroSeconds().count()); - for (const auto& output : node_info.outputs) { + for (const auto& output : node_info.op_info.outputs()) { auto output_info = cost_node->add_output_info(); output_info->set_dtype(output.dtype()); auto shape = output_info->mutable_shape(); diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc new file mode 100644 index 00000000000..9e3dd38b09f --- /dev/null +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/virtual_scheduler.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +class AnalyticalCostEstimatorTest : public ::testing::Test { + protected: + void SetUp() override { + // Initializes cluster_ and placer_. + std::unordered_map devices; + DeviceProperties cpu_device; + cpu_device.set_type("CPU"); + cpu_device.set_num_cores(4); + cpu_device.set_frequency(2600); + cpu_device.set_bandwidth(24 * 1024 * 1024); + devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; + DeviceProperties gpu_device; + gpu_device.set_type("GPU"); + gpu_device.set_num_cores(12); + gpu_device.set_frequency(1100); + gpu_device.set_bandwidth(180 * 1024 * 1024); + (*gpu_device.mutable_environment())["architecture"] = "6"; + devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device; + + cluster_.reset(new VirtualCluster(devices)); + } + + GrapplerItem CreateMiniGraph() { + const int batch = 1; + const int width = 28; + const int height = 28; + const int num_channels = 1; + const int num_labels = 10; + const int kernel_size = 3; + const int conv_filters = 32; + + Scope s = Scope::NewRootScope(); + auto images = ops::RandomUniform( + s.WithOpName("image"), {batch, width, height, num_channels}, DT_FLOAT); + auto labels = ops::RandomUniform(s.WithOpName("label"), {batch, num_labels}, + DT_FLOAT); + auto w = ops::Variable( + s.WithOpName("W"), + {kernel_size, kernel_size, num_channels, conv_filters}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("B"), {conv_filters}, DT_FLOAT); + auto conv = + ops::Conv2D(s.WithOpName("conv"), images, w, {1, 1, 1, 1}, "SAME"); + auto bias = ops::Add(s.WithOpName("bias"), conv, b); + auto relu = ops::Relu(s.WithOpName("relu"), bias); + auto flat_shape = ops::Const(s.WithOpName("flat_shape"), + {batch, width * height * conv_filters}); + auto flat = ops::Reshape(s.WithOpName("flat"), relu, flat_shape); + + auto w2 = + ops::Variable(s.WithOpName("W2"), + {width * height * conv_filters, num_labels}, DT_FLOAT); + auto b2 = ops::Variable(s.WithOpName("B2"), {num_labels}, DT_FLOAT); + auto matmul = ops::MatMul(s.WithOpName("matmul"), flat, w2); + auto logits = ops::Add(s.WithOpName("logits"), matmul, b2); + auto softmax = ops::Softmax(s.WithOpName("softmax"), logits); + auto lsm = ops::Log(s.WithOpName("lsm"), softmax); + + GrapplerItem item; + item.fetch.push_back("lsm"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + return item; + } + + std::unique_ptr cluster_; +}; + +TEST_F(AnalyticalCostEstimatorTest, SimpleTest) { + GrapplerItem item = CreateMiniGraph(); + + AnalyticalCostEstimator estimator(cluster_.get(), true); + TF_ASSERT_OK(estimator.Initialize(item)); + + CostGraphDef cost_graph; + Costs summary; + TF_ASSERT_OK(estimator.PredictCosts(item.graph, &cost_graph, &summary)); + + EXPECT_EQ(Costs::NanoSeconds(9108), summary.execution_time); + EXPECT_FALSE(summary.inaccurate); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 22c0c803e85..11a57921e56 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -76,32 +76,21 @@ std::pair OpLevelCostEstimator::GetDeviceInfo( const DeviceProperties& device) const { double gflops = -1; double bandwidth = -1; - if (device.bandwidth() > 0) { - bandwidth = device.bandwidth() / 1e6; - } if (device.type() == "CPU") { - DeviceProperties local_cpu; - if (device.num_cores() <= 0 || device.frequency() <= 0) { - local_cpu = GetLocalCPUInfo(); - } else { - local_cpu = device; - } - // Check if vector instructions are available, and refine performance // prediction based on this. // Frequencies are stored in MHz in the DeviceProperties. - gflops = local_cpu.num_cores() * local_cpu.frequency() * 1e-3; + gflops = device.num_cores() * device.frequency() * 1e-3; if (bandwidth < 0) { - if (local_cpu.bandwidth() > 0) { - bandwidth = local_cpu.bandwidth() / 1e6; + if (device.bandwidth() > 0) { + bandwidth = device.bandwidth() / 1e6; } else { bandwidth = 32; } } } else if (device.type() == "GPU") { - const DeviceProperties local_gpu = GetLocalGPUInfo(0); - const string architecture = local_gpu.environment().at("architecture"); + const string architecture = device.environment().at("architecture"); int cores_per_multiprocessor; if (architecture < "3") { // Fermi @@ -110,17 +99,18 @@ std::pair OpLevelCostEstimator::GetDeviceInfo( // Kepler cores_per_multiprocessor = 192; } else if (architecture < "6") { - // Maxwell + // Maxwell cores_per_multiprocessor = 128; } else { - // Pascal. + // Pascal cores_per_multiprocessor = 64; } - gflops = local_gpu.num_cores() * local_gpu.frequency() * 1e-3 * + gflops = device.num_cores() * device.frequency() * 1e-3 * cores_per_multiprocessor * kOpsPerMac; - if (bandwidth < 0) { - CHECK(local_gpu.bandwidth() > 0); - bandwidth = local_gpu.bandwidth() / 1e6; + if (device.bandwidth() > 0) { + bandwidth = device.bandwidth() / 1e6; + } else { + bandwidth = 100; } } @@ -507,14 +497,13 @@ int64 OpLevelCostEstimator::CountConv2DBackPropInputOperations( return ops; } - if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { + if (op_features.outputs_size() != 1) { // Need _output_shapes for input shape. - LOG(ERROR) << "No output shape in Conv2DBackPropInput op feaure."; + LOG(ERROR) << "No output shape in Conv2DBackPropInput op."; return ops; } - const auto& input_shape = - op_features.attr().at("_output_shapes").list().shape(0); + const auto& input_shape = op_features.outputs(0).shape(); ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( input_shape, op_features.inputs(1).shape(), op_features, found_unknown_shapes); @@ -542,14 +531,13 @@ int64 OpLevelCostEstimator::CountConv2DBackPropFilterOperations( return ops; } - if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { - // Need _output_shapes for filter shape. - LOG(ERROR) << "No output shape in Conv2DBackPropFilter op feaure."; + if (op_features.outputs_size() != 1) { + // Need _output_shapes for input shape. + LOG(ERROR) << "No output shape in Conv2DBackPropFilter op."; return ops; } - const auto& filter_shape = - op_features.attr().at("_output_shapes").list().shape(0); + const auto& filter_shape = op_features.outputs(0).shape(); ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( op_features.inputs(0).shape(), filter_shape, op_features, found_unknown_shapes); @@ -598,28 +586,19 @@ int64 OpLevelCostEstimator::CalculateOutputSize( const OpInfo& op_features, bool* found_unknown_shapes) const { int64 total_output_size = 0; // use float as default for calculations - DataType dt = DT_FLOAT; - for (const auto& item : op_features.attr()) { - VLOG(1) << "Key:" << item.first - << " Value:" << SummarizeAttrValue(item.second); - if (item.first == "_output_shapes") { - for (const auto& original_output_shape : item.second.list().shape()) { - int64 output_size = 1; - int num_dims = std::max(1, original_output_shape.dim_size()); - auto output_shape = MaybeGetMinimumShape( - original_output_shape, num_dims, found_unknown_shapes); - for (const auto& dim : output_shape.dim()) { - output_size *= dim.size(); - } - output_size *= DataTypeSize(dt); - total_output_size += output_size; - VLOG(1) << "Output Size: " << output_size - << " Total Output Size:" << total_output_size; - } - } - if (item.first == "T") { - dt = item.second.type(); + for (const auto& output : op_features.outputs()) { + DataType dt = output.dtype(); + const auto& original_output_shape = output.shape(); + int64 output_size = DataTypeSize(dt); + int num_dims = std::max(1, original_output_shape.dim_size()); + auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims, + found_unknown_shapes); + for (const auto& dim : output_shape.dim()) { + output_size *= dim.size(); } + total_output_size += output_size; + VLOG(1) << "Output Size: " << output_size + << " Total Output Size:" << total_output_size; } return total_output_size; } diff --git a/tensorflow/core/grappler/costs/op_performance_data.proto b/tensorflow/core/grappler/costs/op_performance_data.proto index 887a714c0f7..0d6b337d5a3 100644 --- a/tensorflow/core/grappler/costs/op_performance_data.proto +++ b/tensorflow/core/grappler/costs/op_performance_data.proto @@ -33,7 +33,7 @@ message OpInfo { // Custom parameters impacting the behavior of the op. map attr = 2; - // Input types, shapes and values if known. + // Input data types, shapes and values if known. message TensorProperties { DataType dtype = 1; TensorShapeProto shape = 2; @@ -41,6 +41,9 @@ message OpInfo { }; repeated TensorProperties inputs = 3; + // Optional description of the op outputs + repeated TensorProperties outputs = 5; + // Device on which the operation is run. DeviceProperties device = 4; } diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 3cc92b56d20..1eca141e57c 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -167,12 +167,16 @@ std::vector FindInputFeatures( inputs.push_back(UnknownInput()); } else { const CostGraphDef::Node* input_cost = it->second; - const CostGraphDef::Node::OutputInfo& output = - input_cost->output_info(output_index); - OpInfo::TensorProperties input; - input.set_dtype(output.dtype()); - *input.mutable_shape() = output.shape(); - inputs.push_back(input); + if (input_cost->output_info_size() == 0) { + inputs.push_back(UnknownInput()); + } else { + const CostGraphDef::Node::OutputInfo& output = + input_cost->output_info(output_index); + OpInfo::TensorProperties input; + input.set_dtype(output.dtype()); + *input.mutable_shape() = output.shape(); + inputs.push_back(input); + } } } diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 32b4b3c8bc0..8d8d246078c 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -316,13 +316,17 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const { NodeInfo node_info; node_info.name = node->name(); node_info.device_name = graph_properties_.GetDeviceName(node->name()); - node_info.outputs = graph_properties_.GetOutputProperties(node->name()); + std::vector outputs = + graph_properties_.GetOutputProperties(node->name()); auto& op_info = node_info.op_info; op_info.set_op(node->op()); *op_info.mutable_attr() = node->attr(); for (auto& input : inputs) { op_info.add_inputs()->Swap(&input); } + for (auto& output : outputs) { + op_info.add_outputs()->Swap(&output); + } op_info.mutable_device()->Swap(&device); // add some more to the node_info. return node_info; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 310f6cca09c..7764bdc478a 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -95,7 +95,6 @@ struct NodeInfo { OpInfo op_info; string name; string device_name; - std::vector outputs; }; // The virtual scheduler emulates execution of nodes in a graph, considering diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc new file mode 100644 index 00000000000..dad2104b754 --- /dev/null +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/virtual_scheduler.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/costs/virtual_placer.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +class VirtualSchedulerTest : public ::testing::Test { + protected: + void SetUp() override { + // Initializes cluster_ and placer_. + std::unordered_map devices; + DeviceProperties cpu_device; + cpu_device.set_type("CPU"); + devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; + DeviceProperties gpu_device; + gpu_device.set_type("GPU"); + devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device; + + cluster_.reset(new VirtualCluster(devices)); + placer_.reset(new VirtualPlacer(cluster_.get())); + } + + void CreateSchedulerWithConv2Ds() { + // Create a scheduler with a simple graph: 3 Conv2Ds, where only 2 are in + // fetch nodes. + const int bs = 4; + const int width = 10; + const int height = 10; + const int depth_in = 8; + const int kernel = 3; + const int depth_out = 16; + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = tensorflow::ops::RandomUniform( + s.WithOpName("x"), {bs, width, height, depth_in}, DT_FLOAT); + auto y = tensorflow::ops::RandomUniform( + s.WithOpName("y"), {bs, width, height, depth_in}, DT_FLOAT); + auto z = tensorflow::ops::RandomUniform( + s.WithOpName("z"), {bs, width, height, depth_in}, DT_FLOAT); + auto f = tensorflow::ops::RandomUniform( + s.WithOpName("f"), {kernel, kernel, depth_in, depth_out}, DT_FLOAT); + std::vector strides = {1, 1, 1, 1}; + auto c0 = + tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME"); + auto c1 = + tensorflow::ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME"); + auto c2 = + tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME"); + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + LOG(INFO) << def.DebugString(); + + grappler_item_.reset(new GrapplerItem); + grappler_item_->id = "test_conv2d_graph"; + grappler_item_->graph = def; + grappler_item_->fetch = {"c0", "c1"}; + + scheduler_.reset(new VirtualScheduler( + grappler_item_.get(), true /* use_static_shapes */, + "CPU" /* default_device_type */, cluster_.get(), placer_.get())); + TF_CHECK_OK(scheduler_->Init()); + } + + // SetUp() inits cluster_ and placer_. + std::unique_ptr cluster_; + std::unique_ptr placer_; + + // grappler_item_ and scheduler_ will be initialized differently for each test + // case + std::unique_ptr grappler_item_; + std::unique_ptr scheduler_; +}; + +TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { + CreateSchedulerWithConv2Ds(); // init scheduler_. + + Costs zero_costs = Costs::ZeroCosts(); + std::unordered_map ops_executed; + do { + NodeInfo node_info = scheduler_->GetCurrNodeInfo(); + ops_executed[node_info.name] = node_info; + + // Check scheduling order: x and f before c0, and y and f before c1. + if (node_info.name == "c0") { + EXPECT_GT(ops_executed.count("x"), 0); + EXPECT_GT(ops_executed.count("f"), 0); + } else if (node_info.name == "c1") { + EXPECT_GT(ops_executed.count("y"), 0); + EXPECT_GT(ops_executed.count("f"), 0); + } + } while (scheduler_->MarkCurrNodeExecuted(zero_costs)); + + // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be + // executed. + EXPECT_EQ(8, ops_executed.size()); + + // x, y, f, c0, and c1 should be in the ops executed. + EXPECT_GT(ops_executed.count("x"), 0); + EXPECT_GT(ops_executed.count("y"), 0); + EXPECT_GT(ops_executed.count("f"), 0); + EXPECT_GT(ops_executed.count("c0"), 0); + EXPECT_GT(ops_executed.count("c1"), 0); + + // z and c2 shouldn't be part of it. + EXPECT_EQ(ops_executed.count("z"), 0); + EXPECT_EQ(ops_executed.count("c2"), 0); + + // Check input / output properties. + EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size()); + EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size()); + EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size()); + EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size()); + EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size()); +} +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 9ac0303da92..1ad3cbb4cb9 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/grappler/grappler_item_builder.h" #include @@ -70,7 +69,8 @@ void InitializeTensor(DataType type, Tensor* tensor) { // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in // order to get the correct session options and environment, and performing the // correct optimizations. -Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def) { +Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, + const ItemConfig& cfg) { // Create a session option for a single GPU device. SessionOptions options; @@ -94,7 +94,12 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def) { // Optimizer options: L1 and inlining. L1 is default. OptimizerOptions* optimizer_opts = options.config.mutable_graph_options()->mutable_optimizer_options(); - optimizer_opts->set_do_function_inlining(true); + if (cfg.apply_optimizations) { + optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L1); + } else { + optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L0); + } + optimizer_opts->set_do_function_inlining(cfg.inline_functions); // Create the function library runtime. std::unique_ptr flib(NewFunctionLibraryRuntime( @@ -129,16 +134,6 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( new_item->id = id; new_item->graph = meta_graph.graph_def(); - // Optimize the graph (function inlining, l1 optimizations, etc). - if (cfg.apply_optimizations) { - Status optimize_status = - OptimizeGraph(meta_graph.graph_def(), &new_item->graph); - if (!optimize_status.ok()) { - LOG(ERROR) << "Function optimization failed: " << optimize_status; - return nullptr; - } - } - // Attempt to detect the fetch node(s). if (meta_graph.collection_def().count("train_op") > 0) { const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); @@ -180,13 +175,17 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( // from it. We do this because in newer protos, the input placeholder // shape is not empty if the shape is partially defined. TensorShape shape; + TensorShapeProto shape_proto; std::vector dims; for (const auto& dim_proto : node.attr().at("shape").shape().dim()) { if (cfg.placeholder_unknown_output_shape_dim >= 0 && dim_proto.size() == -1) { dims.push_back(cfg.placeholder_unknown_output_shape_dim); + shape_proto.add_dim()->set_size( + cfg.placeholder_unknown_output_shape_dim); } else { dims.push_back(dim_proto.size()); + shape_proto.add_dim()->set_size(dim_proto.size()); } } Status make_shape_status = @@ -211,6 +210,7 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( (shape.dims() == 0) && (node.attr().count("_output_shapes") == 1) && (node.attr().at("_output_shapes").list().shape(0).dim_size() != 0)) { shape.Clear(); + shape_proto.clear_dim(); for (int dim_i = 0; dim_i < node.attr().at("_output_shapes").list().shape(0).dim_size(); @@ -219,21 +219,33 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( node.attr().at("_output_shapes").list().shape(0).dim(dim_i); if (dim.size() == -1) { shape.AddDim(cfg.placeholder_unknown_output_shape_dim); + shape_proto.add_dim()->set_size( + cfg.placeholder_unknown_output_shape_dim); } else { - shape.AddDim(node.attr() - .at("_output_shapes") - .list() - .shape(0) - .dim(dim_i) - .size()); + int size = node.attr() + .at("_output_shapes") + .list() + .shape(0) + .dim(dim_i) + .size(); + shape.AddDim(size); + shape_proto.add_dim()->set_size(size); } } } Tensor fake_input(type, shape); InitializeTensor(type, &fake_input); new_item->feed.emplace_back(node.name(), fake_input); + // Set the shape of the node in the graph. This is needed for statically + // inferring shapes and is a no-op when dynamically inferring shapes as + // the Placeholder shape will match the shape passed from new_item->feed. + *(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto; } + // Erase the recorded result of any previous shape inference to start again + // from scratch. + node.mutable_attr()->erase("_output_shapes"); + // Delete user specified placement if requested. if (cfg.ignore_user_placement) { node.clear_device(); @@ -313,6 +325,14 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } } + // Optimize the graph (function inlining, l1 optimizations, etc). + Status optimize_status = + OptimizeGraph(new_item->graph, &new_item->graph, cfg); + if (!optimize_status.ok()) { + LOG(ERROR) << "Function optimization failed: " << optimize_status; + return nullptr; + } + return new_item; } diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index 62be8dfe14f..3aa1d2027f5 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -31,7 +31,8 @@ struct ItemConfig { : ignore_user_placement(true), ignore_colocation(true), placeholder_unknown_output_shape_dim(-1), - apply_optimizations(true) {} + apply_optimizations(true), + inline_functions(true) {} // If true, ignore all user specified node placement. bool ignore_user_placement; @@ -40,8 +41,10 @@ struct ItemConfig { // Dimension to use if a placeholder node has an _output_shapes attribute with // a dimension of -1. int placeholder_unknown_output_shape_dim; - // If true, does inlining and L1 optimizations. + // If true, does L1 optimizations. bool apply_optimizations; + // If true, does inlining. + bool inline_functions; }; // Factory method for creating a GrapplerItem from a MetaGraphDef. diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 54400f7051c..92225ffb1b4 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -70,6 +70,7 @@ std::unique_ptr CreateGrapplerItem(const GraphDef &def, const CollectionDef &fetches) { MetaGraphDef meta_def; ItemConfig cfg; + cfg.inline_functions = true; *meta_def.mutable_graph_def() = def; (*meta_def.mutable_collection_def())["train_op"] = fetches; return GrapplerItemFromMetaGraphDef("0", meta_def, cfg); diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 5c2438e258e..ebe380070de 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -29,10 +29,10 @@ bool IsConstant(const NodeDef& node) { } bool IsDequeueOp(const NodeDef& node) { - static const std::set dequeue_ops = { - "QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2", - "QueueDequeue"}; - return dequeue_ops.count(node.op()) > 0; + const auto& op = node.op(); + return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" || + op == "QueueDequeueV2" || op == "QueueDequeue" || + op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo"; } bool IsMerge(const NodeDef& node) { @@ -46,6 +46,12 @@ bool IsPlaceholder(const NodeDef& node) { op == "PlaceholderWithDefault"; } +bool IsReduction(const NodeDef& node) { + const auto& op = node.op(); + return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" || + op == "Mean" || op == "Any" || op == "All"; +} + bool IsTranspose(const NodeDef& node) { const auto op = node.op(); return op == "Transpose"; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 04bb78149f7..d32487c1286 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -26,6 +26,7 @@ bool IsConstant(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsMerge(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); +bool IsReduction(const NodeDef& node); bool IsTranspose(const NodeDef& node); bool IsVariable(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 291d7f35bc4..c2df76e4315 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -101,6 +101,11 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) { } } // namespace +ConstantFolding::ConstantFolding() { + ops_to_preserve_ = + std::regex("Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader"); +} + Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { GraphProperties properties(item); TF_RETURN_IF_ERROR(properties.InferStatically()); @@ -176,7 +181,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { // Turn the inputs into control dependencies. CHECK_EQ(1, node.input_size()); - node.set_input(0, strings::StrCat("^", node.input(0))); + node.set_input(0, strings::StrCat("^", NodeName(node.input(0)))); } } } @@ -184,28 +189,19 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { } bool ConstantFolding::IsFoldable(const NodeDef& node) const { - DeviceTypeVector device_types; - auto status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node, - &device_types); - if (!status.ok()) { - return false; - } - // Only fold ops with a CPU implementation available. - if (device_types[0] != DeviceType(DEVICE_CPU)) { - return false; - } - + // Skips nodes that must be preserved, and op_types that don't benefit from + // folding if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { return false; } - - if (ops_to_preserve_.find(node.op()) != ops_to_preserve_.end()) { + std::cmatch match; + if (std::regex_match(node.op().c_str(), match, ops_to_preserve_)) { return false; } // Don't fold stateful ops such as TruncatedNormal. const OpDef* op_def = nullptr; - status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); if (!status.ok()) { return false; } @@ -217,6 +213,17 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } + DeviceTypeVector device_types; + status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node, + &device_types); + if (!status.ok()) { + return false; + } + // Only fold ops with a CPU implementation available. + if (device_types[0] != DeviceType(DEVICE_CPU)) { + return false; + } + // Folding not applicable to ops with no inputs. if (node.input().empty()) { return false; @@ -232,7 +239,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { } for (const auto& input : node.input()) { - if (input[0] == '^') { + if (IsControlInput(input)) { continue; } bool is_const = IsConstant(*node_map_->GetNode(input)); @@ -267,7 +274,7 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name, Status ConstantFolding::EvaluateNode(const NodeDef& node, const TensorVector& inputs, - TensorVector* output) { + TensorVector* output) const { Status status; auto op_kernel = CreateOpKernel("CPU", device_.get(), device_->GetAllocator({}), node, @@ -299,7 +306,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, std::vector* outputs) { TensorVector inputs; for (const auto& input : node.input()) { - if (input[0] == '^') { + if (IsControlInput(input)) { break; } TensorVector output; @@ -337,12 +344,12 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) { node_map_->AddNode(added_node->name(), added_node); for (const auto& input : node.input()) { - if (input[0] == '^') { + if (IsControlInput(input)) { *added_node->add_input() = input; } else { NodeDef* input_node = node_map_->GetNode(input); for (const auto& fanin_of_input : input_node->input()) { - if (fanin_of_input[0] == '^') { + if (IsControlInput(fanin_of_input)) { *added_node->add_input() = fanin_of_input; } } @@ -396,6 +403,60 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { return Status::OK(); } +// Returns true iff this reduction can be reduced to an identity (i.e if the set +// of dimensions to reduce along is empty). This happens often in the gradient +// graphs. +bool ConstantFolding::IsSimplifiableReduction(const NodeDef& node) const { + if (IsReduction(node)) { + CHECK_LE(2, node.input_size()); + const NodeDef* reductions_indices = node_map_->GetNode(node.input(1)); + if (IsConstant(*reductions_indices)) { + TensorVector output; + Status s = EvaluateNode(*reductions_indices, TensorVector(), &output); + if (!s.ok()) { + return false; + } + CHECK_EQ(1, output.size()); + int output_size = output[0]->NumElements(); + delete output[0].tensor; + if (output_size == 0) { + return true; + } + } + } + return false; +} + +Status ConstantFolding::SimplifyGraph(GraphDef* output) { + for (auto& node : *output->mutable_node()) { + if (IsSimplifiableReduction(node)) { + // Replace the reduction node with an identity node, that can be further + // optimized by the model pruner. + const NodeDef* reductions_indices = node_map_->GetNode(node.input(1)); + DataType output_type; + if (node.attr().count("T") > 0) { + output_type = node.attr().at("T").type(); + } else { + // This is an 'any' or 'all' reduction. The output is always boolean. + output_type = DT_BOOL; + } + node.set_op("Identity"); + node.clear_attr(); + (*node.mutable_attr())["T"].set_type(output_type); + if (node.input_size() > 2) { + node.mutable_input()->SwapElements(1, node.input_size() - 1); + } + node.mutable_input()->RemoveLast(); + for (const auto& input : reductions_indices->input()) { + if (IsControlInput(input)) { + *node.add_input() = input; + } + } + } + } + return Status::OK(); +} + Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { graph_ = item.graph; @@ -404,10 +465,14 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, for (const auto& node : item.fetch) { nodes_to_preserve_.insert(NodeName(node)); } + for (const auto& node : item.feed) { + nodes_to_preserve_.insert(NodeName(node.first)); + } device_.reset(new DeviceSimple()); *output = GraphDef(); TF_RETURN_IF_ERROR(MaterializeShapes(item)); TF_RETURN_IF_ERROR(FoldGraph(output)); + TF_RETURN_IF_ERROR(SimplifyGraph(output)); LOG(INFO) << "Optimized graph size: " << output->node_size(); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index fd77fc945e3..cb9729ef1ee 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ #define TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ +#include #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" @@ -29,7 +30,7 @@ const char kConstantFoldingConst[] = "ConstantFolding"; // Contant folding optimization for a graph. class ConstantFolding : public GraphOptimizer { public: - ConstantFolding() {} + ConstantFolding(); ~ConstantFolding() override {} @@ -50,7 +51,7 @@ class ConstantFolding : public GraphOptimizer { Status EvaluateNode(const NodeDef& node, const gtl::InlinedVector& inputs, - gtl::InlinedVector* output); + gtl::InlinedVector* output) const; Status EvaluateOneFoldable(const NodeDef& node, std::vector* outputs); @@ -59,18 +60,14 @@ class ConstantFolding : public GraphOptimizer { Status FoldGraph(GraphDef* output); + bool IsSimplifiableReduction(const NodeDef& node) const; + Status SimplifyGraph(GraphDef* output); + std::unique_ptr device_; GraphDef graph_; std::unique_ptr node_map_; std::set nodes_to_preserve_; - std::set ops_to_preserve_ = {"Save", - "SaveV2", - "SaveSlices", - "Restore", - "RestoreV2", - "RestoreSlice", - "PlaceholderWithDefault", - "Const"}; + std::regex ops_to_preserve_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 58bbb817d0b..87e42c72e24 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -245,6 +245,47 @@ TEST_F(ConstantFoldingTest, ShapeMaterialization) { EXPECT_EQ(3, found); } +TEST_F(ConstantFoldingTest, NoOpReduction) { + // Build a simple graph with a reduction that can be reduced to the identity. + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output d = ops::Const(scope.WithOpName("d"), 3.14f, {3, 5, 7}); + Output v = ops::PlaceholderWithDefault(scope.WithOpName("v"), d, {3, 5, 7}); + Output c = ops::Const(scope.WithOpName("c"), 0, {0}); + Output i = ops::Identity(scope.WithOpName("i"), c); + Output p = ops::Prod(scope.WithOpName("p"), v, i); + Output s = ops::Square(scope.WithOpName("s"), p); + + GrapplerItem item; + item.fetch.push_back("s"); + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + ASSERT_EQ("c", item.graph.node(2).name()); + (*item.graph.mutable_node(2)->add_input()) = "^v"; + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + auto expected = EvaluateNodes(item.graph, {"s"}); + auto optimized = EvaluateNodes(output, {"s"}); + EXPECT_EQ(1, expected.size()); + EXPECT_EQ(1, optimized.size()); + test::ExpectTensorEqual(expected[0], optimized[0]); + + bool found = false; + for (const auto& node : output.node()) { + if (node.name() == "p") { + found = true; + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("v", node.input(0)); + EXPECT_EQ("^v", node.input(1)); + } + } + EXPECT_TRUE(found); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index e37c4a5b36a..c42218e447b 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -536,6 +536,7 @@ class AddNProcessor : public AgnosticNodeProcessor { protected: std::vector GetInputPos() const override { std::vector input_pos; + input_pos.reserve(node_->input_size()); for (int i = 0; i < node_->input_size(); i++) { input_pos.push_back(i); } diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 06ef61a9613..b7a04f4423d 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -83,6 +83,10 @@ string ParseNodeName(const string& name, int* position) { } } +bool IsControlInput(const string& name) { + return !name.empty() && name[0] == '^'; +} + string NodeName(const string& name) { int position; return ParseNodeName(name, &position); diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 0fb531ef1bd..5a3c0614e7b 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -46,6 +46,10 @@ class NodeMap { std::unordered_map> outputs_; }; +// True iff 'name' refers to a control inputs, i.e. a node name prefixed with +// the ^ character. +bool IsControlInput(const string& name); + // Return the node name corresponding to 'name' if name is valid, or the empty // string otherwise. string NodeName(const string& name); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c7516beda24..1e5bc0ceab8 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5404,6 +5404,7 @@ tf_kernel_library( srcs = ["iterator_ops.cc"], deps = [ ":dataset", + ":ops_util", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/adjust_contrast_op_test.cc b/tensorflow/core/kernels/adjust_contrast_op_test.cc index 06fd7ca419b..d028c0bc591 100644 --- a/tensorflow/core/kernels/adjust_contrast_op_test.cc +++ b/tensorflow/core/kernels/adjust_contrast_op_test.cc @@ -73,6 +73,7 @@ TEST_F(AdjustContrastOpTest, Big_99x99x3) { TF_EXPECT_OK(InitOp()); std::vector values; + values.reserve(99 * 99 * 3); for (int i = 0; i < 99 * 99 * 3; ++i) { values.push_back(i % 255); } diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc index 5c7102f6f67..755ce7c43bd 100644 --- a/tensorflow/core/kernels/cholesky_op.cc +++ b/tensorflow/core/kernels/cholesky_op.cc @@ -64,11 +64,11 @@ class CholeskyOp : public LinearAlgebraOp { Eigen::Matrix> llt_decomposition(input); - // Output the lower triangular in a dense form. - outputs->at(0) = llt_decomposition.matrixL(); - OP_REQUIRES(context, llt_decomposition.info() == Eigen::Success, errors::InvalidArgument(kErrMsg)); + + // Output the lower triangular in a dense form. + outputs->at(0) = llt_decomposition.matrixL(); } }; diff --git a/tensorflow/core/kernels/dequantize_op_test.cc b/tensorflow/core/kernels/dequantize_op_test.cc index efce8101754..8992629d426 100644 --- a/tensorflow/core/kernels/dequantize_op_test.cc +++ b/tensorflow/core/kernels/dequantize_op_test.cc @@ -105,6 +105,7 @@ static void BM_DequantizeMinCombinedCpu(int iters) { auto root = Scope::NewRootScope().ExitOnError(); const int64 num_values = 1500 * 250; std::vector inputs; + inputs.reserve(num_values); for (int i = 0; i < num_values; ++i) inputs.push_back(i); ops::Dequantize(root, test::AsTensor(inputs), test::AsTensor({-1.5f}), diff --git a/tensorflow/core/kernels/dynamic_partition_op.cc b/tensorflow/core/kernels/dynamic_partition_op.cc index 06765d8ee3a..861e16b2fd0 100644 --- a/tensorflow/core/kernels/dynamic_partition_op.cc +++ b/tensorflow/core/kernels/dynamic_partition_op.cc @@ -104,6 +104,7 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared { const auto data_flat = data->flat(); std::vector, Eigen::Aligned> > out_vec; + out_vec.reserve(num_partitions_); for (int p = 0; p < num_partitions_; p++) { out_vec.push_back(outputs[p]->vec()); } @@ -124,6 +125,7 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared { // If data has extra dimensions, use Eigen slices std::vector, Eigen::Aligned> > out_flat; + out_flat.reserve(num_partitions_); for (int p = 0; p < num_partitions_; p++) { out_flat.push_back(outputs[p]->flat_outer_dims()); } diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc index dfba8e01e4e..33d73c84776 100644 --- a/tensorflow/core/kernels/fractional_max_pool_op.cc +++ b/tensorflow/core/kernels/fractional_max_pool_op.cc @@ -245,9 +245,11 @@ class FractionalMaxPoolGradOp : public OpKernel { constexpr int tensor_in_and_out_dims = 4; std::vector input_size; std::vector output_size; + input_size.reserve(tensor_in_and_out_dims); for (int i = 0; i < tensor_in_and_out_dims; ++i) { input_size.push_back(tensor_in.dim_size(i)); } + output_size.reserve(tensor_in_and_out_dims); for (int i = 0; i < tensor_in_and_out_dims; ++i) { output_size.push_back(tensor_out.dim_size(i)); } diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc index 23645dafad4..37c1462f10c 100644 --- a/tensorflow/core/kernels/gather_op_test.cc +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -164,6 +164,7 @@ static Graph* Gather(int dim) { random::PhiloxRandom philox(301, 17); random::SimplePhilox rnd(&philox); std::vector indices_vec; + indices_vec.reserve(kLookups); for (int i = 0; i < kLookups; i++) { indices_vec.push_back(rnd.Uniform(kRows)); } diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index 880c6a7e824..fa3f3a4db67 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -18,7 +18,10 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -282,38 +285,54 @@ class OneShotIteratorOp : public OpKernel { IteratorResource* iterator_resource_ = nullptr; }; -class IteratorGetNextOp : public OpKernel { +class IteratorGetNextOp : public AsyncOpKernel { public: - explicit IteratorGetNextOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit IteratorGetNextOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("iterator_get_next_thread_", + SanitizeThreadSuffix(def().name())), + 1 /* num_threads */, false /* low_latency_hint */)) {} - // TODO(mrry): Convert this to an async op, because - // `iterator->GetNext()` could trigger long-running operations - // (e.g. a QueueDequeue or a remote read). - void Compute(OpKernelContext* ctx) override { + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { IteratorResource* iterator; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); - core::ScopedUnref unref_iterator(iterator); - std::vector components; - bool end_of_sequence; + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + thread_pool_->Schedule([this, ctx, iterator, done]() { + core::ScopedUnref unref_iterator(iterator); - IteratorContext::Params params; - params.env = ctx->env(); - params.step_id = ctx->step_id(); - params.resource_manager = ctx->resource_manager(); - params.runner = *(ctx->runner()); - IteratorContext iter_ctx(std::move(params)); + std::vector components; + bool end_of_sequence; - OP_REQUIRES_OK(ctx, - iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); - OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence")); + IteratorContext::Params params; + params.env = ctx->env(); + params.step_id = ctx->step_id(); + params.resource_manager = ctx->resource_manager(); + params.runner = *(ctx->runner()); + IteratorContext iter_ctx(std::move(params)); - for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. - ctx->set_output(i, components[i]); - } + OP_REQUIRES_OK_ASYNC( + ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), + done); + OP_REQUIRES_ASYNC(ctx, !end_of_sequence, + errors::OutOfRange("End of sequence"), done); + + for (int i = 0; i < components.size(); ++i) { + // TODO(mrry): Check that the shapes match the shape attrs. + ctx->set_output(i, components[i]); + } + + done(); + }); } + + private: + std::unique_ptr thread_pool_; }; class IteratorDisposeOp : public OpKernel { diff --git a/tensorflow/core/kernels/matrix_solve_ls_op.cc b/tensorflow/core/kernels/matrix_solve_ls_op.cc index 11e7c94faf3..381a5ec7b9d 100644 --- a/tensorflow/core/kernels/matrix_solve_ls_op.cc +++ b/tensorflow/core/kernels/matrix_solve_ls_op.cc @@ -105,18 +105,19 @@ class MatrixSolveLsOp : public LinearAlgebraOp { // using Cholesky decomposition. Matrix gramian(cols, cols); gramian.template triangularView() = - matrix.transpose() * matrix; + matrix.adjoint() * matrix; if (l2_regularizer > 0) { gramian += (Scalar(l2_regularizer) * Matrix::Ones(cols, 1)).asDiagonal(); } - const Eigen::LLT llt(gramian); + const Eigen::LLT, Eigen::Lower> llt(gramian); OP_REQUIRES( context, llt.info() == Eigen::Success, errors::InvalidArgument("Input matrix was rank deficient or " "ill-conditioned. Try setting fast=False " "or provide a larger l2_regularizer > 0.")); - outputs->at(0) = llt.solve(matrix.transpose() * rhs); + outputs->at(0).noalias() = matrix.adjoint() * rhs; + llt.solveInPlace(outputs->at(0)); } else { // Underdetermined case (rows < cols): Solves the minimum-norm problem // min ||X||_F^2 s.t. A*X = RHS @@ -125,18 +126,18 @@ class MatrixSolveLsOp : public LinearAlgebraOp { // using Cholesky decomposition. Matrix gramian(rows, rows); gramian.template triangularView() = - matrix * matrix.transpose(); + matrix * matrix.adjoint(); if (l2_regularizer > 0) { gramian += (Scalar(l2_regularizer) * Matrix::Ones(rows, 1)).asDiagonal(); } - const Eigen::LLT llt(gramian); + const Eigen::LLT, Eigen::Lower> llt(gramian); OP_REQUIRES( context, llt.info() == Eigen::Success, errors::InvalidArgument("Input matrix was rank deficient or " "ill-conditioned. Try setting fast=False " "or provide an l2_regularizer > 0.")); - outputs->at(0) = matrix.transpose() * llt.solve(rhs); + outputs->at(0).noalias() = matrix.adjoint() * llt.solve(rhs); } } else { // Use complete orthogonal decomposition which is backwards stable and diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc index c3a7e779403..602dfeb4e54 100644 --- a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc +++ b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc @@ -29,6 +29,7 @@ TEST(MfccMelFilterbankTest, AgreesWithPythonGoldenValues) { std::vector input; const int kSampleCount = 513; + input.reserve(kSampleCount); for (int i = 0; i < kSampleCount; ++i) { input.push_back(i + 1); } diff --git a/tensorflow/core/kernels/mfcc_test.cc b/tensorflow/core/kernels/mfcc_test.cc index 7efecba85b1..cb32df8811e 100644 --- a/tensorflow/core/kernels/mfcc_test.cc +++ b/tensorflow/core/kernels/mfcc_test.cc @@ -26,6 +26,7 @@ TEST(MfccTest, AgreesWithPythonGoldenValues) { Mfcc mfcc; std::vector input; const int kSampleCount = 513; + input.reserve(kSampleCount); for (int i = 0; i < kSampleCount; ++i) { input.push_back(i + 1); } @@ -51,6 +52,7 @@ TEST(MfccTest, AvoidsNansWithZeroInput) { Mfcc mfcc; std::vector input; const int kSampleCount = 513; + input.reserve(kSampleCount); for (int i = 0; i < kSampleCount; ++i) { input.push_back(0.0); } diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc index c547b166eee..901ea65bdc1 100644 --- a/tensorflow/core/kernels/quantization_utils_test.cc +++ b/tensorflow/core/kernels/quantization_utils_test.cc @@ -37,6 +37,7 @@ void TestRequantizeMany(Eigen::ThreadPoolDevice* eigen_device, float input_min, int tolerance = 1) { const int values_count = values_quantized.size(); std::vector expected_values; + expected_values.reserve(values_count); for (int value_index = 0; value_index < values_count; ++value_index) { expected_values.push_back(FloatToQuantized( QuantizedToFloat(values_quantized[value_index], input_min, input_max), @@ -78,6 +79,7 @@ void TestRequantizeMany8To32Bit(float input_min, float input_max, int tolerance = 256) { const int values_count = values_quantized.size(); std::vector expected_values; + expected_values.reserve(values_count); for (int value_index = 0; value_index < values_count; ++value_index) { expected_values.push_back(FloatToQuantized( QuantizedToFloat(values_quantized[value_index], input_min, input_max), diff --git a/tensorflow/core/kernels/sdca_ops_test.cc b/tensorflow/core/kernels/sdca_ops_test.cc index 400f330ce7b..ce50116a2d0 100644 --- a/tensorflow/core/kernels/sdca_ops_test.cc +++ b/tensorflow/core/kernels/sdca_ops_test.cc @@ -57,6 +57,7 @@ Node* Var(Graph* const g, const int n) { std::vector VarVector(Graph* const g, const int nodes, const int node_size) { std::vector result; + result.reserve(nodes); for (int i = 0; i < nodes; ++i) { result.push_back(Var(g, node_size)); } @@ -164,6 +165,7 @@ void GetGraphs(const int32 num_examples, const int32 num_sparse_feature_groups, sparse_weights.push_back(NodeBuilder::NodeOut(n)); } std::vector dense_weights; + dense_weights.reserve(dense_weight_nodes.size()); for (Node* n : dense_weight_nodes) { dense_weights.push_back(NodeBuilder::NodeOut(n)); } @@ -171,20 +173,24 @@ void GetGraphs(const int32 num_examples, const int32 num_sparse_feature_groups, std::vector sparse_example_indices; std::vector sparse_feature_indices; std::vector sparse_values; + sparse_example_indices.reserve(num_sparse_feature_groups); for (int i = 0; i < num_sparse_feature_groups; ++i) { sparse_example_indices.push_back(NodeBuilder::NodeOut( SparseExampleIndices(g, sparse_features_per_group, num_examples))); } + sparse_feature_indices.reserve(num_sparse_feature_groups); for (int i = 0; i < num_sparse_feature_groups; ++i) { sparse_feature_indices.push_back(NodeBuilder::NodeOut( SparseFeatureIndices(g, sparse_features_per_group, num_examples))); } + sparse_values.reserve(num_sparse_feature_groups); for (int i = 0; i < num_sparse_feature_groups; ++i) { sparse_values.push_back( NodeBuilder::NodeOut(RandomZeroOrOne(g, num_examples * 4))); } std::vector dense_features; + dense_features.reserve(num_dense_feature_groups); for (int i = 0; i < num_dense_feature_groups; ++i) { dense_features.push_back(NodeBuilder::NodeOut( RandomZeroOrOneMatrix(g, num_examples, dense_features_per_group))); diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index 4f73583ed80..4d04a206754 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -361,6 +361,7 @@ class DeserializeManySparseOp : public OpKernel { std::iota(std_order.begin(), std_order.end(), 0); std::vector tensors_to_concat; + tensors_to_concat.reserve(num_sparse_tensors); for (int i = 0; i < num_sparse_tensors; ++i) { tensors_to_concat.emplace_back(indices_to_concat[i], values_to_concat[i], preconcat_shape, std_order); diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index 2b4d5effdad..ed93caad331 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -452,6 +452,7 @@ class SparseCrossOp : public OpKernel { ExtractFeatureData(indices_list_in, batch_size, &feature_counts, &feature_start_indices); + columns.reserve(values_list_in.size()); for (int i = 0; i < values_list_in.size(); ++i) { columns.emplace_back(new SparseTensorColumn( values_list_in[i], std::move(feature_counts[i]), diff --git a/tensorflow/core/kernels/sparse_tensors_map_ops.cc b/tensorflow/core/kernels/sparse_tensors_map_ops.cc index f7b609191af..047e7c9e5d7 100644 --- a/tensorflow/core/kernels/sparse_tensors_map_ops.cc +++ b/tensorflow/core/kernels/sparse_tensors_map_ops.cc @@ -463,6 +463,7 @@ class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp { std::iota(std_order.begin(), std_order.end(), 0); std::vector tensors_to_concat; + tensors_to_concat.reserve(N); for (int i = 0; i < N; ++i) { tensors_to_concat.emplace_back(std::move(indices_to_concat[i]), std::move(values_to_concat[i]), diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc index d56941b332d..e216855c27e 100644 --- a/tensorflow/core/kernels/stage_op.cc +++ b/tensorflow/core/kernels/stage_op.cc @@ -220,7 +220,8 @@ class StageOp : public OpKernel { OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf)); core::ScopedUnref scope(buf); Buffer::Tuple tuple; - for (std::size_t i = 0; i < ctx->num_inputs(); ++i) { + tuple.reserve(ctx->num_inputs()); + for (int i = 0; i < ctx->num_inputs(); ++i) { tuple.push_back(ctx->input(i)); } OP_REQUIRES_OK(ctx, buf->Put(&tuple)); diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index b46b405ffbf..075bacb432b 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" @@ -101,7 +102,7 @@ Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) { class TensorArrayCreationOp : public OpKernel { public: explicit TensorArrayCreationOp(OpKernelConstruction* context) - : OpKernel(context) {} + : OpKernel(context), device_type_(context->device_type()) {} void Compute(OpKernelContext* ctx) override { Tensor tensor_array_output_handle; @@ -133,6 +134,12 @@ class TensorArrayCreationOp : public OpKernel { // Create the flow output. Tensor* flow; OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &flow)); + if (device_type_ == DEVICE_CPU) { + // Value doesn't matter, but this makes msan not complaint about + // copying an uninitialized value. To do this on GPU would require + // a kernel launch or a host->device memcpy, so we avoid that. + flow->flat()(0) = 0; + } } } @@ -140,6 +147,9 @@ class TensorArrayCreationOp : public OpKernel { virtual Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, Tensor* tensor_array_output_handle, TensorArray** output_tensor_array) = 0; + + private: + const DeviceType device_type_; }; // A per-run local tensor array. The tensor array uses a "per-step" resource diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc index ef1d44fa944..2721885c4a7 100644 --- a/tensorflow/core/lib/gtl/inlined_vector_test.cc +++ b/tensorflow/core/lib/gtl/inlined_vector_test.cc @@ -778,6 +778,7 @@ BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024); static void BM_StdVectorFill(int iters, int len) { for (int i = 0; i < iters; i++) { std::vector v; + v.reserve(len); for (int j = 0; j < len; j++) { v.push_back(j); } @@ -810,6 +811,7 @@ static void BM_StdVectorFillString(int iters, int len) { "012345678901234567", "to cause allocation"}; for (int i = 0; i < iters; i++) { std::vector v; + v.reserve(len); for (int j = 0; j < len; j++) { v.push_back(strings[j & 3]); } diff --git a/tensorflow/core/lib/gtl/optional_test.cc b/tensorflow/core/lib/gtl/optional_test.cc index bd203b9e859..547bee7b75f 100644 --- a/tensorflow/core/lib/gtl/optional_test.cc +++ b/tensorflow/core/lib/gtl/optional_test.cc @@ -1078,6 +1078,7 @@ TEST(optionalTest, NoExcept) { static_assert( !std::is_nothrow_move_constructible>::value, ""); std::vector> v; + v.reserve(10); for (int i = 0; i < 10; ++i) v.emplace_back(); } diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index e9c313e9031..325dbc48835 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -248,6 +248,7 @@ Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) { int N; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N)); std::vector dys; + dys.reserve(N); for (int i = 0; i < N; ++i) { dys.push_back(strings::StrCat("dy:", i)); } diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 1fa5a4ed25e..85a6cfcac91 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -613,6 +613,7 @@ REGISTER_OP("Const") TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape())); TensorShape shape(proto->tensor_shape()); std::vector dims; + dims.reserve(shape.dims()); for (int i = 0; i < shape.dims(); ++i) { dims.push_back(c->MakeDim(shape.dim_size(i))); } @@ -894,6 +895,7 @@ REGISTER_OP("MatrixDiagPart") } const int32 rank = c->Rank(in); std::vector dims; + dims.reserve(rank - 2); for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i)); DimensionHandle min_dim; diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 1be68b6000e..a7b4422bab6 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -31,6 +31,7 @@ TEST(ArrayOpsTest, Pack_ShapeFn) { auto set_axis = [&op](int axis) { int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "Pack") .Input(src_list) @@ -281,6 +282,7 @@ TEST(ArrayOpsTest, ShapeN_ShapeFn) { ShapeInferenceTestOp op("ShapeN"); int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "ShapeN") .Input(src_list) @@ -546,6 +548,7 @@ TEST(ArrayOpsTest, Concat_ShapeFn) { ShapeInferenceTestOp op("Concat"); auto set_n = [&op](int n) { std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "Concat") .Input({"concat_dim", 0, DT_INT32}) @@ -619,6 +622,7 @@ TEST(ArrayOpsTest, ConcatV2_ShapeFn) { ShapeInferenceTestOp op("ConcatV2"); auto set_n = [&op](int n) { std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "ConcatV2") .Input(src_list) @@ -695,6 +699,7 @@ TEST(ArrayOpsTest, ConcatOffset_ShapeFn) { const int n = 4; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_INT32); TF_ASSERT_OK(NodeDefBuilder("test", "ConcatOffset") .Input({"concat_dim", 0, DT_INT32}) diff --git a/tensorflow/core/ops/control_flow_ops_test.cc b/tensorflow/core/ops/control_flow_ops_test.cc index 9aa14e27a0a..b6abafc51b8 100644 --- a/tensorflow/core/ops/control_flow_ops_test.cc +++ b/tensorflow/core/ops/control_flow_ops_test.cc @@ -28,6 +28,7 @@ TEST(ControlFlowOpsTest, Merge_ShapeFn) { int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "Merge") .Input(src_list) @@ -54,6 +55,7 @@ TEST(ControlFlowOpsTest, RefSelect_ShapeFn) { int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 1, DT_FLOAT_REF); TF_ASSERT_OK(NodeDefBuilder("test", "RefSelect") .Input("index", 0, DT_INT32) diff --git a/tensorflow/core/ops/functional_ops_test.cc b/tensorflow/core/ops/functional_ops_test.cc index 37ee301c3bd..64b5ccea5a8 100644 --- a/tensorflow/core/ops/functional_ops_test.cc +++ b/tensorflow/core/ops/functional_ops_test.cc @@ -33,6 +33,7 @@ TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) { in_type_list.emplace_back(DT_FLOAT); src_list.emplace_back("a", 0, DT_FLOAT); } + out_type_list.reserve(num_outputs); for (int i = 0; i < num_outputs; ++i) { out_type_list.emplace_back(DT_FLOAT); } diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 31bbe916f43..c10e667f564 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -27,6 +27,7 @@ TEST(MathOpsTest, AddN_ShapeFn) { ShapeInferenceTestOp op("AddN"); auto set_n = [&op](int n) { std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); TF_ASSERT_OK(NodeDefBuilder("test", "AddN") .Input(src_list) diff --git a/tensorflow/core/ops/sparse_ops_test.cc b/tensorflow/core/ops/sparse_ops_test.cc index b3ee92fa21e..21b27346889 100644 --- a/tensorflow/core/ops/sparse_ops_test.cc +++ b/tensorflow/core/ops/sparse_ops_test.cc @@ -255,6 +255,7 @@ TEST(SparseOpsTest, SparseConcat_ShapeFn) { ShapeInferenceTestOp op("SparseConcat"); std::vector src_list; int n = 2; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_INT64); TF_ASSERT_OK(NodeDefBuilder("test", "SparseConcat") .Input(src_list) diff --git a/tensorflow/core/ops/string_ops_test.cc b/tensorflow/core/ops/string_ops_test.cc index 79130bae2c0..f4d3adbb2a3 100644 --- a/tensorflow/core/ops/string_ops_test.cc +++ b/tensorflow/core/ops/string_ops_test.cc @@ -27,6 +27,7 @@ TEST(StringOpsTest, StringJoin_ShapeFn) { ShapeInferenceTestOp op("StringJoin"); int n = 3; std::vector src_list; + src_list.reserve(n); for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_STRING); TF_ASSERT_OK(NodeDefBuilder("test", "StringJoin") .Input(src_list) diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc index aced1aa8baf..232dcb3e71a 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc @@ -25,6 +25,7 @@ typedef std::vector> ExpectedCalls; ExpectedCalls CreateRetriableErrors(const string& method, int n) { ExpectedCalls expected_calls; + expected_calls.reserve(n); for (int i = 0; i < n; i++) { expected_calls.emplace_back(std::make_tuple( method, errors::Unavailable(strings::StrCat("Retriable error #", i)))); diff --git a/tensorflow/core/protobuf/meta_graph.proto b/tensorflow/core/protobuf/meta_graph.proto index 5b2022321e5..47ec2aa1efe 100644 --- a/tensorflow/core/protobuf/meta_graph.proto +++ b/tensorflow/core/protobuf/meta_graph.proto @@ -202,8 +202,34 @@ message CollectionDef { // Information about a Tensor necessary for feeding or retrieval. message TensorInfo { - string name = 1; + // For sparse tensors, The COO encoding stores a triple of values, indices, + // and shape. + message CooSparse { + // The shape of the values Tensor is [?]. Its dtype must be the dtype of + // the SparseTensor as a whole, given in the enclosing TensorInfo. + string values_tensor_name = 1; + + // The indices Tensor must have dtype int64 and shape [?, ?]. + string indices_tensor_name = 2; + + // The dynamic logical shape represented by the SparseTensor is recorded in + // the Tensor referenced here. It must have dtype int64 and shape [?]. + string dense_shape_tensor_name = 3; + } + + oneof encoding { + // For dense `Tensor`s, the name of the tensor in the graph. + string name = 1; + // There are many possible encodings of sparse matrices + // (https://en.wikipedia.org/wiki/Sparse_matrix). Currently, TensorFlow + // uses only the COO encoding. This is supported and documented in the + // SparseTensor Python class. + CooSparse coo_sparse = 4; + } DataType dtype = 2; + // The static shape should be recorded here, to the extent that it can + // be known in advance. In the case of a SparseTensor, this field describes + // the logical shape of the represented tensor (aka dense_shape). TensorShapeProto tensor_shape = 3; } diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc index 62025463af7..c86a70ec9d0 100644 --- a/tensorflow/core/util/command_line_flags_test.cc +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -27,6 +27,7 @@ namespace { std::vector CharPointerVectorFromStrings( const std::vector &strings) { std::vector result; + result.reserve(strings.size()); for (const string &string : strings) { result.push_back(const_cast(string.c_str())); } diff --git a/tensorflow/core/util/ctc/ctc_beam_search_test.cc b/tensorflow/core/util/ctc/ctc_beam_search_test.cc index 8c723e8e4fe..b2d5ef56adf 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search_test.cc +++ b/tensorflow/core/util/ctc/ctc_beam_search_test.cc @@ -150,6 +150,7 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) { // using Eigen::Map. Eigen::Map seq_len(&sequence_lengths[0], batch_size); std::vector> inputs; + inputs.reserve(timesteps); for (int t = 0; t < timesteps; ++t) { inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); } @@ -199,6 +200,7 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) { // using Eigen::Map. Eigen::Map seq_len(&sequence_lengths[0], batch_size); std::vector> inputs; + inputs.reserve(timesteps); for (int t = 0; t < timesteps; ++t) { inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); } @@ -293,6 +295,7 @@ TEST(CtcBeamSearch, LabelSelection) { // using Eigen::Map. Eigen::Map seq_len(&sequence_lengths[0], batch_size); std::vector> inputs; + inputs.reserve(timesteps); for (int t = 0; t < timesteps; ++t) { inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); } diff --git a/tensorflow/core/util/equal_graph_def_test.cc b/tensorflow/core/util/equal_graph_def_test.cc index af870c5c607..054cc92c169 100644 --- a/tensorflow/core/util/equal_graph_def_test.cc +++ b/tensorflow/core/util/equal_graph_def_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/util/equal_graph_def.h" #include "tensorflow/core/framework/node_def_util.h" @@ -40,7 +42,7 @@ Node* Alternate(const GraphDefBuilder::Options& opts) { Node* Combine(ops::NodeOut a, ops::NodeOut b, const GraphDefBuilder::Options& opts) { - return ops::BinaryOp("Combine", a, b, opts); + return ops::BinaryOp("Combine", std::move(a), std::move(b), opts); } class EqualGraphDefTest : public ::testing::Test { diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index dd04cea40d1..b495bc31b1f 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -255,6 +255,16 @@ Status CorruptFileError(const Status& in_status, const string& filename, detail, "): ", in_status.error_message())); } +table::Options TableBuilderOptions() { + table::Options o; + // Compressed tables cannot be read by TensorFlow releases prior to 1.1. + // To smoothen the transition, compressed writes are disabled for now + // (version 1.2) with the intention that they will be enabled again at + // some point (perhaps the 1.3 release?). + o.compression = table::kNoCompression; + return o; +} + } // namespace BundleWriter::BundleWriter(Env* env, StringPiece prefix) @@ -442,7 +452,7 @@ static Status MergeOneBundle(Env* env, StringPiece prefix, table::Table* table = nullptr; TF_RETURN_IF_ERROR( - table::Table::Open(table::Options(), file.get(), file_size, &table)); + table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table)); std::unique_ptr table_deleter(table); std::unique_ptr iter(table->NewIterator()); @@ -555,7 +565,7 @@ Status MergeBundles(Env* env, gtl::ArraySlice prefixes, TF_RETURN_IF_ERROR( env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata)); { - table::TableBuilder builder(table::Options(), merged_metadata.get()); + table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get()); // Header entry. BundleHeaderProto header; header.set_num_shards(merge.num_shards); @@ -630,6 +640,12 @@ BundleReader::~BundleReader() { delete metadata_; delete iter_; delete table_; + // InputBuffer does not own the underlying RandomAccessFile. + for (auto pair : data_) { + if (pair.second->file() != nullptr) { + delete pair.second->file(); + } + } gtl::STLDeleteValues(&data_); gtl::STLDeleteValues(&tensor_slices_); } @@ -684,14 +700,16 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { } } - // Open the data file if not opened it. - std::unique_ptr file = nullptr; - std::unique_ptr buffered_file(data_[entry.shard_id()]); + // Open the data file if it has not been opened. + io::InputBuffer* buffered_file = data_[entry.shard_id()]; if (buffered_file == nullptr) { + std::unique_ptr file = nullptr; TF_RETURN_IF_ERROR(env_->NewRandomAccessFile( DataFilename(prefix_, entry.shard_id(), num_shards_), &file)); - buffered_file.reset( - new io::InputBuffer(file.get(), 256 << 10 /* 256KB buffer */)); + buffered_file = + new io::InputBuffer(file.release(), 256 << 10 /* 256KB buffer */); + // The InputBuffer and RandomAccessFile objects are both released in dtor. + data_[entry.shard_id()] = buffered_file; } CHECK(buffered_file != nullptr); @@ -710,7 +728,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { // Relies on io::InputBuffer's buffering, because we issue many neighboring // reads for a single string tensor. TF_RETURN_IF_ERROR(ReadStringTensor( - buffered_file.get(), ret->NumElements(), entry.offset(), entry.size(), + buffered_file, ret->NumElements(), entry.offset(), entry.size(), GetStringBackingBuffer(*ret), &actual_crc32c)); } if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) { diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index 2c40388250c..9cc1f9fd21b 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -273,6 +273,7 @@ class BundleReader { RandomAccessFile* metadata_; // Owned. table::Table* table_; table::Iterator* iter_; + // Owned the InputBuffer objects and their underlying RandomAccessFile's. std::unordered_map data_; // Maps each partitioned tensor's key to its stored slices (represented in a diff --git a/tensorflow/docs_src/programmers_guide/embedding.md b/tensorflow/docs_src/programmers_guide/embedding.md new file mode 100644 index 00000000000..975850349f0 --- /dev/null +++ b/tensorflow/docs_src/programmers_guide/embedding.md @@ -0,0 +1,352 @@ +# Embeddings + +[TOC] + +## Introduction + +An embedding is a mapping from discrete objects, such as words, to vectors of +real numbers. For example, a 300-dimensional embedding for English words could +include: + +``` +blue: (0.01359, 0.00075997, 0.24608, ..., -0.2524, 1.0048, 0.06259) +blues: (0.01396, 0.11887, -0.48963, ..., 0.033483, -0.10007, 0.1158) +orange: (-0.24776, -0.12359, 0.20986, ..., 0.079717, 0.23865, -0.014213) +oranges: (-0.35609, 0.21854, 0.080944, ..., -0.35413, 0.38511, -0.070976) +``` + +Embeddings let you apply machine learning to discrete inputs. Classifiers, and +neural networks more generally, are designed to work with dense continuous +vectors, where all values contribute to define what an object is. If discrete +objects are naively encoded as discrete atoms, e.g., unique id numbers, they +hinder learning and generalization. One way to think of embeddings is as a way +to transform non-vector objects into useful inputs for machine learning. + +Embeddings are also useful as outputs of machine learning. Because embeddings +map objects to vectors, applications can use similarity in vector space (e.g., +Euclidean distance or the angle between vectors) as a robust and flexible +measure of object similarity. One common use is to find nearest neighbors. +Using the same word embeddings above, for instance, here are the three nearest +neighbors for each word and the corresponding angles (in degrees): + +``` +blue: (red, 47.6°), (yellow, 51.9°), (purple, 52.4°) +blues: (jazz, 53.3°), (folk, 59.1°), (bluegrass, 60.6°) +orange: (yellow, 53.5°), (colored, 58.0°), (bright, 59.9°) +oranges: (apples, 45.3°), (lemons, 48.3°), (mangoes, 50.4°) +``` + +This would tell an application that apples and oranges are in some way more +similar (45.3° apart) than lemons and oranges (48.3° apart). + +## Training an Embedding + +To train word embeddings in TensorFlow, we first need to split the text into +words and assign an integer to every word in the vocabulary. Let us assume that +this has already been done, and that `word_ids` is a vector of these integers. +For example, the sentence “I have a cat.” could be split into +`[“I”, “have”, “a”, “cat”, “.”]` and then the corresponding `word_ids` tensor +would have shape `[5]` and consist of 5 integers. To get these word ids +embedded, we need to create the embedding variable and use the `tf.gather` +function as follows: + +``` +word_embeddings = tf.get_variable(“word_embeddings”, + [vocabulary_size, embedding_size]) +embedded_word_ids = tf.gather(word_embeddings, word_ids) +``` + +After this, the tensor `embedded_word_ids` will have shape `[5, embedding_size]` +in our example and contain the embeddings (dense vectors) for each of the 5 +words. The variable `word_embeddings` will be learned and at the end of the +training it will contain the embeddings for all words in the vocabulary. +The embeddings can be trained in many ways, depending on the data available. +For example, one could use a recurrent neural network to predict the next word +from the previous one given a large corpus of sentences, or one could train +two networks to do multi-lingual translation. These methods are described in +[Vector Representations of Words](../tutorials/word2vec.md) tutorial, but in +all cases there is an embedding variable like above and words are embedded +using `tf.gather`, as shown. + +## Visualizing Embeddings + +TensorBoard has a built-in visualizer, called the Embedding Projector, +for interactive visualization of embeddings. The embedding projector will read +the embeddings from your checkpoint file and project them into 3 dimensions using +[principal component analysis](https://en.wikipedia.org/wiki/Principal_component_analysis). +For a visual explanation of PCA, see +[this article](http://setosa.io/ev/principal-component-analysis/). Another +very useful projection you can use is +[t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding). + +If you are working with an embedding, you'll probably want to attach +labels/images to the data points. You can do this by generating a +[metadata file](#metadata) containing the labels for each point and configuring +the projector either by using our Python API, or manually constructing and +saving a +[projector_config.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/plugins/projector/projector_config.proto) +in the same directory as your checkpoint file. + +### Setup + +For in depth information on how to run TensorBoard and make sure you are +logging all the necessary information, see +[TensorBoard: Visualizing Learning](../get_started/summaries_and_tensorboard.md). + +To visualize your embeddings, there are 3 things you need to do: + +1) Setup a 2D tensor that holds your embedding(s). + +```python +embedding_var = tf.get_variable(....) +``` + +2) Periodically save your model variables in a checkpoint in +LOG_DIR. + +```python +saver = tf.train.Saver() +saver.save(session, os.path.join(LOG_DIR, "model.ckpt"), step) +``` + +3) (Optional) Associate metadata with your embedding. + +If you have any metadata (labels, images) associated with your embedding, you +can tell TensorBoard about it either by directly storing a +[projector_config.pbtxt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/plugins/projector/projector_config.proto) +in the LOG_DIR, or use our python API. + +For instance, the following projector_config.ptxt associates the +word_embedding tensor with metadata stored in $LOG_DIR/metadata.tsv: + +``` +embeddings { + tensor_name: 'word_embedding' + metadata_path: '$LOG_DIR/metadata.tsv' +} +``` + +The same config can be produced programmatically using the following code snippet: + +```python +from tensorflow.contrib.tensorboard.plugins import projector + +# Create randomly initialized embedding weights which will be trained. +vocabulary_size = 10000 +embedding_size = 200 +embedding_var = tf.get_variable('word_embedding', [vocabulary_size, embedding_size]) + +# Format: tensorflow/tensorboard/plugins/projector/projector_config.proto +config = projector.ProjectorConfig() + +# You can add multiple embeddings. Here we add only one. +embedding = config.embeddings.add() +embedding.tensor_name = embedding_var.name +# Link this tensor to its metadata file (e.g. labels). +embedding.metadata_path = os.path.join(LOG_DIR, 'metadata.tsv') + +# Use the same LOG_DIR where you stored your checkpoint. +summary_writer = tf.summary.FileWriter(LOG_DIR) + +# The next line writes a projector_config.pbtxt in the LOG_DIR. TensorBoard will +# read this file during startup. +projector.visualize_embeddings(summary_writer, config) +``` + +After running your model and training your embeddings, run TensorBoard and point +it to the LOG_DIR of the job. + +```python +tensorboard --logdir=LOG_DIR +``` + +Then click on the *Embeddings* tab on the top pane +and select the appropriate run (if there are more than one run). + + +### Metadata +Usually embeddings have metadata associated with it (e.g. labels, images). The +metadata should be stored in a separate file outside of the model checkpoint +since the metadata is not a trainable parameter of the model. The format should +be a [TSV file](https://en.wikipedia.org/wiki/Tab-separated_values) +(tab characters shown in red) with the first line containing column headers +(shown in bold) and subsequent lines contain the metadata values: + + +Word\tFrequency
+ Airplane\t345
+ Car\t241
+ ... +
+ +There is no explicit key shared with the main data file; instead, the order in +the metadata file is assumed to match the order in the embedding tensor. In +other words, the first line is the header information and the (i+1)-th line in +the metadata file corresponds to the i-th row of the embedding tensor stored in +the checkpoint. + +Note: If the TSV metadata file has only a single column, then we don’t expect a +header row, and assume each row is the label of the embedding. We include this +exception because it matches the commonly-used "vocab file" format. + +### Images +If you have images associated with your embeddings, you will need to +produce a single image consisting of small thumbnails of each data point. +This is known as the +[sprite image](https://www.google.com/webhp#q=what+is+a+sprite+image). +The sprite should have the same number of rows and columns with thumbnails +stored in row-first order: the first data point placed in the top left and the +last data point in the bottom right: + + + + + + + + + + + + + + + + + +
012
345
67
+ +Note in the example above that the last row doesn't have to be filled. For a +concrete example of a sprite, see +[this sprite image](https://www.tensorflow.org/images/mnist_10k_sprite.png) of 10,000 MNIST digits +(100x100). + +Note: We currently support sprites up to 8192px X 8192px. + +After constructing the sprite, you need to tell the Embedding Projector where +to find it: + + +```python +embedding.sprite.image_path = PATH_TO_SPRITE_IMAGE +# Specify the width and height of a single thumbnail. +embedding.sprite.single_image_dim.extend([w, h]) +``` + +### Interaction + +The Embedding Projector has three panels: + +1. *Data panel* on the top left, where you can choose the run, the embedding + tensor and data columns to color and label points by. +2. *Projections panel* on the bottom left, where you choose the type of + projection (e.g. PCA, t-SNE). +3. *Inspector panel* on the right side, where you can search for particular + points and see a list of nearest neighbors. + +### Projections +The Embedding Projector has three methods of reducing the dimensionality of a +data set: two linear and one nonlinear. Each method can be used to create either +a two- or three-dimensional view. + +**Principal Component Analysis** A straightforward technique for reducing +dimensions is Principal Component Analysis (PCA). The Embedding Projector +computes the top 10 principal components. The menu lets you project those +components onto any combination of two or three. PCA is a linear projection, +often effective at examining global geometry. + +**t-SNE** A popular non-linear dimensionality reduction technique is t-SNE. +The Embedding Projector offers both two- and three-dimensional t-SNE views. +Layout is performed client-side animating every step of the algorithm. Because +t-SNE often preserves some local structure, it is useful for exploring local +neighborhoods and finding clusters. Although extremely useful for visualizing +high-dimensional data, t-SNE plots can sometimes be mysterious or misleading. +See this [great article](http://distill.pub/2016/misread-tsne/) for how to use +t-SNE effectively. + +**Custom** You can also construct specialized linear projections based on text +searches for finding meaningful directions in space. To define a projection +axis, enter two search strings or regular expressions. The program computes the +centroids of the sets of points whose labels match these searches, and uses the +difference vector between centroids as a projection axis. + +### Navigation + +To explore a data set, you can navigate the views in either a 2D or a 3D mode, +zooming, rotating, and panning using natural click-and-drag gestures. +Clicking on a point causes the right pane to show an explicit textual list of +nearest neighbors, along with distances to the current point. The +nearest-neighbor points themselves are highlighted on the projection. + +Zooming into the cluster gives some information, but it is sometimes more +helpful to restrict the view to a subset of points and perform projections only +on those points. To do so, you can select points in multiple ways: + +1. After clicking on a point, its nearest neighbors are also selected. +2. After a search, the points matching the query are selected. +3. Enabling selection, clicking on a point and dragging defines a selection + sphere. + +After selecting a set of points, you can isolate those points for +further analysis on their own with the "Isolate Points" button in the Inspector +pane on the right hand side. + + +![Selection of nearest neighbors](https://www.tensorflow.org/images/embedding-nearest-points.png "Selection of nearest neighbors") +*Selection of the nearest neighbors of “important” in a word embedding dataset.* + +The combination of filtering with custom projection can be powerful. Below, we filtered +the 100 nearest neighbors of “politics” and projected them onto the +“best” - “worst” vector as an x axis. The y axis is random. + +You can see that on the right side we have “ideas”, “science”, “perspective”, +“journalism” while on the left we have “crisis”, “violence” and “conflict”. + + + + + + + + + + +
+ Custom controls panel + + Custom projection +
+ Custom projection controls. + + Custom projection of neighbors of "politics" onto "best" - "worst" vector. +
+ +### Collaborative Features + +To share your findings, you can use the bookmark panel in the bottom right +corner and save the current state (including computed coordinates of any +projection) as a small file. The Projector can then be pointed to a set of one +or more of these files, producing the panel below. Other users can then walk +through a sequence of bookmarks. + +Bookmark panel + + +## Mini-FAQ + +**Is "embedding" an action or a thing?** +Both. People talk about embedding words in a vector space (action) and about +producing word embeddings (things). Common to both is the notion of embedding +as a mapping from discrete objects to vectors. Creating or applying that +mapping is an action, but the mapping itself is a thing. + +**Are embeddings high-dimensional or low-dimensional?** +It depends. A 300-dimensional vector space of words and phrases, for instance, +is often called low-dimensional (and dense) when compared to the millions of +words and phrases it can contain. But mathematically it is high-dimensional, +displaying many properties that are dramatically different from what our human +intuition has learned about 2- and 3-dimensional spaces. + +**Is an embedding the same as an embedding layer?** +No; an embedding layer is a part of neural network, but an embedding is a more +general concept. diff --git a/tensorflow/examples/wav_to_spectrogram/BUILD b/tensorflow/examples/wav_to_spectrogram/BUILD index 1e72324fb05..5923fa59293 100644 --- a/tensorflow/examples/wav_to_spectrogram/BUILD +++ b/tensorflow/examples/wav_to_spectrogram/BUILD @@ -15,12 +15,8 @@ exports_files(["LICENSE"]) cc_library( name = "wav_to_spectrogram_lib", - srcs = [ - "wav_to_spectrogram.cc", - ], - hdrs = [ - "wav_to_spectrogram.h", - ], + srcs = ["wav_to_spectrogram.cc"], + hdrs = ["wav_to_spectrogram.h"], deps = [ "//tensorflow/cc:cc_ops", "//tensorflow/core:framework_internal", @@ -30,13 +26,10 @@ cc_library( cc_binary( name = "wav_to_spectrogram", - srcs = [ - "main.cc", - ], + srcs = ["main.cc"], deps = [ ":wav_to_spectrogram_lib", "//tensorflow/core:framework_internal", - "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index d4c83222c04..d709abef2b4 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -208,6 +208,16 @@ func FakeQuantWithMinMaxVarsPerChannelGradientNumBits(value int64) FakeQuantWith } } +// FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange sets the optional narrow_range attribute to value. +// +// value: Whether to quantize into 2^num_bits - 1 distinct values. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation. // // Arguments: @@ -254,16 +264,26 @@ func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { } } +// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Fake-quantize the 'inputs' tensor of type float via global float scalars `min` // // and `max` to 'outputs' tensor of same shape as `inputs`. // -// [min; max] is the clamping range for the 'inputs' data. Op divides this range -// into 255 steps (total of 256 values), then replaces each 'inputs' value with the -// closest of the quantized step values. -// 'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive. +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. // -// This operation has a gradient and thus allows for training `min` and `max` values. +// This operation has a gradient and thus allows for training `min` and `max` +// values. func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { if scope.Err() != nil { return @@ -2149,7 +2169,7 @@ func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { // dimension. Must sum to the dimension of value along split_dim. // Can contain one -1 indicating that dimension is to be inferred. // split_dim: 0-D. The dimension along which to split. Must be in the range -// `[0, rank(value))`. +// `[-rank(value), rank(value))`. // // // Returns Tensors whose shape matches that of `value` @@ -2184,7 +2204,7 @@ func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, split_dim tf.O // // Arguments: // split_dim: 0-D. The dimension along which to split. Must be in the range -// `[0, rank(value))`. +// `[-rank(value), rank(value))`. // value: The tensor to split. // num_split: The number of ways to split. Must evenly divide // `value.shape[split_dim]`. @@ -3325,12 +3345,21 @@ func FakeQuantWithMinMaxArgsNumBits(value int64) FakeQuantWithMinMaxArgsAttr { } } +// FakeQuantWithMinMaxArgsNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxArgsNarrowRange(value bool) FakeQuantWithMinMaxArgsAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. // -// Attributes [min; max] define the clamping range for the 'inputs' data. Op -// divides this range into 255 steps (total of 256 values), then replaces each -// 'inputs' value with the closest of the quantized step values. -// 'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive. +// Attributes `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. // // Quantization is called fake since the output is still in floating point. func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsAttr) (outputs tf.Output) { @@ -6410,6 +6439,14 @@ func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgs } } +// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Compute gradients for a FakeQuantWithMinMaxArgs operation. // // Arguments: @@ -8601,17 +8638,27 @@ func FakeQuantWithMinMaxVarsPerChannelNumBits(value int64) FakeQuantWithMinMaxVa } } +// FakeQuantWithMinMaxVarsPerChannelNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsPerChannelNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`, // // `[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]` // to 'outputs' tensor of same shape as `inputs`. // -// [min; max] is the clamping range for the 'inputs' data in the corresponding -// depth channel. Op divides this range into 255 steps (total of 256 values), then -// replaces each 'inputs' value with the closest of the quantized step values. -// 'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive. +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. // -// This operation has a gradient and thus allows for training `min` and `max` values. +// This operation has a gradient and thus allows for training `min` and `max` +// values. func FakeQuantWithMinMaxVarsPerChannel(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelAttr) (outputs tf.Output) { if scope.Err() != nil { return @@ -21779,6 +21826,16 @@ func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVars } } +// FakeQuantWithMinMaxVarsGradientNarrowRange sets the optional narrow_range attribute to value. +// +// value: Whether to quantize into 2^num_bits - 1 distinct values. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + // Compute gradients for a FakeQuantWithMinMaxVars operation. // // Arguments: diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b735964bc6b..f23c528e0dd 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -320,7 +320,10 @@ py_test( cc_library( name = "python_op_gen", srcs = ["framework/python_op_gen.cc"], - hdrs = ["framework/python_op_gen.h"], + hdrs = [ + "framework/python_op_gen.h", + "framework/python_op_gen_internal.h", + ], visibility = ["//visibility:public"], deps = [ "//tensorflow/core:framework", @@ -2954,39 +2957,34 @@ tf_cuda_library( alwayslink = 1, ) -# Disabled due to http://b/62145493 -# py_test( -# name = "session_test", -# size = "medium", # http://b/62144199 -# srcs = ["client/session_test.py"], -# srcs_version = "PY2AND3", -# tags = [ -# "no_gpu", -# "no_pip_gpu", # testInteractivePlacePrunedGraph fails on invalid assumption about GPU ops. -# "no_windows", -# ], -# deps = [ -# ":array_ops", -# ":client", -# ":construction_fails_op", -# ":control_flow_ops", -# ":data_flow_ops", -# ":errors", -# ":framework", -# ":framework_for_generated_wrappers", -# ":framework_test_lib", -# ":math_ops", -# ":platform_test", -# ":state_ops", -# ":training", -# ":util", -# ":variables", -# "//third_party/py/numpy", -# "@six_archive//:six", -# "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", -# "//tensorflow/core/distributed_runtime/rpc:grpc_session", -# ], -# ) +py_test( + name = "session_test", + size = "small", + srcs = ["client/session_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_gpu", + "no_pip_gpu", # testInteractivePlacePrunedGraph fails on invalid assumption about GPU ops. + ], + deps = [ + ":array_ops", + ":client", + ":control_flow_ops", + ":data_flow_ops", + ":errors", + ":framework", + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":math_ops", + ":platform_test", + ":state_ops", + ":training", + ":util", + ":variables", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) py_test( name = "session_clusterspec_prop_test", @@ -3575,13 +3573,11 @@ py_test( ], ) -py_test( +cuda_py_test( name = "layers_normalization_test", size = "small", srcs = ["layers/normalization_test.py"], - main = "layers/normalization_test.py", - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":array_ops", ":client_testlib", ":framework_for_generated_wrappers", @@ -3591,6 +3587,7 @@ py_test( ":variables", "//third_party/py/numpy", ], + main = "layers/normalization_test.py", ) # ----------------------------------------------------------------------------- @@ -3820,10 +3817,17 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ + ":array_ops", ":client_testlib", ":cost_analyzer", ":framework_for_generated_wrappers", ":math_ops", + ":nn", + ":nn_grad", + ":random_ops", + ":state_ops", + ":training", + ":variables", "//tensorflow/core:protos_all_py", "//third_party/py/numpy", ], diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py new file mode 100644 index 00000000000..db5a72b8f21 --- /dev/null +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -0,0 +1,373 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow estimators for Linear and DNN joined training models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import six + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import sync_replicas_optimizer +from tensorflow.python.training import training_util + +# The default learning rates are a historical artifact of the initial +# implementation, but seem a reasonable choice. +_DNN_LEARNING_RATE = 0.05 +_LINEAR_LEARNING_RATE = 0.2 + + +def _check_no_sync_replicas_optimizer(optimizer): + if isinstance(optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): + raise ValueError( + 'SyncReplicasOptimizer does not support multi optimizers case. ' + 'Therefore, it is not supported in DNNLinearCombined model. ' + 'If you want to use this optimizer, please use either DNN or Linear ' + 'model.') + + +def _linear_learning_rate(num_linear_feature_columns): + """Returns the default learning rate of the linear model. + + The calculation is a historical artifact of this initial implementation, but + has proven a reasonable choice. + + Args: + num_linear_feature_columns: The number of feature columns of the linear + model. + + Returns: + A float. + """ + default_learning_rate = 1. / math.sqrt(num_linear_feature_columns) + return min(_LINEAR_LEARNING_RATE, default_learning_rate) + + +def _add_layer_summary(value, tag): + summary.scalar('%s/fraction_of_zero_values' % tag, nn.zero_fraction(value)) + summary.histogram('%s/activation' % tag, value) + + +def _dnn_linear_combined_model_fn( + features, labels, mode, head, + linear_feature_columns=None, linear_optimizer='Ftrl', + dnn_feature_columns=None, dnn_optimizer='Adagrad', dnn_hidden_units=None, + dnn_activation_fn=nn.relu, dnn_dropout=None, + input_layer_partitioner=None, config=None): + """Deep Neural Net and Linear combined model_fn. + + Args: + features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`). + labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype + `int32` or `int64` in the range `[0, n_classes)`. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + head: A `Head` instance. + linear_feature_columns: An iterable containing all the feature columns used + by the Linear model. + linear_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the Linear model. Defaults to the Ftrl + optimizer. + dnn_feature_columns: An iterable containing all the feature columns used by + the DNN model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN model. Defaults to the Adagrad + optimizer. + dnn_hidden_units: List of hidden units per DNN layer. + dnn_activation_fn: Activation function applied to each DNN layer. If `None`, + will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability we will drop out a given DNN + coordinate. + input_layer_partitioner: Partitioner for input layer. + config: `RunConfig` object to configure the runtime settings. + + Returns: + `ModelFnOps` + + Raises: + ValueError: If both `linear_feature_columns` and `dnn_features_columns` + are empty at the same time, or `input_layer_partitioner` is missing. + """ + if not linear_feature_columns and not dnn_feature_columns: + raise ValueError( + 'Either linear_feature_columns or dnn_feature_columns must be defined.') + num_ps_replicas = config.num_ps_replicas if config else 0 + input_layer_partitioner = input_layer_partitioner or ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas, + min_slice_size=64 << 20)) + + linear_optimizer = optimizers.get_optimizer_instance( + linear_optimizer, + learning_rate=_linear_learning_rate(len(linear_feature_columns))) + _check_no_sync_replicas_optimizer(linear_optimizer) + + dnn_optimizer = optimizers.get_optimizer_instance( + dnn_optimizer, + learning_rate=_DNN_LEARNING_RATE) + _check_no_sync_replicas_optimizer(dnn_optimizer) + + # Build DNN Logits. + dnn_parent_scope = 'dnn' + + if not dnn_feature_columns: + dnn_logits = None + else: + if not dnn_hidden_units: + raise ValueError( + 'dnn_hidden_units must be defined when dnn_feature_columns is ' + 'specified.') + dnn_partitioner = ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas)) + with variable_scope.variable_scope( + dnn_parent_scope, + values=tuple(six.itervalues(features)), + partitioner=dnn_partitioner): + with variable_scope.variable_scope('input', + partitioner=input_layer_partitioner): + net = feature_column_lib.input_layer( + features=features, + feature_columns=dnn_feature_columns) + + for layer_id, num_hidden_units in enumerate(dnn_hidden_units): + with variable_scope.variable_scope( + 'hiddenlayer_%d' % layer_id, + values=(net,)) as dnn_hidden_layer_scope: + net = core_layers.dense( + net, + units=num_hidden_units, + activation=dnn_activation_fn, + kernel_initializer=init_ops.glorot_uniform_initializer(), + name=dnn_hidden_layer_scope) + if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: + net = core_layers.dropout(net, rate=dnn_dropout, training=True) + _add_layer_summary(net, dnn_hidden_layer_scope.name) + + with variable_scope.variable_scope( + 'logits', + values=(net,)) as dnn_logits_scope: + logits = core_layers.dense( + net, + units=head.logits_dimension, + activation=None, + kernel_initializer=init_ops.glorot_uniform_initializer(), + name=dnn_logits_scope) + _add_layer_summary(dnn_logits, dnn_logits_scope.name) + + linear_parent_scope = 'linear' + + if not linear_feature_columns: + linear_logits = None + else: + with variable_scope.variable_scope( + linear_parent_scope, + values=tuple(six.itervalues(features)), + partitioner=input_layer_partitioner) as scope: + linear_logits = feature_column_lib.linear_model( + features=features, + feature_columns=linear_feature_columns, + units=head.logits_dimension) + _add_layer_summary(linear_logits, scope.name) + + # Combine logits and build full model. + if dnn_logits is not None and linear_logits is not None: + logits = dnn_logits + linear_logits + elif dnn_logits is not None: + logits = dnn_logits + else: + logits = linear_logits + + def _train_op_fn(loss): + """Returns the op to optimize the loss.""" + train_ops = [] + global_step = training_util.get_global_step() + if dnn_logits is not None: + train_ops.append( + dnn_optimizer.minimize( + loss, + var_list=ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, + scope=dnn_parent_scope))) + if linear_logits is not None: + train_ops.append( + linear_optimizer.minimize( + loss, + var_list=ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, + scope=linear_parent_scope))) + + train_op = control_flow_ops.group(*train_ops) + with ops.control_dependencies([train_op]): + with ops.colocate_with(global_step): + return state_ops.assign_add(global_step, 1) + + return head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + +class DNNLinearCombinedRegressor(estimator.Estimator): + """An estimator for TensorFlow Linear and DNN joined models for regresssion. + + Note: This estimator is also known as wide-n-deep. + + Example: + + ```python + numeric_feature = numeric_column(...) + sparse_column_a = categorical_column_with_hash_bucket(...) + sparse_column_b = categorical_column_with_hash_bucket(...) + + sparse_feature_a_x_sparse_feature_b = crossed_column(...) + sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a, + ...) + sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b, + ...) + + estimator = DNNLinearCombinedRegressor( + # wide settings + linear_feature_columns=[sparse_feature_a_x_sparse_feature_b], + linear_optimizer=tf.train.FtrlOptimizer(...), + # deep settings + dnn_feature_columns=[ + sparse_feature_a_emb, sparse_feature_b_emb, numeric_feature], + dnn_hidden_units=[1000, 500, 100], + dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) + + # To apply L1 and L2 regularization, you can set optimizers as follows: + tf.train.ProximalAdagradOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.001) + # It is same for FtrlOptimizer. + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * for each `column` in `dnn_feature_columns` + `linear_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + """ + + def __init__(self, + model_dir=None, + linear_feature_columns=None, + linear_optimizer=None, + dnn_feature_columns=None, + dnn_optimizer=None, + dnn_hidden_units=None, + dnn_activation_fn=nn.relu, + dnn_dropout=None, + label_dimension=1, + input_layer_partitioner=None, + config=None): + """Initializes a DNNLinearCombinedRegressor instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + linear_feature_columns: An iterable containing all the feature columns + used by linear part of the model. All items in the set must be + instances of classes derived from `FeatureColumn`. + linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the linear part of the model. If `None`, will use a FTRL optimizer. + dnn_feature_columns: An iterable containing all the feature columns used + by deep part of the model. All items in the set must be instances of + classes derived from `FeatureColumn`. + dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the deep part of the model. If `None`, will use an Adagrad optimizer. + dnn_hidden_units: List of hidden units per layer. All layers are fully + connected. + dnn_activation_fn: Activation function applied to each layer. If None, + will use `tf.nn.relu`. + dnn_dropout: When not None, the probability we will drop out + a given coordinate. + label_dimension: Number of regression targets per example. This is the + size of the last dimension of the labels and logits `Tensor` objects + (typically, these have shape `[batch_size, label_dimension]`). + input_layer_partitioner: Partitioner for input layer. Defaults to + `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: RunConfig object to configure the runtime settings. + + Raises: + ValueError: If both linear_feature_columns and dnn_features_columns are + empty at the same time. + """ + linear_feature_columns = linear_feature_columns or [] + dnn_feature_columns = dnn_feature_columns or [] + self._feature_columns = linear_feature_columns + dnn_feature_columns + if not self._feature_columns: + raise ValueError('Either linear_feature_columns or dnn_feature_columns ' + 'must be defined.') + + def _model_fn(features, labels, mode, config): + return _dnn_linear_combined_model_fn( + features=features, + labels=labels, + mode=mode, + head=head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access + label_dimension=label_dimension), + linear_feature_columns=linear_feature_columns, + linear_optimizer=linear_optimizer, + dnn_feature_columns=dnn_feature_columns, + dnn_optimizer=dnn_optimizer, + dnn_hidden_units=dnn_hidden_units, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + input_layer_partitioner=input_layer_partitioner, + config=config) + + super(DNNLinearCombinedRegressor, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/python/estimator/canned/dnn_test.py b/tensorflow/python/estimator/canned/dnn_test.py index c5af8e86876..a424cc0ed4c 100644 --- a/tensorflow/python/estimator/canned/dnn_test.py +++ b/tensorflow/python/estimator/canned/dnn_test.py @@ -462,6 +462,86 @@ class DNNRegressorEvaluateTest(test.TestCase): }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1)) +class DNNClassifierEvaluateTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def test_one_dim(self): + """Asserts evaluation metrics for one-dimensional input and logits.""" + global_step = 100 + _create_checkpoint(( + ([[.6, .5]], [.1, -.1]), + ([[1., .8], [-.8, -1.]], [.2, -.2]), + ([[-1.], [1.]], [.3]), + ), global_step, self._model_dir) + + dnn_classifier = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=[feature_column.numeric_column('age')], + model_dir=self._model_dir) + def _input_fn(): + # batch_size = 2, one false label, and one true. + return {'age': [[10.], [10.]]}, [[1], [0]] + # Uses identical numbers as DNNModelTest.test_one_dim_logits. + # See that test for calculation of logits. + # logits = [[-2.08], [-2.08]] => + # logistic = 1/(1 + exp(-logits)) = [[0.11105597], [0.11105597]] + # loss = -1. * log(0.111) -1. * log(0.889) = 2.31544200 + expected_loss = 2.31544200 + self.assertAllClose({ + metric_keys.MetricKeys.LOSS: expected_loss, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2., + metric_keys.MetricKeys.ACCURACY: 0.5, + metric_keys.MetricKeys.PREDICTION_MEAN: 0.11105597, + metric_keys.MetricKeys.LABEL_MEAN: 0.5, + metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, + # There is no good way to calculate AUC for only two data points. But + # that is what the algorithm returns. + metric_keys.MetricKeys.AUC: 0.5, + metric_keys.MetricKeys.AUC_PR: 0.75, + ops.GraphKeys.GLOBAL_STEP: global_step + }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1)) + + def test_multi_dim(self): + """Asserts evaluation metrics for multi-dimensional input and logits.""" + global_step = 100 + _create_checkpoint(( + ([[.6, .5], [-.6, -.5]], [.1, -.1]), + ([[1., .8], [-.8, -1.]], [.2, -.2]), + ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]), + ), global_step, self._model_dir) + n_classes = 3 + + dnn_classifier = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=[feature_column.numeric_column('age', shape=[2])], + n_classes=n_classes, + model_dir=self._model_dir) + def _input_fn(): + # batch_size = 2, one false label, and one true. + return {'age': [[10., 8.], [10., 8.]]}, [[1], [0]] + # Uses identical numbers as + # DNNModelFnTest.test_multi_dim_input_multi_dim_logits. + # See that test for calculation of logits. + # logits = [[-0.48, 0.48, 0.39], [-0.48, 0.48, 0.39]] + # probabilities = exp(logits)/sum(exp(logits)) + # = [[0.16670536, 0.43538380, 0.39791084], + # [0.16670536, 0.43538380, 0.39791084]] + # loss = -log(0.43538380) - log(0.16670536) + expected_loss = 2.62305466 + self.assertAllClose({ + metric_keys.MetricKeys.LOSS: expected_loss, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, + metric_keys.MetricKeys.ACCURACY: 0.5, + ops.GraphKeys.GLOBAL_STEP: global_step + }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1)) + + class DNNRegressorPredictTest(test.TestCase): def setUp(self): @@ -524,6 +604,85 @@ class DNNRegressorPredictTest(test.TestCase): }, next(dnn_regressor.predict(input_fn=input_fn))) +class DNNClassifierPredictTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def test_one_dim(self): + """Asserts predictions for one-dimensional input and logits.""" + _create_checkpoint(( + ([[.6, .5]], [.1, -.1]), + ([[1., .8], [-.8, -1.]], [.2, -.2]), + ([[-1.], [1.]], [.3]), + ), global_step=0, model_dir=self._model_dir) + + dnn_classifier = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=(feature_column.numeric_column('x'),), + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + x={'x': np.array([[10.]])}, batch_size=1, shuffle=False) + # Uses identical numbers as DNNModelTest.test_one_dim_logits. + # See that test for calculation of logits. + # logits = [-2.08] => + # logistic = exp(-2.08)/(1 + exp(-2.08)) = 0.11105597 + # probabilities = [1-logistic, logistic] = [0.88894403, 0.11105597] + # class_ids = argmax(probabilities) = [0] + self.assertAllClose({ + prediction_keys.PredictionKeys.LOGITS: [-2.08], + prediction_keys.PredictionKeys.LOGISTIC: [0.11105597], + prediction_keys.PredictionKeys.PROBABILITIES: [0.88894403, 0.11105597], + prediction_keys.PredictionKeys.CLASS_IDS: [0], + }, next(dnn_classifier.predict(input_fn=input_fn))) + + def test_multi_dim(self): + """Asserts predictions for multi-dimensional input and logits.""" + _create_checkpoint(( + ([[.6, .5], [-.6, -.5]], [.1, -.1]), + ([[1., .8], [-.8, -1.]], [.2, -.2]), + ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]), + ), global_step=0, model_dir=self._model_dir) + + dnn_classifier = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=(feature_column.numeric_column('x', shape=(2,)),), + n_classes=3, + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + # Inputs shape is (batch_size, num_inputs). + x={'x': np.array([[10., 8.]])}, + batch_size=1, + shuffle=False) + # Uses identical numbers as + # DNNModelFnTest.test_multi_dim_input_multi_dim_logits. + # See that test for calculation of logits. + # logits = [-0.48, 0.48, 0.39] => + # probabilities[i] = exp(logits[i]) / sum_j exp(logits[j]) => + # probabilities = [0.16670536, 0.43538380, 0.39791084] + # class_ids = argmax(probabilities) = [1] + predictions = next(dnn_classifier.predict(input_fn=input_fn)) + self.assertItemsEqual( + [prediction_keys.PredictionKeys.LOGITS, + prediction_keys.PredictionKeys.PROBABILITIES, + prediction_keys.PredictionKeys.CLASS_IDS, + prediction_keys.PredictionKeys.CLASSES], + six.iterkeys(predictions)) + self.assertAllClose( + [-0.48, 0.48, 0.39], predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose( + [0.16670536, 0.43538380, 0.39791084], + predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllEqual( + [1], predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertAllEqual( + [b'1'], predictions[prediction_keys.PredictionKeys.CLASSES]) + + def _queue_parsed_features(feature_map): tensors_to_enqueue = [] keys = [] diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index b06940ae611..631ddfc5dfc 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -459,7 +459,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): def _binary_logistic_head_with_sigmoid_cross_entropy_loss( - weight_feature_key=None, thresholds=(0.5,)): + weight_feature_key=None, thresholds=None): """Creates a `Head` for single label binary classification. This head uses `sigmoid_cross_entropy_with_logits` loss. @@ -482,6 +482,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( Raises: ValueError: if `thresholds` contains a value outside of `(0, 1)`. """ + thresholds = tuple(thresholds) if thresholds else tuple() for threshold in thresholds: if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,)) @@ -494,7 +495,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): def __init__(self, weight_feature_key=None, thresholds=None): self._weight_feature_key = weight_feature_key - self._thresholds = tuple(thresholds) + self._thresholds = thresholds @property def logits_dimension(self): diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 34c1eb6c828..0efafac87ab 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -845,7 +845,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): prediction_keys.PredictionKeys.CLASS_IDS: np.array(((1,), (0,)), dtype=np.int64), } - default_threshold = .5 keys = metric_keys.MetricKeys expected_metrics = { # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41 @@ -857,9 +856,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., keys.AUC_PR: 1., - keys.ACCURACY_AT_THRESHOLD % default_threshold: 1./2, - keys.PRECISION_AT_THRESHOLD % default_threshold: 2./2, - keys.RECALL_AT_THRESHOLD % default_threshold: 1./2, } # Assert spec contains expected tensors. @@ -888,6 +884,44 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertAllClose( expected_metrics, {k: value_ops[k].eval() for k in value_ops}) + def test_eval_with_thresholds(self): + thresholds = [0.25, 0.5, 0.75] + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + thresholds=thresholds) + + # Create estimator spec. + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.float32)}, + mode=model_fn.ModeKeys.EVAL, + logits=np.array(((-1,), (1,),), dtype=np.float32), + labels=np.array(((1,), (1,),), dtype=np.int32)) + + # probabilities[i] = 1/(1 + exp(-logits[i])) => + # probabilities = [1/(1 + exp(1)), 1/(1 + exp(-1))] = [0.269, 0.731] + # loss = -sum(ln(probabilities[label[i]])) = -ln(0.269) -ln(0.731) + # = 1.62652338 + keys = metric_keys.MetricKeys + expected_metrics = { + keys.LOSS_MEAN: 1.62652338 / 2., + keys.ACCURACY: 1./2, + keys.PREDICTION_MEAN: 1./2, + keys.LABEL_MEAN: 2./2, + keys.ACCURACY_BASELINE: 2./2, + keys.AUC: 0., + keys.AUC_PR: 1., + keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1., + keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1., + keys.RECALL_AT_THRESHOLD % thresholds[0]: 1., + keys.ACCURACY_AT_THRESHOLD % thresholds[1]: .5, + keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1., + keys.RECALL_AT_THRESHOLD % thresholds[1]: .5, + keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 0., + keys.PRECISION_AT_THRESHOLD % thresholds[2]: 0., + keys.RECALL_AT_THRESHOLD % thresholds[2]: 0., + } + + self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) + def test_train(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() @@ -1000,7 +1034,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): logits=logits, labels=np.array(((1,), (1,), (0,)), dtype=np.int32)) - default_threshold = .5 # label_mean = (1*1 + .1*1 + 1.5*0)/(1 + .1 + 1.5) = 1.1/2.6 # = .42307692307 expected_label_mean = .42307692307 @@ -1021,11 +1054,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.ACCURACY_BASELINE: 1 - expected_label_mean, keys.AUC: .45454565, keys.AUC_PR: .6737757325172424, - keys.ACCURACY_AT_THRESHOLD % default_threshold: .38461538461, - # precision = (1*1 + 1.5*0)/(1 + 1.5) = 1/2.5 = .4 - keys.PRECISION_AT_THRESHOLD % default_threshold: .4, - # recall = (1*1 + .1*0)/(1 + .1) = 1/1.1 = .90909090909 - keys.RECALL_AT_THRESHOLD % default_threshold: .90909090909, } # Assert spec contains expected tensors. diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py index ad6cdbf5e80..7fa6ef91323 100644 --- a/tensorflow/python/estimator/canned/linear_test.py +++ b/tensorflow/python/estimator/canned/linear_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import os import shutil import tempfile @@ -629,7 +630,7 @@ def _assert_close(expected, actual, rtol=1e-04, name='assert_close'): with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: expected = ops.convert_to_tensor(expected, name='expected') actual = ops.convert_to_tensor(actual, name='actual') - rdiff = math_ops.abs(expected - actual, 'diff') / expected + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) rtol = ops.convert_to_tensor(rtol, name='rtol') return check_ops.assert_less( rdiff, @@ -654,7 +655,7 @@ class LinearRegressorTrainingTest(test.TestCase): writer_cache.FileWriterCache.clear() shutil.rmtree(self._model_dir) - def _mockOptimizer(self, expected_loss=None): + def _mock_optimizer(self, expected_loss=None): expected_var_names = [ '%s/part_0:0' % _AGE_WEIGHT_NAME, '%s/part_0:0' % _BIAS_NAME @@ -686,7 +687,7 @@ class LinearRegressorTrainingTest(test.TestCase): mock_optimizer.__deepcopy__ = lambda _: mock_optimizer return mock_optimizer - def _assertCheckpoint( + def _assert_checkpoint( self, expected_global_step, expected_age_weight=None, expected_bias=None): shapes = { name: shape for (name, shape) in @@ -723,7 +724,7 @@ class LinearRegressorTrainingTest(test.TestCase): num_steps = 10 linear_regressor.train( input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) - self._assertCheckpoint(num_steps) + self._assert_checkpoint(num_steps) def testTrainWithOneDimLabel(self): label_dimension = 1 @@ -742,7 +743,7 @@ class LinearRegressorTrainingTest(test.TestCase): batch_size=batch_size, num_epochs=None, shuffle=True) est.train(train_input_fn, steps=200) - self._assertCheckpoint(200) + self._assert_checkpoint(200) def testTrainWithOneDimWeight(self): label_dimension = 1 @@ -763,14 +764,14 @@ class LinearRegressorTrainingTest(test.TestCase): batch_size=batch_size, num_epochs=None, shuffle=True) est.train(train_input_fn, steps=200) - self._assertCheckpoint(200) + self._assert_checkpoint(200) def testFromScratch(self): # Create LinearRegressor. label = 5. age = 17 # loss = (logits - label)^2 = (0 - 5.)^2 = 25. - mock_optimizer = self._mockOptimizer(expected_loss=25.) + mock_optimizer = self._mock_optimizer(expected_loss=25.) linear_regressor = linear.LinearRegressor( feature_columns=(feature_column_lib.numeric_column('age'),), model_dir=self._model_dir, optimizer=mock_optimizer) @@ -781,7 +782,7 @@ class LinearRegressorTrainingTest(test.TestCase): linear_regressor.train( input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assertCheckpoint( + self._assert_checkpoint( expected_global_step=num_steps, expected_age_weight=0., expected_bias=0.) @@ -801,7 +802,7 @@ class LinearRegressorTrainingTest(test.TestCase): # logits = age * age_weight + bias = 17 * 10. + 5. = 175 # loss = (logits - label)^2 = (175 - 5)^2 = 28900 - mock_optimizer = self._mockOptimizer(expected_loss=28900.) + mock_optimizer = self._mock_optimizer(expected_loss=28900.) linear_regressor = linear.LinearRegressor( feature_columns=(feature_column_lib.numeric_column('age'),), model_dir=self._model_dir, optimizer=mock_optimizer) @@ -812,7 +813,7 @@ class LinearRegressorTrainingTest(test.TestCase): linear_regressor.train( input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps) self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assertCheckpoint( + self._assert_checkpoint( expected_global_step=initial_global_step + num_steps, expected_age_weight=age_weight, expected_bias=bias) @@ -834,7 +835,7 @@ class LinearRegressorTrainingTest(test.TestCase): # logits[0] = 17 * 10. + 5. = 175 # logits[1] = 15 * 10. + 5. = 155 # loss = sum(logits - label)^2 = (175 - 5)^2 + (155 - 3)^2 = 52004 - mock_optimizer = self._mockOptimizer(expected_loss=52004.) + mock_optimizer = self._mock_optimizer(expected_loss=52004.) linear_regressor = linear.LinearRegressor( feature_columns=(feature_column_lib.numeric_column('age'),), model_dir=self._model_dir, optimizer=mock_optimizer) @@ -846,10 +847,351 @@ class LinearRegressorTrainingTest(test.TestCase): input_fn=lambda: ({'age': ((17,), (15,))}, ((5.,), (3.,))), steps=num_steps) self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assertCheckpoint( + self._assert_checkpoint( expected_global_step=initial_global_step + num_steps, expected_age_weight=age_weight, expected_bias=bias) + +class _BaseLinearClassiferTrainingTest(object): + + def __init__(self, n_classes): + self._n_classes = n_classes + self._logits_dimensions = ( + self._n_classes if self._n_classes > 2 else 1) + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s/part_0:0' % _AGE_WEIGHT_NAME, + '%s/part_0:0' % _BIAS_NAME + ] + + def _minimize(loss, global_step): + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + return state_ops.assign_add(global_step, 1).op + assert_loss = _assert_close( + math_ops.to_float(expected_loss, name='expected'), loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + return state_ops.assign_add(global_step, 1).op + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assert_checkpoint( + self, expected_global_step, expected_age_weight=None, expected_bias=None): + logits_dimension = self._logits_dimensions + + shapes = { + name: shape for (name, shape) in + checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual( + expected_global_step, + checkpoint_utils.load_variable( + self._model_dir, ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([1, logits_dimension], shapes[_AGE_WEIGHT_NAME]) + if expected_age_weight is not None: + self.assertAllEqual( + expected_age_weight, + checkpoint_utils.load_variable(self._model_dir, _AGE_WEIGHT_NAME)) + + self.assertEqual([logits_dimension], shapes[_BIAS_NAME]) + if expected_bias is not None: + self.assertAllEqual( + expected_bias, + checkpoint_utils.load_variable(self._model_dir, _BIAS_NAME)) + + def testFromScratchWithDefaultOptimizer(self): + n_classes = self._n_classes + label = 0 + age = 17 + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self._assert_checkpoint(num_steps) + + def testTrainWithTwoDimsLabel(self): + n_classes = self._n_classes + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + data_rank_2 = np.array([[0], [1]]) + self.assertEqual((2,), data_rank_1.shape) + self.assertEqual((2, 1), data_rank_2.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, + y=data_rank_2, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(200) + + def testTrainWithOneDimLabel(self): + n_classes = self._n_classes + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + self.assertEqual((2,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, + y=data_rank_1, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(200) + + def testTrainWithTwoDimsWeight(self): + n_classes = self._n_classes + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + weight_feature_key='w', + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + data_rank_2 = np.array([[0], [1]]) + self.assertEqual((2,), data_rank_1.shape) + self.assertEqual((2, 1), data_rank_2.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, 'w': data_rank_2}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(200) + + def testTrainWithOneDimWeight(self): + n_classes = self._n_classes + batch_size = 20 + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + weight_feature_key='w', + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + self.assertEqual((2,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(200) + + def testFromScratch(self): + n_classes = self._n_classes + label = 1 + age = 17 + # For binary classifer: + # loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are + # all zero initially) and label = 1 so, + # loss = 1 * -log ( sigmoid(logits) ) = 0.69315 + # For multi class classifer: + # loss = cross_entropy(logits, label) where logits are all 0s (weights are + # all zero initially) and label = 1 so, + # loss = 1 * -log ( 1.0 / n_classes ) + # For this particular test case, as logits are same, the formular + # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases. + mock_optimizer = self._mock_optimizer( + expected_loss=-1 * math.log(1.0/n_classes)) + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + expected_global_step=num_steps, + expected_age_weight=[[0.]] if n_classes == 2 else [[0.] * n_classes], + expected_bias=[0.] if n_classes == 2 else [.0] * n_classes) + + def testFromCheckpoint(self): + # Create initial checkpoint. + n_classes = self._n_classes + label = 1 + age = 17 + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + age_weight = [[2.0]] if n_classes == 2 else ( + np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), + (1, n_classes))) + bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(age_weight, name=_AGE_WEIGHT_NAME) + variables.Variable(bias, name=_BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + _save_variables_to_ckpt(self._model_dir) + + # For binary classifer: + # logits = age * age_weight + bias = 17 * 2. - 35. = -1. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133 + # For multi class classifer: + # loss = cross_entropy(logits, label) + # where logits = 17 * age_weight + bias and label = 1 + # so, loss = 1 * -log ( soft_max(logits)[1] ) + if n_classes == 2: + expected_loss = 1.3133 + else: + logits = age_weight * age + bias + logits_exp = np.exp(logits) + softmax = logits_exp / logits_exp.sum() + expected_loss = -1 * math.log(softmax[0, label]) + + mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + expected_global_step=initial_global_step + num_steps, + expected_age_weight=age_weight, + expected_bias=bias) + + def testFromCheckpointMultiBatch(self): + # Create initial checkpoint. + n_classes = self._n_classes + label = [1, 0] + age = [17, 18.5] + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + age_weight = [[2.0]] if n_classes == 2 else ( + np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32), + (1, n_classes))) + bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(age_weight, name=_AGE_WEIGHT_NAME) + variables.Variable(bias, name=_BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + _save_variables_to_ckpt(self._model_dir) + + # For binary classifer: + # logits = age * age_weight + bias + # logits[0] = 17 * 2. - 35. = -1. + # logits[1] = 18.5 * 2. - 35. = 2. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133 + # loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269 + # For multi class classifer: + # loss = cross_entropy(logits, label) + # where logits = [17, 18.5] * age_weight + bias and label = [1, 0] + # so, loss = 1 * -log ( soft_max(logits)[label] ) + if n_classes == 2: + expected_loss = (1.3133 + 2.1269) + else: + logits = age_weight * np.reshape(age, (2, 1)) + bias + logits_exp = np.exp(logits) + softmax_row_0 = logits_exp[0] / logits_exp[0].sum() + softmax_row_1 = logits_exp[1] / logits_exp[1].sum() + expected_loss_0 = -1 * math.log(softmax_row_0[label[0]]) + expected_loss_1 = -1 * math.log(softmax_row_1[label[1]]) + expected_loss = expected_loss_0 + expected_loss_1 + + mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) + + est = linear.LinearClassifier( + feature_columns=(feature_column_lib.numeric_column('age'),), + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': (age)}, (label)), + steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + expected_global_step=initial_global_step + num_steps, + expected_age_weight=age_weight, + expected_bias=bias) + + +class LinearClassiferWithBinaryClassesTrainingTest( + _BaseLinearClassiferTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): + test.TestCase.__init__(self, methodName) + _BaseLinearClassiferTrainingTest.__init__(self, n_classes=2) + + +class LinearClassiferWithMultiClassesTrainingTest( + _BaseLinearClassiferTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): + test.TestCase.__init__(self, methodName) + _BaseLinearClassiferTrainingTest.__init__(self, n_classes=4) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/canned/metric_keys.py b/tensorflow/python/estimator/canned/metric_keys.py index 1261d1dcfb1..91e3bf1d83a 100644 --- a/tensorflow/python/estimator/canned/metric_keys.py +++ b/tensorflow/python/estimator/canned/metric_keys.py @@ -29,6 +29,8 @@ class MetricKeys(object): LOSS_MEAN = model_fn.MetricKeys.AVERAGE_LOSS ACCURACY = 'accuracy' + # This is the best the model could do by always predicting one class. + # Should be < ACCURACY in a trained model. ACCURACY_BASELINE = 'accuracy_baseline' AUC = 'auc' AUC_PR = 'auc_precision_recall' diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc index 04bcbddde46..2931b8c378c 100644 --- a/tensorflow/python/framework/cpp_shape_inference.cc +++ b/tensorflow/python/framework/cpp_shape_inference.cc @@ -182,6 +182,7 @@ std::vector RunCppShapeInference( std::vector input_constant_tensor_values_v; int cnt = PyList_Size(input_constant_tensor_values); + input_constant_tensor_values_v.reserve(cnt); for (int i = 0; i < cnt; ++i) { input_constant_tensor_values_v.push_back( PyList_GetItem(input_constant_tensor_values, i)); diff --git a/tensorflow/python/framework/graph_io.py b/tensorflow/python/framework/graph_io.py index 0033a370883..f909bcd62d2 100644 --- a/tensorflow/python/framework/graph_io.py +++ b/tensorflow/python/framework/graph_io.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import os.path +from google.protobuf import text_format from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io @@ -64,7 +65,8 @@ def write_graph(graph_or_graph_def, logdir, name, as_text=True): file_io.recursive_create_dir(logdir) path = os.path.join(logdir, name) if as_text: - file_io.atomic_write_string_to_file(path, str(graph_def)) + file_io.atomic_write_string_to_file(path, + text_format.MessageToString(graph_def)) else: file_io.atomic_write_string_to_file(path, graph_def.SerializeToString()) return path diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index a3168a00883..00260fe0bf7 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -36,9 +36,10 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/python/framework/python_op_gen_internal.h" namespace tensorflow { -namespace { +namespace python_op_gen_internal { const int kRightMargin = 78; @@ -67,15 +68,11 @@ bool IsPythonReserved(const string& s) { "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError", "UnicodeWarning", "UserWarning", "ValueError", "Warning", "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__", - "__package__", - // Imports and symbols used in the generated code: - "_text_format", "_op_def_pb2", "_common_shapes", "_op_def_registry", - "_ops", "_op_def_library"}); + "__package__"}); return kPythonReserved->count(s) > 0; } -// Add a _ to the end of s if necessary to avoid a Python keyword or built-in. string AvoidPythonReserved(const string& s) { if (IsPythonReserved(s)) return strings::StrCat(s, "_"); return s; @@ -323,8 +320,8 @@ string StringToPython(const string& str) { return strings::StrCat("\"", str_util::CEscape(str), "\""); } -string DataTypeToPython(DataType dtype) { - return strings::StrCat("tf.", PythonDataTypeString(dtype)); +string DataTypeToPython(DataType dtype, const string& dtype_module) { + return strings::StrCat(dtype_module, PythonDataTypeString(dtype)); } string ShapeToPython(const TensorShapeProto& shape) { @@ -346,7 +343,8 @@ string TensorToPython(const TensorProto& proto) { return ProtoShortDebugString(proto); } -string AttrListToPython(const AttrValue& value) { +string AttrListToPython(const AttrValue& value, + const string& dtype_module = "tf.") { string ret; if (value.list().s_size() > 0) { for (int i = 0; i < value.list().s_size(); ++i) { @@ -371,7 +369,8 @@ string AttrListToPython(const AttrValue& value) { } else if (value.list().type_size() > 0) { for (int i = 0; i < value.list().type_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); - strings::StrAppend(&ret, DataTypeToPython(value.list().type(i))); + strings::StrAppend(&ret, + DataTypeToPython(value.list().type(i), dtype_module)); } } else if (value.list().shape_size() > 0) { for (int i = 0; i < value.list().shape_size(); ++i) { @@ -392,7 +391,8 @@ string AttrListToPython(const AttrValue& value) { return ret; } -string AttrValueToPython(const string& type, const AttrValue& value) { +string AttrValueToPython(const string& type, const AttrValue& value, + const string& dtype_module) { if (type == "string") { return StringToPython(value.s()); } else if (type == "int") { @@ -402,7 +402,7 @@ string AttrValueToPython(const string& type, const AttrValue& value) { } else if (type == "bool") { return value.b() ? "True" : "False"; } else if (type == "type") { - return DataTypeToPython(value.type()); + return DataTypeToPython(value.type(), dtype_module); } else if (type == "shape") { return ShapeToPython(value.shape()); } else if (type == "tensor") { @@ -410,7 +410,7 @@ string AttrValueToPython(const string& type, const AttrValue& value) { } else if (type == "func") { return StringToPython(value.func().name()); } else if (StringPiece(type).starts_with("list(")) { - return strings::StrCat("[", AttrListToPython(value), "]"); + return strings::StrCat("[", AttrListToPython(value, dtype_module), "]"); } else { return "?"; } @@ -432,35 +432,41 @@ void GenerateLowerCaseOpName(const string& str, string* result) { } } -} // namespace +static void AddDelimiter(string* append_to, const string& delim) { + if (!append_to->empty()) strings::StrAppend(append_to, delim); +} -string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { - string result; - // Map from attr name to the first input arg it is inferred from. - std::unordered_map inferred_attrs; +GenPythonOp::GenPythonOp(const OpDef& op_def, const string& function_name) + : op_def_(op_def), + function_name_(function_name), + num_outs_(op_def.output_arg_size()) {} + +GenPythonOp::~GenPythonOp() {} + +string GenPythonOp::Code() { // This has all the input args followed by those attrs that don't have // defaults. std::vector args_no_default; // The parameters with defaults (these have to be listed after those without). // No input args are included, just attrs. std::vector args_with_defaults; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); + for (int i = 0; i < op_def_.input_arg_size(); ++i) { + const auto& arg(op_def_.input_arg(i)); args_no_default.push_back(arg.name()); if (!arg.type_attr().empty()) { - gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name()); + gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name()); } else if (!arg.type_list_attr().empty()) { - gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(), + gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(), arg.name()); } if (!arg.number_attr().empty()) { - gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name()); + gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name()); } } - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < op_def_.attr_size(); ++i) { + const auto& attr(op_def_.attr(i)); // Do not add inferred attrs to the Python function signature. - if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) { + if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { if (attr.has_default_value()) { args_with_defaults.push_back(attr.name()); } else { @@ -471,110 +477,92 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { // Save the list of attr parameters (attrs that won't be inferred), // those with defaults go at the end. - std::vector attrs; // Get the attrs in the order we want by taking the attrs without defaults // from the end of args_no_default, and adding args_no_default. - attrs.reserve(args_no_default.size() - op_def.input_arg_size() + - args_with_defaults.size()); - attrs.insert(attrs.end(), args_no_default.begin() + op_def.input_arg_size(), - args_no_default.end()); - attrs.insert(attrs.end(), args_with_defaults.begin(), - args_with_defaults.end()); + attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() + + args_with_defaults.size()); + attrs_.insert(attrs_.end(), + args_no_default.begin() + op_def_.input_arg_size(), + args_no_default.end()); + attrs_.insert(attrs_.end(), args_with_defaults.begin(), + args_with_defaults.end()); - std::vector param_names; - param_names.reserve(args_no_default.size() + args_with_defaults.size()); + param_names_.reserve(args_no_default.size() + args_with_defaults.size()); string parameters; for (const string& name : args_no_default) { - if (!parameters.empty()) strings::StrAppend(¶meters, ", "); + AddDelimiter(¶meters, ", "); const string param = AvoidPythonReserved(name); strings::StrAppend(¶meters, param); - param_names.push_back(param); + param_names_.push_back(param); } for (const string& name : args_with_defaults) { - if (!parameters.empty()) strings::StrAppend(¶meters, ", "); + AddDelimiter(¶meters, ", "); const string param = AvoidPythonReserved(name); strings::StrAppend(¶meters, param, "=None"); - param_names.push_back(param); + param_names_.push_back(param); } + AddDelimiter(¶meters, ", "); + strings::StrAppend(¶meters, "name=None"); - const string lower_op_name = strings::StrCat(is_hidden ? "_" : "", op_name); + AddDefLine(parameters); + AddDocStringDescription(); + AddDocStringArgs(); + AddDocStringInputs(); + AddDocStringAttrs(); + AddDocStringNameArg(); + AddOutputGlobals(); + AddDocStringOutputs(); + strings::StrAppend(&result_, " \"\"\"\n"); + AddBody(" "); + strings::StrAppend(&result_, "\n\n"); - const int num_outs = op_def.output_arg_size(); - // Prepare a NamedTuple type to hold the outputs, if there are multiple - if (num_outs > 1) { - // Prepare the list of output names - std::vector out_names(num_outs); - for (int i = 0; i < num_outs; ++i) { - if (!op_def.output_arg(i).name().empty()) { - out_names[i] = op_def.output_arg(i).name(); - } else { - out_names[i] = strings::StrCat("output", i); - } - } - string out_names_list = - strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]"); + return prelude_ + result_; +} - // Provide the output names as a Python list - string lower_op_name_outputs = - strings::StrCat("_", lower_op_name, "_outputs"); - const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = "); - strings::StrAppend(&result, "\n", - WordWrap(outputs_prefix, out_names_list, kRightMargin), - "\n"); +void GenPythonOp::AddDefLine(const string& parameters) { + const string def_prefix = strings::StrCat("def ", function_name_, "("); + strings::StrAppend( + &result_, WordWrap(def_prefix, parameters + "):", kRightMargin), "\n"); +} - strings::StrAppend(&result, "_", op_def.name(), - "Output = _collections.namedtuple(\n"); - const string tuple_type_prefix = " "; - const string tuple_type_suffix = strings::StrCat( - "\"", op_def.name(), "\", ", lower_op_name_outputs, ")"); - strings::StrAppend( - &result, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin), - "\n\n"); - } - strings::StrAppend(&result, "\n"); - - // Print: def Function(parameters): - const string def_prefix = strings::StrCat("def ", lower_op_name, "("); - const bool has_args = args_no_default.size() + args_with_defaults.size() > 0; - const string def_suffix = - strings::StrCat(parameters, has_args ? ", " : "", "name=None):"); - - strings::StrAppend(&result, WordWrap(def_prefix, def_suffix, kRightMargin), - "\n"); - - // Format the Op's descriptions so that it can be a Python docstring. +void GenPythonOp::AddDocStringDescription() { string comment; - if (op_def.summary().empty()) { + if (op_def_.summary().empty()) { comment = "TODO: add doc.\n"; } else { - comment = strings::StrCat(op_def.summary(), "\n"); - if (!op_def.description().empty()) { - strings::StrAppend(&comment, "\n", Indent(2, 2, op_def.description())); + comment = strings::StrCat(op_def_.summary(), "\n"); + if (!op_def_.description().empty()) { + strings::StrAppend(&comment, "\n", Indent(2, 2, op_def_.description())); } } + strings::StrAppend(&result_, " r\"\"\"", comment, "\n"); +} - strings::StrAppend(&result, " r\"\"\"", comment, "\n Args:\n"); +void GenPythonOp::AddDocStringArgs() { + strings::StrAppend(&result_, " Args:\n"); +} - // Inputs - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); - StringPiece description = op_def.input_arg(i).description(); +void GenPythonOp::AddDocStringInputs() { + for (int i = 0; i < op_def_.input_arg_size(); ++i) { + const auto& arg(op_def_.input_arg(i)); + StringPiece description = op_def_.input_arg(i).description(); string desc; if (ConsumeEquals(&description)) { // Skip the generated type info. - desc = strings::StrCat(param_names[i], ": "); + desc = strings::StrCat(param_names_[i], ": "); } else { - desc = strings::StrCat(param_names[i], ": ", - ArgTypeName(op_def, arg, inferred_attrs, false)); + desc = strings::StrCat(param_names_[i], ": ", + ArgTypeName(op_def_, arg, inferred_attrs_, false)); } if (!description.empty()) { AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); } - strings::StrAppend(&result, Indent(4, 6, desc)); + strings::StrAppend(&result_, Indent(4, 6, desc)); } +} - // Attrs - for (const string& name : attrs) { - const auto& attr = *FindAttr(name, op_def); +void GenPythonOp::AddDocStringAttrs() { + for (const string& name : attrs_) { + const auto& attr = *FindAttr(name, op_def_); string desc = strings::StrCat(AvoidPythonReserved(name), ": "); static const char* const kAttrTypeName[][2] = { @@ -638,40 +626,86 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { AppendWithinWidth(&desc, attr.description(), kRightMargin - 4 /* indent */); } - strings::StrAppend(&result, Indent(4, 6, desc)); + strings::StrAppend(&result_, Indent(4, 6, desc)); } +} - strings::StrAppend(&result, +void GenPythonOp::AddDocStringNameArg() { + strings::StrAppend(&result_, " name: A name for the operation (optional).\n"); +} - std::vector output_type_string; - output_type_string.reserve(num_outs); - for (int i = 0; i < num_outs; ++i) { - output_type_string.push_back( - ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true)); +void GenPythonOp::AddOutputGlobals() { + // Prepare a NamedTuple type to hold the outputs, if there are multiple + if (num_outs_ > 1) { + // Prepare the list of output names + std::vector out_names(num_outs_); + for (int i = 0; i < num_outs_; ++i) { + if (!op_def_.output_arg(i).name().empty()) { + out_names[i] = op_def_.output_arg(i).name(); + } else { + out_names[i] = strings::StrCat("output", i); + } + } + string out_names_list = + strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]"); + + // Provide the output names as a Python list + string lower_op_name_outputs = + strings::StrCat("_", function_name_, "_outputs"); + const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = "); + strings::StrAppend(&prelude_, "\n", + WordWrap(outputs_prefix, out_names_list, kRightMargin), + "\n"); + + strings::StrAppend(&prelude_, "_", op_def_.name(), + "Output = _collections.namedtuple(\n"); + const string tuple_type_prefix = " "; + const string tuple_type_suffix = strings::StrCat( + "\"", op_def_.name(), "\", ", lower_op_name_outputs, ")"); + strings::StrAppend( + &prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin), + "\n\n"); } - strings::StrAppend(&result, GetReturns(op_def, output_type_string)); + strings::StrAppend(&prelude_, "\n"); +} - string return_prefix = strings::StrCat(" result = _op_def_lib.apply_op("); - string return_args = strings::StrCat("\"", op_def.name(), "\", "); - for (size_t i = 0; i < param_names.size(); ++i) { - strings::StrAppend(&return_args, param_names[i], "=", param_names[i], ", "); +void GenPythonOp::AddDocStringOutputs() { + std::vector output_type_string; + output_type_string.reserve(num_outs_); + for (int i = 0; i < num_outs_; ++i) { + output_type_string.push_back( + ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true)); + } + strings::StrAppend(&result_, GetReturns(op_def_, output_type_string)); +} + +void GenPythonOp::AddBody(const string& prefix) { + string return_prefix = + strings::StrCat(prefix, "result = _op_def_lib.apply_op("); + string return_args = strings::StrCat("\"", op_def_.name(), "\", "); + for (size_t i = 0; i < param_names_.size(); ++i) { + strings::StrAppend(&return_args, param_names_[i], "=", param_names_[i], + ", "); } strings::StrAppend(&return_args, "name=name)"); - strings::StrAppend(&result, " \"\"\"\n", + strings::StrAppend(&result_, // Wrap the arguments, and indent to the (. WordWrap(return_prefix, return_args, kRightMargin), "\n"); - if (num_outs <= 1) { - strings::StrAppend(&result, " return result\n"); + if (num_outs_ <= 1) { + strings::StrAppend(&result_, prefix, "return result\n"); } else { - strings::StrAppend(&result, " return _", op_def.name(), + strings::StrAppend(&result_, prefix, "return _", op_def_.name(), "Output._make(result)\n"); } - strings::StrAppend(&result, "\n\n"); +} - return result; +} // namespace python_op_gen_internal + +string GetPythonOp(const OpDef& op_def, const string& function_name) { + return python_op_gen_internal::GenPythonOp(op_def, function_name).Code(); } string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, @@ -711,20 +745,20 @@ from tensorflow.python.framework import op_def_library as _op_def_library } } - // PrintPythonOp(op_def, is_hidden, op_def.name()); - string lower_case_name; - GenerateLowerCaseOpName(op_def.name(), &lower_case_name); + string function_name; + python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), + &function_name); + if (is_hidden) function_name = strings::StrCat("_", function_name); // When users create custom python wrappers, they may link in the // default op registry by accident, and because they can't // enumerate all 'hidden' symbols, this guard is to prevent // instantiating a python reserved word in their wrapper. - if (!is_hidden && IsPythonReserved(lower_case_name)) { + if (python_op_gen_internal::IsPythonReserved(function_name)) { continue; } - strings::StrAppend(&result, - GetPythonOp(op_def, is_hidden, lower_case_name)); + strings::StrAppend(&result, GetPythonOp(op_def, function_name)); if (!require_shapes) { strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h index d865c238743..f485044c5af 100644 --- a/tensorflow/python/framework/python_op_gen.h +++ b/tensorflow/python/framework/python_op_gen.h @@ -31,7 +31,7 @@ void PrintPythonOps(const OpList& ops, const std::vector& hidden_ops, bool require_shapes); string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, bool require_shapes); -string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name); +string GetPythonOp(const OpDef& op_def, const string& function_name); // Get the python wrappers for a list of ops in a OpList. // `op_list_buf` should be a pointer to a buffer containing diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h new file mode 100644 index 00000000000..44b1aed71f1 --- /dev/null +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ +#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ + +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace python_op_gen_internal { + +// Returns true if s is a Python keyword or built-in. +bool IsPythonReserved(const string& s); + +// Add a _ to the end of s if necessary to avoid a Python keyword or built-in. +string AvoidPythonReserved(const string& s); + +// Convert an AttrValue with type `type` to the Python representation for +// that value. +string AttrValueToPython(const string& type, const AttrValue& value, + const string& dtype_module = "tf."); + +void GenerateLowerCaseOpName(const string& str, string* result); + +class GenPythonOp { + public: + GenPythonOp(const OpDef& op_def, const string& function_name); + virtual ~GenPythonOp(); + + virtual string Code(); + + protected: + // Print: def Function(parameters): + void AddDefLine(const string& parameters); + + // Format the Op's descriptions so that it can be a Python docstring. + void AddDocStringDescription(); + + void AddDocStringArgs(); + void AddDocStringInputs(); + void AddDocStringAttrs(); + void AddDocStringNameArg(); + void AddOutputGlobals(); + void AddDocStringOutputs(); + void AddBody(const string& prefix); + + // From constructor arguments + const OpDef& op_def_; + const string& function_name_; + const int num_outs_; + + // Return value from Code() is prelude_ + result_. + string prelude_; // Code before function definition + string result_; // Function definition + + // Map from attr name to the first input arg it is inferred from + std::unordered_map inferred_attrs_; + + // The names of the non-inferred attrs, in parameter order + std::vector attrs_; + + // All parameters, including inputs & non-inferred attrs, required and those + // with defaults, except "name" + std::vector param_names_; +}; + +} // namespace python_op_gen_internal +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_ diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc index 273a74dd286..29976b79495 100644 --- a/tensorflow/python/grappler/cost_analyzer.cc +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -56,8 +56,8 @@ void CostAnalyzer::GatherCosts() { CostGraphDef cost_graph_measured; PredictCosts(&measure_estimator_, &cost_graph_measured, &total_time_measured_); + VLOG(1) << "Graph size: " << item_->graph.node_size(); VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size(); - op_perf_ = CostGraphToOpPerformanceData(cost_graph_measured, item_->graph); CostGraphDef cost_graph_analytical; PredictCosts(&analytical_estimator_, &cost_graph_analytical, @@ -66,25 +66,32 @@ void CostAnalyzer::GatherCosts() { << cost_graph_analytical.node_size(); CostGraphDef cost_graph_analytical_filtered; - std::set cost_nodes; - for (auto& node : cost_graph_measured.node()) { - cost_nodes.insert(node.name()); + CostGraphDef cost_graph_measured_filtered; + std::map measured_nodes; + for (const auto& node : cost_graph_measured.node()) { + measured_nodes[node.name()] = &node; } for (const auto& node : cost_graph_analytical.node()) { - auto it = cost_nodes.find(node.name()); + auto it = measured_nodes.find(node.name()); // Filter the nodes that are not the cost nodes returned by // MeasuringCostEstimator. - if (it == cost_nodes.end()) { + if (it == measured_nodes.end()) { continue; } - auto added_node = cost_graph_analytical_filtered.add_node(); - *added_node = node; + auto added_node_analytical = cost_graph_analytical_filtered.add_node(); + auto added_node_measured = cost_graph_measured_filtered.add_node(); + *added_node_analytical = node; + *added_node_measured = *(it->second); } VLOG(1) << "cost_graph_analytical_filtered size: " << cost_graph_analytical_filtered.node_size(); + // TODO(yaozhang): add a test to make sure that op_perf_analytical_ and + // op_perf_ cover the same set of nodes. op_perf_analytical_ = CostGraphToOpPerformanceData( cost_graph_analytical_filtered, item_->graph); + op_perf_ = + CostGraphToOpPerformanceData(cost_graph_measured_filtered, item_->graph); } void CostAnalyzer::PreprocessCosts() { diff --git a/tensorflow/python/grappler/cost_analyzer_test.py b/tensorflow/python/grappler/cost_analyzer_test.py index 19d3c9695bf..726db29f3c1 100644 --- a/tensorflow/python/grappler/cost_analyzer_test.py +++ b/tensorflow/python/grappler/cost_analyzer_test.py @@ -19,11 +19,18 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.grappler import cost_analyzer +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_grad # pylint: disable=unused-import +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import adam class PyWrapOptimizeGraphTest(test.TestCase): @@ -51,6 +58,40 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Also print the report to make it easier to debug print("{}".format(report)) + def testSmallNetwork(self): + image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1]) + label = array_ops.placeholder(dtypes.float32, shape=[1, 10]) + w = variables.Variable( + random_ops.truncated_normal([5, 5, 1, 32], stddev=0.1)) + b = variables.Variable(random_ops.truncated_normal([32], stddev=0.1)) + conv = nn_ops.conv2d(image, w, strides=[1, 1, 1, 1], padding="SAME") + h_conv = nn_ops.relu(conv + b) + h_conv_flat = array_ops.reshape(h_conv, [1, -1]) + + w_fc = variables.Variable( + random_ops.truncated_normal([25088, 10], stddev=0.1)) + b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1)) + y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc) + + cross_entropy = math_ops.reduce_mean(-math_ops.reduce_sum( + label * math_ops.log(y_conv), reduction_indices=[1])) + _ = adam.AdamOptimizer(1e-4).minimize(cross_entropy) + + mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) + report = cost_analyzer.GenerateCostReport(mg) + + self.assertTrue(b"MatMul" in report) + self.assertTrue(b"ApplyAdam" in report) + self.assertTrue(b"Conv2D" in report) + self.assertTrue(b"Conv2DBackpropInput" in report) + self.assertTrue(b"Conv2DBackpropFilter" in report) + self.assertTrue(b"Softmax" in report) + + # Also print the report to make it easier to debug + print("{}".format(report)) + + +# print("{}".format(mg.graph_def)) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i index 404ce351801..a8067467d91 100644 --- a/tensorflow/python/grappler/tf_optimizer.i +++ b/tensorflow/python/grappler/tf_optimizer.i @@ -67,7 +67,9 @@ PyObject* TF_OptimizeGraph( const tensorflow::RewriterConfig& rewriter_config, const tensorflow::MetaGraphDef& metagraph, const string& graph_id, TF_Status* out_status) { - const tensorflow::grappler::ItemConfig item_config; + tensorflow::grappler::ItemConfig item_config; + item_config.inline_functions = false; + item_config.apply_optimizations = false; std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); std::unordered_map device_map; diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 27b69b5a638..79c7905427e 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1954,7 +1954,6 @@ cuda_py_test( "//tensorflow/python:variables", ], flaky = 1, # create_local_cluster sometimes times out. - tags = ["nomsan"], # b/38390993 ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py index e6d0c06d140..dbbc2de811e 100644 --- a/tensorflow/python/kernel_tests/basic_gpu_test.py +++ b/tensorflow/python/kernel_tests/basic_gpu_test.py @@ -228,9 +228,9 @@ class BroadcastSimpleTest(test.TestCase): class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase): """Tests concurrent sessions executing on the same GPU.""" - def _run_session(self, results): + def _run_session(self, session, results): n_iterations = 500 - with self.test_session(use_gpu=True) as s: + with session as s: data = variables.Variable(1.0) with ops.device('/gpu:0'): random_seed.set_random_seed(1) @@ -245,29 +245,29 @@ class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase): for _ in xrange(n_iterations): value = s.run(x4) - results.append(value) - if value != results[0]: + results.add(value.flat[0]) + if len(results) != 1: break def testConcurrentSessions(self): - if not test.is_gpu_available(): - return - n_threads = 4 - results = [[]] * n_threads - threads = [ - threading.Thread(target=self._run_session, args=(results[i],)) - for i in xrange(n_threads) - ] + threads = [] + results = [] + for _ in xrange(n_threads): + session = self.test_session(graph=ops.Graph(), use_gpu=True) + results.append(set()) + args = (session, results[-1]) + threads.append(threading.Thread(target=self._run_session, args=args)) + for thread in threads: thread.start() for thread in threads: thread.join() - flat_results = [x for x in itertools.chain(*results)] - self.assertNotEqual(0, len(flat_results)) - for result in flat_results: - self.assertEqual(result, flat_results[0]) + flat_results = set([x for x in itertools.chain(*results)]) + self.assertEqual(1, + len(flat_results), + 'Expected single value, got %r' % flat_results) if __name__ == '__main__': diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index 9a1c918b150..b7f8f5c51f6 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl @@ -303,7 +304,7 @@ class CholeskyBenchmark(test.Benchmark): ops.device("/cpu:0"): l = linalg_ops.cholesky(data) self.run_op_benchmark( - sess, l, + sess, control_flow_ops.group(l,), min_iters=25, name="cholesky_cpu_{size}".format(size=size)) @@ -328,7 +329,7 @@ class CholeskyBenchmark(test.Benchmark): ops.device(device): grad = grad_fn(l, grad_data) self.run_op_benchmark( - sess, grad, + sess, control_flow_ops.group(grad,), min_iters=25, name="{name}_{dev}_{size}".format( name=name, dev=grad.device, size=size)) diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py index d009f4e9319..2f8f85866df 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py @@ -252,7 +252,7 @@ class DirichletMultinomialTest(test.TestCase): ]) self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04) self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.05) - self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03) + self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.05) self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02) def testCovariance(self): diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 2a90bc539bb..79254cb28c2 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -280,6 +280,18 @@ class IndexTableFromFile(test.TestCase): lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) + def test_string_index_table_from_file_tensor_filename(self): + vocabulary_file = self._createVocabFile("f2i_vocab1.txt") + with self.test_session(): + vocabulary_file = constant_op.constant(vocabulary_file) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab2.txt", values=("42", "1", "-1000")) @@ -340,7 +352,11 @@ class IndexTableFromFile(test.TestCase): 860), # 3 + fingerprint("toccata") mod 300. ids.eval()) - def test_index_table_from_file_with_only_oov_buckets(self): + def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self): + self.assertRaises( + ValueError, lookup_ops.index_table_from_file, vocabulary_file="") + + def test_index_table_from_file_fails_with_empty_vocabulary(self): self.assertRaises( ValueError, lookup_ops.index_table_from_file, vocabulary_file=None) diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index ea6f55281ed..780d1c2b8e0 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -66,9 +66,6 @@ class BatchNormalization(base.Layer): moving_variance_initializer: Initializer for the moving variance. beta_regularizer: Optional regularizer for the beta weight. gamma_regularizer: Optional regularizer for the gamma weight. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - name: A string, the name of the layer. renorm: Whether to use Batch Renormalization (https://arxiv.org/abs/1702.03275). This adds extra variables during training. The inference is the same for either value of this parameter. @@ -82,6 +79,11 @@ class BatchNormalization(base.Layer): and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. + fused: if `True`, use a faster, fused implementation based on + nn.fused_batch_norm. If `None`, use the fused implementation if possible. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + name: A string, the name of the layer. """ def __init__(self, @@ -99,6 +101,7 @@ class BatchNormalization(base.Layer): renorm=False, renorm_clipping=None, renorm_momentum=0.99, + fused=False, trainable=True, name=None, **kwargs): @@ -116,6 +119,10 @@ class BatchNormalization(base.Layer): self.beta_regularizer = beta_regularizer self.gamma_regularizer = gamma_regularizer self.renorm = renorm + self.fused = fused + if self.fused and renorm: + raise ValueError( + 'Batch renorm is currently not supported with fused batch norm.') if renorm: renorm_clipping = renorm_clipping or {} keys = ['rmax', 'rmin', 'dmax'] @@ -130,6 +137,13 @@ class BatchNormalization(base.Layer): if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndim = len(input_shape) + # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the + # output back to its original shape accordingly. + if self.fused and ndim != 4: + raise ValueError( + 'Only 4D inputs are currently supported with fused batch norm. ' + 'Consider reshaping the input to 4D and reshape the output back ' + 'to its original shape. Got input rank: ', ndim) if self.axis < 0: axis = ndim + self.axis else: @@ -137,6 +151,20 @@ class BatchNormalization(base.Layer): if axis < 0 or axis >= ndim: raise ValueError('Value of `axis` argument ' + str(self.axis) + ' is out of range for input with rank ' + str(ndim)) + + if self.fused is None: + self.fused = not self.renorm and ndim == 4 and axis in [1, 3] + + if self.fused: + if axis == 1: + self._data_format = 'NCHW' + elif axis == 3: + self._data_format = 'NHWC' + else: + raise ValueError( + 'Only axis 1 and 3 are currently supported dimensions for ' + 'fused batch norm. Got `axis` dimension: ', axis) + param_dim = input_shape[axis] if not param_dim.value: raise ValueError('Input has undefined `axis` dimension. Input shape: ', @@ -152,6 +180,8 @@ class BatchNormalization(base.Layer): trainable=True) else: self.beta = None + if self.fused: + self._beta_const = array_ops.constant(0.0, shape=(param_dim,)) if self.scale: self.gamma = self.add_variable(name='gamma', shape=(param_dim,), @@ -160,6 +190,8 @@ class BatchNormalization(base.Layer): trainable=True) else: self.gamma = None + if self.fused: + self._gamma_const = array_ops.constant(1.0, shape=(param_dim,)) # Disable variable partitioning when creating the moving mean and variance partitioner = self._scope.partitioner @@ -205,6 +237,45 @@ class BatchNormalization(base.Layer): self._scope.set_partitioner(partitioner) self.built = True + def _fused_batch_norm(self, inputs, training): + """Returns the output of fused batch norm.""" + beta = self.beta if self.center else self._beta_const + gamma = self.gamma if self.scale else self._gamma_const + + def _fused_batch_norm_training(): + return nn.fused_batch_norm( + inputs, + gamma, + beta, + epsilon=self.epsilon, + data_format=self._data_format) + + def _fused_batch_norm_inference(): + return nn.fused_batch_norm( + inputs, + gamma, + beta, + mean=self.moving_mean, + variance=self.moving_variance, + epsilon=self.epsilon, + is_training=False, + data_format=self._data_format) + + output, mean, variance = utils.smart_cond( + training, _fused_batch_norm_training, _fused_batch_norm_inference) + + training_value = utils.constant_value(training) + if training_value is not False: + decay = _smart_select(training, lambda: self.momentum, lambda: 1.) + mean_update = moving_averages.assign_moving_average( + self.moving_mean, mean, decay, zero_debias=False) + variance_update = moving_averages.assign_moving_average( + self.moving_variance, variance, decay, zero_debias=False) + self.add_update(mean_update, inputs=inputs) + self.add_update(variance_update, inputs=inputs) + + return output + def _renorm_correction_and_moments(self, mean, variance, training): """Returns the correction and update values for renorm.""" stddev = math_ops.sqrt(variance + self.epsilon) @@ -265,6 +336,9 @@ class BatchNormalization(base.Layer): return (r, d, new_mean, new_variance) def call(self, inputs, training=False): + if self.fused: + return self._fused_batch_norm(inputs, training=training) + # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() @@ -353,7 +427,8 @@ def batch_normalization(inputs, reuse=None, renorm=False, renorm_clipping=None, - renorm_momentum=0.99): + renorm_momentum=0.99, + fused=False): """Functional interface for the batch normalization layer. Reference: http://arxiv.org/abs/1502.03167 @@ -415,6 +490,8 @@ def batch_normalization(inputs, and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. + fused: if `True`, use a faster, fused implementation based on + nn.fused_batch_norm. If `None`, use the fused implementation if possible. Returns: Output tensor. @@ -431,10 +508,11 @@ def batch_normalization(inputs, moving_variance_initializer=moving_variance_initializer, beta_regularizer=beta_regularizer, gamma_regularizer=gamma_regularizer, - trainable=trainable, renorm=renorm, renorm_clipping=renorm_clipping, renorm_momentum=renorm_momentum, + fused=fused, + trainable=trainable, name=name, _reuse=reuse, _scope=name) diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index 933f196e011..fa6c9c4a5db 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -262,6 +262,87 @@ class BNTest(test.TestCase): self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + def test4DInputAxis3Fused(self): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization( + axis=3, epsilon=epsilon, momentum=0.9, fused=True) + inputs = variables.Variable( + np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) + training = array_ops.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(variables.global_variables_initializer()) + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 1, 1, 6)) + np_beta = np.reshape(np_beta, (1, 1, 1, 6)) + for _ in range(100): + np_output, _, _ = sess.run( + [outputs] + bn.updates, feed_dict={training: True}) + # Verify that the axis is normalized during training. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 1, 2)) + std = np.std(np_inputs, axis=(0, 1, 2)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + def test4DInputAxis1Fused(self): + if test.is_gpu_available(cuda_only=True): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization( + axis=1, epsilon=epsilon, momentum=0.9, fused=True) + inputs = variables.Variable( + np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) + training = array_ops.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) + + with self.test_session() as sess: + # Test training with placeholder learning phase. + sess.run(variables.global_variables_initializer()) + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 4, 1, 1)) + np_beta = np.reshape(np_beta, (1, 4, 1, 1)) + for _ in range(100): + np_output, _, _ = sess.run( + [outputs] + bn.updates, feed_dict={training: True}) + # Verify that the axis is normalized during training. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 2, 3)) + std = np.std(np_inputs, axis=(0, 2, 3)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + def testNegativeAxis(self): epsilon = 1e-3 bn = normalization_layers.BatchNormalization( diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 9b2d7618837..c48296eccb0 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -347,6 +347,7 @@ Status ConvertTensorToNdarray(const Tensor& t, PyObject** ret) { PyArray_Descr* descr = PyArray_DescrFromType(typenum); CHECK(descr); std::vector dims; + dims.reserve(t.dims()); for (int i = 0; i < t.dims(); ++i) { dims.push_back(t.dim_size(i)); } diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 82277ebaccb..bffaf6324fc 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -893,7 +893,7 @@ def index_table_from_file(vocabulary_file=None, ``` Args: - vocabulary_file: The vocabulary filename. + vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. num_oov_buckets: The number of out-of-vocabulary buckets. vocab_size: Number of the elements in the vocabulary, if known. default_value: The value to use for out-of-vocabulary feature values. @@ -911,8 +911,9 @@ def index_table_from_file(vocabulary_file=None, ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater than zero. """ - if not vocabulary_file: - raise ValueError("vocabulary_file must be specified.") + if vocabulary_file is None or ( + isinstance(vocabulary_file, str) and not vocabulary_file): + raise ValueError("vocabulary_file must be specified and must not be empty.") if num_oov_buckets < 0: raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." % num_oov_buckets) diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index 5cb2c152b04..a307347f606 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -61,6 +61,9 @@ else: # Import Benchmark class Benchmark = _googletest.Benchmark # pylint: disable=invalid-name +# Import StubOutForTesting class +StubOutForTesting = _googletest.StubOutForTesting # pylint: disable=invalid-name + def main(argv=None): """Runs all unit tests.""" @@ -117,6 +120,7 @@ _allowed_symbols = [ # We piggy-back googletest documentation. 'Benchmark', 'mock', + 'StubOutForTesting', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index b8554abb4ff..ff77470a824 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -279,7 +279,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to - initialize/restore. + initialize/restore. Please check `tf.train.MonitoredSession` for more + information. Args: @@ -633,6 +634,12 @@ class MonitoredSession(_MonitoredSession): See `MonitoredTrainingSession` for an example usage based on chief or worker. + Note: This is not a `tf.Session`. For example, it cannot do following: + + * it cannot be set as default session. + * it cannot be sent to saver.save. + * it cannot be sent to tf.train.start_queue_runners. + Args: session_creator: A factory object to create session. Typically a `ChiefSessionCreator` which is the default one. diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py index d713e222aee..4e58602a6f7 100644 --- a/tensorflow/python/training/queue_runner_impl.py +++ b/tensorflow/python/training/queue_runner_impl.py @@ -22,6 +22,7 @@ import threading import weakref from tensorflow.core.protobuf import queue_runner_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging @@ -401,6 +402,10 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection: A `GraphKey` specifying the graph collection to get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`. + Raises: + ValueError: if `sess` is None and there isn't any default session. + TypeError: if `sess` is not a `tf.Session` object. + Returns: A list of threads. """ @@ -410,6 +415,15 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True, raise ValueError("Cannot start queue runners: No default session is " "registered. Use `with sess.as_default()` or pass an " "explicit session to tf.start_queue_runners(sess=sess)") + + if not isinstance(sess, session.SessionInterface): + # Following check is due to backward compatibility. (b/62061352) + if sess.__class__.__name__ in [ + "MonitoredSession", "SingularMonitoredSession"]: + return [] + raise TypeError("sess must be a `tf.Session` object. " + "Given class: {}".format(sess.__class__)) + with sess.graph.as_default(): threads = [] for qr in ops.get_collection(collection): diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py index 5b00ac9fc31..51c0eecf46a 100644 --- a/tensorflow/python/training/queue_runner_test.py +++ b/tensorflow/python/training/queue_runner_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import coordinator +from tensorflow.python.training import monitored_session from tensorflow.python.training import queue_runner_impl @@ -247,6 +248,33 @@ class QueueRunnerTest(test.TestCase): # The variable should be 3. self.assertEqual(3, var.eval()) + def testStartQueueRunnersRaisesIfNotASession(self): + zero64 = constant_op.constant(0, dtype=dtypes.int64) + var = variables.Variable(zero64) + count_up_to = var.count_up_to(3) + queue = data_flow_ops.FIFOQueue(10, dtypes.float32) + init_op = variables.global_variables_initializer() + qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) + queue_runner_impl.add_queue_runner(qr) + with self.test_session(): + init_op.run() + with self.assertRaisesRegexp(TypeError, "tf.Session"): + queue_runner_impl.start_queue_runners("NotASession") + + def testStartQueueRunnersIgnoresMonitoredSession(self): + zero64 = constant_op.constant(0, dtype=dtypes.int64) + var = variables.Variable(zero64) + count_up_to = var.count_up_to(3) + queue = data_flow_ops.FIFOQueue(10, dtypes.float32) + init_op = variables.global_variables_initializer() + qr = queue_runner_impl.QueueRunner(queue, [count_up_to]) + queue_runner_impl.add_queue_runner(qr) + with self.test_session(): + init_op.run() + threads = queue_runner_impl.start_queue_runners( + monitored_session.MonitoredSession()) + self.assertFalse(threads) + def testStartQueueRunnersNonDefaultGraph(self): # CountUpTo will raise OUT_OF_RANGE when it reaches the count. graph = ops.Graph() diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index e1674745c84..cd8994f73a0 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -239,6 +239,17 @@ CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) #undef CUDNN_DNN_ROUTINE_EACH_R5 #endif +// APIs in R6 +// clang-format off +#if CUDNN_VERSION >= 6000 +#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) \ + __macro(cudnnConvolutionBiasActivationForward) + +// clang-format on +CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) +#undef CUDNN_DNN_ROUTINE_EACH_R6 +#endif + #undef CUDNN_DNN_ROUTINE_EACH } // namespace wrap @@ -1791,6 +1802,7 @@ bool CudnnSupport::DoConvolveImpl( const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, const BatchDescriptor& output_descriptor, DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, @@ -1917,6 +1929,26 @@ bool CudnnSupport::DoConvolveImpl( } } + const bool has_biases = (biases != nullptr); + const bool supported_activation_mode = + (activation_mode == dnn::ActivationMode::kRelu6 || + activation_mode == dnn::ActivationMode::kReluX || + activation_mode == dnn::ActivationMode::kRelu); + + if (has_biases && !supported_activation_mode) { + LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only " + "support relu activation."; + return false; + } + + if (has_biases && activation_mode != dnn::ActivationMode::kNone) { + LOG(ERROR) << "To use cudnnConvolutionBiasActivationForward() " + "with a valid biases tensor, need to also provide " + "a valid activation mode (currently only supports " + "kRelu6, kReluX, and kRelu)."; + return false; + } + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); @@ -1931,14 +1963,45 @@ bool CudnnSupport::DoConvolveImpl( return false; } } - status = wrap::cudnnConvolutionForward( - parent_, ToHandle(dnn_handle_), - /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(), - /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), - /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), - /*algo=*/algo, /*workSpace=*/scratch.opaque(), - /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta, - /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); + if (has_biases) { + CHECK(supported_activation_mode); +#if CUDNN_VERSION < 6000 + LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only " + "supported for cuDNN version >= 6."; + return false; +#else + BatchDescriptor bias_dimensions; + bias_dimensions.set_count(1) + .set_feature_map_count(output_descriptor.feature_map_count()) + .set_height(1) + .set_width(1) + .set_layout(dnn::DataLayout::kBatchYXDepth); + ScopedTensorDescriptor bias_descriptor{ + parent_, bias_dimensions, static_cast(cudnn_type)}; + ScopedActivationDescriptor activation_desc{parent_, activation_mode, + output_descriptor.value_max()}; + status = wrap::cudnnConvolutionBiasActivationForward( + parent_, ToHandle(dnn_handle_), + /*alpha1=*/&alpha, /*srcDesc=*/input_nd.handle(), + /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), + /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), + /*algo=*/algo, /*workSpace=*/scratch.opaque(), + /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&beta, + /*zDesc=*/output_nd.handle(), /*z=*/nullptr, + /*biasDesc=*/bias_descriptor.handle(), + /*bias=*/biases.opaque(), /*activationDesc=*/activation_desc.handle(), + /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); +#endif // CUDNN_VERSION < 6000 + } else { + status = wrap::cudnnConvolutionForward( + parent_, ToHandle(dnn_handle_), + /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(), + /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), + /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), + /*algo=*/algo, /*workSpace=*/scratch.opaque(), + /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta, + /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); + } if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { timer->Destroy(); @@ -2211,16 +2274,48 @@ bool CudnnSupport::DoConvolve( const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, const BatchDescriptor& output_descriptor, DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveImpl( stream, CUDNN_DATA_FLOAT, batch_descriptor, input_data, filter_descriptor, - filter_data, convolution_descriptor, output_descriptor, output_data, + filter_data, convolution_descriptor, biases, activation_mode, + output_descriptor, output_data, scratch_allocator, algorithm_config, + output_profile_result); +} + +bool CudnnSupport::DoConvolve( + Stream* stream, const BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const ConvolutionDescriptor& convolution_descriptor, + const BatchDescriptor& output_descriptor, DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) { + return DoConvolveImpl( + stream, CUDNN_DATA_FLOAT, batch_descriptor, input_data, filter_descriptor, + filter_data, convolution_descriptor, /*biases=*/nullptr, + dnn::ActivationMode::kNone, output_descriptor, output_data, scratch_allocator, algorithm_config, output_profile_result); } +bool CudnnSupport::DoConvolve( + Stream* stream, const BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, + const BatchDescriptor& output_descriptor, + DeviceMemory* output_data) { + LOG(ERROR) << "double-based DNN not yet implemented"; + return false; +} + bool CudnnSupport::DoConvolve( Stream* stream, const BatchDescriptor& batch_descriptor, const DeviceMemory& input_data, @@ -2239,13 +2334,33 @@ bool CudnnSupport::DoConvolve( const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, const BatchDescriptor& output_descriptor, DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveImpl( stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor, - filter_data, convolution_descriptor, output_descriptor, output_data, + filter_data, convolution_descriptor, biases, activation_mode, + output_descriptor, output_data, scratch_allocator, algorithm_config, + output_profile_result); +} + +bool CudnnSupport::DoConvolve( + Stream* stream, const BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const ConvolutionDescriptor& convolution_descriptor, + const BatchDescriptor& output_descriptor, + DeviceMemory* output_data, ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) { + return DoConvolveImpl( + stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor, + filter_data, convolution_descriptor, /*biases=*/nullptr, + dnn::ActivationMode::kNone, output_descriptor, output_data, scratch_allocator, algorithm_config, output_profile_result); } @@ -2942,6 +3057,7 @@ bool CudnnSupport::DoMatMul(Stream* stream, } const auto toPtrs = [](std::vector>& v) { std::vector*> ptrs; + ptrs.reserve(v.size()); for (auto& mem : v) { ptrs.push_back(&mem); } diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 2c8ed9a3353..7824885e1b3 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -137,7 +137,43 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) override; - bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& input_descriptor, + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) override; + + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data) override; + + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) override; + + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, const DeviceMemory& input_data, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, @@ -156,7 +192,7 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemory* output_data) override; - bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& input_descriptor, + bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, const DeviceMemory& input_data, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, @@ -477,6 +513,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, const dnn::BatchDescriptor& output_descriptor, DeviceMemory* output_data, ScratchAllocator* scratch_allocator, diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 8e56933ba38..e8b5bbf5b1a 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -796,6 +796,7 @@ class NormalizeDescriptor { // Describes a kind of non-linearity (threshold-like mathematical function). enum class ActivationMode { + kNone, kSigmoid, // Rectified linear activation: f(x) = x < 0 ? 0 : x kRelu, @@ -910,9 +911,11 @@ class DnnSupport { // input_data: un-owned device memory region which contains the // convolution input. // filter_descriptor: dimensions of the convolution filter. - // weights: coefficients for the convolution filter, these are multiplied - // against values in the input that the filter convolves over. // convolution_descriptor: stride of the convolution filter. + // biases: un-owned device memory region containing biases to add to the + // input. This can be DeviceMemory pointing to NULL only when activation_mode + // is kNone. + // activation_mode: Type of activation to perform. // output_descriptor: dimensions of the output layer. // output_data: un-owned device memory region in which to place the // convolution result. @@ -939,6 +942,55 @@ class DnnSupport { // that if the inverse of the filter is applied to the output in VALID mode // the result is the same size as the input - this requires even more // padding of the input. + virtual bool DoConvolve( + Stream* stream, const dnn::BatchDescriptor& input_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { + return false; + } + + // Enqueues a double-precision fused convolution, bias add, and activation + // operation onto the stream. See DoConvolve above for argument details. + virtual bool DoConvolve( + Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data) { + return false; + } + + // Enqueues a half-precision fused convolution, bias add, and activation + // operation onto the stream. See DoConvolve above for argument details. + virtual bool DoConvolve( + Stream* stream, const dnn::BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const DeviceMemory& biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) { + return false; + } + + // Enqueues a single-precision convolution operation (without bias add + // or activation) onto the stream. + // See DoConvolve above for argument details. virtual bool DoConvolve( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, @@ -950,11 +1002,8 @@ class DnnSupport { const dnn::AlgorithmConfig& algorithm_config, ProfileResult* output_profile_result) = 0; - // Return a list of algorithms supported by the forward convolution pass. - virtual bool GetConvolveAlgorithms( - bool with_winograd_nonfused, std::vector* out_algorithms); - - // Enqueues a double-precision convolution operation onto the stream. + // Enqueues a double-precision convolution operation (without bias add + // or activation) onto the stream. // See DoConvolve above for argument details. virtual bool DoConvolve( Stream* stream, const dnn::BatchDescriptor& batch_descriptor, @@ -965,7 +1014,8 @@ class DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemory* output_data) = 0; - // Enqueues a half-precision convolution operation onto the stream. + // Enqueues a half-precision convolution operation (without bias add + // or activation) onto the stream. // See DoConvolve above for argument details. virtual bool DoConvolve( Stream* stream, const dnn::BatchDescriptor& batch_descriptor, @@ -979,6 +1029,10 @@ class DnnSupport { const dnn::AlgorithmConfig& algorithm_config, ProfileResult* output_profile_result) = 0; + // Return a list of algorithms supported by the forward convolution pass. + virtual bool GetConvolveAlgorithms( + bool with_winograd_nonfused, std::vector* out_algorithms); + // Version of DoConvolve that uses pre-quantized 8 bit coefficients. // coefficient_scales specifies the scaling of each column of coefficients: // original float coefficient[row * num_columns + column] = diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index a393b077034..bb586c58485 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -350,9 +350,65 @@ Stream &Stream::ThenConvolveWithScratch( const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory &filter_data, const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, const dnn::BatchDescriptor &output_descriptor, - DeviceMemory *output, + DeviceMemory *output, ScratchAllocator *scratch_allocator) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output, ScratchAllocator *scratch_allocator) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(filter_descriptor), PARAM(filter_data), PARAM(convolution_descriptor), PARAM(output_descriptor), @@ -362,9 +418,9 @@ Stream &Stream::ThenConvolveWithScratch( if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoConvolve( this, input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output, - /*scratch_allocator=*/scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + convolution_descriptor, output_descriptor, output, scratch_allocator, + dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -389,9 +445,74 @@ Stream &Stream::ThenConvolveWithScratch( if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoConvolve( this, input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output, - /*scratch_allocator=*/scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + convolution_descriptor, output_descriptor, output, scratch_allocator, + dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } } else { SetErrorAndLogNoDnnSupport(); } @@ -461,6 +582,21 @@ Stream &Stream::ThenConvolveWithAlgorithm( return *this; } +Stream &Stream::ThenConvolve( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output) { + return ThenConvolveWithScratch( + input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, /*scratch_allocator=*/nullptr); +} + Stream &Stream::ThenConvolve( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, @@ -582,7 +718,7 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch( this, filter_descriptor, filter_data, output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, backward_input_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -676,7 +812,7 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch( this, filter_descriptor, filter_data, output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, backward_input_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -718,7 +854,7 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch( this, input_descriptor, input_data, output_descriptor, backward_output_data, convolution_descriptor, filter_descriptor, backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -779,7 +915,7 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch( this, input_descriptor, input_data, output_descriptor, backward_output_data, convolution_descriptor, filter_descriptor, backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -3868,7 +4004,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -3900,7 +4036,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -3934,7 +4070,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -3973,7 +4109,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 5b46b86f54a..bc1d05cc08c 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -240,6 +240,16 @@ class Stream { DeviceMemory *offset_backprop); // TODO(leary) add double-precision version of this interface. + Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output); + Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, const dnn::FilterDescriptor &filter_descriptor, @@ -268,6 +278,27 @@ class Stream { const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output_data); + Stream &ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator); + + Stream &ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator); + Stream &ThenConvolveWithScratch( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, @@ -286,6 +317,31 @@ class Stream { const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output, ScratchAllocator *scratch_allocator); + Stream &ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result); + + Stream &ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory &biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result); + Stream &ThenConvolveWithAlgorithm( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD index b5bff9eaf72..1eb5b124157 100644 --- a/tensorflow/tensorboard/BUILD +++ b/tensorflow/tensorboard/BUILD @@ -13,6 +13,9 @@ py_binary( deps = [ "//tensorflow/tensorboard/backend:application", "//tensorflow/tensorboard/backend/event_processing:event_file_inspector", + "//tensorflow/tensorboard/plugins/audio:audio_plugin", + "//tensorflow/tensorboard/plugins/distributions:distributions_plugin", + "//tensorflow/tensorboard/plugins/graphs:graphs_plugin", "//tensorflow/tensorboard/plugins/histograms:histograms_plugin", "//tensorflow/tensorboard/plugins/images:images_plugin", "//tensorflow/tensorboard/plugins/projector:projector_plugin", diff --git a/tensorflow/tensorboard/README.md b/tensorflow/tensorboard/README.md index f5d55690230..a9ab4d3bd2a 100644 --- a/tensorflow/tensorboard/README.md +++ b/tensorflow/tensorboard/README.md @@ -330,7 +330,9 @@ TensorBoard uses [reservoir sampling](https://en.wikipedia.org/wiki/Reservoir_sampling) to downsample your data so that it can be loaded into RAM. You can modify the number of elements it will keep per tag in -[tensorboard/backend/server.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/backend/server.py). +[tensorboard/backend/application.py](https://www.github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/tensorboard/backend/application.py). +See this [StackOverflow question](http://stackoverflow.com/questions/43702546/tensorboard-doesnt-show-all-data-points/) +for some more information. ### I get a network security popup every time I run TensorBoard on a mac! diff --git a/tensorflow/tensorboard/backend/BUILD b/tensorflow/tensorboard/backend/BUILD index 3b5ce4c6e3e..c7f22b1b6ab 100644 --- a/tensorflow/tensorboard/backend/BUILD +++ b/tensorflow/tensorboard/backend/BUILD @@ -63,7 +63,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":http_util", - ":process_graph", "//tensorflow:tensorflow_py", "//tensorflow/tensorboard/backend/event_processing:event_accumulator", "//tensorflow/tensorboard/backend/event_processing:event_multiplexer", diff --git a/tensorflow/tensorboard/backend/application.py b/tensorflow/tensorboard/backend/application.py index cf1c376be08..9c492e7dd39 100644 --- a/tensorflow/tensorboard/backend/application.py +++ b/tensorflow/tensorboard/backend/application.py @@ -22,22 +22,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import csv import os import re import threading import time import six -from six import StringIO -from six.moves import urllib -from six.moves import xrange # pylint: disable=redefined-builtin from six.moves.urllib import parse as urlparse import tensorflow as tf from werkzeug import wrappers from tensorflow.tensorboard.backend import http_util -from tensorflow.tensorboard.backend import process_graph from tensorflow.tensorboard.backend.event_processing import event_accumulator from tensorflow.tensorboard.backend.event_processing import event_multiplexer @@ -59,8 +54,12 @@ DEFAULT_SIZE_GUIDANCE = { # Once everything has been migrated, we should be able to delete # /data/runs entirely. _MIGRATED_DATA_KEYS = frozenset(( + 'audio', + 'distributions', + 'graph', 'histograms', 'images', + 'run_metadata', 'scalars', )) @@ -69,11 +68,6 @@ LOGDIR_ROUTE = '/logdir' RUNS_ROUTE = '/runs' PLUGIN_PREFIX = '/plugin' PLUGINS_LISTING_ROUTE = '/plugins_listing' -AUDIO_ROUTE = '/' + event_accumulator.AUDIO -COMPRESSED_HISTOGRAMS_ROUTE = '/' + event_accumulator.COMPRESSED_HISTOGRAMS -INDIVIDUAL_AUDIO_ROUTE = '/individualAudio' -GRAPH_ROUTE = '/' + event_accumulator.GRAPH -RUN_METADATA_ROUTE = '/' + event_accumulator.RUN_METADATA TAB_ROUTES = ['', '/events', '/images', '/audio', '/graphs', '/histograms'] # Slashes in a plugin name could throw the router for a loop. An empty @@ -82,16 +76,6 @@ TAB_ROUTES = ['', '/events', '/images', '/audio', '/graphs', '/histograms'] _VALID_PLUGIN_RE = re.compile(r'^[A-Za-z0-9_.-]+$') -class _OutputFormat(object): - """An enum used to list the valid output formats for API calls. - - Not all API calls support all formats (for example, only scalars and - compressed histograms support CSV). - """ - JSON = 'json' - CSV = 'csv' - - def standard_tensorboard_wsgi( logdir, purge_orphaned_data, @@ -161,22 +145,12 @@ class TensorBoardWSGIApp(object): reload_multiplexer(self._multiplexer, path_to_run) self.data_applications = { - DATA_PREFIX + AUDIO_ROUTE: - self._serve_audio, - DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE: - self._serve_compressed_histograms, - DATA_PREFIX + GRAPH_ROUTE: - self._serve_graph, - DATA_PREFIX + INDIVIDUAL_AUDIO_ROUTE: - self._serve_individual_audio, DATA_PREFIX + LOGDIR_ROUTE: self._serve_logdir, # TODO(chizeng): Delete this RPC once we have skylark rules that obviate # the need for the frontend to determine which plugins are active. DATA_PREFIX + PLUGINS_LISTING_ROUTE: self._serve_plugins_listing, - DATA_PREFIX + RUN_METADATA_ROUTE: - self._serve_run_metadata, DATA_PREFIX + RUNS_ROUTE: self._serve_runs, } @@ -209,30 +183,6 @@ class TensorBoardWSGIApp(object): path = DATA_PREFIX + PLUGIN_PREFIX + '/' + plugin.plugin_name + route self.data_applications[path] = app - # We use underscore_names for consistency with inherited methods. - - def _audio_response_for_run(self, run_audio, run, tag): - """Builds a JSON-serializable object with information about run_audio. - - Args: - run_audio: A list of event_accumulator.AudioValueEvent objects. - run: The name of the run. - tag: The name of the tag the audio files all belong to. - - Returns: - A list of dictionaries containing the wall time, step, URL, and - content_type for each audio clip. - """ - response = [] - for index, run_audio_clip in enumerate(run_audio): - response.append({ - 'wall_time': run_audio_clip.wall_time, - 'step': run_audio_clip.step, - 'content_type': run_audio_clip.content_type, - 'query': self._query_for_individual_audio(run, tag, index) - }) - return response - def _path_is_safe(self, path): """Check path is safe (stays within current directory). @@ -257,141 +207,6 @@ class TensorBoardWSGIApp(object): return http_util.Respond( request, {'logdir': self._logdir}, 'application/json') - @wrappers.Request.application - def _serve_graph(self, request): - """Given a single run, return the graph definition in json format.""" - run = request.args.get('run', None) - if run is None: - return http_util.Respond( - request, 'query parameter "run" is required', 'text/plain', 400) - - try: - graph = self._multiplexer.Graph(run) - except ValueError: - return http_util.Respond( - request, '404 Not Found', 'text/plain; charset=UTF-8', code=404) - - limit_attr_size = request.args.get('limit_attr_size', None) - if limit_attr_size is not None: - try: - limit_attr_size = int(limit_attr_size) - except ValueError: - return http_util.Respond( - request, 'query parameter `limit_attr_size` must be integer', - 'text/plain', 400) - - large_attrs_key = request.args.get('large_attrs_key', None) - try: - process_graph.prepare_graph_for_ui(graph, limit_attr_size, - large_attrs_key) - except ValueError as e: - return http_util.Respond(request, e.message, 'text/plain', 400) - - return http_util.Respond(request, str(graph), 'text/x-protobuf') # pbtxt - - @wrappers.Request.application - def _serve_run_metadata(self, request): - """Given a tag and a TensorFlow run, return the session.run() metadata.""" - tag = request.args.get('tag', None) - run = request.args.get('run', None) - if tag is None: - return http_util.Respond( - request, 'query parameter "tag" is required', 'text/plain', 400) - if run is None: - return http_util.Respond( - request, 'query parameter "run" is required', 'text/plain', 400) - try: - run_metadata = self._multiplexer.RunMetadata(run, tag) - except ValueError: - return http_util.Respond( - request, '404 Not Found', 'text/plain; charset=UTF-8', code=404) - return http_util.Respond( - request, str(run_metadata), 'text/x-protobuf') # pbtxt - - @wrappers.Request.application - def _serve_compressed_histograms(self, request): - """Given a tag and single run, return an array of compressed histograms.""" - tag = request.args.get('tag') - run = request.args.get('run') - compressed_histograms = self._multiplexer.CompressedHistograms(run, tag) - if request.args.get('format') == _OutputFormat.CSV: - string_io = StringIO() - writer = csv.writer(string_io) - - # Build the headers; we have two columns for timing and two columns for - # each compressed histogram bucket. - headers = ['Wall time', 'Step'] - if compressed_histograms: - bucket_count = len(compressed_histograms[0].compressed_histogram_values) - for i in xrange(bucket_count): - headers += ['Edge %d basis points' % i, 'Edge %d value' % i] - writer.writerow(headers) - - for compressed_histogram in compressed_histograms: - row = [compressed_histogram.wall_time, compressed_histogram.step] - for value in compressed_histogram.compressed_histogram_values: - row += [value.rank_in_bps, value.value] - writer.writerow(row) - return http_util.Respond(request, string_io.getvalue(), 'text/csv') - else: - return http_util.Respond( - request, compressed_histograms, 'application/json') - - @wrappers.Request.application - def _serve_audio(self, request): - """Given a tag and list of runs, serve a list of audio. - - Note that the audio clips themselves are not sent; instead, we respond with - URLs to the audio. The frontend should treat these URLs as opaque and should - not try to parse information about them or generate them itself, as the - format may change. - - Args: - request: A werkzeug.wrappers.Request object. - - Returns: - A werkzeug.Response application. - """ - tag = request.args.get('tag') - run = request.args.get('run') - - audio_list = self._multiplexer.Audio(run, tag) - response = self._audio_response_for_run(audio_list, run, tag) - return http_util.Respond(request, response, 'application/json') - - @wrappers.Request.application - def _serve_individual_audio(self, request): - """Serves an individual audio clip.""" - tag = request.args.get('tag') - run = request.args.get('run') - index = int(request.args.get('index')) - audio = self._multiplexer.Audio(run, tag)[index] - return http_util.Respond( - request, audio.encoded_audio_string, audio.content_type) - - def _query_for_individual_audio(self, run, tag, index): - """Builds a URL for accessing the specified audio. - - This should be kept in sync with _serve_individual_audio. Note that the URL - is *not* guaranteed to always return the same audio, since audio may be - unloaded from the reservoir as new audio comes in. - - Args: - run: The name of the run. - tag: The tag. - index: The index of the audio. Negative values are OK. - - Returns: - A string representation of a URL that will load the index-th - sampled audio in the given run with the given tag. - """ - query_string = urllib.parse.urlencode({ - 'run': run, - 'tag': tag, - 'index': index - }) - return query_string - @wrappers.Request.application def _serve_plugins_listing(self, request): """Serves an object mapping plugin name to whether it is enabled. @@ -418,8 +233,7 @@ class TensorBoardWSGIApp(object): Returns: A werkzeug Response with the following content: - {runName: {audio: [tag4, tag5, tag6], - firstEventTimestamp: 123456.789}} + {runName: {firstEventTimestamp: 123456.789}} """ runs = self._multiplexer.Runs() for run_name, run_data in runs.items(): diff --git a/tensorflow/tensorboard/backend/application_test.py b/tensorflow/tensorboard/backend/application_test.py index 08f3485047a..87cfdbc1d8d 100644 --- a/tensorflow/tensorboard/backend/application_test.py +++ b/tensorflow/tensorboard/backend/application_test.py @@ -35,7 +35,6 @@ from six.moves import http_client import tensorflow as tf from werkzeug import serving -from google.protobuf import text_format from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.tensorboard import tensorboard @@ -167,12 +166,8 @@ class TensorboardServerTest(tf.test.TestCase): run_json, { 'run1': { - 'compressedHistograms': ['histogram'], - 'audio': ['audio'], # if only_use_meta_graph, the graph is from the metagraph - 'graph': True, 'meta_graph': self._only_use_meta_graph, - 'run_metadata': ['test run'], 'tensors': [], } }) @@ -193,8 +188,7 @@ class TensorboardServerTest(tf.test.TestCase): def testDataPaths_disableAllCaching(self): """Test the format of the /data/runs endpoint.""" - for path in ('/data/runs', '/data/logdir', '/data/audio?run=run1&tag=audio', - '/data/run_metadata?run=run1&tag=test%20run'): + for path in ('/data/runs', '/data/logdir'): connection = http_client.HTTPConnection('localhost', self._server.server_address[1]) connection.request('GET', path) @@ -204,88 +198,11 @@ class TensorboardServerTest(tf.test.TestCase): response.read() connection.close() - def testAudio(self): - """Test listing audio and retrieving an individual audio clip.""" - audio_json = self._getJson('/data/audio?tag=audio&run=run1') - audio_query = audio_json[0]['query'] - # We don't care about the format of the audio query. - del audio_json[0]['query'] - self.assertEqual(audio_json, [{ - 'wall_time': 0, - 'step': 0, - 'content_type': 'audio/wav' - }]) - response = self._get('/data/individualAudio?%s' % audio_query) - self.assertEqual(response.status, 200) - - def testGraph(self): - """Test retrieving the graph definition.""" - response = self._get('/data/graph?run=run1&limit_attr_size=1024' - '&large_attrs_key=_very_large_attrs') - self.assertEqual(response.status, 200) - graph_pbtxt = response.read() - # Parse the graph from pbtxt into a graph message. - graph = tf.GraphDef() - graph = text_format.Parse(graph_pbtxt, graph) - self.assertEqual(len(graph.node), 2) - self.assertEqual(graph.node[0].name, 'a') - self.assertEqual(graph.node[1].name, 'b') - # Make sure the second node has an attribute that was filtered out because - # it was too large and was added to the "too large" attributes list. - self.assertEqual(list(graph.node[1].attr.keys()), ['_very_large_attrs']) - self.assertEqual(graph.node[1].attr['_very_large_attrs'].list.s, - [b'very_large_attr']) - - def testAcceptGzip_compressesResponse(self): - response = self._get('/data/graph?run=run1&limit_attr_size=1024' - '&large_attrs_key=_very_large_attrs', - {'Accept-Encoding': 'gzip'}) - self.assertEqual(response.status, 200) - self.assertEqual(response.getheader('Content-Encoding'), 'gzip') - pbtxt = gzip.GzipFile('', 'rb', 9, BytesIO(response.read())).read() - graph = text_format.Parse(pbtxt, tf.GraphDef()) - self.assertEqual(len(graph.node), 2) - - def testAcceptAnyEncoding_compressesResponse(self): - response = self._get('/data/graph?run=run1&limit_attr_size=1024' - '&large_attrs_key=_very_large_attrs', - {'Accept-Encoding': '*'}) - self.assertEqual(response.status, 200) - self.assertEqual(response.getheader('Content-Encoding'), 'gzip') - pbtxt = gzip.GzipFile('', 'rb', 9, BytesIO(response.read())).read() - graph = text_format.Parse(pbtxt, tf.GraphDef()) - self.assertEqual(len(graph.node), 2) - - def testAcceptDoodleEncoding_doesNotCompressResponse(self): - response = self._get('/data/graph?run=run1&limit_attr_size=1024' - '&large_attrs_key=_very_large_attrs', - {'Accept-Encoding': 'doodle'}) - self.assertEqual(response.status, 200) - self.assertIsNone(response.getheader('Content-Encoding')) - graph = text_format.Parse(response.read(), tf.GraphDef()) - self.assertEqual(len(graph.node), 2) - - def testRunMetadata(self): - """Test retrieving the run metadata information.""" - response = self._get('/data/run_metadata?run=run1&tag=test%20run') - self.assertEqual(response.status, 200) - run_metadata_pbtxt = response.read() - # Parse from pbtxt into a message. - run_metadata = tf.RunMetadata() - text_format.Parse(run_metadata_pbtxt, run_metadata) - self.assertEqual(len(run_metadata.step_stats.dev_stats), 1) - self.assertEqual(run_metadata.step_stats.dev_stats[0].device, 'test device') - def _GenerateTestData(self): """Generates the test data directory. The test data has a single run named run1 which contains: - - a histogram [1] - - a graph definition - - [1]: Histograms no longer appear in `/runs`, but compressed - histograms do, and they use the same test data. Thus, histograms are - still here for now. + - a graph definition and metagraph definition Returns: temp_dir: The directory the test data is generated under. @@ -296,14 +213,6 @@ class TensorboardServerTest(tf.test.TestCase): os.makedirs(run1_path) writer = tf.summary.FileWriter(run1_path) - histogram_value = tf.HistogramProto( - min=0, - max=2, - num=3, - sum=6, - sum_squares=5, - bucket_limit=[0, 1, 2], - bucket=[1, 1, 1]) # Add a simple graph event. graph_def = tf.GraphDef() node1 = graph_def.node.add() @@ -319,27 +228,6 @@ class TensorboardServerTest(tf.test.TestCase): else: writer.add_graph(graph_def) - # Add a simple run metadata event. - run_metadata = tf.RunMetadata() - device_stats = run_metadata.step_stats.dev_stats.add() - device_stats.device = 'test device' - writer.add_run_metadata(run_metadata, 'test run') - - audio_value = tf.Summary.Audio( - sample_rate=44100, - length_frames=22050, - num_channels=2, - encoded_audio_string=b'', - content_type='audio/wav') - writer.add_event( - tf.Event( - wall_time=0, - step=0, - summary=tf.Summary(value=[ - tf.Summary.Value(tag='histogram', histo=histogram_value), - tf.Summary.Value(tag='audio', audio=audio_value) - ]))) - writer.flush() writer.close() diff --git a/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py b/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py index 190ae6a96b4..d44f74a8a43 100644 --- a/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py +++ b/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py @@ -24,7 +24,6 @@ import shutil import tensorflow as tf -from tensorflow.python.platform import googletest from tensorflow.tensorboard.backend.event_processing import directory_watcher from tensorflow.tensorboard.backend.event_processing import io_wrapper @@ -55,7 +54,7 @@ class DirectoryWatcherTest(tf.test.TestCase): os.mkdir(self._directory) self._watcher = directory_watcher.DirectoryWatcher(self._directory, _ByteLoader) - self.stubs = googletest.StubOutForTesting() + self.stubs = tf.test.StubOutForTesting() def tearDown(self): self.stubs.CleanUp() diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator.py index 1669c060844..1562f0f8339 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py +++ b/tensorflow/tensorboard/backend/event_processing/event_accumulator.py @@ -72,7 +72,7 @@ SUMMARY_TYPES = { ## The tagTypes below are just arbitrary strings chosen to pass the type ## information of the tag from the backend to the frontend -COMPRESSED_HISTOGRAMS = 'compressedHistograms' +COMPRESSED_HISTOGRAMS = 'distributions' HISTOGRAMS = 'histograms' IMAGES = 'images' AUDIO = 'audio' diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py index a2ac371a931..9efd64bd2ef 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py +++ b/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py @@ -24,7 +24,6 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf -from tensorflow.python.platform import googletest from tensorflow.python.summary.writer.writer import SummaryToEventTransformer from tensorflow.tensorboard.backend.event_processing import event_accumulator as ea @@ -182,7 +181,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): def setUp(self): super(MockingEventAccumulatorTest, self).setUp() - self.stubs = googletest.StubOutForTesting() + self.stubs = tf.test.StubOutForTesting() self._real_constructor = ea.EventAccumulator self._real_generator = ea._GeneratorFromPath diff --git a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py b/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py index a7a6413ad1f..ea536dfaad6 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py +++ b/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py @@ -24,7 +24,6 @@ import shutil import tensorflow as tf -from tensorflow.python.platform import googletest from tensorflow.tensorboard.backend.event_processing import event_accumulator from tensorflow.tensorboard.backend.event_processing import event_multiplexer @@ -116,7 +115,7 @@ class EventMultiplexerTest(tf.test.TestCase): def setUp(self): super(EventMultiplexerTest, self).setUp() - self.stubs = googletest.StubOutForTesting() + self.stubs = tf.test.StubOutForTesting() self.stubs.Set(event_accumulator, 'EventAccumulator', _GetFakeAccumulator) diff --git a/tensorflow/tensorboard/components/BUILD b/tensorflow/tensorboard/components/BUILD index 1cc2b7caf28..6a0052b793d 100644 --- a/tensorflow/tensorboard/components/BUILD +++ b/tensorflow/tensorboard/components/BUILD @@ -1,21 +1,18 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") load("//tensorflow/tensorboard:vulcanize.bzl", "tensorboard_html_binary") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tensorboard", srcs = [ "analytics.html", "tensorboard.html", ], path = "/", - deps = [ - "//tensorflow/tensorboard/components/tf_tensorboard", - "@org_polymer_webcomponentsjs", - ], + deps = ["//tensorflow/tensorboard/components/tf_tensorboard"], ) tensorboard_html_binary( diff --git a/tensorflow/tensorboard/components/tensorboard.html b/tensorflow/tensorboard/components/tensorboard.html index 0652902cfac..afaf396614f 100644 --- a/tensorflow/tensorboard/components/tensorboard.html +++ b/tensorflow/tensorboard/components/tensorboard.html @@ -19,7 +19,6 @@ limitations under the License. TensorBoard - diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD b/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD index 1e599cb710f..18009043d23 100644 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD @@ -1,10 +1,10 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_audio_dashboard", srcs = [ "tf-audio-dashboard.html", @@ -17,7 +17,7 @@ web_library( "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:lodash", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_paper_icon_button", "@org_polymer_paper_slider", "@org_polymer_paper_spinner", @@ -25,7 +25,7 @@ web_library( ], ) -web_library( +ts_web_library( name = "index", srcs = [ "demo/index.html", @@ -35,11 +35,11 @@ web_library( deps = [ ":tf_audio_dashboard", "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "//tensorflow/tensorboard/demo:demo_data", "@org_polymer_iron_component_page", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html index 177bc85db0d..dc8cd91d439 100644 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html +++ b/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html @@ -42,14 +42,16 @@ limitations under the License. + + + diff --git a/tensorflow/tensorboard/components/tf_backend/tf-backend.html b/tensorflow/tensorboard/components/tf_backend/tf-backend.html index 5bf26633628..4cfed247a5e 100644 --- a/tensorflow/tensorboard/components/tf_backend/tf-backend.html +++ b/tensorflow/tensorboard/components/tf_backend/tf-backend.html @@ -20,4 +20,8 @@ limitations under the License. - + + + + + diff --git a/tensorflow/tensorboard/components/tf_color_scale/BUILD b/tensorflow/tensorboard/components/tf_color_scale/BUILD index 3ec3d26051f..afe98ec5b5e 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/BUILD +++ b/tensorflow/tensorboard/components/tf_color_scale/BUILD @@ -1,57 +1,37 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_color_scale", srcs = [ - "bundle.js", + "colorScale.ts", + "palettes.ts", "tf-color-scale.html", ], path = "/tf-color-scale", deps = [ "//tensorflow/tensorboard/components/tf_imports:d3", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/tf-color-scale", deps = [ ":tf_color_scale", "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_button", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"TF": [ - "palettes.ts", - "colorScale.ts", - ]}, -) - filegroup( name = "all_files", srcs = glob(["**"]), diff --git a/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts b/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts index ff90d46aa24..6916e3bb2dd 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts +++ b/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts @@ -19,9 +19,8 @@ limitations under the License. // ccs.domain(runs); // ccs.getColor("train"); // ccs.getColor("test1"); -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import {palettes} from './palettes' +import {palettes} from './palettes' export class ColorScale { private palette: string[]; @@ -29,8 +28,8 @@ export class ColorScale { /** * Creates a color scale with optional custom palette. - * @param {string[]} [palette=palettes.googleColorBlind] - The color - * palette you want as an Array of hex strings. + * @param {Array} [palette=palettes.googleColorBlind] - The color + * palette you want as an Array of hex strings. */ constructor(palette: string[] = palettes.googleColorBlindAssist) { this.palette = palette; @@ -38,8 +37,8 @@ export class ColorScale { /** * Set the domain of strings. - * @param {string[]} strings - An array of possible strings to use as the - * domain for your scale. + * @param {Array} strings - An array of possible strings to use as the + * domain for your scale. */ public domain(strings: string[]): this { this.identifiers = d3.map(); diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/BUILD b/tensorflow/tensorboard/components/tf_color_scale/test/BUILD index 6071b20886e..dab2779dc3c 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/test/BUILD +++ b/tensorflow/tensorboard/components/tf_color_scale/test/BUILD @@ -3,43 +3,25 @@ package( default_visibility = ["//tensorflow:internal"], ) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "test", srcs = [ - "bundle.js", + "colorScaleTests.ts", "tests.html", ], path = "/tf-color-scale/test", deps = [ "//tensorflow/tensorboard/components/tf_color_scale", - "@org_npmjs_registry_web_component_tester", - "@org_polymer", - "@org_polymer_webcomponentsjs", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:chai.d.ts", - "@org_definitelytyped//:mocha.d.ts", - "//tensorflow/tensorboard/components/tf_color_scale:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"TF": ["colorScaleTests.ts"]}, -) - filegroup( name = "all_files", testonly = 0, diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/tests.html b/tensorflow/tensorboard/components/tf_color_scale/test/tests.html index eccc32cdec5..59c802d02bf 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/test/tests.html +++ b/tensorflow/tensorboard/components/tf_color_scale/test/tests.html @@ -21,4 +21,4 @@ limitations under the License. - + diff --git a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html b/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html index 3dedfaf1a1c..a325f0a04cd 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html +++ b/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html @@ -26,5 +26,6 @@ a set of colors. @element tf-color-scale --> - + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/BUILD b/tensorflow/tensorboard/components/tf_dashboard_common/BUILD index f9a990e3799..b504fe79f99 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/BUILD +++ b/tensorflow/tensorboard/components/tf_dashboard_common/BUILD @@ -1,33 +1,33 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_dashboard_common", srcs = [ + "dashboard-behavior.ts", "dashboard-style.html", + "reload-behavior.ts", "run-color-style.html", "scrollbar-style.html", "tensorboard-color.html", "tf-categorizer.html", - "tf-categorizer-bundle.js", + "tf-categorizer.ts", "tf-chart-scaffold.html", "tf-collapsable-pane.html", "tf-dashboard.html", - "tf-dashboard.js", "tf-dashboard-layout.html", "tf-downloader.html", "tf-multi-checkbox.html", - "tf-multi-checkbox-bundle.js", + "tf-multi-checkbox.ts", "tf-no-data-warning.html", "tf-option-selector.html", "tf-panes-helper.html", "tf-regex-group.html", - "tf-regex-group-bundle.js", + "tf-regex-group.ts", "tf-run-selector.html", "tf-sidebar-helper.html", ], @@ -35,9 +35,9 @@ web_library( deps = [ "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:lodash", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/tf_storage", "//tensorflow/tensorboard/components/vz_sorting", - "@org_polymer", "@org_polymer_iron_ajax", "@org_polymer_iron_collapse", "@org_polymer_iron_icons", @@ -56,7 +56,7 @@ web_library( ], ) -web_library( +ts_web_library( name = "demo", srcs = [ "tf-categorizer-demo.html", @@ -73,91 +73,9 @@ web_library( ], ) -tensorboard_typescript_bundle( - name = "tf_categorizer_bundle", - out = "tf-categorizer-bundle.ts", - namespace_srcs = {"TF.Dashboard.Categorizer": ["tf-categorizer.ts"]}, - namespace_symbol_aliases = {"TF.Dashboard.Categorizer": {"compareTagNames": "VZ.Sorting.compareTagNames"}}, -) - -tensorboard_typescript_genrule( - name = "tf_categorizer_ts", - srcs = ["tf-categorizer-bundle.ts"], - typings = [ - "@org_definitelytyped//:lodash.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - "//tensorflow/tensorboard/components/vz_sorting:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "tf_regex_group_bundle", - out = "tf-regex-group-bundle.ts", - namespace_srcs = {"TF.Dashboard.RegexGroup": ["tf-regex-group.ts"]}, - namespace_symbol_aliases = {"TF.Dashboard.RegexGroup": {"storage": "TF.URIStorage"}}, -) - -tensorboard_typescript_genrule( - name = "tf_regex_group_ts", - srcs = ["tf-regex-group-bundle.ts"], - typings = [ - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_storage:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "tf_multi_checkbox_bundle", - out = "tf-multi-checkbox-bundle.ts", - namespace_srcs = {"TF.Dashboard.MultiCheckbox": ["tf-multi-checkbox.ts"]}, - namespace_symbol_aliases = {"TF.Dashboard.MultiCheckbox": {"storage": "TF.URIStorage"}}, -) - -tensorboard_typescript_genrule( - name = "tf_multi_checkbox_ts", - srcs = ["tf-multi-checkbox-bundle.ts"], - typings = [ - "@org_definitelytyped//:lodash.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_storage:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "tf_dashboard_bundle", - out = "tf-dashboard.ts", - namespace_srcs = { - "TF.Dashboard": [ - "dashboard-behavior.ts", - "reload-behavior.ts", - ], - }, -) - -tensorboard_typescript_genrule( - name = "tf_dashboard_ts", - srcs = ["tf-dashboard.ts"], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - tensorboard_webcomponent_library( name = "legacy", - srcs = glob(["*.html"]) + [":legacy_ts"], + srcs = [":tf_dashboard_common"], destdir = "tf-dashboard-common", deps = [ "//tensorflow/tensorboard/components/tf_imports_google:lib", @@ -182,19 +100,8 @@ tensorboard_webcomponent_library( ], ) -tensorboard_ts_library( - name = "legacy_ts", - srcs = [ - "dashboard-behavior.ts", - "reload-behavior.ts", - "tf-categorizer.ts", - ], - deps_mgmt = "off", - runtime = "nodejs", - deps = [ - "//tensorflow/tensorboard/components/vz_sorting:legacy_ts", - "//third_party/javascript/typings/d3_v4:bundle", - "//third_party/javascript/typings/lodash", - "//third_party/javascript/typings/polymer:polymer_without_externs", - ], +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], ) diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts b/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts index 3e40da14528..aa063c74220 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts +++ b/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts @@ -16,6 +16,8 @@ limitations under the License. /** * A behavior that TensorBoard dashboards must implement. This behavior serves * the purpose of an interface. + * + * @polymerBehavior */ export function DashboardBehavior(dashboardName) { return { diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts b/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts index 8b5ca120d60..61fe0c07812 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts +++ b/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts @@ -20,6 +20,8 @@ limitations under the License. * and call a `reload` method on that child. * May later extend it so it has more sophisticated logic, e.g. reloading * only tags that are in view. + * + * @polymerBehavior */ export function ReloadBehavior(tagName) { return { diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD b/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD index e82c4bd63cd..3cad646b967 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD +++ b/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD @@ -3,44 +3,25 @@ package( default_visibility = ["//tensorflow:internal"], ) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "test", srcs = [ - "bundle.js", "tests.html", + "tf-categorizer-tests.ts", ], path = "/tf-dashboard-common/test", deps = [ "//tensorflow/tensorboard/components/tf_dashboard_common", - "@org_npmjs_registry_web_component_tester", - "@org_polymer", - "@org_polymer_webcomponentsjs", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:chai.d.ts", - "@org_definitelytyped//:mocha.d.ts", - "//tensorflow/tensorboard/components/tf_dashboard_common:tf-categorizer-bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"TF.Dashboard": ["tf-categorizer-tests.ts"]}, - namespace_symbol_aliases = {"TF.Dashboard": {"cat": "TF.Dashboard.Categorizer"}}, -) - filegroup( name = "all_files", testonly = 0, diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html b/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html index cd33cee4742..c9ad14730f0 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html @@ -21,4 +21,4 @@ limitations under the License. - + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html index 6388ab5e7d4..f09eb03582d 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html @@ -59,5 +59,5 @@ categories are exclusive. } - + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts index ebece842461..0eaf852ff13 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts @@ -13,10 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as _ from 'lodash'; - -import {compareTagNames} from '../vz_sorting/sorting'; +import {compareTagNames} from '../vz-sorting/sorting'; /** * This module contains methods that allow sorting tags into 'categories'. diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html index 475c2cef3bd..9e2f6b9589b 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html @@ -22,4 +22,5 @@ limitations under the License. - + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html index 8a56616f820..fad4642963f 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html @@ -156,5 +156,5 @@ handle these situations gracefully. } - + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts index 44a14a21cfe..4b38d82b14e 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as _ from 'lodash'; import * as storage from '../tf-storage/storage'; Polymer({ diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html index e68b306ee33..c1d3cf06aea 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html @@ -95,5 +95,5 @@ more regexes). - + diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD b/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD index dcd5047bf49..fe089b80b42 100644 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD @@ -1,10 +1,10 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_distribution_dashboard", srcs = ["tf-distribution-dashboard.html"], path = "/tf-distribution-dashboard", @@ -13,24 +13,24 @@ web_library( "//tensorflow/tensorboard/components/tf_color_scale", "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:lodash", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/vz_distribution_chart", - "@org_polymer", "@org_polymer_iron_collapse", "@org_polymer_paper_icon_button", "@org_polymer_paper_styles", ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-distribution-dashboard", deps = [ ":tf_distribution_dashboard", "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html b/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html index 5e825f13f5c..2c300446480 100644 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html +++ b/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html @@ -43,14 +43,17 @@ limitations under the License. diff --git a/tensorflow/tensorboard/components/tf_globals/BUILD b/tensorflow/tensorboard/components/tf_globals/BUILD index ca59c2fb93a..0ffefd79682 100644 --- a/tensorflow/tensorboard/components/tf_globals/BUILD +++ b/tensorflow/tensorboard/components/tf_globals/BUILD @@ -1,29 +1,23 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_globals", srcs = [ - "bundle.js", + "globals.ts", "tf-globals.html", ], path = "/tf-globals", ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"TF.Globals": ["globals.ts"]}, +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_globals"], + destdir = "tf-globals", ) filegroup( @@ -31,25 +25,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-globals.html", - ":legacy_ts", - ], - destdir = "tf-globals", -) - -tensorboard_ts_library( - name = "legacy_ts", - srcs = ["globals.ts"], - deps_mgmt = "off", - runtime = "nodejs", -) diff --git a/tensorflow/tensorboard/components/tf_globals/globals.ts b/tensorflow/tensorboard/components/tf_globals/globals.ts index 7d4229dccb0..fb6bb83b97f 100644 --- a/tensorflow/tensorboard/components/tf_globals/globals.ts +++ b/tensorflow/tensorboard/components/tf_globals/globals.ts @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - - // The names of TensorBoard tabs. export const TABS = [ 'scalars', 'images', 'audio', 'graphs', 'distributions', 'histograms', diff --git a/tensorflow/tensorboard/components/tf_globals/tf-globals.html b/tensorflow/tensorboard/components/tf_globals/tf-globals.html index b0fd74d4f20..efb8e92e080 100644 --- a/tensorflow/tensorboard/components/tf_globals/tf-globals.html +++ b/tensorflow/tensorboard/components/tf_globals/tf-globals.html @@ -15,5 +15,5 @@ See the License for the specific language governing permissions and limitations under the License. --> - + diff --git a/tensorflow/tensorboard/components/tf_graph/BUILD b/tensorflow/tensorboard/components/tf_graph/BUILD index 115964a59bd..92d2e8a42a1 100644 --- a/tensorflow/tensorboard/components/tf_graph/BUILD +++ b/tensorflow/tensorboard/components/tf_graph/BUILD @@ -1,10 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph", srcs = [ "tf-graph.html", @@ -15,7 +16,7 @@ web_library( deps = [ "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_graph_common", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_iron_flex_layout", "@org_polymer_iron_icons", "@org_polymer_paper_button", @@ -28,26 +29,28 @@ web_library( ], ) +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph"], + destdir = "tf-graph", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", + "//tensorflow/tensorboard/components/tf_graph_common:legacy", + "//third_party/javascript/polymer/v1/iron-flex-layout:lib", + "//third_party/javascript/polymer/v1/iron-icons:lib", + "//third_party/javascript/polymer/v1/paper-button:lib", + "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", + "//third_party/javascript/polymer/v1/paper-input:lib", + "//third_party/javascript/polymer/v1/paper-menu:lib", + "//third_party/javascript/polymer/v1/paper-radio-group:lib", + "//third_party/javascript/polymer/v1/paper-toggle-button:lib", + "//third_party/javascript/polymer/v1/paper-tooltip:lib", + "//third_party/javascript/polymer/v1/polymer:lib", + ], +) + filegroup( name = "all_files", srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph.html", - "tf-graph-minimap.html", - "tf-graph-scene.html", - ], - destdir = "tf-graph", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph/demo/BUILD b/tensorflow/tensorboard/components/tf_graph/demo/BUILD index 524d0ff7679..b578a51798b 100644 --- a/tensorflow/tensorboard/components/tf_graph/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph/demo/BUILD @@ -1,11 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-graph/demo", @@ -13,9 +13,9 @@ web_library( "//tensorflow/tensorboard/components/tf_graph", "//tensorflow/tensorboard/components/tf_graph_common", "//tensorflow/tensorboard/components/tf_graph_loader", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html b/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html index 35705713b98..65306acf8bf 100644 --- a/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html +++ b/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html @@ -941,7 +941,7 @@ Polymer({ delete this._nodeGroupIndex[n]; }, addEdgeGroup: function(n, selection) { - this._edgeGroupIndex[e] = selection; + this._edgeGroupIndex[n] = selection; }, getEdgeGroup: function(e) { return this._edgeGroupIndex[e]; diff --git a/tensorflow/tensorboard/components/tf_graph_app/BUILD b/tensorflow/tensorboard/components/tf_graph_app/BUILD index 415b20598ec..af568893821 100644 --- a/tensorflow/tensorboard/components/tf_graph_app/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_app/BUILD @@ -1,11 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_app", srcs = [ "index.html", @@ -16,9 +16,23 @@ web_library( "//tensorflow/tensorboard/components/tf_graph_board", "//tensorflow/tensorboard/components/tf_graph_controls", "//tensorflow/tensorboard/components/tf_graph_loader", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_component_page", - "@org_polymer_webcomponentsjs", + ], +) + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_app"], + destdir = "tf-graph-app", + deps = [ + "//tensorflow/tensorboard/components/tf_graph_board:legacy", + "//tensorflow/tensorboard/components/tf_graph_controls:legacy", + "//tensorflow/tensorboard/components/tf_graph_loader:legacy", + "//third_party/javascript/polymer/v1/iron-component-page:lib", + "//third_party/javascript/polymer/v1/polymer:lib", + "//third_party/javascript/polymer/v1/webcomponentsjs:lib", ], ) @@ -27,23 +41,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "index.html", - "tf-graph-app.html", - ], - destdir = "tf-graph-app", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_board:legacy", - "//tensorflow/tensorboard/components/tf_graph_controls:legacy", - "//tensorflow/tensorboard/components/tf_graph_loader:legacy", - "//third_party/javascript/polymer/v1/iron-list:lib", - "//third_party/javascript/polymer/v1/paper-radio-group:lib", - "//third_party/javascript/polymer/v1/paper-tooltip:lib", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD index 147cb0947c4..0f984664ce2 100644 --- a/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD @@ -1,11 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph_app/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-graph-app/demo", diff --git a/tensorflow/tensorboard/components/tf_graph_board/BUILD b/tensorflow/tensorboard/components/tf_graph_board/BUILD index f1c1ed1fc0f..14a66166582 100644 --- a/tensorflow/tensorboard/components/tf_graph_board/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_board/BUILD @@ -1,44 +1,38 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_board", - srcs = [ - "tf-graph-board.html", - ], + srcs = ["tf-graph-board.html"], path = "/tf-graph-board", deps = [ "//tensorflow/tensorboard/components/tf_graph", "//tensorflow/tensorboard/components/tf_graph_common", "//tensorflow/tensorboard/components/tf_graph_info", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_paper_progress", ], ) +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_board"], + destdir = "tf-graph-board", + deps = [ + "//tensorflow/tensorboard/components/tf_graph:legacy", + "//tensorflow/tensorboard/components/tf_graph_common:legacy", + "//tensorflow/tensorboard/components/tf_graph_info:legacy", + "//third_party/javascript/polymer/v1/paper-progress:lib", + "//third_party/javascript/polymer/v1/polymer:lib", + ], +) + filegroup( name = "all_files", srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph-board.html", - ], - destdir = "tf-graph-board", - deps = [ - "//tensorflow/tensorboard/components/tf_graph:legacy", - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_info:legacy", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD index 2d668769e62..4bf52c5a567 100644 --- a/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD @@ -1,11 +1,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph_board/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-graph-board/demo", @@ -13,9 +13,9 @@ web_library( "//tensorflow/tensorboard/components/tf_graph_board", "//tensorflow/tensorboard/components/tf_graph_common", "//tensorflow/tensorboard/components/tf_graph_loader", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html b/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html index 0ee694e1e66..79409ce2a0c 100644 --- a/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html +++ b/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html @@ -180,10 +180,9 @@ Polymer({ graph: Object, stats: Object, /** - * @type {value: number, msg: string} - * * A number between 0 and 100 denoting the % of progress * for the progress bar and the displayed message. + * @type {{value: number, msg: string}} */ progress: Object, colorBy: String, diff --git a/tensorflow/tensorboard/components/tf_graph_common/BUILD b/tensorflow/tensorboard/components/tf_graph_common/BUILD index a372ab8279b..25e0403aa34 100644 --- a/tensorflow/tensorboard/components/tf_graph_common/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_common/BUILD @@ -1,15 +1,31 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_common", srcs = [ + "annotation.ts", + "colors.ts", + "common.ts", + "contextmenu.ts", + "edge.ts", + "externs.ts", + "graph.ts", + "hierarchy.ts", + "layout.ts", + "minimap.ts", + "node.ts", + "parser.ts", + "proto.ts", + "render.ts", + "scene.ts", + "template.ts", "tf-graph-common.html", - ":ts", + "util.ts", ], path = "/tf-graph-common", deps = [ @@ -17,18 +33,17 @@ web_library( "//tensorflow/tensorboard/components/tf_imports:dagre", "//tensorflow/tensorboard/components/tf_imports:graphlib", "//tensorflow/tensorboard/components/tf_imports:lodash", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = glob(["*.ts"]), - typings = [ - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - "@org_definitelytyped//:lodash.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_common"], + destdir = "tf-graph-common", + deps = [ + "//tensorflow/tensorboard/components/tf_imports_google:lib", + "//third_party/javascript/polymer/v1/polymer:lib", ], ) @@ -37,36 +52,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph-common.html", - ":legacy_ts", - ], - destdir = "tf-graph-common", - deps = [ - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -tensorboard_ts_library( - name = "legacy_ts", - srcs = glob(["*.ts"]), - deps_mgmt = "off", - runtime = "nodejs", - deps = [ - "//third_party/javascript/node_modules/typescript:es2015.promise", - "//third_party/javascript/typings/d3_v4:bundle", - "//third_party/javascript/typings/lodash", - "//third_party/javascript/typings/polymer:polymer_without_externs", - "//third_party/javascript/typings/webcomponents_js", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_controls/BUILD b/tensorflow/tensorboard/components/tf_graph_controls/BUILD index 65cafa9570a..7004b7145a3 100644 --- a/tensorflow/tensorboard/components/tf_graph_controls/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_controls/BUILD @@ -1,19 +1,18 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_controls", - srcs = [ - "tf-graph-controls.html", - ], + srcs = ["tf-graph-controls.html"], path = "/tf-graph-controls", deps = [ "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_graph_common", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_paper_button", "@org_polymer_paper_dropdown_menu", "@org_polymer_paper_menu", @@ -23,25 +22,25 @@ web_library( ], ) +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_controls"], + destdir = "tf-graph-controls", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", + "//tensorflow/tensorboard/components/tf_graph_common:legacy", + "//third_party/javascript/polymer/v1/paper-button:lib", + "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", + "//third_party/javascript/polymer/v1/paper-menu:lib", + "//third_party/javascript/polymer/v1/paper-radio-group:lib", + "//third_party/javascript/polymer/v1/paper-toggle-button:lib", + "//third_party/javascript/polymer/v1/paper-tooltip:lib", + "//third_party/javascript/polymer/v1/polymer:lib", + ], +) + filegroup( name = "all_files", srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph-controls.html", - ], - destdir = "tf-graph-controls", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD index c47cb90a03e..cd86ac7320a 100644 --- a/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD @@ -1,19 +1,19 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph_controls/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/tf-graph-controls/demo", deps = [ "//tensorflow/tensorboard/components/tf_graph_controls", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD b/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD index d1866b5d807..20f9d3990b5 100644 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD @@ -1,14 +1,13 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_graph_dashboard", - srcs = [ - "tf-graph-dashboard.html", - ], + srcs = ["tf-graph-dashboard.html"], path = "/tf-graph-dashboard", deps = [ "//tensorflow/tensorboard/components/tf_backend", @@ -17,7 +16,24 @@ web_library( "//tensorflow/tensorboard/components/tf_graph_board", "//tensorflow/tensorboard/components/tf_graph_controls", "//tensorflow/tensorboard/components/tf_graph_loader", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/vz_sorting", + ], +) + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph_dashboard"], + destdir = "tf-graph-dashboard", + deps = [ + "//tensorflow/tensorboard/components/tf_backend:legacy", + "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", + "//tensorflow/tensorboard/components/tf_graph:legacy", + "//tensorflow/tensorboard/components/tf_graph_board:legacy", + "//tensorflow/tensorboard/components/tf_graph_controls:legacy", + "//tensorflow/tensorboard/components/tf_graph_loader:legacy", + "//tensorflow/tensorboard/components/vz_sorting:legacy", + "//third_party/javascript/polymer/v1/polymer:lib", ], ) @@ -26,23 +42,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "tf-graph-dashboard.html", - ], - destdir = "tf-graph-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", - "//tensorflow/tensorboard/components/tf_graph:legacy", - "//tensorflow/tensorboard/components/tf_graph_board:legacy", - "//tensorflow/tensorboard/components/tf_graph_controls:legacy", - "//tensorflow/tensorboard/components/tf_graph_loader:legacy", - ], -) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD index 3658f45b153..58cd2854c57 100644 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD @@ -1,19 +1,19 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 # bazel run //third_party/tensorflow/tensorboard/components/tf_graph_dashboard/demo -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-graph-dashboard/demo", deps = [ "//tensorflow/tensorboard/components/tf_graph_dashboard", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html index 67756cc1298..2035e87898a 100644 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html +++ b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html @@ -37,14 +37,17 @@ limitations under the License. + diff --git a/tensorflow/tensorboard/components/tf_imports/dagre.html b/tensorflow/tensorboard/components/tf_imports/dagre.html index 1e2f6ef9af6..cb57b9a5cd8 100644 --- a/tensorflow/tensorboard/components/tf_imports/dagre.html +++ b/tensorflow/tensorboard/components/tf_imports/dagre.html @@ -42,4 +42,4 @@ THE SOFTWARE. - + diff --git a/tensorflow/tensorboard/components/tf_imports/graphlib.html b/tensorflow/tensorboard/components/tf_imports/graphlib.html index 783e33be0a6..05942123ab0 100644 --- a/tensorflow/tensorboard/components/tf_imports/graphlib.html +++ b/tensorflow/tensorboard/components/tf_imports/graphlib.html @@ -17,4 +17,4 @@ limitations under the License. - + diff --git a/tensorflow/tensorboard/components/tf_imports/lodash.html b/tensorflow/tensorboard/components/tf_imports/lodash.html index cbe35f10505..65ff6a4b032 100644 --- a/tensorflow/tensorboard/components/tf_imports/lodash.html +++ b/tensorflow/tensorboard/components/tf_imports/lodash.html @@ -15,4 +15,4 @@ See the License for the specific language governing permissions and limitations under the License. --> - + diff --git a/tensorflow/tensorboard/components/tf_imports/numericjs.html b/tensorflow/tensorboard/components/tf_imports/numericjs.html index 7559054aaba..81fa9491688 100644 --- a/tensorflow/tensorboard/components/tf_imports/numericjs.html +++ b/tensorflow/tensorboard/components/tf_imports/numericjs.html @@ -40,4 +40,4 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --> - + diff --git a/tensorflow/tensorboard/components/tf_imports/plottable.html b/tensorflow/tensorboard/components/tf_imports/plottable.html index 2c3e10a7c44..77ad544d5a0 100644 --- a/tensorflow/tensorboard/components/tf_imports/plottable.html +++ b/tensorflow/tensorboard/components/tf_imports/plottable.html @@ -40,5 +40,5 @@ THE SOFTWARE. --> - + diff --git a/tensorflow/tensorboard/components/tf_imports/threejs.html b/tensorflow/tensorboard/components/tf_imports/threejs.html index d6adad43b03..7f4233b5713 100644 --- a/tensorflow/tensorboard/components/tf_imports/threejs.html +++ b/tensorflow/tensorboard/components/tf_imports/threejs.html @@ -39,5 +39,5 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --> - - + + diff --git a/tensorflow/tensorboard/components/tf_imports/weblas.html b/tensorflow/tensorboard/components/tf_imports/weblas.html index 054d04ea85e..c07020598fc 100644 --- a/tensorflow/tensorboard/components/tf_imports/weblas.html +++ b/tensorflow/tensorboard/components/tf_imports/weblas.html @@ -39,4 +39,4 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --> - + diff --git a/tensorflow/tensorboard/components/tf_option_selector/BUILD b/tensorflow/tensorboard/components/tf_option_selector/BUILD index 6f79ac536ab..cd0150529e7 100644 --- a/tensorflow/tensorboard/components/tf_option_selector/BUILD +++ b/tensorflow/tensorboard/components/tf_option_selector/BUILD @@ -1,16 +1,16 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_option_selector", srcs = ["tf-option-selector.html"], path = "/tf-option-selector", deps = [ "//tensorflow/tensorboard/components/tf_dashboard_common", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD index f2a491a2b25..2de11a231e6 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD @@ -1,10 +1,10 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_scalar_dashboard", srcs = [ "tf-scalar-dashboard.html", @@ -16,8 +16,8 @@ web_library( "//tensorflow/tensorboard/components/tf_color_scale", "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:lodash", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/vz_line_chart", - "@org_polymer", "@org_polymer_iron_collapse", "@org_polymer_paper_checkbox", "@org_polymer_paper_dropdown_menu", diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD index 3b135d68afc..497767363ec 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD @@ -1,22 +1,22 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/tf-scalar-dashboard/demo", deps = [ "//tensorflow/tensorboard/components/tf_backend", "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "//tensorflow/tensorboard/components/tf_scalar_dashboard", "//tensorflow/tensorboard/demo:demo_data", - "@org_polymer", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html index 7429c87b873..10cf83b2e9a 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html @@ -45,14 +45,17 @@ limitations under the License. + diff --git a/tensorflow/tensorboard/components/tf_storage/tf-storage.html b/tensorflow/tensorboard/components/tf_storage/tf-storage.html index 91b8976519d..ff3f7b0ad4a 100644 --- a/tensorflow/tensorboard/components/tf_storage/tf-storage.html +++ b/tensorflow/tensorboard/components/tf_storage/tf-storage.html @@ -18,4 +18,4 @@ limitations under the License. - + diff --git a/tensorflow/tensorboard/components/tf_tensorboard/BUILD b/tensorflow/tensorboard/components/tf_tensorboard/BUILD index b649bb53f2a..72f9a0852ae 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/BUILD +++ b/tensorflow/tensorboard/components/tf_tensorboard/BUILD @@ -1,16 +1,16 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") +load("//tensorflow/tensorboard:vulcanize.bzl", "tensorboard_html_binary") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_tensorboard", srcs = [ + "autoReloadBehavior.ts", "style.html", "tf-tensorboard.html", - ":ts", ], path = "/tf-tensorboard", visibility = ["//visibility:public"], @@ -23,11 +23,11 @@ web_library( "//tensorflow/tensorboard/components/tf_graph_dashboard", "//tensorflow/tensorboard/components/tf_histogram_dashboard", "//tensorflow/tensorboard/components/tf_image_dashboard", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/tf_scalar_dashboard", "//tensorflow/tensorboard/components/tf_storage", "//tensorflow/tensorboard/components/tf_text_dashboard", "//tensorflow/tensorboard/components/vz_projector", - "@org_polymer", "@org_polymer_font_roboto", "@org_polymer_iron_icons", "@org_polymer_paper_button", @@ -40,20 +40,22 @@ web_library( ], ) -web_library( +ts_web_library( name = "demo", srcs = ["demo.html"], path = "/tf-tensorboard", deps = [ ":tf_tensorboard", "//tensorflow/tensorboard/demo:demo_data", - "@org_polymer_webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["autoReloadBehavior.ts"], +tensorboard_html_binary( + name = "devserver", + testonly = 1, + input_path = "/tf-tensorboard/demo.html", + output_path = "/index.html", + deps = [":demo"], ) filegroup( diff --git a/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts b/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts index 1f6b4cf6419..54df16f5b5d 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts @@ -12,49 +12,51 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module TF.TensorBoard { - export var AUTORELOAD_LOCALSTORAGE_KEY = 'TF.TensorBoard.autoReloadEnabled'; - var getAutoReloadFromLocalStorage: () => boolean = () => { - var val = window.localStorage.getItem(AUTORELOAD_LOCALSTORAGE_KEY); - return val === 'true' || val == null; // defaults to true - }; +export var AUTORELOAD_LOCALSTORAGE_KEY = 'TF.TensorBoard.autoReloadEnabled'; - export var AutoReloadBehavior = { - properties: { - autoReloadEnabled: { - type: Boolean, - observer: '_autoReloadObserver', - value: getAutoReloadFromLocalStorage, - }, - _autoReloadId: { - type: Number, - }, - autoReloadIntervalSecs: { - type: Number, - value: 30, - }, +var getAutoReloadFromLocalStorage: () => boolean = () => { + var val = window.localStorage.getItem(AUTORELOAD_LOCALSTORAGE_KEY); + return val === 'true' || val == null; // defaults to true +}; + +/** + * @polymerBehavior + */ +export var AutoReloadBehavior = { + properties: { + autoReloadEnabled: { + type: Boolean, + observer: '_autoReloadObserver', + value: getAutoReloadFromLocalStorage, }, - detached: function() { - window.clearTimeout(this._autoReloadId); + _autoReloadId: { + type: Number, }, - _autoReloadObserver: function(autoReload) { - window.localStorage.setItem(AUTORELOAD_LOCALSTORAGE_KEY, autoReload); - if (autoReload) { - var _this = this; - this._autoReloadId = window.setTimeout( - this._doAutoReload.bind(this), this.autoReloadIntervalSecs * 1000); - } else { - window.clearTimeout(this._autoReloadId); - } + autoReloadIntervalSecs: { + type: Number, + value: 30, }, - _doAutoReload: function() { - if (this.reload == null) { - throw new Error('AutoReloadBehavior requires a reload method'); - } - this.reload(); + }, + detached: function() { + window.clearTimeout(this._autoReloadId); + }, + _autoReloadObserver: function(autoReload) { + window.localStorage.setItem(AUTORELOAD_LOCALSTORAGE_KEY, autoReload); + if (autoReload) { + var _this = this; this._autoReloadId = window.setTimeout( this._doAutoReload.bind(this), this.autoReloadIntervalSecs * 1000); + } else { + window.clearTimeout(this._autoReloadId); } - }; -} + }, + _doAutoReload: function() { + if (this.reload == null) { + throw new Error('AutoReloadBehavior requires a reload method'); + } + this.reload(); + this._autoReloadId = window.setTimeout( + this._doAutoReload.bind(this), this.autoReloadIntervalSecs * 1000); + } +}; diff --git a/tensorflow/tensorboard/components/tf_tensorboard/demo.html b/tensorflow/tensorboard/components/tf_tensorboard/demo.html index c8a9238aef0..f691f6211bc 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/demo.html +++ b/tensorflow/tensorboard/components/tf_tensorboard/demo.html @@ -18,7 +18,6 @@ limitations under the License. TensorBoard Demo - diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts index 0f049d40ab6..b68fd8c9438 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts @@ -12,19 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + +import {AUTORELOAD_LOCALSTORAGE_KEY, AutoReloadBehavior} from '../autoReloadBehavior'; + declare function fixture(id: string): void; + window.HTMLImports.whenReady(() => { Polymer({ is: 'autoreload-test-element', - behaviors: [TF.TensorBoard.AutoReloadBehavior], + behaviors: [AutoReloadBehavior], }); describe('autoReload-behavior', function() { - var testElement; - var ls = window.localStorage; - var key = TF.TensorBoard.AUTORELOAD_LOCALSTORAGE_KEY; - var clock; - var callCount: number; + let testElement; + const ls = window.localStorage; + const key = AUTORELOAD_LOCALSTORAGE_KEY; + let clock; + let callCount: number; beforeEach(function() { ls.setItem(key, 'false'); // start it turned off so we can mutate fns diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts index 2308298ced9..a00027963be 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import {TABS} from '../../tf-globals/globals'; + describe('end-to-end test', () => { window.HTMLImports.whenReady(() => { let tb = d3.select('tf-tensorboard'); var tabs = (tb.node()).$.tabs; function testTab(tabIndex: number) { - it(`selecting ${TF.Globals.TABS[tabIndex]} tab`, done => { + it(`selecting ${TABS[tabIndex]} tab`, done => { // Every dashboard emits a rendered event when it is done rendering. tb.on('rendered', () => done()); tabs.set('selected', tabIndex); @@ -32,7 +34,7 @@ describe('end-to-end test', () => { // have failed. Re-selecting the default tab and listening for // "rendered" event won't work since the content is not re-stamped. let selected = +tabs.get('selected'); - for (let i = 0; i < TF.Globals.TABS.length; i++) { + for (let i = 0; i < TABS.length; i++) { if (i !== selected) { testTab(i); } diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts index 4dd62a0c382..905ed4ee4aa 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import {TABS} from '../../tf-globals/globals'; + describe('fast tab switch', () => { window.HTMLImports.whenReady(() => { let tb = d3.select('tf-tensorboard'); + // tslint:disable-next-line:no-any be quiet tsc var tabs = (tb.node()).$.tabs; // This test will select the events tab. Once the events tab @@ -23,9 +26,9 @@ describe('fast tab switch', () => { // the images tab wihout waiting for the graph tab to finish // rendering. Finally, it finishes when the images tab // has rendered and no errors were thrown. - let eventsTabIndex = TF.Globals.TABS.indexOf('events'); - let imagesTabIndex = TF.Globals.TABS.indexOf('images'); - let graphTabIndex = TF.Globals.TABS.indexOf('graphs'); + const eventsTabIndex = TABS.indexOf('events'); + const imagesTabIndex = TABS.indexOf('images'); + const graphTabIndex = TABS.indexOf('graphs'); // Listen for when the events tab rendered. tb.on('rendered', () => { diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts index 3c7fe2c9e72..33e11e3094d 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts +++ b/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + +import * as backend_router from '../../tf-backend/router'; +import {TABS} from '../../tf-globals/globals'; + describe('tf-tensorboard tests', () => { window.HTMLImports.whenReady(() => { let tensorboard: any; @@ -25,16 +29,16 @@ describe('tf-tensorboard tests', () => { setTimeout(function() { let tabs = tensorboard.$.tabs.getElementsByTagName('paper-tab'); let tabMode = Array.prototype.map.call(tabs, (x) => x.dataMode); - chai.assert.deepEqual(tabMode, TF.Globals.TABS, 'mode is correct'); + chai.assert.deepEqual(tabMode, TABS, 'mode is correct'); let tabText = Array.prototype.map.call(tabs, (x) => x.innerText.toLowerCase()); - chai.assert.deepEqual(tabText, TF.Globals.TABS, 'text is correct'); + chai.assert.deepEqual(tabText, TABS, 'text is correct'); done(); }); }); it('respects router manually provided', function() { - let router = TF.Backend.router('data', true); + const router = backend_router.router('data', true); tensorboard.router = router; tensorboard.demoDir = null; chai.assert.equal(tensorboard._backend.router, router); @@ -46,7 +50,7 @@ describe('tf-tensorboard tests', () => { }); describe('reloading the selected dashboard', function() { - TF.Globals.TABS.forEach((name, tabIndex) => { + TABS.forEach((name, tabIndex) => { // These tabs do not support reload mode. if (name === 'graphs' || name === 'projections') { return; @@ -70,7 +74,7 @@ describe('tf-tensorboard tests', () => { }); it('reload is disabled for graph dashboard', function(done) { - let idx = TF.Globals.TABS.indexOf('graphs'); + const idx = TABS.indexOf('graphs'); chai.assert.notEqual(idx, -1, 'graphs was found'); tensorboard.$.tabs.set('selected', idx); setTimeout( diff --git a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html index ac3132fadaf..00a30686f69 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html +++ b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html @@ -44,7 +44,6 @@ tf-tensorboard is the frontend entry point for TensorBoard. It implements a toolbar (via paper-header-panel and paper-toolbar) that allows the user to toggle between various dashboards. --> - + - diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/BUILD b/tensorflow/tensorboard/components/tf_text_dashboard/BUILD index a1a97778280..b6dfdbefb4c 100644 --- a/tensorflow/tensorboard/components/tf_text_dashboard/BUILD +++ b/tensorflow/tensorboard/components/tf_text_dashboard/BUILD @@ -1,10 +1,10 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "tf_text_dashboard", srcs = [ "tf-text-dashboard.html", @@ -17,7 +17,7 @@ web_library( "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:lodash", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", "@org_polymer_paper_dialog", "@org_polymer_paper_icon_button", "@org_polymer_paper_material", @@ -26,15 +26,15 @@ web_library( ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"] + glob(["data/**"]), path = "/tf-text-dashboard", deps = [ ":tf_text_dashboard", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/index.html b/tensorflow/tensorboard/components/tf_text_dashboard/index.html index 77d19b948c9..d01f4777ed3 100644 --- a/tensorflow/tensorboard/components/tf_text_dashboard/index.html +++ b/tensorflow/tensorboard/components/tf_text_dashboard/index.html @@ -44,6 +44,9 @@ limitations under the License. + diff --git a/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts b/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts index 17e35978249..f3911d301d9 100644 --- a/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts +++ b/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts @@ -12,13 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -/* tslint:disable:no-namespace variable-name */ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as _ from 'lodash' -import * as Plottable from 'Plottable/plottable'; // from //third_party/javascript/plottable -import {Dataset} from 'Plottable/plottable'; -import * as ChartHelpers from '../vz_line_chart/vz-chart-helpers'; +import * as ChartHelpers from '../vz-line-chart/vz-chart-helpers'; export class DistributionChart { private run2datasets: {[run: string]: Plottable.Dataset}; diff --git a/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD b/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD index 005090b8e06..6f21df0c865 100644 --- a/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD +++ b/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD @@ -1,29 +1,41 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "vz_histogram_timeseries", srcs = ["vz-histogram-timeseries.html"], path = "/vz-histogram-timeseries", deps = [ "//tensorflow/tensorboard/components/tf_imports:d3", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/vz-histogram-timeseries", deps = [ ":vz_histogram_timeseries", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_button", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", + ], +) + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":vz_histogram_timeseries"], + visibility = ["//learning/vis/vz_elements/catalog:__pkg__"], + destdir = "vz-histogram-timeseries", + deps = [ + "//tensorflow/tensorboard/components/tf_imports_google:lib", + "//third_party/javascript/polymer/v1/polymer:lib", ], ) @@ -32,22 +44,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "index.html", - "vz-histogram-timeseries.html", - ], - visibility = ["//learning/vis/vz_elements/catalog:__pkg__"], - destdir = "vz-histogram-timeseries", - deps = [ - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) diff --git a/tensorflow/tensorboard/components/vz_line_chart/BUILD b/tensorflow/tensorboard/components/vz_line_chart/BUILD index c641587158b..7d8d0d60749 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/BUILD +++ b/tensorflow/tensorboard/components/vz_line_chart/BUILD @@ -1,16 +1,17 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "vz_line_chart", srcs = [ - "bundle.js", + "dragZoomInteraction.ts", + "vz-chart-helpers.ts", "vz-line-chart.html", + "vz-line-chart.ts", ], path = "/vz-line-chart", visibility = ["//visibility:public"], @@ -18,11 +19,11 @@ web_library( "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:lodash", "//tensorflow/tensorboard/components/tf_imports:plottable", - "@org_polymer", + "//tensorflow/tensorboard/components/tf_imports:polymer", ], ) -web_library( +ts_web_library( name = "demo", srcs = ["index.html"], path = "/vz-line-chart", @@ -30,60 +31,12 @@ web_library( ":vz_line_chart", "@org_polymer_iron_demo_helpers", "@org_polymer_paper_styles", - "@org_polymer_webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:lodash.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - "//tensorflow/tensorboard/components/tf_imports:plottable.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = { - "VZ.ChartHelpers": [ - "vz-chart-helpers.ts", - ], - "VZ": [ - "vz-line-chart.ts", - "dragZoomInteraction.ts", - ], - }, - namespace_symbol_aliases = { - "VZ.ChartHelpers": { - "Dataset": "Plottable.Dataset", - }, - }, -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - tensorboard_webcomponent_library( name = "legacy", - srcs = [ - "index.html", - "vz-line-chart.html", - ":legacy_ts", - ], + srcs = [":vz_line_chart"], visibility = ["//learning/vis/vz_elements/catalog:__pkg__"], destdir = "vz-line-chart", deps = [ @@ -93,24 +46,8 @@ tensorboard_webcomponent_library( ], ) -tensorboard_ts_library( - name = "legacy_ts", - srcs = [ - "dragZoomInteraction.ts", - "vz-chart-helpers.ts", - "vz-line-chart.ts", - ], - deps_mgmt = "off", - runtime = "nodejs", - deps = [ - "//third_party/javascript/node_modules/typescript:es2015.promise", - "//third_party/javascript/plottable:bundle", - "//third_party/javascript/typings/chai", - "//third_party/javascript/typings/d3_v4:bundle", - "//third_party/javascript/typings/lodash", - "//third_party/javascript/typings/mocha", - "//third_party/javascript/typings/polymer:polymer_without_externs", - "//third_party/javascript/typings/sinon", - "//third_party/javascript/typings/webcomponents_js", - ], +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], ) diff --git a/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts b/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts index 2c1f4989c4c..c7f1f30e76b 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts +++ b/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts @@ -13,11 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as Plottable from 'Plottable/plottable'; // from //third_party/javascript/plottable - - export class DragZoomLayer extends Plottable.Components.SelectionBoxLayer { private _dragInteraction: Plottable.Interactions.Drag; private _doubleClickInteraction: Plottable.Interactions.Click; diff --git a/tensorflow/tensorboard/components/vz_line_chart/index.html b/tensorflow/tensorboard/components/vz_line_chart/index.html index fb571a51837..856ab7d1efe 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/index.html +++ b/tensorflow/tensorboard/components/vz_line_chart/index.html @@ -21,7 +21,6 @@ limitations under the License. vz-line-chart demo - diff --git a/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts b/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts index cd8f1376172..fa89e06ada1 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts +++ b/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts @@ -12,12 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -/* tslint:disable:no-namespace variable-name */ - - -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as Plottable from 'Plottable/plottable'; // from //third_party/javascript/plottable -import {Dataset} from 'Plottable/plottable'; export interface Datum { wall_time: Date; @@ -123,6 +117,7 @@ export function computeDomain(values: number[], ignoreOutliers: boolean) { } export function accessorize(key: string): Plottable.IAccessor { + // tslint:disable-next-line:no-any be quiet tsc return (d: any, index: number, dataset: Plottable.Dataset) => d[key]; } @@ -157,19 +152,21 @@ export function wallX(): XComponents { accessor: (d: Datum) => d.wall_time, }; } -export let relativeAccessor = (d: any, index: number, dataset: Dataset) => { - // We may be rendering the final-point datum for scatterplot. - // If so, we will have already provided the 'relative' property - if (d.relative != null) { - return d.relative; - } - let data = dataset.data(); - // I can't imagine how this function would be called when the data is - // empty (after all, it iterates over the data), but lets guard just - // to be safe. - let first = data.length > 0 ? +data[0].wall_time : 0; - return (+d.wall_time - first) / (60 * 60 * 1000); // ms to hours -}; +export let relativeAccessor = + // tslint:disable-next-line:no-any be quiet tsc + (d: any, index: number, dataset: Plottable.Dataset) => { + // We may be rendering the final-point datum for scatterplot. + // If so, we will have already provided the 'relative' property + if (d.relative != null) { + return d.relative; + } + let data = dataset.data(); + // I can't imagine how this function would be called when the data is + // empty (after all, it iterates over the data), but lets guard just + // to be safe. + let first = data.length > 0 ? +data[0].wall_time : 0; + return (+d.wall_time - first) / (60 * 60 * 1000); // ms to hours + }; export let relativeFormatter = (n: number) => { // we will always show 2 units of precision, e.g days and hours, or diff --git a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html index 85e24ae4be0..38e0d7cb8d8 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html +++ b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html @@ -125,5 +125,7 @@ such as different X scales (linear and temporal), tooltips and smoothing. - + + + diff --git a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts index d50a7834f5f..5da6190ea24 100644 --- a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts +++ b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts @@ -14,10 +14,6 @@ limitations under the License. ==============================================================================*/ /* tslint:disable:no-namespace variable-name */ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 -import * as _ from 'lodash' -import * as Plottable from 'Plottable/plottable'; // from //third_party/javascript/plottable - import {DragZoomLayer} from './dragZoomInteraction' import * as ChartHelpers from './vz-chart-helpers' @@ -142,7 +138,7 @@ Polymer({ * Sets the series that the chart displays. Series with other names will * not be displayed. * - * @param {String[]} names Array with the names of the series to + * @param {Array} names Array with the names of the series to * display. */ setVisibleSeries: function(names) { @@ -157,8 +153,8 @@ Polymer({ * Sets the data of one of the series. Note that to display this series * its name must be in the setVisibleSeries() array. * - * @param {String} name Name of the series. - * @param {VZ.ChartHelpers.ScalarDatum[]} data Data of the series. This is + * @param {string} name Name of the series. + * @param {Array} data Data of the series. This is * an array of objects with at least the following properties: * - step: (Number) - index of the datum. * - wall_time: (Date) - Date object with the datum's time. diff --git a/tensorflow/tensorboard/components/vz_projector/BUILD b/tensorflow/tensorboard/components/vz_projector/BUILD index c1adeabbf53..6d22554efa5 100644 --- a/tensorflow/tensorboard/components/vz_projector/BUILD +++ b/tensorflow/tensorboard/components/vz_projector/BUILD @@ -1,38 +1,69 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "vz_projector", srcs = [ + "analyticsLogger.ts", "bundle.html", - "bundle.js", + "data.ts", + "data-provider.ts", + "data-provider-demo.ts", + "data-provider-proto.ts", + "data-provider-server.ts", + "external.d.ts", + "knn.ts", + "label.ts", + "logging.ts", + "projectorEventContext.ts", + "projectorScatterPlotAdapter.ts", + "renderContext.ts", + "scatterPlot.ts", + "scatterPlotRectangleSelector.ts", + "scatterPlotVisualizer.ts", + "scatterPlotVisualizer3DLabels.ts", + "scatterPlotVisualizerCanvasLabels.ts", + "scatterPlotVisualizerPolylines.ts", + "scatterPlotVisualizerSprites.ts", "styles.html", + "util.ts", + "vector.ts", "vz-projector.html", + "vz-projector.ts", "vz-projector-app.html", "vz-projector-bookmark-panel.html", + "vz-projector-bookmark-panel.ts", "vz-projector-colab.html", "vz-projector-dashboard.html", "vz-projector-data-panel.html", + "vz-projector-data-panel.ts", "vz-projector-input.html", + "vz-projector-input.ts", "vz-projector-inspector-panel.html", + "vz-projector-inspector-panel.ts", "vz-projector-legend.html", + "vz-projector-legend.ts", "vz-projector-metadata-card.html", + "vz-projector-metadata-card.ts", "vz-projector-projections-panel.html", + "vz-projector-projections-panel.ts", + "vz-projector-util.ts", ], path = "/vz-projector", visibility = ["//visibility:public"], deps = [ + ":bh_tsne", + ":heap", + ":sptree", "//tensorflow/tensorboard/components/tf_dashboard_common", "//tensorflow/tensorboard/components/tf_imports:d3", "//tensorflow/tensorboard/components/tf_imports:numericjs", + "//tensorflow/tensorboard/components/tf_imports:polymer", "//tensorflow/tensorboard/components/tf_imports:threejs", "//tensorflow/tensorboard/components/tf_imports:weblas", - "@org_polymer", "@org_polymer_iron_collapse", "@org_polymer_iron_icons", "@org_polymer_paper_button", @@ -53,298 +84,23 @@ web_library( ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "external.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:three.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - ], +ts_web_library( + name = "heap", + srcs = ["heap.ts"], + path = "/vz-projector", ) -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = { - "VZ.Projector.Heap": ["heap.ts"], - "VZ.Projector.Label": ["label.ts"], - "VZ.Projector.SPTree": ["sptree.ts"], - "VZ.Projector.BhTsne": ["bh_tsne.ts"], - "VZ.Projector.Logging": ["logging.ts"], - "VZ.Projector.RenderContext": ["renderContext.ts"], - "VZ.Projector.ScatterPlotRectangleSelector": ["scatterPlotRectangleSelector.ts"], - "VZ.Projector.AnalyticsLogger": ["analyticsLogger.ts"], - "VZ.Projector.Util": ["util.ts"], - "VZ.Projector.Vector": ["vector.ts"], - "VZ.Projector.Knn": ["knn.ts"], - "VZ.Projector.Data": ["data.ts"], - "VZ.Projector.DataProvider": ["data-provider.ts"], - "VZ.Projector.DataProviderDemo": ["data-provider-demo.ts"], - "VZ.Projector.DataProviderProto": ["data-provider-proto.ts"], - "VZ.Projector.DataProviderServer": ["data-provider-server.ts"], - "VZ.Projector.ProjectorEventContext": ["projectorEventContext.ts"], - "VZ.Projector.ScatterPlot": ["scatterPlot.ts"], - "VZ.Projector.ScatterPlotVisualizer3DLabels": ["scatterPlotVisualizer3DLabels.ts"], - "VZ.Projector.ScatterPlotVisualizerCanvasLabels": ["scatterPlotVisualizerCanvasLabels.ts"], - "VZ.Projector.ScatterPlotVisualizerPolylines": ["scatterPlotVisualizerPolylines.ts"], - "VZ.Projector.ScatterPlotVisualizerSprites": ["scatterPlotVisualizerSprites.ts"], - "VZ.Projector.ScatterPlotVisualizer": ["scatterPlotVisualizer.ts"], - "VZ.Projector.ProjectorScatterPlotAdapter": ["projectorScatterPlotAdapter.ts"], - "VZ.Projector.ProjectorUtil": ["vz-projector-util.ts"], - "VZ.Projector.ProjectorBookmarkPanel": ["vz-projector-bookmark-panel.ts"], - "VZ.Projector.ProjectorDataPanel": ["vz-projector-data-panel.ts"], - "VZ.Projector.ProjectorInput": ["vz-projector-input.ts"], - "VZ.Projector.ProjectorInspectorPanel": ["vz-projector-inspector-panel.ts"], - "VZ.Projector.ProjectorLegend": ["vz-projector-legend.ts"], - "VZ.Projector.ProjectorMetadataCard": ["vz-projector-metadata-card.ts"], - "VZ.Projector.ProjectorProjectionsPanel": ["vz-projector-projections-panel.ts"], - "VZ.Projector": ["vz-projector.ts"], - }, - namespace_symbol_aliases = { - "VZ.Projector.AnalyticsLogger": { - "ProjectionType": "VZ.Projector.Data.ProjectionType", - }, - "VZ.Projector.BhTsne": { - "SPNode": "VZ.Projector.SPTree.SPNode", - "SPTree": "VZ.Projector.SPTree.SPTree", - }, - "VZ.Projector.DataProviderDemo": { - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "DataSet": "VZ.Projector.Data.DataSet", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "ProjectorConfig": "VZ.Projector.DataProvider.ProjectorConfig", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "TENSORS_MSG_ID": "VZ.Projector.DataProvider.TENSORS_MSG_ID", - "dataProvider": "VZ.Projector.DataProvider", - "logging": "VZ.Projector.Logging", - }, - "VZ.Projector.DataProviderProto": { - "DataPoint": "VZ.Projector.Data.DataPoint", - "DataProto": "VZ.Projector.Data.DataProto", - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "DataSet": "VZ.Projector.Data.DataSet", - "PointMetadata": "VZ.Projector.Data.PointMetadata", - "ProjectorConfig": "VZ.Projector.DataProvider.ProjectorConfig", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "analyzeMetadata": "VZ.Projector.DataProvider.analyzeMetadata", - }, - "VZ.Projector.DataProviderServer": { - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "DataSet": "VZ.Projector.Data.DataSet", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "ProjectorConfig": "VZ.Projector.DataProvider.ProjectorConfig", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "dataProvider": "VZ.Projector.DataProvider", - "logging": "VZ.Projector.Logging", - }, - "VZ.Projector.DataProvider": { - "ColumnStats": "VZ.Projector.Data.ColumnStats", - "DataPoint": "VZ.Projector.Data.DataPoint", - "DataSet": "VZ.Projector.Data.DataSet", - "PointMetadata": "VZ.Projector.Data.PointMetadata", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "logging": "VZ.Projector.Logging", - "runAsyncTask": "VZ.Projector.Util.runAsyncTask", - }, - "VZ.Projector.Data": { - "SpriteMetadata": "VZ.Projector.DataProvider.SpriteMetadata", - "TSNE": "VZ.Projector.BhTsne.TSNE", - "knn": "VZ.Projector.Knn", - "logging": "VZ.Projector.Logging", - "scatterPlot": "VZ.Projector.ScatterPlot", - "util": "VZ.Projector.Util", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector.Knn": { - "KMin": "VZ.Projector.Heap.KMin", - "Vector": "VZ.Projector.Vector.Vector", - "logging": "VZ.Projector.Logging", - "runAsyncTask": "VZ.Projector.Util.runAsyncTask", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector.ProjectorEventContext": { - "DistanceFunction": "VZ.Projector.Data.DistanceFunction", - "NearestEntry": "VZ.Projector.Knn.NearestEntry", - "Projection": "VZ.Projector.Data.Projection", - }, - "VZ.Projector.ProjectorScatterPlotAdapter": { - "DataSet": "VZ.Projector.Data.DataSet", - "DistanceFunction": "VZ.Projector.Data.DistanceFunction", - "LabelRenderParams": "VZ.Projector.RenderContext.LabelRenderParams", - "NearestEntry": "VZ.Projector.Knn.NearestEntry", - "Projection": "VZ.Projector.Data.Projection", - "ProjectionComponents3D": "VZ.Projector.Data.ProjectionComponents3D", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "ScatterPlot": "VZ.Projector.ScatterPlot.ScatterPlot", - "ScatterPlotVisualizer3DLabels": "VZ.Projector.ScatterPlotVisualizer3DLabels.ScatterPlotVisualizer3DLabels", - "ScatterPlotVisualizerCanvasLabels": "VZ.Projector.ScatterPlotVisualizerCanvasLabels.ScatterPlotVisualizerCanvasLabels", - "ScatterPlotVisualizerPolylines": "VZ.Projector.ScatterPlotVisualizerPolylines.ScatterPlotVisualizerPolylines", - "ScatterPlotVisualizerSprites": "VZ.Projector.ScatterPlotVisualizerSprites.ScatterPlotVisualizerSprites", - "State": "VZ.Projector.Data.State", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector.ScatterPlot": { - "BoundingBox": "VZ.Projector.ScatterPlotRectangleSelector.BoundingBox", - "CameraType": "VZ.Projector.RenderContext.CameraType", - "LabelRenderParams": "VZ.Projector.RenderContext.LabelRenderParams", - "Point2D": "VZ.Projector.Vector.Point2D", - "Point3D": "VZ.Projector.Vector.Point3D", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotRectangleSelector": "VZ.Projector.ScatterPlotRectangleSelector.ScatterPlotRectangleSelector", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizer3DLabels": { - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizerCanvasLabels": { - "BoundingBox": "VZ.Projector.Label.BoundingBox", - "CameraType": "VZ.Projector.RenderContext.CameraType", - "CollisionGrid": "VZ.Projector.Label.CollisionGrid", - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizerPolylines": { - "DataSet": "VZ.Projector.Data.DataSet", - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizerSprites": { - "CameraType": "VZ.Projector.RenderContext.CameraType", - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - "ScatterPlotVisualizer": "VZ.Projector.ScatterPlotVisualizer.ScatterPlotVisualizer", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ScatterPlotVisualizer": { - "RenderContext": "VZ.Projector.RenderContext.RenderContext", - }, - "VZ.Projector.Util": { - "DataPoint": "VZ.Projector.Data.DataPoint", - "Point2D": "VZ.Projector.Vector.Point2D", - "logging": "VZ.Projector.Logging", - }, - "VZ.Projector.Vector": { - "assert": "VZ.Projector.Util.assert", - }, - "VZ.Projector.ProjectorBookmarkPanel": { - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projector": "VZ.Projector.Projector", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "State": "VZ.Projector.Data.State", - "logging": "VZ.Projector.Logging", - }, - "VZ.Projector.ProjectorDataPanel": { - "ColorLegendRenderInfo": "VZ.Projector.ProjectorLegend.ColorLegendRenderInfo", - "ColorLegendThreshold": "VZ.Projector.ProjectorLegend.ColorLegendThreshold", - "ColorOption": "VZ.Projector.Data.ColorOption", - "ColumnStats": "VZ.Projector.Data.ColumnStats", - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projector": "VZ.Projector.Projector", - "ProjectorConfig": "VZ.Projector.DataProvider.ProjectorConfig", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "parseRawMetadata": "VZ.Projector.DataProvider.parseRawMetadata", - "parseRawTensors": "VZ.Projector.DataProvider.parseRawTensors", - "util": "VZ.Projector.Util", - }, - "VZ.Projector.ProjectorInput": { - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - }, - "VZ.Projector.ProjectorInspectorPanel": { - "DistanceFunction": "VZ.Projector.Data.DistanceFunction", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projector": "VZ.Projector.Projector", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "ProjectorInput": "VZ.Projector.ProjectorInput.ProjectorInput", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "adapter": "VZ.Projector.ProjectorScatterPlotAdapter", - "knn": "VZ.Projector.Knn", - "util": "VZ.Projector.Util", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector.ProjectorLegend": { - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - }, - "VZ.Projector.ProjectorMetadataCard": { - "PointMetadata": "VZ.Projector.Data.PointMetadata", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - }, - "VZ.Projector.ProjectorProjectionsPanel": { - "DataSet": "VZ.Projector.Data.DataSet", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projection": "VZ.Projector.Data.Projection", - "ProjectionType": "VZ.Projector.Data.ProjectionType", - "Projector": "VZ.Projector.Projector", - "ProjectorInput": "VZ.Projector.ProjectorInput.ProjectorInput", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "Vector": "VZ.Projector.Vector.Vector", - "data": "VZ.Projector.Data", - "util": "VZ.Projector.Util", - "vector": "VZ.Projector.Vector", - }, - "VZ.Projector": { - "AnalyticsLogger": "VZ.Projector.AnalyticsLogger.AnalyticsLogger", - "BookmarkPanel": "VZ.Projector.ProjectorBookmarkPanel.BookmarkPanel", - "ColorOption": "VZ.Projector.Data.ColorOption", - "ColumnStats": "VZ.Projector.Data.ColumnStats", - "DataPanel": "VZ.Projector.ProjectorDataPanel.DataPanel", - "DataPoint": "VZ.Projector.Data.DataPoint", - "DataProto": "VZ.Projector.Data.DataProto", - "DataProvider": "VZ.Projector.DataProvider.DataProvider", - "DataSet": "VZ.Projector.Data.DataSet", - "DemoDataProvider": "VZ.Projector.DataProviderDemo.DemoDataProvider", - "DistanceFunction": "VZ.Projector.Data.DistanceFunction", - "DistanceMetricChangedListener": "VZ.Projector.ProjectorEventContext.DistanceMetricChangedListener", - "EmbeddingInfo": "VZ.Projector.DataProvider.EmbeddingInfo", - "HoverListener": "VZ.Projector.ProjectorEventContext.HoverListener", - "InspectorPanel": "VZ.Projector.ProjectorInspectorPanel.InspectorPanel", - "MetadataCard": "VZ.Projector.ProjectorMetadataCard.MetadataCard", - "MouseMode": "VZ.Projector.ScatterPlot.MouseMode", - "PointMetadata": "VZ.Projector.Data.PointMetadata", - "PolymerElement": "VZ.Projector.ProjectorUtil.PolymerElement", - "PolymerHTMLElement": "VZ.Projector.ProjectorUtil.PolymerHTMLElement", - "Projection": "VZ.Projector.Data.Projection", - "ProjectionChangedListener": "VZ.Projector.ProjectorEventContext.ProjectionChangedListener", - "ProjectionsPanel": "VZ.Projector.ProjectorProjectionsPanel.ProjectionsPanel", - "ProjectorEventContext": "VZ.Projector.ProjectorEventContext.ProjectorEventContext", - "ProjectorScatterPlotAdapter": "VZ.Projector.ProjectorScatterPlotAdapter.ProjectorScatterPlotAdapter", - "ProtoDataProvider": "VZ.Projector.DataProviderProto.ProtoDataProvider", - "SelectionChangedListener": "VZ.Projector.ProjectorEventContext.SelectionChangedListener", - "ServerDataProvider": "VZ.Projector.DataProviderServer.ServerDataProvider", - "ServingMode": "VZ.Projector.DataProvider.ServingMode", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "data": "VZ.Projector.Data", - "knn": "VZ.Projector.Knn", - "logging": "VZ.Projector.Logging", - "stateGetAccessorDimensions": "VZ.Projector.Data.stateGetAccessorDimensions", - "util": "VZ.Projector.Util", - }, - }, +ts_web_library( + name = "sptree", + srcs = ["sptree.ts"], + path = "/vz-projector", +) + +ts_web_library( + name = "bh_tsne", + srcs = ["bh_tsne.ts"], + path = "/vz-projector", + deps = [":sptree"], ) filegroup( @@ -352,97 +108,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -#### Legacy for other consumers -load( - "//tensorflow/tensorboard:defs.bzl", - "tensorboard_webcomponent_library", - "tensorboard_ts_library", - "tensorboard_ts_declaration", -) - -# Standalone embedding projector demos should depend on this target. We -# exclude the HTML file for the dashboard itself. Demos do not need that -# HTML file. This was introduced because standalone demos as of today -# have an additional Closure pass that uses a compilation configuration -# stricter than that of TensorBoard. - -_PROJECTOR_LIB_TS_LIB_DEPS = [ - ":ts_lib", - ":tsne_ts_lib", -] - -_PROJECTOR_DESTDIR = "vz-projector" - -_PROJECTOR_LIB_DEPS = [ - "//third_party/javascript/polymer/v1/iron-collapse:lib", - "//third_party/javascript/polymer/v1/iron-icons:lib", - "//third_party/javascript/polymer/v1/paper-button:lib", - "//third_party/javascript/polymer/v1/paper-checkbox:lib", - "//third_party/javascript/polymer/v1/paper-dialog:lib", - "//third_party/javascript/polymer/v1/paper-dialog-scrollable:lib", - "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", - "//third_party/javascript/polymer/v1/paper-icon-button:lib", - "//third_party/javascript/polymer/v1/paper-input:lib", - "//third_party/javascript/polymer/v1/paper-item:lib", - "//third_party/javascript/polymer/v1/paper-listbox:lib", - "//third_party/javascript/polymer/v1/paper-slider:lib", - "//third_party/javascript/polymer/v1/paper-spinner:lib", - "//third_party/javascript/polymer/v1/paper-toast:lib", - "//third_party/javascript/polymer/v1/paper-toggle-button:lib", - "//third_party/javascript/polymer/v1/paper-tooltip:lib", - "//third_party/javascript/polymer/v1/polymer:lib", -] - -tensorboard_ts_library( - name = "tsne_ts_lib", - srcs = [ - "bh_tsne.ts", - "sptree.ts", - ], -) - -tensorboard_ts_declaration( - name = "external", - srcs = ["external.d.ts"], -) - -tensorboard_ts_library( - name = "ts_lib", - srcs = glob( - ["*.ts"], - exclude = [ - "*.d.ts", - "*_test.ts", - "bh_tsne.ts", - "sptree.ts", - ], - ), - runtime_deps = [ - "//third_party/javascript/d3/v4:d3", - "//third_party/javascript/numericjs", - "//third_party/javascript/threejs/r77:threejs", - "//third_party/javascript/threejs/r77/examples/js/controls:orbitcontrols", - "//third_party/javascript/weblas", - ], - deps = [ - ":external", - ":tsne_ts_lib", - "//third_party/javascript/node_modules/typescript:es2015.promise", - "//third_party/javascript/typings/d3_v4:bundle", - "//third_party/javascript/typings/polymer:polymer_without_externs", - "//third_party/javascript/typings/threejs:three", - "//third_party/javascript/typings/webcomponents_js", - ], -) - -tensorboard_webcomponent_library( - name = "lib", - srcs = glob( - ["*.html"], - exclude = ["vz-projector-dashboard.html"], - ), - ts_lib_deps = _PROJECTOR_LIB_TS_LIB_DEPS, - destdir = _PROJECTOR_DESTDIR, - deps = _PROJECTOR_LIB_DEPS, -) diff --git a/tensorflow/tensorboard/components/vz_projector/bundle.html b/tensorflow/tensorboard/components/vz_projector/bundle.html index 2837fed8708..de87763673b 100644 --- a/tensorflow/tensorboard/components/vz_projector/bundle.html +++ b/tensorflow/tensorboard/components/vz_projector/bundle.html @@ -21,4 +21,36 @@ limitations under the License. - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts index 9d6df953d65..c0da9526598 100644 --- a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts +++ b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 - import {DataSet, DistanceFunction, Projection, ProjectionComponents3D, State} from './data'; import {NearestEntry} from './knn'; import {ProjectorEventContext} from './projectorEventContext'; diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts index ece4d84ef28..2f3146d213c 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 import {BoundingBox, CollisionGrid} from './label'; import {CameraType, RenderContext} from './renderContext'; import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; diff --git a/tensorflow/tensorboard/components/vz_projector/test/BUILD b/tensorflow/tensorboard/components/vz_projector/test/BUILD index 7629272c350..a73c50dcd6d 100644 --- a/tensorflow/tensorboard/components/vz_projector/test/BUILD +++ b/tensorflow/tensorboard/components/vz_projector/test/BUILD @@ -3,76 +3,31 @@ package( default_visibility = ["//tensorflow:internal"], ) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "test", srcs = [ - "bundle.js", + "assert.ts", + "data-provider_test.ts", + "data_test.ts", + "sptree_test.ts", "tests.html", + "util_test.ts", + # "scatterPlotRectangleSelector_test.ts", + # "vz-projector-projections-panel_test.ts", ], path = "/vz-projector/test", deps = [ + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", "//tensorflow/tensorboard/components/vz_projector", - "@org_npmjs_registry_web_component_tester", - "@org_polymer", - "@org_polymer_webcomponentsjs", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:chai.d.ts", - "@org_definitelytyped//:mocha.d.ts", - "@org_definitelytyped//:polymer.d.ts", - "@org_definitelytyped//:three.d.ts", - "@org_definitelytyped//:webcomponents.js.d.ts", - "//tensorflow/tensorboard/components/tf_imports:d3.d.ts", - "//tensorflow/tensorboard/components/tf_imports:plottable.d.ts", - "//tensorflow/tensorboard/components/vz_projector:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = { - "VZ.Projector.Test": [ - "assert.ts", - "sptree_test.ts", - "data_test.ts", - "data-provider_test.ts", - "util_test.ts", - - # TODO(smilkov): Migrate these away from jasmine. - # "scatterPlotRectangleSelector_test.ts", - # "vz-projector-projections-panel_test.ts", - ], - }, - namespace_symbol_aliases = { - "VZ.Projector.Test": { - "BoundingBox": "VZ.Projector.ScatterPlotRectangleSelector.BoundingBox", - "DataPoint": "VZ.Projector.Data.DataPoint", - "DataSet": "VZ.Projector.Data.DataSet", - "ProjectionsPanel": "VZ.Projector.ProjectorProjectionsPanel.ProjectionsPanel", - "SPTree": "VZ.Projector.SPTree.SPTree", - "ScatterPlotRectangleSelector": "VZ.Projector.ScatterPlotRectangleSelector.ScatterPlotRectangleSelector", - "SpriteAndMetadataInfo": "VZ.Projector.Data.SpriteAndMetadataInfo", - "State": "VZ.Projector.Data.State", - "State": "VZ.Projector.Data.State", - "data_provider": "VZ.Projector.DataProvider", - "stateGetAccessorDimensions": "VZ.Projector.Data.stateGetAccessorDimensions", - "util": "VZ.Projector.Util", - }, - }, -) - filegroup( name = "all_files", testonly = 0, diff --git a/tensorflow/tensorboard/components/vz_projector/test/tests.html b/tensorflow/tensorboard/components/vz_projector/test/tests.html index dd43079bde1..a6843d0d6b8 100644 --- a/tensorflow/tensorboard/components/vz_projector/test/tests.html +++ b/tensorflow/tensorboard/components/vz_projector/test/tests.html @@ -21,4 +21,11 @@ limitations under the License. - + + + + + + + diff --git a/tensorflow/tensorboard/components/vz_projector/vector.ts b/tensorflow/tensorboard/components/vz_projector/vector.ts index 0de78ad85df..cab30483138 100644 --- a/tensorflow/tensorboard/components/vz_projector/vector.ts +++ b/tensorflow/tensorboard/components/vz_projector/vector.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 import {assert} from './util'; /** diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html index 55c15da5ed7..8223c503ecd 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html @@ -37,10 +37,9 @@ limitations under the License. diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts index a6847ed3c87..a9b6f6c5a06 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 import {ColorOption, ColumnStats, SpriteAndMetadataInfo} from './data'; import {DataProvider, EmbeddingInfo, parseRawMetadata, parseRawTensors, ProjectorConfig} from './data-provider'; import * as util from './util'; diff --git a/tensorflow/tensorboard/components/vz_sorting/BUILD b/tensorflow/tensorboard/components/vz_sorting/BUILD index 96e270ce21f..fc309ce4a5d 100644 --- a/tensorflow/tensorboard/components/vz_sorting/BUILD +++ b/tensorflow/tensorboard/components/vz_sorting/BUILD @@ -1,30 +1,24 @@ package(default_visibility = ["//tensorflow:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "vz_sorting", srcs = [ - "bundle.js", + "sorting.ts", "vz-sorting.html", ], path = "/vz-sorting", visibility = ["//visibility:public"], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"VZ.Sorting": ["sorting.ts"]}, +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":vz_sorting"], + destdir = "vz-sorting", ) filegroup( @@ -32,25 +26,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "vz-sorting.html", - ":legacy_ts", - ], - destdir = "vz-sorting", -) - -tensorboard_ts_library( - name = "legacy_ts", - srcs = ["sorting.ts"], - deps_mgmt = "off", - runtime = "nodejs", -) diff --git a/tensorflow/tensorboard/components/vz_sorting/test/BUILD b/tensorflow/tensorboard/components/vz_sorting/test/BUILD index 07913e3cbde..5f3a951f689 100644 --- a/tensorflow/tensorboard/components/vz_sorting/test/BUILD +++ b/tensorflow/tensorboard/components/vz_sorting/test/BUILD @@ -3,39 +3,30 @@ package( default_visibility = ["//tensorflow:internal"], ) -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") -load("//tensorflow/tensorboard:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:vulcanize.bzl", "tensorboard_html_binary") +load("//tensorflow/tensorboard:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -web_library( +ts_web_library( name = "test", srcs = [ - "bundle.js", + "sortingTests.ts", "tests.html", ], path = "/vz-sorting/test", deps = [ + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", "//tensorflow/tensorboard/components/vz_sorting", - "@org_npmjs_registry_web_component_tester", ], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["bundle.ts"], - typings = [ - "@org_definitelytyped//:mocha.d.ts", - "@org_definitelytyped//:chai.d.ts", - "//tensorflow/tensorboard/components/vz_sorting:bundle.d.ts", - ], -) - -tensorboard_typescript_bundle( - name = "bundle", - out = "bundle.ts", - namespace_srcs = {"VZ.Sorting": ["sortingTests.ts"]}, +tensorboard_html_binary( + name = "devserver", + compilation_level = "WHITESPACE_ONLY", + input_path = "/vz-sorting/test/tests.html", + output_path = "/vz-sorting/test/tests.html", + deps = [":test"], ) filegroup( diff --git a/tensorflow/tensorboard/components/vz_sorting/test/tests.html b/tensorflow/tensorboard/components/vz_sorting/test/tests.html index d1b4a1db31c..f92c608cdb1 100644 --- a/tensorflow/tensorboard/components/vz_sorting/test/tests.html +++ b/tensorflow/tensorboard/components/vz_sorting/test/tests.html @@ -17,7 +17,7 @@ limitations under the License. --> - + + - - + diff --git a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html b/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html index 9f925951cb2..5ff6f311589 100644 --- a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html +++ b/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html @@ -15,4 +15,4 @@ See the License for the specific language governing permissions and limitations under the License. --> - + diff --git a/tensorflow/tensorboard/defs.bzl b/tensorflow/tensorboard/defs.bzl index 827a74b173f..b3712a8156d 100644 --- a/tensorflow/tensorboard/defs.bzl +++ b/tensorflow/tensorboard/defs.bzl @@ -12,83 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -_DEFAULT_TYPINGS = [ - "@com_microsoft_typescript//:lib.es6.d.ts", -] - -def tensorboard_typescript_genrule(name, srcs, typings=[], **kwargs): - """Filegroup of compiled TypeScript sources. - - This is a very unsophisticated TypeScript rule where the user is responsible - for passing all typings and sources via srcs. It's meant as a stopgap because - TypeScript rules currently don't exist for Bazel. The definition of this rule - will need to evolve as more ts_library rules are migrated. - """ - for src in srcs: - if (src.startswith("/") or - src.endswith(".d.ts") or - not src.endswith(".ts")): - fail("srcs must be typescript sources in same package") - typings_out = [src[:-3] + ".d.ts" for src in srcs] - inputs = _DEFAULT_TYPINGS + typings + srcs - # These inputs are meant to work around a sandbox bug in Bazel. If we list - # @com_microsoft_typescript//:tsc.sh under tools, then its - # data attribute won't be considered when --genrule_strategy=sandboxed. See - # https://github.com/bazelbuild/bazel/issues/1147 and its linked issues. - data = [ - "@org_nodejs", - "@com_microsoft_typescript", - ] - native.genrule( - name = name, - srcs = inputs + data, - outs = [src[:-3] + ".js" for src in srcs] + typings_out, - cmd = "$(location @com_microsoft_typescript//:tsc.sh)" + - " --inlineSourceMap" + - " --inlineSources" + - # Do not follow triple slash references within typings. - " --noResolve" + - " --declaration" + - " --module es6" + - " --outDir $(@D) " + - " ".join(["$(locations %s)" % i for i in inputs]), - tools = ["@com_microsoft_typescript//:tsc.sh"], - **kwargs - ) - native.filegroup( - name = name + "_typings", - srcs = typings_out, - **kwargs - ) - -def tensorboard_karma_web_test_suite(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_config(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_declaration(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_development_sources(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_devserver(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_library(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - def tensorboard_webcomponent_library(**kwargs): """Rules referencing this will be deleted from the codebase soon.""" pass - -def tensorboard_wct_test_suite(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass diff --git a/tensorflow/tensorboard/http_api.md b/tensorflow/tensorboard/http_api.md index 0cf788601a7..c62de0376d2 100644 --- a/tensorflow/tensorboard/http_api.md +++ b/tensorflow/tensorboard/http_api.md @@ -55,17 +55,9 @@ all of the data available from the TensorBoard server. Here is an example: { "train_run": { - "compressedHistograms": ["foo_histogram", "bar_histogram"], - "audio": ["input_audio"], - "graph": true, "firstEventTimestamp": 123456.789 - "run_metadata": ["forward prop", "inference"] }, "eval": { - "compressedHistograms": ["foo_histogram", "bar_histogram"], - "audio": ["input_audio"], - "graph": false, - "run_metadata": [] } } @@ -80,9 +72,13 @@ will have the same tag type across different runs. Each of the following tag types `` has been migrated to `/data/plugin//tags`, and will not appear in the output from this route: + - `audio` - `images` - `scalars` + - `compressedHistograms`, moved to `distributions` - `histograms` + - `graph`, as `/data/plugin/graphs/runs` + - `run_metadata`, as `/data/plugin/graphs/run_metadata_tags` ## `/data/plugin/scalars/tags` @@ -161,7 +157,21 @@ Annotated Example: (note - real data is higher precision) ] ] -## '/data/compressedHistograms?run=foo&tag=bar' +## `/data/plugin/distributions/tags` + +Returns a dictionary mapping from `run_name` (quoted string) to arrays of +`tag_name` (quoted string), where each array contains the names of all +distribution tags present in the corresponding run. Here is an example: + + { + "train_run": ["foo_histogram", "bar_histogram"], + "eval": ["foo_histogram", "bar_histogram"] + } + +Note that runs without any distribution tags are included as keys with +value the empty array. + +## `/data/plugin/distributions/distributions?run=foo&tag=bar` Returns an array of event_accumulator.CompressedHistogramEvents ([wall_time, step, CompressedHistogramValues]) for the given run and tag. @@ -181,8 +191,8 @@ Annotated Example: (note - real data is higher precision) [ 1441154832.580509, # wall_time 5, # step - [ [0, -3.67], # CompressedHistogramValue for 0th percentile - [2500, -4.19], # CompressedHistogramValue for 25th percentile + [ [0, -3.67], # CompressedHistogramValue for 0th percentile + [2500, -4.19], # CompressedHistogramValue for 25th percentile [5000, 6.29], [7500, 1.64], [10000, 3.67] @@ -238,13 +248,13 @@ tags present in the corresponding run. Here is an example: Note that runs without any image tags are included as keys with value the empty array. -## `/audio?run=foo&tag=bar` +## `/data/plugin/audio/audio?run=foo&tag=bar` Gets a sample of AudioMetadatas for the given run and tag. Returns an array of objects containing information about available audio, crucially including the query parameter that may be used to retrieve that audio. -(See /individualAudio for details.) +(See /data/plugin/audio/individualAudio for details.) For example: @@ -256,7 +266,7 @@ For example: # param for /individualAudio } -## `/individualAudio?{{query}}` +## `/data/plugin/audio/individualAudio?{{query}}` Retrieves an individual audio clip. The audio query should not be generated by the frontend, but instead acquired from calling the /audio route (the audio @@ -270,11 +280,33 @@ replaced with other clips. (See Notes for details on the reservoir sampling.) An example call to this route would look like this: /individualAudio?index=0&tagname=input%2Faudio%2F2&run=train -## `/data/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` +## `/data/plugin/audio/tags` -Returns the graph definition for the given run in gzipped pbtxt format. The -graph is composed of a list of nodes, where each node is a specific TensorFlow -operation which takes as inputs other nodes (operations). +Returns a dictionary mapping from `run_name` (quoted string) to arrays of +`tag_name` (quoted string), where each array contains the names of all audio +tags present in the corresponding run. Here is an example: + + { + "train": ["foo_audio", "bar_audio"], + "eval": ["foo_audio", "bar_audio"], + } + +Note that runs without any audio tags are included as keys with value the empty +array. + +## `/data/plugin/graphs/runs` + +Returns a list of runs that have associated graphs. + +For example: + + ["train"] + +## `/data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` + +Returns the graph definition for the given run in pbtxt format. The +graph is composed of a list of nodes, where each node is a specific +TensorFlow operation which takes as inputs other nodes (operations). The query parameters `limit_attr_size` and `large_attrs_key` are optional. @@ -287,7 +319,10 @@ attributes that are too large. The value of this key (list of strings) should be used by the client in order to determine which attributes have been filtered. Must be specified if `limit_attr_size` is specified. -For the query `/graph?run=foo&limit_attr_size=1024&large_attrs_key=_too_large`, +For the query + + /data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=_too_large, + here is an example pbtxt response of a graph with 3 nodes, where the second node had two large attributes "a" and "b" that were filtered out (size > 1024): diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD index 447dff55a3f..f2ea14503a0 100644 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD +++ b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD @@ -5,6 +5,11 @@ licenses(["notice"]) # Apache 2.0 java_binary( name = "Vulcanize", srcs = ["Vulcanize.java"], + jvm_flags = [ + "-Xss20m", # JSCompiler needs big stacks for recursive parsing + "-XX:+UseParallelGC", # Best GC when app isn't latency sensitive + "-Djava.util.logging.SimpleFormatter.format='%1$$tY-%1$$tm-%1$$td %1$$tH:%1$$tM:%1$$tS.%1$$tL %4$$-6s %5$$s%6$$s%n'", # Less log spam + ], visibility = ["//visibility:public"], deps = [ "@com_google_guava", @@ -29,6 +34,21 @@ java_binary( ], ) +# These JS files are always taken into consideration by the Closure Compiler +# when vulcanizing, per vulcanize.bzl. +filegroup( + name = "jslibs", + srcs = [ + # Ordering probably matters + "@com_google_javascript_closure_compiler_externs", + "@com_google_javascript_closure_compiler_externs_polymer", + "externs.js", + "@com_google_javascript_closure_library//:closure/goog/base.js", + "@com_google_javascript_closure_library//:closure/goog/deps.js", + ], + visibility = ["//visibility:public"], +) + filegroup( name = "all_files", srcs = glob(["**"]), diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java index e572415856c..2635f9b12f1 100644 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java +++ b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java @@ -15,23 +15,33 @@ package org.tensorflow.tensorboard.vulcanize; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Verify.verify; import static com.google.common.base.Verify.verifyNotNull; import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.CharMatcher; import com.google.common.base.Joiner; +import com.google.common.base.Optional; +import com.google.common.base.Splitter; +import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import com.google.javascript.jscomp.BasicErrorManager; +import com.google.common.collect.Lists; +import com.google.common.collect.Multimap; import com.google.javascript.jscomp.CheckLevel; +import com.google.javascript.jscomp.CompilationLevel; import com.google.javascript.jscomp.Compiler; import com.google.javascript.jscomp.CompilerOptions; -import com.google.javascript.jscomp.CompilerOptions.LanguageMode; -import com.google.javascript.jscomp.CompilerOptions.Reach; +import com.google.javascript.jscomp.DiagnosticGroup; +import com.google.javascript.jscomp.DiagnosticGroups; +import com.google.javascript.jscomp.DiagnosticType; import com.google.javascript.jscomp.JSError; import com.google.javascript.jscomp.PropertyRenamingPolicy; +import com.google.javascript.jscomp.Result; import com.google.javascript.jscomp.SourceFile; -import com.google.javascript.jscomp.VariableRenamingPolicy; +import com.google.javascript.jscomp.WarningsGuard; import com.google.protobuf.TextFormat; import io.bazel.rules.closure.Webpath; import io.bazel.rules.closure.webfiles.BuildInfo.Webfiles; @@ -44,12 +54,17 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.jsoup.Jsoup; +import org.jsoup.nodes.Attribute; import org.jsoup.nodes.Comment; import org.jsoup.nodes.DataNode; import org.jsoup.nodes.Document; @@ -63,21 +78,45 @@ import org.jsoup.parser.Tag; /** Simple one-off solution for TensorBoard vulcanization. */ public final class Vulcanize { + private static final Pattern IGNORE_PATHS_PATTERN = + Pattern.compile("/(?:polymer|marked-element)/.*"); + + private static final ImmutableSet EXTRA_JSDOC_TAGS = + ImmutableSet.of("attribute", "hero", "group", "required"); + + private static final Pattern WEBPATH_PATTERN = Pattern.compile("//~~WEBPATH~~([^\n]+)"); + private static final Parser parser = Parser.htmlParser(); private static final Map webfiles = new HashMap<>(); private static final Set alreadyInlined = new HashSet<>(); private static final Set legalese = new HashSet<>(); private static final List licenses = new ArrayList<>(); private static final List stack = new ArrayList<>(); + private static final List sourcesFromJsLibraries = new ArrayList<>(); + private static final Map sourcesFromScriptTags = new LinkedHashMap<>(); + private static final Map sourceTags = new LinkedHashMap<>(); + private static final Multimap suppressions = HashMultimap.create(); + private static CompilationLevel compilationLevel; private static Webpath outputPath; + private static Node firstCompiledScript; private static Node licenseComment; - private static boolean nominify; + private static int insideDemoSnippet; + private static boolean testOnly; public static void main(String[] args) throws IOException { - Webpath inputPath = Webpath.get(args[0]); - outputPath = Webpath.get(args[1]); - Path output = Paths.get(args[2]); - for (int i = 3; i < args.length; i++) { + compilationLevel = CompilationLevel.fromString(args[0]); + testOnly = args[1].equals("true"); + Webpath inputPath = Webpath.get(args[2]); + outputPath = Webpath.get(args[3]); + Path output = Paths.get(args[4]); + for (int i = 5; i < args.length; i++) { + if (args[i].endsWith(".js")) { + sourcesFromJsLibraries.add(SourceFile.fromFile(args[i])); + continue; + } + if (!args[i].endsWith(".pbtxt")) { + continue; + } Webfiles manifest = loadWebfilesPbtxt(Paths.get(args[i])); for (WebfilesSource src : manifest.getSrcList()) { webfiles.put(Webpath.get(src.getWebpath()), Paths.get(src.getPath())); @@ -86,6 +125,7 @@ public final class Vulcanize { stack.add(inputPath); Document document = parse(Files.readAllBytes(webfiles.get(inputPath))); transform(document); + compile(); if (licenseComment != null) { licenseComment.attr("comment", String.format("\n%s\n", Joiner.on("\n\n").join(licenses))); } @@ -134,72 +174,32 @@ public final class Vulcanize { } private static Node enterNode(Node node) throws IOException { - Node newNode = node; + if (node.nodeName().equals("demo-snippet")) { + insideDemoSnippet++; + } + if (insideDemoSnippet > 0) { + return node; + } if (node instanceof Element) { - if (node.nodeName().equals("link") && node.attr("rel").equals("import")) { - // Inline HTML. - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - if (alreadyInlined.add(href)) { - newNode = - parse(Files.readAllBytes(checkNotNull(webfiles.get(href), "%s in %s", href, me()))); - stack.add(href); - node.replaceWith(newNode); - } else { - newNode = new TextNode("", node.baseUri()); - node.replaceWith(newNode); - } - } else if (node.nodeName().equals("script")) { - nominify = node.hasAttr("nominify"); - node.removeAttr("nominify"); - Webpath src; - String script; - if (node.attr("src").isEmpty()) { - // Minify JavaScript. - StringBuilder sb = new StringBuilder(); - for (Node child : node.childNodes()) { - if (child instanceof DataNode) { - sb.append(((DataNode) child).getWholeData()); - } - } - src = me(); - script = sb.toString(); - } else { - // Inline JavaScript. - src = me().lookup(Webpath.get(node.attr("src"))); - Path other = webfiles.get(src); - if (other != null) { - script = new String(Files.readAllBytes(other), UTF_8); - node.removeAttr("src"); - } else { - src = me(); - script = ""; - } - } - script = minify(src, script); - newNode = - new Element(Tag.valueOf("script"), node.baseUri(), node.attributes()) - .appendChild(new DataNode(script, node.baseUri())); - node.replaceWith(newNode); - } else if (node.nodeName().equals("link") - && node.attr("rel").equals("stylesheet") - && !node.attr("href").isEmpty()) { - // Inline CSS. - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - Path other = webfiles.get(href); - if (other != null) { - newNode = - new Element(Tag.valueOf("style"), node.baseUri(), node.attributes()) - .appendChild( - new DataNode(new String(Files.readAllBytes(other), UTF_8), node.baseUri())); - newNode.removeAttr("rel"); - newNode.removeAttr("href"); - node.replaceWith(newNode); + if (!getAttrTransitive(node, "vulcanize-noinline").isPresent()) { + if (node.nodeName().equals("link") && node.attr("rel").equals("import")) { + // Inline HTML. + node = visitHtmlImport(node); + } else if (node.nodeName().equals("script") + && !shouldIgnoreUri(node.attr("src")) + && !node.hasAttr("jscomp-ignore")) { + node = visitScript(node); + } else if (node.nodeName().equals("link") + && node.attr("rel").equals("stylesheet") + && !node.attr("href").isEmpty() + && !shouldIgnoreUri(node.attr("href"))) { + node = visitStylesheet(node); } } - rootifyAttribute(newNode, "href"); - rootifyAttribute(newNode, "src"); - rootifyAttribute(newNode, "action"); - rootifyAttribute(newNode, "assetpath"); + rootifyAttribute(node, "href"); + rootifyAttribute(node, "src"); + rootifyAttribute(node, "action"); + rootifyAttribute(node, "assetpath"); } else if (node instanceof Comment) { String text = ((Comment) node).getData(); if (text.contains("@license")) { @@ -207,53 +207,230 @@ public final class Vulcanize { if (licenseComment == null) { licenseComment = node; } else { - newNode = new TextNode("", node.baseUri()); - node.replaceWith(newNode); + node = replaceNode(node, new TextNode("", node.baseUri())); } } else { - newNode = new TextNode("", node.baseUri()); - node.replaceWith(newNode); + node = replaceNode(node, new TextNode("", node.baseUri())); } } + return node; + } + + private static Node leaveNode(Node node) { + if (node instanceof Document) { + stack.remove(stack.size() - 1); + } else if (node.nodeName().equals("demo-snippet")) { + insideDemoSnippet--; + } + return node; + } + + private static Node visitHtmlImport(Node node) throws IOException { + Webpath href = me().lookup(Webpath.get(node.attr("href"))); + if (alreadyInlined.add(href)) { + stack.add(href); + Document subdocument = parse(Files.readAllBytes(getWebfile(href))); + for (Attribute attr : node.attributes()) { + subdocument.attr(attr.getKey(), attr.getValue()); + } + return replaceNode(node, subdocument); + } else { + return replaceNode(node, new TextNode("", node.baseUri())); + } + } + + private static Node visitScript(Node node) throws IOException { + Webpath path; + String script; + if (node.attr("src").isEmpty()) { + path = makeSyntheticName(".js"); + script = getInlineScriptFromNode(node); + } else { + path = me().lookup(Webpath.get(node.attr("src"))); + script = new String(Files.readAllBytes(getWebfile(path)), UTF_8); + } + if (node.attr("src").endsWith(".min.js") + || getAttrTransitive(node, "jscomp-nocompile").isPresent()) { + Node newScript = + new Element(Tag.valueOf("script"), node.baseUri(), node.attributes()) + .appendChild(new DataNode(script, node.baseUri())) + .removeAttr("src") + .removeAttr("jscomp-nocompile"); + if (firstCompiledScript != null) { + firstCompiledScript.before(newScript); + return replaceNode(node, new TextNode("", node.baseUri())); + } else { + return replaceNode(node, newScript); + } + } else { + if (firstCompiledScript == null) { + firstCompiledScript = node; + } + sourcesFromScriptTags.put(path, script); + sourceTags.put(path, node); + Optional suppress = getAttrTransitive(node, "jscomp-suppress"); + if (suppress.isPresent()) { + if (suppress.get().isEmpty()) { + suppressions.put(path, "*"); + } else { + suppressions.putAll(path, Splitter.on(' ').split(suppress.get())); + } + } + return node; + } + } + + private static Node visitStylesheet(Node node) throws IOException { + Webpath href = me().lookup(Webpath.get(node.attr("href"))); + return replaceNode( + node, + new Element(Tag.valueOf("style"), node.baseUri(), node.attributes()) + .appendChild( + new DataNode( + new String(Files.readAllBytes(getWebfile(href)), UTF_8), node.baseUri())) + .removeAttr("rel") + .removeAttr("href")); + } + + private static Optional getAttrTransitive(Node node, String attr) { + while (node != null) { + if (node.hasAttr(attr)) { + return Optional.of(node.attr(attr)); + } + node = node.parent(); + } + return Optional.absent(); + } + + private static Node replaceNode(Node oldNode, Node newNode) { + oldNode.replaceWith(newNode); return newNode; } - private static String minify(Webpath src, String script) { - if (nominify) { - return script; + private static Path getWebfile(Webpath path) { + return verifyNotNull(webfiles.get(path), "Bad ref: %s -> %s", me(), path); + } + + private static void compile() { + if (sourcesFromScriptTags.isEmpty()) { + return; } - Compiler compiler = new Compiler(new JsPrintlessErrorManager()); + CompilerOptions options = new CompilerOptions(); - options.skipAllCompilerPasses(); // too lazy to get externs - options.setLanguageIn(LanguageMode.ECMASCRIPT_2016); - options.setLanguageOut(LanguageMode.ECMASCRIPT5); + compilationLevel.setOptionsForCompilationLevel(options); + + // Nice options. + options.setColorizeErrorOutput(true); options.setContinueAfterErrors(true); - options.setManageClosureDependencies(false); - options.setRenamingPolicy(VariableRenamingPolicy.LOCAL, PropertyRenamingPolicy.OFF); - options.setShadowVariables(true); - options.setInlineVariables(Reach.LOCAL_ONLY); - options.setFlowSensitiveInlineVariables(true); - options.setInlineFunctions(Reach.LOCAL_ONLY); - options.setAssumeClosuresOnlyCaptureReferences(false); + options.setLanguageIn(CompilerOptions.LanguageMode.ECMASCRIPT_2016); + options.setLanguageOut(CompilerOptions.LanguageMode.ECMASCRIPT5); + options.setGenerateExports(true); + options.setStrictModeInput(false); + options.setExtraAnnotationNames(EXTRA_JSDOC_TAGS); + + // So we can chop JS binary back up into the original script tags. + options.setPrintInputDelimiter(true); + options.setInputDelimiter("//~~WEBPATH~~%name%"); + + // Optimizations that are too advanced for us right now. + options.setPropertyRenaming(PropertyRenamingPolicy.OFF); options.setCheckGlobalThisLevel(CheckLevel.OFF); - options.setFoldConstants(true); - options.setCoalesceVariableNames(true); - options.setDeadAssignmentElimination(true); - options.setCollapseVariableDeclarations(true); - options.setConvertToDottedProperties(true); - options.setLabelRenaming(true); - options.setRemoveDeadCode(true); - options.setOptimizeArgumentsArray(true); - options.setRemoveUnusedVariables(Reach.LOCAL_ONLY); - options.setCollapseObjectLiterals(true); - options.setProtectHiddenSideEffects(true); - //options.setPrettyPrint(true); + options.setRemoveUnusedPrototypeProperties(false); + options.setRemoveUnusedPrototypePropertiesInExterns(false); + options.setRemoveUnusedClassProperties(false); + + // Closure pass. + options.setClosurePass(true); + options.setManageClosureDependencies(true); + options.getDependencyOptions().setDependencyPruning(true); + options.getDependencyOptions().setDependencySorting(false); + options.getDependencyOptions().setMoocherDropping(false); + + // Polymer pass. + options.setPolymerVersion(1); + + // Debug flags. + if (testOnly) { + options.setPrettyPrint(true); + options.setGeneratePseudoNames(true); + options.setExportTestFunctions(true); + } + + // Don't print warnings from