fix merge issues

This commit is contained in:
Andrew Harp 2017-06-03 00:00:53 -04:00
commit 5efd272aab
378 changed files with 9541 additions and 4762 deletions

View File

@ -41,6 +41,15 @@
be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post- be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
processing of the rnn. For RNN decoding, this functionality has been replaced processing of the rnn. For RNN decoding, this functionality has been replaced
with an alternative API in `tf.contrib.seq2seq`. with an alternative API in `tf.contrib.seq2seq`.
* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of
optimized deep learning primitives: In addition to matrix multiplication and
convolution, these building blocks include:
Direct batched convolution
Pooling: maximum, minimum, average
Normalization: LRN, batch normalization
Activation: rectified linear unit (ReLU)
Data manipulation: multi-dimensional transposition (conversion), split,
concat, sum and scale.
* TensorForest Estimator now supports SavedModel export for serving. * TensorForest Estimator now supports SavedModel export for serving.
* Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters. * Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters.
* TensorFlow C library now available for Windows. * TensorFlow C library now available for Windows.

View File

@ -2,11 +2,11 @@ workspace(name = "org_tensorflow")
http_archive( http_archive(
name = "io_bazel_rules_closure", name = "io_bazel_rules_closure",
sha256 = "4be8a887f6f38f883236e77bb25c2da10d506f2bf1a8e5d785c0f35574c74ca4", sha256 = "edc91f556b762fc5212d1050d00b12e40dd0b0b1c1d5d96886b59e9a30a6cae4",
strip_prefix = "rules_closure-aac19edc557aec9b603cd7ffe359401264ceff0d", strip_prefix = "rules_closure-3f07fb6a58870afbb36051bd5d54da4479561cc6",
urls = [ urls = [
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", # 2017-05-10 "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz", # 2017-05-31
"https://github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", "https://github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz",
], ],
) )

View File

@ -393,6 +393,9 @@ filegroup(
"//tensorflow/tensorboard/demo:all_files", "//tensorflow/tensorboard/demo:all_files",
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
"//tensorflow/tensorboard/plugins:all_files", "//tensorflow/tensorboard/plugins:all_files",
"//tensorflow/tensorboard/plugins/audio:all_files",
"//tensorflow/tensorboard/plugins/distributions:all_files",
"//tensorflow/tensorboard/plugins/graphs:all_files",
"//tensorflow/tensorboard/plugins/histograms:all_files", "//tensorflow/tensorboard/plugins/histograms:all_files",
"//tensorflow/tensorboard/plugins/images:all_files", "//tensorflow/tensorboard/plugins/images:all_files",
"//tensorflow/tensorboard/plugins/projector:all_files", "//tensorflow/tensorboard/plugins/projector:all_files",

View File

@ -805,6 +805,7 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
} }
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec; std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
dim_vec.reserve(num_dims);
for (int i = 0; i < num_dims; ++i) { for (int i = 0; i < num_dims; ++i) {
dim_vec.push_back(ic->MakeDim(dims[i])); dim_vec.push_back(ic->MakeDim(dims[i]));
} }

View File

@ -113,10 +113,12 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
feeds.emplace_back(feed.first.name(), feed.second.tensor); feeds.emplace_back(feed.first.name(), feed.second.tensor);
} }
std::vector<string> output_tensor_names; std::vector<string> output_tensor_names;
output_tensor_names.reserve(fetch_outputs.size());
for (auto const& output : fetch_outputs) { for (auto const& output : fetch_outputs) {
output_tensor_names.push_back(output.name()); output_tensor_names.push_back(output.name());
} }
std::vector<string> target_node_names; std::vector<string> target_node_names;
target_node_names.reserve(run_outputs.size());
for (auto const& output : run_outputs) { for (auto const& output : run_outputs) {
target_node_names.push_back(output.node()->name()); target_node_names.push_back(output.node()->name());
} }

View File

@ -44,6 +44,7 @@ Status ComputeTheoreticalJacobianTranspose(
size_t x_num = x_shapes.size(); size_t x_num = x_shapes.size();
// Call AddSymbolicGradients to get 'dxs' (we will feed 'dys'). // Call AddSymbolicGradients to get 'dxs' (we will feed 'dys').
OutputList dys; OutputList dys;
dys.reserve(y_shapes.size());
for (const auto& y_shape : y_shapes) { for (const auto& y_shape : y_shapes) {
// TODO(suharshs): This currently assumes that all x's are the same type. // TODO(suharshs): This currently assumes that all x's are the same type.
dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type())); dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type()));

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/framework/testutil.h"
#include <utility>
#include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/client/client_session.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/default_device.h" #include "tensorflow/core/graph/default_device.h"
@ -30,7 +32,7 @@ void GetTensors(const Scope& scope, OutputList tensors,
void GetTensor(const Scope& scope, Output tensor, Tensor* out) { void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
std::vector<Tensor> outputs; std::vector<Tensor> outputs;
GetTensors(scope, {tensor}, &outputs); GetTensors(scope, {std::move(tensor)}, &outputs);
*out = outputs[0]; *out = outputs[0];
} }

View File

@ -350,6 +350,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
compile_result->program_shape = *pshape_or.ValueOrDie(); compile_result->program_shape = *pshape_or.ValueOrDie();
xla::ProgramShape* pshape = &compile_result->program_shape; xla::ProgramShape* pshape = &compile_result->program_shape;
std::vector<const xla::Shape*> arg_layouts; std::vector<const xla::Shape*> arg_layouts;
arg_layouts.reserve(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) { for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i)); arg_layouts.push_back(pshape->mutable_parameters(i));
} }

View File

@ -218,6 +218,7 @@ cc_library(
deps = [ deps = [
":common", ":common",
":graph_to_functiondef", ":graph_to_functiondef",
":union_find",
"//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/kernels:xla_local_launch_op", "//tensorflow/compiler/jit/kernels:xla_local_launch_op",
@ -237,6 +238,11 @@ cc_library(
], ],
) )
cc_library(
name = "union_find",
hdrs = ["union_find.h"],
)
cc_test( cc_test(
name = "compilation_passes_test", name = "compilation_passes_test",
size = "small", size = "small",

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <utility>
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
@ -101,12 +103,12 @@ Node* Input(const GraphDefBuilder::Options& opts) {
} }
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
return ops::UnaryOp("UnaryTest", a, opts); return ops::UnaryOp("UnaryTest", std::move(a), opts);
} }
Node* Binary(ops::NodeOut a, ops::NodeOut b, Node* Binary(ops::NodeOut a, ops::NodeOut b,
const GraphDefBuilder::Options& opts) { const GraphDefBuilder::Options& opts) {
return ops::BinaryOp("BinaryTest", a, b, opts); return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
} }
Node* AddNLike(const std::vector<ops::NodeOut>& inputs, Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
@ -127,7 +129,7 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr; if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval", NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
opts.op_registry()); opts.op_registry());
node_builder.Input(a).Attr("index", index); node_builder.Input(std::move(a)).Attr("index", index);
return opts.FinalizeBuilder(&node_builder); return opts.FinalizeBuilder(&node_builder);
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
@ -206,70 +207,12 @@ Status FindCompilationCandidates(
return Status::OK(); return Status::OK();
} }
// Union-Find data structure used to compute clusters. We use our own struct Cluster {
// implementation because we want one key feature: when merging clusters, we // Identifies the node that represents this cluster in the cycle detection
// need to know which value becomes the representative of the merged clusters. // graph.
// We use the representatives to name nodes in a cycle detection graph, and we int representative = -1;
// need to control which node is named.
// TODO(phawkins): consider merging this code with union-find implementations
// in Tensorflow, e.g., in SimplePlacer.
class Cluster {
public:
Cluster();
int Size() { return FindRoot()->size_; }
// Merges this cluster with 'other'. This cluster's representative becomes
// the representative of the merged cluster; the representative of 'other'
// is ignored.
void Merge(Cluster* other);
// Each cluster has an associated integer 'representative', initialized to -1
// by default.
int GetRepresentative() { return FindRoot()->representative_; }
void SetRepresentative(int representative) {
FindRoot()->representative_ = representative;
}
private:
// Finds the root element of the cluster. Performs path compression.
Cluster* FindRoot();
int representative_;
int rank_;
int size_; // Size of the cluster.
Cluster* parent_;
}; };
Cluster::Cluster()
: representative_(-1), rank_(0), size_(1), parent_(nullptr) {}
void Cluster::Merge(Cluster* other) {
Cluster* a = FindRoot();
Cluster* b = other->FindRoot();
if (a == b) return;
if (a->rank_ > b->rank_) {
b->parent_ = a;
a->size_ += b->size_;
return;
}
a->parent_ = b;
if (a->rank_ == b->rank_) {
b->rank_++;
}
b->representative_ = a->representative_;
b->size_ += a->size_;
}
Cluster* Cluster::FindRoot() {
if (!parent_) return this;
// Path compression: update intermediate nodes to point to the root of the
// equivalence class.
parent_ = parent_->FindRoot();
return parent_;
}
} // anonymous namespace } // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
@ -432,10 +375,11 @@ Status MarkForCompilationPass::RunImpl(
// Each compilation candidate belongs to a cluster. The cluster's // Each compilation candidate belongs to a cluster. The cluster's
// representative // representative
// names the node in the 'cycles' graph that represents the cluster. // names the node in the 'cycles' graph that represents the cluster.
std::vector<Cluster> clusters(graph->num_node_ids()); std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids());
std::deque<Cluster*> worklist; std::deque<UnionFind<Cluster>*> worklist;
for (Node* node : compilation_candidates) { for (Node* node : compilation_candidates) {
clusters[node->id()].SetRepresentative(node->id()); Cluster& cluster = clusters[node->id()].Get();
cluster.representative = node->id();
worklist.push_back(&clusters[node->id()]); worklist.push_back(&clusters[node->id()]);
} }
@ -445,7 +389,7 @@ Status MarkForCompilationPass::RunImpl(
// Repeatedly contract edges between clusters that are on the same device, // Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle. // provided the contraction would not create a cycle.
while (!worklist.empty()) { while (!worklist.empty()) {
int from = worklist.front()->GetRepresentative(); int from = worklist.front()->Get().representative;
worklist.pop_front(); worklist.pop_front();
Node* node_from = graph->FindNodeId(from); Node* node_from = graph->FindNodeId(from);
@ -518,7 +462,7 @@ Status MarkForCompilationPass::RunImpl(
// Count the number of elements in each cluster. // Count the number of elements in each cluster.
std::vector<int> cluster_sizes(graph->num_node_ids()); std::vector<int> cluster_sizes(graph->num_node_ids());
for (const Node* n : compilation_candidates) { for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative(); int cluster = clusters[n->id()].Get().representative;
cluster_sizes[cluster]++; cluster_sizes[cluster]++;
} }
@ -532,7 +476,7 @@ Status MarkForCompilationPass::RunImpl(
// if compilation is enabled, otherwise there will be no such candidates). // if compilation is enabled, otherwise there will be no such candidates).
const int min_cluster_size = flags->tf_xla_min_cluster_size; const int min_cluster_size = flags->tf_xla_min_cluster_size;
for (Node* n : compilation_candidates) { for (Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative(); int cluster = clusters[n->id()].Get().representative;
// Compile if the user marked this node _XlaCompile=true // Compile if the user marked this node _XlaCompile=true
bool compile_attr = false; bool compile_attr = false;

View File

@ -0,0 +1,81 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
namespace tensorflow {
// Union-Find data structure.
// Each cluster has an associated value; when merging clusters we can control
// which value becomes the representative of the merged clusters. Values must be
// copyable.
template <typename T>
class UnionFind {
public:
UnionFind() : rank_(0), size_(1), parent_(nullptr) {}
// Returns the number of elements in a cluster.
int Size() { return FindRoot()->size_; }
// Merges this cluster with 'other'. This cluster's value becomes
// the value of the merged cluster; the value of 'other' is ignored.
void Merge(UnionFind* other);
// Each cluster has an associated value. Retrieves the value associated
// with this cluster.
T& Get() { return FindRoot()->value_; }
private:
// Finds the root element of the cluster. Performs path compression.
UnionFind* FindRoot();
int rank_;
int size_; // Size of the cluster.
UnionFind* parent_;
T value_;
};
template <typename T>
void UnionFind<T>::Merge(UnionFind* other) {
UnionFind<T>* a = FindRoot();
UnionFind<T>* b = other->FindRoot();
if (a == b) return;
if (a->rank_ > b->rank_) {
b->parent_ = a;
a->size_ += b->size_;
return;
}
a->parent_ = b;
if (a->rank_ == b->rank_) {
b->rank_++;
}
b->value_ = a->value_;
b->size_ += a->size_;
}
template <typename T>
UnionFind<T>* UnionFind<T>::FindRoot() {
if (!parent_) return this;
// Path compression: update intermediate nodes to point to the root of the
// equivalence class.
parent_ = parent_->FindRoot();
return parent_;
}
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_

View File

@ -50,6 +50,7 @@ class FillOp : public XlaOpKernel {
// Convert the dims literal into a vector that we can pass to // Convert the dims literal into a vector that we can pass to
// ComputationBuilder. // ComputationBuilder.
std::vector<int64> broadcast; std::vector<int64> broadcast;
broadcast.reserve(dims_literal.shape().dimensions(0));
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
broadcast.push_back(xla::LiteralUtil::Get<int>(dims_literal, {i})); broadcast.push_back(xla::LiteralUtil::Get<int>(dims_literal, {i}));
} }

View File

@ -50,6 +50,7 @@ class SliceOp : public XlaOpKernel {
// slice will be an empty handle if the output has no elements. // slice will be an empty handle if the output has no elements.
CHECK_EQ(begin.size(), size.size()); CHECK_EQ(begin.size(), size.size());
std::vector<int64> limits; std::vector<int64> limits;
limits.reserve(begin.size());
for (int i = 0; i < begin.size(); ++i) { for (int i = 0; i < begin.size(); ++i) {
limits.push_back(begin[i] + size[i]); limits.push_back(begin[i] + size[i]);
} }

View File

@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"

View File

@ -58,14 +58,13 @@ StatusOr<std::unique_ptr<Literal>> Client::Transfer(
"server provided response without a literal in " "server provided response without a literal in "
"TransferToClient request"); "TransferToClient request");
} }
return MakeUnique<Literal>(response.literal());
return WrapUnique(response.release_literal());
} }
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer( StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
const Literal& literal, const DeviceHandle* device_handle) { const Literal& literal, const DeviceHandle* device_handle) {
TransferToServerRequest request; TransferToServerRequest request;
*request.mutable_literal() = literal; *request.mutable_literal() = literal.ToProto();
if (device_handle) { if (device_handle) {
*request.mutable_device_handle() = *device_handle; *request.mutable_device_handle() = *device_handle;
} }
@ -93,7 +92,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
const DeviceHandle* device_handle) { const DeviceHandle* device_handle) {
TransferToInfeedRequest request; TransferToInfeedRequest request;
*request.mutable_literal() = literal; *request.mutable_literal() = literal.ToProto();
if (device_handle) { if (device_handle) {
*request.mutable_device_handle() = *device_handle; *request.mutable_device_handle() = *device_handle;
} }
@ -141,7 +140,8 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
"TransferToClient request"); "TransferToClient request");
} }
return WrapUnique(response.release_literal()); Literal literal(response.literal());
return MakeUnique<Literal>(literal);
} }
Status Client::ResetDevice() { Status Client::ResetDevice() {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"

View File

@ -165,9 +165,10 @@ ComputationDataHandle ComputationBuilder::ConstantOp(
} }
ConstantRequest request; ConstantRequest request;
Literal* literal = request.mutable_literal(); Literal literal;
populate(literal); populate(&literal);
VLOG(3) << "created constant: " << literal->ShortDebugString(); *request.mutable_literal() = literal.ToProto();
VLOG(3) << "created constant: " << request.literal().ShortDebugString();
OpRequest op_request; OpRequest op_request;
*op_request.mutable_constant_request() = request; *op_request.mutable_constant_request() = request;
*op_request.mutable_computation() = computation_.handle(); *op_request.mutable_computation() = computation_.handle();

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/global_data.h"
#include <string> #include <string>
#include <utility>
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -23,7 +24,7 @@ limitations under the License.
namespace xla { namespace xla {
GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle) GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle)
: handle_(handle), parent_(parent) {} : handle_(std::move(handle)), parent_(parent) {}
GlobalData::~GlobalData() { GlobalData::~GlobalData() {
UnregisterRequest request; UnregisterRequest request;

View File

@ -222,8 +222,9 @@ tensorflow::Status LocalExecutable::RecordArguments(
SessionModule* session_module) { SessionModule* session_module) {
session_module->clear_arguments(); session_module->clear_arguments();
for (const ShapedBuffer* argument : arguments) { for (const ShapedBuffer* argument : arguments) {
TF_RETURN_IF_ERROR( Literal literal;
LiteralFromShapedBuffer(*argument, session_module->add_arguments())); TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal));
*session_module->add_arguments() = literal.ToProto();
} }
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
@ -231,9 +232,13 @@ tensorflow::Status LocalExecutable::RecordArguments(
tensorflow::Status LocalExecutable::RecordResult( tensorflow::Status LocalExecutable::RecordResult(
const ShapedBuffer* result, SessionModule* session_module) { const ShapedBuffer* result, SessionModule* session_module) {
session_module->clear_result(); session_module->clear_result();
return LiteralFromShapedBuffer(*result, session_module->mutable_result()); Literal literal(session_module->result());
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal));
*session_module->mutable_result() = literal.ToProto();
return tensorflow::Status::OK();
} }
// TODO(dnovillo) Change signature to return StatusOr<Literal>.
tensorflow::Status LocalExecutable::LiteralFromShapedBuffer( tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer, Literal* literal) { const ShapedBuffer& shaped_buffer, Literal* literal) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -856,5 +856,26 @@ TEST_F(LiteralUtilTest, ConvertR4) {
EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted)); EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted));
} }
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
LiteralProto p;
p.mutable_shape()->set_element_type(PRED);
for (int len = 0; len < 25; ++len) {
p.mutable_shape()->clear_dimensions();
p.mutable_shape()->add_dimensions(len);
p.clear_preds();
for (int i = 0; i < len; ++i) {
p.add_preds((i % 2) == (len % 2));
}
Literal literal(p);
ASSERT_EQ(len, literal.preds_size());
int i = 0;
for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) {
EXPECT_EQ((i % 2) == (len % 2), *it);
++i;
}
}
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -60,8 +60,8 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
int64 elements = ShapeUtil::ElementsIn(shape); int64 elements = ShapeUtil::ElementsIn(shape);
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(), LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
result.get()); result.get());
tensorflow::protobuf::RepeatedField<float>* field = result->mutable_f32s(); std::vector<float>* field = result->mutable_f32s();
char* data = tensorflow::bit_cast<char*>(field->mutable_data()); char* data = tensorflow::bit_cast<char*>(field->data());
uint64 bytes = elements * sizeof(float); uint64 bytes = elements * sizeof(float);
tensorflow::StringPiece sp; tensorflow::StringPiece sp;
auto s = file_->Read(offset_, bytes, &sp, data); auto s = file_->Read(offset_, bytes, &sp, data);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/reference_util.h"
#include <array> #include <array>
#include <utility>
#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
@ -331,7 +332,8 @@ ReferenceUtil::ConvArray4DGeneralDimensions(
std::pair<int64, int64> kernel_stride, Padding padding, std::pair<int64, int64> kernel_stride, Padding padding,
ConvolutionDimensionNumbers dimension_numbers) { ConvolutionDimensionNumbers dimension_numbers) {
return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
{1, 1}, {1, 1}, dimension_numbers); {1, 1}, {1, 1},
std::move(dimension_numbers));
} }
/* static */ std::unique_ptr<Array4D<float>> /* static */ std::unique_ptr<Array4D<float>>

View File

@ -529,6 +529,7 @@ cc_library(
srcs = ["transfer_manager.cc"], srcs = ["transfer_manager.cc"],
hdrs = ["transfer_manager.h"], hdrs = ["transfer_manager.h"],
deps = [ deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
@ -1680,10 +1681,8 @@ cc_library(
deps = [ deps = [
":buffer_assignment", ":buffer_assignment",
":hlo", ":hlo",
":hlo_ordering",
":hlo_proto", ":hlo_proto",
"//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )

View File

@ -171,6 +171,7 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
executor, allocation->device_memory(), allocation->shape())); executor, allocation->device_memory(), allocation->shape()));
std::vector<GlobalDataHandle> element_handles; std::vector<GlobalDataHandle> element_handles;
element_handles.reserve(element_bases.size());
for (int i = 0; i < element_bases.size(); ++i) { for (int i = 0; i < element_bases.size(); ++i) {
element_handles.push_back(RegisterInternal( element_handles.push_back(RegisterInternal(
allocation->backend(), allocation->device_ordinal(), element_bases[i], allocation->backend(), allocation->device_ordinal(), element_bases[i],

View File

@ -229,7 +229,8 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
// Mapping from LogicalBuffer to index (used to detect non-distinct indices). // Mapping from LogicalBuffer to index (used to detect non-distinct indices).
FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>> FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>>
buffer_to_source_indices; buffer_to_source_indices;
TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices]( TF_RETURN_IF_ERROR(points_to.ForEachElement(
[this, &buffer_to_source_indices](
const ShapeIndex& index, bool /*is_leaf*/, const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) { const std::vector<const LogicalBuffer*>& buffers) {
if (buffers.size() > 1) { if (buffers.size() > 1) {
@ -449,10 +450,14 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) { FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) {
const HloInstruction* init_hlo = while_hlo->operand(0); const HloInstruction* init_hlo = while_hlo->operand(0);
const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo);
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
FlatSet<const LogicalBuffer*> buffer_set;
ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape()); ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
TF_RETURN_IF_ERROR(points_to.ForEachElement( TF_RETURN_IF_ERROR(points_to.ForEachElement(
[init_hlo, read_only_indices, shared_copies, &copy_overrides]( [init_hlo, read_only_indices, shared_copies, &buffer_set,
const ShapeIndex& index, bool /*is_leaf*/, &copy_overrides](const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) { const std::vector<const LogicalBuffer*>& buffers) {
// Look for read-only entry parameters. // Look for read-only entry parameters.
if (!read_only_indices->element(index)) { if (!read_only_indices->element(index)) {
@ -468,6 +473,7 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
if (!is_entry_parameter && !is_constant) { if (!is_entry_parameter && !is_constant) {
continue; continue;
} }
// We have found an entry parameter or constant that is read-only in // We have found an entry parameter or constant that is read-only in
// the while body. These buffers are managed by the caller, and cannot // the while body. These buffers are managed by the caller, and cannot
// be aliased with non-parameter buffers. Revert this read-only index, // be aliased with non-parameter buffers. Revert this read-only index,
@ -476,16 +482,17 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
// Optimization to allow multiple while loops that share the same // Optimization to allow multiple while loops that share the same
// read-only entry parameters (or constants) to share a single copy. // read-only entry parameters (or constants) to share a single copy.
// Only unambiguous array-shaped buffers are allowed, to reduce code // Only unambiguous and distinct array-shaped buffers are allowed, to
// complexity. The shape of the entry parameter must be identical to // reduce code complexity. The shape of the entry parameter must be
// the shape of the init_hlo at this index, to ensure there were no // identical to the shape of the init_hlo at this index, to ensure
// intervening bitcast or GTE instructions, which are also hard to // there were no intervening bitcast or GTE instructions, which are
// handle. // also hard to handle.
const Shape& pointee_shape = pointee->shape(); const Shape& pointee_shape = pointee->shape();
const Shape& init_shape = const Shape& init_shape =
ShapeUtil::GetSubshape(init_hlo->shape(), index); ShapeUtil::GetSubshape(init_hlo->shape(), index);
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
ShapeUtil::Equal(pointee_shape, init_shape)) { ShapeUtil::Equal(pointee_shape, init_shape) &&
buffer_set.count(buffer) < 1) {
HloInstruction** copy = &(*shared_copies)[pointee]; HloInstruction** copy = &(*shared_copies)[pointee];
if (*copy == nullptr) { if (*copy == nullptr) {
*copy = *copy =
@ -496,6 +503,9 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
*copy_overrides.mutable_element(index) = *copy; *copy_overrides.mutable_element(index) = *copy;
} }
// Tracks whether this current buffer is distinct.
buffer_set.insert(buffer);
// We've already reverted the read-only index and handled the // We've already reverted the read-only index and handled the
// single-copy optimization above, so there's nothing more to do. // single-copy optimization above, so there's nothing more to do.
break; break;

View File

@ -44,13 +44,20 @@ class CopyInsertionTest : public HloTestBase {
EXPECT_IS_OK(copy_insertion.Run(module).status()); EXPECT_IS_OK(copy_insertion.Run(module).status());
// Verify the points to set of the root of the computation after copy // Verify the points to set of the root of the computation after copy
// insertion contains no constants or parameters. // insertion contains no constants or parameters, and is distinct and
// non-ambiguous.
auto points_to_analysis = auto points_to_analysis =
TuplePointsToAnalysis::Run(module).ConsumeValueOrDie(); TuplePointsToAnalysis::Run(module).ConsumeValueOrDie();
const auto& points_to = points_to_analysis->GetPointsToSet(
module->entry_computation()->root_instruction());
EXPECT_TRUE(points_to.IsDistinct());
EXPECT_TRUE(!points_to.IsAmbiguous());
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers = tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers =
points_to_analysis points_to_analysis
->GetPointsToSet(module->entry_computation()->root_instruction()) ->GetPointsToSet(module->entry_computation()->root_instruction())
.CreateFlattenedSet(); .CreateFlattenedSet();
for (const LogicalBuffer* buffer : maybe_live_out_buffers) { for (const LogicalBuffer* buffer : maybe_live_out_buffers) {
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant); EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant);
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter);
@ -390,6 +397,47 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
return builder.Build(); return builder.Build();
} }
// Builds a While body computation with two output tuple elements dependent on
// both input tuple elements.
//
// EX: Body({in0, in1, in2})
// out0 = Add(in0, 1)
// out1 = in1
// out2 = in2
// Tuple(out0, out1, out2)
std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
auto builder = HloComputation::Builder(TestName() + ".Body");
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
{induction_variable_shape_, data_shape_, data_shape_});
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
// Update the induction variable GTE(0).
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
// add0 = Add(in0, 1)
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
// data1 = GTE(1).
HloInstruction* data1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
// data2 = GTE(2).
HloInstruction* data2 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
// Create output Tuple.
builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
return builder.Build();
}
// Builds a While body computation with read-only tuple element 0. // Builds a While body computation with read-only tuple element 0.
// EX: // EX:
// Body({in0, in1}) // Body({in0, in1})
@ -408,6 +456,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// Update data GTE(1). // Update data GTE(1).
auto data = builder.AddInstruction( auto data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
// Use 'induction_variable' in computation with no path to output tuple. // Use 'induction_variable' in computation with no path to output tuple.
auto update = builder.AddInstruction( auto update = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
@ -431,6 +480,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// Create param instruction to access loop state. // Create param instruction to access loop state.
const Shape& loop_state_shape = const Shape& loop_state_shape =
nested ? nested_loop_state_shape_ : loop_state_shape_; nested ? nested_loop_state_shape_ : loop_state_shape_;
auto loop_state = builder.AddInstruction( auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
// Update the induction variable GTE(0). // Update the induction variable GTE(0).
@ -972,7 +1022,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
op::Copy(old_init->operand(1)->operand(0))))); op::Copy(old_init->operand(1)->operand(0)))));
} }
// Tests while init instruction buffer which interferes with while result buffer. // Tests while init instruction buffer which interferes with while result
// buffer.
// //
// init_data = Broadcast(...) // init_data = Broadcast(...)
// add_unrelated = Add(init_data) // takes a reference to cause interference // add_unrelated = Add(init_data) // takes a reference to cause interference
@ -989,5 +1040,81 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
op::Copy(old_init->operand(1)))); op::Copy(old_init->operand(1))));
} }
// Tests while init instruction buffer which has a non-distinct points-to set:
//
// init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
// Parameter(F32, {8})))
//
// where the second and third parameters are identical *and* the tuple shared
// by another while instruction..
//
// Verifies that the resulting point-to set is distinct in the resulting Tuple
// (non-identical Copys). In other words, verifies that copy sharing does not
// insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation());
auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation());
// Loop body that outputs tuple comprises two elements dependent on the init
// tuple.
auto body1 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
auto body2 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
auto builder = HloComputation::Builder(TestName() + ".While");
auto iter_param = builder.AddInstruction(
HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
auto data_param = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "data"));
// Loop init tuple contains two identical parameter buffers.
auto loop_init = builder.AddInstruction(
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
{induction_variable_shape_, data_shape_, data_shape_});
// Two while loops shares the same loop init tuple.
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition1, body1, loop_init));
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition2, body2, loop_init));
module_.AddEntryComputation(builder.Build());
auto points_to_analysis =
TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
// Asserts that the init tuples before copy insertion is non-distinct.
ASSERT_FALSE(
points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct());
ASSERT_FALSE(
points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct());
auto old_init1 = while_hlo1->operand(0);
auto old_init2 = while_hlo2->operand(0);
InsertCopies(&module_);
EXPECT_THAT(while_hlo1->operand(0),
op::Tuple(op::Copy(old_init1->operand(0)),
op::Copy(old_init1->operand(1)),
op::Copy(old_init1->operand(2))));
EXPECT_THAT(while_hlo2->operand(0),
op::Tuple(op::Copy(old_init2->operand(0)),
op::Copy(old_init2->operand(1)),
op::Copy(old_init2->operand(2))));
// Verifies the init tuples after copy insertion is distinct.
points_to_analysis = TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
const auto& points_to1 =
points_to_analysis->GetPointsToSet(while_hlo1->operand(0));
EXPECT_TRUE(points_to1.IsDistinct());
const auto& points_to2 =
points_to_analysis->GetPointsToSet(while_hlo2->operand(0));
EXPECT_TRUE(points_to2.IsDistinct());
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"

View File

@ -31,7 +31,7 @@ AsyncExecution::AsyncExecution(Backend* backend,
: backend_(CHECK_NOTNULL(backend)), : backend_(CHECK_NOTNULL(backend)),
streams_(std::move(streams)), streams_(std::move(streams)),
profile_(profile), profile_(profile),
result_(result) { result_(std::move(result)) {
for (const auto& stream : streams_) { for (const auto& stream : streams_) {
CHECK(stream != nullptr); CHECK(stream != nullptr);
} }

View File

@ -254,6 +254,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
// d40 -- layer 4 // d40 -- layer 4
HloComputation::Builder builder("entry_computation"); HloComputation::Builder builder("entry_computation");
std::vector<HloInstruction*> params; std::vector<HloInstruction*> params;
params.reserve(6);
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));

View File

@ -1631,6 +1631,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk(
// Compute the input buffer indices. // Compute the input buffer indices.
std::vector<BufferAllocation::Slice> io_buffers; std::vector<BufferAllocation::Slice> io_buffers;
io_buffers.reserve(io_hlos.size());
for (const HloInstruction* io_hlo : io_hlos) { for (const HloInstruction* io_hlo : io_hlos) {
io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo))); io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo)));
} }

View File

@ -86,6 +86,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
// d40 -- layer 4 // d40 -- layer 4
HloComputation::Builder builder("entry_computation"); HloComputation::Builder builder("entry_computation");
std::vector<HloInstruction*> params; std::vector<HloInstruction*> params;
params.reserve(6);
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));

View File

@ -46,7 +46,7 @@ message HloInstructionProto {
xla.OpMetadata metadata = 7; xla.OpMetadata metadata = 7;
// Literal, only present for kConstant. // Literal, only present for kConstant.
xla.Literal literal = 8; xla.LiteralProto literal = 8;
// Parameter info, only present for kParameter. // Parameter info, only present for kParameter.
int64 parameter_number = 9; int64 parameter_number = 9;

View File

@ -311,7 +311,6 @@ void ComputeComputationPostOrder(
visited->insert(computation); visited->insert(computation);
post_order->push_back(computation); post_order->push_back(computation);
return;
} }
} // namespace } // namespace

View File

@ -65,7 +65,7 @@ using ::tensorflow::strings::StrCat;
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
instruction->operands_.push_back(operand); instruction->operands_.push_back(operand);
instruction->literal_.reset(new Literal); instruction->literal_.reset(new Literal);
*instruction->literal_->mutable_u8s() += tag; instruction->literal_->append_u8s(tag);
return instruction; return instruction;
} }
@ -1484,6 +1484,7 @@ string HloInstruction::ToString(bool compact_operands,
} }
if (!slice_starts_.empty() && !slice_limits_.empty()) { if (!slice_starts_.empty() && !slice_limits_.empty()) {
std::vector<string> bounds; std::vector<string> bounds;
bounds.reserve(slice_starts_.size());
for (int i = 0; i < slice_starts_.size(); ++i) { for (int i = 0; i < slice_starts_.size(); ++i) {
bounds.push_back( bounds.push_back(
StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]")); StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]"));
@ -1550,7 +1551,7 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_; *proto.mutable_metadata() = metadata_;
switch (opcode_) { switch (opcode_) {
case HloOpcode::kConstant: case HloOpcode::kConstant:
*proto.mutable_literal() = *literal_; *proto.mutable_literal() = literal_->ToProto();
break; break;
case HloOpcode::kParameter: case HloOpcode::kParameter:
proto.set_parameter_number(parameter_number_); proto.set_parameter_number(parameter_number_);
@ -1647,10 +1648,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
trace_instruction_ = trace_instruction; trace_instruction_ = trace_instruction;
} }
const string& HloInstruction::tracing_tag() const { string HloInstruction::TracingTag() const {
CHECK_EQ(HloOpcode::kTrace, opcode()); CHECK_EQ(HloOpcode::kTrace, opcode());
CHECK(literal_ != nullptr); CHECK(literal_ != nullptr);
return literal_->u8s(); return literal_->u8s_string();
} }
bool HloInstruction::IsFused() const { bool HloInstruction::IsFused() const {

View File

@ -30,6 +30,7 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@ -535,7 +536,7 @@ class HloInstruction {
// Returns a tag to be used in tracing. // Returns a tag to be used in tracing.
// //
// Precondition: opcode() == HloOpcode::kTrace // Precondition: opcode() == HloOpcode::kTrace
const string& tracing_tag() const; string TracingTag() const;
// Returns whether the instruction is a constant. // Returns whether the instruction is a constant.
bool IsConstant() const; bool IsConstant() const;

View File

@ -151,7 +151,26 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
return true; return true;
}; };
if (std::all_of(hlo->users().begin(), hlo->users().end(), // An "effectively unary" operation is one that has one "large"
// input with the others being negligible in terms of memory usage.
// We use "has a smaller true rank than the output" as a heuristic
// for "negligible" memory usage.
auto effectively_unary = [](HloInstruction* hlo) {
if (hlo->operands().size() == 1) {
return true;
}
auto output_rank = ShapeUtil::TrueRank(hlo->shape());
return std::count_if(
hlo->operands().begin(), hlo->operands().end(),
[output_rank](HloInstruction* operand) {
return ((operand->opcode() != HloOpcode::kBroadcast) &&
ShapeUtil::TrueRank(operand->shape()) >=
output_rank);
}) <= 1;
};
if (effectively_unary(hlo) ||
std::all_of(hlo->users().begin(), hlo->users().end(),
user_fusable_into_hlo)) { user_fusable_into_hlo)) {
all_consumers_fusable.insert(hlo); all_consumers_fusable.insert(hlo);
} }

View File

@ -156,21 +156,67 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
HloComputation::Builder builder(TestName()); HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( auto shape = ShapeUtil::MakeShape(F32, {16, 16});
0, ShapeUtil::MakeShape(F32, {16, 16}), "0")); auto param0 =
HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary( builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0)); auto param1 =
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary( HloInstruction* binary1 = builder.AddInstruction(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1)); HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
auto module = MakeUnique<HloModule>(TestName()); auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build()); auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary2, computation->root_instruction()); EXPECT_EQ(unary, computation->root_instruction());
EXPECT_FALSE( EXPECT_FALSE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get()) .Run(module.get())
.ValueOrDie()); .ValueOrDie());
} }
TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
HloComputation::Builder builder(TestName());
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto param0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
HloInstruction* unary1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
HloInstruction* unary2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary2, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto small_shape = ShapeUtil::MakeShape(F32, {16});
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, small_shape, "0"));
auto param1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
} // namespace xla } // namespace xla

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "external/llvm/include/llvm/IR/Module.h" #include "external/llvm/include/llvm/IR/Module.h"
#include "external/llvm/include/llvm/IR/Value.h" #include "external/llvm/include/llvm/IR/Value.h"
#include "external/llvm/include/llvm/Support/raw_ostream.h" #include "external/llvm/include/llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"

View File

@ -77,8 +77,10 @@ tensorflow::Status RecordArguments(
SessionModule* module) { SessionModule* module) {
module->clear_arguments(); module->clear_arguments();
for (const Allocation* allocation : arg_allocations) { for (const Allocation* allocation : arg_allocations) {
TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(), Literal argument;
module->add_arguments())); TF_RETURN_IF_ERROR(
LiteralFromAllocation(allocation, allocation->shape(), &argument));
*module->add_arguments() = argument.ToProto();
} }
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
@ -87,8 +89,11 @@ tensorflow::Status RecordArguments(
tensorflow::Status RecordResult(const Allocation* result_allocation, tensorflow::Status RecordResult(const Allocation* result_allocation,
SessionModule* module) { SessionModule* module) {
module->clear_result(); module->clear_result();
return LiteralFromAllocation(result_allocation, result_allocation->shape(), Literal result;
module->mutable_result()); TF_RETURN_IF_ERROR(LiteralFromAllocation(
result_allocation, result_allocation->shape(), &result));
*module->mutable_result() = result.ToProto();
return tensorflow::Status::OK();
} }
} // namespace } // namespace
@ -649,6 +654,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), ResolveAndValidateArguments(request.arguments(), execute_backend_.get(),
executor->device_ordinal())); executor->device_ordinal()));
std::vector<se::DeviceMemoryBase> arguments; std::vector<se::DeviceMemoryBase> arguments;
arguments.reserve(arg_allocations.size());
for (const Allocation* allocation : arg_allocations) { for (const Allocation* allocation : arg_allocations) {
arguments.push_back(allocation->device_memory()); arguments.push_back(allocation->device_memory());
} }
@ -677,6 +683,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
BuildExecutables(versioned_handles, std::move(module_configs), BuildExecutables(versioned_handles, std::move(module_configs),
execute_backend_.get(), executors)); execute_backend_.get(), executors));
std::vector<Executable*> executable_ptrs; std::vector<Executable*> executable_ptrs;
executable_ptrs.reserve(executables.size());
for (const auto& executable : executables) { for (const auto& executable : executables) {
executable_ptrs.push_back(executable.get()); executable_ptrs.push_back(executable.get());
} }
@ -752,6 +759,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
<< module_config->entry_computation_layout().ToString(); << module_config->entry_computation_layout().ToString();
std::vector<se::DeviceMemoryBase> arguments; std::vector<se::DeviceMemoryBase> arguments;
arguments.reserve(arg_allocations.size());
for (const Allocation* allocation : arg_allocations) { for (const Allocation* allocation : arg_allocations) {
arguments.push_back(allocation->device_memory()); arguments.push_back(allocation->device_memory());
} }
@ -820,6 +828,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
<< module_config->entry_computation_layout().ToString(); << module_config->entry_computation_layout().ToString();
std::vector<se::DeviceMemoryBase> arguments; std::vector<se::DeviceMemoryBase> arguments;
arguments.reserve(arg_allocations.size());
for (const Allocation* allocation : arg_allocations) { for (const Allocation* allocation : arg_allocations) {
arguments.push_back(allocation->device_memory()); arguments.push_back(allocation->device_memory());
} }
@ -908,13 +917,15 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
literal_shape = &allocation->shape(); literal_shape = &allocation->shape();
} }
return LiteralFromAllocation(allocation, *literal_shape, Literal literal;
result->mutable_literal()); auto status = LiteralFromAllocation(allocation, *literal_shape, &literal);
*result->mutable_literal() = literal.ToProto();
return status;
} }
tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) { TransferToServerResponse* result) {
const Literal& literal = arg->literal(); Literal literal = Literal(arg->literal());
const Shape& shape = literal.shape(); const Shape& shape = literal.shape();
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
@ -978,7 +989,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
} }
return execute_backend_->transfer_manager()->TransferLiteralToInfeed( return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
executor, arg->literal()); executor, Literal(arg->literal()));
} }
tensorflow::Status Service::TransferFromOutfeed( tensorflow::Status Service::TransferFromOutfeed(
@ -1001,8 +1012,12 @@ tensorflow::Status Service::TransferFromOutfeed(
executor = execute_backend_->Replicas()[arg->replica_id()]; executor = execute_backend_->Replicas()[arg->replica_id()];
} }
return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( Literal literal;
executor, arg->shape_with_layout(), result->mutable_literal()); TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
executor, arg->shape_with_layout(), &literal));
*result->mutable_literal() = literal.ToProto();
return tensorflow::Status::OK();
} }
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,

View File

@ -75,10 +75,10 @@ message SessionModule {
repeated SessionComputation embedded_computations = 2; repeated SessionComputation embedded_computations = 2;
// The arguments passed to the computation. // The arguments passed to the computation.
repeated Literal arguments = 3; repeated LiteralProto arguments = 3;
// The result of the computation. // The result of the computation.
Literal result = 4; LiteralProto result = 4;
// The name of the platform used to run the computation. // The name of the platform used to run the computation.
string execution_platform = 5; string execution_platform = 5;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <set> #include <set>
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -121,7 +121,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) {
const Shape shape = ShapeUtil::MakeShape(U8, {4}); const Shape shape = ShapeUtil::MakeShape(U8, {4});
TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
stream_exec_, memptr, shape, shape, &literal)); stream_exec_, memptr, shape, shape, &literal));
CHECK_EQ("klmn", literal.u8s()); CHECK_EQ("klmn", literal.u8s_string());
} }
TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) { TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {

View File

@ -2275,7 +2275,7 @@ void ComputationLowerer::Visit(
const ConstantRequest& constant_request = const ConstantRequest& constant_request =
request.request().constant_request(); request.request().constant_request();
hlo_instruction = add_instruction(HloInstruction::CreateConstant( hlo_instruction = add_instruction(HloInstruction::CreateConstant(
LiteralUtil::CloneToUnique(constant_request.literal()))); LiteralUtil::CloneToUnique(Literal(constant_request.literal()))));
break; break;
} }
@ -2467,6 +2467,7 @@ void ComputationLowerer::Visit(
// to append dimensions on the left the broadcast_dimensions should just // to append dimensions on the left the broadcast_dimensions should just
// be the n highest dimension numbers of the output shape where n is // be the n highest dimension numbers of the output shape where n is
// the number of input dimensions. // the number of input dimensions.
broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape()));
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
broadcast_dimensions.push_back(i + broadcast_dimensions.push_back(i +
ShapeUtil::Rank(request.output_shape()) - ShapeUtil::Rank(request.output_shape()) -

View File

@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) {
ConstantRequest constant_request; ConstantRequest constant_request;
*constant_request.mutable_literal() = *constant_request.mutable_literal() =
*LiteralUtil::CreateR1<float>({123.0f, 42.0f}); LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle, TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle,
computation.AddConstantInstruction(constant_request)); computation.AddConstantInstruction(constant_request));
@ -160,12 +160,13 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
UserComputation computation("TheComputation", handle); UserComputation computation("TheComputation", handle);
ConstantRequest a_request; ConstantRequest a_request;
*a_request.mutable_literal() = *LiteralUtil::CreateR1<float>({123.0f, 42.0f}); *a_request.mutable_literal() =
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
computation.AddConstantInstruction(a_request)); computation.AddConstantInstruction(a_request));
ConstantRequest b_request; ConstantRequest b_request;
*b_request.mutable_literal() = *LiteralUtil::CreateR0<float>(1.0f); *b_request.mutable_literal() = LiteralUtil::CreateR0<float>(1.0f)->ToProto();
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
computation.AddConstantInstruction(b_request)); computation.AddConstantInstruction(b_request));

View File

@ -44,6 +44,7 @@ struct ShapeTreeNode {
// Children of this node. // Children of this node.
std::vector<std::unique_ptr<ShapeTreeNode>> children; std::vector<std::unique_ptr<ShapeTreeNode>> children;
ShapeTreeNode() = default;
explicit ShapeTreeNode(const T& data) : data(data) {} explicit ShapeTreeNode(const T& data) : data(data) {}
ShapeTreeNode(const ShapeTreeNode& other) ShapeTreeNode(const ShapeTreeNode& other)
@ -85,8 +86,9 @@ class ShapeTree {
public: public:
// Default constructor creates a tree with a nil shape (i.e. an empty tuple). // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
// Create ShapeTree with the given shape, and default T values for all nodes. // Create ShapeTree with the given shape, and default-constructed T values for
explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {} // all nodes.
explicit ShapeTree(const Shape& shape);
// Create ShapeTree with the given shape, and init_value for all nodes. // Create ShapeTree with the given shape, and init_value for all nodes.
ShapeTree(const Shape& shape, const T& init_value); ShapeTree(const Shape& shape, const T& init_value);
@ -127,6 +129,19 @@ class ShapeTree {
const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>; const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
Status ForEachMutableElement(const MutableVisitorFunction& func); Status ForEachMutableElement(const MutableVisitorFunction& func);
// Copy the subtree of values from 'other' rooted at ShapeIndex
// 'source_base_index' into the subtree of value in this ShapeTree rooted at
// 'target_base_index'.
//
// Precondition: The subshape of other.shape() at index source_base_index must
// be compatible with the subshape of shape() at index target_base_index.
void CopySubtreeFrom(const ShapeTree<T>& other,
const ShapeIndex& source_base_index,
const ShapeIndex& target_base_index);
bool operator==(const ShapeTree<T>& other) const;
bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
private: private:
using Node = internal::ShapeTreeNode<T>; using Node = internal::ShapeTreeNode<T>;
@ -134,6 +149,10 @@ class ShapeTree {
// the given 'init_value'. // the given 'init_value'.
void InitChildren(const Shape& shape, const T& init_value, Node* node); void InitChildren(const Shape& shape, const T& init_value, Node* node);
// Initialize node->children based on 'shape'. All children have
// default-constructed data values.
void InitChildren(const Shape& shape, Node* node);
// Helpers for traversing the shape via ForEachElement. The helpers // Helpers for traversing the shape via ForEachElement. The helpers
// recursively traverse the subtree rooted at "index" (defined as in // recursively traverse the subtree rooted at "index" (defined as in
// ShapeUtil::GetSubshape). // ShapeUtil::GetSubshape).
@ -165,6 +184,24 @@ void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
} }
} }
template <typename T>
void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
if (ShapeUtil::IsTuple(shape)) {
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
node->children.emplace_back(new Node());
InitChildren(shape.tuple_shapes(i), node->children.back().get());
}
}
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape& shape) : root_(), shape_(shape) {
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(&shape_);
InitChildren(shape_, &root_);
}
template <typename T> template <typename T>
ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value) ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
: root_(init_value), shape_(shape) { : root_(init_value), shape_(shape) {
@ -240,6 +277,48 @@ Status ShapeTree<T>::ForEachMutableElement(const MutableVisitorFunction& func) {
return ForEachMutableHelper(func, &root_, &index); return ForEachMutableHelper(func, &root_, &index);
} }
template <typename T>
void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
const ShapeIndex& source_base_index,
const ShapeIndex& target_base_index) {
CHECK(ShapeUtil::Compatible(
ShapeUtil::GetSubshape(shape(), target_base_index),
ShapeUtil::GetSubshape(other.shape(), source_base_index)));
ForEachMutableElement(
[this, &other, &source_base_index, &target_base_index](
const ShapeIndex& index, bool /*is_leaf*/, T* data) {
// Copy the data element only if index is in the
// subtree rooted at target_base_index.
for (int i = 0; i < target_base_index.size(); ++i) {
if (i >= index.size() || index[i] != target_base_index[i]) {
return Status::OK();
}
}
// Construct source element index to copy from.
ShapeIndex source_index = source_base_index;
for (int i = target_base_index.size(); i < index.size(); ++i) {
source_index.push_back(index[i]);
}
*data = other.element(source_index);
return Status::OK();
})
.IgnoreError();
}
template <typename T>
bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
bool equal = true;
ForEachElement([this, &other, &equal](const ShapeIndex& index,
bool /*is_leaf*/, const T& data) {
if (data != other.element(index)) {
equal = false;
}
return Status::OK();
})
.IgnoreError();
return equal;
}
} // namespace xla } // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_

View File

@ -245,5 +245,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
EXPECT_DEATH(shape_tree.element({0, 0}), ""); EXPECT_DEATH(shape_tree.element({0, 0}), "");
} }
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
*shape_tree.mutable_element({2}) = MakeUnique<int>(42);
EXPECT_EQ(*shape_tree.element({2}), 42);
}
TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) {
// Test CopySubtreeFrom method for a single value copied between array-shaped
// ShapeTrees.
ShapeTree<int> source(array_shape_);
*source.mutable_element(/*index=*/{}) = 42;
ShapeTree<int> destination(array_shape_, 123);
EXPECT_EQ(destination.element(/*index=*/{}), 123);
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
/*target_base_index=*/{});
EXPECT_EQ(destination.element(/*index=*/{}), 42);
}
TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) {
// Test CopySubtreeFrom method for a copy of all elements from one
// tuple-shaped ShapeTree to another.
ShapeTree<int> source(tuple_shape_);
*source.mutable_element(/*index=*/{}) = 10;
*source.mutable_element(/*index=*/{0}) = 11;
*source.mutable_element(/*index=*/{1}) = 12;
*source.mutable_element(/*index=*/{2}) = 13;
ShapeTree<int> destination(tuple_shape_, 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
/*target_base_index=*/{});
EXPECT_EQ(destination.element(/*index=*/{}), 10);
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
EXPECT_EQ(destination.element(/*index=*/{2}), 13);
}
TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) {
// Test CopySubtreeFrom method for a copy of a single element from one
// tuple-shaped ShapeTree to another.
ShapeTree<int> source(tuple_shape_);
*source.mutable_element(/*index=*/{}) = 10;
*source.mutable_element(/*index=*/{0}) = 11;
*source.mutable_element(/*index=*/{1}) = 12;
*source.mutable_element(/*index=*/{2}) = 13;
ShapeTree<int> destination(tuple_shape_, 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{0},
/*target_base_index=*/{1});
EXPECT_EQ(destination.element(/*index=*/{}), 0);
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
EXPECT_EQ(destination.element(/*index=*/{1}), 11);
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
}
TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) {
// Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a
// nested-tuple-shaped ShapeTree.
ShapeTree<int> source(
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}));
*source.mutable_element(/*index=*/{}) = 10;
*source.mutable_element(/*index=*/{0}) = 11;
*source.mutable_element(/*index=*/{1}) = 12;
ShapeTree<int> destination(nested_tuple_shape_, 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
/*target_base_index=*/{2, 0});
EXPECT_EQ(destination.element(/*index=*/{}), 0);
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
EXPECT_EQ(destination.element(/*index=*/{1}), 0);
EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0);
EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0);
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10);
EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11);
EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12);
EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0);
}
TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) {
// Test CopySubtreeFrom method for a copy from a nested-tuple-shape.
ShapeTree<int> source(nested_tuple_shape_, 42);
*source.mutable_element(/*index=*/{1}) = 10;
*source.mutable_element(/*index=*/{1, 0}) = 11;
*source.mutable_element(/*index=*/{1, 1}) = 12;
ShapeTree<int> destination(
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{1},
/*target_base_index=*/{});
EXPECT_EQ(destination.element(/*index=*/{}), 10);
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
}
TEST_F(ShapeTreeTest, OperatorEquals) {
{
ShapeTree<int> a(array_shape_, 123);
ShapeTree<int> b(array_shape_, 42);
ShapeTree<int> c(array_shape_, 42);
EXPECT_FALSE(a == b);
EXPECT_TRUE(a != b);
EXPECT_TRUE(b == c);
}
{
ShapeTree<int> a(tuple_shape_);
*a.mutable_element(/*index=*/{}) = 10;
*a.mutable_element(/*index=*/{0}) = 11;
*a.mutable_element(/*index=*/{1}) = 12;
ShapeTree<int> b(tuple_shape_);
*b.mutable_element(/*index=*/{}) = 10;
*b.mutable_element(/*index=*/{0}) = 42;
*b.mutable_element(/*index=*/{1}) = 11;
ShapeTree<int> c(tuple_shape_);
*c.mutable_element(/*index=*/{}) = 10;
*c.mutable_element(/*index=*/{0}) = 42;
*c.mutable_element(/*index=*/{1}) = 11;
EXPECT_FALSE(a == b);
EXPECT_TRUE(a != b);
EXPECT_TRUE(b == c);
EXPECT_FALSE(b != c);
}
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -122,7 +122,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
for (const auto& shape : parameters) { for (const auto& shape : parameters) {
*program_shape.add_parameters() = shape; *program_shape.add_parameters() = shape;
} }
*program_shape.mutable_result() = result; *program_shape.mutable_result() = std::move(result);
return program_shape; return program_shape;
} }

View File

@ -829,6 +829,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
const int count = GetParam(); const int count = GetParam();
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
std::vector<float> values; std::vector<float> values;
values.reserve(count);
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
values.push_back(i / static_cast<float>(count)); values.push_back(i / static_cast<float>(count));
} }
@ -836,6 +837,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f)); auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
std::vector<float> expected; std::vector<float> expected;
expected.reserve(values.size());
for (float value : values) { for (float value : values) {
expected.push_back(value * value); expected.push_back(value * value);
} }

View File

@ -179,7 +179,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal); VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
VLOG(1) << "actual: " << LiteralUtil::ToString(*actual); VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
EXPECT_EQ(expected, actual->u8s()); EXPECT_EQ(expected, actual->u8s_string());
} }
void ClientLibraryTestBase::ComputeAndCompareTuple( void ClientLibraryTestBase::ComputeAndCompareTuple(

View File

@ -442,6 +442,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
ComputeAndCompareR1<int32>(&builder, expected, {}); ComputeAndCompareR1<int32>(&builder, expected, {});
} }
XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
ComputationBuilder builder(client_, TestName());
Array3D<float> arr0(9, 17, 1);
arr0.Fill(1);
Array3D<float> arr1(9, 17, 256);
arr1.Fill(2);
Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
for (int64 i = 0; i < expected.n1(); ++i) {
for (int64 j = 0; j < expected.n2(); ++j) {
int64 kk = 0;
for (const Array3D<float>& arr : {arr0, arr1}) {
for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
expected(i, j, kk) = arr(i, j, k);
}
}
}
}
ComputationDataHandle h0;
auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
&builder, &h0);
ComputationDataHandle h1;
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
&builder, &h1);
auto concatenated = builder.ConcatInDim({h0, h1}, 2);
ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
}
// Describes a binary rank-2 concatenation test. // Describes a binary rank-2 concatenation test.
struct R2BinarySpec { struct R2BinarySpec {
int64 lhs_dim0; int64 lhs_dim0;

View File

@ -262,7 +262,7 @@ class NearComparator {
max_abs_err_ = 0.0; max_abs_err_ = 0.0;
*miscompares_.mutable_shape() = *miscompares_.mutable_shape() =
ShapeUtil::ChangeElementType(actual.shape(), PRED); ShapeUtil::ChangeElementType(actual.shape(), PRED);
miscompares_.mutable_preds()->Resize( miscompares_.mutable_preds()->resize(
ShapeUtil::ElementsIn(miscompares_.shape()), false); ShapeUtil::ElementsIn(miscompares_.shape()), false);
multi_index_.resize(expected.shape().dimensions_size(), 0); multi_index_.resize(expected.shape().dimensions_size(), 0);
@ -389,7 +389,7 @@ class NearComparator {
tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(), tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
now_usec, name.c_str())); now_usec, name.c_str()));
TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
filename, literal)); filename, literal.ToProto()));
LOG(ERROR) << "wrote to " << name << " file: " << filename; LOG(ERROR) << "wrote to " << name << " file: " << filename;
} }

View File

@ -83,9 +83,10 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]"; LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
EXPECT_EQ(3, results.size()); EXPECT_EQ(3, results.size());
for (const string& result : results) { for (const string& result : results) {
Literal literal; LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
&literal)); &literal_proto));
Literal literal(literal_proto);
if (result.find("expected") != string::npos) { if (result.find("expected") != string::npos) {
EXPECT_EQ("2", LiteralUtil::ToString(literal)); EXPECT_EQ("2", LiteralUtil::ToString(literal));
} else if (result.find("actual") != string::npos) { } else if (result.find("actual") != string::npos) {

View File

@ -47,6 +47,7 @@ TEST_F(LogTest, LogTenValues) {
builder.Log(x); builder.Log(x);
std::vector<float> expected; std::vector<float> expected;
expected.reserve(input.size());
for (float f : input) { for (float f : input) {
expected.push_back(std::log(f)); expected.push_back(std::log(f));
} }

View File

@ -246,6 +246,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
} }
std::vector<GlobalData*> param_data; std::vector<GlobalData*> param_data;
param_data.reserve(param_data_owner.size());
for (const std::unique_ptr<GlobalData>& data : param_data_owner) { for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
param_data.push_back(data.get()); param_data.push_back(data.get());
} }

View File

@ -37,6 +37,7 @@ class SliceTest : public ClientLibraryTestBase {
template <typename NativeT> template <typename NativeT>
void RunSliceTenToTwo() { void RunSliceTenToTwo() {
std::vector<NativeT> constant; std::vector<NativeT> constant;
constant.reserve(10);
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
constant.push_back(static_cast<NativeT>(i)); constant.push_back(static_cast<NativeT>(i));
} }

View File

@ -64,6 +64,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) { for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
std::vector<float> exponents; std::vector<float> exponents;
exponents.reserve(count);
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
exponents.push_back(i / static_cast<float>(count)); exponents.push_back(i / static_cast<float>(count));
} }
@ -71,6 +72,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
auto exp = builder.Exp(x); auto exp = builder.Exp(x);
std::vector<float> expected; std::vector<float> expected;
expected.reserve(exponents.size());
for (float exponent : exponents) { for (float exponent : exponents) {
expected.push_back(std::exp(exponent)); expected.push_back(std::exp(exponent));
} }

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"

View File

@ -81,6 +81,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
client->GetComputationShape(computation).ConsumeValueOrDie(); client->GetComputationShape(computation).ConsumeValueOrDie();
std::vector<const Shape*> layouts; std::vector<const Shape*> layouts;
layouts.reserve(program_shape->parameters_size());
for (int i = 0; i < program_shape->parameters_size(); ++i) { for (int i = 0; i < program_shape->parameters_size(); ++i) {
layouts.push_back(&program_shape->parameters(i)); layouts.push_back(&program_shape->parameters(i));
} }

View File

@ -56,6 +56,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
client->GetComputationShape(computation).ConsumeValueOrDie(); client->GetComputationShape(computation).ConsumeValueOrDie();
std::vector<const Shape*> layouts; std::vector<const Shape*> layouts;
layouts.reserve(program_shape->parameters_size());
for (int i = 0; i < program_shape->parameters_size(); ++i) { for (int i = 0; i < program_shape->parameters_size(); ++i) {
layouts.push_back(&program_shape->parameters(i)); layouts.push_back(&program_shape->parameters(i));
} }

View File

@ -66,7 +66,8 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
if (use_fake_data) { if (use_fake_data) {
arguments = MakeFakeArgumentsOrDie(computation, client); arguments = MakeFakeArgumentsOrDie(computation, client);
} else { // use recorded data if available } else { // use recorded data if available
for (const Literal& literal : module.arguments()) { for (const auto& proto : module.arguments()) {
Literal literal(proto);
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data, TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
client->TransferToServer(literal)); client->TransferToServer(literal));
arguments.push_back(std::move(data)); arguments.push_back(std::move(data));
@ -74,6 +75,7 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
} }
std::vector<GlobalData*> execute_arguments; std::vector<GlobalData*> execute_arguments;
execute_arguments.reserve(arguments.size());
for (auto& argument : arguments) { for (auto& argument : arguments) {
execute_arguments.push_back(argument.get()); execute_arguments.push_back(argument.get());
} }
@ -100,7 +102,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
if (module.has_result()) { if (module.has_result()) {
fprintf(stdout, "was %s:%s\n", fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(module.result().shape()).c_str(), ShapeUtil::HumanString(module.result().shape()).c_str(),
LiteralUtil::ToString(module.result()).c_str()); LiteralUtil::ToString(Literal(module.result())).c_str());
} }
} }
} }

View File

@ -37,9 +37,10 @@ int main(int argc, char **argv) {
<< " <path-to-serialized-literal-proto>"; << " <path-to-serialized-literal-proto>";
} }
xla::Literal literal; xla::LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
&literal)); &literal_proto));
LOG(INFO) << "literal: " << literal.ShortDebugString(); xla::Literal literal(literal_proto);
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str()); fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str());
} }

View File

@ -92,11 +92,11 @@ message TransferToClientRequest {
} }
message TransferToClientResponse { message TransferToClientResponse {
Literal literal = 1; LiteralProto literal = 1;
} }
message TransferToServerRequest { message TransferToServerRequest {
Literal literal = 1; LiteralProto literal = 1;
DeviceHandle device_handle = 2; DeviceHandle device_handle = 2;
} }
@ -105,7 +105,7 @@ message TransferToServerResponse {
} }
message TransferToInfeedRequest { message TransferToInfeedRequest {
Literal literal = 1; LiteralProto literal = 1;
int64 replica_id = 2; int64 replica_id = 2;
DeviceHandle device_handle = 3; DeviceHandle device_handle = 3;
} }
@ -123,7 +123,7 @@ message TransferFromOutfeedRequest {
} }
message TransferFromOutfeedResponse { message TransferFromOutfeedResponse {
Literal literal = 1; LiteralProto literal = 1;
} }
message ResetDeviceRequest { message ResetDeviceRequest {

View File

@ -275,7 +275,7 @@ message ChannelHandle {
// //
// Transfers to/from the client are encoded in literal form, and the structure // Transfers to/from the client are encoded in literal form, and the structure
// of the repeated fields is implied by the shape. // of the repeated fields is implied by the shape.
message Literal { message LiteralProto {
Shape shape = 1; Shape shape = 1;
repeated bool preds = 2; repeated bool preds = 2;
bytes u8s = 3; bytes u8s = 3;
@ -285,7 +285,7 @@ message Literal {
repeated uint64 u64s = 7; repeated uint64 u64s = 7;
repeated float f32s = 8; repeated float f32s = 8;
repeated double f64s = 9; repeated double f64s = 9;
repeated Literal tuple_literals = 10; repeated LiteralProto tuple_literals = 10;
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
} }
@ -337,7 +337,7 @@ message Window {
// field in OpRequest. // field in OpRequest.
message ConstantRequest { message ConstantRequest {
Literal literal = 2; LiteralProto literal = 2;
} }
message GetTupleElementRequest { message GetTupleElementRequest {

View File

@ -85,6 +85,7 @@ cc_library(
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/nccl:nccl_kernels",
"//tensorflow/contrib/seq2seq:beam_search_ops_kernels",
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
"//tensorflow/contrib/text:all_kernels", "//tensorflow/contrib/text:all_kernels",
], ],
@ -100,6 +101,7 @@ cc_library(
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib",
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
"//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/text:all_ops",
], ],

View File

@ -347,6 +347,7 @@ class BatchResource : public ResourceBase {
// Concatenate the tasks ith input tensors into a big output tensor. // Concatenate the tasks ith input tensors into a big output tensor.
std::vector<Tensor> to_concatenate; std::vector<Tensor> to_concatenate;
to_concatenate.reserve(batch->num_tasks());
for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) { for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
to_concatenate.push_back(batch->task(task_idx).inputs.at(i)); to_concatenate.push_back(batch->task(task_idx).inputs.at(i));
} }

View File

@ -139,6 +139,7 @@ TEST(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) {
&callback_data](std::unique_ptr<Batch<FakeTask>> batch) { &callback_data](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed()); ASSERT_TRUE(batch->IsClosed());
std::vector<size_t> batch_data; std::vector<size_t> batch_data;
batch_data.reserve(batch->num_tasks());
for (int i = 0; i < batch->num_tasks(); ++i) { for (int i = 0; i < batch->num_tasks(); ++i) {
batch_data.push_back(batch->mutable_task(i)->size()); batch_data.push_back(batch->mutable_task(i)->size());
} }

View File

@ -295,6 +295,7 @@ void ExpectVecsEquiv(const std::vector<float>& vec1,
std::vector<float> GetWeightsByIndex(const std::vector<float>& weights, std::vector<float> GetWeightsByIndex(const std::vector<float>& weights,
const std::vector<int>& indices) { const std::vector<int>& indices) {
std::vector<float> res; std::vector<float> res;
res.reserve(indices.size());
for (const int index : indices) { for (const int index : indices) {
res.push_back(weights[index]); res.push_back(weights[index]);
} }

View File

@ -236,6 +236,9 @@ add_python_module("tensorflow/tensorboard")
add_python_module("tensorflow/tensorboard/backend") add_python_module("tensorflow/tensorboard/backend")
add_python_module("tensorflow/tensorboard/backend/event_processing") add_python_module("tensorflow/tensorboard/backend/event_processing")
add_python_module("tensorflow/tensorboard/plugins") add_python_module("tensorflow/tensorboard/plugins")
add_python_module("tensorflow/tensorboard/plugins/audio")
add_python_module("tensorflow/tensorboard/plugins/distributions")
add_python_module("tensorflow/tensorboard/plugins/graphs")
add_python_module("tensorflow/tensorboard/plugins/histograms") add_python_module("tensorflow/tensorboard/plugins/histograms")
add_python_module("tensorflow/tensorboard/plugins/images") add_python_module("tensorflow/tensorboard/plugins/images")
add_python_module("tensorflow/tensorboard/plugins/projector") add_python_module("tensorflow/tensorboard/plugins/projector")
@ -536,6 +539,7 @@ set(tf_python_op_gen_main_srcs
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h"
) )
add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs}) add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs})

View File

@ -209,10 +209,11 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Broken TensorBoard tests due to different paths in windows # Broken TensorBoard tests due to different paths in windows
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py" "${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py" "${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
# Broken tensorboard test due to cmake issues. # Broken tensorboard test due to cmake issues.
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py" "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
# tensor_forest tests (also note that we exclude the hybrid tests for now) # tensor_forest tests (also note that we exclude the hybrid tests for now)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.

View File

@ -150,7 +150,8 @@ class MapDatasetTest(test.TestCase):
results.append(sess.run(get_next)) results.append(sess.run(get_next))
except errors.OutOfRangeError: except errors.OutOfRangeError:
return return
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)] threads = [self.checkedThread(target=iterator_thread)
for _ in range(64)]
for t in threads: for t in threads:
t.start() t.start()
for t in threads: for t in threads:

View File

@ -375,8 +375,8 @@ class NearestNeighborsOp : public OpKernel {
const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm, const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
const Eigen::Ref<const MatrixXfRowMajor>& centers, const Eigen::Ref<const MatrixXfRowMajor>& centers,
const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm, const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices, const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) { const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
CHECK_LE(k, centers.rows()); CHECK_LE(k, centers.rows());
if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) { if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers, FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,

View File

@ -164,9 +164,10 @@ class KMeans(object):
with ops.colocate_with(inp): with ops.colocate_with(inp):
# Computes Euclidean distance. Note the first and third terms are # Computes Euclidean distance. Note the first and third terms are
# broadcast additions. # broadcast additions.
squared_distance = (math_ops.reduce_sum( squared_distance = (
math_ops.square(inp), 1, keep_dims=True) - 2 * math_ops.matmul( math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
inp, clusters, transpose_b=True) + array_ops.transpose( 2 * math_ops.matmul(inp, clusters, transpose_b=True) +
array_ops.transpose(
math_ops.reduce_sum( math_ops.reduce_sum(
math_ops.square(clusters), 1, keep_dims=True))) math_ops.square(clusters), 1, keep_dims=True)))
output.append(squared_distance) output.append(squared_distance)
@ -229,12 +230,12 @@ class KMeans(object):
clusters = nn_impl.l2_normalize(clusters, dim=1) clusters = nn_impl.l2_normalize(clusters, dim=1)
for inp, score in zip(inputs, scores): for inp, score in zip(inputs, scores):
with ops.colocate_with(inp): with ops.colocate_with(inp):
(indices, (indices, distances) = gen_clustering_ops.nearest_neighbors(
distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1) inp, clusters, 1)
if self._distance_metric == COSINE_DISTANCE: if self._distance_metric == COSINE_DISTANCE:
distances *= 0.5 distances *= 0.5
output.append( output.append((score, array_ops.squeeze(distances),
(score, array_ops.squeeze(distances), array_ops.squeeze(indices))) array_ops.squeeze(indices)))
return zip(*output) return zip(*output)
def _init_clusters_random(self): def _init_clusters_random(self):
@ -265,9 +266,7 @@ class KMeans(object):
(not self._use_mini_batch or (not self._use_mini_batch or
self._mini_batch_steps_per_iteration > 1)) self._mini_batch_steps_per_iteration > 1))
def _initialize_clusters(self, def _initialize_clusters(self, cluster_centers, cluster_centers_initialized,
cluster_centers,
cluster_centers_initialized,
cluster_centers_updated): cluster_centers_updated):
"""Returns an op to initialize the cluster centers.""" """Returns an op to initialize the cluster centers."""
@ -294,21 +293,19 @@ class KMeans(object):
with ops.colocate_with(cluster_centers_initialized): with ops.colocate_with(cluster_centers_initialized):
initialized = control_flow_ops.with_dependencies( initialized = control_flow_ops.with_dependencies(
[clusters_init], [clusters_init], array_ops.identity(cluster_centers_initialized))
array_ops.identity(cluster_centers_initialized))
with ops.colocate_with(cluster_centers): with ops.colocate_with(cluster_centers):
assign_centers = state_ops.assign(cluster_centers, clusters_init, assign_centers = state_ops.assign(
validate_shape=False) cluster_centers, clusters_init, validate_shape=False)
if cluster_centers_updated != cluster_centers: if cluster_centers_updated != cluster_centers:
assign_centers = control_flow_ops.group( assign_centers = control_flow_ops.group(assign_centers,
assign_centers, state_ops.assign(
state_ops.assign(cluster_centers_updated, clusters_init, cluster_centers_updated,
clusters_init,
validate_shape=False)) validate_shape=False))
assign_centers = control_flow_ops.with_dependencies( assign_centers = control_flow_ops.with_dependencies(
[assign_centers], [assign_centers], state_ops.assign(cluster_centers_initialized, True))
state_ops.assign(cluster_centers_initialized, True)) return control_flow_ops.cond(initialized, control_flow_ops.no_op,
return control_flow_ops.cond(initialized,
control_flow_ops.no_op,
lambda: assign_centers).op lambda: assign_centers).op
def _create_variables(self): def _create_variables(self):
@ -327,19 +324,16 @@ class KMeans(object):
cluster_centers_updated back to cluster_centers. cluster_centers_updated back to cluster_centers.
""" """
init_value = array_ops.constant([], dtype=dtypes.float32) init_value = array_ops.constant([], dtype=dtypes.float32)
cluster_centers = variable_scope.variable(init_value, cluster_centers = variable_scope.variable(
name='clusters', init_value, name='clusters', validate_shape=False)
validate_shape=False) cluster_centers_initialized = variable_scope.variable(
cluster_centers_initialized = variable_scope.variable(False, False, dtype=dtypes.bool, name='initialized')
dtype=dtypes.bool,
name='initialized')
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
# Copy of cluster centers actively updated each step according to # Copy of cluster centers actively updated each step according to
# mini-batch update rule. # mini-batch update rule.
cluster_centers_updated = variable_scope.variable(init_value, cluster_centers_updated = variable_scope.variable(
name='clusters_updated', init_value, name='clusters_updated', validate_shape=False)
validate_shape=False)
# How many steps till we copy the updated clusters to cluster_centers. # How many steps till we copy the updated clusters to cluster_centers.
update_in_steps = variable_scope.variable( update_in_steps = variable_scope.variable(
self._mini_batch_steps_per_iteration, self._mini_batch_steps_per_iteration,
@ -347,20 +341,15 @@ class KMeans(object):
name='update_in_steps') name='update_in_steps')
# Count of points assigned to cluster_centers_updated. # Count of points assigned to cluster_centers_updated.
cluster_counts = variable_scope.variable( cluster_counts = variable_scope.variable(
array_ops.zeros([self._num_clusters], array_ops.zeros([self._num_clusters], dtype=dtypes.int64))
dtype=dtypes.int64))
else: else:
cluster_centers_updated = cluster_centers cluster_centers_updated = cluster_centers
update_in_steps = None update_in_steps = None
cluster_counts = (variable_scope.variable(array_ops.ones( cluster_counts = (variable_scope.variable(
[self._num_clusters], array_ops.ones([self._num_clusters], dtype=dtypes.int64))
dtype=dtypes.int64))
if self._use_mini_batch else None) if self._use_mini_batch else None)
return (cluster_centers, return (cluster_centers, cluster_centers_initialized, cluster_counts,
cluster_centers_initialized, cluster_centers_updated, update_in_steps)
cluster_counts,
cluster_centers_updated,
update_in_steps)
@classmethod @classmethod
def _l2_normalize_data(cls, inputs): def _l2_normalize_data(cls, inputs):
@ -391,11 +380,8 @@ class KMeans(object):
""" """
# Implementation of kmeans. # Implementation of kmeans.
inputs = self._inputs inputs = self._inputs
(cluster_centers_var, (cluster_centers_var, cluster_centers_initialized, total_counts,
cluster_centers_initialized, cluster_centers_updated, update_in_steps) = self._create_variables()
total_counts,
cluster_centers_updated,
update_in_steps) = self._create_variables()
init_op = self._initialize_clusters(cluster_centers_var, init_op = self._initialize_clusters(cluster_centers_var,
cluster_centers_initialized, cluster_centers_initialized,
cluster_centers_updated) cluster_centers_updated)
@ -409,8 +395,7 @@ class KMeans(object):
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers) all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
if self._use_mini_batch: if self._use_mini_batch:
sync_updates_op = self._mini_batch_sync_updates_op( sync_updates_op = self._mini_batch_sync_updates_op(
update_in_steps, update_in_steps, cluster_centers_var, cluster_centers_updated,
cluster_centers_var, cluster_centers_updated,
total_counts) total_counts)
assert sync_updates_op is not None assert sync_updates_op is not None
with ops.control_dependencies([sync_updates_op]): with ops.control_dependencies([sync_updates_op]):
@ -421,15 +406,15 @@ class KMeans(object):
training_op = self._full_batch_training_op(inputs, cluster_idx, training_op = self._full_batch_training_op(inputs, cluster_idx,
cluster_centers_var) cluster_centers_var)
return (all_scores, cluster_idx, scores, return (all_scores, cluster_idx, scores, cluster_centers_initialized,
cluster_centers_initialized, init_op, training_op) init_op, training_op)
def _mini_batch_sync_updates_op(self, update_in_steps, def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
cluster_centers_var, cluster_centers_updated, cluster_centers_updated, total_counts):
total_counts):
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
assert update_in_steps is not None assert update_in_steps is not None
with ops.colocate_with(update_in_steps): with ops.colocate_with(update_in_steps):
def _f(): def _f():
# Note that there is a race condition here, so we do a best effort # Note that there is a race condition here, so we do a best effort
# updates here. We reset update_in_steps first so that other workers # updates here. We reset update_in_steps first so that other workers
@ -437,33 +422,36 @@ class KMeans(object):
# before resetting total_counts to avoid large updates to # before resetting total_counts to avoid large updates to
# cluster_centers_updated based on partially updated # cluster_centers_updated based on partially updated
# cluster_center_vars. # cluster_center_vars.
with ops.control_dependencies([state_ops.assign( with ops.control_dependencies([
update_in_steps, state_ops.assign(update_in_steps,
self._mini_batch_steps_per_iteration - 1)]): self._mini_batch_steps_per_iteration - 1)
with ops.colocate_with(cluster_centers_updated): ]):
with ops.colocate_with(
cluster_centers_updated, ignore_existing=True):
if self._distance_metric == COSINE_DISTANCE: if self._distance_metric == COSINE_DISTANCE:
cluster_centers = nn_impl.l2_normalize(cluster_centers_updated, cluster_centers = nn_impl.l2_normalize(
dim=1) cluster_centers_updated, dim=1)
else: else:
cluster_centers = cluster_centers_updated cluster_centers = cluster_centers_updated
with ops.colocate_with(cluster_centers_var): with ops.colocate_with(cluster_centers_var):
with ops.control_dependencies([state_ops.assign( with ops.control_dependencies(
cluster_centers_var, [state_ops.assign(cluster_centers_var, cluster_centers)]):
cluster_centers)]): with ops.colocate_with(
with ops.colocate_with(cluster_centers_var): cluster_centers_var, ignore_existing=True):
with ops.control_dependencies([ with ops.control_dependencies([
state_ops.assign(total_counts, state_ops.assign(total_counts,
array_ops.zeros_like(total_counts))]): array_ops.zeros_like(total_counts))
]):
return array_ops.identity(update_in_steps) return array_ops.identity(update_in_steps)
return control_flow_ops.cond( return control_flow_ops.cond(
update_in_steps <= 0, update_in_steps <= 0, _f,
_f,
lambda: state_ops.assign_sub(update_in_steps, 1)) lambda: state_ops.assign_sub(update_in_steps, 1))
else: else:
return control_flow_ops.no_op() return control_flow_ops.no_op()
def _mini_batch_training_op(self, inputs, cluster_idx_list, def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
cluster_centers, total_counts): total_counts):
"""Creates an op for training for mini batch case. """Creates an op for training for mini batch case.
Args: Args:
@ -487,17 +475,15 @@ class KMeans(object):
unique_ids, unique_idx = array_ops.unique(cluster_idx) unique_ids, unique_idx = array_ops.unique(cluster_idx)
num_unique_cluster_idx = array_ops.size(unique_ids) num_unique_cluster_idx = array_ops.size(unique_ids)
# Fetch the old values of counts and cluster_centers. # Fetch the old values of counts and cluster_centers.
with ops.colocate_with(total_counts): with ops.colocate_with(total_counts, ignore_existing=True):
old_counts = array_ops.gather(total_counts, unique_ids) old_counts = array_ops.gather(total_counts, unique_ids)
# TODO(agarwal): This colocation seems to run into problems. Fix it. # TODO(agarwal): This colocation seems to run into problems. Fix it.
# with ops.colocate_with(cluster_centers): with ops.colocate_with(cluster_centers, ignore_existing=True):
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids) old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
# Locally aggregate the increment to counts. # Locally aggregate the increment to counts.
count_updates = math_ops.unsorted_segment_sum( count_updates = math_ops.unsorted_segment_sum(
array_ops.ones_like( array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
unique_idx, dtype=total_counts.dtype), unique_idx, num_unique_cluster_idx)
unique_idx,
num_unique_cluster_idx)
# Locally compute the sum of inputs mapped to each id. # Locally compute the sum of inputs mapped to each id.
# For a cluster with old cluster value x, old count n, and with data # For a cluster with old cluster value x, old count n, and with data
# d_1,...d_k newly assigned to it, we recompute the new value as # d_1,...d_k newly assigned to it, we recompute the new value as
@ -507,13 +493,12 @@ class KMeans(object):
inp, unique_idx, num_unique_cluster_idx) inp, unique_idx, num_unique_cluster_idx)
# Shape to enable broadcasting count_updates and learning_rate to inp. # Shape to enable broadcasting count_updates and learning_rate to inp.
# It extends the shape with 1's to match the rank of inp. # It extends the shape with 1's to match the rank of inp.
broadcast_shape = array_ops.concat( broadcast_shape = array_ops.concat([
[ array_ops.reshape(num_unique_cluster_idx, [1]),
array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones( array_ops.ones(
array_ops.reshape(array_ops.rank(inp) - 1, [1]), array_ops.reshape(array_ops.rank(inp) - 1, [1]),
dtype=dtypes.int32) dtype=dtypes.int32)
], ], 0)
0)
# Subtract k * x, see comment above. # Subtract k * x, see comment above.
cluster_center_updates -= math_ops.cast( cluster_center_updates -= math_ops.cast(
array_ops.reshape(count_updates, broadcast_shape), array_ops.reshape(count_updates, broadcast_shape),
@ -524,14 +509,10 @@ class KMeans(object):
# scale by 1 / (n + k), see comment above. # scale by 1 / (n + k), see comment above.
cluster_center_updates *= learning_rate cluster_center_updates *= learning_rate
# Apply the updates. # Apply the updates.
update_counts = state_ops.scatter_add( update_counts = state_ops.scatter_add(total_counts, unique_ids,
total_counts,
unique_ids,
count_updates) count_updates)
update_cluster_centers = state_ops.scatter_add( update_cluster_centers = state_ops.scatter_add(
cluster_centers, cluster_centers, unique_ids, cluster_center_updates)
unique_ids,
cluster_center_updates)
update_ops.extend([update_counts, update_cluster_centers]) update_ops.extend([update_counts, update_cluster_centers])
return control_flow_ops.group(*update_ops) return control_flow_ops.group(*update_ops)
@ -552,7 +533,7 @@ class KMeans(object):
cluster_counts = [] cluster_counts = []
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype) epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
for inp, cluster_idx in zip(inputs, cluster_idx_list): for inp, cluster_idx in zip(inputs, cluster_idx_list):
with ops.colocate_with(inp): with ops.colocate_with(inp, ignore_existing=True):
cluster_sums.append( cluster_sums.append(
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters)) math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
cluster_counts.append( cluster_counts.append(
@ -561,7 +542,7 @@ class KMeans(object):
array_ops.ones( array_ops.ones(
array_ops.reshape(array_ops.shape(inp)[0], [-1])), array_ops.reshape(array_ops.shape(inp)[0], [-1])),
[-1, 1]), cluster_idx, self._num_clusters)) [-1, 1]), cluster_idx, self._num_clusters))
with ops.colocate_with(cluster_centers): with ops.colocate_with(cluster_centers, ignore_existing=True):
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast( new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
if self._clusters_l2_normalized(): if self._clusters_l2_normalized():

View File

@ -94,6 +94,7 @@ TEST(FfmpegLibTest, TestRoundTripGeneratedWav) {
} }
std::vector<float> sine_wave; std::vector<float> sine_wave;
sine_wave.reserve(20000);
for (int i = 0; i < 20000; ++i) { for (int i = 0; i < 20000; ++i) {
sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0)); sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0));
} }

View File

@ -494,6 +494,7 @@ class SparseFeatureCrossOp : public OpKernel {
ExtractFeatureData(indices_list_in, batch_size, &feature_counts, ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
&feature_start_indices); &feature_start_indices);
columns.reserve(values_list_in.size());
for (int i = 0; i < values_list_in.size(); ++i) { for (int i = 0; i < values_list_in.size(); ++i) {
columns.emplace_back(new SparseTensorColumn<InternalType>( columns.emplace_back(new SparseTensorColumn<InternalType>(
values_list_in[i], std::move(feature_counts[i]), values_list_in[i], std::move(feature_counts[i]),

View File

@ -308,6 +308,7 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_rea
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
from tensorflow.contrib.learn.python.learn.estimators.head import Head from tensorflow.contrib.learn.python.learn.estimators.head import Head
from tensorflow.contrib.learn.python.learn.estimators.head import loss_only_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head

View File

@ -429,6 +429,23 @@ def multi_label_head(n_classes,
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
def loss_only_head(loss_fn, head_name=None):
"""Creates a Head that contains only loss terms.
Loss only head holds additional loss terms to be added to other heads and
usually represents additional regularization terms in the objective function.
Args:
loss_fn: a function that takes no argument and returns a list of
scalar tensors.
head_name: a name for for the head.
Returns:
An instance of `Head` to hold the additional losses.
"""
return _LossOnlyHead(loss_fn, head_name=head_name)
def multi_head(heads, loss_weights=None): def multi_head(heads, loss_weights=None):
"""Creates a MultiHead stemming from same logits/hidden layer. """Creates a MultiHead stemming from same logits/hidden layer.
@ -1406,6 +1423,80 @@ class _MultiLabelHead(_SingleHead):
return metrics return metrics
class _LossOnlyHead(Head):
"""`Head` implementation for additional loss terms.
This class only holds loss terms unrelated to any other heads (labels),
e.g. regularization.
Common usage:
This is oftem combine with other heads in a multi head setup.
```python
head = multi_head([
head1, head2, loss_only_head('regularizer', regularizer)])
```
"""
def __init__(self, loss_fn, head_name=None):
self._loss_fn = loss_fn
self.head_name = head_name or "loss_only_head"
@property
def logits_dimension(self):
return 0
def create_model_fn_ops(self,
features,
mode,
labels=None,
train_op_fn=None,
logits=None,
logits_input=None,
scope=None):
"""See `_Head.create_model_fn_ops`.
Args:
features: Not been used.
mode: Estimator's `ModeKeys`.
labels: Labels `Tensor`, or `dict` of same.
train_op_fn: Function that takes a scalar loss and returns an op to
optimize with the loss.
logits: Not been used.
logits_input: Not been used.
scope: Optional scope for variable_scope. If provided, will be passed to
all heads. Most users will want to set this to `None`, so each head
constructs a separate variable_scope according to its `head_name`.
Returns:
A `ModelFnOps` object.
Raises:
ValueError: if `mode` is not recognition.
"""
_check_mode_valid(mode)
loss = None
train_op = None
if mode != model_fn.ModeKeys.INFER:
with variable_scope.variable_scope(scope, default_name=self.head_name):
loss = self._loss_fn()
if isinstance(loss, list):
loss = math_ops.add_n(loss)
logging_ops.scalar_summary(
_summary_key(self.head_name, mkey.LOSS), loss)
if mode == model_fn.ModeKeys.TRAIN:
if train_op_fn is None:
raise ValueError("train_op_fn can not be None in TRAIN mode")
with ops.name_scope(None, "train_op", (loss,)):
train_op = train_op_fn(loss)
return model_fn.ModelFnOps(
mode=mode,
loss=loss,
train_op=train_op,
predictions={},
eval_metric_ops={})
class _MultiHead(Head): class _MultiHead(Head):
"""`Head` implementation for multi objective learning. """`Head` implementation for multi objective learning.
@ -1525,6 +1616,9 @@ class _MultiHead(Head):
if isinstance(logits, dict): if isinstance(logits, dict):
head_logits_pairs = [] head_logits_pairs = []
for head in self._heads: for head in self._heads:
if isinstance(head, _LossOnlyHead):
head_logits_pairs.append((head, None))
else:
head_logits_pairs.append((head, logits[head.head_name])) head_logits_pairs.append((head, logits[head.head_name]))
else: else:
# Split logits for each head. # Split logits for each head.
@ -1606,6 +1700,8 @@ class _MultiHead(Head):
predictions = {} predictions = {}
output_alternatives = {} output_alternatives = {}
for head, m in zip(self._heads, all_model_fn_ops): for head, m in zip(self._heads, all_model_fn_ops):
if isinstance(head, _LossOnlyHead):
continue
head_name = head.head_name head_name = head.head_name
output_alternatives[head_name] = m.output_alternatives[head_name] output_alternatives[head_name] = m.output_alternatives[head_name]
for k, v in m.predictions.items(): for k, v in m.predictions.items():

View File

@ -1638,6 +1638,21 @@ class BinarySvmHeadTest(test.TestCase):
}, model_fn_ops) }, model_fn_ops)
class LossOnlyHead(test.TestCase):
def testNoPredictionsAndNoMetrics(self):
head = head_lib.loss_only_head(lambda: 1, head_name="const")
model_fn_ops = head.create_model_fn_ops(
features={},
mode=model_fn.ModeKeys.TRAIN,
train_op_fn=head_lib.no_op_train_fn)
self.assertDictEqual(model_fn_ops.predictions, {})
self.assertDictEqual(model_fn_ops.eval_metric_ops, {})
self.assertIsNotNone(model_fn_ops.loss)
with session.Session() as sess:
self.assertEqual(1, sess.run(model_fn_ops.loss))
class MultiHeadTest(test.TestCase): class MultiHeadTest(test.TestCase):
def testInvalidHeads(self): def testInvalidHeads(self):
@ -1672,7 +1687,8 @@ class MultiHeadTest(test.TestCase):
n_classes=3, label_name="label1", head_name="head1") n_classes=3, label_name="label1", head_name="head1")
head2 = head_lib.multi_class_head( head2 = head_lib.multi_class_head(
n_classes=4, label_name="label2", head_name="head2") n_classes=4, label_name="label2", head_name="head2")
head = head_lib.multi_head((head1, head2)) head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")
head = head_lib.multi_head((head1, head2, head3))
labels = { labels = {
"label1": (1,), "label1": (1,),
"label2": (1,) "label2": (1,)
@ -1691,7 +1707,7 @@ class MultiHeadTest(test.TestCase):
self.assertIsNone(model_fn_ops.output_alternatives) self.assertIsNone(model_fn_ops.output_alternatives)
with session.Session() as sess: with session.Session() as sess:
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3) self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)
def testTrain_withHeadWeights(self): def testTrain_withHeadWeights(self):
head1 = head_lib.multi_class_head( head1 = head_lib.multi_class_head(

View File

@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None,
``` ```
Args: Args:
vocabulary_file: The vocabulary filename. vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
num_oov_buckets: The number of out-of-vocabulary buckets. num_oov_buckets: The number of out-of-vocabulary buckets.
vocab_size: Number of the elements in the vocabulary, if known. vocab_size: Number of the elements in the vocabulary, if known.
default_value: The value to use for out-of-vocabulary feature values. default_value: The value to use for out-of-vocabulary feature values.
@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None,
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
than zero. than zero.
""" """
if not vocabulary_file: if vocabulary_file is None or (
raise ValueError("vocabulary_file must be specified.") isinstance(vocabulary_file, str) and not vocabulary_file):
raise ValueError("vocabulary_file must be specified and must not be empty.")
if num_oov_buckets < 0: if num_oov_buckets < 0:
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
% num_oov_buckets) % num_oov_buckets)

View File

@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase):
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval()) self.assertAllEqual((1, 2, 3), ids.eval())
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
with self.test_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_file(self): def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile( vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000")) "f2i_vocab2.txt", values=("42", "1", "-1000"))
@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase):
860), # 3 + fingerprint("toccata") mod 300. 860), # 3 + fingerprint("toccata") mod 300.
ids.eval()) ids.eval())
def test_index_table_from_file_with_only_oov_buckets(self): def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
self.assertRaises(
ValueError,
lookup.index_table_from_file,
vocabulary_file="")
def test_index_table_from_file_fails_with_empty_vocabulary(self):
self.assertRaises( self.assertRaises(
ValueError, ValueError,
lookup.index_table_from_file, lookup.index_table_from_file,

View File

@ -23,6 +23,7 @@ See the @{$python/contrib.metrics} guide.
@@streaming_precision @@streaming_precision
@@streaming_precision_at_thresholds @@streaming_precision_at_thresholds
@@streaming_auc @@streaming_auc
@@streaming_curve_points
@@streaming_recall_at_k @@streaming_recall_at_k
@@streaming_mean_absolute_error @@streaming_mean_absolute_error
@@streaming_mean_iou @@streaming_mean_iou
@ -76,6 +77,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives

View File

@ -733,6 +733,102 @@ def streaming_true_negatives_at_thresholds(
return values['tn'], update_ops['tn'] return values['tn'], update_ops['tn']
def streaming_curve_points(labels=None,
predictions=None,
weights=None,
num_thresholds=200,
metrics_collections=None,
updates_collections=None,
curve='ROC',
name=None):
"""Computes curve (ROC or PR) values for a prespecified number of points.
The `streaming_curve_points` function creates four local variables,
`true_positives`, `true_negatives`, `false_positives` and `false_negatives`
that are used to compute the curve values. To discretize the curve, a linearly
spaced set of thresholds is used to compute pairs of recall and precision
values.
For best results, `predictions` should be distributed approximately uniformly
in the range [0, 1] and not peaked around 0 or 1.
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
`bool`.
predictions: A floating point `Tensor` of arbitrary shape and whose values
are in the range `[0, 1]`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `labels` dimension).
num_thresholds: The number of thresholds to use when discretizing the roc
curve.
metrics_collections: An optional list of collections that `auc` should be
added to.
updates_collections: An optional list of collections that `update_op` should
be added to.
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
'PR' for the Precision-Recall-curve.
name: An optional variable_scope name.
Returns:
points: A `Tensor` with shape [num_thresholds, 2] that contains points of
the curve.
update_op: An operation that increments the `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` variables.
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
`weights` is not `None` and its shape doesn't match `predictions`, or if
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(name, 'curve_points', (labels, predictions,
weights)):
if curve != 'ROC' and curve != 'PR':
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _streaming_confusion_matrix_at_thresholds(
labels=labels,
predictions=predictions,
thresholds=thresholds,
weights=weights)
# Add epsilons to avoid dividing by 0.
epsilon = 1.0e-6
def compute_points(tp, fn, tn, fp):
"""Computes the roc-auc or pr-auc based on confusion counts."""
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
if curve == 'ROC':
fp_rate = math_ops.div(fp, fp + tn + epsilon)
return fp_rate, rec
else: # curve == 'PR'.
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
return rec, prec
xs, ys = compute_points(values['tp'], values['fn'], values['tn'],
values['fp'])
points = array_ops.stack([xs, ys], axis=1)
update_op = control_flow_ops.group(*update_ops.values())
if metrics_collections:
ops.add_to_collections(metrics_collections, points)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return points, update_op
def streaming_auc(predictions, labels, weights=None, num_thresholds=200, def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
metrics_collections=None, updates_collections=None, metrics_collections=None, updates_collections=None,
curve='ROC', name=None): curve='ROC', name=None):
@ -2372,6 +2468,7 @@ __all__ = [
'sparse_recall_at_top_k', 'sparse_recall_at_top_k',
'streaming_accuracy', 'streaming_accuracy',
'streaming_auc', 'streaming_auc',
'streaming_curve_points',
'streaming_false_negatives', 'streaming_false_negatives',
'streaming_false_negatives_at_thresholds', 'streaming_false_negatives_at_thresholds',
'streaming_false_positives', 'streaming_false_positives',

View File

@ -1327,6 +1327,99 @@ class StreamingRecallTest(test.TestCase):
self.assertEqual(0, recall.eval()) self.assertEqual(0, recall.eval())
class StreamingCurvePointsTest(test.TestCase):
def setUp(self):
np.random.seed(1)
ops.reset_default_graph()
def testVars(self):
metric_ops.streaming_curve_points(
predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
_assert_local_variables(
self,
('curve_points/true_positives:0', 'curve_points/false_negatives:0',
'curve_points/false_positives:0', 'curve_points/true_negatives:0'))
def testMetricsCollection(self):
my_collection_name = '__metrics__'
points, _ = metric_ops.streaming_curve_points(
labels=array_ops.ones((10, 1)),
predictions=array_ops.ones((10, 1)),
metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [points])
def testUpdatesCollection(self):
my_collection_name = '__updates__'
_, update_op = metric_ops.streaming_curve_points(
labels=array_ops.ones((10, 1)),
predictions=array_ops.ones((10, 1)),
updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def _testValueTensorIsIdempotent(self, curve):
predictions = constant_op.constant(
np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32)
labels = constant_op.constant(
np.random.uniform(high=2, size=(10, 3)), dtype=dtypes_lib.float32)
points, update_op = metric_ops.streaming_curve_points(
labels, predictions=predictions, curve=curve)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
initial_points = points.eval()
sess.run(update_op)
self.assertAllClose(initial_points, points.eval())
def testValueTensorIsIdempotentROC(self):
self._testValueTensorIsIdempotent(curve='ROC')
def testValueTensorIsIdempotentPR(self):
self._testValueTensorIsIdempotent(curve='PR')
def _testCase(self, labels, predictions, curve, expected_points):
with self.test_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
points, update_op = metric_ops.streaming_curve_points(
labels=labels_tensor,
predictions=predictions_tensor,
num_thresholds=3,
curve=curve)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAllClose(expected_points, points.eval())
def testEdgeCasesROC(self):
self._testCase([[1]], [[1]], 'ROC', [[0, 1], [0, 1], [0, 0]])
self._testCase([[0]], [[0]], 'ROC', [[1, 1], [0, 1], [0, 1]])
self._testCase([[0]], [[1]], 'ROC', [[1, 1], [1, 1], [0, 1]])
self._testCase([[1]], [[0]], 'ROC', [[0, 1], [0, 0], [0, 0]])
def testManyValuesROC(self):
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'ROC',
[[1.0, 1.0], [0.0, 0.75], [0.0, 0.0]])
def testEdgeCasesPR(self):
self._testCase([[1]], [[1]], 'PR', [[1, 1], [1, 1], [0, 1]])
self._testCase([[0]], [[0]], 'PR', [[1, 0], [1, 1], [1, 1]])
self._testCase([[0]], [[1]], 'PR', [[1, 0], [1, 0], [1, 1]])
self._testCase([[1]], [[0]], 'PR', [[1, 1], [0, 1], [0, 1]])
def testManyValuesPR(self):
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'PR',
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
class StreamingAUCTest(test.TestCase): class StreamingAUCTest(test.TestCase):
def setUp(self): def setUp(self):

View File

@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase):
class BeamSearchDecoderTest(test.TestCase): class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention): def _testDynamicDecodeRNN(self, time_major, has_attention):
encoder_sequence_length = [3, 2, 3, 1, 1] encoder_sequence_length = np.array([3, 2, 3, 1, 1])
decoder_sequence_length = [2, 0, 1, 2, 3] decoder_sequence_length = np.array([2, 0, 1, 2, 3])
batch_size = 5 batch_size = 5
decoder_max_time = 4 decoder_max_time = 4
input_depth = 7 input_depth = 7
@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase):
batch_size_tensor = constant_op.constant(batch_size) batch_size_tensor = constant_op.constant(batch_size)
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth) cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
if has_attention: if has_attention:
inputs = array_ops.placeholder_with_default( inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time, np.random.randn(batch_size, decoder_max_time,
@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase):
num_units=attention_depth, num_units=attention_depth,
memory=tiled_inputs, memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length) memory_sequence_length=tiled_sequence_length)
initial_state = beam_search_decoder.tile_batch(
initial_state, multiplier=beam_width)
cell = attention_wrapper.AttentionWrapper( cell = attention_wrapper.AttentionWrapper(
cell=cell, cell=cell,
attention_mechanism=attention_mechanism, attention_mechanism=attention_mechanism,
@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase):
alignment_history=False) alignment_history=False)
cell_state = cell.zero_state( cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
cell_state = cell_state.clone(
cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoder( bsd = beam_search_decoder.BeamSearchDecoder(
cell=cell, cell=cell,
embedding=embedding, embedding=embedding,

View File

@ -72,27 +72,8 @@ class FinalBeamSearchDecoderOutput(
pass pass
def tile_batch(t, multiplier, name=None): def _tile_batch(t, multiplier):
"""Tile the batch dimension of tensor t. """Core single-tensor implementation of tile_batch."""
This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
`multiplier` times.
Args:
t: `Tensor` shaped `[batch_size, ...]`.
multiplier: Python int.
name: Name scope for any created operations.
Returns:
A `Tensor` shaped `[batch_size * multiplier, ...]`.
Raises:
ValueError: if `t` does not have a statically known rank or it's < 1.
"""
with ops.name_scope(name, "tile_batch", [t, multiplier]):
t = ops.convert_to_tensor(t, name="t") t = ops.convert_to_tensor(t, name="t")
shape_t = array_ops.shape(t) shape_t = array_ops.shape(t)
if t.shape.ndims is None or t.shape.ndims < 1: if t.shape.ndims is None or t.shape.ndims < 1:
@ -110,6 +91,34 @@ def tile_batch(t, multiplier, name=None):
return tiled return tiled
def tile_batch(t, multiplier, name=None):
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
For each tensor t in a (possibly nested structure) of tensors,
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
`multiplier` times.
Args:
t: `Tensor` shaped `[batch_size, ...]`.
multiplier: Python int.
name: Name scope for any created operations.
Returns:
A (possibly nested structure of) `Tensor` shaped
`[batch_size * multiplier, ...]`.
Raises:
ValueError: if tensor(s) `t` do not have a statically known rank or
the rank is < 1.
"""
flat_t = nest.flatten(t)
with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
def _check_maybe(t): def _check_maybe(t):
if isinstance(t, tensor_array_ops.TensorArray): if isinstance(t, tensor_array_ops.TensorArray):
raise TypeError( raise TypeError(

View File

@ -270,7 +270,7 @@ class SessionBundleTest : public ::testing::Test {
// MetaGraphDef. // MetaGraphDef.
// Returns the path of the export. // Returns the path of the export.
// ** Should only be called once per test ** // ** Should only be called once per test **
string SetupExport(MetaGraphDefTwiddler twiddler) { string SetupExport(const MetaGraphDefTwiddler& twiddler) {
return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename); return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename);
} }
// SetupExport that allows for the variables and meta_graph_def filenames // SetupExport that allows for the variables and meta_graph_def filenames

View File

@ -62,6 +62,7 @@ licenses(["notice"]) # Apache 2.0
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"full_path",
"if_android", "if_android",
"if_ios", "if_ios",
"if_x86", "if_x86",

View File

@ -30,7 +30,11 @@ Device::Device(Env* env, const DeviceAttributes& device_attributes)
rmgr_ = new ResourceMgr(parsed_name_.job); rmgr_ = new ResourceMgr(parsed_name_.job);
} }
Device::~Device() { delete rmgr_; } Device::~Device() {
if (rmgr_ != nullptr) {
DeleteResourceMgr();
}
}
// static // static
DeviceAttributes Device::BuildDeviceAttributes( DeviceAttributes Device::BuildDeviceAttributes(

View File

@ -60,7 +60,9 @@ class Device : public DeviceBase {
const string& name() const { return device_attributes_.name(); } const string& name() const { return device_attributes_.name(); }
// Parsed name of this device // Parsed name of this device
const DeviceNameUtils::ParsedName parsed_name() const { return parsed_name_; } const DeviceNameUtils::ParsedName& parsed_name() const {
return parsed_name_;
}
// Describes what kind of device this is. This is intended to be // Describes what kind of device this is. This is intended to be
// human-readable and not computer-parsed, except that two devices // human-readable and not computer-parsed, except that two devices
@ -149,6 +151,12 @@ class Device : public DeviceBase {
return BuildDeviceAttributes(name, device, memory_limit, locality, ""); return BuildDeviceAttributes(name, device, memory_limit, locality, "");
} }
protected:
void DeleteResourceMgr() {
delete rmgr_;
rmgr_ = nullptr;
}
private: private:
const DeviceAttributes device_attributes_; const DeviceAttributes device_attributes_;
DeviceNameUtils::ParsedName parsed_name_; DeviceNameUtils::ParsedName parsed_name_;

View File

@ -53,7 +53,7 @@ Device* DeviceSet::FindDeviceByName(const string& name) const {
// static // static
int DeviceSet::DeviceTypeOrder(const DeviceType& d) { int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
return DeviceFactory::DevicePriority(d.type()); return DeviceFactory::DevicePriority(d.type_string());
} }
static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) { static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) {

View File

@ -1231,7 +1231,7 @@ Status FunctionDefToBodyHelper(
GraphConstructorOptions opts; GraphConstructorOptions opts;
opts.allow_internal_ops = true; opts.allow_internal_ops = true;
opts.expect_device_spec = false; opts.expect_device_spec = false;
Status s = ConvertGraphDefToGraph(opts, result.gdef, graph); Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph);
if (!s.ok()) { if (!s.ok()) {
delete graph; delete graph;
} else { } else {

View File

@ -93,7 +93,7 @@ class FunctionTest : public ::testing::Test {
GraphConstructorOptions opts; GraphConstructorOptions opts;
opts.allow_internal_ops = true; opts.allow_internal_ops = true;
opts.expect_device_spec = false; opts.expect_device_spec = false;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g)); TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g));
const int version = g->versions().producer(); const int version = g->versions().producer();
LocalExecutorParams params; LocalExecutorParams params;
@ -949,7 +949,7 @@ GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
GraphConstructorOptions opts; GraphConstructorOptions opts;
opts.allow_internal_ops = true; opts.allow_internal_ops = true;
opts.expect_device_spec = false; opts.expect_device_spec = false;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get())); TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
pass(g.get()); pass(g.get());
std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global())); std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
CopyGraph(*g, g1.get()); CopyGraph(*g, g1.get());

View File

@ -324,6 +324,7 @@ static void BM_AllocationDelayed(int iters, int delay) {
int size_index = 0; int size_index = 0;
std::vector<void*> ptrs; std::vector<void*> ptrs;
ptrs.reserve(delay);
for (int i = 0; i < delay; i++) { for (int i = 0; i < delay; i++) {
ptrs.push_back(nullptr); ptrs.push_back(nullptr);
} }

View File

@ -123,10 +123,12 @@ void Benchmark::RunWithArgs(
} }
// Gets inputs' and outputs' rendezvous keys. // Gets inputs' and outputs' rendezvous keys.
std::vector<std::pair<string, Tensor>> in; std::vector<std::pair<string, Tensor>> in;
in.reserve(inputs.size());
for (const auto& p : inputs) { for (const auto& p : inputs) {
in.push_back({GetRendezvousKey(p.first), p.second}); in.push_back({GetRendezvousKey(p.first), p.second});
} }
std::vector<string> out; std::vector<string> out;
out.reserve(outputs.size());
for (const auto& n : outputs) { for (const auto& n : outputs) {
out.push_back(GetRendezvousKey(n)); out.push_back(GetRendezvousKey(n));
} }

View File

@ -94,6 +94,7 @@ Status SessionFactory::GetFactory(const SessionOptions& options,
// TODO(mrry): Consider providing a system-default fallback option // TODO(mrry): Consider providing a system-default fallback option
// in this case. // in this case.
std::vector<string> factory_types; std::vector<string> factory_types;
factory_types.reserve(candidate_factories.size());
for (const auto& candidate_factory : candidate_factories) { for (const auto& candidate_factory : candidate_factories) {
factory_types.push_back(candidate_factory.first); factory_types.push_back(candidate_factory.first);
} }

Some files were not shown because too many files have changed in this diff Show More