fix merge issues
This commit is contained in:
commit
5efd272aab
RELEASE.mdWORKSPACE
tensorflow
BUILD
c
cc
compiler
aot
jit
tf2xla
xla
client
literal_util.ccliteral_util.hliteral_util_test.ccpacked_literal_reader.ccpacked_literal_reader.hreference_util.ccservice
BUILDallocation_tracker.cccopy_insertion.cccopy_insertion_test.ccdfs_hlo_visitor.hdfs_hlo_visitor_with_default.hexecution_tracker.cc
shape_tree.hshape_tree_test.ccshape_util.ccgpu
hlo.protohlo_computation.cchlo_instruction.cchlo_instruction.hinstruction_fusion.ccinstruction_fusion_test.ccllvm_ir
service.ccsession.prototransfer_manager.htransfer_manager_test.ccuser_computation.ccuser_computation_test.cctests
array_elementwise_ops_test.ccclient_library_test_base.ccconcat_test.ccliteral_test_util.ccliteral_test_util_test.cclog_test.ccparams_test.ccslice_test.ccvector_ops_simple_test.cc
text_literal_reader.htext_literal_writer.htools
dumped_computation_to_operation_list.ccdumped_computation_to_text.ccreplay_computation.ccshow_literal.cc
xla.protoxla_data.protocontrib
BUILD
batching
boosted_trees/lib/utils
cmake
data/python/kernel_tests
factorization
ffmpeg/default
layers/kernels
learn/python/learn/estimators
lookup
metrics
seq2seq/python
session_bundle
core
@ -41,6 +41,15 @@
|
||||
be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
|
||||
processing of the rnn. For RNN decoding, this functionality has been replaced
|
||||
with an alternative API in `tf.contrib.seq2seq`.
|
||||
* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of
|
||||
optimized deep learning primitives: In addition to matrix multiplication and
|
||||
convolution, these building blocks include:
|
||||
Direct batched convolution
|
||||
Pooling: maximum, minimum, average
|
||||
Normalization: LRN, batch normalization
|
||||
Activation: rectified linear unit (ReLU)
|
||||
Data manipulation: multi-dimensional transposition (conversion), split,
|
||||
concat, sum and scale.
|
||||
* TensorForest Estimator now supports SavedModel export for serving.
|
||||
* Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters.
|
||||
* TensorFlow C library now available for Windows.
|
||||
|
@ -2,11 +2,11 @@ workspace(name = "org_tensorflow")
|
||||
|
||||
http_archive(
|
||||
name = "io_bazel_rules_closure",
|
||||
sha256 = "4be8a887f6f38f883236e77bb25c2da10d506f2bf1a8e5d785c0f35574c74ca4",
|
||||
strip_prefix = "rules_closure-aac19edc557aec9b603cd7ffe359401264ceff0d",
|
||||
sha256 = "edc91f556b762fc5212d1050d00b12e40dd0b0b1c1d5d96886b59e9a30a6cae4",
|
||||
strip_prefix = "rules_closure-3f07fb6a58870afbb36051bd5d54da4479561cc6",
|
||||
urls = [
|
||||
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", # 2017-05-10
|
||||
"https://github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz",
|
||||
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz", # 2017-05-31
|
||||
"https://github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -393,6 +393,9 @@ filegroup(
|
||||
"//tensorflow/tensorboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
|
||||
"//tensorflow/tensorboard/plugins:all_files",
|
||||
"//tensorflow/tensorboard/plugins/audio:all_files",
|
||||
"//tensorflow/tensorboard/plugins/distributions:all_files",
|
||||
"//tensorflow/tensorboard/plugins/graphs:all_files",
|
||||
"//tensorflow/tensorboard/plugins/histograms:all_files",
|
||||
"//tensorflow/tensorboard/plugins/images:all_files",
|
||||
"//tensorflow/tensorboard/plugins/projector:all_files",
|
||||
|
@ -805,6 +805,7 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
|
||||
}
|
||||
|
||||
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
|
||||
dim_vec.reserve(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dim_vec.push_back(ic->MakeDim(dims[i]));
|
||||
}
|
||||
|
@ -113,10 +113,12 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
|
||||
feeds.emplace_back(feed.first.name(), feed.second.tensor);
|
||||
}
|
||||
std::vector<string> output_tensor_names;
|
||||
output_tensor_names.reserve(fetch_outputs.size());
|
||||
for (auto const& output : fetch_outputs) {
|
||||
output_tensor_names.push_back(output.name());
|
||||
}
|
||||
std::vector<string> target_node_names;
|
||||
target_node_names.reserve(run_outputs.size());
|
||||
for (auto const& output : run_outputs) {
|
||||
target_node_names.push_back(output.node()->name());
|
||||
}
|
||||
|
@ -44,6 +44,7 @@ Status ComputeTheoreticalJacobianTranspose(
|
||||
size_t x_num = x_shapes.size();
|
||||
// Call AddSymbolicGradients to get 'dxs' (we will feed 'dys').
|
||||
OutputList dys;
|
||||
dys.reserve(y_shapes.size());
|
||||
for (const auto& y_shape : y_shapes) {
|
||||
// TODO(suharshs): This currently assumes that all x's are the same type.
|
||||
dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type()));
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/default_device.h"
|
||||
@ -30,7 +32,7 @@ void GetTensors(const Scope& scope, OutputList tensors,
|
||||
|
||||
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
|
||||
std::vector<Tensor> outputs;
|
||||
GetTensors(scope, {tensor}, &outputs);
|
||||
GetTensors(scope, {std::move(tensor)}, &outputs);
|
||||
*out = outputs[0];
|
||||
}
|
||||
|
||||
|
@ -350,6 +350,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
compile_result->program_shape = *pshape_or.ValueOrDie();
|
||||
xla::ProgramShape* pshape = &compile_result->program_shape;
|
||||
std::vector<const xla::Shape*> arg_layouts;
|
||||
arg_layouts.reserve(pshape->parameters_size());
|
||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||
arg_layouts.push_back(pshape->mutable_parameters(i));
|
||||
}
|
||||
|
@ -218,6 +218,7 @@ cc_library(
|
||||
deps = [
|
||||
":common",
|
||||
":graph_to_functiondef",
|
||||
":union_find",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/jit/kernels:parallel_check_op",
|
||||
"//tensorflow/compiler/jit/kernels:xla_local_launch_op",
|
||||
@ -237,6 +238,11 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "union_find",
|
||||
hdrs = ["union_find.h"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "compilation_passes_test",
|
||||
size = "small",
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
@ -101,12 +103,12 @@ Node* Input(const GraphDefBuilder::Options& opts) {
|
||||
}
|
||||
|
||||
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
||||
return ops::UnaryOp("UnaryTest", a, opts);
|
||||
return ops::UnaryOp("UnaryTest", std::move(a), opts);
|
||||
}
|
||||
|
||||
Node* Binary(ops::NodeOut a, ops::NodeOut b,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
return ops::BinaryOp("BinaryTest", a, b, opts);
|
||||
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
|
||||
}
|
||||
|
||||
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
|
||||
@ -127,7 +129,7 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
|
||||
opts.op_registry());
|
||||
node_builder.Input(a).Attr("index", index);
|
||||
node_builder.Input(std::move(a)).Attr("index", index);
|
||||
return opts.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -206,70 +207,12 @@ Status FindCompilationCandidates(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Union-Find data structure used to compute clusters. We use our own
|
||||
// implementation because we want one key feature: when merging clusters, we
|
||||
// need to know which value becomes the representative of the merged clusters.
|
||||
// We use the representatives to name nodes in a cycle detection graph, and we
|
||||
// need to control which node is named.
|
||||
// TODO(phawkins): consider merging this code with union-find implementations
|
||||
// in Tensorflow, e.g., in SimplePlacer.
|
||||
class Cluster {
|
||||
public:
|
||||
Cluster();
|
||||
|
||||
int Size() { return FindRoot()->size_; }
|
||||
|
||||
// Merges this cluster with 'other'. This cluster's representative becomes
|
||||
// the representative of the merged cluster; the representative of 'other'
|
||||
// is ignored.
|
||||
void Merge(Cluster* other);
|
||||
|
||||
// Each cluster has an associated integer 'representative', initialized to -1
|
||||
// by default.
|
||||
int GetRepresentative() { return FindRoot()->representative_; }
|
||||
void SetRepresentative(int representative) {
|
||||
FindRoot()->representative_ = representative;
|
||||
}
|
||||
|
||||
private:
|
||||
// Finds the root element of the cluster. Performs path compression.
|
||||
Cluster* FindRoot();
|
||||
|
||||
int representative_;
|
||||
int rank_;
|
||||
int size_; // Size of the cluster.
|
||||
Cluster* parent_;
|
||||
struct Cluster {
|
||||
// Identifies the node that represents this cluster in the cycle detection
|
||||
// graph.
|
||||
int representative = -1;
|
||||
};
|
||||
|
||||
Cluster::Cluster()
|
||||
: representative_(-1), rank_(0), size_(1), parent_(nullptr) {}
|
||||
|
||||
void Cluster::Merge(Cluster* other) {
|
||||
Cluster* a = FindRoot();
|
||||
Cluster* b = other->FindRoot();
|
||||
if (a == b) return;
|
||||
if (a->rank_ > b->rank_) {
|
||||
b->parent_ = a;
|
||||
a->size_ += b->size_;
|
||||
return;
|
||||
}
|
||||
|
||||
a->parent_ = b;
|
||||
if (a->rank_ == b->rank_) {
|
||||
b->rank_++;
|
||||
}
|
||||
b->representative_ = a->representative_;
|
||||
b->size_ += a->size_;
|
||||
}
|
||||
|
||||
Cluster* Cluster::FindRoot() {
|
||||
if (!parent_) return this;
|
||||
// Path compression: update intermediate nodes to point to the root of the
|
||||
// equivalence class.
|
||||
parent_ = parent_->FindRoot();
|
||||
return parent_;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
||||
@ -432,10 +375,11 @@ Status MarkForCompilationPass::RunImpl(
|
||||
// Each compilation candidate belongs to a cluster. The cluster's
|
||||
// representative
|
||||
// names the node in the 'cycles' graph that represents the cluster.
|
||||
std::vector<Cluster> clusters(graph->num_node_ids());
|
||||
std::deque<Cluster*> worklist;
|
||||
std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids());
|
||||
std::deque<UnionFind<Cluster>*> worklist;
|
||||
for (Node* node : compilation_candidates) {
|
||||
clusters[node->id()].SetRepresentative(node->id());
|
||||
Cluster& cluster = clusters[node->id()].Get();
|
||||
cluster.representative = node->id();
|
||||
worklist.push_back(&clusters[node->id()]);
|
||||
}
|
||||
|
||||
@ -445,7 +389,7 @@ Status MarkForCompilationPass::RunImpl(
|
||||
// Repeatedly contract edges between clusters that are on the same device,
|
||||
// provided the contraction would not create a cycle.
|
||||
while (!worklist.empty()) {
|
||||
int from = worklist.front()->GetRepresentative();
|
||||
int from = worklist.front()->Get().representative;
|
||||
worklist.pop_front();
|
||||
|
||||
Node* node_from = graph->FindNodeId(from);
|
||||
@ -518,7 +462,7 @@ Status MarkForCompilationPass::RunImpl(
|
||||
// Count the number of elements in each cluster.
|
||||
std::vector<int> cluster_sizes(graph->num_node_ids());
|
||||
for (const Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].GetRepresentative();
|
||||
int cluster = clusters[n->id()].Get().representative;
|
||||
cluster_sizes[cluster]++;
|
||||
}
|
||||
|
||||
@ -532,7 +476,7 @@ Status MarkForCompilationPass::RunImpl(
|
||||
// if compilation is enabled, otherwise there will be no such candidates).
|
||||
const int min_cluster_size = flags->tf_xla_min_cluster_size;
|
||||
for (Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].GetRepresentative();
|
||||
int cluster = clusters[n->id()].Get().representative;
|
||||
|
||||
// Compile if the user marked this node _XlaCompile=true
|
||||
bool compile_attr = false;
|
||||
|
81
tensorflow/compiler/jit/union_find.h
Normal file
81
tensorflow/compiler/jit/union_find.h
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Union-Find data structure.
|
||||
// Each cluster has an associated value; when merging clusters we can control
|
||||
// which value becomes the representative of the merged clusters. Values must be
|
||||
// copyable.
|
||||
template <typename T>
|
||||
class UnionFind {
|
||||
public:
|
||||
UnionFind() : rank_(0), size_(1), parent_(nullptr) {}
|
||||
|
||||
// Returns the number of elements in a cluster.
|
||||
int Size() { return FindRoot()->size_; }
|
||||
|
||||
// Merges this cluster with 'other'. This cluster's value becomes
|
||||
// the value of the merged cluster; the value of 'other' is ignored.
|
||||
void Merge(UnionFind* other);
|
||||
|
||||
// Each cluster has an associated value. Retrieves the value associated
|
||||
// with this cluster.
|
||||
T& Get() { return FindRoot()->value_; }
|
||||
|
||||
private:
|
||||
// Finds the root element of the cluster. Performs path compression.
|
||||
UnionFind* FindRoot();
|
||||
|
||||
int rank_;
|
||||
int size_; // Size of the cluster.
|
||||
UnionFind* parent_;
|
||||
T value_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void UnionFind<T>::Merge(UnionFind* other) {
|
||||
UnionFind<T>* a = FindRoot();
|
||||
UnionFind<T>* b = other->FindRoot();
|
||||
if (a == b) return;
|
||||
if (a->rank_ > b->rank_) {
|
||||
b->parent_ = a;
|
||||
a->size_ += b->size_;
|
||||
return;
|
||||
}
|
||||
|
||||
a->parent_ = b;
|
||||
if (a->rank_ == b->rank_) {
|
||||
b->rank_++;
|
||||
}
|
||||
b->value_ = a->value_;
|
||||
b->size_ += a->size_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
UnionFind<T>* UnionFind<T>::FindRoot() {
|
||||
if (!parent_) return this;
|
||||
// Path compression: update intermediate nodes to point to the root of the
|
||||
// equivalence class.
|
||||
parent_ = parent_->FindRoot();
|
||||
return parent_;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
|
@ -50,6 +50,7 @@ class FillOp : public XlaOpKernel {
|
||||
// Convert the dims literal into a vector that we can pass to
|
||||
// ComputationBuilder.
|
||||
std::vector<int64> broadcast;
|
||||
broadcast.reserve(dims_literal.shape().dimensions(0));
|
||||
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
|
||||
broadcast.push_back(xla::LiteralUtil::Get<int>(dims_literal, {i}));
|
||||
}
|
||||
|
@ -50,6 +50,7 @@ class SliceOp : public XlaOpKernel {
|
||||
// slice will be an empty handle if the output has no elements.
|
||||
CHECK_EQ(begin.size(), size.size());
|
||||
std::vector<int64> limits;
|
||||
limits.reserve(begin.size());
|
||||
for (int i = 0; i < begin.size(); ++i) {
|
||||
limits.push_back(begin[i] + size[i]);
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
@ -58,14 +58,13 @@ StatusOr<std::unique_ptr<Literal>> Client::Transfer(
|
||||
"server provided response without a literal in "
|
||||
"TransferToClient request");
|
||||
}
|
||||
|
||||
return WrapUnique(response.release_literal());
|
||||
return MakeUnique<Literal>(response.literal());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
|
||||
const Literal& literal, const DeviceHandle* device_handle) {
|
||||
TransferToServerRequest request;
|
||||
*request.mutable_literal() = literal;
|
||||
*request.mutable_literal() = literal.ToProto();
|
||||
if (device_handle) {
|
||||
*request.mutable_device_handle() = *device_handle;
|
||||
}
|
||||
@ -93,7 +92,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
|
||||
Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
|
||||
const DeviceHandle* device_handle) {
|
||||
TransferToInfeedRequest request;
|
||||
*request.mutable_literal() = literal;
|
||||
*request.mutable_literal() = literal.ToProto();
|
||||
if (device_handle) {
|
||||
*request.mutable_device_handle() = *device_handle;
|
||||
}
|
||||
@ -141,7 +140,8 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
|
||||
"TransferToClient request");
|
||||
}
|
||||
|
||||
return WrapUnique(response.release_literal());
|
||||
Literal literal(response.literal());
|
||||
return MakeUnique<Literal>(literal);
|
||||
}
|
||||
|
||||
Status Client::ResetDevice() {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/session.pb.h"
|
||||
#include "tensorflow/compiler/xla/service_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
@ -165,9 +165,10 @@ ComputationDataHandle ComputationBuilder::ConstantOp(
|
||||
}
|
||||
|
||||
ConstantRequest request;
|
||||
Literal* literal = request.mutable_literal();
|
||||
populate(literal);
|
||||
VLOG(3) << "created constant: " << literal->ShortDebugString();
|
||||
Literal literal;
|
||||
populate(&literal);
|
||||
*request.mutable_literal() = literal.ToProto();
|
||||
VLOG(3) << "created constant: " << request.literal().ShortDebugString();
|
||||
OpRequest op_request;
|
||||
*op_request.mutable_constant_request() = request;
|
||||
*op_request.mutable_computation() = computation_.handle();
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -23,7 +24,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle)
|
||||
: handle_(handle), parent_(parent) {}
|
||||
: handle_(std::move(handle)), parent_(parent) {}
|
||||
|
||||
GlobalData::~GlobalData() {
|
||||
UnregisterRequest request;
|
||||
|
@ -222,8 +222,9 @@ tensorflow::Status LocalExecutable::RecordArguments(
|
||||
SessionModule* session_module) {
|
||||
session_module->clear_arguments();
|
||||
for (const ShapedBuffer* argument : arguments) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
LiteralFromShapedBuffer(*argument, session_module->add_arguments()));
|
||||
Literal literal;
|
||||
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal));
|
||||
*session_module->add_arguments() = literal.ToProto();
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -231,9 +232,13 @@ tensorflow::Status LocalExecutable::RecordArguments(
|
||||
tensorflow::Status LocalExecutable::RecordResult(
|
||||
const ShapedBuffer* result, SessionModule* session_module) {
|
||||
session_module->clear_result();
|
||||
return LiteralFromShapedBuffer(*result, session_module->mutable_result());
|
||||
Literal literal(session_module->result());
|
||||
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal));
|
||||
*session_module->mutable_result() = literal.ToProto();
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// TODO(dnovillo) Change signature to return StatusOr<Literal>.
|
||||
tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
|
||||
const ShapedBuffer& shaped_buffer, Literal* literal) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -856,5 +856,26 @@ TEST_F(LiteralUtilTest, ConvertR4) {
|
||||
EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
|
||||
LiteralProto p;
|
||||
p.mutable_shape()->set_element_type(PRED);
|
||||
for (int len = 0; len < 25; ++len) {
|
||||
p.mutable_shape()->clear_dimensions();
|
||||
p.mutable_shape()->add_dimensions(len);
|
||||
p.clear_preds();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
p.add_preds((i % 2) == (len % 2));
|
||||
}
|
||||
|
||||
Literal literal(p);
|
||||
ASSERT_EQ(len, literal.preds_size());
|
||||
int i = 0;
|
||||
for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) {
|
||||
EXPECT_EQ((i % 2) == (len % 2), *it);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -60,8 +60,8 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
|
||||
int64 elements = ShapeUtil::ElementsIn(shape);
|
||||
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
|
||||
result.get());
|
||||
tensorflow::protobuf::RepeatedField<float>* field = result->mutable_f32s();
|
||||
char* data = tensorflow::bit_cast<char*>(field->mutable_data());
|
||||
std::vector<float>* field = result->mutable_f32s();
|
||||
char* data = tensorflow::bit_cast<char*>(field->data());
|
||||
uint64 bytes = elements * sizeof(float);
|
||||
tensorflow::StringPiece sp;
|
||||
auto s = file_->Read(offset_, bytes, &sp, data);
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/reference_util.h"
|
||||
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
@ -331,7 +332,8 @@ ReferenceUtil::ConvArray4DGeneralDimensions(
|
||||
std::pair<int64, int64> kernel_stride, Padding padding,
|
||||
ConvolutionDimensionNumbers dimension_numbers) {
|
||||
return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
|
||||
{1, 1}, {1, 1}, dimension_numbers);
|
||||
{1, 1}, {1, 1},
|
||||
std::move(dimension_numbers));
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Array4D<float>>
|
||||
|
@ -529,6 +529,7 @@ cc_library(
|
||||
srcs = ["transfer_manager.cc"],
|
||||
hdrs = ["transfer_manager.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -1680,10 +1681,8 @@ cc_library(
|
||||
deps = [
|
||||
":buffer_assignment",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_proto",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
@ -171,6 +171,7 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
|
||||
executor, allocation->device_memory(), allocation->shape()));
|
||||
|
||||
std::vector<GlobalDataHandle> element_handles;
|
||||
element_handles.reserve(element_bases.size());
|
||||
for (int i = 0; i < element_bases.size(); ++i) {
|
||||
element_handles.push_back(RegisterInternal(
|
||||
allocation->backend(), allocation->device_ordinal(), element_bases[i],
|
||||
|
@ -229,25 +229,26 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
|
||||
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
|
||||
FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>>
|
||||
buffer_to_source_indices;
|
||||
TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices](
|
||||
const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
if (buffers.size() > 1) {
|
||||
// Record ambiguous points-to set at 'index'.
|
||||
if (!indices_to_copy_.element(index)) {
|
||||
VLOG(2) << "Adding copy of buffer for instruction: "
|
||||
<< instruction_->name()
|
||||
<< " at index: " << tensorflow::str_util::Join(index, ",")
|
||||
<< " with ambiguous points-to set.";
|
||||
RecordIndex(index);
|
||||
}
|
||||
}
|
||||
// For each 'buffer': record a mapping from 'buffer' to 'index'.
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
buffer_to_source_indices[buffer].push_back(index);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
TF_RETURN_IF_ERROR(points_to.ForEachElement(
|
||||
[this, &buffer_to_source_indices](
|
||||
const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
if (buffers.size() > 1) {
|
||||
// Record ambiguous points-to set at 'index'.
|
||||
if (!indices_to_copy_.element(index)) {
|
||||
VLOG(2) << "Adding copy of buffer for instruction: "
|
||||
<< instruction_->name()
|
||||
<< " at index: " << tensorflow::str_util::Join(index, ",")
|
||||
<< " with ambiguous points-to set.";
|
||||
RecordIndex(index);
|
||||
}
|
||||
}
|
||||
// For each 'buffer': record a mapping from 'buffer' to 'index'.
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
buffer_to_source_indices[buffer].push_back(index);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
// Record all non-distinct indices detected in 'buffer_to_source_indices'.
|
||||
for (const auto& buff_to_src : buffer_to_source_indices) {
|
||||
@ -449,11 +450,15 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
||||
FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) {
|
||||
const HloInstruction* init_hlo = while_hlo->operand(0);
|
||||
const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo);
|
||||
|
||||
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
|
||||
FlatSet<const LogicalBuffer*> buffer_set;
|
||||
|
||||
ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
|
||||
TF_RETURN_IF_ERROR(points_to.ForEachElement(
|
||||
[init_hlo, read_only_indices, shared_copies, ©_overrides](
|
||||
const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
[init_hlo, read_only_indices, shared_copies, &buffer_set,
|
||||
©_overrides](const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
// Look for read-only entry parameters.
|
||||
if (!read_only_indices->element(index)) {
|
||||
return Status::OK();
|
||||
@ -468,6 +473,7 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
||||
if (!is_entry_parameter && !is_constant) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We have found an entry parameter or constant that is read-only in
|
||||
// the while body. These buffers are managed by the caller, and cannot
|
||||
// be aliased with non-parameter buffers. Revert this read-only index,
|
||||
@ -476,16 +482,17 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
||||
|
||||
// Optimization to allow multiple while loops that share the same
|
||||
// read-only entry parameters (or constants) to share a single copy.
|
||||
// Only unambiguous array-shaped buffers are allowed, to reduce code
|
||||
// complexity. The shape of the entry parameter must be identical to
|
||||
// the shape of the init_hlo at this index, to ensure there were no
|
||||
// intervening bitcast or GTE instructions, which are also hard to
|
||||
// handle.
|
||||
// Only unambiguous and distinct array-shaped buffers are allowed, to
|
||||
// reduce code complexity. The shape of the entry parameter must be
|
||||
// identical to the shape of the init_hlo at this index, to ensure
|
||||
// there were no intervening bitcast or GTE instructions, which are
|
||||
// also hard to handle.
|
||||
const Shape& pointee_shape = pointee->shape();
|
||||
const Shape& init_shape =
|
||||
ShapeUtil::GetSubshape(init_hlo->shape(), index);
|
||||
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
|
||||
ShapeUtil::Equal(pointee_shape, init_shape)) {
|
||||
ShapeUtil::Equal(pointee_shape, init_shape) &&
|
||||
buffer_set.count(buffer) < 1) {
|
||||
HloInstruction** copy = &(*shared_copies)[pointee];
|
||||
if (*copy == nullptr) {
|
||||
*copy =
|
||||
@ -496,6 +503,9 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
||||
*copy_overrides.mutable_element(index) = *copy;
|
||||
}
|
||||
|
||||
// Tracks whether this current buffer is distinct.
|
||||
buffer_set.insert(buffer);
|
||||
|
||||
// We've already reverted the read-only index and handled the
|
||||
// single-copy optimization above, so there's nothing more to do.
|
||||
break;
|
||||
|
@ -44,13 +44,20 @@ class CopyInsertionTest : public HloTestBase {
|
||||
EXPECT_IS_OK(copy_insertion.Run(module).status());
|
||||
|
||||
// Verify the points to set of the root of the computation after copy
|
||||
// insertion contains no constants or parameters.
|
||||
// insertion contains no constants or parameters, and is distinct and
|
||||
// non-ambiguous.
|
||||
auto points_to_analysis =
|
||||
TuplePointsToAnalysis::Run(module).ConsumeValueOrDie();
|
||||
const auto& points_to = points_to_analysis->GetPointsToSet(
|
||||
module->entry_computation()->root_instruction());
|
||||
EXPECT_TRUE(points_to.IsDistinct());
|
||||
EXPECT_TRUE(!points_to.IsAmbiguous());
|
||||
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers =
|
||||
points_to_analysis
|
||||
->GetPointsToSet(module->entry_computation()->root_instruction())
|
||||
.CreateFlattenedSet();
|
||||
|
||||
for (const LogicalBuffer* buffer : maybe_live_out_buffers) {
|
||||
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant);
|
||||
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter);
|
||||
@ -390,6 +397,47 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
// Builds a While body computation with two output tuple elements dependent on
|
||||
// both input tuple elements.
|
||||
//
|
||||
// EX: Body({in0, in1, in2})
|
||||
// out0 = Add(in0, 1)
|
||||
// out1 = in1
|
||||
// out2 = in2
|
||||
// Tuple(out0, out1, out2)
|
||||
std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
|
||||
auto builder = HloComputation::Builder(TestName() + ".Body");
|
||||
|
||||
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
|
||||
{induction_variable_shape_, data_shape_, data_shape_});
|
||||
|
||||
auto loop_state = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
|
||||
|
||||
// Update the induction variable GTE(0).
|
||||
auto induction_variable =
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
induction_variable_shape_, loop_state, 0));
|
||||
auto inc = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
|
||||
// add0 = Add(in0, 1)
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
|
||||
// data1 = GTE(1).
|
||||
HloInstruction* data1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
||||
|
||||
// data2 = GTE(2).
|
||||
HloInstruction* data2 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
|
||||
|
||||
// Create output Tuple.
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
|
||||
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
// Builds a While body computation with read-only tuple element 0.
|
||||
// EX:
|
||||
// Body({in0, in1})
|
||||
@ -408,6 +456,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
// Update data GTE(1).
|
||||
auto data = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
||||
|
||||
// Use 'induction_variable' in computation with no path to output tuple.
|
||||
auto update = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
|
||||
@ -431,6 +480,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
// Create param instruction to access loop state.
|
||||
const Shape& loop_state_shape =
|
||||
nested ? nested_loop_state_shape_ : loop_state_shape_;
|
||||
|
||||
auto loop_state = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
|
||||
// Update the induction variable GTE(0).
|
||||
@ -972,7 +1022,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
|
||||
op::Copy(old_init->operand(1)->operand(0)))));
|
||||
}
|
||||
|
||||
// Tests while init instruction buffer which interferes with while result buffer.
|
||||
// Tests while init instruction buffer which interferes with while result
|
||||
// buffer.
|
||||
//
|
||||
// init_data = Broadcast(...)
|
||||
// add_unrelated = Add(init_data) // takes a reference to cause interference
|
||||
@ -989,5 +1040,81 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
|
||||
op::Copy(old_init->operand(1))));
|
||||
}
|
||||
|
||||
// Tests while init instruction buffer which has a non-distinct points-to set:
|
||||
//
|
||||
// init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
|
||||
// Parameter(F32, {8})))
|
||||
//
|
||||
// where the second and third parameters are identical *and* the tuple shared
|
||||
// by another while instruction..
|
||||
//
|
||||
// Verifies that the resulting point-to set is distinct in the resulting Tuple
|
||||
// (non-identical Copys). In other words, verifies that copy sharing does not
|
||||
// insert identical copies to the resulting tuple.
|
||||
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
|
||||
auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation());
|
||||
auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation());
|
||||
// Loop body that outputs tuple comprises two elements dependent on the init
|
||||
// tuple.
|
||||
auto body1 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
|
||||
auto body2 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName() + ".While");
|
||||
|
||||
auto iter_param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
|
||||
auto data_param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, data_shape_, "data"));
|
||||
|
||||
// Loop init tuple contains two identical parameter buffers.
|
||||
auto loop_init = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
|
||||
|
||||
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
|
||||
{induction_variable_shape_, data_shape_, data_shape_});
|
||||
|
||||
// Two while loops shares the same loop init tuple.
|
||||
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
loop_state_shape, condition1, body1, loop_init));
|
||||
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
loop_state_shape, condition2, body2, loop_init));
|
||||
|
||||
module_.AddEntryComputation(builder.Build());
|
||||
|
||||
auto points_to_analysis =
|
||||
TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
|
||||
|
||||
// Asserts that the init tuples before copy insertion is non-distinct.
|
||||
ASSERT_FALSE(
|
||||
points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct());
|
||||
ASSERT_FALSE(
|
||||
points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct());
|
||||
|
||||
auto old_init1 = while_hlo1->operand(0);
|
||||
auto old_init2 = while_hlo2->operand(0);
|
||||
|
||||
InsertCopies(&module_);
|
||||
|
||||
EXPECT_THAT(while_hlo1->operand(0),
|
||||
op::Tuple(op::Copy(old_init1->operand(0)),
|
||||
op::Copy(old_init1->operand(1)),
|
||||
op::Copy(old_init1->operand(2))));
|
||||
|
||||
EXPECT_THAT(while_hlo2->operand(0),
|
||||
op::Tuple(op::Copy(old_init2->operand(0)),
|
||||
op::Copy(old_init2->operand(1)),
|
||||
op::Copy(old_init2->operand(2))));
|
||||
|
||||
// Verifies the init tuples after copy insertion is distinct.
|
||||
points_to_analysis = TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
|
||||
const auto& points_to1 =
|
||||
points_to_analysis->GetPointsToSet(while_hlo1->operand(0));
|
||||
EXPECT_TRUE(points_to1.IsDistinct());
|
||||
|
||||
const auto& points_to2 =
|
||||
points_to_analysis->GetPointsToSet(while_hlo2->operand(0));
|
||||
EXPECT_TRUE(points_to2.IsDistinct());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
@ -31,7 +31,7 @@ AsyncExecution::AsyncExecution(Backend* backend,
|
||||
: backend_(CHECK_NOTNULL(backend)),
|
||||
streams_(std::move(streams)),
|
||||
profile_(profile),
|
||||
result_(result) {
|
||||
result_(std::move(result)) {
|
||||
for (const auto& stream : streams_) {
|
||||
CHECK(stream != nullptr);
|
||||
}
|
||||
|
@ -254,6 +254,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
|
||||
// d40 -- layer 4
|
||||
HloComputation::Builder builder("entry_computation");
|
||||
std::vector<HloInstruction*> params;
|
||||
params.reserve(6);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
|
||||
|
@ -1631,6 +1631,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk(
|
||||
|
||||
// Compute the input buffer indices.
|
||||
std::vector<BufferAllocation::Slice> io_buffers;
|
||||
io_buffers.reserve(io_hlos.size());
|
||||
for (const HloInstruction* io_hlo : io_hlos) {
|
||||
io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo)));
|
||||
}
|
||||
|
@ -86,6 +86,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
|
||||
// d40 -- layer 4
|
||||
HloComputation::Builder builder("entry_computation");
|
||||
std::vector<HloInstruction*> params;
|
||||
params.reserve(6);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
|
||||
|
@ -46,7 +46,7 @@ message HloInstructionProto {
|
||||
xla.OpMetadata metadata = 7;
|
||||
|
||||
// Literal, only present for kConstant.
|
||||
xla.Literal literal = 8;
|
||||
xla.LiteralProto literal = 8;
|
||||
|
||||
// Parameter info, only present for kParameter.
|
||||
int64 parameter_number = 9;
|
||||
|
@ -311,7 +311,6 @@ void ComputeComputationPostOrder(
|
||||
|
||||
visited->insert(computation);
|
||||
post_order->push_back(computation);
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -65,7 +65,7 @@ using ::tensorflow::strings::StrCat;
|
||||
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
|
||||
instruction->operands_.push_back(operand);
|
||||
instruction->literal_.reset(new Literal);
|
||||
*instruction->literal_->mutable_u8s() += tag;
|
||||
instruction->literal_->append_u8s(tag);
|
||||
return instruction;
|
||||
}
|
||||
|
||||
@ -1484,6 +1484,7 @@ string HloInstruction::ToString(bool compact_operands,
|
||||
}
|
||||
if (!slice_starts_.empty() && !slice_limits_.empty()) {
|
||||
std::vector<string> bounds;
|
||||
bounds.reserve(slice_starts_.size());
|
||||
for (int i = 0; i < slice_starts_.size(); ++i) {
|
||||
bounds.push_back(
|
||||
StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]"));
|
||||
@ -1550,7 +1551,7 @@ HloInstructionProto HloInstruction::ToProto() const {
|
||||
*proto.mutable_metadata() = metadata_;
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kConstant:
|
||||
*proto.mutable_literal() = *literal_;
|
||||
*proto.mutable_literal() = literal_->ToProto();
|
||||
break;
|
||||
case HloOpcode::kParameter:
|
||||
proto.set_parameter_number(parameter_number_);
|
||||
@ -1647,10 +1648,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
|
||||
trace_instruction_ = trace_instruction;
|
||||
}
|
||||
|
||||
const string& HloInstruction::tracing_tag() const {
|
||||
string HloInstruction::TracingTag() const {
|
||||
CHECK_EQ(HloOpcode::kTrace, opcode());
|
||||
CHECK(literal_ != nullptr);
|
||||
return literal_->u8s();
|
||||
return literal_->u8s_string();
|
||||
}
|
||||
|
||||
bool HloInstruction::IsFused() const {
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
@ -535,7 +536,7 @@ class HloInstruction {
|
||||
// Returns a tag to be used in tracing.
|
||||
//
|
||||
// Precondition: opcode() == HloOpcode::kTrace
|
||||
const string& tracing_tag() const;
|
||||
string TracingTag() const;
|
||||
|
||||
// Returns whether the instruction is a constant.
|
||||
bool IsConstant() const;
|
||||
|
@ -151,7 +151,26 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
||||
return true;
|
||||
};
|
||||
|
||||
if (std::all_of(hlo->users().begin(), hlo->users().end(),
|
||||
// An "effectively unary" operation is one that has one "large"
|
||||
// input with the others being negligible in terms of memory usage.
|
||||
// We use "has a smaller true rank than the output" as a heuristic
|
||||
// for "negligible" memory usage.
|
||||
auto effectively_unary = [](HloInstruction* hlo) {
|
||||
if (hlo->operands().size() == 1) {
|
||||
return true;
|
||||
}
|
||||
auto output_rank = ShapeUtil::TrueRank(hlo->shape());
|
||||
return std::count_if(
|
||||
hlo->operands().begin(), hlo->operands().end(),
|
||||
[output_rank](HloInstruction* operand) {
|
||||
return ((operand->opcode() != HloOpcode::kBroadcast) &&
|
||||
ShapeUtil::TrueRank(operand->shape()) >=
|
||||
output_rank);
|
||||
}) <= 1;
|
||||
};
|
||||
|
||||
if (effectively_unary(hlo) ||
|
||||
std::all_of(hlo->users().begin(), hlo->users().end(),
|
||||
user_fusable_into_hlo)) {
|
||||
all_consumers_fusable.insert(hlo);
|
||||
}
|
||||
|
@ -156,21 +156,67 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
|
||||
|
||||
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {16, 16}), "0"));
|
||||
HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0));
|
||||
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
|
||||
HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1));
|
||||
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||
auto param0 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
|
||||
auto param1 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
|
||||
HloInstruction* binary1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
|
||||
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
|
||||
HloInstruction* unary = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(unary2, computation->root_instruction());
|
||||
EXPECT_EQ(unary, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||
auto param0 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
|
||||
HloInstruction* unary1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
|
||||
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
|
||||
HloInstruction* unary2 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(unary2, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
|
||||
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||
auto small_shape = ShapeUtil::MakeShape(F32, {16});
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, small_shape, "0"));
|
||||
auto param1 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
|
||||
HloInstruction* binary1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
|
||||
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
|
||||
HloInstruction* unary = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(unary, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "external/llvm/include/llvm/IR/Module.h"
|
||||
#include "external/llvm/include/llvm/IR/Value.h"
|
||||
#include "external/llvm/include/llvm/Support/raw_ostream.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
@ -77,8 +77,10 @@ tensorflow::Status RecordArguments(
|
||||
SessionModule* module) {
|
||||
module->clear_arguments();
|
||||
for (const Allocation* allocation : arg_allocations) {
|
||||
TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(),
|
||||
module->add_arguments()));
|
||||
Literal argument;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LiteralFromAllocation(allocation, allocation->shape(), &argument));
|
||||
*module->add_arguments() = argument.ToProto();
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -87,8 +89,11 @@ tensorflow::Status RecordArguments(
|
||||
tensorflow::Status RecordResult(const Allocation* result_allocation,
|
||||
SessionModule* module) {
|
||||
module->clear_result();
|
||||
return LiteralFromAllocation(result_allocation, result_allocation->shape(),
|
||||
module->mutable_result());
|
||||
Literal result;
|
||||
TF_RETURN_IF_ERROR(LiteralFromAllocation(
|
||||
result_allocation, result_allocation->shape(), &result));
|
||||
*module->mutable_result() = result.ToProto();
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -649,6 +654,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
ResolveAndValidateArguments(request.arguments(), execute_backend_.get(),
|
||||
executor->device_ordinal()));
|
||||
std::vector<se::DeviceMemoryBase> arguments;
|
||||
arguments.reserve(arg_allocations.size());
|
||||
for (const Allocation* allocation : arg_allocations) {
|
||||
arguments.push_back(allocation->device_memory());
|
||||
}
|
||||
@ -677,6 +683,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
BuildExecutables(versioned_handles, std::move(module_configs),
|
||||
execute_backend_.get(), executors));
|
||||
std::vector<Executable*> executable_ptrs;
|
||||
executable_ptrs.reserve(executables.size());
|
||||
for (const auto& executable : executables) {
|
||||
executable_ptrs.push_back(executable.get());
|
||||
}
|
||||
@ -752,6 +759,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
|
||||
std::vector<se::DeviceMemoryBase> arguments;
|
||||
arguments.reserve(arg_allocations.size());
|
||||
for (const Allocation* allocation : arg_allocations) {
|
||||
arguments.push_back(allocation->device_memory());
|
||||
}
|
||||
@ -820,6 +828,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
|
||||
std::vector<se::DeviceMemoryBase> arguments;
|
||||
arguments.reserve(arg_allocations.size());
|
||||
for (const Allocation* allocation : arg_allocations) {
|
||||
arguments.push_back(allocation->device_memory());
|
||||
}
|
||||
@ -908,13 +917,15 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
|
||||
literal_shape = &allocation->shape();
|
||||
}
|
||||
|
||||
return LiteralFromAllocation(allocation, *literal_shape,
|
||||
result->mutable_literal());
|
||||
Literal literal;
|
||||
auto status = LiteralFromAllocation(allocation, *literal_shape, &literal);
|
||||
*result->mutable_literal() = literal.ToProto();
|
||||
return status;
|
||||
}
|
||||
|
||||
tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||
TransferToServerResponse* result) {
|
||||
const Literal& literal = arg->literal();
|
||||
Literal literal = Literal(arg->literal());
|
||||
const Shape& shape = literal.shape();
|
||||
|
||||
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
|
||||
@ -978,7 +989,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||
}
|
||||
|
||||
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
|
||||
executor, arg->literal());
|
||||
executor, Literal(arg->literal()));
|
||||
}
|
||||
|
||||
tensorflow::Status Service::TransferFromOutfeed(
|
||||
@ -1001,8 +1012,12 @@ tensorflow::Status Service::TransferFromOutfeed(
|
||||
executor = execute_backend_->Replicas()[arg->replica_id()];
|
||||
}
|
||||
|
||||
return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, arg->shape_with_layout(), result->mutable_literal());
|
||||
Literal literal;
|
||||
TF_RETURN_IF_ERROR(
|
||||
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, arg->shape_with_layout(), &literal));
|
||||
*result->mutable_literal() = literal.ToProto();
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
|
||||
|
@ -75,10 +75,10 @@ message SessionModule {
|
||||
repeated SessionComputation embedded_computations = 2;
|
||||
|
||||
// The arguments passed to the computation.
|
||||
repeated Literal arguments = 3;
|
||||
repeated LiteralProto arguments = 3;
|
||||
|
||||
// The result of the computation.
|
||||
Literal result = 4;
|
||||
LiteralProto result = 4;
|
||||
|
||||
// The name of the platform used to run the computation.
|
||||
string execution_platform = 5;
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -121,7 +121,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) {
|
||||
const Shape shape = ShapeUtil::MakeShape(U8, {4});
|
||||
TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
|
||||
stream_exec_, memptr, shape, shape, &literal));
|
||||
CHECK_EQ("klmn", literal.u8s());
|
||||
CHECK_EQ("klmn", literal.u8s_string());
|
||||
}
|
||||
|
||||
TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {
|
||||
|
@ -2275,7 +2275,7 @@ void ComputationLowerer::Visit(
|
||||
const ConstantRequest& constant_request =
|
||||
request.request().constant_request();
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CloneToUnique(constant_request.literal())));
|
||||
LiteralUtil::CloneToUnique(Literal(constant_request.literal()))));
|
||||
break;
|
||||
}
|
||||
|
||||
@ -2467,6 +2467,7 @@ void ComputationLowerer::Visit(
|
||||
// to append dimensions on the left the broadcast_dimensions should just
|
||||
// be the n highest dimension numbers of the output shape where n is
|
||||
// the number of input dimensions.
|
||||
broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape()));
|
||||
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
|
||||
broadcast_dimensions.push_back(i +
|
||||
ShapeUtil::Rank(request.output_shape()) -
|
||||
|
@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) {
|
||||
|
||||
ConstantRequest constant_request;
|
||||
*constant_request.mutable_literal() =
|
||||
*LiteralUtil::CreateR1<float>({123.0f, 42.0f});
|
||||
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
|
||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle,
|
||||
computation.AddConstantInstruction(constant_request));
|
||||
|
||||
@ -160,12 +160,13 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
|
||||
UserComputation computation("TheComputation", handle);
|
||||
|
||||
ConstantRequest a_request;
|
||||
*a_request.mutable_literal() = *LiteralUtil::CreateR1<float>({123.0f, 42.0f});
|
||||
*a_request.mutable_literal() =
|
||||
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
|
||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
|
||||
computation.AddConstantInstruction(a_request));
|
||||
|
||||
ConstantRequest b_request;
|
||||
*b_request.mutable_literal() = *LiteralUtil::CreateR0<float>(1.0f);
|
||||
*b_request.mutable_literal() = LiteralUtil::CreateR0<float>(1.0f)->ToProto();
|
||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
|
||||
computation.AddConstantInstruction(b_request));
|
||||
|
||||
|
@ -44,6 +44,7 @@ struct ShapeTreeNode {
|
||||
// Children of this node.
|
||||
std::vector<std::unique_ptr<ShapeTreeNode>> children;
|
||||
|
||||
ShapeTreeNode() = default;
|
||||
explicit ShapeTreeNode(const T& data) : data(data) {}
|
||||
|
||||
ShapeTreeNode(const ShapeTreeNode& other)
|
||||
@ -85,8 +86,9 @@ class ShapeTree {
|
||||
public:
|
||||
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
|
||||
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
|
||||
// Create ShapeTree with the given shape, and default T values for all nodes.
|
||||
explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {}
|
||||
// Create ShapeTree with the given shape, and default-constructed T values for
|
||||
// all nodes.
|
||||
explicit ShapeTree(const Shape& shape);
|
||||
// Create ShapeTree with the given shape, and init_value for all nodes.
|
||||
ShapeTree(const Shape& shape, const T& init_value);
|
||||
|
||||
@ -127,6 +129,19 @@ class ShapeTree {
|
||||
const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
|
||||
Status ForEachMutableElement(const MutableVisitorFunction& func);
|
||||
|
||||
// Copy the subtree of values from 'other' rooted at ShapeIndex
|
||||
// 'source_base_index' into the subtree of value in this ShapeTree rooted at
|
||||
// 'target_base_index'.
|
||||
//
|
||||
// Precondition: The subshape of other.shape() at index source_base_index must
|
||||
// be compatible with the subshape of shape() at index target_base_index.
|
||||
void CopySubtreeFrom(const ShapeTree<T>& other,
|
||||
const ShapeIndex& source_base_index,
|
||||
const ShapeIndex& target_base_index);
|
||||
|
||||
bool operator==(const ShapeTree<T>& other) const;
|
||||
bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
using Node = internal::ShapeTreeNode<T>;
|
||||
|
||||
@ -134,6 +149,10 @@ class ShapeTree {
|
||||
// the given 'init_value'.
|
||||
void InitChildren(const Shape& shape, const T& init_value, Node* node);
|
||||
|
||||
// Initialize node->children based on 'shape'. All children have
|
||||
// default-constructed data values.
|
||||
void InitChildren(const Shape& shape, Node* node);
|
||||
|
||||
// Helpers for traversing the shape via ForEachElement. The helpers
|
||||
// recursively traverse the subtree rooted at "index" (defined as in
|
||||
// ShapeUtil::GetSubshape).
|
||||
@ -165,6 +184,24 @@ void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
node->children.emplace_back(new Node());
|
||||
InitChildren(shape.tuple_shapes(i), node->children.back().get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ShapeTree<T>::ShapeTree(const Shape& shape) : root_(), shape_(shape) {
|
||||
// The shape_ field is just used to hold the structure of the shape.
|
||||
// It should not be relied upon to store layout information.
|
||||
LayoutUtil::ClearLayout(&shape_);
|
||||
InitChildren(shape_, &root_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
|
||||
: root_(init_value), shape_(shape) {
|
||||
@ -240,6 +277,48 @@ Status ShapeTree<T>::ForEachMutableElement(const MutableVisitorFunction& func) {
|
||||
return ForEachMutableHelper(func, &root_, &index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
|
||||
const ShapeIndex& source_base_index,
|
||||
const ShapeIndex& target_base_index) {
|
||||
CHECK(ShapeUtil::Compatible(
|
||||
ShapeUtil::GetSubshape(shape(), target_base_index),
|
||||
ShapeUtil::GetSubshape(other.shape(), source_base_index)));
|
||||
ForEachMutableElement(
|
||||
[this, &other, &source_base_index, &target_base_index](
|
||||
const ShapeIndex& index, bool /*is_leaf*/, T* data) {
|
||||
// Copy the data element only if index is in the
|
||||
// subtree rooted at target_base_index.
|
||||
for (int i = 0; i < target_base_index.size(); ++i) {
|
||||
if (i >= index.size() || index[i] != target_base_index[i]) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
// Construct source element index to copy from.
|
||||
ShapeIndex source_index = source_base_index;
|
||||
for (int i = target_base_index.size(); i < index.size(); ++i) {
|
||||
source_index.push_back(index[i]);
|
||||
}
|
||||
*data = other.element(source_index);
|
||||
return Status::OK();
|
||||
})
|
||||
.IgnoreError();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
|
||||
bool equal = true;
|
||||
ForEachElement([this, &other, &equal](const ShapeIndex& index,
|
||||
bool /*is_leaf*/, const T& data) {
|
||||
if (data != other.element(index)) {
|
||||
equal = false;
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.IgnoreError();
|
||||
return equal;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
|
||||
|
@ -245,5 +245,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
|
||||
EXPECT_DEATH(shape_tree.element({0, 0}), "");
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
|
||||
ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
|
||||
EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
|
||||
*shape_tree.mutable_element({2}) = MakeUnique<int>(42);
|
||||
EXPECT_EQ(*shape_tree.element({2}), 42);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) {
|
||||
// Test CopySubtreeFrom method for a single value copied between array-shaped
|
||||
// ShapeTrees.
|
||||
ShapeTree<int> source(array_shape_);
|
||||
*source.mutable_element(/*index=*/{}) = 42;
|
||||
ShapeTree<int> destination(array_shape_, 123);
|
||||
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 123);
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||
/*target_base_index=*/{});
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 42);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) {
|
||||
// Test CopySubtreeFrom method for a copy of all elements from one
|
||||
// tuple-shaped ShapeTree to another.
|
||||
ShapeTree<int> source(tuple_shape_);
|
||||
*source.mutable_element(/*index=*/{}) = 10;
|
||||
*source.mutable_element(/*index=*/{0}) = 11;
|
||||
*source.mutable_element(/*index=*/{1}) = 12;
|
||||
*source.mutable_element(/*index=*/{2}) = 13;
|
||||
|
||||
ShapeTree<int> destination(tuple_shape_, 0);
|
||||
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||
/*target_base_index=*/{});
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 10);
|
||||
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2}), 13);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) {
|
||||
// Test CopySubtreeFrom method for a copy of a single element from one
|
||||
// tuple-shaped ShapeTree to another.
|
||||
ShapeTree<int> source(tuple_shape_);
|
||||
*source.mutable_element(/*index=*/{}) = 10;
|
||||
*source.mutable_element(/*index=*/{0}) = 11;
|
||||
*source.mutable_element(/*index=*/{1}) = 12;
|
||||
*source.mutable_element(/*index=*/{2}) = 13;
|
||||
|
||||
ShapeTree<int> destination(tuple_shape_, 0);
|
||||
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{0},
|
||||
/*target_base_index=*/{1});
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1}), 11);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) {
|
||||
// Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a
|
||||
// nested-tuple-shaped ShapeTree.
|
||||
ShapeTree<int> source(
|
||||
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}));
|
||||
*source.mutable_element(/*index=*/{}) = 10;
|
||||
*source.mutable_element(/*index=*/{0}) = 11;
|
||||
*source.mutable_element(/*index=*/{1}) = 12;
|
||||
|
||||
ShapeTree<int> destination(nested_tuple_shape_, 0);
|
||||
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||
/*target_base_index=*/{2, 0});
|
||||
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12);
|
||||
EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) {
|
||||
// Test CopySubtreeFrom method for a copy from a nested-tuple-shape.
|
||||
ShapeTree<int> source(nested_tuple_shape_, 42);
|
||||
*source.mutable_element(/*index=*/{1}) = 10;
|
||||
*source.mutable_element(/*index=*/{1, 0}) = 11;
|
||||
*source.mutable_element(/*index=*/{1, 1}) = 12;
|
||||
|
||||
ShapeTree<int> destination(
|
||||
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0);
|
||||
|
||||
destination.CopySubtreeFrom(source, /*source_base_index=*/{1},
|
||||
/*target_base_index=*/{});
|
||||
|
||||
EXPECT_EQ(destination.element(/*index=*/{}), 10);
|
||||
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
|
||||
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
|
||||
}
|
||||
|
||||
TEST_F(ShapeTreeTest, OperatorEquals) {
|
||||
{
|
||||
ShapeTree<int> a(array_shape_, 123);
|
||||
ShapeTree<int> b(array_shape_, 42);
|
||||
ShapeTree<int> c(array_shape_, 42);
|
||||
EXPECT_FALSE(a == b);
|
||||
EXPECT_TRUE(a != b);
|
||||
EXPECT_TRUE(b == c);
|
||||
}
|
||||
{
|
||||
ShapeTree<int> a(tuple_shape_);
|
||||
*a.mutable_element(/*index=*/{}) = 10;
|
||||
*a.mutable_element(/*index=*/{0}) = 11;
|
||||
*a.mutable_element(/*index=*/{1}) = 12;
|
||||
|
||||
ShapeTree<int> b(tuple_shape_);
|
||||
*b.mutable_element(/*index=*/{}) = 10;
|
||||
*b.mutable_element(/*index=*/{0}) = 42;
|
||||
*b.mutable_element(/*index=*/{1}) = 11;
|
||||
|
||||
ShapeTree<int> c(tuple_shape_);
|
||||
*c.mutable_element(/*index=*/{}) = 10;
|
||||
*c.mutable_element(/*index=*/{0}) = 42;
|
||||
*c.mutable_element(/*index=*/{1}) = 11;
|
||||
|
||||
EXPECT_FALSE(a == b);
|
||||
EXPECT_TRUE(a != b);
|
||||
EXPECT_TRUE(b == c);
|
||||
EXPECT_FALSE(b != c);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -122,7 +122,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
||||
for (const auto& shape : parameters) {
|
||||
*program_shape.add_parameters() = shape;
|
||||
}
|
||||
*program_shape.mutable_result() = result;
|
||||
*program_shape.mutable_result() = std::move(result);
|
||||
return program_shape;
|
||||
}
|
||||
|
||||
|
@ -829,6 +829,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
|
||||
const int count = GetParam();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
std::vector<float> values;
|
||||
values.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
values.push_back(i / static_cast<float>(count));
|
||||
}
|
||||
@ -836,6 +837,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
|
||||
auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
|
||||
|
||||
std::vector<float> expected;
|
||||
expected.reserve(values.size());
|
||||
for (float value : values) {
|
||||
expected.push_back(value * value);
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
|
||||
VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
|
||||
VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
|
||||
|
||||
EXPECT_EQ(expected, actual->u8s());
|
||||
EXPECT_EQ(expected, actual->u8s_string());
|
||||
}
|
||||
|
||||
void ClientLibraryTestBase::ComputeAndCompareTuple(
|
||||
|
@ -442,6 +442,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
Array3D<float> arr0(9, 17, 1);
|
||||
arr0.Fill(1);
|
||||
|
||||
Array3D<float> arr1(9, 17, 256);
|
||||
arr1.Fill(2);
|
||||
|
||||
Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
|
||||
for (int64 i = 0; i < expected.n1(); ++i) {
|
||||
for (int64 j = 0; j < expected.n2(); ++j) {
|
||||
int64 kk = 0;
|
||||
for (const Array3D<float>& arr : {arr0, arr1}) {
|
||||
for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
|
||||
expected(i, j, kk) = arr(i, j, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ComputationDataHandle h0;
|
||||
auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
|
||||
&builder, &h0);
|
||||
ComputationDataHandle h1;
|
||||
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
|
||||
&builder, &h1);
|
||||
|
||||
auto concatenated = builder.ConcatInDim({h0, h1}, 2);
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
|
||||
}
|
||||
|
||||
// Describes a binary rank-2 concatenation test.
|
||||
struct R2BinarySpec {
|
||||
int64 lhs_dim0;
|
||||
|
@ -262,7 +262,7 @@ class NearComparator {
|
||||
max_abs_err_ = 0.0;
|
||||
*miscompares_.mutable_shape() =
|
||||
ShapeUtil::ChangeElementType(actual.shape(), PRED);
|
||||
miscompares_.mutable_preds()->Resize(
|
||||
miscompares_.mutable_preds()->resize(
|
||||
ShapeUtil::ElementsIn(miscompares_.shape()), false);
|
||||
multi_index_.resize(expected.shape().dimensions_size(), 0);
|
||||
|
||||
@ -389,7 +389,7 @@ class NearComparator {
|
||||
tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
|
||||
now_usec, name.c_str()));
|
||||
TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
|
||||
filename, literal));
|
||||
filename, literal.ToProto()));
|
||||
LOG(ERROR) << "wrote to " << name << " file: " << filename;
|
||||
}
|
||||
|
||||
|
@ -83,9 +83,10 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
|
||||
LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
|
||||
EXPECT_EQ(3, results.size());
|
||||
for (const string& result : results) {
|
||||
Literal literal;
|
||||
LiteralProto literal_proto;
|
||||
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
|
||||
&literal));
|
||||
&literal_proto));
|
||||
Literal literal(literal_proto);
|
||||
if (result.find("expected") != string::npos) {
|
||||
EXPECT_EQ("2", LiteralUtil::ToString(literal));
|
||||
} else if (result.find("actual") != string::npos) {
|
||||
|
@ -47,6 +47,7 @@ TEST_F(LogTest, LogTenValues) {
|
||||
builder.Log(x);
|
||||
|
||||
std::vector<float> expected;
|
||||
expected.reserve(input.size());
|
||||
for (float f : input) {
|
||||
expected.push_back(std::log(f));
|
||||
}
|
||||
|
@ -246,6 +246,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
|
||||
}
|
||||
|
||||
std::vector<GlobalData*> param_data;
|
||||
param_data.reserve(param_data_owner.size());
|
||||
for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
|
||||
param_data.push_back(data.get());
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ class SliceTest : public ClientLibraryTestBase {
|
||||
template <typename NativeT>
|
||||
void RunSliceTenToTwo() {
|
||||
std::vector<NativeT> constant;
|
||||
constant.reserve(10);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
constant.push_back(static_cast<NativeT>(i));
|
||||
}
|
||||
|
@ -64,6 +64,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
|
||||
for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
std::vector<float> exponents;
|
||||
exponents.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
exponents.push_back(i / static_cast<float>(count));
|
||||
}
|
||||
@ -71,6 +72,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
|
||||
auto exp = builder.Exp(x);
|
||||
|
||||
std::vector<float> expected;
|
||||
expected.reserve(exponents.size());
|
||||
for (float exponent : exponents) {
|
||||
expected.push_back(std::exp(exponent));
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
@ -81,6 +81,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
|
||||
client->GetComputationShape(computation).ConsumeValueOrDie();
|
||||
|
||||
std::vector<const Shape*> layouts;
|
||||
layouts.reserve(program_shape->parameters_size());
|
||||
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
||||
layouts.push_back(&program_shape->parameters(i));
|
||||
}
|
||||
|
@ -56,6 +56,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
|
||||
client->GetComputationShape(computation).ConsumeValueOrDie();
|
||||
|
||||
std::vector<const Shape*> layouts;
|
||||
layouts.reserve(program_shape->parameters_size());
|
||||
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
||||
layouts.push_back(&program_shape->parameters(i));
|
||||
}
|
||||
|
@ -66,7 +66,8 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
|
||||
if (use_fake_data) {
|
||||
arguments = MakeFakeArgumentsOrDie(computation, client);
|
||||
} else { // use recorded data if available
|
||||
for (const Literal& literal : module.arguments()) {
|
||||
for (const auto& proto : module.arguments()) {
|
||||
Literal literal(proto);
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
|
||||
client->TransferToServer(literal));
|
||||
arguments.push_back(std::move(data));
|
||||
@ -74,6 +75,7 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
|
||||
}
|
||||
|
||||
std::vector<GlobalData*> execute_arguments;
|
||||
execute_arguments.reserve(arguments.size());
|
||||
for (auto& argument : arguments) {
|
||||
execute_arguments.push_back(argument.get());
|
||||
}
|
||||
@ -100,7 +102,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
|
||||
if (module.has_result()) {
|
||||
fprintf(stdout, "was %s:%s\n",
|
||||
ShapeUtil::HumanString(module.result().shape()).c_str(),
|
||||
LiteralUtil::ToString(module.result()).c_str());
|
||||
LiteralUtil::ToString(Literal(module.result())).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -37,9 +37,10 @@ int main(int argc, char **argv) {
|
||||
<< " <path-to-serialized-literal-proto>";
|
||||
}
|
||||
|
||||
xla::Literal literal;
|
||||
xla::LiteralProto literal_proto;
|
||||
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
|
||||
&literal));
|
||||
LOG(INFO) << "literal: " << literal.ShortDebugString();
|
||||
&literal_proto));
|
||||
xla::Literal literal(literal_proto);
|
||||
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
|
||||
fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str());
|
||||
}
|
||||
|
@ -92,11 +92,11 @@ message TransferToClientRequest {
|
||||
}
|
||||
|
||||
message TransferToClientResponse {
|
||||
Literal literal = 1;
|
||||
LiteralProto literal = 1;
|
||||
}
|
||||
|
||||
message TransferToServerRequest {
|
||||
Literal literal = 1;
|
||||
LiteralProto literal = 1;
|
||||
DeviceHandle device_handle = 2;
|
||||
}
|
||||
|
||||
@ -105,7 +105,7 @@ message TransferToServerResponse {
|
||||
}
|
||||
|
||||
message TransferToInfeedRequest {
|
||||
Literal literal = 1;
|
||||
LiteralProto literal = 1;
|
||||
int64 replica_id = 2;
|
||||
DeviceHandle device_handle = 3;
|
||||
}
|
||||
@ -123,7 +123,7 @@ message TransferFromOutfeedRequest {
|
||||
}
|
||||
|
||||
message TransferFromOutfeedResponse {
|
||||
Literal literal = 1;
|
||||
LiteralProto literal = 1;
|
||||
}
|
||||
|
||||
message ResetDeviceRequest {
|
||||
|
@ -275,7 +275,7 @@ message ChannelHandle {
|
||||
//
|
||||
// Transfers to/from the client are encoded in literal form, and the structure
|
||||
// of the repeated fields is implied by the shape.
|
||||
message Literal {
|
||||
message LiteralProto {
|
||||
Shape shape = 1;
|
||||
repeated bool preds = 2;
|
||||
bytes u8s = 3;
|
||||
@ -285,7 +285,7 @@ message Literal {
|
||||
repeated uint64 u64s = 7;
|
||||
repeated float f32s = 8;
|
||||
repeated double f64s = 9;
|
||||
repeated Literal tuple_literals = 10;
|
||||
repeated LiteralProto tuple_literals = 10;
|
||||
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
|
||||
}
|
||||
|
||||
@ -337,7 +337,7 @@ message Window {
|
||||
// field in OpRequest.
|
||||
|
||||
message ConstantRequest {
|
||||
Literal literal = 2;
|
||||
LiteralProto literal = 2;
|
||||
}
|
||||
|
||||
message GetTupleElementRequest {
|
||||
|
@ -85,6 +85,7 @@ cc_library(
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
|
||||
"//tensorflow/contrib/nccl:nccl_kernels",
|
||||
"//tensorflow/contrib/seq2seq:beam_search_ops_kernels",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
|
||||
"//tensorflow/contrib/text:all_kernels",
|
||||
],
|
||||
@ -100,6 +101,7 @@ cc_library(
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
|
||||
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
|
||||
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
||||
"//tensorflow/contrib/text:all_ops",
|
||||
],
|
||||
|
@ -347,6 +347,7 @@ class BatchResource : public ResourceBase {
|
||||
|
||||
// Concatenate the tasks ith input tensors into a big output tensor.
|
||||
std::vector<Tensor> to_concatenate;
|
||||
to_concatenate.reserve(batch->num_tasks());
|
||||
for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
|
||||
to_concatenate.push_back(batch->task(task_idx).inputs.at(i));
|
||||
}
|
||||
|
@ -139,6 +139,7 @@ TEST(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) {
|
||||
&callback_data](std::unique_ptr<Batch<FakeTask>> batch) {
|
||||
ASSERT_TRUE(batch->IsClosed());
|
||||
std::vector<size_t> batch_data;
|
||||
batch_data.reserve(batch->num_tasks());
|
||||
for (int i = 0; i < batch->num_tasks(); ++i) {
|
||||
batch_data.push_back(batch->mutable_task(i)->size());
|
||||
}
|
||||
|
@ -295,6 +295,7 @@ void ExpectVecsEquiv(const std::vector<float>& vec1,
|
||||
std::vector<float> GetWeightsByIndex(const std::vector<float>& weights,
|
||||
const std::vector<int>& indices) {
|
||||
std::vector<float> res;
|
||||
res.reserve(indices.size());
|
||||
for (const int index : indices) {
|
||||
res.push_back(weights[index]);
|
||||
}
|
||||
|
@ -236,6 +236,9 @@ add_python_module("tensorflow/tensorboard")
|
||||
add_python_module("tensorflow/tensorboard/backend")
|
||||
add_python_module("tensorflow/tensorboard/backend/event_processing")
|
||||
add_python_module("tensorflow/tensorboard/plugins")
|
||||
add_python_module("tensorflow/tensorboard/plugins/audio")
|
||||
add_python_module("tensorflow/tensorboard/plugins/distributions")
|
||||
add_python_module("tensorflow/tensorboard/plugins/graphs")
|
||||
add_python_module("tensorflow/tensorboard/plugins/histograms")
|
||||
add_python_module("tensorflow/tensorboard/plugins/images")
|
||||
add_python_module("tensorflow/tensorboard/plugins/projector")
|
||||
@ -536,6 +539,7 @@ set(tf_python_op_gen_main_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h"
|
||||
)
|
||||
|
||||
add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs})
|
||||
|
@ -209,10 +209,11 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
# Broken TensorBoard tests due to different paths in windows
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
|
||||
# Broken tensorboard test due to cmake issues.
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
|
||||
# tensor_forest tests (also note that we exclude the hybrid tests for now)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
|
||||
|
@ -150,7 +150,8 @@ class MapDatasetTest(test.TestCase):
|
||||
results.append(sess.run(get_next))
|
||||
except errors.OutOfRangeError:
|
||||
return
|
||||
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
|
||||
threads = [self.checkedThread(target=iterator_thread)
|
||||
for _ in range(64)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
|
@ -375,8 +375,8 @@ class NearestNeighborsOp : public OpKernel {
|
||||
const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
|
||||
const Eigen::Ref<const MatrixXfRowMajor>& centers,
|
||||
const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
|
||||
Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
|
||||
Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
|
||||
const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
|
||||
const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
|
||||
CHECK_LE(k, centers.rows());
|
||||
if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
|
||||
FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
|
||||
|
@ -164,11 +164,12 @@ class KMeans(object):
|
||||
with ops.colocate_with(inp):
|
||||
# Computes Euclidean distance. Note the first and third terms are
|
||||
# broadcast additions.
|
||||
squared_distance = (math_ops.reduce_sum(
|
||||
math_ops.square(inp), 1, keep_dims=True) - 2 * math_ops.matmul(
|
||||
inp, clusters, transpose_b=True) + array_ops.transpose(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.square(clusters), 1, keep_dims=True)))
|
||||
squared_distance = (
|
||||
math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
|
||||
2 * math_ops.matmul(inp, clusters, transpose_b=True) +
|
||||
array_ops.transpose(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.square(clusters), 1, keep_dims=True)))
|
||||
output.append(squared_distance)
|
||||
|
||||
return output
|
||||
@ -229,12 +230,12 @@ class KMeans(object):
|
||||
clusters = nn_impl.l2_normalize(clusters, dim=1)
|
||||
for inp, score in zip(inputs, scores):
|
||||
with ops.colocate_with(inp):
|
||||
(indices,
|
||||
distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1)
|
||||
(indices, distances) = gen_clustering_ops.nearest_neighbors(
|
||||
inp, clusters, 1)
|
||||
if self._distance_metric == COSINE_DISTANCE:
|
||||
distances *= 0.5
|
||||
output.append(
|
||||
(score, array_ops.squeeze(distances), array_ops.squeeze(indices)))
|
||||
output.append((score, array_ops.squeeze(distances),
|
||||
array_ops.squeeze(indices)))
|
||||
return zip(*output)
|
||||
|
||||
def _init_clusters_random(self):
|
||||
@ -265,9 +266,7 @@ class KMeans(object):
|
||||
(not self._use_mini_batch or
|
||||
self._mini_batch_steps_per_iteration > 1))
|
||||
|
||||
def _initialize_clusters(self,
|
||||
cluster_centers,
|
||||
cluster_centers_initialized,
|
||||
def _initialize_clusters(self, cluster_centers, cluster_centers_initialized,
|
||||
cluster_centers_updated):
|
||||
"""Returns an op to initialize the cluster centers."""
|
||||
|
||||
@ -294,22 +293,20 @@ class KMeans(object):
|
||||
|
||||
with ops.colocate_with(cluster_centers_initialized):
|
||||
initialized = control_flow_ops.with_dependencies(
|
||||
[clusters_init],
|
||||
array_ops.identity(cluster_centers_initialized))
|
||||
[clusters_init], array_ops.identity(cluster_centers_initialized))
|
||||
with ops.colocate_with(cluster_centers):
|
||||
assign_centers = state_ops.assign(cluster_centers, clusters_init,
|
||||
validate_shape=False)
|
||||
assign_centers = state_ops.assign(
|
||||
cluster_centers, clusters_init, validate_shape=False)
|
||||
if cluster_centers_updated != cluster_centers:
|
||||
assign_centers = control_flow_ops.group(
|
||||
assign_centers,
|
||||
state_ops.assign(cluster_centers_updated, clusters_init,
|
||||
validate_shape=False))
|
||||
assign_centers = control_flow_ops.with_dependencies(
|
||||
[assign_centers],
|
||||
state_ops.assign(cluster_centers_initialized, True))
|
||||
return control_flow_ops.cond(initialized,
|
||||
control_flow_ops.no_op,
|
||||
lambda: assign_centers).op
|
||||
assign_centers = control_flow_ops.group(assign_centers,
|
||||
state_ops.assign(
|
||||
cluster_centers_updated,
|
||||
clusters_init,
|
||||
validate_shape=False))
|
||||
assign_centers = control_flow_ops.with_dependencies(
|
||||
[assign_centers], state_ops.assign(cluster_centers_initialized, True))
|
||||
return control_flow_ops.cond(initialized, control_flow_ops.no_op,
|
||||
lambda: assign_centers).op
|
||||
|
||||
def _create_variables(self):
|
||||
"""Creates variables.
|
||||
@ -327,19 +324,16 @@ class KMeans(object):
|
||||
cluster_centers_updated back to cluster_centers.
|
||||
"""
|
||||
init_value = array_ops.constant([], dtype=dtypes.float32)
|
||||
cluster_centers = variable_scope.variable(init_value,
|
||||
name='clusters',
|
||||
validate_shape=False)
|
||||
cluster_centers_initialized = variable_scope.variable(False,
|
||||
dtype=dtypes.bool,
|
||||
name='initialized')
|
||||
cluster_centers = variable_scope.variable(
|
||||
init_value, name='clusters', validate_shape=False)
|
||||
cluster_centers_initialized = variable_scope.variable(
|
||||
False, dtype=dtypes.bool, name='initialized')
|
||||
|
||||
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
||||
# Copy of cluster centers actively updated each step according to
|
||||
# mini-batch update rule.
|
||||
cluster_centers_updated = variable_scope.variable(init_value,
|
||||
name='clusters_updated',
|
||||
validate_shape=False)
|
||||
cluster_centers_updated = variable_scope.variable(
|
||||
init_value, name='clusters_updated', validate_shape=False)
|
||||
# How many steps till we copy the updated clusters to cluster_centers.
|
||||
update_in_steps = variable_scope.variable(
|
||||
self._mini_batch_steps_per_iteration,
|
||||
@ -347,20 +341,15 @@ class KMeans(object):
|
||||
name='update_in_steps')
|
||||
# Count of points assigned to cluster_centers_updated.
|
||||
cluster_counts = variable_scope.variable(
|
||||
array_ops.zeros([self._num_clusters],
|
||||
dtype=dtypes.int64))
|
||||
array_ops.zeros([self._num_clusters], dtype=dtypes.int64))
|
||||
else:
|
||||
cluster_centers_updated = cluster_centers
|
||||
update_in_steps = None
|
||||
cluster_counts = (variable_scope.variable(array_ops.ones(
|
||||
[self._num_clusters],
|
||||
dtype=dtypes.int64))
|
||||
cluster_counts = (variable_scope.variable(
|
||||
array_ops.ones([self._num_clusters], dtype=dtypes.int64))
|
||||
if self._use_mini_batch else None)
|
||||
return (cluster_centers,
|
||||
cluster_centers_initialized,
|
||||
cluster_counts,
|
||||
cluster_centers_updated,
|
||||
update_in_steps)
|
||||
return (cluster_centers, cluster_centers_initialized, cluster_counts,
|
||||
cluster_centers_updated, update_in_steps)
|
||||
|
||||
@classmethod
|
||||
def _l2_normalize_data(cls, inputs):
|
||||
@ -391,11 +380,8 @@ class KMeans(object):
|
||||
"""
|
||||
# Implementation of kmeans.
|
||||
inputs = self._inputs
|
||||
(cluster_centers_var,
|
||||
cluster_centers_initialized,
|
||||
total_counts,
|
||||
cluster_centers_updated,
|
||||
update_in_steps) = self._create_variables()
|
||||
(cluster_centers_var, cluster_centers_initialized, total_counts,
|
||||
cluster_centers_updated, update_in_steps) = self._create_variables()
|
||||
init_op = self._initialize_clusters(cluster_centers_var,
|
||||
cluster_centers_initialized,
|
||||
cluster_centers_updated)
|
||||
@ -409,8 +395,7 @@ class KMeans(object):
|
||||
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
|
||||
if self._use_mini_batch:
|
||||
sync_updates_op = self._mini_batch_sync_updates_op(
|
||||
update_in_steps,
|
||||
cluster_centers_var, cluster_centers_updated,
|
||||
update_in_steps, cluster_centers_var, cluster_centers_updated,
|
||||
total_counts)
|
||||
assert sync_updates_op is not None
|
||||
with ops.control_dependencies([sync_updates_op]):
|
||||
@ -421,15 +406,15 @@ class KMeans(object):
|
||||
training_op = self._full_batch_training_op(inputs, cluster_idx,
|
||||
cluster_centers_var)
|
||||
|
||||
return (all_scores, cluster_idx, scores,
|
||||
cluster_centers_initialized, init_op, training_op)
|
||||
return (all_scores, cluster_idx, scores, cluster_centers_initialized,
|
||||
init_op, training_op)
|
||||
|
||||
def _mini_batch_sync_updates_op(self, update_in_steps,
|
||||
cluster_centers_var, cluster_centers_updated,
|
||||
total_counts):
|
||||
def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
|
||||
cluster_centers_updated, total_counts):
|
||||
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
||||
assert update_in_steps is not None
|
||||
with ops.colocate_with(update_in_steps):
|
||||
|
||||
def _f():
|
||||
# Note that there is a race condition here, so we do a best effort
|
||||
# updates here. We reset update_in_steps first so that other workers
|
||||
@ -437,33 +422,36 @@ class KMeans(object):
|
||||
# before resetting total_counts to avoid large updates to
|
||||
# cluster_centers_updated based on partially updated
|
||||
# cluster_center_vars.
|
||||
with ops.control_dependencies([state_ops.assign(
|
||||
update_in_steps,
|
||||
self._mini_batch_steps_per_iteration - 1)]):
|
||||
with ops.colocate_with(cluster_centers_updated):
|
||||
with ops.control_dependencies([
|
||||
state_ops.assign(update_in_steps,
|
||||
self._mini_batch_steps_per_iteration - 1)
|
||||
]):
|
||||
with ops.colocate_with(
|
||||
cluster_centers_updated, ignore_existing=True):
|
||||
if self._distance_metric == COSINE_DISTANCE:
|
||||
cluster_centers = nn_impl.l2_normalize(cluster_centers_updated,
|
||||
dim=1)
|
||||
cluster_centers = nn_impl.l2_normalize(
|
||||
cluster_centers_updated, dim=1)
|
||||
else:
|
||||
cluster_centers = cluster_centers_updated
|
||||
with ops.colocate_with(cluster_centers_var):
|
||||
with ops.control_dependencies([state_ops.assign(
|
||||
cluster_centers_var,
|
||||
cluster_centers)]):
|
||||
with ops.colocate_with(cluster_centers_var):
|
||||
with ops.control_dependencies(
|
||||
[state_ops.assign(cluster_centers_var, cluster_centers)]):
|
||||
with ops.colocate_with(
|
||||
cluster_centers_var, ignore_existing=True):
|
||||
with ops.control_dependencies([
|
||||
state_ops.assign(total_counts,
|
||||
array_ops.zeros_like(total_counts))]):
|
||||
array_ops.zeros_like(total_counts))
|
||||
]):
|
||||
return array_ops.identity(update_in_steps)
|
||||
|
||||
return control_flow_ops.cond(
|
||||
update_in_steps <= 0,
|
||||
_f,
|
||||
update_in_steps <= 0, _f,
|
||||
lambda: state_ops.assign_sub(update_in_steps, 1))
|
||||
else:
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
def _mini_batch_training_op(self, inputs, cluster_idx_list,
|
||||
cluster_centers, total_counts):
|
||||
def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
|
||||
total_counts):
|
||||
"""Creates an op for training for mini batch case.
|
||||
|
||||
Args:
|
||||
@ -487,17 +475,15 @@ class KMeans(object):
|
||||
unique_ids, unique_idx = array_ops.unique(cluster_idx)
|
||||
num_unique_cluster_idx = array_ops.size(unique_ids)
|
||||
# Fetch the old values of counts and cluster_centers.
|
||||
with ops.colocate_with(total_counts):
|
||||
with ops.colocate_with(total_counts, ignore_existing=True):
|
||||
old_counts = array_ops.gather(total_counts, unique_ids)
|
||||
# TODO(agarwal): This colocation seems to run into problems. Fix it.
|
||||
# with ops.colocate_with(cluster_centers):
|
||||
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
|
||||
with ops.colocate_with(cluster_centers, ignore_existing=True):
|
||||
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
|
||||
# Locally aggregate the increment to counts.
|
||||
count_updates = math_ops.unsorted_segment_sum(
|
||||
array_ops.ones_like(
|
||||
unique_idx, dtype=total_counts.dtype),
|
||||
unique_idx,
|
||||
num_unique_cluster_idx)
|
||||
array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
|
||||
unique_idx, num_unique_cluster_idx)
|
||||
# Locally compute the sum of inputs mapped to each id.
|
||||
# For a cluster with old cluster value x, old count n, and with data
|
||||
# d_1,...d_k newly assigned to it, we recompute the new value as
|
||||
@ -507,13 +493,12 @@ class KMeans(object):
|
||||
inp, unique_idx, num_unique_cluster_idx)
|
||||
# Shape to enable broadcasting count_updates and learning_rate to inp.
|
||||
# It extends the shape with 1's to match the rank of inp.
|
||||
broadcast_shape = array_ops.concat(
|
||||
[
|
||||
array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(
|
||||
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
|
||||
dtype=dtypes.int32)
|
||||
],
|
||||
0)
|
||||
broadcast_shape = array_ops.concat([
|
||||
array_ops.reshape(num_unique_cluster_idx, [1]),
|
||||
array_ops.ones(
|
||||
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
|
||||
dtype=dtypes.int32)
|
||||
], 0)
|
||||
# Subtract k * x, see comment above.
|
||||
cluster_center_updates -= math_ops.cast(
|
||||
array_ops.reshape(count_updates, broadcast_shape),
|
||||
@ -524,14 +509,10 @@ class KMeans(object):
|
||||
# scale by 1 / (n + k), see comment above.
|
||||
cluster_center_updates *= learning_rate
|
||||
# Apply the updates.
|
||||
update_counts = state_ops.scatter_add(
|
||||
total_counts,
|
||||
unique_ids,
|
||||
count_updates)
|
||||
update_counts = state_ops.scatter_add(total_counts, unique_ids,
|
||||
count_updates)
|
||||
update_cluster_centers = state_ops.scatter_add(
|
||||
cluster_centers,
|
||||
unique_ids,
|
||||
cluster_center_updates)
|
||||
cluster_centers, unique_ids, cluster_center_updates)
|
||||
update_ops.extend([update_counts, update_cluster_centers])
|
||||
return control_flow_ops.group(*update_ops)
|
||||
|
||||
@ -552,7 +533,7 @@ class KMeans(object):
|
||||
cluster_counts = []
|
||||
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
|
||||
for inp, cluster_idx in zip(inputs, cluster_idx_list):
|
||||
with ops.colocate_with(inp):
|
||||
with ops.colocate_with(inp, ignore_existing=True):
|
||||
cluster_sums.append(
|
||||
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
|
||||
cluster_counts.append(
|
||||
@ -561,7 +542,7 @@ class KMeans(object):
|
||||
array_ops.ones(
|
||||
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
|
||||
[-1, 1]), cluster_idx, self._num_clusters))
|
||||
with ops.colocate_with(cluster_centers):
|
||||
with ops.colocate_with(cluster_centers, ignore_existing=True):
|
||||
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
|
||||
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
|
||||
if self._clusters_l2_normalized():
|
||||
|
@ -94,6 +94,7 @@ TEST(FfmpegLibTest, TestRoundTripGeneratedWav) {
|
||||
}
|
||||
|
||||
std::vector<float> sine_wave;
|
||||
sine_wave.reserve(20000);
|
||||
for (int i = 0; i < 20000; ++i) {
|
||||
sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0));
|
||||
}
|
||||
|
@ -494,6 +494,7 @@ class SparseFeatureCrossOp : public OpKernel {
|
||||
ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
|
||||
&feature_start_indices);
|
||||
|
||||
columns.reserve(values_list_in.size());
|
||||
for (int i = 0; i < values_list_in.size(); ++i) {
|
||||
columns.emplace_back(new SparseTensorColumn<InternalType>(
|
||||
values_list_in[i], std::move(feature_counts[i]),
|
||||
|
@ -308,6 +308,7 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_rea
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import Head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import loss_only_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head
|
||||
|
@ -429,6 +429,23 @@ def multi_label_head(n_classes,
|
||||
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
|
||||
|
||||
|
||||
def loss_only_head(loss_fn, head_name=None):
|
||||
"""Creates a Head that contains only loss terms.
|
||||
|
||||
Loss only head holds additional loss terms to be added to other heads and
|
||||
usually represents additional regularization terms in the objective function.
|
||||
|
||||
Args:
|
||||
loss_fn: a function that takes no argument and returns a list of
|
||||
scalar tensors.
|
||||
head_name: a name for for the head.
|
||||
|
||||
Returns:
|
||||
An instance of `Head` to hold the additional losses.
|
||||
"""
|
||||
return _LossOnlyHead(loss_fn, head_name=head_name)
|
||||
|
||||
|
||||
def multi_head(heads, loss_weights=None):
|
||||
"""Creates a MultiHead stemming from same logits/hidden layer.
|
||||
|
||||
@ -1406,6 +1423,80 @@ class _MultiLabelHead(_SingleHead):
|
||||
return metrics
|
||||
|
||||
|
||||
class _LossOnlyHead(Head):
|
||||
"""`Head` implementation for additional loss terms.
|
||||
|
||||
This class only holds loss terms unrelated to any other heads (labels),
|
||||
e.g. regularization.
|
||||
|
||||
Common usage:
|
||||
This is oftem combine with other heads in a multi head setup.
|
||||
```python
|
||||
head = multi_head([
|
||||
head1, head2, loss_only_head('regularizer', regularizer)])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, loss_fn, head_name=None):
|
||||
self._loss_fn = loss_fn
|
||||
self.head_name = head_name or "loss_only_head"
|
||||
|
||||
@property
|
||||
def logits_dimension(self):
|
||||
return 0
|
||||
|
||||
def create_model_fn_ops(self,
|
||||
features,
|
||||
mode,
|
||||
labels=None,
|
||||
train_op_fn=None,
|
||||
logits=None,
|
||||
logits_input=None,
|
||||
scope=None):
|
||||
"""See `_Head.create_model_fn_ops`.
|
||||
|
||||
Args:
|
||||
features: Not been used.
|
||||
mode: Estimator's `ModeKeys`.
|
||||
labels: Labels `Tensor`, or `dict` of same.
|
||||
train_op_fn: Function that takes a scalar loss and returns an op to
|
||||
optimize with the loss.
|
||||
logits: Not been used.
|
||||
logits_input: Not been used.
|
||||
scope: Optional scope for variable_scope. If provided, will be passed to
|
||||
all heads. Most users will want to set this to `None`, so each head
|
||||
constructs a separate variable_scope according to its `head_name`.
|
||||
|
||||
Returns:
|
||||
A `ModelFnOps` object.
|
||||
|
||||
Raises:
|
||||
ValueError: if `mode` is not recognition.
|
||||
"""
|
||||
_check_mode_valid(mode)
|
||||
loss = None
|
||||
train_op = None
|
||||
if mode != model_fn.ModeKeys.INFER:
|
||||
with variable_scope.variable_scope(scope, default_name=self.head_name):
|
||||
loss = self._loss_fn()
|
||||
if isinstance(loss, list):
|
||||
loss = math_ops.add_n(loss)
|
||||
logging_ops.scalar_summary(
|
||||
_summary_key(self.head_name, mkey.LOSS), loss)
|
||||
if mode == model_fn.ModeKeys.TRAIN:
|
||||
if train_op_fn is None:
|
||||
raise ValueError("train_op_fn can not be None in TRAIN mode")
|
||||
with ops.name_scope(None, "train_op", (loss,)):
|
||||
train_op = train_op_fn(loss)
|
||||
|
||||
return model_fn.ModelFnOps(
|
||||
mode=mode,
|
||||
loss=loss,
|
||||
train_op=train_op,
|
||||
predictions={},
|
||||
eval_metric_ops={})
|
||||
|
||||
|
||||
class _MultiHead(Head):
|
||||
"""`Head` implementation for multi objective learning.
|
||||
|
||||
@ -1525,7 +1616,10 @@ class _MultiHead(Head):
|
||||
if isinstance(logits, dict):
|
||||
head_logits_pairs = []
|
||||
for head in self._heads:
|
||||
head_logits_pairs.append((head, logits[head.head_name]))
|
||||
if isinstance(head, _LossOnlyHead):
|
||||
head_logits_pairs.append((head, None))
|
||||
else:
|
||||
head_logits_pairs.append((head, logits[head.head_name]))
|
||||
else:
|
||||
# Split logits for each head.
|
||||
head_logits_pairs = zip(self._heads, self._split_logits(logits))
|
||||
@ -1606,6 +1700,8 @@ class _MultiHead(Head):
|
||||
predictions = {}
|
||||
output_alternatives = {}
|
||||
for head, m in zip(self._heads, all_model_fn_ops):
|
||||
if isinstance(head, _LossOnlyHead):
|
||||
continue
|
||||
head_name = head.head_name
|
||||
output_alternatives[head_name] = m.output_alternatives[head_name]
|
||||
for k, v in m.predictions.items():
|
||||
|
@ -1638,6 +1638,21 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
}, model_fn_ops)
|
||||
|
||||
|
||||
class LossOnlyHead(test.TestCase):
|
||||
|
||||
def testNoPredictionsAndNoMetrics(self):
|
||||
head = head_lib.loss_only_head(lambda: 1, head_name="const")
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
features={},
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
train_op_fn=head_lib.no_op_train_fn)
|
||||
self.assertDictEqual(model_fn_ops.predictions, {})
|
||||
self.assertDictEqual(model_fn_ops.eval_metric_ops, {})
|
||||
self.assertIsNotNone(model_fn_ops.loss)
|
||||
with session.Session() as sess:
|
||||
self.assertEqual(1, sess.run(model_fn_ops.loss))
|
||||
|
||||
|
||||
class MultiHeadTest(test.TestCase):
|
||||
|
||||
def testInvalidHeads(self):
|
||||
@ -1672,7 +1687,8 @@ class MultiHeadTest(test.TestCase):
|
||||
n_classes=3, label_name="label1", head_name="head1")
|
||||
head2 = head_lib.multi_class_head(
|
||||
n_classes=4, label_name="label2", head_name="head2")
|
||||
head = head_lib.multi_head((head1, head2))
|
||||
head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")
|
||||
head = head_lib.multi_head((head1, head2, head3))
|
||||
labels = {
|
||||
"label1": (1,),
|
||||
"label2": (1,)
|
||||
@ -1691,7 +1707,7 @@ class MultiHeadTest(test.TestCase):
|
||||
self.assertIsNone(model_fn_ops.output_alternatives)
|
||||
|
||||
with session.Session() as sess:
|
||||
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
||||
self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)
|
||||
|
||||
def testTrain_withHeadWeights(self):
|
||||
head1 = head_lib.multi_class_head(
|
||||
|
@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None,
|
||||
```
|
||||
|
||||
Args:
|
||||
vocabulary_file: The vocabulary filename.
|
||||
vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
|
||||
num_oov_buckets: The number of out-of-vocabulary buckets.
|
||||
vocab_size: Number of the elements in the vocabulary, if known.
|
||||
default_value: The value to use for out-of-vocabulary feature values.
|
||||
@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None,
|
||||
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
|
||||
than zero.
|
||||
"""
|
||||
if not vocabulary_file:
|
||||
raise ValueError("vocabulary_file must be specified.")
|
||||
if vocabulary_file is None or (
|
||||
isinstance(vocabulary_file, str) and not vocabulary_file):
|
||||
raise ValueError("vocabulary_file must be specified and must not be empty.")
|
||||
if num_oov_buckets < 0:
|
||||
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
|
||||
% num_oov_buckets)
|
||||
|
@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase):
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||
|
||||
def test_string_index_table_from_file_tensor_filename(self):
|
||||
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
|
||||
with self.test_session():
|
||||
vocabulary_file = constant_op.constant(vocabulary_file)
|
||||
table = lookup.index_table_from_file(
|
||||
vocabulary_file=vocabulary_file, num_oov_buckets=1)
|
||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||
|
||||
def test_int32_index_table_from_file(self):
|
||||
vocabulary_file = self._createVocabFile(
|
||||
"f2i_vocab2.txt", values=("42", "1", "-1000"))
|
||||
@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase):
|
||||
860), # 3 + fingerprint("toccata") mod 300.
|
||||
ids.eval())
|
||||
|
||||
def test_index_table_from_file_with_only_oov_buckets(self):
|
||||
def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
lookup.index_table_from_file,
|
||||
vocabulary_file="")
|
||||
|
||||
def test_index_table_from_file_fails_with_empty_vocabulary(self):
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
lookup.index_table_from_file,
|
||||
|
@ -23,6 +23,7 @@ See the @{$python/contrib.metrics} guide.
|
||||
@@streaming_precision
|
||||
@@streaming_precision_at_thresholds
|
||||
@@streaming_auc
|
||||
@@streaming_curve_points
|
||||
@@streaming_recall_at_k
|
||||
@@streaming_mean_absolute_error
|
||||
@@streaming_mean_iou
|
||||
@ -76,6 +77,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives
|
||||
|
@ -733,6 +733,102 @@ def streaming_true_negatives_at_thresholds(
|
||||
return values['tn'], update_ops['tn']
|
||||
|
||||
|
||||
def streaming_curve_points(labels=None,
|
||||
predictions=None,
|
||||
weights=None,
|
||||
num_thresholds=200,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
curve='ROC',
|
||||
name=None):
|
||||
"""Computes curve (ROC or PR) values for a prespecified number of points.
|
||||
|
||||
The `streaming_curve_points` function creates four local variables,
|
||||
`true_positives`, `true_negatives`, `false_positives` and `false_negatives`
|
||||
that are used to compute the curve values. To discretize the curve, a linearly
|
||||
spaced set of thresholds is used to compute pairs of recall and precision
|
||||
values.
|
||||
|
||||
For best results, `predictions` should be distributed approximately uniformly
|
||||
in the range [0, 1] and not peaked around 0 or 1.
|
||||
|
||||
For estimation of the metric over a stream of data, the function creates an
|
||||
`update_op` operation that updates these variables.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
|
||||
Args:
|
||||
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
|
||||
`bool`.
|
||||
predictions: A floating point `Tensor` of arbitrary shape and whose values
|
||||
are in the range `[0, 1]`.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||
be either `1`, or the same as the corresponding `labels` dimension).
|
||||
num_thresholds: The number of thresholds to use when discretizing the roc
|
||||
curve.
|
||||
metrics_collections: An optional list of collections that `auc` should be
|
||||
added to.
|
||||
updates_collections: An optional list of collections that `update_op` should
|
||||
be added to.
|
||||
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
|
||||
'PR' for the Precision-Recall-curve.
|
||||
name: An optional variable_scope name.
|
||||
|
||||
Returns:
|
||||
points: A `Tensor` with shape [num_thresholds, 2] that contains points of
|
||||
the curve.
|
||||
update_op: An operation that increments the `true_positives`,
|
||||
`true_negatives`, `false_positives` and `false_negatives` variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predictions` and `labels` have mismatched shapes, or if
|
||||
`weights` is not `None` and its shape doesn't match `predictions`, or if
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'curve_points', (labels, predictions,
|
||||
weights)):
|
||||
if curve != 'ROC' and curve != 'PR':
|
||||
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
|
||||
kepsilon = 1e-7 # to account for floating point imprecisions
|
||||
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
|
||||
for i in range(num_thresholds - 2)]
|
||||
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
|
||||
|
||||
values, update_ops = _streaming_confusion_matrix_at_thresholds(
|
||||
labels=labels,
|
||||
predictions=predictions,
|
||||
thresholds=thresholds,
|
||||
weights=weights)
|
||||
|
||||
# Add epsilons to avoid dividing by 0.
|
||||
epsilon = 1.0e-6
|
||||
|
||||
def compute_points(tp, fn, tn, fp):
|
||||
"""Computes the roc-auc or pr-auc based on confusion counts."""
|
||||
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
|
||||
if curve == 'ROC':
|
||||
fp_rate = math_ops.div(fp, fp + tn + epsilon)
|
||||
return fp_rate, rec
|
||||
else: # curve == 'PR'.
|
||||
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
|
||||
return rec, prec
|
||||
|
||||
xs, ys = compute_points(values['tp'], values['fn'], values['tn'],
|
||||
values['fp'])
|
||||
points = array_ops.stack([xs, ys], axis=1)
|
||||
update_op = control_flow_ops.group(*update_ops.values())
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, points)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return points, update_op
|
||||
|
||||
|
||||
def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
||||
metrics_collections=None, updates_collections=None,
|
||||
curve='ROC', name=None):
|
||||
@ -2372,6 +2468,7 @@ __all__ = [
|
||||
'sparse_recall_at_top_k',
|
||||
'streaming_accuracy',
|
||||
'streaming_auc',
|
||||
'streaming_curve_points',
|
||||
'streaming_false_negatives',
|
||||
'streaming_false_negatives_at_thresholds',
|
||||
'streaming_false_positives',
|
||||
|
@ -1327,6 +1327,99 @@ class StreamingRecallTest(test.TestCase):
|
||||
self.assertEqual(0, recall.eval())
|
||||
|
||||
|
||||
class StreamingCurvePointsTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
np.random.seed(1)
|
||||
ops.reset_default_graph()
|
||||
|
||||
def testVars(self):
|
||||
metric_ops.streaming_curve_points(
|
||||
predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
|
||||
_assert_local_variables(
|
||||
self,
|
||||
('curve_points/true_positives:0', 'curve_points/false_negatives:0',
|
||||
'curve_points/false_positives:0', 'curve_points/true_negatives:0'))
|
||||
|
||||
def testMetricsCollection(self):
|
||||
my_collection_name = '__metrics__'
|
||||
points, _ = metric_ops.streaming_curve_points(
|
||||
labels=array_ops.ones((10, 1)),
|
||||
predictions=array_ops.ones((10, 1)),
|
||||
metrics_collections=[my_collection_name])
|
||||
self.assertListEqual(ops.get_collection(my_collection_name), [points])
|
||||
|
||||
def testUpdatesCollection(self):
|
||||
my_collection_name = '__updates__'
|
||||
_, update_op = metric_ops.streaming_curve_points(
|
||||
labels=array_ops.ones((10, 1)),
|
||||
predictions=array_ops.ones((10, 1)),
|
||||
updates_collections=[my_collection_name])
|
||||
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
|
||||
|
||||
def _testValueTensorIsIdempotent(self, curve):
|
||||
predictions = constant_op.constant(
|
||||
np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32)
|
||||
labels = constant_op.constant(
|
||||
np.random.uniform(high=2, size=(10, 3)), dtype=dtypes_lib.float32)
|
||||
|
||||
points, update_op = metric_ops.streaming_curve_points(
|
||||
labels, predictions=predictions, curve=curve)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.local_variables_initializer())
|
||||
|
||||
sess.run(update_op)
|
||||
initial_points = points.eval()
|
||||
|
||||
sess.run(update_op)
|
||||
self.assertAllClose(initial_points, points.eval())
|
||||
|
||||
def testValueTensorIsIdempotentROC(self):
|
||||
self._testValueTensorIsIdempotent(curve='ROC')
|
||||
|
||||
def testValueTensorIsIdempotentPR(self):
|
||||
self._testValueTensorIsIdempotent(curve='PR')
|
||||
|
||||
def _testCase(self, labels, predictions, curve, expected_points):
|
||||
with self.test_session() as sess:
|
||||
predictions_tensor = constant_op.constant(
|
||||
predictions, dtype=dtypes_lib.float32)
|
||||
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
|
||||
points, update_op = metric_ops.streaming_curve_points(
|
||||
labels=labels_tensor,
|
||||
predictions=predictions_tensor,
|
||||
num_thresholds=3,
|
||||
curve=curve)
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
|
||||
self.assertAllClose(expected_points, points.eval())
|
||||
|
||||
def testEdgeCasesROC(self):
|
||||
self._testCase([[1]], [[1]], 'ROC', [[0, 1], [0, 1], [0, 0]])
|
||||
self._testCase([[0]], [[0]], 'ROC', [[1, 1], [0, 1], [0, 1]])
|
||||
self._testCase([[0]], [[1]], 'ROC', [[1, 1], [1, 1], [0, 1]])
|
||||
self._testCase([[1]], [[0]], 'ROC', [[0, 1], [0, 0], [0, 0]])
|
||||
|
||||
def testManyValuesROC(self):
|
||||
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
|
||||
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'ROC',
|
||||
[[1.0, 1.0], [0.0, 0.75], [0.0, 0.0]])
|
||||
|
||||
def testEdgeCasesPR(self):
|
||||
self._testCase([[1]], [[1]], 'PR', [[1, 1], [1, 1], [0, 1]])
|
||||
self._testCase([[0]], [[0]], 'PR', [[1, 0], [1, 1], [1, 1]])
|
||||
self._testCase([[0]], [[1]], 'PR', [[1, 0], [1, 0], [1, 1]])
|
||||
self._testCase([[1]], [[0]], 'PR', [[1, 1], [0, 1], [0, 1]])
|
||||
|
||||
def testManyValuesPR(self):
|
||||
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
|
||||
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'PR',
|
||||
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
|
||||
|
||||
|
||||
class StreamingAUCTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase):
|
||||
class BeamSearchDecoderTest(test.TestCase):
|
||||
|
||||
def _testDynamicDecodeRNN(self, time_major, has_attention):
|
||||
encoder_sequence_length = [3, 2, 3, 1, 1]
|
||||
decoder_sequence_length = [2, 0, 1, 2, 3]
|
||||
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
|
||||
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
|
||||
batch_size = 5
|
||||
decoder_max_time = 4
|
||||
input_depth = 7
|
||||
@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
batch_size_tensor = constant_op.constant(batch_size)
|
||||
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
initial_state = cell.zero_state(batch_size, dtypes.float32)
|
||||
if has_attention:
|
||||
inputs = array_ops.placeholder_with_default(
|
||||
np.random.randn(batch_size, decoder_max_time,
|
||||
@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
num_units=attention_depth,
|
||||
memory=tiled_inputs,
|
||||
memory_sequence_length=tiled_sequence_length)
|
||||
initial_state = beam_search_decoder.tile_batch(
|
||||
initial_state, multiplier=beam_width)
|
||||
cell = attention_wrapper.AttentionWrapper(
|
||||
cell=cell,
|
||||
attention_mechanism=attention_mechanism,
|
||||
@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
alignment_history=False)
|
||||
cell_state = cell.zero_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
|
||||
if has_attention:
|
||||
cell_state = cell_state.clone(
|
||||
cell_state=initial_state)
|
||||
bsd = beam_search_decoder.BeamSearchDecoder(
|
||||
cell=cell,
|
||||
embedding=embedding,
|
||||
|
@ -72,10 +72,30 @@ class FinalBeamSearchDecoderOutput(
|
||||
pass
|
||||
|
||||
|
||||
def tile_batch(t, multiplier, name=None):
|
||||
"""Tile the batch dimension of tensor t.
|
||||
def _tile_batch(t, multiplier):
|
||||
"""Core single-tensor implementation of tile_batch."""
|
||||
t = ops.convert_to_tensor(t, name="t")
|
||||
shape_t = array_ops.shape(t)
|
||||
if t.shape.ndims is None or t.shape.ndims < 1:
|
||||
raise ValueError("t must have statically known rank")
|
||||
tiling = [1] * (t.shape.ndims + 1)
|
||||
tiling[1] = multiplier
|
||||
tiled_static_batch_size = (
|
||||
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
|
||||
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
|
||||
tiled = array_ops.reshape(
|
||||
tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
|
||||
tiled.set_shape(
|
||||
tensor_shape.TensorShape(
|
||||
[tiled_static_batch_size]).concatenate(t.shape[1:]))
|
||||
return tiled
|
||||
|
||||
This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
|
||||
|
||||
def tile_batch(t, multiplier, name=None):
|
||||
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
|
||||
|
||||
For each tensor t in a (possibly nested structure) of tensors,
|
||||
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
|
||||
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
|
||||
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
|
||||
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
|
||||
@ -87,27 +107,16 @@ def tile_batch(t, multiplier, name=None):
|
||||
name: Name scope for any created operations.
|
||||
|
||||
Returns:
|
||||
A `Tensor` shaped `[batch_size * multiplier, ...]`.
|
||||
A (possibly nested structure of) `Tensor` shaped
|
||||
`[batch_size * multiplier, ...]`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `t` does not have a statically known rank or it's < 1.
|
||||
ValueError: if tensor(s) `t` do not have a statically known rank or
|
||||
the rank is < 1.
|
||||
"""
|
||||
with ops.name_scope(name, "tile_batch", [t, multiplier]):
|
||||
t = ops.convert_to_tensor(t, name="t")
|
||||
shape_t = array_ops.shape(t)
|
||||
if t.shape.ndims is None or t.shape.ndims < 1:
|
||||
raise ValueError("t must have statically known rank")
|
||||
tiling = [1] * (t.shape.ndims + 1)
|
||||
tiling[1] = multiplier
|
||||
tiled_static_batch_size = (
|
||||
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
|
||||
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
|
||||
tiled = array_ops.reshape(
|
||||
tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
|
||||
tiled.set_shape(
|
||||
tensor_shape.TensorShape(
|
||||
[tiled_static_batch_size]).concatenate(t.shape[1:]))
|
||||
return tiled
|
||||
flat_t = nest.flatten(t)
|
||||
with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
|
||||
return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
|
||||
|
||||
|
||||
def _check_maybe(t):
|
||||
|
@ -270,7 +270,7 @@ class SessionBundleTest : public ::testing::Test {
|
||||
// MetaGraphDef.
|
||||
// Returns the path of the export.
|
||||
// ** Should only be called once per test **
|
||||
string SetupExport(MetaGraphDefTwiddler twiddler) {
|
||||
string SetupExport(const MetaGraphDefTwiddler& twiddler) {
|
||||
return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename);
|
||||
}
|
||||
// SetupExport that allows for the variables and meta_graph_def filenames
|
||||
|
@ -62,6 +62,7 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"full_path",
|
||||
"if_android",
|
||||
"if_ios",
|
||||
"if_x86",
|
||||
|
@ -30,7 +30,11 @@ Device::Device(Env* env, const DeviceAttributes& device_attributes)
|
||||
rmgr_ = new ResourceMgr(parsed_name_.job);
|
||||
}
|
||||
|
||||
Device::~Device() { delete rmgr_; }
|
||||
Device::~Device() {
|
||||
if (rmgr_ != nullptr) {
|
||||
DeleteResourceMgr();
|
||||
}
|
||||
}
|
||||
|
||||
// static
|
||||
DeviceAttributes Device::BuildDeviceAttributes(
|
||||
|
@ -60,7 +60,9 @@ class Device : public DeviceBase {
|
||||
const string& name() const { return device_attributes_.name(); }
|
||||
|
||||
// Parsed name of this device
|
||||
const DeviceNameUtils::ParsedName parsed_name() const { return parsed_name_; }
|
||||
const DeviceNameUtils::ParsedName& parsed_name() const {
|
||||
return parsed_name_;
|
||||
}
|
||||
|
||||
// Describes what kind of device this is. This is intended to be
|
||||
// human-readable and not computer-parsed, except that two devices
|
||||
@ -149,6 +151,12 @@ class Device : public DeviceBase {
|
||||
return BuildDeviceAttributes(name, device, memory_limit, locality, "");
|
||||
}
|
||||
|
||||
protected:
|
||||
void DeleteResourceMgr() {
|
||||
delete rmgr_;
|
||||
rmgr_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
const DeviceAttributes device_attributes_;
|
||||
DeviceNameUtils::ParsedName parsed_name_;
|
||||
|
@ -53,7 +53,7 @@ Device* DeviceSet::FindDeviceByName(const string& name) const {
|
||||
|
||||
// static
|
||||
int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
|
||||
return DeviceFactory::DevicePriority(d.type());
|
||||
return DeviceFactory::DevicePriority(d.type_string());
|
||||
}
|
||||
|
||||
static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) {
|
||||
|
@ -1231,7 +1231,7 @@ Status FunctionDefToBodyHelper(
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = false;
|
||||
Status s = ConvertGraphDefToGraph(opts, result.gdef, graph);
|
||||
Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph);
|
||||
if (!s.ok()) {
|
||||
delete graph;
|
||||
} else {
|
||||
|
@ -93,7 +93,7 @@ class FunctionTest : public ::testing::Test {
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = false;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g));
|
||||
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g));
|
||||
|
||||
const int version = g->versions().producer();
|
||||
LocalExecutorParams params;
|
||||
@ -949,7 +949,7 @@ GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = false;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get()));
|
||||
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
|
||||
pass(g.get());
|
||||
std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
|
||||
CopyGraph(*g, g1.get());
|
||||
|
@ -324,6 +324,7 @@ static void BM_AllocationDelayed(int iters, int delay) {
|
||||
int size_index = 0;
|
||||
|
||||
std::vector<void*> ptrs;
|
||||
ptrs.reserve(delay);
|
||||
for (int i = 0; i < delay; i++) {
|
||||
ptrs.push_back(nullptr);
|
||||
}
|
||||
|
@ -123,10 +123,12 @@ void Benchmark::RunWithArgs(
|
||||
}
|
||||
// Gets inputs' and outputs' rendezvous keys.
|
||||
std::vector<std::pair<string, Tensor>> in;
|
||||
in.reserve(inputs.size());
|
||||
for (const auto& p : inputs) {
|
||||
in.push_back({GetRendezvousKey(p.first), p.second});
|
||||
}
|
||||
std::vector<string> out;
|
||||
out.reserve(outputs.size());
|
||||
for (const auto& n : outputs) {
|
||||
out.push_back(GetRendezvousKey(n));
|
||||
}
|
||||
|
@ -94,6 +94,7 @@ Status SessionFactory::GetFactory(const SessionOptions& options,
|
||||
// TODO(mrry): Consider providing a system-default fallback option
|
||||
// in this case.
|
||||
std::vector<string> factory_types;
|
||||
factory_types.reserve(candidate_factories.size());
|
||||
for (const auto& candidate_factory : candidate_factories) {
|
||||
factory_types.push_back(candidate_factory.first);
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user