fix merge issues

This commit is contained in:
Andrew Harp 2017-06-03 00:00:53 -04:00
commit 5efd272aab
378 changed files with 9541 additions and 4762 deletions
RELEASE.mdWORKSPACE
tensorflow
BUILD
c
cc
compiler
contrib
core

View File

@ -41,6 +41,15 @@
be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
processing of the rnn. For RNN decoding, this functionality has been replaced
with an alternative API in `tf.contrib.seq2seq`.
* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of
optimized deep learning primitives: In addition to matrix multiplication and
convolution, these building blocks include:
Direct batched convolution
Pooling: maximum, minimum, average
Normalization: LRN, batch normalization
Activation: rectified linear unit (ReLU)
Data manipulation: multi-dimensional transposition (conversion), split,
concat, sum and scale.
* TensorForest Estimator now supports SavedModel export for serving.
* Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters.
* TensorFlow C library now available for Windows.

View File

@ -2,11 +2,11 @@ workspace(name = "org_tensorflow")
http_archive(
name = "io_bazel_rules_closure",
sha256 = "4be8a887f6f38f883236e77bb25c2da10d506f2bf1a8e5d785c0f35574c74ca4",
strip_prefix = "rules_closure-aac19edc557aec9b603cd7ffe359401264ceff0d",
sha256 = "edc91f556b762fc5212d1050d00b12e40dd0b0b1c1d5d96886b59e9a30a6cae4",
strip_prefix = "rules_closure-3f07fb6a58870afbb36051bd5d54da4479561cc6",
urls = [
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", # 2017-05-10
"https://github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz",
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz", # 2017-05-31
"https://github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz",
],
)

View File

@ -393,6 +393,9 @@ filegroup(
"//tensorflow/tensorboard/demo:all_files",
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
"//tensorflow/tensorboard/plugins:all_files",
"//tensorflow/tensorboard/plugins/audio:all_files",
"//tensorflow/tensorboard/plugins/distributions:all_files",
"//tensorflow/tensorboard/plugins/graphs:all_files",
"//tensorflow/tensorboard/plugins/histograms:all_files",
"//tensorflow/tensorboard/plugins/images:all_files",
"//tensorflow/tensorboard/plugins/projector:all_files",

View File

@ -805,6 +805,7 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
}
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
dim_vec.reserve(num_dims);
for (int i = 0; i < num_dims; ++i) {
dim_vec.push_back(ic->MakeDim(dims[i]));
}

View File

@ -113,10 +113,12 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
feeds.emplace_back(feed.first.name(), feed.second.tensor);
}
std::vector<string> output_tensor_names;
output_tensor_names.reserve(fetch_outputs.size());
for (auto const& output : fetch_outputs) {
output_tensor_names.push_back(output.name());
}
std::vector<string> target_node_names;
target_node_names.reserve(run_outputs.size());
for (auto const& output : run_outputs) {
target_node_names.push_back(output.node()->name());
}

View File

@ -44,6 +44,7 @@ Status ComputeTheoreticalJacobianTranspose(
size_t x_num = x_shapes.size();
// Call AddSymbolicGradients to get 'dxs' (we will feed 'dys').
OutputList dys;
dys.reserve(y_shapes.size());
for (const auto& y_shape : y_shapes) {
// TODO(suharshs): This currently assumes that all x's are the same type.
dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type()));

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/cc/framework/testutil.h"
#include <utility>
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/default_device.h"
@ -30,7 +32,7 @@ void GetTensors(const Scope& scope, OutputList tensors,
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
std::vector<Tensor> outputs;
GetTensors(scope, {tensor}, &outputs);
GetTensors(scope, {std::move(tensor)}, &outputs);
*out = outputs[0];
}

View File

@ -350,6 +350,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
compile_result->program_shape = *pshape_or.ValueOrDie();
xla::ProgramShape* pshape = &compile_result->program_shape;
std::vector<const xla::Shape*> arg_layouts;
arg_layouts.reserve(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
}

View File

@ -218,6 +218,7 @@ cc_library(
deps = [
":common",
":graph_to_functiondef",
":union_find",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/kernels:xla_local_launch_op",
@ -237,6 +238,11 @@ cc_library(
],
)
cc_library(
name = "union_find",
hdrs = ["union_find.h"],
)
cc_test(
name = "compilation_passes_test",
size = "small",

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <utility>
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/cc/framework/ops.h"
@ -101,12 +103,12 @@ Node* Input(const GraphDefBuilder::Options& opts) {
}
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
return ops::UnaryOp("UnaryTest", a, opts);
return ops::UnaryOp("UnaryTest", std::move(a), opts);
}
Node* Binary(ops::NodeOut a, ops::NodeOut b,
const GraphDefBuilder::Options& opts) {
return ops::BinaryOp("BinaryTest", a, b, opts);
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
}
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
@ -127,7 +129,7 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
opts.op_registry());
node_builder.Input(a).Attr("index", index);
node_builder.Input(std::move(a)).Attr("index", index);
return opts.FinalizeBuilder(&node_builder);
}

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
@ -206,70 +207,12 @@ Status FindCompilationCandidates(
return Status::OK();
}
// Union-Find data structure used to compute clusters. We use our own
// implementation because we want one key feature: when merging clusters, we
// need to know which value becomes the representative of the merged clusters.
// We use the representatives to name nodes in a cycle detection graph, and we
// need to control which node is named.
// TODO(phawkins): consider merging this code with union-find implementations
// in Tensorflow, e.g., in SimplePlacer.
class Cluster {
public:
Cluster();
int Size() { return FindRoot()->size_; }
// Merges this cluster with 'other'. This cluster's representative becomes
// the representative of the merged cluster; the representative of 'other'
// is ignored.
void Merge(Cluster* other);
// Each cluster has an associated integer 'representative', initialized to -1
// by default.
int GetRepresentative() { return FindRoot()->representative_; }
void SetRepresentative(int representative) {
FindRoot()->representative_ = representative;
}
private:
// Finds the root element of the cluster. Performs path compression.
Cluster* FindRoot();
int representative_;
int rank_;
int size_; // Size of the cluster.
Cluster* parent_;
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.
int representative = -1;
};
Cluster::Cluster()
: representative_(-1), rank_(0), size_(1), parent_(nullptr) {}
void Cluster::Merge(Cluster* other) {
Cluster* a = FindRoot();
Cluster* b = other->FindRoot();
if (a == b) return;
if (a->rank_ > b->rank_) {
b->parent_ = a;
a->size_ += b->size_;
return;
}
a->parent_ = b;
if (a->rank_ == b->rank_) {
b->rank_++;
}
b->representative_ = a->representative_;
b->size_ += a->size_;
}
Cluster* Cluster::FindRoot() {
if (!parent_) return this;
// Path compression: update intermediate nodes to point to the root of the
// equivalence class.
parent_ = parent_->FindRoot();
return parent_;
}
} // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
@ -432,10 +375,11 @@ Status MarkForCompilationPass::RunImpl(
// Each compilation candidate belongs to a cluster. The cluster's
// representative
// names the node in the 'cycles' graph that represents the cluster.
std::vector<Cluster> clusters(graph->num_node_ids());
std::deque<Cluster*> worklist;
std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids());
std::deque<UnionFind<Cluster>*> worklist;
for (Node* node : compilation_candidates) {
clusters[node->id()].SetRepresentative(node->id());
Cluster& cluster = clusters[node->id()].Get();
cluster.representative = node->id();
worklist.push_back(&clusters[node->id()]);
}
@ -445,7 +389,7 @@ Status MarkForCompilationPass::RunImpl(
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
while (!worklist.empty()) {
int from = worklist.front()->GetRepresentative();
int from = worklist.front()->Get().representative;
worklist.pop_front();
Node* node_from = graph->FindNodeId(from);
@ -518,7 +462,7 @@ Status MarkForCompilationPass::RunImpl(
// Count the number of elements in each cluster.
std::vector<int> cluster_sizes(graph->num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative();
int cluster = clusters[n->id()].Get().representative;
cluster_sizes[cluster]++;
}
@ -532,7 +476,7 @@ Status MarkForCompilationPass::RunImpl(
// if compilation is enabled, otherwise there will be no such candidates).
const int min_cluster_size = flags->tf_xla_min_cluster_size;
for (Node* n : compilation_candidates) {
int cluster = clusters[n->id()].GetRepresentative();
int cluster = clusters[n->id()].Get().representative;
// Compile if the user marked this node _XlaCompile=true
bool compile_attr = false;

View File

@ -0,0 +1,81 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
namespace tensorflow {
// Union-Find data structure.
// Each cluster has an associated value; when merging clusters we can control
// which value becomes the representative of the merged clusters. Values must be
// copyable.
template <typename T>
class UnionFind {
public:
UnionFind() : rank_(0), size_(1), parent_(nullptr) {}
// Returns the number of elements in a cluster.
int Size() { return FindRoot()->size_; }
// Merges this cluster with 'other'. This cluster's value becomes
// the value of the merged cluster; the value of 'other' is ignored.
void Merge(UnionFind* other);
// Each cluster has an associated value. Retrieves the value associated
// with this cluster.
T& Get() { return FindRoot()->value_; }
private:
// Finds the root element of the cluster. Performs path compression.
UnionFind* FindRoot();
int rank_;
int size_; // Size of the cluster.
UnionFind* parent_;
T value_;
};
template <typename T>
void UnionFind<T>::Merge(UnionFind* other) {
UnionFind<T>* a = FindRoot();
UnionFind<T>* b = other->FindRoot();
if (a == b) return;
if (a->rank_ > b->rank_) {
b->parent_ = a;
a->size_ += b->size_;
return;
}
a->parent_ = b;
if (a->rank_ == b->rank_) {
b->rank_++;
}
b->value_ = a->value_;
b->size_ += a->size_;
}
template <typename T>
UnionFind<T>* UnionFind<T>::FindRoot() {
if (!parent_) return this;
// Path compression: update intermediate nodes to point to the root of the
// equivalence class.
parent_ = parent_->FindRoot();
return parent_;
}
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_

View File

@ -50,6 +50,7 @@ class FillOp : public XlaOpKernel {
// Convert the dims literal into a vector that we can pass to
// ComputationBuilder.
std::vector<int64> broadcast;
broadcast.reserve(dims_literal.shape().dimensions(0));
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
broadcast.push_back(xla::LiteralUtil::Get<int>(dims_literal, {i}));
}

View File

@ -50,6 +50,7 @@ class SliceOp : public XlaOpKernel {
// slice will be an empty handle if the output has no elements.
CHECK_EQ(begin.size(), size.size());
std::vector<int64> limits;
limits.reserve(begin.size());
for (int i = 0; i < begin.size(); ++i) {
limits.push_back(begin[i] + size[i]);
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"

View File

@ -58,14 +58,13 @@ StatusOr<std::unique_ptr<Literal>> Client::Transfer(
"server provided response without a literal in "
"TransferToClient request");
}
return WrapUnique(response.release_literal());
return MakeUnique<Literal>(response.literal());
}
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
const Literal& literal, const DeviceHandle* device_handle) {
TransferToServerRequest request;
*request.mutable_literal() = literal;
*request.mutable_literal() = literal.ToProto();
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
@ -93,7 +92,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
const DeviceHandle* device_handle) {
TransferToInfeedRequest request;
*request.mutable_literal() = literal;
*request.mutable_literal() = literal.ToProto();
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
@ -141,7 +140,8 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
"TransferToClient request");
}
return WrapUnique(response.release_literal());
Literal literal(response.literal());
return MakeUnique<Literal>(literal);
}
Status Client::ResetDevice() {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"

View File

@ -165,9 +165,10 @@ ComputationDataHandle ComputationBuilder::ConstantOp(
}
ConstantRequest request;
Literal* literal = request.mutable_literal();
populate(literal);
VLOG(3) << "created constant: " << literal->ShortDebugString();
Literal literal;
populate(&literal);
*request.mutable_literal() = literal.ToProto();
VLOG(3) << "created constant: " << request.literal().ShortDebugString();
OpRequest op_request;
*op_request.mutable_constant_request() = request;
*op_request.mutable_computation() = computation_.handle();

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include <string>
#include <utility>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/logging.h"
@ -23,7 +24,7 @@ limitations under the License.
namespace xla {
GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle)
: handle_(handle), parent_(parent) {}
: handle_(std::move(handle)), parent_(parent) {}
GlobalData::~GlobalData() {
UnregisterRequest request;

View File

@ -222,8 +222,9 @@ tensorflow::Status LocalExecutable::RecordArguments(
SessionModule* session_module) {
session_module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_RETURN_IF_ERROR(
LiteralFromShapedBuffer(*argument, session_module->add_arguments()));
Literal literal;
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal));
*session_module->add_arguments() = literal.ToProto();
}
return tensorflow::Status::OK();
}
@ -231,9 +232,13 @@ tensorflow::Status LocalExecutable::RecordArguments(
tensorflow::Status LocalExecutable::RecordResult(
const ShapedBuffer* result, SessionModule* session_module) {
session_module->clear_result();
return LiteralFromShapedBuffer(*result, session_module->mutable_result());
Literal literal(session_module->result());
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal));
*session_module->mutable_result() = literal.ToProto();
return tensorflow::Status::OK();
}
// TODO(dnovillo) Change signature to return StatusOr<Literal>.
tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer, Literal* literal) {
TF_ASSIGN_OR_RETURN(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -856,5 +856,26 @@ TEST_F(LiteralUtilTest, ConvertR4) {
EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted));
}
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
LiteralProto p;
p.mutable_shape()->set_element_type(PRED);
for (int len = 0; len < 25; ++len) {
p.mutable_shape()->clear_dimensions();
p.mutable_shape()->add_dimensions(len);
p.clear_preds();
for (int i = 0; i < len; ++i) {
p.add_preds((i % 2) == (len % 2));
}
Literal literal(p);
ASSERT_EQ(len, literal.preds_size());
int i = 0;
for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) {
EXPECT_EQ((i % 2) == (len % 2), *it);
++i;
}
}
}
} // namespace
} // namespace xla

View File

@ -60,8 +60,8 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
int64 elements = ShapeUtil::ElementsIn(shape);
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
result.get());
tensorflow::protobuf::RepeatedField<float>* field = result->mutable_f32s();
char* data = tensorflow::bit_cast<char*>(field->mutable_data());
std::vector<float>* field = result->mutable_f32s();
char* data = tensorflow::bit_cast<char*>(field->data());
uint64 bytes = elements * sizeof(float);
tensorflow::StringPiece sp;
auto s = file_->Read(offset_, bytes, &sp, data);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/reference_util.h"
#include <array>
#include <utility>
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
@ -331,7 +332,8 @@ ReferenceUtil::ConvArray4DGeneralDimensions(
std::pair<int64, int64> kernel_stride, Padding padding,
ConvolutionDimensionNumbers dimension_numbers) {
return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
{1, 1}, {1, 1}, dimension_numbers);
{1, 1}, {1, 1},
std::move(dimension_numbers));
}
/* static */ std::unique_ptr<Array4D<float>>

View File

@ -529,6 +529,7 @@ cc_library(
srcs = ["transfer_manager.cc"],
hdrs = ["transfer_manager.h"],
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -1680,10 +1681,8 @@ cc_library(
deps = [
":buffer_assignment",
":hlo",
":hlo_ordering",
":hlo_proto",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)

View File

@ -171,6 +171,7 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
executor, allocation->device_memory(), allocation->shape()));
std::vector<GlobalDataHandle> element_handles;
element_handles.reserve(element_bases.size());
for (int i = 0; i < element_bases.size(); ++i) {
element_handles.push_back(RegisterInternal(
allocation->backend(), allocation->device_ordinal(), element_bases[i],

View File

@ -229,25 +229,26 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>>
buffer_to_source_indices;
TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices](
const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) {
if (buffers.size() > 1) {
// Record ambiguous points-to set at 'index'.
if (!indices_to_copy_.element(index)) {
VLOG(2) << "Adding copy of buffer for instruction: "
<< instruction_->name()
<< " at index: " << tensorflow::str_util::Join(index, ",")
<< " with ambiguous points-to set.";
RecordIndex(index);
}
}
// For each 'buffer': record a mapping from 'buffer' to 'index'.
for (const LogicalBuffer* buffer : buffers) {
buffer_to_source_indices[buffer].push_back(index);
}
return Status::OK();
}));
TF_RETURN_IF_ERROR(points_to.ForEachElement(
[this, &buffer_to_source_indices](
const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) {
if (buffers.size() > 1) {
// Record ambiguous points-to set at 'index'.
if (!indices_to_copy_.element(index)) {
VLOG(2) << "Adding copy of buffer for instruction: "
<< instruction_->name()
<< " at index: " << tensorflow::str_util::Join(index, ",")
<< " with ambiguous points-to set.";
RecordIndex(index);
}
}
// For each 'buffer': record a mapping from 'buffer' to 'index'.
for (const LogicalBuffer* buffer : buffers) {
buffer_to_source_indices[buffer].push_back(index);
}
return Status::OK();
}));
// Record all non-distinct indices detected in 'buffer_to_source_indices'.
for (const auto& buff_to_src : buffer_to_source_indices) {
@ -449,11 +450,15 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) {
const HloInstruction* init_hlo = while_hlo->operand(0);
const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo);
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
FlatSet<const LogicalBuffer*> buffer_set;
ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
TF_RETURN_IF_ERROR(points_to.ForEachElement(
[init_hlo, read_only_indices, shared_copies, &copy_overrides](
const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) {
[init_hlo, read_only_indices, shared_copies, &buffer_set,
&copy_overrides](const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) {
// Look for read-only entry parameters.
if (!read_only_indices->element(index)) {
return Status::OK();
@ -468,6 +473,7 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
if (!is_entry_parameter && !is_constant) {
continue;
}
// We have found an entry parameter or constant that is read-only in
// the while body. These buffers are managed by the caller, and cannot
// be aliased with non-parameter buffers. Revert this read-only index,
@ -476,16 +482,17 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
// Optimization to allow multiple while loops that share the same
// read-only entry parameters (or constants) to share a single copy.
// Only unambiguous array-shaped buffers are allowed, to reduce code
// complexity. The shape of the entry parameter must be identical to
// the shape of the init_hlo at this index, to ensure there were no
// intervening bitcast or GTE instructions, which are also hard to
// handle.
// Only unambiguous and distinct array-shaped buffers are allowed, to
// reduce code complexity. The shape of the entry parameter must be
// identical to the shape of the init_hlo at this index, to ensure
// there were no intervening bitcast or GTE instructions, which are
// also hard to handle.
const Shape& pointee_shape = pointee->shape();
const Shape& init_shape =
ShapeUtil::GetSubshape(init_hlo->shape(), index);
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
ShapeUtil::Equal(pointee_shape, init_shape)) {
ShapeUtil::Equal(pointee_shape, init_shape) &&
buffer_set.count(buffer) < 1) {
HloInstruction** copy = &(*shared_copies)[pointee];
if (*copy == nullptr) {
*copy =
@ -496,6 +503,9 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
*copy_overrides.mutable_element(index) = *copy;
}
// Tracks whether this current buffer is distinct.
buffer_set.insert(buffer);
// We've already reverted the read-only index and handled the
// single-copy optimization above, so there's nothing more to do.
break;

View File

@ -44,13 +44,20 @@ class CopyInsertionTest : public HloTestBase {
EXPECT_IS_OK(copy_insertion.Run(module).status());
// Verify the points to set of the root of the computation after copy
// insertion contains no constants or parameters.
// insertion contains no constants or parameters, and is distinct and
// non-ambiguous.
auto points_to_analysis =
TuplePointsToAnalysis::Run(module).ConsumeValueOrDie();
const auto& points_to = points_to_analysis->GetPointsToSet(
module->entry_computation()->root_instruction());
EXPECT_TRUE(points_to.IsDistinct());
EXPECT_TRUE(!points_to.IsAmbiguous());
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers =
points_to_analysis
->GetPointsToSet(module->entry_computation()->root_instruction())
.CreateFlattenedSet();
for (const LogicalBuffer* buffer : maybe_live_out_buffers) {
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant);
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter);
@ -390,6 +397,47 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
return builder.Build();
}
// Builds a While body computation with two output tuple elements dependent on
// both input tuple elements.
//
// EX: Body({in0, in1, in2})
// out0 = Add(in0, 1)
// out1 = in1
// out2 = in2
// Tuple(out0, out1, out2)
std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
auto builder = HloComputation::Builder(TestName() + ".Body");
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
{induction_variable_shape_, data_shape_, data_shape_});
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
// Update the induction variable GTE(0).
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
// add0 = Add(in0, 1)
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
// data1 = GTE(1).
HloInstruction* data1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
// data2 = GTE(2).
HloInstruction* data2 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
// Create output Tuple.
builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
return builder.Build();
}
// Builds a While body computation with read-only tuple element 0.
// EX:
// Body({in0, in1})
@ -408,6 +456,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// Update data GTE(1).
auto data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
// Use 'induction_variable' in computation with no path to output tuple.
auto update = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
@ -431,6 +480,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// Create param instruction to access loop state.
const Shape& loop_state_shape =
nested ? nested_loop_state_shape_ : loop_state_shape_;
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
// Update the induction variable GTE(0).
@ -972,7 +1022,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
op::Copy(old_init->operand(1)->operand(0)))));
}
// Tests while init instruction buffer which interferes with while result buffer.
// Tests while init instruction buffer which interferes with while result
// buffer.
//
// init_data = Broadcast(...)
// add_unrelated = Add(init_data) // takes a reference to cause interference
@ -989,5 +1040,81 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
op::Copy(old_init->operand(1))));
}
// Tests while init instruction buffer which has a non-distinct points-to set:
//
// init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
// Parameter(F32, {8})))
//
// where the second and third parameters are identical *and* the tuple shared
// by another while instruction..
//
// Verifies that the resulting point-to set is distinct in the resulting Tuple
// (non-identical Copys). In other words, verifies that copy sharing does not
// insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation());
auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation());
// Loop body that outputs tuple comprises two elements dependent on the init
// tuple.
auto body1 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
auto body2 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
auto builder = HloComputation::Builder(TestName() + ".While");
auto iter_param = builder.AddInstruction(
HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
auto data_param = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "data"));
// Loop init tuple contains two identical parameter buffers.
auto loop_init = builder.AddInstruction(
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
{induction_variable_shape_, data_shape_, data_shape_});
// Two while loops shares the same loop init tuple.
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition1, body1, loop_init));
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition2, body2, loop_init));
module_.AddEntryComputation(builder.Build());
auto points_to_analysis =
TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
// Asserts that the init tuples before copy insertion is non-distinct.
ASSERT_FALSE(
points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct());
ASSERT_FALSE(
points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct());
auto old_init1 = while_hlo1->operand(0);
auto old_init2 = while_hlo2->operand(0);
InsertCopies(&module_);
EXPECT_THAT(while_hlo1->operand(0),
op::Tuple(op::Copy(old_init1->operand(0)),
op::Copy(old_init1->operand(1)),
op::Copy(old_init1->operand(2))));
EXPECT_THAT(while_hlo2->operand(0),
op::Tuple(op::Copy(old_init2->operand(0)),
op::Copy(old_init2->operand(1)),
op::Copy(old_init2->operand(2))));
// Verifies the init tuples after copy insertion is distinct.
points_to_analysis = TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
const auto& points_to1 =
points_to_analysis->GetPointsToSet(while_hlo1->operand(0));
EXPECT_TRUE(points_to1.IsDistinct());
const auto& points_to2 =
points_to_analysis->GetPointsToSet(while_hlo2->operand(0));
EXPECT_TRUE(points_to2.IsDistinct());
}
} // namespace
} // namespace xla

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"

View File

@ -31,7 +31,7 @@ AsyncExecution::AsyncExecution(Backend* backend,
: backend_(CHECK_NOTNULL(backend)),
streams_(std::move(streams)),
profile_(profile),
result_(result) {
result_(std::move(result)) {
for (const auto& stream : streams_) {
CHECK(stream != nullptr);
}

View File

@ -254,6 +254,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
// d40 -- layer 4
HloComputation::Builder builder("entry_computation");
std::vector<HloInstruction*> params;
params.reserve(6);
for (int i = 0; i < 6; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));

View File

@ -1631,6 +1631,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk(
// Compute the input buffer indices.
std::vector<BufferAllocation::Slice> io_buffers;
io_buffers.reserve(io_hlos.size());
for (const HloInstruction* io_hlo : io_hlos) {
io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo)));
}

View File

@ -86,6 +86,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
// d40 -- layer 4
HloComputation::Builder builder("entry_computation");
std::vector<HloInstruction*> params;
params.reserve(6);
for (int i = 0; i < 6; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));

View File

@ -46,7 +46,7 @@ message HloInstructionProto {
xla.OpMetadata metadata = 7;
// Literal, only present for kConstant.
xla.Literal literal = 8;
xla.LiteralProto literal = 8;
// Parameter info, only present for kParameter.
int64 parameter_number = 9;

View File

@ -311,7 +311,6 @@ void ComputeComputationPostOrder(
visited->insert(computation);
post_order->push_back(computation);
return;
}
} // namespace

View File

@ -65,7 +65,7 @@ using ::tensorflow::strings::StrCat;
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
instruction->operands_.push_back(operand);
instruction->literal_.reset(new Literal);
*instruction->literal_->mutable_u8s() += tag;
instruction->literal_->append_u8s(tag);
return instruction;
}
@ -1484,6 +1484,7 @@ string HloInstruction::ToString(bool compact_operands,
}
if (!slice_starts_.empty() && !slice_limits_.empty()) {
std::vector<string> bounds;
bounds.reserve(slice_starts_.size());
for (int i = 0; i < slice_starts_.size(); ++i) {
bounds.push_back(
StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]"));
@ -1550,7 +1551,7 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
switch (opcode_) {
case HloOpcode::kConstant:
*proto.mutable_literal() = *literal_;
*proto.mutable_literal() = literal_->ToProto();
break;
case HloOpcode::kParameter:
proto.set_parameter_number(parameter_number_);
@ -1647,10 +1648,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
trace_instruction_ = trace_instruction;
}
const string& HloInstruction::tracing_tag() const {
string HloInstruction::TracingTag() const {
CHECK_EQ(HloOpcode::kTrace, opcode());
CHECK(literal_ != nullptr);
return literal_->u8s();
return literal_->u8s_string();
}
bool HloInstruction::IsFused() const {

View File

@ -30,6 +30,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@ -535,7 +536,7 @@ class HloInstruction {
// Returns a tag to be used in tracing.
//
// Precondition: opcode() == HloOpcode::kTrace
const string& tracing_tag() const;
string TracingTag() const;
// Returns whether the instruction is a constant.
bool IsConstant() const;

View File

@ -151,7 +151,26 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
return true;
};
if (std::all_of(hlo->users().begin(), hlo->users().end(),
// An "effectively unary" operation is one that has one "large"
// input with the others being negligible in terms of memory usage.
// We use "has a smaller true rank than the output" as a heuristic
// for "negligible" memory usage.
auto effectively_unary = [](HloInstruction* hlo) {
if (hlo->operands().size() == 1) {
return true;
}
auto output_rank = ShapeUtil::TrueRank(hlo->shape());
return std::count_if(
hlo->operands().begin(), hlo->operands().end(),
[output_rank](HloInstruction* operand) {
return ((operand->opcode() != HloOpcode::kBroadcast) &&
ShapeUtil::TrueRank(operand->shape()) >=
output_rank);
}) <= 1;
};
if (effectively_unary(hlo) ||
std::all_of(hlo->users().begin(), hlo->users().end(),
user_fusable_into_hlo)) {
all_consumers_fusable.insert(hlo);
}

View File

@ -156,21 +156,67 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {16, 16}), "0"));
HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0));
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1));
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto param0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
auto param1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary2, computation->root_instruction());
EXPECT_EQ(unary, computation->root_instruction());
EXPECT_FALSE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
HloComputation::Builder builder(TestName());
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto param0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
HloInstruction* unary1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
HloInstruction* unary2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary2, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto small_shape = ShapeUtil::MakeShape(F32, {16});
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, small_shape, "0"));
auto param1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
} // namespace xla

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "external/llvm/include/llvm/IR/Module.h"
#include "external/llvm/include/llvm/IR/Value.h"
#include "external/llvm/include/llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"

View File

@ -77,8 +77,10 @@ tensorflow::Status RecordArguments(
SessionModule* module) {
module->clear_arguments();
for (const Allocation* allocation : arg_allocations) {
TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(),
module->add_arguments()));
Literal argument;
TF_RETURN_IF_ERROR(
LiteralFromAllocation(allocation, allocation->shape(), &argument));
*module->add_arguments() = argument.ToProto();
}
return tensorflow::Status::OK();
}
@ -87,8 +89,11 @@ tensorflow::Status RecordArguments(
tensorflow::Status RecordResult(const Allocation* result_allocation,
SessionModule* module) {
module->clear_result();
return LiteralFromAllocation(result_allocation, result_allocation->shape(),
module->mutable_result());
Literal result;
TF_RETURN_IF_ERROR(LiteralFromAllocation(
result_allocation, result_allocation->shape(), &result));
*module->mutable_result() = result.ToProto();
return tensorflow::Status::OK();
}
} // namespace
@ -649,6 +654,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
ResolveAndValidateArguments(request.arguments(), execute_backend_.get(),
executor->device_ordinal()));
std::vector<se::DeviceMemoryBase> arguments;
arguments.reserve(arg_allocations.size());
for (const Allocation* allocation : arg_allocations) {
arguments.push_back(allocation->device_memory());
}
@ -677,6 +683,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
BuildExecutables(versioned_handles, std::move(module_configs),
execute_backend_.get(), executors));
std::vector<Executable*> executable_ptrs;
executable_ptrs.reserve(executables.size());
for (const auto& executable : executables) {
executable_ptrs.push_back(executable.get());
}
@ -752,6 +759,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
<< module_config->entry_computation_layout().ToString();
std::vector<se::DeviceMemoryBase> arguments;
arguments.reserve(arg_allocations.size());
for (const Allocation* allocation : arg_allocations) {
arguments.push_back(allocation->device_memory());
}
@ -820,6 +828,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
<< module_config->entry_computation_layout().ToString();
std::vector<se::DeviceMemoryBase> arguments;
arguments.reserve(arg_allocations.size());
for (const Allocation* allocation : arg_allocations) {
arguments.push_back(allocation->device_memory());
}
@ -908,13 +917,15 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
literal_shape = &allocation->shape();
}
return LiteralFromAllocation(allocation, *literal_shape,
result->mutable_literal());
Literal literal;
auto status = LiteralFromAllocation(allocation, *literal_shape, &literal);
*result->mutable_literal() = literal.ToProto();
return status;
}
tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) {
const Literal& literal = arg->literal();
Literal literal = Literal(arg->literal());
const Shape& shape = literal.shape();
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
@ -978,7 +989,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
}
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
executor, arg->literal());
executor, Literal(arg->literal()));
}
tensorflow::Status Service::TransferFromOutfeed(
@ -1001,8 +1012,12 @@ tensorflow::Status Service::TransferFromOutfeed(
executor = execute_backend_->Replicas()[arg->replica_id()];
}
return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
executor, arg->shape_with_layout(), result->mutable_literal());
Literal literal;
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
executor, arg->shape_with_layout(), &literal));
*result->mutable_literal() = literal.ToProto();
return tensorflow::Status::OK();
}
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,

View File

@ -75,10 +75,10 @@ message SessionModule {
repeated SessionComputation embedded_computations = 2;
// The arguments passed to the computation.
repeated Literal arguments = 3;
repeated LiteralProto arguments = 3;
// The result of the computation.
Literal result = 4;
LiteralProto result = 4;
// The name of the platform used to run the computation.
string execution_platform = 5;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <set>
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -121,7 +121,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) {
const Shape shape = ShapeUtil::MakeShape(U8, {4});
TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
stream_exec_, memptr, shape, shape, &literal));
CHECK_EQ("klmn", literal.u8s());
CHECK_EQ("klmn", literal.u8s_string());
}
TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {

View File

@ -2275,7 +2275,7 @@ void ComputationLowerer::Visit(
const ConstantRequest& constant_request =
request.request().constant_request();
hlo_instruction = add_instruction(HloInstruction::CreateConstant(
LiteralUtil::CloneToUnique(constant_request.literal())));
LiteralUtil::CloneToUnique(Literal(constant_request.literal()))));
break;
}
@ -2467,6 +2467,7 @@ void ComputationLowerer::Visit(
// to append dimensions on the left the broadcast_dimensions should just
// be the n highest dimension numbers of the output shape where n is
// the number of input dimensions.
broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape()));
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
broadcast_dimensions.push_back(i +
ShapeUtil::Rank(request.output_shape()) -

View File

@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) {
ConstantRequest constant_request;
*constant_request.mutable_literal() =
*LiteralUtil::CreateR1<float>({123.0f, 42.0f});
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle,
computation.AddConstantInstruction(constant_request));
@ -160,12 +160,13 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
UserComputation computation("TheComputation", handle);
ConstantRequest a_request;
*a_request.mutable_literal() = *LiteralUtil::CreateR1<float>({123.0f, 42.0f});
*a_request.mutable_literal() =
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
computation.AddConstantInstruction(a_request));
ConstantRequest b_request;
*b_request.mutable_literal() = *LiteralUtil::CreateR0<float>(1.0f);
*b_request.mutable_literal() = LiteralUtil::CreateR0<float>(1.0f)->ToProto();
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
computation.AddConstantInstruction(b_request));

View File

@ -44,6 +44,7 @@ struct ShapeTreeNode {
// Children of this node.
std::vector<std::unique_ptr<ShapeTreeNode>> children;
ShapeTreeNode() = default;
explicit ShapeTreeNode(const T& data) : data(data) {}
ShapeTreeNode(const ShapeTreeNode& other)
@ -85,8 +86,9 @@ class ShapeTree {
public:
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
// Create ShapeTree with the given shape, and default T values for all nodes.
explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {}
// Create ShapeTree with the given shape, and default-constructed T values for
// all nodes.
explicit ShapeTree(const Shape& shape);
// Create ShapeTree with the given shape, and init_value for all nodes.
ShapeTree(const Shape& shape, const T& init_value);
@ -127,6 +129,19 @@ class ShapeTree {
const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
Status ForEachMutableElement(const MutableVisitorFunction& func);
// Copy the subtree of values from 'other' rooted at ShapeIndex
// 'source_base_index' into the subtree of value in this ShapeTree rooted at
// 'target_base_index'.
//
// Precondition: The subshape of other.shape() at index source_base_index must
// be compatible with the subshape of shape() at index target_base_index.
void CopySubtreeFrom(const ShapeTree<T>& other,
const ShapeIndex& source_base_index,
const ShapeIndex& target_base_index);
bool operator==(const ShapeTree<T>& other) const;
bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
private:
using Node = internal::ShapeTreeNode<T>;
@ -134,6 +149,10 @@ class ShapeTree {
// the given 'init_value'.
void InitChildren(const Shape& shape, const T& init_value, Node* node);
// Initialize node->children based on 'shape'. All children have
// default-constructed data values.
void InitChildren(const Shape& shape, Node* node);
// Helpers for traversing the shape via ForEachElement. The helpers
// recursively traverse the subtree rooted at "index" (defined as in
// ShapeUtil::GetSubshape).
@ -165,6 +184,24 @@ void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
}
}
template <typename T>
void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
if (ShapeUtil::IsTuple(shape)) {
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
node->children.emplace_back(new Node());
InitChildren(shape.tuple_shapes(i), node->children.back().get());
}
}
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape& shape) : root_(), shape_(shape) {
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(&shape_);
InitChildren(shape_, &root_);
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
: root_(init_value), shape_(shape) {
@ -240,6 +277,48 @@ Status ShapeTree<T>::ForEachMutableElement(const MutableVisitorFunction& func) {
return ForEachMutableHelper(func, &root_, &index);
}
template <typename T>
void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
const ShapeIndex& source_base_index,
const ShapeIndex& target_base_index) {
CHECK(ShapeUtil::Compatible(
ShapeUtil::GetSubshape(shape(), target_base_index),
ShapeUtil::GetSubshape(other.shape(), source_base_index)));
ForEachMutableElement(
[this, &other, &source_base_index, &target_base_index](
const ShapeIndex& index, bool /*is_leaf*/, T* data) {
// Copy the data element only if index is in the
// subtree rooted at target_base_index.
for (int i = 0; i < target_base_index.size(); ++i) {
if (i >= index.size() || index[i] != target_base_index[i]) {
return Status::OK();
}
}
// Construct source element index to copy from.
ShapeIndex source_index = source_base_index;
for (int i = target_base_index.size(); i < index.size(); ++i) {
source_index.push_back(index[i]);
}
*data = other.element(source_index);
return Status::OK();
})
.IgnoreError();
}
template <typename T>
bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
bool equal = true;
ForEachElement([this, &other, &equal](const ShapeIndex& index,
bool /*is_leaf*/, const T& data) {
if (data != other.element(index)) {
equal = false;
}
return Status::OK();
})
.IgnoreError();
return equal;
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_

View File

@ -245,5 +245,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
EXPECT_DEATH(shape_tree.element({0, 0}), "");
}
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
*shape_tree.mutable_element({2}) = MakeUnique<int>(42);
EXPECT_EQ(*shape_tree.element({2}), 42);
}
TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) {
// Test CopySubtreeFrom method for a single value copied between array-shaped
// ShapeTrees.
ShapeTree<int> source(array_shape_);
*source.mutable_element(/*index=*/{}) = 42;
ShapeTree<int> destination(array_shape_, 123);
EXPECT_EQ(destination.element(/*index=*/{}), 123);
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
/*target_base_index=*/{});
EXPECT_EQ(destination.element(/*index=*/{}), 42);
}
TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) {
// Test CopySubtreeFrom method for a copy of all elements from one
// tuple-shaped ShapeTree to another.
ShapeTree<int> source(tuple_shape_);
*source.mutable_element(/*index=*/{}) = 10;
*source.mutable_element(/*index=*/{0}) = 11;
*source.mutable_element(/*index=*/{1}) = 12;
*source.mutable_element(/*index=*/{2}) = 13;
ShapeTree<int> destination(tuple_shape_, 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
/*target_base_index=*/{});
EXPECT_EQ(destination.element(/*index=*/{}), 10);
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
EXPECT_EQ(destination.element(/*index=*/{2}), 13);
}
TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) {
// Test CopySubtreeFrom method for a copy of a single element from one
// tuple-shaped ShapeTree to another.
ShapeTree<int> source(tuple_shape_);
*source.mutable_element(/*index=*/{}) = 10;
*source.mutable_element(/*index=*/{0}) = 11;
*source.mutable_element(/*index=*/{1}) = 12;
*source.mutable_element(/*index=*/{2}) = 13;
ShapeTree<int> destination(tuple_shape_, 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{0},
/*target_base_index=*/{1});
EXPECT_EQ(destination.element(/*index=*/{}), 0);
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
EXPECT_EQ(destination.element(/*index=*/{1}), 11);
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
}
TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) {
// Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a
// nested-tuple-shaped ShapeTree.
ShapeTree<int> source(
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}));
*source.mutable_element(/*index=*/{}) = 10;
*source.mutable_element(/*index=*/{0}) = 11;
*source.mutable_element(/*index=*/{1}) = 12;
ShapeTree<int> destination(nested_tuple_shape_, 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
/*target_base_index=*/{2, 0});
EXPECT_EQ(destination.element(/*index=*/{}), 0);
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
EXPECT_EQ(destination.element(/*index=*/{1}), 0);
EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0);
EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0);
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10);
EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11);
EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12);
EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0);
}
TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) {
// Test CopySubtreeFrom method for a copy from a nested-tuple-shape.
ShapeTree<int> source(nested_tuple_shape_, 42);
*source.mutable_element(/*index=*/{1}) = 10;
*source.mutable_element(/*index=*/{1, 0}) = 11;
*source.mutable_element(/*index=*/{1, 1}) = 12;
ShapeTree<int> destination(
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{1},
/*target_base_index=*/{});
EXPECT_EQ(destination.element(/*index=*/{}), 10);
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
}
TEST_F(ShapeTreeTest, OperatorEquals) {
{
ShapeTree<int> a(array_shape_, 123);
ShapeTree<int> b(array_shape_, 42);
ShapeTree<int> c(array_shape_, 42);
EXPECT_FALSE(a == b);
EXPECT_TRUE(a != b);
EXPECT_TRUE(b == c);
}
{
ShapeTree<int> a(tuple_shape_);
*a.mutable_element(/*index=*/{}) = 10;
*a.mutable_element(/*index=*/{0}) = 11;
*a.mutable_element(/*index=*/{1}) = 12;
ShapeTree<int> b(tuple_shape_);
*b.mutable_element(/*index=*/{}) = 10;
*b.mutable_element(/*index=*/{0}) = 42;
*b.mutable_element(/*index=*/{1}) = 11;
ShapeTree<int> c(tuple_shape_);
*c.mutable_element(/*index=*/{}) = 10;
*c.mutable_element(/*index=*/{0}) = 42;
*c.mutable_element(/*index=*/{1}) = 11;
EXPECT_FALSE(a == b);
EXPECT_TRUE(a != b);
EXPECT_TRUE(b == c);
EXPECT_FALSE(b != c);
}
}
} // namespace
} // namespace xla

View File

@ -122,7 +122,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
for (const auto& shape : parameters) {
*program_shape.add_parameters() = shape;
}
*program_shape.mutable_result() = result;
*program_shape.mutable_result() = std::move(result);
return program_shape;
}

View File

@ -829,6 +829,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
const int count = GetParam();
ComputationBuilder builder(client_, TestName());
std::vector<float> values;
values.reserve(count);
for (int i = 0; i < count; ++i) {
values.push_back(i / static_cast<float>(count));
}
@ -836,6 +837,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
std::vector<float> expected;
expected.reserve(values.size());
for (float value : values) {
expected.push_back(value * value);
}

View File

@ -179,7 +179,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
EXPECT_EQ(expected, actual->u8s());
EXPECT_EQ(expected, actual->u8s_string());
}
void ClientLibraryTestBase::ComputeAndCompareTuple(

View File

@ -442,6 +442,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
ComputeAndCompareR1<int32>(&builder, expected, {});
}
XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
ComputationBuilder builder(client_, TestName());
Array3D<float> arr0(9, 17, 1);
arr0.Fill(1);
Array3D<float> arr1(9, 17, 256);
arr1.Fill(2);
Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
for (int64 i = 0; i < expected.n1(); ++i) {
for (int64 j = 0; j < expected.n2(); ++j) {
int64 kk = 0;
for (const Array3D<float>& arr : {arr0, arr1}) {
for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
expected(i, j, kk) = arr(i, j, k);
}
}
}
}
ComputationDataHandle h0;
auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
&builder, &h0);
ComputationDataHandle h1;
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
&builder, &h1);
auto concatenated = builder.ConcatInDim({h0, h1}, 2);
ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
}
// Describes a binary rank-2 concatenation test.
struct R2BinarySpec {
int64 lhs_dim0;

View File

@ -262,7 +262,7 @@ class NearComparator {
max_abs_err_ = 0.0;
*miscompares_.mutable_shape() =
ShapeUtil::ChangeElementType(actual.shape(), PRED);
miscompares_.mutable_preds()->Resize(
miscompares_.mutable_preds()->resize(
ShapeUtil::ElementsIn(miscompares_.shape()), false);
multi_index_.resize(expected.shape().dimensions_size(), 0);
@ -389,7 +389,7 @@ class NearComparator {
tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
now_usec, name.c_str()));
TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
filename, literal));
filename, literal.ToProto()));
LOG(ERROR) << "wrote to " << name << " file: " << filename;
}

View File

@ -83,9 +83,10 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
EXPECT_EQ(3, results.size());
for (const string& result : results) {
Literal literal;
LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
&literal));
&literal_proto));
Literal literal(literal_proto);
if (result.find("expected") != string::npos) {
EXPECT_EQ("2", LiteralUtil::ToString(literal));
} else if (result.find("actual") != string::npos) {

View File

@ -47,6 +47,7 @@ TEST_F(LogTest, LogTenValues) {
builder.Log(x);
std::vector<float> expected;
expected.reserve(input.size());
for (float f : input) {
expected.push_back(std::log(f));
}

View File

@ -246,6 +246,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
}
std::vector<GlobalData*> param_data;
param_data.reserve(param_data_owner.size());
for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
param_data.push_back(data.get());
}

View File

@ -37,6 +37,7 @@ class SliceTest : public ClientLibraryTestBase {
template <typename NativeT>
void RunSliceTenToTwo() {
std::vector<NativeT> constant;
constant.reserve(10);
for (int i = 0; i < 10; ++i) {
constant.push_back(static_cast<NativeT>(i));
}

View File

@ -64,6 +64,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
ComputationBuilder builder(client_, TestName());
std::vector<float> exponents;
exponents.reserve(count);
for (int i = 0; i < count; ++i) {
exponents.push_back(i / static_cast<float>(count));
}
@ -71,6 +72,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
auto exp = builder.Exp(x);
std::vector<float> expected;
expected.reserve(exponents.size());
for (float exponent : exponents) {
expected.push_back(std::exp(exponent));
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"

View File

@ -81,6 +81,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
client->GetComputationShape(computation).ConsumeValueOrDie();
std::vector<const Shape*> layouts;
layouts.reserve(program_shape->parameters_size());
for (int i = 0; i < program_shape->parameters_size(); ++i) {
layouts.push_back(&program_shape->parameters(i));
}

View File

@ -56,6 +56,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
client->GetComputationShape(computation).ConsumeValueOrDie();
std::vector<const Shape*> layouts;
layouts.reserve(program_shape->parameters_size());
for (int i = 0; i < program_shape->parameters_size(); ++i) {
layouts.push_back(&program_shape->parameters(i));
}

View File

@ -66,7 +66,8 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
if (use_fake_data) {
arguments = MakeFakeArgumentsOrDie(computation, client);
} else { // use recorded data if available
for (const Literal& literal : module.arguments()) {
for (const auto& proto : module.arguments()) {
Literal literal(proto);
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
client->TransferToServer(literal));
arguments.push_back(std::move(data));
@ -74,6 +75,7 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
}
std::vector<GlobalData*> execute_arguments;
execute_arguments.reserve(arguments.size());
for (auto& argument : arguments) {
execute_arguments.push_back(argument.get());
}
@ -100,7 +102,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
if (module.has_result()) {
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(module.result().shape()).c_str(),
LiteralUtil::ToString(module.result()).c_str());
LiteralUtil::ToString(Literal(module.result())).c_str());
}
}
}

View File

@ -37,9 +37,10 @@ int main(int argc, char **argv) {
<< " <path-to-serialized-literal-proto>";
}
xla::Literal literal;
xla::LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
&literal));
LOG(INFO) << "literal: " << literal.ShortDebugString();
&literal_proto));
xla::Literal literal(literal_proto);
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str());
}

View File

@ -92,11 +92,11 @@ message TransferToClientRequest {
}
message TransferToClientResponse {
Literal literal = 1;
LiteralProto literal = 1;
}
message TransferToServerRequest {
Literal literal = 1;
LiteralProto literal = 1;
DeviceHandle device_handle = 2;
}
@ -105,7 +105,7 @@ message TransferToServerResponse {
}
message TransferToInfeedRequest {
Literal literal = 1;
LiteralProto literal = 1;
int64 replica_id = 2;
DeviceHandle device_handle = 3;
}
@ -123,7 +123,7 @@ message TransferFromOutfeedRequest {
}
message TransferFromOutfeedResponse {
Literal literal = 1;
LiteralProto literal = 1;
}
message ResetDeviceRequest {

View File

@ -275,7 +275,7 @@ message ChannelHandle {
//
// Transfers to/from the client are encoded in literal form, and the structure
// of the repeated fields is implied by the shape.
message Literal {
message LiteralProto {
Shape shape = 1;
repeated bool preds = 2;
bytes u8s = 3;
@ -285,7 +285,7 @@ message Literal {
repeated uint64 u64s = 7;
repeated float f32s = 8;
repeated double f64s = 9;
repeated Literal tuple_literals = 10;
repeated LiteralProto tuple_literals = 10;
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
}
@ -337,7 +337,7 @@ message Window {
// field in OpRequest.
message ConstantRequest {
Literal literal = 2;
LiteralProto literal = 2;
}
message GetTupleElementRequest {

View File

@ -85,6 +85,7 @@ cc_library(
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/nccl:nccl_kernels",
"//tensorflow/contrib/seq2seq:beam_search_ops_kernels",
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
"//tensorflow/contrib/text:all_kernels",
],
@ -100,6 +101,7 @@ cc_library(
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
"//tensorflow/contrib/text:all_ops",
],

View File

@ -347,6 +347,7 @@ class BatchResource : public ResourceBase {
// Concatenate the tasks ith input tensors into a big output tensor.
std::vector<Tensor> to_concatenate;
to_concatenate.reserve(batch->num_tasks());
for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
to_concatenate.push_back(batch->task(task_idx).inputs.at(i));
}

View File

@ -139,6 +139,7 @@ TEST(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) {
&callback_data](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
std::vector<size_t> batch_data;
batch_data.reserve(batch->num_tasks());
for (int i = 0; i < batch->num_tasks(); ++i) {
batch_data.push_back(batch->mutable_task(i)->size());
}

View File

@ -295,6 +295,7 @@ void ExpectVecsEquiv(const std::vector<float>& vec1,
std::vector<float> GetWeightsByIndex(const std::vector<float>& weights,
const std::vector<int>& indices) {
std::vector<float> res;
res.reserve(indices.size());
for (const int index : indices) {
res.push_back(weights[index]);
}

View File

@ -236,6 +236,9 @@ add_python_module("tensorflow/tensorboard")
add_python_module("tensorflow/tensorboard/backend")
add_python_module("tensorflow/tensorboard/backend/event_processing")
add_python_module("tensorflow/tensorboard/plugins")
add_python_module("tensorflow/tensorboard/plugins/audio")
add_python_module("tensorflow/tensorboard/plugins/distributions")
add_python_module("tensorflow/tensorboard/plugins/graphs")
add_python_module("tensorflow/tensorboard/plugins/histograms")
add_python_module("tensorflow/tensorboard/plugins/images")
add_python_module("tensorflow/tensorboard/plugins/projector")
@ -536,6 +539,7 @@ set(tf_python_op_gen_main_srcs
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h"
)
add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs})

View File

@ -209,10 +209,11 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Broken TensorBoard tests due to different paths in windows
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
# Broken tensorboard test due to cmake issues.
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
# tensor_forest tests (also note that we exclude the hybrid tests for now)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.

View File

@ -150,7 +150,8 @@ class MapDatasetTest(test.TestCase):
results.append(sess.run(get_next))
except errors.OutOfRangeError:
return
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
threads = [self.checkedThread(target=iterator_thread)
for _ in range(64)]
for t in threads:
t.start()
for t in threads:

View File

@ -375,8 +375,8 @@ class NearestNeighborsOp : public OpKernel {
const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
const Eigen::Ref<const MatrixXfRowMajor>& centers,
const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
CHECK_LE(k, centers.rows());
if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,

View File

@ -164,11 +164,12 @@ class KMeans(object):
with ops.colocate_with(inp):
# Computes Euclidean distance. Note the first and third terms are
# broadcast additions.
squared_distance = (math_ops.reduce_sum(
math_ops.square(inp), 1, keep_dims=True) - 2 * math_ops.matmul(
inp, clusters, transpose_b=True) + array_ops.transpose(
math_ops.reduce_sum(
math_ops.square(clusters), 1, keep_dims=True)))
squared_distance = (
math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
2 * math_ops.matmul(inp, clusters, transpose_b=True) +
array_ops.transpose(
math_ops.reduce_sum(
math_ops.square(clusters), 1, keep_dims=True)))
output.append(squared_distance)
return output
@ -229,12 +230,12 @@ class KMeans(object):
clusters = nn_impl.l2_normalize(clusters, dim=1)
for inp, score in zip(inputs, scores):
with ops.colocate_with(inp):
(indices,
distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1)
(indices, distances) = gen_clustering_ops.nearest_neighbors(
inp, clusters, 1)
if self._distance_metric == COSINE_DISTANCE:
distances *= 0.5
output.append(
(score, array_ops.squeeze(distances), array_ops.squeeze(indices)))
output.append((score, array_ops.squeeze(distances),
array_ops.squeeze(indices)))
return zip(*output)
def _init_clusters_random(self):
@ -265,9 +266,7 @@ class KMeans(object):
(not self._use_mini_batch or
self._mini_batch_steps_per_iteration > 1))
def _initialize_clusters(self,
cluster_centers,
cluster_centers_initialized,
def _initialize_clusters(self, cluster_centers, cluster_centers_initialized,
cluster_centers_updated):
"""Returns an op to initialize the cluster centers."""
@ -294,22 +293,20 @@ class KMeans(object):
with ops.colocate_with(cluster_centers_initialized):
initialized = control_flow_ops.with_dependencies(
[clusters_init],
array_ops.identity(cluster_centers_initialized))
[clusters_init], array_ops.identity(cluster_centers_initialized))
with ops.colocate_with(cluster_centers):
assign_centers = state_ops.assign(cluster_centers, clusters_init,
validate_shape=False)
assign_centers = state_ops.assign(
cluster_centers, clusters_init, validate_shape=False)
if cluster_centers_updated != cluster_centers:
assign_centers = control_flow_ops.group(
assign_centers,
state_ops.assign(cluster_centers_updated, clusters_init,
validate_shape=False))
assign_centers = control_flow_ops.with_dependencies(
[assign_centers],
state_ops.assign(cluster_centers_initialized, True))
return control_flow_ops.cond(initialized,
control_flow_ops.no_op,
lambda: assign_centers).op
assign_centers = control_flow_ops.group(assign_centers,
state_ops.assign(
cluster_centers_updated,
clusters_init,
validate_shape=False))
assign_centers = control_flow_ops.with_dependencies(
[assign_centers], state_ops.assign(cluster_centers_initialized, True))
return control_flow_ops.cond(initialized, control_flow_ops.no_op,
lambda: assign_centers).op
def _create_variables(self):
"""Creates variables.
@ -327,19 +324,16 @@ class KMeans(object):
cluster_centers_updated back to cluster_centers.
"""
init_value = array_ops.constant([], dtype=dtypes.float32)
cluster_centers = variable_scope.variable(init_value,
name='clusters',
validate_shape=False)
cluster_centers_initialized = variable_scope.variable(False,
dtype=dtypes.bool,
name='initialized')
cluster_centers = variable_scope.variable(
init_value, name='clusters', validate_shape=False)
cluster_centers_initialized = variable_scope.variable(
False, dtype=dtypes.bool, name='initialized')
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
# Copy of cluster centers actively updated each step according to
# mini-batch update rule.
cluster_centers_updated = variable_scope.variable(init_value,
name='clusters_updated',
validate_shape=False)
cluster_centers_updated = variable_scope.variable(
init_value, name='clusters_updated', validate_shape=False)
# How many steps till we copy the updated clusters to cluster_centers.
update_in_steps = variable_scope.variable(
self._mini_batch_steps_per_iteration,
@ -347,20 +341,15 @@ class KMeans(object):
name='update_in_steps')
# Count of points assigned to cluster_centers_updated.
cluster_counts = variable_scope.variable(
array_ops.zeros([self._num_clusters],
dtype=dtypes.int64))
array_ops.zeros([self._num_clusters], dtype=dtypes.int64))
else:
cluster_centers_updated = cluster_centers
update_in_steps = None
cluster_counts = (variable_scope.variable(array_ops.ones(
[self._num_clusters],
dtype=dtypes.int64))
cluster_counts = (variable_scope.variable(
array_ops.ones([self._num_clusters], dtype=dtypes.int64))
if self._use_mini_batch else None)
return (cluster_centers,
cluster_centers_initialized,
cluster_counts,
cluster_centers_updated,
update_in_steps)
return (cluster_centers, cluster_centers_initialized, cluster_counts,
cluster_centers_updated, update_in_steps)
@classmethod
def _l2_normalize_data(cls, inputs):
@ -391,11 +380,8 @@ class KMeans(object):
"""
# Implementation of kmeans.
inputs = self._inputs
(cluster_centers_var,
cluster_centers_initialized,
total_counts,
cluster_centers_updated,
update_in_steps) = self._create_variables()
(cluster_centers_var, cluster_centers_initialized, total_counts,
cluster_centers_updated, update_in_steps) = self._create_variables()
init_op = self._initialize_clusters(cluster_centers_var,
cluster_centers_initialized,
cluster_centers_updated)
@ -409,8 +395,7 @@ class KMeans(object):
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
if self._use_mini_batch:
sync_updates_op = self._mini_batch_sync_updates_op(
update_in_steps,
cluster_centers_var, cluster_centers_updated,
update_in_steps, cluster_centers_var, cluster_centers_updated,
total_counts)
assert sync_updates_op is not None
with ops.control_dependencies([sync_updates_op]):
@ -421,15 +406,15 @@ class KMeans(object):
training_op = self._full_batch_training_op(inputs, cluster_idx,
cluster_centers_var)
return (all_scores, cluster_idx, scores,
cluster_centers_initialized, init_op, training_op)
return (all_scores, cluster_idx, scores, cluster_centers_initialized,
init_op, training_op)
def _mini_batch_sync_updates_op(self, update_in_steps,
cluster_centers_var, cluster_centers_updated,
total_counts):
def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
cluster_centers_updated, total_counts):
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
assert update_in_steps is not None
with ops.colocate_with(update_in_steps):
def _f():
# Note that there is a race condition here, so we do a best effort
# updates here. We reset update_in_steps first so that other workers
@ -437,33 +422,36 @@ class KMeans(object):
# before resetting total_counts to avoid large updates to
# cluster_centers_updated based on partially updated
# cluster_center_vars.
with ops.control_dependencies([state_ops.assign(
update_in_steps,
self._mini_batch_steps_per_iteration - 1)]):
with ops.colocate_with(cluster_centers_updated):
with ops.control_dependencies([
state_ops.assign(update_in_steps,
self._mini_batch_steps_per_iteration - 1)
]):
with ops.colocate_with(
cluster_centers_updated, ignore_existing=True):
if self._distance_metric == COSINE_DISTANCE:
cluster_centers = nn_impl.l2_normalize(cluster_centers_updated,
dim=1)
cluster_centers = nn_impl.l2_normalize(
cluster_centers_updated, dim=1)
else:
cluster_centers = cluster_centers_updated
with ops.colocate_with(cluster_centers_var):
with ops.control_dependencies([state_ops.assign(
cluster_centers_var,
cluster_centers)]):
with ops.colocate_with(cluster_centers_var):
with ops.control_dependencies(
[state_ops.assign(cluster_centers_var, cluster_centers)]):
with ops.colocate_with(
cluster_centers_var, ignore_existing=True):
with ops.control_dependencies([
state_ops.assign(total_counts,
array_ops.zeros_like(total_counts))]):
array_ops.zeros_like(total_counts))
]):
return array_ops.identity(update_in_steps)
return control_flow_ops.cond(
update_in_steps <= 0,
_f,
update_in_steps <= 0, _f,
lambda: state_ops.assign_sub(update_in_steps, 1))
else:
return control_flow_ops.no_op()
def _mini_batch_training_op(self, inputs, cluster_idx_list,
cluster_centers, total_counts):
def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
total_counts):
"""Creates an op for training for mini batch case.
Args:
@ -487,17 +475,15 @@ class KMeans(object):
unique_ids, unique_idx = array_ops.unique(cluster_idx)
num_unique_cluster_idx = array_ops.size(unique_ids)
# Fetch the old values of counts and cluster_centers.
with ops.colocate_with(total_counts):
with ops.colocate_with(total_counts, ignore_existing=True):
old_counts = array_ops.gather(total_counts, unique_ids)
# TODO(agarwal): This colocation seems to run into problems. Fix it.
# with ops.colocate_with(cluster_centers):
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
with ops.colocate_with(cluster_centers, ignore_existing=True):
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
# Locally aggregate the increment to counts.
count_updates = math_ops.unsorted_segment_sum(
array_ops.ones_like(
unique_idx, dtype=total_counts.dtype),
unique_idx,
num_unique_cluster_idx)
array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
unique_idx, num_unique_cluster_idx)
# Locally compute the sum of inputs mapped to each id.
# For a cluster with old cluster value x, old count n, and with data
# d_1,...d_k newly assigned to it, we recompute the new value as
@ -507,13 +493,12 @@ class KMeans(object):
inp, unique_idx, num_unique_cluster_idx)
# Shape to enable broadcasting count_updates and learning_rate to inp.
# It extends the shape with 1's to match the rank of inp.
broadcast_shape = array_ops.concat(
[
array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
dtype=dtypes.int32)
],
0)
broadcast_shape = array_ops.concat([
array_ops.reshape(num_unique_cluster_idx, [1]),
array_ops.ones(
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
dtype=dtypes.int32)
], 0)
# Subtract k * x, see comment above.
cluster_center_updates -= math_ops.cast(
array_ops.reshape(count_updates, broadcast_shape),
@ -524,14 +509,10 @@ class KMeans(object):
# scale by 1 / (n + k), see comment above.
cluster_center_updates *= learning_rate
# Apply the updates.
update_counts = state_ops.scatter_add(
total_counts,
unique_ids,
count_updates)
update_counts = state_ops.scatter_add(total_counts, unique_ids,
count_updates)
update_cluster_centers = state_ops.scatter_add(
cluster_centers,
unique_ids,
cluster_center_updates)
cluster_centers, unique_ids, cluster_center_updates)
update_ops.extend([update_counts, update_cluster_centers])
return control_flow_ops.group(*update_ops)
@ -552,7 +533,7 @@ class KMeans(object):
cluster_counts = []
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
for inp, cluster_idx in zip(inputs, cluster_idx_list):
with ops.colocate_with(inp):
with ops.colocate_with(inp, ignore_existing=True):
cluster_sums.append(
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
cluster_counts.append(
@ -561,7 +542,7 @@ class KMeans(object):
array_ops.ones(
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
[-1, 1]), cluster_idx, self._num_clusters))
with ops.colocate_with(cluster_centers):
with ops.colocate_with(cluster_centers, ignore_existing=True):
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
if self._clusters_l2_normalized():

View File

@ -94,6 +94,7 @@ TEST(FfmpegLibTest, TestRoundTripGeneratedWav) {
}
std::vector<float> sine_wave;
sine_wave.reserve(20000);
for (int i = 0; i < 20000; ++i) {
sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0));
}

View File

@ -494,6 +494,7 @@ class SparseFeatureCrossOp : public OpKernel {
ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
&feature_start_indices);
columns.reserve(values_list_in.size());
for (int i = 0; i < values_list_in.size(); ++i) {
columns.emplace_back(new SparseTensorColumn<InternalType>(
values_list_in[i], std::move(feature_counts[i]),

View File

@ -308,6 +308,7 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_rea
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
from tensorflow.contrib.learn.python.learn.estimators.head import Head
from tensorflow.contrib.learn.python.learn.estimators.head import loss_only_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head

View File

@ -429,6 +429,23 @@ def multi_label_head(n_classes,
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
def loss_only_head(loss_fn, head_name=None):
"""Creates a Head that contains only loss terms.
Loss only head holds additional loss terms to be added to other heads and
usually represents additional regularization terms in the objective function.
Args:
loss_fn: a function that takes no argument and returns a list of
scalar tensors.
head_name: a name for for the head.
Returns:
An instance of `Head` to hold the additional losses.
"""
return _LossOnlyHead(loss_fn, head_name=head_name)
def multi_head(heads, loss_weights=None):
"""Creates a MultiHead stemming from same logits/hidden layer.
@ -1406,6 +1423,80 @@ class _MultiLabelHead(_SingleHead):
return metrics
class _LossOnlyHead(Head):
"""`Head` implementation for additional loss terms.
This class only holds loss terms unrelated to any other heads (labels),
e.g. regularization.
Common usage:
This is oftem combine with other heads in a multi head setup.
```python
head = multi_head([
head1, head2, loss_only_head('regularizer', regularizer)])
```
"""
def __init__(self, loss_fn, head_name=None):
self._loss_fn = loss_fn
self.head_name = head_name or "loss_only_head"
@property
def logits_dimension(self):
return 0
def create_model_fn_ops(self,
features,
mode,
labels=None,
train_op_fn=None,
logits=None,
logits_input=None,
scope=None):
"""See `_Head.create_model_fn_ops`.
Args:
features: Not been used.
mode: Estimator's `ModeKeys`.
labels: Labels `Tensor`, or `dict` of same.
train_op_fn: Function that takes a scalar loss and returns an op to
optimize with the loss.
logits: Not been used.
logits_input: Not been used.
scope: Optional scope for variable_scope. If provided, will be passed to
all heads. Most users will want to set this to `None`, so each head
constructs a separate variable_scope according to its `head_name`.
Returns:
A `ModelFnOps` object.
Raises:
ValueError: if `mode` is not recognition.
"""
_check_mode_valid(mode)
loss = None
train_op = None
if mode != model_fn.ModeKeys.INFER:
with variable_scope.variable_scope(scope, default_name=self.head_name):
loss = self._loss_fn()
if isinstance(loss, list):
loss = math_ops.add_n(loss)
logging_ops.scalar_summary(
_summary_key(self.head_name, mkey.LOSS), loss)
if mode == model_fn.ModeKeys.TRAIN:
if train_op_fn is None:
raise ValueError("train_op_fn can not be None in TRAIN mode")
with ops.name_scope(None, "train_op", (loss,)):
train_op = train_op_fn(loss)
return model_fn.ModelFnOps(
mode=mode,
loss=loss,
train_op=train_op,
predictions={},
eval_metric_ops={})
class _MultiHead(Head):
"""`Head` implementation for multi objective learning.
@ -1525,7 +1616,10 @@ class _MultiHead(Head):
if isinstance(logits, dict):
head_logits_pairs = []
for head in self._heads:
head_logits_pairs.append((head, logits[head.head_name]))
if isinstance(head, _LossOnlyHead):
head_logits_pairs.append((head, None))
else:
head_logits_pairs.append((head, logits[head.head_name]))
else:
# Split logits for each head.
head_logits_pairs = zip(self._heads, self._split_logits(logits))
@ -1606,6 +1700,8 @@ class _MultiHead(Head):
predictions = {}
output_alternatives = {}
for head, m in zip(self._heads, all_model_fn_ops):
if isinstance(head, _LossOnlyHead):
continue
head_name = head.head_name
output_alternatives[head_name] = m.output_alternatives[head_name]
for k, v in m.predictions.items():

View File

@ -1638,6 +1638,21 @@ class BinarySvmHeadTest(test.TestCase):
}, model_fn_ops)
class LossOnlyHead(test.TestCase):
def testNoPredictionsAndNoMetrics(self):
head = head_lib.loss_only_head(lambda: 1, head_name="const")
model_fn_ops = head.create_model_fn_ops(
features={},
mode=model_fn.ModeKeys.TRAIN,
train_op_fn=head_lib.no_op_train_fn)
self.assertDictEqual(model_fn_ops.predictions, {})
self.assertDictEqual(model_fn_ops.eval_metric_ops, {})
self.assertIsNotNone(model_fn_ops.loss)
with session.Session() as sess:
self.assertEqual(1, sess.run(model_fn_ops.loss))
class MultiHeadTest(test.TestCase):
def testInvalidHeads(self):
@ -1672,7 +1687,8 @@ class MultiHeadTest(test.TestCase):
n_classes=3, label_name="label1", head_name="head1")
head2 = head_lib.multi_class_head(
n_classes=4, label_name="label2", head_name="head2")
head = head_lib.multi_head((head1, head2))
head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")
head = head_lib.multi_head((head1, head2, head3))
labels = {
"label1": (1,),
"label2": (1,)
@ -1691,7 +1707,7 @@ class MultiHeadTest(test.TestCase):
self.assertIsNone(model_fn_ops.output_alternatives)
with session.Session() as sess:
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)
def testTrain_withHeadWeights(self):
head1 = head_lib.multi_class_head(

View File

@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None,
```
Args:
vocabulary_file: The vocabulary filename.
vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
num_oov_buckets: The number of out-of-vocabulary buckets.
vocab_size: Number of the elements in the vocabulary, if known.
default_value: The value to use for out-of-vocabulary feature values.
@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None,
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
than zero.
"""
if not vocabulary_file:
raise ValueError("vocabulary_file must be specified.")
if vocabulary_file is None or (
isinstance(vocabulary_file, str) and not vocabulary_file):
raise ValueError("vocabulary_file must be specified and must not be empty.")
if num_oov_buckets < 0:
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
% num_oov_buckets)

View File

@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase):
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
with self.test_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase):
860), # 3 + fingerprint("toccata") mod 300.
ids.eval())
def test_index_table_from_file_with_only_oov_buckets(self):
def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
self.assertRaises(
ValueError,
lookup.index_table_from_file,
vocabulary_file="")
def test_index_table_from_file_fails_with_empty_vocabulary(self):
self.assertRaises(
ValueError,
lookup.index_table_from_file,

View File

@ -23,6 +23,7 @@ See the @{$python/contrib.metrics} guide.
@@streaming_precision
@@streaming_precision_at_thresholds
@@streaming_auc
@@streaming_curve_points
@@streaming_recall_at_k
@@streaming_mean_absolute_error
@@streaming_mean_iou
@ -76,6 +77,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives

View File

@ -733,6 +733,102 @@ def streaming_true_negatives_at_thresholds(
return values['tn'], update_ops['tn']
def streaming_curve_points(labels=None,
predictions=None,
weights=None,
num_thresholds=200,
metrics_collections=None,
updates_collections=None,
curve='ROC',
name=None):
"""Computes curve (ROC or PR) values for a prespecified number of points.
The `streaming_curve_points` function creates four local variables,
`true_positives`, `true_negatives`, `false_positives` and `false_negatives`
that are used to compute the curve values. To discretize the curve, a linearly
spaced set of thresholds is used to compute pairs of recall and precision
values.
For best results, `predictions` should be distributed approximately uniformly
in the range [0, 1] and not peaked around 0 or 1.
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
`bool`.
predictions: A floating point `Tensor` of arbitrary shape and whose values
are in the range `[0, 1]`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `labels` dimension).
num_thresholds: The number of thresholds to use when discretizing the roc
curve.
metrics_collections: An optional list of collections that `auc` should be
added to.
updates_collections: An optional list of collections that `update_op` should
be added to.
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
'PR' for the Precision-Recall-curve.
name: An optional variable_scope name.
Returns:
points: A `Tensor` with shape [num_thresholds, 2] that contains points of
the curve.
update_op: An operation that increments the `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` variables.
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
`weights` is not `None` and its shape doesn't match `predictions`, or if
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(name, 'curve_points', (labels, predictions,
weights)):
if curve != 'ROC' and curve != 'PR':
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _streaming_confusion_matrix_at_thresholds(
labels=labels,
predictions=predictions,
thresholds=thresholds,
weights=weights)
# Add epsilons to avoid dividing by 0.
epsilon = 1.0e-6
def compute_points(tp, fn, tn, fp):
"""Computes the roc-auc or pr-auc based on confusion counts."""
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
if curve == 'ROC':
fp_rate = math_ops.div(fp, fp + tn + epsilon)
return fp_rate, rec
else: # curve == 'PR'.
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
return rec, prec
xs, ys = compute_points(values['tp'], values['fn'], values['tn'],
values['fp'])
points = array_ops.stack([xs, ys], axis=1)
update_op = control_flow_ops.group(*update_ops.values())
if metrics_collections:
ops.add_to_collections(metrics_collections, points)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return points, update_op
def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
metrics_collections=None, updates_collections=None,
curve='ROC', name=None):
@ -2372,6 +2468,7 @@ __all__ = [
'sparse_recall_at_top_k',
'streaming_accuracy',
'streaming_auc',
'streaming_curve_points',
'streaming_false_negatives',
'streaming_false_negatives_at_thresholds',
'streaming_false_positives',

View File

@ -1327,6 +1327,99 @@ class StreamingRecallTest(test.TestCase):
self.assertEqual(0, recall.eval())
class StreamingCurvePointsTest(test.TestCase):
def setUp(self):
np.random.seed(1)
ops.reset_default_graph()
def testVars(self):
metric_ops.streaming_curve_points(
predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
_assert_local_variables(
self,
('curve_points/true_positives:0', 'curve_points/false_negatives:0',
'curve_points/false_positives:0', 'curve_points/true_negatives:0'))
def testMetricsCollection(self):
my_collection_name = '__metrics__'
points, _ = metric_ops.streaming_curve_points(
labels=array_ops.ones((10, 1)),
predictions=array_ops.ones((10, 1)),
metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [points])
def testUpdatesCollection(self):
my_collection_name = '__updates__'
_, update_op = metric_ops.streaming_curve_points(
labels=array_ops.ones((10, 1)),
predictions=array_ops.ones((10, 1)),
updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def _testValueTensorIsIdempotent(self, curve):
predictions = constant_op.constant(
np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32)
labels = constant_op.constant(
np.random.uniform(high=2, size=(10, 3)), dtype=dtypes_lib.float32)
points, update_op = metric_ops.streaming_curve_points(
labels, predictions=predictions, curve=curve)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
initial_points = points.eval()
sess.run(update_op)
self.assertAllClose(initial_points, points.eval())
def testValueTensorIsIdempotentROC(self):
self._testValueTensorIsIdempotent(curve='ROC')
def testValueTensorIsIdempotentPR(self):
self._testValueTensorIsIdempotent(curve='PR')
def _testCase(self, labels, predictions, curve, expected_points):
with self.test_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
points, update_op = metric_ops.streaming_curve_points(
labels=labels_tensor,
predictions=predictions_tensor,
num_thresholds=3,
curve=curve)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAllClose(expected_points, points.eval())
def testEdgeCasesROC(self):
self._testCase([[1]], [[1]], 'ROC', [[0, 1], [0, 1], [0, 0]])
self._testCase([[0]], [[0]], 'ROC', [[1, 1], [0, 1], [0, 1]])
self._testCase([[0]], [[1]], 'ROC', [[1, 1], [1, 1], [0, 1]])
self._testCase([[1]], [[0]], 'ROC', [[0, 1], [0, 0], [0, 0]])
def testManyValuesROC(self):
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'ROC',
[[1.0, 1.0], [0.0, 0.75], [0.0, 0.0]])
def testEdgeCasesPR(self):
self._testCase([[1]], [[1]], 'PR', [[1, 1], [1, 1], [0, 1]])
self._testCase([[0]], [[0]], 'PR', [[1, 0], [1, 1], [1, 1]])
self._testCase([[0]], [[1]], 'PR', [[1, 0], [1, 0], [1, 1]])
self._testCase([[1]], [[0]], 'PR', [[1, 1], [0, 1], [0, 1]])
def testManyValuesPR(self):
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'PR',
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
class StreamingAUCTest(test.TestCase):
def setUp(self):

View File

@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase):
class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention):
encoder_sequence_length = [3, 2, 3, 1, 1]
decoder_sequence_length = [2, 0, 1, 2, 3]
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
batch_size = 5
decoder_max_time = 4
input_depth = 7
@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase):
batch_size_tensor = constant_op.constant(batch_size)
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
if has_attention:
inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time,
@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase):
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
initial_state = beam_search_decoder.tile_batch(
initial_state, multiplier=beam_width)
cell = attention_wrapper.AttentionWrapper(
cell=cell,
attention_mechanism=attention_mechanism,
@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase):
alignment_history=False)
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
cell_state = cell_state.clone(
cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoder(
cell=cell,
embedding=embedding,

View File

@ -72,10 +72,30 @@ class FinalBeamSearchDecoderOutput(
pass
def tile_batch(t, multiplier, name=None):
"""Tile the batch dimension of tensor t.
def _tile_batch(t, multiplier):
"""Core single-tensor implementation of tile_batch."""
t = ops.convert_to_tensor(t, name="t")
shape_t = array_ops.shape(t)
if t.shape.ndims is None or t.shape.ndims < 1:
raise ValueError("t must have statically known rank")
tiling = [1] * (t.shape.ndims + 1)
tiling[1] = multiplier
tiled_static_batch_size = (
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
tiled = array_ops.reshape(
tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
tiled.set_shape(
tensor_shape.TensorShape(
[tiled_static_batch_size]).concatenate(t.shape[1:]))
return tiled
This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
def tile_batch(t, multiplier, name=None):
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
For each tensor t in a (possibly nested structure) of tensors,
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
@ -87,27 +107,16 @@ def tile_batch(t, multiplier, name=None):
name: Name scope for any created operations.
Returns:
A `Tensor` shaped `[batch_size * multiplier, ...]`.
A (possibly nested structure of) `Tensor` shaped
`[batch_size * multiplier, ...]`.
Raises:
ValueError: if `t` does not have a statically known rank or it's < 1.
ValueError: if tensor(s) `t` do not have a statically known rank or
the rank is < 1.
"""
with ops.name_scope(name, "tile_batch", [t, multiplier]):
t = ops.convert_to_tensor(t, name="t")
shape_t = array_ops.shape(t)
if t.shape.ndims is None or t.shape.ndims < 1:
raise ValueError("t must have statically known rank")
tiling = [1] * (t.shape.ndims + 1)
tiling[1] = multiplier
tiled_static_batch_size = (
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
tiled = array_ops.reshape(
tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
tiled.set_shape(
tensor_shape.TensorShape(
[tiled_static_batch_size]).concatenate(t.shape[1:]))
return tiled
flat_t = nest.flatten(t)
with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
def _check_maybe(t):

View File

@ -270,7 +270,7 @@ class SessionBundleTest : public ::testing::Test {
// MetaGraphDef.
// Returns the path of the export.
// ** Should only be called once per test **
string SetupExport(MetaGraphDefTwiddler twiddler) {
string SetupExport(const MetaGraphDefTwiddler& twiddler) {
return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename);
}
// SetupExport that allows for the variables and meta_graph_def filenames

View File

@ -62,6 +62,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"full_path",
"if_android",
"if_ios",
"if_x86",

View File

@ -30,7 +30,11 @@ Device::Device(Env* env, const DeviceAttributes& device_attributes)
rmgr_ = new ResourceMgr(parsed_name_.job);
}
Device::~Device() { delete rmgr_; }
Device::~Device() {
if (rmgr_ != nullptr) {
DeleteResourceMgr();
}
}
// static
DeviceAttributes Device::BuildDeviceAttributes(

View File

@ -60,7 +60,9 @@ class Device : public DeviceBase {
const string& name() const { return device_attributes_.name(); }
// Parsed name of this device
const DeviceNameUtils::ParsedName parsed_name() const { return parsed_name_; }
const DeviceNameUtils::ParsedName& parsed_name() const {
return parsed_name_;
}
// Describes what kind of device this is. This is intended to be
// human-readable and not computer-parsed, except that two devices
@ -149,6 +151,12 @@ class Device : public DeviceBase {
return BuildDeviceAttributes(name, device, memory_limit, locality, "");
}
protected:
void DeleteResourceMgr() {
delete rmgr_;
rmgr_ = nullptr;
}
private:
const DeviceAttributes device_attributes_;
DeviceNameUtils::ParsedName parsed_name_;

View File

@ -53,7 +53,7 @@ Device* DeviceSet::FindDeviceByName(const string& name) const {
// static
int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
return DeviceFactory::DevicePriority(d.type());
return DeviceFactory::DevicePriority(d.type_string());
}
static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) {

View File

@ -1231,7 +1231,7 @@ Status FunctionDefToBodyHelper(
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
Status s = ConvertGraphDefToGraph(opts, result.gdef, graph);
Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph);
if (!s.ok()) {
delete graph;
} else {

View File

@ -93,7 +93,7 @@ class FunctionTest : public ::testing::Test {
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g));
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g));
const int version = g->versions().producer();
LocalExecutorParams params;
@ -949,7 +949,7 @@ GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get()));
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
pass(g.get());
std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
CopyGraph(*g, g1.get());

View File

@ -324,6 +324,7 @@ static void BM_AllocationDelayed(int iters, int delay) {
int size_index = 0;
std::vector<void*> ptrs;
ptrs.reserve(delay);
for (int i = 0; i < delay; i++) {
ptrs.push_back(nullptr);
}

View File

@ -123,10 +123,12 @@ void Benchmark::RunWithArgs(
}
// Gets inputs' and outputs' rendezvous keys.
std::vector<std::pair<string, Tensor>> in;
in.reserve(inputs.size());
for (const auto& p : inputs) {
in.push_back({GetRendezvousKey(p.first), p.second});
}
std::vector<string> out;
out.reserve(outputs.size());
for (const auto& n : outputs) {
out.push_back(GetRendezvousKey(n));
}

View File

@ -94,6 +94,7 @@ Status SessionFactory::GetFactory(const SessionOptions& options,
// TODO(mrry): Consider providing a system-default fallback option
// in this case.
std::vector<string> factory_types;
factory_types.reserve(candidate_factories.size());
for (const auto& candidate_factory : candidate_factories) {
factory_types.push_back(candidate_factory.first);
}

Some files were not shown because too many files have changed in this diff Show More