fix merge issues
This commit is contained in:
commit
5efd272aab
@ -41,6 +41,15 @@
|
|||||||
be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
|
be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post-
|
||||||
processing of the rnn. For RNN decoding, this functionality has been replaced
|
processing of the rnn. For RNN decoding, this functionality has been replaced
|
||||||
with an alternative API in `tf.contrib.seq2seq`.
|
with an alternative API in `tf.contrib.seq2seq`.
|
||||||
|
* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of
|
||||||
|
optimized deep learning primitives: In addition to matrix multiplication and
|
||||||
|
convolution, these building blocks include:
|
||||||
|
Direct batched convolution
|
||||||
|
Pooling: maximum, minimum, average
|
||||||
|
Normalization: LRN, batch normalization
|
||||||
|
Activation: rectified linear unit (ReLU)
|
||||||
|
Data manipulation: multi-dimensional transposition (conversion), split,
|
||||||
|
concat, sum and scale.
|
||||||
* TensorForest Estimator now supports SavedModel export for serving.
|
* TensorForest Estimator now supports SavedModel export for serving.
|
||||||
* Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters.
|
* Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters.
|
||||||
* TensorFlow C library now available for Windows.
|
* TensorFlow C library now available for Windows.
|
||||||
|
@ -2,11 +2,11 @@ workspace(name = "org_tensorflow")
|
|||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "io_bazel_rules_closure",
|
name = "io_bazel_rules_closure",
|
||||||
sha256 = "4be8a887f6f38f883236e77bb25c2da10d506f2bf1a8e5d785c0f35574c74ca4",
|
sha256 = "edc91f556b762fc5212d1050d00b12e40dd0b0b1c1d5d96886b59e9a30a6cae4",
|
||||||
strip_prefix = "rules_closure-aac19edc557aec9b603cd7ffe359401264ceff0d",
|
strip_prefix = "rules_closure-3f07fb6a58870afbb36051bd5d54da4479561cc6",
|
||||||
urls = [
|
urls = [
|
||||||
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", # 2017-05-10
|
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz", # 2017-05-31
|
||||||
"https://github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz",
|
"https://github.com/bazelbuild/rules_closure/archive/3f07fb6a58870afbb36051bd5d54da4479561cc6.tar.gz",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -393,6 +393,9 @@ filegroup(
|
|||||||
"//tensorflow/tensorboard/demo:all_files",
|
"//tensorflow/tensorboard/demo:all_files",
|
||||||
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
|
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
|
||||||
"//tensorflow/tensorboard/plugins:all_files",
|
"//tensorflow/tensorboard/plugins:all_files",
|
||||||
|
"//tensorflow/tensorboard/plugins/audio:all_files",
|
||||||
|
"//tensorflow/tensorboard/plugins/distributions:all_files",
|
||||||
|
"//tensorflow/tensorboard/plugins/graphs:all_files",
|
||||||
"//tensorflow/tensorboard/plugins/histograms:all_files",
|
"//tensorflow/tensorboard/plugins/histograms:all_files",
|
||||||
"//tensorflow/tensorboard/plugins/images:all_files",
|
"//tensorflow/tensorboard/plugins/images:all_files",
|
||||||
"//tensorflow/tensorboard/plugins/projector:all_files",
|
"//tensorflow/tensorboard/plugins/projector:all_files",
|
||||||
|
@ -805,6 +805,7 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
|
std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
|
||||||
|
dim_vec.reserve(num_dims);
|
||||||
for (int i = 0; i < num_dims; ++i) {
|
for (int i = 0; i < num_dims; ++i) {
|
||||||
dim_vec.push_back(ic->MakeDim(dims[i]));
|
dim_vec.push_back(ic->MakeDim(dims[i]));
|
||||||
}
|
}
|
||||||
|
@ -113,10 +113,12 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
|
|||||||
feeds.emplace_back(feed.first.name(), feed.second.tensor);
|
feeds.emplace_back(feed.first.name(), feed.second.tensor);
|
||||||
}
|
}
|
||||||
std::vector<string> output_tensor_names;
|
std::vector<string> output_tensor_names;
|
||||||
|
output_tensor_names.reserve(fetch_outputs.size());
|
||||||
for (auto const& output : fetch_outputs) {
|
for (auto const& output : fetch_outputs) {
|
||||||
output_tensor_names.push_back(output.name());
|
output_tensor_names.push_back(output.name());
|
||||||
}
|
}
|
||||||
std::vector<string> target_node_names;
|
std::vector<string> target_node_names;
|
||||||
|
target_node_names.reserve(run_outputs.size());
|
||||||
for (auto const& output : run_outputs) {
|
for (auto const& output : run_outputs) {
|
||||||
target_node_names.push_back(output.node()->name());
|
target_node_names.push_back(output.node()->name());
|
||||||
}
|
}
|
||||||
|
@ -44,6 +44,7 @@ Status ComputeTheoreticalJacobianTranspose(
|
|||||||
size_t x_num = x_shapes.size();
|
size_t x_num = x_shapes.size();
|
||||||
// Call AddSymbolicGradients to get 'dxs' (we will feed 'dys').
|
// Call AddSymbolicGradients to get 'dxs' (we will feed 'dys').
|
||||||
OutputList dys;
|
OutputList dys;
|
||||||
|
dys.reserve(y_shapes.size());
|
||||||
for (const auto& y_shape : y_shapes) {
|
for (const auto& y_shape : y_shapes) {
|
||||||
// TODO(suharshs): This currently assumes that all x's are the same type.
|
// TODO(suharshs): This currently assumes that all x's are the same type.
|
||||||
dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type()));
|
dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type()));
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/cc/framework/testutil.h"
|
#include "tensorflow/cc/framework/testutil.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/cc/client/client_session.h"
|
#include "tensorflow/cc/client/client_session.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/graph/default_device.h"
|
#include "tensorflow/core/graph/default_device.h"
|
||||||
@ -30,7 +32,7 @@ void GetTensors(const Scope& scope, OutputList tensors,
|
|||||||
|
|
||||||
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
|
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
GetTensors(scope, {tensor}, &outputs);
|
GetTensors(scope, {std::move(tensor)}, &outputs);
|
||||||
*out = outputs[0];
|
*out = outputs[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,6 +350,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
|||||||
compile_result->program_shape = *pshape_or.ValueOrDie();
|
compile_result->program_shape = *pshape_or.ValueOrDie();
|
||||||
xla::ProgramShape* pshape = &compile_result->program_shape;
|
xla::ProgramShape* pshape = &compile_result->program_shape;
|
||||||
std::vector<const xla::Shape*> arg_layouts;
|
std::vector<const xla::Shape*> arg_layouts;
|
||||||
|
arg_layouts.reserve(pshape->parameters_size());
|
||||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||||
arg_layouts.push_back(pshape->mutable_parameters(i));
|
arg_layouts.push_back(pshape->mutable_parameters(i));
|
||||||
}
|
}
|
||||||
|
@ -218,6 +218,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
":graph_to_functiondef",
|
":graph_to_functiondef",
|
||||||
|
":union_find",
|
||||||
"//tensorflow/compiler/jit/graphcycles",
|
"//tensorflow/compiler/jit/graphcycles",
|
||||||
"//tensorflow/compiler/jit/kernels:parallel_check_op",
|
"//tensorflow/compiler/jit/kernels:parallel_check_op",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_local_launch_op",
|
"//tensorflow/compiler/jit/kernels:xla_local_launch_op",
|
||||||
@ -237,6 +238,11 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "union_find",
|
||||||
|
hdrs = ["union_find.h"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "compilation_passes_test",
|
name = "compilation_passes_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||||
|
|
||||||
#include "tensorflow/cc/framework/ops.h"
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
@ -101,12 +103,12 @@ Node* Input(const GraphDefBuilder::Options& opts) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
||||||
return ops::UnaryOp("UnaryTest", a, opts);
|
return ops::UnaryOp("UnaryTest", std::move(a), opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
Node* Binary(ops::NodeOut a, ops::NodeOut b,
|
Node* Binary(ops::NodeOut a, ops::NodeOut b,
|
||||||
const GraphDefBuilder::Options& opts) {
|
const GraphDefBuilder::Options& opts) {
|
||||||
return ops::BinaryOp("BinaryTest", a, b, opts);
|
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
|
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
|
||||||
@ -127,7 +129,7 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
|||||||
if (opts.HaveError()) return nullptr;
|
if (opts.HaveError()) return nullptr;
|
||||||
NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
|
NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
|
||||||
opts.op_registry());
|
opts.op_registry());
|
||||||
node_builder.Input(a).Attr("index", index);
|
node_builder.Input(std::move(a)).Attr("index", index);
|
||||||
return opts.FinalizeBuilder(&node_builder);
|
return opts.FinalizeBuilder(&node_builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||||
|
#include "tensorflow/compiler/jit/union_find.h"
|
||||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
@ -206,70 +207,12 @@ Status FindCompilationCandidates(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Union-Find data structure used to compute clusters. We use our own
|
struct Cluster {
|
||||||
// implementation because we want one key feature: when merging clusters, we
|
// Identifies the node that represents this cluster in the cycle detection
|
||||||
// need to know which value becomes the representative of the merged clusters.
|
// graph.
|
||||||
// We use the representatives to name nodes in a cycle detection graph, and we
|
int representative = -1;
|
||||||
// need to control which node is named.
|
|
||||||
// TODO(phawkins): consider merging this code with union-find implementations
|
|
||||||
// in Tensorflow, e.g., in SimplePlacer.
|
|
||||||
class Cluster {
|
|
||||||
public:
|
|
||||||
Cluster();
|
|
||||||
|
|
||||||
int Size() { return FindRoot()->size_; }
|
|
||||||
|
|
||||||
// Merges this cluster with 'other'. This cluster's representative becomes
|
|
||||||
// the representative of the merged cluster; the representative of 'other'
|
|
||||||
// is ignored.
|
|
||||||
void Merge(Cluster* other);
|
|
||||||
|
|
||||||
// Each cluster has an associated integer 'representative', initialized to -1
|
|
||||||
// by default.
|
|
||||||
int GetRepresentative() { return FindRoot()->representative_; }
|
|
||||||
void SetRepresentative(int representative) {
|
|
||||||
FindRoot()->representative_ = representative;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Finds the root element of the cluster. Performs path compression.
|
|
||||||
Cluster* FindRoot();
|
|
||||||
|
|
||||||
int representative_;
|
|
||||||
int rank_;
|
|
||||||
int size_; // Size of the cluster.
|
|
||||||
Cluster* parent_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Cluster::Cluster()
|
|
||||||
: representative_(-1), rank_(0), size_(1), parent_(nullptr) {}
|
|
||||||
|
|
||||||
void Cluster::Merge(Cluster* other) {
|
|
||||||
Cluster* a = FindRoot();
|
|
||||||
Cluster* b = other->FindRoot();
|
|
||||||
if (a == b) return;
|
|
||||||
if (a->rank_ > b->rank_) {
|
|
||||||
b->parent_ = a;
|
|
||||||
a->size_ += b->size_;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
a->parent_ = b;
|
|
||||||
if (a->rank_ == b->rank_) {
|
|
||||||
b->rank_++;
|
|
||||||
}
|
|
||||||
b->representative_ = a->representative_;
|
|
||||||
b->size_ += a->size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
Cluster* Cluster::FindRoot() {
|
|
||||||
if (!parent_) return this;
|
|
||||||
// Path compression: update intermediate nodes to point to the root of the
|
|
||||||
// equivalence class.
|
|
||||||
parent_ = parent_->FindRoot();
|
|
||||||
return parent_;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
||||||
@ -432,10 +375,11 @@ Status MarkForCompilationPass::RunImpl(
|
|||||||
// Each compilation candidate belongs to a cluster. The cluster's
|
// Each compilation candidate belongs to a cluster. The cluster's
|
||||||
// representative
|
// representative
|
||||||
// names the node in the 'cycles' graph that represents the cluster.
|
// names the node in the 'cycles' graph that represents the cluster.
|
||||||
std::vector<Cluster> clusters(graph->num_node_ids());
|
std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids());
|
||||||
std::deque<Cluster*> worklist;
|
std::deque<UnionFind<Cluster>*> worklist;
|
||||||
for (Node* node : compilation_candidates) {
|
for (Node* node : compilation_candidates) {
|
||||||
clusters[node->id()].SetRepresentative(node->id());
|
Cluster& cluster = clusters[node->id()].Get();
|
||||||
|
cluster.representative = node->id();
|
||||||
worklist.push_back(&clusters[node->id()]);
|
worklist.push_back(&clusters[node->id()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -445,7 +389,7 @@ Status MarkForCompilationPass::RunImpl(
|
|||||||
// Repeatedly contract edges between clusters that are on the same device,
|
// Repeatedly contract edges between clusters that are on the same device,
|
||||||
// provided the contraction would not create a cycle.
|
// provided the contraction would not create a cycle.
|
||||||
while (!worklist.empty()) {
|
while (!worklist.empty()) {
|
||||||
int from = worklist.front()->GetRepresentative();
|
int from = worklist.front()->Get().representative;
|
||||||
worklist.pop_front();
|
worklist.pop_front();
|
||||||
|
|
||||||
Node* node_from = graph->FindNodeId(from);
|
Node* node_from = graph->FindNodeId(from);
|
||||||
@ -518,7 +462,7 @@ Status MarkForCompilationPass::RunImpl(
|
|||||||
// Count the number of elements in each cluster.
|
// Count the number of elements in each cluster.
|
||||||
std::vector<int> cluster_sizes(graph->num_node_ids());
|
std::vector<int> cluster_sizes(graph->num_node_ids());
|
||||||
for (const Node* n : compilation_candidates) {
|
for (const Node* n : compilation_candidates) {
|
||||||
int cluster = clusters[n->id()].GetRepresentative();
|
int cluster = clusters[n->id()].Get().representative;
|
||||||
cluster_sizes[cluster]++;
|
cluster_sizes[cluster]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -532,7 +476,7 @@ Status MarkForCompilationPass::RunImpl(
|
|||||||
// if compilation is enabled, otherwise there will be no such candidates).
|
// if compilation is enabled, otherwise there will be no such candidates).
|
||||||
const int min_cluster_size = flags->tf_xla_min_cluster_size;
|
const int min_cluster_size = flags->tf_xla_min_cluster_size;
|
||||||
for (Node* n : compilation_candidates) {
|
for (Node* n : compilation_candidates) {
|
||||||
int cluster = clusters[n->id()].GetRepresentative();
|
int cluster = clusters[n->id()].Get().representative;
|
||||||
|
|
||||||
// Compile if the user marked this node _XlaCompile=true
|
// Compile if the user marked this node _XlaCompile=true
|
||||||
bool compile_attr = false;
|
bool compile_attr = false;
|
||||||
|
81
tensorflow/compiler/jit/union_find.h
Normal file
81
tensorflow/compiler/jit/union_find.h
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
|
||||||
|
#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Union-Find data structure.
|
||||||
|
// Each cluster has an associated value; when merging clusters we can control
|
||||||
|
// which value becomes the representative of the merged clusters. Values must be
|
||||||
|
// copyable.
|
||||||
|
template <typename T>
|
||||||
|
class UnionFind {
|
||||||
|
public:
|
||||||
|
UnionFind() : rank_(0), size_(1), parent_(nullptr) {}
|
||||||
|
|
||||||
|
// Returns the number of elements in a cluster.
|
||||||
|
int Size() { return FindRoot()->size_; }
|
||||||
|
|
||||||
|
// Merges this cluster with 'other'. This cluster's value becomes
|
||||||
|
// the value of the merged cluster; the value of 'other' is ignored.
|
||||||
|
void Merge(UnionFind* other);
|
||||||
|
|
||||||
|
// Each cluster has an associated value. Retrieves the value associated
|
||||||
|
// with this cluster.
|
||||||
|
T& Get() { return FindRoot()->value_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Finds the root element of the cluster. Performs path compression.
|
||||||
|
UnionFind* FindRoot();
|
||||||
|
|
||||||
|
int rank_;
|
||||||
|
int size_; // Size of the cluster.
|
||||||
|
UnionFind* parent_;
|
||||||
|
T value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void UnionFind<T>::Merge(UnionFind* other) {
|
||||||
|
UnionFind<T>* a = FindRoot();
|
||||||
|
UnionFind<T>* b = other->FindRoot();
|
||||||
|
if (a == b) return;
|
||||||
|
if (a->rank_ > b->rank_) {
|
||||||
|
b->parent_ = a;
|
||||||
|
a->size_ += b->size_;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
a->parent_ = b;
|
||||||
|
if (a->rank_ == b->rank_) {
|
||||||
|
b->rank_++;
|
||||||
|
}
|
||||||
|
b->value_ = a->value_;
|
||||||
|
b->size_ += a->size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
UnionFind<T>* UnionFind<T>::FindRoot() {
|
||||||
|
if (!parent_) return this;
|
||||||
|
// Path compression: update intermediate nodes to point to the root of the
|
||||||
|
// equivalence class.
|
||||||
|
parent_ = parent_->FindRoot();
|
||||||
|
return parent_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_
|
@ -50,6 +50,7 @@ class FillOp : public XlaOpKernel {
|
|||||||
// Convert the dims literal into a vector that we can pass to
|
// Convert the dims literal into a vector that we can pass to
|
||||||
// ComputationBuilder.
|
// ComputationBuilder.
|
||||||
std::vector<int64> broadcast;
|
std::vector<int64> broadcast;
|
||||||
|
broadcast.reserve(dims_literal.shape().dimensions(0));
|
||||||
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
|
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
|
||||||
broadcast.push_back(xla::LiteralUtil::Get<int>(dims_literal, {i}));
|
broadcast.push_back(xla::LiteralUtil::Get<int>(dims_literal, {i}));
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,7 @@ class SliceOp : public XlaOpKernel {
|
|||||||
// slice will be an empty handle if the output has no elements.
|
// slice will be an empty handle if the output has no elements.
|
||||||
CHECK_EQ(begin.size(), size.size());
|
CHECK_EQ(begin.size(), size.size());
|
||||||
std::vector<int64> limits;
|
std::vector<int64> limits;
|
||||||
|
limits.reserve(begin.size());
|
||||||
for (int i = 0; i < begin.size(); ++i) {
|
for (int i = 0; i < begin.size(); ++i) {
|
||||||
limits.push_back(begin[i] + size[i]);
|
limits.push_back(begin[i] + size[i]);
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
@ -58,14 +58,13 @@ StatusOr<std::unique_ptr<Literal>> Client::Transfer(
|
|||||||
"server provided response without a literal in "
|
"server provided response without a literal in "
|
||||||
"TransferToClient request");
|
"TransferToClient request");
|
||||||
}
|
}
|
||||||
|
return MakeUnique<Literal>(response.literal());
|
||||||
return WrapUnique(response.release_literal());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
|
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
|
||||||
const Literal& literal, const DeviceHandle* device_handle) {
|
const Literal& literal, const DeviceHandle* device_handle) {
|
||||||
TransferToServerRequest request;
|
TransferToServerRequest request;
|
||||||
*request.mutable_literal() = literal;
|
*request.mutable_literal() = literal.ToProto();
|
||||||
if (device_handle) {
|
if (device_handle) {
|
||||||
*request.mutable_device_handle() = *device_handle;
|
*request.mutable_device_handle() = *device_handle;
|
||||||
}
|
}
|
||||||
@ -93,7 +92,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
|
|||||||
Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
|
Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
|
||||||
const DeviceHandle* device_handle) {
|
const DeviceHandle* device_handle) {
|
||||||
TransferToInfeedRequest request;
|
TransferToInfeedRequest request;
|
||||||
*request.mutable_literal() = literal;
|
*request.mutable_literal() = literal.ToProto();
|
||||||
if (device_handle) {
|
if (device_handle) {
|
||||||
*request.mutable_device_handle() = *device_handle;
|
*request.mutable_device_handle() = *device_handle;
|
||||||
}
|
}
|
||||||
@ -141,7 +140,8 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
|
|||||||
"TransferToClient request");
|
"TransferToClient request");
|
||||||
}
|
}
|
||||||
|
|
||||||
return WrapUnique(response.release_literal());
|
Literal literal(response.literal());
|
||||||
|
return MakeUnique<Literal>(literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Client::ResetDevice() {
|
Status Client::ResetDevice() {
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/computation.h"
|
#include "tensorflow/compiler/xla/client/computation.h"
|
||||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/session.pb.h"
|
#include "tensorflow/compiler/xla/service/session.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service_interface.h"
|
#include "tensorflow/compiler/xla/service_interface.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
@ -165,9 +165,10 @@ ComputationDataHandle ComputationBuilder::ConstantOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
ConstantRequest request;
|
ConstantRequest request;
|
||||||
Literal* literal = request.mutable_literal();
|
Literal literal;
|
||||||
populate(literal);
|
populate(&literal);
|
||||||
VLOG(3) << "created constant: " << literal->ShortDebugString();
|
*request.mutable_literal() = literal.ToProto();
|
||||||
|
VLOG(3) << "created constant: " << request.literal().ShortDebugString();
|
||||||
OpRequest op_request;
|
OpRequest op_request;
|
||||||
*op_request.mutable_constant_request() = request;
|
*op_request.mutable_constant_request() = request;
|
||||||
*op_request.mutable_computation() = computation_.handle();
|
*op_request.mutable_computation() = computation_.handle();
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -23,7 +24,7 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle)
|
GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle)
|
||||||
: handle_(handle), parent_(parent) {}
|
: handle_(std::move(handle)), parent_(parent) {}
|
||||||
|
|
||||||
GlobalData::~GlobalData() {
|
GlobalData::~GlobalData() {
|
||||||
UnregisterRequest request;
|
UnregisterRequest request;
|
||||||
|
@ -222,8 +222,9 @@ tensorflow::Status LocalExecutable::RecordArguments(
|
|||||||
SessionModule* session_module) {
|
SessionModule* session_module) {
|
||||||
session_module->clear_arguments();
|
session_module->clear_arguments();
|
||||||
for (const ShapedBuffer* argument : arguments) {
|
for (const ShapedBuffer* argument : arguments) {
|
||||||
TF_RETURN_IF_ERROR(
|
Literal literal;
|
||||||
LiteralFromShapedBuffer(*argument, session_module->add_arguments()));
|
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal));
|
||||||
|
*session_module->add_arguments() = literal.ToProto();
|
||||||
}
|
}
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
@ -231,9 +232,13 @@ tensorflow::Status LocalExecutable::RecordArguments(
|
|||||||
tensorflow::Status LocalExecutable::RecordResult(
|
tensorflow::Status LocalExecutable::RecordResult(
|
||||||
const ShapedBuffer* result, SessionModule* session_module) {
|
const ShapedBuffer* result, SessionModule* session_module) {
|
||||||
session_module->clear_result();
|
session_module->clear_result();
|
||||||
return LiteralFromShapedBuffer(*result, session_module->mutable_result());
|
Literal literal(session_module->result());
|
||||||
|
TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal));
|
||||||
|
*session_module->mutable_result() = literal.ToProto();
|
||||||
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(dnovillo) Change signature to return StatusOr<Literal>.
|
||||||
tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
|
tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
|
||||||
const ShapedBuffer& shaped_buffer, Literal* literal) {
|
const ShapedBuffer& shaped_buffer, Literal* literal) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -856,5 +856,26 @@ TEST_F(LiteralUtilTest, ConvertR4) {
|
|||||||
EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted));
|
EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
|
||||||
|
LiteralProto p;
|
||||||
|
p.mutable_shape()->set_element_type(PRED);
|
||||||
|
for (int len = 0; len < 25; ++len) {
|
||||||
|
p.mutable_shape()->clear_dimensions();
|
||||||
|
p.mutable_shape()->add_dimensions(len);
|
||||||
|
p.clear_preds();
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
p.add_preds((i % 2) == (len % 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
Literal literal(p);
|
||||||
|
ASSERT_EQ(len, literal.preds_size());
|
||||||
|
int i = 0;
|
||||||
|
for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) {
|
||||||
|
EXPECT_EQ((i % 2) == (len % 2), *it);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -60,8 +60,8 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
|
|||||||
int64 elements = ShapeUtil::ElementsIn(shape);
|
int64 elements = ShapeUtil::ElementsIn(shape);
|
||||||
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
|
LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
|
||||||
result.get());
|
result.get());
|
||||||
tensorflow::protobuf::RepeatedField<float>* field = result->mutable_f32s();
|
std::vector<float>* field = result->mutable_f32s();
|
||||||
char* data = tensorflow::bit_cast<char*>(field->mutable_data());
|
char* data = tensorflow::bit_cast<char*>(field->data());
|
||||||
uint64 bytes = elements * sizeof(float);
|
uint64 bytes = elements * sizeof(float);
|
||||||
tensorflow::StringPiece sp;
|
tensorflow::StringPiece sp;
|
||||||
auto s = file_->Read(offset_, bytes, &sp, data);
|
auto s = file_->Read(offset_, bytes, &sp, data);
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/reference_util.h"
|
#include "tensorflow/compiler/xla/reference_util.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||||
@ -331,7 +332,8 @@ ReferenceUtil::ConvArray4DGeneralDimensions(
|
|||||||
std::pair<int64, int64> kernel_stride, Padding padding,
|
std::pair<int64, int64> kernel_stride, Padding padding,
|
||||||
ConvolutionDimensionNumbers dimension_numbers) {
|
ConvolutionDimensionNumbers dimension_numbers) {
|
||||||
return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
|
return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
|
||||||
{1, 1}, {1, 1}, dimension_numbers);
|
{1, 1}, {1, 1},
|
||||||
|
std::move(dimension_numbers));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<Array4D<float>>
|
/* static */ std::unique_ptr<Array4D<float>>
|
||||||
|
@ -529,6 +529,7 @@ cc_library(
|
|||||||
srcs = ["transfer_manager.cc"],
|
srcs = ["transfer_manager.cc"],
|
||||||
hdrs = ["transfer_manager.h"],
|
hdrs = ["transfer_manager.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -1680,10 +1681,8 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":buffer_assignment",
|
":buffer_assignment",
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_ordering",
|
|
||||||
":hlo_proto",
|
":hlo_proto",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
"//tensorflow/compiler/xla:util",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -171,6 +171,7 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
|
|||||||
executor, allocation->device_memory(), allocation->shape()));
|
executor, allocation->device_memory(), allocation->shape()));
|
||||||
|
|
||||||
std::vector<GlobalDataHandle> element_handles;
|
std::vector<GlobalDataHandle> element_handles;
|
||||||
|
element_handles.reserve(element_bases.size());
|
||||||
for (int i = 0; i < element_bases.size(); ++i) {
|
for (int i = 0; i < element_bases.size(); ++i) {
|
||||||
element_handles.push_back(RegisterInternal(
|
element_handles.push_back(RegisterInternal(
|
||||||
allocation->backend(), allocation->device_ordinal(), element_bases[i],
|
allocation->backend(), allocation->device_ordinal(), element_bases[i],
|
||||||
|
@ -229,7 +229,8 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
|
|||||||
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
|
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
|
||||||
FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>>
|
FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>>
|
||||||
buffer_to_source_indices;
|
buffer_to_source_indices;
|
||||||
TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices](
|
TF_RETURN_IF_ERROR(points_to.ForEachElement(
|
||||||
|
[this, &buffer_to_source_indices](
|
||||||
const ShapeIndex& index, bool /*is_leaf*/,
|
const ShapeIndex& index, bool /*is_leaf*/,
|
||||||
const std::vector<const LogicalBuffer*>& buffers) {
|
const std::vector<const LogicalBuffer*>& buffers) {
|
||||||
if (buffers.size() > 1) {
|
if (buffers.size() > 1) {
|
||||||
@ -449,10 +450,14 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
|||||||
FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) {
|
FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) {
|
||||||
const HloInstruction* init_hlo = while_hlo->operand(0);
|
const HloInstruction* init_hlo = while_hlo->operand(0);
|
||||||
const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo);
|
const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo);
|
||||||
|
|
||||||
|
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
|
||||||
|
FlatSet<const LogicalBuffer*> buffer_set;
|
||||||
|
|
||||||
ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
|
ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
|
||||||
TF_RETURN_IF_ERROR(points_to.ForEachElement(
|
TF_RETURN_IF_ERROR(points_to.ForEachElement(
|
||||||
[init_hlo, read_only_indices, shared_copies, ©_overrides](
|
[init_hlo, read_only_indices, shared_copies, &buffer_set,
|
||||||
const ShapeIndex& index, bool /*is_leaf*/,
|
©_overrides](const ShapeIndex& index, bool /*is_leaf*/,
|
||||||
const std::vector<const LogicalBuffer*>& buffers) {
|
const std::vector<const LogicalBuffer*>& buffers) {
|
||||||
// Look for read-only entry parameters.
|
// Look for read-only entry parameters.
|
||||||
if (!read_only_indices->element(index)) {
|
if (!read_only_indices->element(index)) {
|
||||||
@ -468,6 +473,7 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
|||||||
if (!is_entry_parameter && !is_constant) {
|
if (!is_entry_parameter && !is_constant) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have found an entry parameter or constant that is read-only in
|
// We have found an entry parameter or constant that is read-only in
|
||||||
// the while body. These buffers are managed by the caller, and cannot
|
// the while body. These buffers are managed by the caller, and cannot
|
||||||
// be aliased with non-parameter buffers. Revert this read-only index,
|
// be aliased with non-parameter buffers. Revert this read-only index,
|
||||||
@ -476,16 +482,17 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
|||||||
|
|
||||||
// Optimization to allow multiple while loops that share the same
|
// Optimization to allow multiple while loops that share the same
|
||||||
// read-only entry parameters (or constants) to share a single copy.
|
// read-only entry parameters (or constants) to share a single copy.
|
||||||
// Only unambiguous array-shaped buffers are allowed, to reduce code
|
// Only unambiguous and distinct array-shaped buffers are allowed, to
|
||||||
// complexity. The shape of the entry parameter must be identical to
|
// reduce code complexity. The shape of the entry parameter must be
|
||||||
// the shape of the init_hlo at this index, to ensure there were no
|
// identical to the shape of the init_hlo at this index, to ensure
|
||||||
// intervening bitcast or GTE instructions, which are also hard to
|
// there were no intervening bitcast or GTE instructions, which are
|
||||||
// handle.
|
// also hard to handle.
|
||||||
const Shape& pointee_shape = pointee->shape();
|
const Shape& pointee_shape = pointee->shape();
|
||||||
const Shape& init_shape =
|
const Shape& init_shape =
|
||||||
ShapeUtil::GetSubshape(init_hlo->shape(), index);
|
ShapeUtil::GetSubshape(init_hlo->shape(), index);
|
||||||
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
|
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
|
||||||
ShapeUtil::Equal(pointee_shape, init_shape)) {
|
ShapeUtil::Equal(pointee_shape, init_shape) &&
|
||||||
|
buffer_set.count(buffer) < 1) {
|
||||||
HloInstruction** copy = &(*shared_copies)[pointee];
|
HloInstruction** copy = &(*shared_copies)[pointee];
|
||||||
if (*copy == nullptr) {
|
if (*copy == nullptr) {
|
||||||
*copy =
|
*copy =
|
||||||
@ -496,6 +503,9 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
|||||||
*copy_overrides.mutable_element(index) = *copy;
|
*copy_overrides.mutable_element(index) = *copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tracks whether this current buffer is distinct.
|
||||||
|
buffer_set.insert(buffer);
|
||||||
|
|
||||||
// We've already reverted the read-only index and handled the
|
// We've already reverted the read-only index and handled the
|
||||||
// single-copy optimization above, so there's nothing more to do.
|
// single-copy optimization above, so there's nothing more to do.
|
||||||
break;
|
break;
|
||||||
|
@ -44,13 +44,20 @@ class CopyInsertionTest : public HloTestBase {
|
|||||||
EXPECT_IS_OK(copy_insertion.Run(module).status());
|
EXPECT_IS_OK(copy_insertion.Run(module).status());
|
||||||
|
|
||||||
// Verify the points to set of the root of the computation after copy
|
// Verify the points to set of the root of the computation after copy
|
||||||
// insertion contains no constants or parameters.
|
// insertion contains no constants or parameters, and is distinct and
|
||||||
|
// non-ambiguous.
|
||||||
auto points_to_analysis =
|
auto points_to_analysis =
|
||||||
TuplePointsToAnalysis::Run(module).ConsumeValueOrDie();
|
TuplePointsToAnalysis::Run(module).ConsumeValueOrDie();
|
||||||
|
const auto& points_to = points_to_analysis->GetPointsToSet(
|
||||||
|
module->entry_computation()->root_instruction());
|
||||||
|
EXPECT_TRUE(points_to.IsDistinct());
|
||||||
|
EXPECT_TRUE(!points_to.IsAmbiguous());
|
||||||
|
|
||||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers =
|
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers =
|
||||||
points_to_analysis
|
points_to_analysis
|
||||||
->GetPointsToSet(module->entry_computation()->root_instruction())
|
->GetPointsToSet(module->entry_computation()->root_instruction())
|
||||||
.CreateFlattenedSet();
|
.CreateFlattenedSet();
|
||||||
|
|
||||||
for (const LogicalBuffer* buffer : maybe_live_out_buffers) {
|
for (const LogicalBuffer* buffer : maybe_live_out_buffers) {
|
||||||
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant);
|
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant);
|
||||||
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter);
|
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter);
|
||||||
@ -390,6 +397,47 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
|||||||
return builder.Build();
|
return builder.Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Builds a While body computation with two output tuple elements dependent on
|
||||||
|
// both input tuple elements.
|
||||||
|
//
|
||||||
|
// EX: Body({in0, in1, in2})
|
||||||
|
// out0 = Add(in0, 1)
|
||||||
|
// out1 = in1
|
||||||
|
// out2 = in2
|
||||||
|
// Tuple(out0, out1, out2)
|
||||||
|
std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
|
||||||
|
auto builder = HloComputation::Builder(TestName() + ".Body");
|
||||||
|
|
||||||
|
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
|
||||||
|
{induction_variable_shape_, data_shape_, data_shape_});
|
||||||
|
|
||||||
|
auto loop_state = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
|
||||||
|
|
||||||
|
// Update the induction variable GTE(0).
|
||||||
|
auto induction_variable =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||||
|
induction_variable_shape_, loop_state, 0));
|
||||||
|
auto inc = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||||
|
|
||||||
|
// add0 = Add(in0, 1)
|
||||||
|
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
|
||||||
|
// data1 = GTE(1).
|
||||||
|
HloInstruction* data1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
||||||
|
|
||||||
|
// data2 = GTE(2).
|
||||||
|
HloInstruction* data2 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
|
||||||
|
|
||||||
|
// Create output Tuple.
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
|
||||||
|
|
||||||
|
return builder.Build();
|
||||||
|
}
|
||||||
|
|
||||||
// Builds a While body computation with read-only tuple element 0.
|
// Builds a While body computation with read-only tuple element 0.
|
||||||
// EX:
|
// EX:
|
||||||
// Body({in0, in1})
|
// Body({in0, in1})
|
||||||
@ -408,6 +456,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
|||||||
// Update data GTE(1).
|
// Update data GTE(1).
|
||||||
auto data = builder.AddInstruction(
|
auto data = builder.AddInstruction(
|
||||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
||||||
|
|
||||||
// Use 'induction_variable' in computation with no path to output tuple.
|
// Use 'induction_variable' in computation with no path to output tuple.
|
||||||
auto update = builder.AddInstruction(
|
auto update = builder.AddInstruction(
|
||||||
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
|
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
|
||||||
@ -431,6 +480,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
|||||||
// Create param instruction to access loop state.
|
// Create param instruction to access loop state.
|
||||||
const Shape& loop_state_shape =
|
const Shape& loop_state_shape =
|
||||||
nested ? nested_loop_state_shape_ : loop_state_shape_;
|
nested ? nested_loop_state_shape_ : loop_state_shape_;
|
||||||
|
|
||||||
auto loop_state = builder.AddInstruction(
|
auto loop_state = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
|
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
|
||||||
// Update the induction variable GTE(0).
|
// Update the induction variable GTE(0).
|
||||||
@ -972,7 +1022,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
|
|||||||
op::Copy(old_init->operand(1)->operand(0)))));
|
op::Copy(old_init->operand(1)->operand(0)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests while init instruction buffer which interferes with while result buffer.
|
// Tests while init instruction buffer which interferes with while result
|
||||||
|
// buffer.
|
||||||
//
|
//
|
||||||
// init_data = Broadcast(...)
|
// init_data = Broadcast(...)
|
||||||
// add_unrelated = Add(init_data) // takes a reference to cause interference
|
// add_unrelated = Add(init_data) // takes a reference to cause interference
|
||||||
@ -989,5 +1040,81 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
|
|||||||
op::Copy(old_init->operand(1))));
|
op::Copy(old_init->operand(1))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests while init instruction buffer which has a non-distinct points-to set:
|
||||||
|
//
|
||||||
|
// init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
|
||||||
|
// Parameter(F32, {8})))
|
||||||
|
//
|
||||||
|
// where the second and third parameters are identical *and* the tuple shared
|
||||||
|
// by another while instruction..
|
||||||
|
//
|
||||||
|
// Verifies that the resulting point-to set is distinct in the resulting Tuple
|
||||||
|
// (non-identical Copys). In other words, verifies that copy sharing does not
|
||||||
|
// insert identical copies to the resulting tuple.
|
||||||
|
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
|
||||||
|
auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation());
|
||||||
|
auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation());
|
||||||
|
// Loop body that outputs tuple comprises two elements dependent on the init
|
||||||
|
// tuple.
|
||||||
|
auto body1 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
|
||||||
|
auto body2 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName() + ".While");
|
||||||
|
|
||||||
|
auto iter_param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
|
||||||
|
auto data_param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(1, data_shape_, "data"));
|
||||||
|
|
||||||
|
// Loop init tuple contains two identical parameter buffers.
|
||||||
|
auto loop_init = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
|
||||||
|
|
||||||
|
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
|
||||||
|
{induction_variable_shape_, data_shape_, data_shape_});
|
||||||
|
|
||||||
|
// Two while loops shares the same loop init tuple.
|
||||||
|
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||||
|
loop_state_shape, condition1, body1, loop_init));
|
||||||
|
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||||
|
loop_state_shape, condition2, body2, loop_init));
|
||||||
|
|
||||||
|
module_.AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
auto points_to_analysis =
|
||||||
|
TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
|
||||||
|
|
||||||
|
// Asserts that the init tuples before copy insertion is non-distinct.
|
||||||
|
ASSERT_FALSE(
|
||||||
|
points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct());
|
||||||
|
ASSERT_FALSE(
|
||||||
|
points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct());
|
||||||
|
|
||||||
|
auto old_init1 = while_hlo1->operand(0);
|
||||||
|
auto old_init2 = while_hlo2->operand(0);
|
||||||
|
|
||||||
|
InsertCopies(&module_);
|
||||||
|
|
||||||
|
EXPECT_THAT(while_hlo1->operand(0),
|
||||||
|
op::Tuple(op::Copy(old_init1->operand(0)),
|
||||||
|
op::Copy(old_init1->operand(1)),
|
||||||
|
op::Copy(old_init1->operand(2))));
|
||||||
|
|
||||||
|
EXPECT_THAT(while_hlo2->operand(0),
|
||||||
|
op::Tuple(op::Copy(old_init2->operand(0)),
|
||||||
|
op::Copy(old_init2->operand(1)),
|
||||||
|
op::Copy(old_init2->operand(2))));
|
||||||
|
|
||||||
|
// Verifies the init tuples after copy insertion is distinct.
|
||||||
|
points_to_analysis = TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
|
||||||
|
const auto& points_to1 =
|
||||||
|
points_to_analysis->GetPointsToSet(while_hlo1->operand(0));
|
||||||
|
EXPECT_TRUE(points_to1.IsDistinct());
|
||||||
|
|
||||||
|
const auto& points_to2 =
|
||||||
|
points_to_analysis->GetPointsToSet(while_hlo2->operand(0));
|
||||||
|
EXPECT_TRUE(points_to2.IsDistinct());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
@ -31,7 +31,7 @@ AsyncExecution::AsyncExecution(Backend* backend,
|
|||||||
: backend_(CHECK_NOTNULL(backend)),
|
: backend_(CHECK_NOTNULL(backend)),
|
||||||
streams_(std::move(streams)),
|
streams_(std::move(streams)),
|
||||||
profile_(profile),
|
profile_(profile),
|
||||||
result_(result) {
|
result_(std::move(result)) {
|
||||||
for (const auto& stream : streams_) {
|
for (const auto& stream : streams_) {
|
||||||
CHECK(stream != nullptr);
|
CHECK(stream != nullptr);
|
||||||
}
|
}
|
||||||
|
@ -254,6 +254,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
|
|||||||
// d40 -- layer 4
|
// d40 -- layer 4
|
||||||
HloComputation::Builder builder("entry_computation");
|
HloComputation::Builder builder("entry_computation");
|
||||||
std::vector<HloInstruction*> params;
|
std::vector<HloInstruction*> params;
|
||||||
|
params.reserve(6);
|
||||||
for (int i = 0; i < 6; ++i) {
|
for (int i = 0; i < 6; ++i) {
|
||||||
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
|
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
|
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
|
||||||
|
@ -1631,6 +1631,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk(
|
|||||||
|
|
||||||
// Compute the input buffer indices.
|
// Compute the input buffer indices.
|
||||||
std::vector<BufferAllocation::Slice> io_buffers;
|
std::vector<BufferAllocation::Slice> io_buffers;
|
||||||
|
io_buffers.reserve(io_hlos.size());
|
||||||
for (const HloInstruction* io_hlo : io_hlos) {
|
for (const HloInstruction* io_hlo : io_hlos) {
|
||||||
io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo)));
|
io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo)));
|
||||||
}
|
}
|
||||||
|
@ -86,6 +86,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
|
|||||||
// d40 -- layer 4
|
// d40 -- layer 4
|
||||||
HloComputation::Builder builder("entry_computation");
|
HloComputation::Builder builder("entry_computation");
|
||||||
std::vector<HloInstruction*> params;
|
std::vector<HloInstruction*> params;
|
||||||
|
params.reserve(6);
|
||||||
for (int i = 0; i < 6; ++i) {
|
for (int i = 0; i < 6; ++i) {
|
||||||
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
|
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
|
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
|
||||||
|
@ -46,7 +46,7 @@ message HloInstructionProto {
|
|||||||
xla.OpMetadata metadata = 7;
|
xla.OpMetadata metadata = 7;
|
||||||
|
|
||||||
// Literal, only present for kConstant.
|
// Literal, only present for kConstant.
|
||||||
xla.Literal literal = 8;
|
xla.LiteralProto literal = 8;
|
||||||
|
|
||||||
// Parameter info, only present for kParameter.
|
// Parameter info, only present for kParameter.
|
||||||
int64 parameter_number = 9;
|
int64 parameter_number = 9;
|
||||||
|
@ -311,7 +311,6 @@ void ComputeComputationPostOrder(
|
|||||||
|
|
||||||
visited->insert(computation);
|
visited->insert(computation);
|
||||||
post_order->push_back(computation);
|
post_order->push_back(computation);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -65,7 +65,7 @@ using ::tensorflow::strings::StrCat;
|
|||||||
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
|
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
|
||||||
instruction->operands_.push_back(operand);
|
instruction->operands_.push_back(operand);
|
||||||
instruction->literal_.reset(new Literal);
|
instruction->literal_.reset(new Literal);
|
||||||
*instruction->literal_->mutable_u8s() += tag;
|
instruction->literal_->append_u8s(tag);
|
||||||
return instruction;
|
return instruction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1484,6 +1484,7 @@ string HloInstruction::ToString(bool compact_operands,
|
|||||||
}
|
}
|
||||||
if (!slice_starts_.empty() && !slice_limits_.empty()) {
|
if (!slice_starts_.empty() && !slice_limits_.empty()) {
|
||||||
std::vector<string> bounds;
|
std::vector<string> bounds;
|
||||||
|
bounds.reserve(slice_starts_.size());
|
||||||
for (int i = 0; i < slice_starts_.size(); ++i) {
|
for (int i = 0; i < slice_starts_.size(); ++i) {
|
||||||
bounds.push_back(
|
bounds.push_back(
|
||||||
StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]"));
|
StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]"));
|
||||||
@ -1550,7 +1551,7 @@ HloInstructionProto HloInstruction::ToProto() const {
|
|||||||
*proto.mutable_metadata() = metadata_;
|
*proto.mutable_metadata() = metadata_;
|
||||||
switch (opcode_) {
|
switch (opcode_) {
|
||||||
case HloOpcode::kConstant:
|
case HloOpcode::kConstant:
|
||||||
*proto.mutable_literal() = *literal_;
|
*proto.mutable_literal() = literal_->ToProto();
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kParameter:
|
case HloOpcode::kParameter:
|
||||||
proto.set_parameter_number(parameter_number_);
|
proto.set_parameter_number(parameter_number_);
|
||||||
@ -1647,10 +1648,10 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
|
|||||||
trace_instruction_ = trace_instruction;
|
trace_instruction_ = trace_instruction;
|
||||||
}
|
}
|
||||||
|
|
||||||
const string& HloInstruction::tracing_tag() const {
|
string HloInstruction::TracingTag() const {
|
||||||
CHECK_EQ(HloOpcode::kTrace, opcode());
|
CHECK_EQ(HloOpcode::kTrace, opcode());
|
||||||
CHECK(literal_ != nullptr);
|
CHECK(literal_ != nullptr);
|
||||||
return literal_->u8s();
|
return literal_->u8s_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloInstruction::IsFused() const {
|
bool HloInstruction::IsFused() const {
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||||
@ -535,7 +536,7 @@ class HloInstruction {
|
|||||||
// Returns a tag to be used in tracing.
|
// Returns a tag to be used in tracing.
|
||||||
//
|
//
|
||||||
// Precondition: opcode() == HloOpcode::kTrace
|
// Precondition: opcode() == HloOpcode::kTrace
|
||||||
const string& tracing_tag() const;
|
string TracingTag() const;
|
||||||
|
|
||||||
// Returns whether the instruction is a constant.
|
// Returns whether the instruction is a constant.
|
||||||
bool IsConstant() const;
|
bool IsConstant() const;
|
||||||
|
@ -151,7 +151,26 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (std::all_of(hlo->users().begin(), hlo->users().end(),
|
// An "effectively unary" operation is one that has one "large"
|
||||||
|
// input with the others being negligible in terms of memory usage.
|
||||||
|
// We use "has a smaller true rank than the output" as a heuristic
|
||||||
|
// for "negligible" memory usage.
|
||||||
|
auto effectively_unary = [](HloInstruction* hlo) {
|
||||||
|
if (hlo->operands().size() == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto output_rank = ShapeUtil::TrueRank(hlo->shape());
|
||||||
|
return std::count_if(
|
||||||
|
hlo->operands().begin(), hlo->operands().end(),
|
||||||
|
[output_rank](HloInstruction* operand) {
|
||||||
|
return ((operand->opcode() != HloOpcode::kBroadcast) &&
|
||||||
|
ShapeUtil::TrueRank(operand->shape()) >=
|
||||||
|
output_rank);
|
||||||
|
}) <= 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (effectively_unary(hlo) ||
|
||||||
|
std::all_of(hlo->users().begin(), hlo->users().end(),
|
||||||
user_fusable_into_hlo)) {
|
user_fusable_into_hlo)) {
|
||||||
all_consumers_fusable.insert(hlo);
|
all_consumers_fusable.insert(hlo);
|
||||||
}
|
}
|
||||||
|
@ -156,21 +156,67 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
|
|||||||
|
|
||||||
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
|
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||||
0, ShapeUtil::MakeShape(F32, {16, 16}), "0"));
|
auto param0 =
|
||||||
HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
|
||||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0));
|
auto param1 =
|
||||||
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
|
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
|
||||||
HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary(
|
HloInstruction* binary1 = builder.AddInstruction(
|
||||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1));
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
|
||||||
|
HloInstruction* unary = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
|
||||||
|
|
||||||
auto module = MakeUnique<HloModule>(TestName());
|
auto module = MakeUnique<HloModule>(TestName());
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
EXPECT_EQ(unary2, computation->root_instruction());
|
EXPECT_EQ(unary, computation->root_instruction());
|
||||||
EXPECT_FALSE(
|
EXPECT_FALSE(
|
||||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||||
.Run(module.get())
|
.Run(module.get())
|
||||||
.ValueOrDie());
|
.ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||||
|
auto param0 =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
|
||||||
|
HloInstruction* unary1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
|
||||||
|
HloInstruction* unary2 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
|
||||||
|
|
||||||
|
auto module = MakeUnique<HloModule>(TestName());
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
EXPECT_EQ(unary2, computation->root_instruction());
|
||||||
|
EXPECT_TRUE(
|
||||||
|
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||||
|
.Run(module.get())
|
||||||
|
.ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
|
||||||
|
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||||
|
auto small_shape = ShapeUtil::MakeShape(F32, {16});
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
auto param0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, small_shape, "0"));
|
||||||
|
auto param1 =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
|
||||||
|
HloInstruction* binary1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
|
||||||
|
HloInstruction* unary = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
|
||||||
|
|
||||||
|
auto module = MakeUnique<HloModule>(TestName());
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
EXPECT_EQ(unary, computation->root_instruction());
|
||||||
|
EXPECT_TRUE(
|
||||||
|
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||||
|
.Run(module.get())
|
||||||
|
.ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "external/llvm/include/llvm/IR/Module.h"
|
#include "external/llvm/include/llvm/IR/Module.h"
|
||||||
#include "external/llvm/include/llvm/IR/Value.h"
|
#include "external/llvm/include/llvm/IR/Value.h"
|
||||||
#include "external/llvm/include/llvm/Support/raw_ostream.h"
|
#include "external/llvm/include/llvm/Support/raw_ostream.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
@ -77,8 +77,10 @@ tensorflow::Status RecordArguments(
|
|||||||
SessionModule* module) {
|
SessionModule* module) {
|
||||||
module->clear_arguments();
|
module->clear_arguments();
|
||||||
for (const Allocation* allocation : arg_allocations) {
|
for (const Allocation* allocation : arg_allocations) {
|
||||||
TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(),
|
Literal argument;
|
||||||
module->add_arguments()));
|
TF_RETURN_IF_ERROR(
|
||||||
|
LiteralFromAllocation(allocation, allocation->shape(), &argument));
|
||||||
|
*module->add_arguments() = argument.ToProto();
|
||||||
}
|
}
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
@ -87,8 +89,11 @@ tensorflow::Status RecordArguments(
|
|||||||
tensorflow::Status RecordResult(const Allocation* result_allocation,
|
tensorflow::Status RecordResult(const Allocation* result_allocation,
|
||||||
SessionModule* module) {
|
SessionModule* module) {
|
||||||
module->clear_result();
|
module->clear_result();
|
||||||
return LiteralFromAllocation(result_allocation, result_allocation->shape(),
|
Literal result;
|
||||||
module->mutable_result());
|
TF_RETURN_IF_ERROR(LiteralFromAllocation(
|
||||||
|
result_allocation, result_allocation->shape(), &result));
|
||||||
|
*module->mutable_result() = result.ToProto();
|
||||||
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -649,6 +654,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
|||||||
ResolveAndValidateArguments(request.arguments(), execute_backend_.get(),
|
ResolveAndValidateArguments(request.arguments(), execute_backend_.get(),
|
||||||
executor->device_ordinal()));
|
executor->device_ordinal()));
|
||||||
std::vector<se::DeviceMemoryBase> arguments;
|
std::vector<se::DeviceMemoryBase> arguments;
|
||||||
|
arguments.reserve(arg_allocations.size());
|
||||||
for (const Allocation* allocation : arg_allocations) {
|
for (const Allocation* allocation : arg_allocations) {
|
||||||
arguments.push_back(allocation->device_memory());
|
arguments.push_back(allocation->device_memory());
|
||||||
}
|
}
|
||||||
@ -677,6 +683,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
|||||||
BuildExecutables(versioned_handles, std::move(module_configs),
|
BuildExecutables(versioned_handles, std::move(module_configs),
|
||||||
execute_backend_.get(), executors));
|
execute_backend_.get(), executors));
|
||||||
std::vector<Executable*> executable_ptrs;
|
std::vector<Executable*> executable_ptrs;
|
||||||
|
executable_ptrs.reserve(executables.size());
|
||||||
for (const auto& executable : executables) {
|
for (const auto& executable : executables) {
|
||||||
executable_ptrs.push_back(executable.get());
|
executable_ptrs.push_back(executable.get());
|
||||||
}
|
}
|
||||||
@ -752,6 +759,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
|
|||||||
<< module_config->entry_computation_layout().ToString();
|
<< module_config->entry_computation_layout().ToString();
|
||||||
|
|
||||||
std::vector<se::DeviceMemoryBase> arguments;
|
std::vector<se::DeviceMemoryBase> arguments;
|
||||||
|
arguments.reserve(arg_allocations.size());
|
||||||
for (const Allocation* allocation : arg_allocations) {
|
for (const Allocation* allocation : arg_allocations) {
|
||||||
arguments.push_back(allocation->device_memory());
|
arguments.push_back(allocation->device_memory());
|
||||||
}
|
}
|
||||||
@ -820,6 +828,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
|
|||||||
<< module_config->entry_computation_layout().ToString();
|
<< module_config->entry_computation_layout().ToString();
|
||||||
|
|
||||||
std::vector<se::DeviceMemoryBase> arguments;
|
std::vector<se::DeviceMemoryBase> arguments;
|
||||||
|
arguments.reserve(arg_allocations.size());
|
||||||
for (const Allocation* allocation : arg_allocations) {
|
for (const Allocation* allocation : arg_allocations) {
|
||||||
arguments.push_back(allocation->device_memory());
|
arguments.push_back(allocation->device_memory());
|
||||||
}
|
}
|
||||||
@ -908,13 +917,15 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
|
|||||||
literal_shape = &allocation->shape();
|
literal_shape = &allocation->shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
return LiteralFromAllocation(allocation, *literal_shape,
|
Literal literal;
|
||||||
result->mutable_literal());
|
auto status = LiteralFromAllocation(allocation, *literal_shape, &literal);
|
||||||
|
*result->mutable_literal() = literal.ToProto();
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||||
TransferToServerResponse* result) {
|
TransferToServerResponse* result) {
|
||||||
const Literal& literal = arg->literal();
|
Literal literal = Literal(arg->literal());
|
||||||
const Shape& shape = literal.shape();
|
const Shape& shape = literal.shape();
|
||||||
|
|
||||||
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
|
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
|
||||||
@ -978,7 +989,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
|||||||
}
|
}
|
||||||
|
|
||||||
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
|
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
|
||||||
executor, arg->literal());
|
executor, Literal(arg->literal()));
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Service::TransferFromOutfeed(
|
tensorflow::Status Service::TransferFromOutfeed(
|
||||||
@ -1001,8 +1012,12 @@ tensorflow::Status Service::TransferFromOutfeed(
|
|||||||
executor = execute_backend_->Replicas()[arg->replica_id()];
|
executor = execute_backend_->Replicas()[arg->replica_id()];
|
||||||
}
|
}
|
||||||
|
|
||||||
return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
|
Literal literal;
|
||||||
executor, arg->shape_with_layout(), result->mutable_literal());
|
TF_RETURN_IF_ERROR(
|
||||||
|
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
|
||||||
|
executor, arg->shape_with_layout(), &literal));
|
||||||
|
*result->mutable_literal() = literal.ToProto();
|
||||||
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
|
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
|
||||||
|
@ -75,10 +75,10 @@ message SessionModule {
|
|||||||
repeated SessionComputation embedded_computations = 2;
|
repeated SessionComputation embedded_computations = 2;
|
||||||
|
|
||||||
// The arguments passed to the computation.
|
// The arguments passed to the computation.
|
||||||
repeated Literal arguments = 3;
|
repeated LiteralProto arguments = 3;
|
||||||
|
|
||||||
// The result of the computation.
|
// The result of the computation.
|
||||||
Literal result = 4;
|
LiteralProto result = 4;
|
||||||
|
|
||||||
// The name of the platform used to run the computation.
|
// The name of the platform used to run the computation.
|
||||||
string execution_platform = 5;
|
string execution_platform = 5;
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
@ -121,7 +121,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) {
|
|||||||
const Shape shape = ShapeUtil::MakeShape(U8, {4});
|
const Shape shape = ShapeUtil::MakeShape(U8, {4});
|
||||||
TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
|
TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
|
||||||
stream_exec_, memptr, shape, shape, &literal));
|
stream_exec_, memptr, shape, shape, &literal));
|
||||||
CHECK_EQ("klmn", literal.u8s());
|
CHECK_EQ("klmn", literal.u8s_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {
|
TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {
|
||||||
|
@ -2275,7 +2275,7 @@ void ComputationLowerer::Visit(
|
|||||||
const ConstantRequest& constant_request =
|
const ConstantRequest& constant_request =
|
||||||
request.request().constant_request();
|
request.request().constant_request();
|
||||||
hlo_instruction = add_instruction(HloInstruction::CreateConstant(
|
hlo_instruction = add_instruction(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::CloneToUnique(constant_request.literal())));
|
LiteralUtil::CloneToUnique(Literal(constant_request.literal()))));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2467,6 +2467,7 @@ void ComputationLowerer::Visit(
|
|||||||
// to append dimensions on the left the broadcast_dimensions should just
|
// to append dimensions on the left the broadcast_dimensions should just
|
||||||
// be the n highest dimension numbers of the output shape where n is
|
// be the n highest dimension numbers of the output shape where n is
|
||||||
// the number of input dimensions.
|
// the number of input dimensions.
|
||||||
|
broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape()));
|
||||||
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
|
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
|
||||||
broadcast_dimensions.push_back(i +
|
broadcast_dimensions.push_back(i +
|
||||||
ShapeUtil::Rank(request.output_shape()) -
|
ShapeUtil::Rank(request.output_shape()) -
|
||||||
|
@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) {
|
|||||||
|
|
||||||
ConstantRequest constant_request;
|
ConstantRequest constant_request;
|
||||||
*constant_request.mutable_literal() =
|
*constant_request.mutable_literal() =
|
||||||
*LiteralUtil::CreateR1<float>({123.0f, 42.0f});
|
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
|
||||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle,
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle,
|
||||||
computation.AddConstantInstruction(constant_request));
|
computation.AddConstantInstruction(constant_request));
|
||||||
|
|
||||||
@ -160,12 +160,13 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
|
|||||||
UserComputation computation("TheComputation", handle);
|
UserComputation computation("TheComputation", handle);
|
||||||
|
|
||||||
ConstantRequest a_request;
|
ConstantRequest a_request;
|
||||||
*a_request.mutable_literal() = *LiteralUtil::CreateR1<float>({123.0f, 42.0f});
|
*a_request.mutable_literal() =
|
||||||
|
LiteralUtil::CreateR1<float>({123.0f, 42.0f})->ToProto();
|
||||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
|
||||||
computation.AddConstantInstruction(a_request));
|
computation.AddConstantInstruction(a_request));
|
||||||
|
|
||||||
ConstantRequest b_request;
|
ConstantRequest b_request;
|
||||||
*b_request.mutable_literal() = *LiteralUtil::CreateR0<float>(1.0f);
|
*b_request.mutable_literal() = LiteralUtil::CreateR0<float>(1.0f)->ToProto();
|
||||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
|
||||||
computation.AddConstantInstruction(b_request));
|
computation.AddConstantInstruction(b_request));
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ struct ShapeTreeNode {
|
|||||||
// Children of this node.
|
// Children of this node.
|
||||||
std::vector<std::unique_ptr<ShapeTreeNode>> children;
|
std::vector<std::unique_ptr<ShapeTreeNode>> children;
|
||||||
|
|
||||||
|
ShapeTreeNode() = default;
|
||||||
explicit ShapeTreeNode(const T& data) : data(data) {}
|
explicit ShapeTreeNode(const T& data) : data(data) {}
|
||||||
|
|
||||||
ShapeTreeNode(const ShapeTreeNode& other)
|
ShapeTreeNode(const ShapeTreeNode& other)
|
||||||
@ -85,8 +86,9 @@ class ShapeTree {
|
|||||||
public:
|
public:
|
||||||
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
|
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
|
||||||
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
|
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
|
||||||
// Create ShapeTree with the given shape, and default T values for all nodes.
|
// Create ShapeTree with the given shape, and default-constructed T values for
|
||||||
explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {}
|
// all nodes.
|
||||||
|
explicit ShapeTree(const Shape& shape);
|
||||||
// Create ShapeTree with the given shape, and init_value for all nodes.
|
// Create ShapeTree with the given shape, and init_value for all nodes.
|
||||||
ShapeTree(const Shape& shape, const T& init_value);
|
ShapeTree(const Shape& shape, const T& init_value);
|
||||||
|
|
||||||
@ -127,6 +129,19 @@ class ShapeTree {
|
|||||||
const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
|
const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
|
||||||
Status ForEachMutableElement(const MutableVisitorFunction& func);
|
Status ForEachMutableElement(const MutableVisitorFunction& func);
|
||||||
|
|
||||||
|
// Copy the subtree of values from 'other' rooted at ShapeIndex
|
||||||
|
// 'source_base_index' into the subtree of value in this ShapeTree rooted at
|
||||||
|
// 'target_base_index'.
|
||||||
|
//
|
||||||
|
// Precondition: The subshape of other.shape() at index source_base_index must
|
||||||
|
// be compatible with the subshape of shape() at index target_base_index.
|
||||||
|
void CopySubtreeFrom(const ShapeTree<T>& other,
|
||||||
|
const ShapeIndex& source_base_index,
|
||||||
|
const ShapeIndex& target_base_index);
|
||||||
|
|
||||||
|
bool operator==(const ShapeTree<T>& other) const;
|
||||||
|
bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
using Node = internal::ShapeTreeNode<T>;
|
using Node = internal::ShapeTreeNode<T>;
|
||||||
|
|
||||||
@ -134,6 +149,10 @@ class ShapeTree {
|
|||||||
// the given 'init_value'.
|
// the given 'init_value'.
|
||||||
void InitChildren(const Shape& shape, const T& init_value, Node* node);
|
void InitChildren(const Shape& shape, const T& init_value, Node* node);
|
||||||
|
|
||||||
|
// Initialize node->children based on 'shape'. All children have
|
||||||
|
// default-constructed data values.
|
||||||
|
void InitChildren(const Shape& shape, Node* node);
|
||||||
|
|
||||||
// Helpers for traversing the shape via ForEachElement. The helpers
|
// Helpers for traversing the shape via ForEachElement. The helpers
|
||||||
// recursively traverse the subtree rooted at "index" (defined as in
|
// recursively traverse the subtree rooted at "index" (defined as in
|
||||||
// ShapeUtil::GetSubshape).
|
// ShapeUtil::GetSubshape).
|
||||||
@ -165,6 +184,24 @@ void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
|
||||||
|
if (ShapeUtil::IsTuple(shape)) {
|
||||||
|
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||||
|
node->children.emplace_back(new Node());
|
||||||
|
InitChildren(shape.tuple_shapes(i), node->children.back().get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ShapeTree<T>::ShapeTree(const Shape& shape) : root_(), shape_(shape) {
|
||||||
|
// The shape_ field is just used to hold the structure of the shape.
|
||||||
|
// It should not be relied upon to store layout information.
|
||||||
|
LayoutUtil::ClearLayout(&shape_);
|
||||||
|
InitChildren(shape_, &root_);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
|
ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
|
||||||
: root_(init_value), shape_(shape) {
|
: root_(init_value), shape_(shape) {
|
||||||
@ -240,6 +277,48 @@ Status ShapeTree<T>::ForEachMutableElement(const MutableVisitorFunction& func) {
|
|||||||
return ForEachMutableHelper(func, &root_, &index);
|
return ForEachMutableHelper(func, &root_, &index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
|
||||||
|
const ShapeIndex& source_base_index,
|
||||||
|
const ShapeIndex& target_base_index) {
|
||||||
|
CHECK(ShapeUtil::Compatible(
|
||||||
|
ShapeUtil::GetSubshape(shape(), target_base_index),
|
||||||
|
ShapeUtil::GetSubshape(other.shape(), source_base_index)));
|
||||||
|
ForEachMutableElement(
|
||||||
|
[this, &other, &source_base_index, &target_base_index](
|
||||||
|
const ShapeIndex& index, bool /*is_leaf*/, T* data) {
|
||||||
|
// Copy the data element only if index is in the
|
||||||
|
// subtree rooted at target_base_index.
|
||||||
|
for (int i = 0; i < target_base_index.size(); ++i) {
|
||||||
|
if (i >= index.size() || index[i] != target_base_index[i]) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Construct source element index to copy from.
|
||||||
|
ShapeIndex source_index = source_base_index;
|
||||||
|
for (int i = target_base_index.size(); i < index.size(); ++i) {
|
||||||
|
source_index.push_back(index[i]);
|
||||||
|
}
|
||||||
|
*data = other.element(source_index);
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.IgnoreError();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
|
||||||
|
bool equal = true;
|
||||||
|
ForEachElement([this, &other, &equal](const ShapeIndex& index,
|
||||||
|
bool /*is_leaf*/, const T& data) {
|
||||||
|
if (data != other.element(index)) {
|
||||||
|
equal = false;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.IgnoreError();
|
||||||
|
return equal;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
|
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
|
||||||
|
@ -245,5 +245,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
|
|||||||
EXPECT_DEATH(shape_tree.element({0, 0}), "");
|
EXPECT_DEATH(shape_tree.element({0, 0}), "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
|
||||||
|
ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
|
||||||
|
EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
|
||||||
|
*shape_tree.mutable_element({2}) = MakeUnique<int>(42);
|
||||||
|
EXPECT_EQ(*shape_tree.element({2}), 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) {
|
||||||
|
// Test CopySubtreeFrom method for a single value copied between array-shaped
|
||||||
|
// ShapeTrees.
|
||||||
|
ShapeTree<int> source(array_shape_);
|
||||||
|
*source.mutable_element(/*index=*/{}) = 42;
|
||||||
|
ShapeTree<int> destination(array_shape_, 123);
|
||||||
|
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 123);
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||||
|
/*target_base_index=*/{});
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) {
|
||||||
|
// Test CopySubtreeFrom method for a copy of all elements from one
|
||||||
|
// tuple-shaped ShapeTree to another.
|
||||||
|
ShapeTree<int> source(tuple_shape_);
|
||||||
|
*source.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*source.mutable_element(/*index=*/{0}) = 11;
|
||||||
|
*source.mutable_element(/*index=*/{1}) = 12;
|
||||||
|
*source.mutable_element(/*index=*/{2}) = 13;
|
||||||
|
|
||||||
|
ShapeTree<int> destination(tuple_shape_, 0);
|
||||||
|
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||||
|
/*target_base_index=*/{});
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 10);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2}), 13);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) {
|
||||||
|
// Test CopySubtreeFrom method for a copy of a single element from one
|
||||||
|
// tuple-shaped ShapeTree to another.
|
||||||
|
ShapeTree<int> source(tuple_shape_);
|
||||||
|
*source.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*source.mutable_element(/*index=*/{0}) = 11;
|
||||||
|
*source.mutable_element(/*index=*/{1}) = 12;
|
||||||
|
*source.mutable_element(/*index=*/{2}) = 13;
|
||||||
|
|
||||||
|
ShapeTree<int> destination(tuple_shape_, 0);
|
||||||
|
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{0},
|
||||||
|
/*target_base_index=*/{1});
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1}), 11);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) {
|
||||||
|
// Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a
|
||||||
|
// nested-tuple-shaped ShapeTree.
|
||||||
|
ShapeTree<int> source(
|
||||||
|
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}));
|
||||||
|
*source.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*source.mutable_element(/*index=*/{0}) = 11;
|
||||||
|
*source.mutable_element(/*index=*/{1}) = 12;
|
||||||
|
|
||||||
|
ShapeTree<int> destination(nested_tuple_shape_, 0);
|
||||||
|
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
|
||||||
|
/*target_base_index=*/{2, 0});
|
||||||
|
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) {
|
||||||
|
// Test CopySubtreeFrom method for a copy from a nested-tuple-shape.
|
||||||
|
ShapeTree<int> source(nested_tuple_shape_, 42);
|
||||||
|
*source.mutable_element(/*index=*/{1}) = 10;
|
||||||
|
*source.mutable_element(/*index=*/{1, 0}) = 11;
|
||||||
|
*source.mutable_element(/*index=*/{1, 1}) = 12;
|
||||||
|
|
||||||
|
ShapeTree<int> destination(
|
||||||
|
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0);
|
||||||
|
|
||||||
|
destination.CopySubtreeFrom(source, /*source_base_index=*/{1},
|
||||||
|
/*target_base_index=*/{});
|
||||||
|
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{}), 10);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
|
||||||
|
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeTreeTest, OperatorEquals) {
|
||||||
|
{
|
||||||
|
ShapeTree<int> a(array_shape_, 123);
|
||||||
|
ShapeTree<int> b(array_shape_, 42);
|
||||||
|
ShapeTree<int> c(array_shape_, 42);
|
||||||
|
EXPECT_FALSE(a == b);
|
||||||
|
EXPECT_TRUE(a != b);
|
||||||
|
EXPECT_TRUE(b == c);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ShapeTree<int> a(tuple_shape_);
|
||||||
|
*a.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*a.mutable_element(/*index=*/{0}) = 11;
|
||||||
|
*a.mutable_element(/*index=*/{1}) = 12;
|
||||||
|
|
||||||
|
ShapeTree<int> b(tuple_shape_);
|
||||||
|
*b.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*b.mutable_element(/*index=*/{0}) = 42;
|
||||||
|
*b.mutable_element(/*index=*/{1}) = 11;
|
||||||
|
|
||||||
|
ShapeTree<int> c(tuple_shape_);
|
||||||
|
*c.mutable_element(/*index=*/{}) = 10;
|
||||||
|
*c.mutable_element(/*index=*/{0}) = 42;
|
||||||
|
*c.mutable_element(/*index=*/{1}) = 11;
|
||||||
|
|
||||||
|
EXPECT_FALSE(a == b);
|
||||||
|
EXPECT_TRUE(a != b);
|
||||||
|
EXPECT_TRUE(b == c);
|
||||||
|
EXPECT_FALSE(b != c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -122,7 +122,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
|||||||
for (const auto& shape : parameters) {
|
for (const auto& shape : parameters) {
|
||||||
*program_shape.add_parameters() = shape;
|
*program_shape.add_parameters() = shape;
|
||||||
}
|
}
|
||||||
*program_shape.mutable_result() = result;
|
*program_shape.mutable_result() = std::move(result);
|
||||||
return program_shape;
|
return program_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -829,6 +829,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
|
|||||||
const int count = GetParam();
|
const int count = GetParam();
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
std::vector<float> values;
|
std::vector<float> values;
|
||||||
|
values.reserve(count);
|
||||||
for (int i = 0; i < count; ++i) {
|
for (int i = 0; i < count; ++i) {
|
||||||
values.push_back(i / static_cast<float>(count));
|
values.push_back(i / static_cast<float>(count));
|
||||||
}
|
}
|
||||||
@ -836,6 +837,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
|
|||||||
auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
|
auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
|
||||||
|
|
||||||
std::vector<float> expected;
|
std::vector<float> expected;
|
||||||
|
expected.reserve(values.size());
|
||||||
for (float value : values) {
|
for (float value : values) {
|
||||||
expected.push_back(value * value);
|
expected.push_back(value * value);
|
||||||
}
|
}
|
||||||
|
@ -179,7 +179,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
|
|||||||
VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
|
VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
|
||||||
VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
|
VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
|
||||||
|
|
||||||
EXPECT_EQ(expected, actual->u8s());
|
EXPECT_EQ(expected, actual->u8s_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClientLibraryTestBase::ComputeAndCompareTuple(
|
void ClientLibraryTestBase::ComputeAndCompareTuple(
|
||||||
|
@ -442,6 +442,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
|
|||||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
Array3D<float> arr0(9, 17, 1);
|
||||||
|
arr0.Fill(1);
|
||||||
|
|
||||||
|
Array3D<float> arr1(9, 17, 256);
|
||||||
|
arr1.Fill(2);
|
||||||
|
|
||||||
|
Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
|
||||||
|
for (int64 i = 0; i < expected.n1(); ++i) {
|
||||||
|
for (int64 j = 0; j < expected.n2(); ++j) {
|
||||||
|
int64 kk = 0;
|
||||||
|
for (const Array3D<float>& arr : {arr0, arr1}) {
|
||||||
|
for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
|
||||||
|
expected(i, j, kk) = arr(i, j, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputationDataHandle h0;
|
||||||
|
auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
|
||||||
|
&builder, &h0);
|
||||||
|
ComputationDataHandle h1;
|
||||||
|
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
|
||||||
|
&builder, &h1);
|
||||||
|
|
||||||
|
auto concatenated = builder.ConcatInDim({h0, h1}, 2);
|
||||||
|
|
||||||
|
ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
|
||||||
|
}
|
||||||
|
|
||||||
// Describes a binary rank-2 concatenation test.
|
// Describes a binary rank-2 concatenation test.
|
||||||
struct R2BinarySpec {
|
struct R2BinarySpec {
|
||||||
int64 lhs_dim0;
|
int64 lhs_dim0;
|
||||||
|
@ -262,7 +262,7 @@ class NearComparator {
|
|||||||
max_abs_err_ = 0.0;
|
max_abs_err_ = 0.0;
|
||||||
*miscompares_.mutable_shape() =
|
*miscompares_.mutable_shape() =
|
||||||
ShapeUtil::ChangeElementType(actual.shape(), PRED);
|
ShapeUtil::ChangeElementType(actual.shape(), PRED);
|
||||||
miscompares_.mutable_preds()->Resize(
|
miscompares_.mutable_preds()->resize(
|
||||||
ShapeUtil::ElementsIn(miscompares_.shape()), false);
|
ShapeUtil::ElementsIn(miscompares_.shape()), false);
|
||||||
multi_index_.resize(expected.shape().dimensions_size(), 0);
|
multi_index_.resize(expected.shape().dimensions_size(), 0);
|
||||||
|
|
||||||
@ -389,7 +389,7 @@ class NearComparator {
|
|||||||
tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
|
tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
|
||||||
now_usec, name.c_str()));
|
now_usec, name.c_str()));
|
||||||
TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
|
TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
|
||||||
filename, literal));
|
filename, literal.ToProto()));
|
||||||
LOG(ERROR) << "wrote to " << name << " file: " << filename;
|
LOG(ERROR) << "wrote to " << name << " file: " << filename;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,9 +83,10 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
|
|||||||
LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
|
LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
|
||||||
EXPECT_EQ(3, results.size());
|
EXPECT_EQ(3, results.size());
|
||||||
for (const string& result : results) {
|
for (const string& result : results) {
|
||||||
Literal literal;
|
LiteralProto literal_proto;
|
||||||
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
|
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
|
||||||
&literal));
|
&literal_proto));
|
||||||
|
Literal literal(literal_proto);
|
||||||
if (result.find("expected") != string::npos) {
|
if (result.find("expected") != string::npos) {
|
||||||
EXPECT_EQ("2", LiteralUtil::ToString(literal));
|
EXPECT_EQ("2", LiteralUtil::ToString(literal));
|
||||||
} else if (result.find("actual") != string::npos) {
|
} else if (result.find("actual") != string::npos) {
|
||||||
|
@ -47,6 +47,7 @@ TEST_F(LogTest, LogTenValues) {
|
|||||||
builder.Log(x);
|
builder.Log(x);
|
||||||
|
|
||||||
std::vector<float> expected;
|
std::vector<float> expected;
|
||||||
|
expected.reserve(input.size());
|
||||||
for (float f : input) {
|
for (float f : input) {
|
||||||
expected.push_back(std::log(f));
|
expected.push_back(std::log(f));
|
||||||
}
|
}
|
||||||
|
@ -246,6 +246,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<GlobalData*> param_data;
|
std::vector<GlobalData*> param_data;
|
||||||
|
param_data.reserve(param_data_owner.size());
|
||||||
for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
|
for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
|
||||||
param_data.push_back(data.get());
|
param_data.push_back(data.get());
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,7 @@ class SliceTest : public ClientLibraryTestBase {
|
|||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
void RunSliceTenToTwo() {
|
void RunSliceTenToTwo() {
|
||||||
std::vector<NativeT> constant;
|
std::vector<NativeT> constant;
|
||||||
|
constant.reserve(10);
|
||||||
for (int i = 0; i < 10; ++i) {
|
for (int i = 0; i < 10; ++i) {
|
||||||
constant.push_back(static_cast<NativeT>(i));
|
constant.push_back(static_cast<NativeT>(i));
|
||||||
}
|
}
|
||||||
|
@ -64,6 +64,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
|
|||||||
for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
|
for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
std::vector<float> exponents;
|
std::vector<float> exponents;
|
||||||
|
exponents.reserve(count);
|
||||||
for (int i = 0; i < count; ++i) {
|
for (int i = 0; i < count; ++i) {
|
||||||
exponents.push_back(i / static_cast<float>(count));
|
exponents.push_back(i / static_cast<float>(count));
|
||||||
}
|
}
|
||||||
@ -71,6 +72,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) {
|
|||||||
auto exp = builder.Exp(x);
|
auto exp = builder.Exp(x);
|
||||||
|
|
||||||
std::vector<float> expected;
|
std::vector<float> expected;
|
||||||
|
expected.reserve(exponents.size());
|
||||||
for (float exponent : exponents) {
|
for (float exponent : exponents) {
|
||||||
expected.push_back(std::exp(exponent));
|
expected.push_back(std::exp(exponent));
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
@ -81,6 +81,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
|
|||||||
client->GetComputationShape(computation).ConsumeValueOrDie();
|
client->GetComputationShape(computation).ConsumeValueOrDie();
|
||||||
|
|
||||||
std::vector<const Shape*> layouts;
|
std::vector<const Shape*> layouts;
|
||||||
|
layouts.reserve(program_shape->parameters_size());
|
||||||
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
||||||
layouts.push_back(&program_shape->parameters(i));
|
layouts.push_back(&program_shape->parameters(i));
|
||||||
}
|
}
|
||||||
|
@ -56,6 +56,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
|
|||||||
client->GetComputationShape(computation).ConsumeValueOrDie();
|
client->GetComputationShape(computation).ConsumeValueOrDie();
|
||||||
|
|
||||||
std::vector<const Shape*> layouts;
|
std::vector<const Shape*> layouts;
|
||||||
|
layouts.reserve(program_shape->parameters_size());
|
||||||
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
for (int i = 0; i < program_shape->parameters_size(); ++i) {
|
||||||
layouts.push_back(&program_shape->parameters(i));
|
layouts.push_back(&program_shape->parameters(i));
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,8 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
|
|||||||
if (use_fake_data) {
|
if (use_fake_data) {
|
||||||
arguments = MakeFakeArgumentsOrDie(computation, client);
|
arguments = MakeFakeArgumentsOrDie(computation, client);
|
||||||
} else { // use recorded data if available
|
} else { // use recorded data if available
|
||||||
for (const Literal& literal : module.arguments()) {
|
for (const auto& proto : module.arguments()) {
|
||||||
|
Literal literal(proto);
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
|
||||||
client->TransferToServer(literal));
|
client->TransferToServer(literal));
|
||||||
arguments.push_back(std::move(data));
|
arguments.push_back(std::move(data));
|
||||||
@ -74,6 +75,7 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<GlobalData*> execute_arguments;
|
std::vector<GlobalData*> execute_arguments;
|
||||||
|
execute_arguments.reserve(arguments.size());
|
||||||
for (auto& argument : arguments) {
|
for (auto& argument : arguments) {
|
||||||
execute_arguments.push_back(argument.get());
|
execute_arguments.push_back(argument.get());
|
||||||
}
|
}
|
||||||
@ -100,7 +102,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
|
|||||||
if (module.has_result()) {
|
if (module.has_result()) {
|
||||||
fprintf(stdout, "was %s:%s\n",
|
fprintf(stdout, "was %s:%s\n",
|
||||||
ShapeUtil::HumanString(module.result().shape()).c_str(),
|
ShapeUtil::HumanString(module.result().shape()).c_str(),
|
||||||
LiteralUtil::ToString(module.result()).c_str());
|
LiteralUtil::ToString(Literal(module.result())).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -37,9 +37,10 @@ int main(int argc, char **argv) {
|
|||||||
<< " <path-to-serialized-literal-proto>";
|
<< " <path-to-serialized-literal-proto>";
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::Literal literal;
|
xla::LiteralProto literal_proto;
|
||||||
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
|
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
|
||||||
&literal));
|
&literal_proto));
|
||||||
LOG(INFO) << "literal: " << literal.ShortDebugString();
|
xla::Literal literal(literal_proto);
|
||||||
|
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
|
||||||
fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str());
|
fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str());
|
||||||
}
|
}
|
||||||
|
@ -92,11 +92,11 @@ message TransferToClientRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message TransferToClientResponse {
|
message TransferToClientResponse {
|
||||||
Literal literal = 1;
|
LiteralProto literal = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TransferToServerRequest {
|
message TransferToServerRequest {
|
||||||
Literal literal = 1;
|
LiteralProto literal = 1;
|
||||||
DeviceHandle device_handle = 2;
|
DeviceHandle device_handle = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ message TransferToServerResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message TransferToInfeedRequest {
|
message TransferToInfeedRequest {
|
||||||
Literal literal = 1;
|
LiteralProto literal = 1;
|
||||||
int64 replica_id = 2;
|
int64 replica_id = 2;
|
||||||
DeviceHandle device_handle = 3;
|
DeviceHandle device_handle = 3;
|
||||||
}
|
}
|
||||||
@ -123,7 +123,7 @@ message TransferFromOutfeedRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message TransferFromOutfeedResponse {
|
message TransferFromOutfeedResponse {
|
||||||
Literal literal = 1;
|
LiteralProto literal = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ResetDeviceRequest {
|
message ResetDeviceRequest {
|
||||||
|
@ -275,7 +275,7 @@ message ChannelHandle {
|
|||||||
//
|
//
|
||||||
// Transfers to/from the client are encoded in literal form, and the structure
|
// Transfers to/from the client are encoded in literal form, and the structure
|
||||||
// of the repeated fields is implied by the shape.
|
// of the repeated fields is implied by the shape.
|
||||||
message Literal {
|
message LiteralProto {
|
||||||
Shape shape = 1;
|
Shape shape = 1;
|
||||||
repeated bool preds = 2;
|
repeated bool preds = 2;
|
||||||
bytes u8s = 3;
|
bytes u8s = 3;
|
||||||
@ -285,7 +285,7 @@ message Literal {
|
|||||||
repeated uint64 u64s = 7;
|
repeated uint64 u64s = 7;
|
||||||
repeated float f32s = 8;
|
repeated float f32s = 8;
|
||||||
repeated double f64s = 9;
|
repeated double f64s = 9;
|
||||||
repeated Literal tuple_literals = 10;
|
repeated LiteralProto tuple_literals = 10;
|
||||||
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
|
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -337,7 +337,7 @@ message Window {
|
|||||||
// field in OpRequest.
|
// field in OpRequest.
|
||||||
|
|
||||||
message ConstantRequest {
|
message ConstantRequest {
|
||||||
Literal literal = 2;
|
LiteralProto literal = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GetTupleElementRequest {
|
message GetTupleElementRequest {
|
||||||
|
@ -85,6 +85,7 @@ cc_library(
|
|||||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
|
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
|
||||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
|
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
|
||||||
"//tensorflow/contrib/nccl:nccl_kernels",
|
"//tensorflow/contrib/nccl:nccl_kernels",
|
||||||
|
"//tensorflow/contrib/seq2seq:beam_search_ops_kernels",
|
||||||
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
|
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
|
||||||
"//tensorflow/contrib/text:all_kernels",
|
"//tensorflow/contrib/text:all_kernels",
|
||||||
],
|
],
|
||||||
@ -100,6 +101,7 @@ cc_library(
|
|||||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
|
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
|
||||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
|
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
|
||||||
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
|
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
|
||||||
|
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
|
||||||
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
||||||
"//tensorflow/contrib/text:all_ops",
|
"//tensorflow/contrib/text:all_ops",
|
||||||
],
|
],
|
||||||
|
@ -347,6 +347,7 @@ class BatchResource : public ResourceBase {
|
|||||||
|
|
||||||
// Concatenate the tasks ith input tensors into a big output tensor.
|
// Concatenate the tasks ith input tensors into a big output tensor.
|
||||||
std::vector<Tensor> to_concatenate;
|
std::vector<Tensor> to_concatenate;
|
||||||
|
to_concatenate.reserve(batch->num_tasks());
|
||||||
for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
|
for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
|
||||||
to_concatenate.push_back(batch->task(task_idx).inputs.at(i));
|
to_concatenate.push_back(batch->task(task_idx).inputs.at(i));
|
||||||
}
|
}
|
||||||
|
@ -139,6 +139,7 @@ TEST(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) {
|
|||||||
&callback_data](std::unique_ptr<Batch<FakeTask>> batch) {
|
&callback_data](std::unique_ptr<Batch<FakeTask>> batch) {
|
||||||
ASSERT_TRUE(batch->IsClosed());
|
ASSERT_TRUE(batch->IsClosed());
|
||||||
std::vector<size_t> batch_data;
|
std::vector<size_t> batch_data;
|
||||||
|
batch_data.reserve(batch->num_tasks());
|
||||||
for (int i = 0; i < batch->num_tasks(); ++i) {
|
for (int i = 0; i < batch->num_tasks(); ++i) {
|
||||||
batch_data.push_back(batch->mutable_task(i)->size());
|
batch_data.push_back(batch->mutable_task(i)->size());
|
||||||
}
|
}
|
||||||
|
@ -295,6 +295,7 @@ void ExpectVecsEquiv(const std::vector<float>& vec1,
|
|||||||
std::vector<float> GetWeightsByIndex(const std::vector<float>& weights,
|
std::vector<float> GetWeightsByIndex(const std::vector<float>& weights,
|
||||||
const std::vector<int>& indices) {
|
const std::vector<int>& indices) {
|
||||||
std::vector<float> res;
|
std::vector<float> res;
|
||||||
|
res.reserve(indices.size());
|
||||||
for (const int index : indices) {
|
for (const int index : indices) {
|
||||||
res.push_back(weights[index]);
|
res.push_back(weights[index]);
|
||||||
}
|
}
|
||||||
|
@ -236,6 +236,9 @@ add_python_module("tensorflow/tensorboard")
|
|||||||
add_python_module("tensorflow/tensorboard/backend")
|
add_python_module("tensorflow/tensorboard/backend")
|
||||||
add_python_module("tensorflow/tensorboard/backend/event_processing")
|
add_python_module("tensorflow/tensorboard/backend/event_processing")
|
||||||
add_python_module("tensorflow/tensorboard/plugins")
|
add_python_module("tensorflow/tensorboard/plugins")
|
||||||
|
add_python_module("tensorflow/tensorboard/plugins/audio")
|
||||||
|
add_python_module("tensorflow/tensorboard/plugins/distributions")
|
||||||
|
add_python_module("tensorflow/tensorboard/plugins/graphs")
|
||||||
add_python_module("tensorflow/tensorboard/plugins/histograms")
|
add_python_module("tensorflow/tensorboard/plugins/histograms")
|
||||||
add_python_module("tensorflow/tensorboard/plugins/images")
|
add_python_module("tensorflow/tensorboard/plugins/images")
|
||||||
add_python_module("tensorflow/tensorboard/plugins/projector")
|
add_python_module("tensorflow/tensorboard/plugins/projector")
|
||||||
@ -536,6 +539,7 @@ set(tf_python_op_gen_main_srcs
|
|||||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
|
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc"
|
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h"
|
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h"
|
||||||
)
|
)
|
||||||
|
|
||||||
add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs})
|
add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs})
|
||||||
|
@ -209,10 +209,11 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||||||
# Broken TensorBoard tests due to different paths in windows
|
# Broken TensorBoard tests due to different paths in windows
|
||||||
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
|
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
|
||||||
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
|
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
|
||||||
# Broken tensorboard test due to cmake issues.
|
# Broken tensorboard test due to cmake issues.
|
||||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
|
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
||||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py"
|
|
||||||
# tensor_forest tests (also note that we exclude the hybrid tests for now)
|
# tensor_forest tests (also note that we exclude the hybrid tests for now)
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
|
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
|
||||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
|
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
|
||||||
|
@ -150,7 +150,8 @@ class MapDatasetTest(test.TestCase):
|
|||||||
results.append(sess.run(get_next))
|
results.append(sess.run(get_next))
|
||||||
except errors.OutOfRangeError:
|
except errors.OutOfRangeError:
|
||||||
return
|
return
|
||||||
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
|
threads = [self.checkedThread(target=iterator_thread)
|
||||||
|
for _ in range(64)]
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.start()
|
t.start()
|
||||||
for t in threads:
|
for t in threads:
|
||||||
|
@ -375,8 +375,8 @@ class NearestNeighborsOp : public OpKernel {
|
|||||||
const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
|
const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
|
||||||
const Eigen::Ref<const MatrixXfRowMajor>& centers,
|
const Eigen::Ref<const MatrixXfRowMajor>& centers,
|
||||||
const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
|
const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
|
||||||
Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
|
const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
|
||||||
Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
|
const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
|
||||||
CHECK_LE(k, centers.rows());
|
CHECK_LE(k, centers.rows());
|
||||||
if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
|
if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
|
||||||
FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
|
FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
|
||||||
|
@ -164,9 +164,10 @@ class KMeans(object):
|
|||||||
with ops.colocate_with(inp):
|
with ops.colocate_with(inp):
|
||||||
# Computes Euclidean distance. Note the first and third terms are
|
# Computes Euclidean distance. Note the first and third terms are
|
||||||
# broadcast additions.
|
# broadcast additions.
|
||||||
squared_distance = (math_ops.reduce_sum(
|
squared_distance = (
|
||||||
math_ops.square(inp), 1, keep_dims=True) - 2 * math_ops.matmul(
|
math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
|
||||||
inp, clusters, transpose_b=True) + array_ops.transpose(
|
2 * math_ops.matmul(inp, clusters, transpose_b=True) +
|
||||||
|
array_ops.transpose(
|
||||||
math_ops.reduce_sum(
|
math_ops.reduce_sum(
|
||||||
math_ops.square(clusters), 1, keep_dims=True)))
|
math_ops.square(clusters), 1, keep_dims=True)))
|
||||||
output.append(squared_distance)
|
output.append(squared_distance)
|
||||||
@ -229,12 +230,12 @@ class KMeans(object):
|
|||||||
clusters = nn_impl.l2_normalize(clusters, dim=1)
|
clusters = nn_impl.l2_normalize(clusters, dim=1)
|
||||||
for inp, score in zip(inputs, scores):
|
for inp, score in zip(inputs, scores):
|
||||||
with ops.colocate_with(inp):
|
with ops.colocate_with(inp):
|
||||||
(indices,
|
(indices, distances) = gen_clustering_ops.nearest_neighbors(
|
||||||
distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1)
|
inp, clusters, 1)
|
||||||
if self._distance_metric == COSINE_DISTANCE:
|
if self._distance_metric == COSINE_DISTANCE:
|
||||||
distances *= 0.5
|
distances *= 0.5
|
||||||
output.append(
|
output.append((score, array_ops.squeeze(distances),
|
||||||
(score, array_ops.squeeze(distances), array_ops.squeeze(indices)))
|
array_ops.squeeze(indices)))
|
||||||
return zip(*output)
|
return zip(*output)
|
||||||
|
|
||||||
def _init_clusters_random(self):
|
def _init_clusters_random(self):
|
||||||
@ -265,9 +266,7 @@ class KMeans(object):
|
|||||||
(not self._use_mini_batch or
|
(not self._use_mini_batch or
|
||||||
self._mini_batch_steps_per_iteration > 1))
|
self._mini_batch_steps_per_iteration > 1))
|
||||||
|
|
||||||
def _initialize_clusters(self,
|
def _initialize_clusters(self, cluster_centers, cluster_centers_initialized,
|
||||||
cluster_centers,
|
|
||||||
cluster_centers_initialized,
|
|
||||||
cluster_centers_updated):
|
cluster_centers_updated):
|
||||||
"""Returns an op to initialize the cluster centers."""
|
"""Returns an op to initialize the cluster centers."""
|
||||||
|
|
||||||
@ -294,21 +293,19 @@ class KMeans(object):
|
|||||||
|
|
||||||
with ops.colocate_with(cluster_centers_initialized):
|
with ops.colocate_with(cluster_centers_initialized):
|
||||||
initialized = control_flow_ops.with_dependencies(
|
initialized = control_flow_ops.with_dependencies(
|
||||||
[clusters_init],
|
[clusters_init], array_ops.identity(cluster_centers_initialized))
|
||||||
array_ops.identity(cluster_centers_initialized))
|
|
||||||
with ops.colocate_with(cluster_centers):
|
with ops.colocate_with(cluster_centers):
|
||||||
assign_centers = state_ops.assign(cluster_centers, clusters_init,
|
assign_centers = state_ops.assign(
|
||||||
validate_shape=False)
|
cluster_centers, clusters_init, validate_shape=False)
|
||||||
if cluster_centers_updated != cluster_centers:
|
if cluster_centers_updated != cluster_centers:
|
||||||
assign_centers = control_flow_ops.group(
|
assign_centers = control_flow_ops.group(assign_centers,
|
||||||
assign_centers,
|
state_ops.assign(
|
||||||
state_ops.assign(cluster_centers_updated, clusters_init,
|
cluster_centers_updated,
|
||||||
|
clusters_init,
|
||||||
validate_shape=False))
|
validate_shape=False))
|
||||||
assign_centers = control_flow_ops.with_dependencies(
|
assign_centers = control_flow_ops.with_dependencies(
|
||||||
[assign_centers],
|
[assign_centers], state_ops.assign(cluster_centers_initialized, True))
|
||||||
state_ops.assign(cluster_centers_initialized, True))
|
return control_flow_ops.cond(initialized, control_flow_ops.no_op,
|
||||||
return control_flow_ops.cond(initialized,
|
|
||||||
control_flow_ops.no_op,
|
|
||||||
lambda: assign_centers).op
|
lambda: assign_centers).op
|
||||||
|
|
||||||
def _create_variables(self):
|
def _create_variables(self):
|
||||||
@ -327,19 +324,16 @@ class KMeans(object):
|
|||||||
cluster_centers_updated back to cluster_centers.
|
cluster_centers_updated back to cluster_centers.
|
||||||
"""
|
"""
|
||||||
init_value = array_ops.constant([], dtype=dtypes.float32)
|
init_value = array_ops.constant([], dtype=dtypes.float32)
|
||||||
cluster_centers = variable_scope.variable(init_value,
|
cluster_centers = variable_scope.variable(
|
||||||
name='clusters',
|
init_value, name='clusters', validate_shape=False)
|
||||||
validate_shape=False)
|
cluster_centers_initialized = variable_scope.variable(
|
||||||
cluster_centers_initialized = variable_scope.variable(False,
|
False, dtype=dtypes.bool, name='initialized')
|
||||||
dtype=dtypes.bool,
|
|
||||||
name='initialized')
|
|
||||||
|
|
||||||
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
||||||
# Copy of cluster centers actively updated each step according to
|
# Copy of cluster centers actively updated each step according to
|
||||||
# mini-batch update rule.
|
# mini-batch update rule.
|
||||||
cluster_centers_updated = variable_scope.variable(init_value,
|
cluster_centers_updated = variable_scope.variable(
|
||||||
name='clusters_updated',
|
init_value, name='clusters_updated', validate_shape=False)
|
||||||
validate_shape=False)
|
|
||||||
# How many steps till we copy the updated clusters to cluster_centers.
|
# How many steps till we copy the updated clusters to cluster_centers.
|
||||||
update_in_steps = variable_scope.variable(
|
update_in_steps = variable_scope.variable(
|
||||||
self._mini_batch_steps_per_iteration,
|
self._mini_batch_steps_per_iteration,
|
||||||
@ -347,20 +341,15 @@ class KMeans(object):
|
|||||||
name='update_in_steps')
|
name='update_in_steps')
|
||||||
# Count of points assigned to cluster_centers_updated.
|
# Count of points assigned to cluster_centers_updated.
|
||||||
cluster_counts = variable_scope.variable(
|
cluster_counts = variable_scope.variable(
|
||||||
array_ops.zeros([self._num_clusters],
|
array_ops.zeros([self._num_clusters], dtype=dtypes.int64))
|
||||||
dtype=dtypes.int64))
|
|
||||||
else:
|
else:
|
||||||
cluster_centers_updated = cluster_centers
|
cluster_centers_updated = cluster_centers
|
||||||
update_in_steps = None
|
update_in_steps = None
|
||||||
cluster_counts = (variable_scope.variable(array_ops.ones(
|
cluster_counts = (variable_scope.variable(
|
||||||
[self._num_clusters],
|
array_ops.ones([self._num_clusters], dtype=dtypes.int64))
|
||||||
dtype=dtypes.int64))
|
|
||||||
if self._use_mini_batch else None)
|
if self._use_mini_batch else None)
|
||||||
return (cluster_centers,
|
return (cluster_centers, cluster_centers_initialized, cluster_counts,
|
||||||
cluster_centers_initialized,
|
cluster_centers_updated, update_in_steps)
|
||||||
cluster_counts,
|
|
||||||
cluster_centers_updated,
|
|
||||||
update_in_steps)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _l2_normalize_data(cls, inputs):
|
def _l2_normalize_data(cls, inputs):
|
||||||
@ -391,11 +380,8 @@ class KMeans(object):
|
|||||||
"""
|
"""
|
||||||
# Implementation of kmeans.
|
# Implementation of kmeans.
|
||||||
inputs = self._inputs
|
inputs = self._inputs
|
||||||
(cluster_centers_var,
|
(cluster_centers_var, cluster_centers_initialized, total_counts,
|
||||||
cluster_centers_initialized,
|
cluster_centers_updated, update_in_steps) = self._create_variables()
|
||||||
total_counts,
|
|
||||||
cluster_centers_updated,
|
|
||||||
update_in_steps) = self._create_variables()
|
|
||||||
init_op = self._initialize_clusters(cluster_centers_var,
|
init_op = self._initialize_clusters(cluster_centers_var,
|
||||||
cluster_centers_initialized,
|
cluster_centers_initialized,
|
||||||
cluster_centers_updated)
|
cluster_centers_updated)
|
||||||
@ -409,8 +395,7 @@ class KMeans(object):
|
|||||||
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
|
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
|
||||||
if self._use_mini_batch:
|
if self._use_mini_batch:
|
||||||
sync_updates_op = self._mini_batch_sync_updates_op(
|
sync_updates_op = self._mini_batch_sync_updates_op(
|
||||||
update_in_steps,
|
update_in_steps, cluster_centers_var, cluster_centers_updated,
|
||||||
cluster_centers_var, cluster_centers_updated,
|
|
||||||
total_counts)
|
total_counts)
|
||||||
assert sync_updates_op is not None
|
assert sync_updates_op is not None
|
||||||
with ops.control_dependencies([sync_updates_op]):
|
with ops.control_dependencies([sync_updates_op]):
|
||||||
@ -421,15 +406,15 @@ class KMeans(object):
|
|||||||
training_op = self._full_batch_training_op(inputs, cluster_idx,
|
training_op = self._full_batch_training_op(inputs, cluster_idx,
|
||||||
cluster_centers_var)
|
cluster_centers_var)
|
||||||
|
|
||||||
return (all_scores, cluster_idx, scores,
|
return (all_scores, cluster_idx, scores, cluster_centers_initialized,
|
||||||
cluster_centers_initialized, init_op, training_op)
|
init_op, training_op)
|
||||||
|
|
||||||
def _mini_batch_sync_updates_op(self, update_in_steps,
|
def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
|
||||||
cluster_centers_var, cluster_centers_updated,
|
cluster_centers_updated, total_counts):
|
||||||
total_counts):
|
|
||||||
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
||||||
assert update_in_steps is not None
|
assert update_in_steps is not None
|
||||||
with ops.colocate_with(update_in_steps):
|
with ops.colocate_with(update_in_steps):
|
||||||
|
|
||||||
def _f():
|
def _f():
|
||||||
# Note that there is a race condition here, so we do a best effort
|
# Note that there is a race condition here, so we do a best effort
|
||||||
# updates here. We reset update_in_steps first so that other workers
|
# updates here. We reset update_in_steps first so that other workers
|
||||||
@ -437,33 +422,36 @@ class KMeans(object):
|
|||||||
# before resetting total_counts to avoid large updates to
|
# before resetting total_counts to avoid large updates to
|
||||||
# cluster_centers_updated based on partially updated
|
# cluster_centers_updated based on partially updated
|
||||||
# cluster_center_vars.
|
# cluster_center_vars.
|
||||||
with ops.control_dependencies([state_ops.assign(
|
with ops.control_dependencies([
|
||||||
update_in_steps,
|
state_ops.assign(update_in_steps,
|
||||||
self._mini_batch_steps_per_iteration - 1)]):
|
self._mini_batch_steps_per_iteration - 1)
|
||||||
with ops.colocate_with(cluster_centers_updated):
|
]):
|
||||||
|
with ops.colocate_with(
|
||||||
|
cluster_centers_updated, ignore_existing=True):
|
||||||
if self._distance_metric == COSINE_DISTANCE:
|
if self._distance_metric == COSINE_DISTANCE:
|
||||||
cluster_centers = nn_impl.l2_normalize(cluster_centers_updated,
|
cluster_centers = nn_impl.l2_normalize(
|
||||||
dim=1)
|
cluster_centers_updated, dim=1)
|
||||||
else:
|
else:
|
||||||
cluster_centers = cluster_centers_updated
|
cluster_centers = cluster_centers_updated
|
||||||
with ops.colocate_with(cluster_centers_var):
|
with ops.colocate_with(cluster_centers_var):
|
||||||
with ops.control_dependencies([state_ops.assign(
|
with ops.control_dependencies(
|
||||||
cluster_centers_var,
|
[state_ops.assign(cluster_centers_var, cluster_centers)]):
|
||||||
cluster_centers)]):
|
with ops.colocate_with(
|
||||||
with ops.colocate_with(cluster_centers_var):
|
cluster_centers_var, ignore_existing=True):
|
||||||
with ops.control_dependencies([
|
with ops.control_dependencies([
|
||||||
state_ops.assign(total_counts,
|
state_ops.assign(total_counts,
|
||||||
array_ops.zeros_like(total_counts))]):
|
array_ops.zeros_like(total_counts))
|
||||||
|
]):
|
||||||
return array_ops.identity(update_in_steps)
|
return array_ops.identity(update_in_steps)
|
||||||
|
|
||||||
return control_flow_ops.cond(
|
return control_flow_ops.cond(
|
||||||
update_in_steps <= 0,
|
update_in_steps <= 0, _f,
|
||||||
_f,
|
|
||||||
lambda: state_ops.assign_sub(update_in_steps, 1))
|
lambda: state_ops.assign_sub(update_in_steps, 1))
|
||||||
else:
|
else:
|
||||||
return control_flow_ops.no_op()
|
return control_flow_ops.no_op()
|
||||||
|
|
||||||
def _mini_batch_training_op(self, inputs, cluster_idx_list,
|
def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
|
||||||
cluster_centers, total_counts):
|
total_counts):
|
||||||
"""Creates an op for training for mini batch case.
|
"""Creates an op for training for mini batch case.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -487,17 +475,15 @@ class KMeans(object):
|
|||||||
unique_ids, unique_idx = array_ops.unique(cluster_idx)
|
unique_ids, unique_idx = array_ops.unique(cluster_idx)
|
||||||
num_unique_cluster_idx = array_ops.size(unique_ids)
|
num_unique_cluster_idx = array_ops.size(unique_ids)
|
||||||
# Fetch the old values of counts and cluster_centers.
|
# Fetch the old values of counts and cluster_centers.
|
||||||
with ops.colocate_with(total_counts):
|
with ops.colocate_with(total_counts, ignore_existing=True):
|
||||||
old_counts = array_ops.gather(total_counts, unique_ids)
|
old_counts = array_ops.gather(total_counts, unique_ids)
|
||||||
# TODO(agarwal): This colocation seems to run into problems. Fix it.
|
# TODO(agarwal): This colocation seems to run into problems. Fix it.
|
||||||
# with ops.colocate_with(cluster_centers):
|
with ops.colocate_with(cluster_centers, ignore_existing=True):
|
||||||
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
|
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
|
||||||
# Locally aggregate the increment to counts.
|
# Locally aggregate the increment to counts.
|
||||||
count_updates = math_ops.unsorted_segment_sum(
|
count_updates = math_ops.unsorted_segment_sum(
|
||||||
array_ops.ones_like(
|
array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
|
||||||
unique_idx, dtype=total_counts.dtype),
|
unique_idx, num_unique_cluster_idx)
|
||||||
unique_idx,
|
|
||||||
num_unique_cluster_idx)
|
|
||||||
# Locally compute the sum of inputs mapped to each id.
|
# Locally compute the sum of inputs mapped to each id.
|
||||||
# For a cluster with old cluster value x, old count n, and with data
|
# For a cluster with old cluster value x, old count n, and with data
|
||||||
# d_1,...d_k newly assigned to it, we recompute the new value as
|
# d_1,...d_k newly assigned to it, we recompute the new value as
|
||||||
@ -507,13 +493,12 @@ class KMeans(object):
|
|||||||
inp, unique_idx, num_unique_cluster_idx)
|
inp, unique_idx, num_unique_cluster_idx)
|
||||||
# Shape to enable broadcasting count_updates and learning_rate to inp.
|
# Shape to enable broadcasting count_updates and learning_rate to inp.
|
||||||
# It extends the shape with 1's to match the rank of inp.
|
# It extends the shape with 1's to match the rank of inp.
|
||||||
broadcast_shape = array_ops.concat(
|
broadcast_shape = array_ops.concat([
|
||||||
[
|
array_ops.reshape(num_unique_cluster_idx, [1]),
|
||||||
array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(
|
array_ops.ones(
|
||||||
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
|
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
|
||||||
dtype=dtypes.int32)
|
dtype=dtypes.int32)
|
||||||
],
|
], 0)
|
||||||
0)
|
|
||||||
# Subtract k * x, see comment above.
|
# Subtract k * x, see comment above.
|
||||||
cluster_center_updates -= math_ops.cast(
|
cluster_center_updates -= math_ops.cast(
|
||||||
array_ops.reshape(count_updates, broadcast_shape),
|
array_ops.reshape(count_updates, broadcast_shape),
|
||||||
@ -524,14 +509,10 @@ class KMeans(object):
|
|||||||
# scale by 1 / (n + k), see comment above.
|
# scale by 1 / (n + k), see comment above.
|
||||||
cluster_center_updates *= learning_rate
|
cluster_center_updates *= learning_rate
|
||||||
# Apply the updates.
|
# Apply the updates.
|
||||||
update_counts = state_ops.scatter_add(
|
update_counts = state_ops.scatter_add(total_counts, unique_ids,
|
||||||
total_counts,
|
|
||||||
unique_ids,
|
|
||||||
count_updates)
|
count_updates)
|
||||||
update_cluster_centers = state_ops.scatter_add(
|
update_cluster_centers = state_ops.scatter_add(
|
||||||
cluster_centers,
|
cluster_centers, unique_ids, cluster_center_updates)
|
||||||
unique_ids,
|
|
||||||
cluster_center_updates)
|
|
||||||
update_ops.extend([update_counts, update_cluster_centers])
|
update_ops.extend([update_counts, update_cluster_centers])
|
||||||
return control_flow_ops.group(*update_ops)
|
return control_flow_ops.group(*update_ops)
|
||||||
|
|
||||||
@ -552,7 +533,7 @@ class KMeans(object):
|
|||||||
cluster_counts = []
|
cluster_counts = []
|
||||||
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
|
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
|
||||||
for inp, cluster_idx in zip(inputs, cluster_idx_list):
|
for inp, cluster_idx in zip(inputs, cluster_idx_list):
|
||||||
with ops.colocate_with(inp):
|
with ops.colocate_with(inp, ignore_existing=True):
|
||||||
cluster_sums.append(
|
cluster_sums.append(
|
||||||
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
|
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
|
||||||
cluster_counts.append(
|
cluster_counts.append(
|
||||||
@ -561,7 +542,7 @@ class KMeans(object):
|
|||||||
array_ops.ones(
|
array_ops.ones(
|
||||||
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
|
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
|
||||||
[-1, 1]), cluster_idx, self._num_clusters))
|
[-1, 1]), cluster_idx, self._num_clusters))
|
||||||
with ops.colocate_with(cluster_centers):
|
with ops.colocate_with(cluster_centers, ignore_existing=True):
|
||||||
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
|
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
|
||||||
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
|
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
|
||||||
if self._clusters_l2_normalized():
|
if self._clusters_l2_normalized():
|
||||||
|
@ -94,6 +94,7 @@ TEST(FfmpegLibTest, TestRoundTripGeneratedWav) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> sine_wave;
|
std::vector<float> sine_wave;
|
||||||
|
sine_wave.reserve(20000);
|
||||||
for (int i = 0; i < 20000; ++i) {
|
for (int i = 0; i < 20000; ++i) {
|
||||||
sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0));
|
sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0));
|
||||||
}
|
}
|
||||||
|
@ -494,6 +494,7 @@ class SparseFeatureCrossOp : public OpKernel {
|
|||||||
ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
|
ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
|
||||||
&feature_start_indices);
|
&feature_start_indices);
|
||||||
|
|
||||||
|
columns.reserve(values_list_in.size());
|
||||||
for (int i = 0; i < values_list_in.size(); ++i) {
|
for (int i = 0; i < values_list_in.size(); ++i) {
|
||||||
columns.emplace_back(new SparseTensorColumn<InternalType>(
|
columns.emplace_back(new SparseTensorColumn<InternalType>(
|
||||||
values_list_in[i], std::move(feature_counts[i]),
|
values_list_in[i], std::move(feature_counts[i]),
|
||||||
|
@ -308,6 +308,7 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_rea
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
|
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import Head
|
from tensorflow.contrib.learn.python.learn.estimators.head import Head
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import loss_only_head
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
|
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
|
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head
|
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head
|
||||||
|
@ -429,6 +429,23 @@ def multi_label_head(n_classes,
|
|||||||
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
|
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
|
||||||
|
|
||||||
|
|
||||||
|
def loss_only_head(loss_fn, head_name=None):
|
||||||
|
"""Creates a Head that contains only loss terms.
|
||||||
|
|
||||||
|
Loss only head holds additional loss terms to be added to other heads and
|
||||||
|
usually represents additional regularization terms in the objective function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_fn: a function that takes no argument and returns a list of
|
||||||
|
scalar tensors.
|
||||||
|
head_name: a name for for the head.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `Head` to hold the additional losses.
|
||||||
|
"""
|
||||||
|
return _LossOnlyHead(loss_fn, head_name=head_name)
|
||||||
|
|
||||||
|
|
||||||
def multi_head(heads, loss_weights=None):
|
def multi_head(heads, loss_weights=None):
|
||||||
"""Creates a MultiHead stemming from same logits/hidden layer.
|
"""Creates a MultiHead stemming from same logits/hidden layer.
|
||||||
|
|
||||||
@ -1406,6 +1423,80 @@ class _MultiLabelHead(_SingleHead):
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
class _LossOnlyHead(Head):
|
||||||
|
"""`Head` implementation for additional loss terms.
|
||||||
|
|
||||||
|
This class only holds loss terms unrelated to any other heads (labels),
|
||||||
|
e.g. regularization.
|
||||||
|
|
||||||
|
Common usage:
|
||||||
|
This is oftem combine with other heads in a multi head setup.
|
||||||
|
```python
|
||||||
|
head = multi_head([
|
||||||
|
head1, head2, loss_only_head('regularizer', regularizer)])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loss_fn, head_name=None):
|
||||||
|
self._loss_fn = loss_fn
|
||||||
|
self.head_name = head_name or "loss_only_head"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits_dimension(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def create_model_fn_ops(self,
|
||||||
|
features,
|
||||||
|
mode,
|
||||||
|
labels=None,
|
||||||
|
train_op_fn=None,
|
||||||
|
logits=None,
|
||||||
|
logits_input=None,
|
||||||
|
scope=None):
|
||||||
|
"""See `_Head.create_model_fn_ops`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Not been used.
|
||||||
|
mode: Estimator's `ModeKeys`.
|
||||||
|
labels: Labels `Tensor`, or `dict` of same.
|
||||||
|
train_op_fn: Function that takes a scalar loss and returns an op to
|
||||||
|
optimize with the loss.
|
||||||
|
logits: Not been used.
|
||||||
|
logits_input: Not been used.
|
||||||
|
scope: Optional scope for variable_scope. If provided, will be passed to
|
||||||
|
all heads. Most users will want to set this to `None`, so each head
|
||||||
|
constructs a separate variable_scope according to its `head_name`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `ModelFnOps` object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `mode` is not recognition.
|
||||||
|
"""
|
||||||
|
_check_mode_valid(mode)
|
||||||
|
loss = None
|
||||||
|
train_op = None
|
||||||
|
if mode != model_fn.ModeKeys.INFER:
|
||||||
|
with variable_scope.variable_scope(scope, default_name=self.head_name):
|
||||||
|
loss = self._loss_fn()
|
||||||
|
if isinstance(loss, list):
|
||||||
|
loss = math_ops.add_n(loss)
|
||||||
|
logging_ops.scalar_summary(
|
||||||
|
_summary_key(self.head_name, mkey.LOSS), loss)
|
||||||
|
if mode == model_fn.ModeKeys.TRAIN:
|
||||||
|
if train_op_fn is None:
|
||||||
|
raise ValueError("train_op_fn can not be None in TRAIN mode")
|
||||||
|
with ops.name_scope(None, "train_op", (loss,)):
|
||||||
|
train_op = train_op_fn(loss)
|
||||||
|
|
||||||
|
return model_fn.ModelFnOps(
|
||||||
|
mode=mode,
|
||||||
|
loss=loss,
|
||||||
|
train_op=train_op,
|
||||||
|
predictions={},
|
||||||
|
eval_metric_ops={})
|
||||||
|
|
||||||
|
|
||||||
class _MultiHead(Head):
|
class _MultiHead(Head):
|
||||||
"""`Head` implementation for multi objective learning.
|
"""`Head` implementation for multi objective learning.
|
||||||
|
|
||||||
@ -1525,6 +1616,9 @@ class _MultiHead(Head):
|
|||||||
if isinstance(logits, dict):
|
if isinstance(logits, dict):
|
||||||
head_logits_pairs = []
|
head_logits_pairs = []
|
||||||
for head in self._heads:
|
for head in self._heads:
|
||||||
|
if isinstance(head, _LossOnlyHead):
|
||||||
|
head_logits_pairs.append((head, None))
|
||||||
|
else:
|
||||||
head_logits_pairs.append((head, logits[head.head_name]))
|
head_logits_pairs.append((head, logits[head.head_name]))
|
||||||
else:
|
else:
|
||||||
# Split logits for each head.
|
# Split logits for each head.
|
||||||
@ -1606,6 +1700,8 @@ class _MultiHead(Head):
|
|||||||
predictions = {}
|
predictions = {}
|
||||||
output_alternatives = {}
|
output_alternatives = {}
|
||||||
for head, m in zip(self._heads, all_model_fn_ops):
|
for head, m in zip(self._heads, all_model_fn_ops):
|
||||||
|
if isinstance(head, _LossOnlyHead):
|
||||||
|
continue
|
||||||
head_name = head.head_name
|
head_name = head.head_name
|
||||||
output_alternatives[head_name] = m.output_alternatives[head_name]
|
output_alternatives[head_name] = m.output_alternatives[head_name]
|
||||||
for k, v in m.predictions.items():
|
for k, v in m.predictions.items():
|
||||||
|
@ -1638,6 +1638,21 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
}, model_fn_ops)
|
}, model_fn_ops)
|
||||||
|
|
||||||
|
|
||||||
|
class LossOnlyHead(test.TestCase):
|
||||||
|
|
||||||
|
def testNoPredictionsAndNoMetrics(self):
|
||||||
|
head = head_lib.loss_only_head(lambda: 1, head_name="const")
|
||||||
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
|
features={},
|
||||||
|
mode=model_fn.ModeKeys.TRAIN,
|
||||||
|
train_op_fn=head_lib.no_op_train_fn)
|
||||||
|
self.assertDictEqual(model_fn_ops.predictions, {})
|
||||||
|
self.assertDictEqual(model_fn_ops.eval_metric_ops, {})
|
||||||
|
self.assertIsNotNone(model_fn_ops.loss)
|
||||||
|
with session.Session() as sess:
|
||||||
|
self.assertEqual(1, sess.run(model_fn_ops.loss))
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadTest(test.TestCase):
|
class MultiHeadTest(test.TestCase):
|
||||||
|
|
||||||
def testInvalidHeads(self):
|
def testInvalidHeads(self):
|
||||||
@ -1672,7 +1687,8 @@ class MultiHeadTest(test.TestCase):
|
|||||||
n_classes=3, label_name="label1", head_name="head1")
|
n_classes=3, label_name="label1", head_name="head1")
|
||||||
head2 = head_lib.multi_class_head(
|
head2 = head_lib.multi_class_head(
|
||||||
n_classes=4, label_name="label2", head_name="head2")
|
n_classes=4, label_name="label2", head_name="head2")
|
||||||
head = head_lib.multi_head((head1, head2))
|
head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")
|
||||||
|
head = head_lib.multi_head((head1, head2, head3))
|
||||||
labels = {
|
labels = {
|
||||||
"label1": (1,),
|
"label1": (1,),
|
||||||
"label2": (1,)
|
"label2": (1,)
|
||||||
@ -1691,7 +1707,7 @@ class MultiHeadTest(test.TestCase):
|
|||||||
self.assertIsNone(model_fn_ops.output_alternatives)
|
self.assertIsNone(model_fn_ops.output_alternatives)
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)
|
||||||
|
|
||||||
def testTrain_withHeadWeights(self):
|
def testTrain_withHeadWeights(self):
|
||||||
head1 = head_lib.multi_class_head(
|
head1 = head_lib.multi_class_head(
|
||||||
|
@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None,
|
|||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocabulary_file: The vocabulary filename.
|
vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
|
||||||
num_oov_buckets: The number of out-of-vocabulary buckets.
|
num_oov_buckets: The number of out-of-vocabulary buckets.
|
||||||
vocab_size: Number of the elements in the vocabulary, if known.
|
vocab_size: Number of the elements in the vocabulary, if known.
|
||||||
default_value: The value to use for out-of-vocabulary feature values.
|
default_value: The value to use for out-of-vocabulary feature values.
|
||||||
@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None,
|
|||||||
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
|
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
|
||||||
than zero.
|
than zero.
|
||||||
"""
|
"""
|
||||||
if not vocabulary_file:
|
if vocabulary_file is None or (
|
||||||
raise ValueError("vocabulary_file must be specified.")
|
isinstance(vocabulary_file, str) and not vocabulary_file):
|
||||||
|
raise ValueError("vocabulary_file must be specified and must not be empty.")
|
||||||
if num_oov_buckets < 0:
|
if num_oov_buckets < 0:
|
||||||
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
|
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
|
||||||
% num_oov_buckets)
|
% num_oov_buckets)
|
||||||
|
@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||||
|
|
||||||
|
def test_string_index_table_from_file_tensor_filename(self):
|
||||||
|
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
|
||||||
|
with self.test_session():
|
||||||
|
vocabulary_file = constant_op.constant(vocabulary_file)
|
||||||
|
table = lookup.index_table_from_file(
|
||||||
|
vocabulary_file=vocabulary_file, num_oov_buckets=1)
|
||||||
|
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||||
|
|
||||||
|
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||||
|
lookup_ops.tables_initializer().run()
|
||||||
|
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||||
|
|
||||||
def test_int32_index_table_from_file(self):
|
def test_int32_index_table_from_file(self):
|
||||||
vocabulary_file = self._createVocabFile(
|
vocabulary_file = self._createVocabFile(
|
||||||
"f2i_vocab2.txt", values=("42", "1", "-1000"))
|
"f2i_vocab2.txt", values=("42", "1", "-1000"))
|
||||||
@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase):
|
|||||||
860), # 3 + fingerprint("toccata") mod 300.
|
860), # 3 + fingerprint("toccata") mod 300.
|
||||||
ids.eval())
|
ids.eval())
|
||||||
|
|
||||||
def test_index_table_from_file_with_only_oov_buckets(self):
|
def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
lookup.index_table_from_file,
|
||||||
|
vocabulary_file="")
|
||||||
|
|
||||||
|
def test_index_table_from_file_fails_with_empty_vocabulary(self):
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValueError,
|
ValueError,
|
||||||
lookup.index_table_from_file,
|
lookup.index_table_from_file,
|
||||||
|
@ -23,6 +23,7 @@ See the @{$python/contrib.metrics} guide.
|
|||||||
@@streaming_precision
|
@@streaming_precision
|
||||||
@@streaming_precision_at_thresholds
|
@@streaming_precision_at_thresholds
|
||||||
@@streaming_auc
|
@@streaming_auc
|
||||||
|
@@streaming_curve_points
|
||||||
@@streaming_recall_at_k
|
@@streaming_recall_at_k
|
||||||
@@streaming_mean_absolute_error
|
@@streaming_mean_absolute_error
|
||||||
@@streaming_mean_iou
|
@@streaming_mean_iou
|
||||||
@ -76,6 +77,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
|
|||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
|
||||||
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds
|
||||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives
|
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives
|
||||||
|
@ -733,6 +733,102 @@ def streaming_true_negatives_at_thresholds(
|
|||||||
return values['tn'], update_ops['tn']
|
return values['tn'], update_ops['tn']
|
||||||
|
|
||||||
|
|
||||||
|
def streaming_curve_points(labels=None,
|
||||||
|
predictions=None,
|
||||||
|
weights=None,
|
||||||
|
num_thresholds=200,
|
||||||
|
metrics_collections=None,
|
||||||
|
updates_collections=None,
|
||||||
|
curve='ROC',
|
||||||
|
name=None):
|
||||||
|
"""Computes curve (ROC or PR) values for a prespecified number of points.
|
||||||
|
|
||||||
|
The `streaming_curve_points` function creates four local variables,
|
||||||
|
`true_positives`, `true_negatives`, `false_positives` and `false_negatives`
|
||||||
|
that are used to compute the curve values. To discretize the curve, a linearly
|
||||||
|
spaced set of thresholds is used to compute pairs of recall and precision
|
||||||
|
values.
|
||||||
|
|
||||||
|
For best results, `predictions` should be distributed approximately uniformly
|
||||||
|
in the range [0, 1] and not peaked around 0 or 1.
|
||||||
|
|
||||||
|
For estimation of the metric over a stream of data, the function creates an
|
||||||
|
`update_op` operation that updates these variables.
|
||||||
|
|
||||||
|
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
|
||||||
|
`bool`.
|
||||||
|
predictions: A floating point `Tensor` of arbitrary shape and whose values
|
||||||
|
are in the range `[0, 1]`.
|
||||||
|
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||||
|
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||||
|
be either `1`, or the same as the corresponding `labels` dimension).
|
||||||
|
num_thresholds: The number of thresholds to use when discretizing the roc
|
||||||
|
curve.
|
||||||
|
metrics_collections: An optional list of collections that `auc` should be
|
||||||
|
added to.
|
||||||
|
updates_collections: An optional list of collections that `update_op` should
|
||||||
|
be added to.
|
||||||
|
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
|
||||||
|
'PR' for the Precision-Recall-curve.
|
||||||
|
name: An optional variable_scope name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
points: A `Tensor` with shape [num_thresholds, 2] that contains points of
|
||||||
|
the curve.
|
||||||
|
update_op: An operation that increments the `true_positives`,
|
||||||
|
`true_negatives`, `false_positives` and `false_negatives` variables.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `predictions` and `labels` have mismatched shapes, or if
|
||||||
|
`weights` is not `None` and its shape doesn't match `predictions`, or if
|
||||||
|
either `metrics_collections` or `updates_collections` are not a list or
|
||||||
|
tuple.
|
||||||
|
"""
|
||||||
|
with variable_scope.variable_scope(name, 'curve_points', (labels, predictions,
|
||||||
|
weights)):
|
||||||
|
if curve != 'ROC' and curve != 'PR':
|
||||||
|
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
|
||||||
|
kepsilon = 1e-7 # to account for floating point imprecisions
|
||||||
|
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
|
||||||
|
for i in range(num_thresholds - 2)]
|
||||||
|
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
|
||||||
|
|
||||||
|
values, update_ops = _streaming_confusion_matrix_at_thresholds(
|
||||||
|
labels=labels,
|
||||||
|
predictions=predictions,
|
||||||
|
thresholds=thresholds,
|
||||||
|
weights=weights)
|
||||||
|
|
||||||
|
# Add epsilons to avoid dividing by 0.
|
||||||
|
epsilon = 1.0e-6
|
||||||
|
|
||||||
|
def compute_points(tp, fn, tn, fp):
|
||||||
|
"""Computes the roc-auc or pr-auc based on confusion counts."""
|
||||||
|
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
|
||||||
|
if curve == 'ROC':
|
||||||
|
fp_rate = math_ops.div(fp, fp + tn + epsilon)
|
||||||
|
return fp_rate, rec
|
||||||
|
else: # curve == 'PR'.
|
||||||
|
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
|
||||||
|
return rec, prec
|
||||||
|
|
||||||
|
xs, ys = compute_points(values['tp'], values['fn'], values['tn'],
|
||||||
|
values['fp'])
|
||||||
|
points = array_ops.stack([xs, ys], axis=1)
|
||||||
|
update_op = control_flow_ops.group(*update_ops.values())
|
||||||
|
|
||||||
|
if metrics_collections:
|
||||||
|
ops.add_to_collections(metrics_collections, points)
|
||||||
|
|
||||||
|
if updates_collections:
|
||||||
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
|
return points, update_op
|
||||||
|
|
||||||
|
|
||||||
def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
||||||
metrics_collections=None, updates_collections=None,
|
metrics_collections=None, updates_collections=None,
|
||||||
curve='ROC', name=None):
|
curve='ROC', name=None):
|
||||||
@ -2372,6 +2468,7 @@ __all__ = [
|
|||||||
'sparse_recall_at_top_k',
|
'sparse_recall_at_top_k',
|
||||||
'streaming_accuracy',
|
'streaming_accuracy',
|
||||||
'streaming_auc',
|
'streaming_auc',
|
||||||
|
'streaming_curve_points',
|
||||||
'streaming_false_negatives',
|
'streaming_false_negatives',
|
||||||
'streaming_false_negatives_at_thresholds',
|
'streaming_false_negatives_at_thresholds',
|
||||||
'streaming_false_positives',
|
'streaming_false_positives',
|
||||||
|
@ -1327,6 +1327,99 @@ class StreamingRecallTest(test.TestCase):
|
|||||||
self.assertEqual(0, recall.eval())
|
self.assertEqual(0, recall.eval())
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingCurvePointsTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
np.random.seed(1)
|
||||||
|
ops.reset_default_graph()
|
||||||
|
|
||||||
|
def testVars(self):
|
||||||
|
metric_ops.streaming_curve_points(
|
||||||
|
predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
|
||||||
|
_assert_local_variables(
|
||||||
|
self,
|
||||||
|
('curve_points/true_positives:0', 'curve_points/false_negatives:0',
|
||||||
|
'curve_points/false_positives:0', 'curve_points/true_negatives:0'))
|
||||||
|
|
||||||
|
def testMetricsCollection(self):
|
||||||
|
my_collection_name = '__metrics__'
|
||||||
|
points, _ = metric_ops.streaming_curve_points(
|
||||||
|
labels=array_ops.ones((10, 1)),
|
||||||
|
predictions=array_ops.ones((10, 1)),
|
||||||
|
metrics_collections=[my_collection_name])
|
||||||
|
self.assertListEqual(ops.get_collection(my_collection_name), [points])
|
||||||
|
|
||||||
|
def testUpdatesCollection(self):
|
||||||
|
my_collection_name = '__updates__'
|
||||||
|
_, update_op = metric_ops.streaming_curve_points(
|
||||||
|
labels=array_ops.ones((10, 1)),
|
||||||
|
predictions=array_ops.ones((10, 1)),
|
||||||
|
updates_collections=[my_collection_name])
|
||||||
|
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
|
||||||
|
|
||||||
|
def _testValueTensorIsIdempotent(self, curve):
|
||||||
|
predictions = constant_op.constant(
|
||||||
|
np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32)
|
||||||
|
labels = constant_op.constant(
|
||||||
|
np.random.uniform(high=2, size=(10, 3)), dtype=dtypes_lib.float32)
|
||||||
|
|
||||||
|
points, update_op = metric_ops.streaming_curve_points(
|
||||||
|
labels, predictions=predictions, curve=curve)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
|
||||||
|
sess.run(update_op)
|
||||||
|
initial_points = points.eval()
|
||||||
|
|
||||||
|
sess.run(update_op)
|
||||||
|
self.assertAllClose(initial_points, points.eval())
|
||||||
|
|
||||||
|
def testValueTensorIsIdempotentROC(self):
|
||||||
|
self._testValueTensorIsIdempotent(curve='ROC')
|
||||||
|
|
||||||
|
def testValueTensorIsIdempotentPR(self):
|
||||||
|
self._testValueTensorIsIdempotent(curve='PR')
|
||||||
|
|
||||||
|
def _testCase(self, labels, predictions, curve, expected_points):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
predictions_tensor = constant_op.constant(
|
||||||
|
predictions, dtype=dtypes_lib.float32)
|
||||||
|
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
|
||||||
|
points, update_op = metric_ops.streaming_curve_points(
|
||||||
|
labels=labels_tensor,
|
||||||
|
predictions=predictions_tensor,
|
||||||
|
num_thresholds=3,
|
||||||
|
curve=curve)
|
||||||
|
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
sess.run(update_op)
|
||||||
|
|
||||||
|
self.assertAllClose(expected_points, points.eval())
|
||||||
|
|
||||||
|
def testEdgeCasesROC(self):
|
||||||
|
self._testCase([[1]], [[1]], 'ROC', [[0, 1], [0, 1], [0, 0]])
|
||||||
|
self._testCase([[0]], [[0]], 'ROC', [[1, 1], [0, 1], [0, 1]])
|
||||||
|
self._testCase([[0]], [[1]], 'ROC', [[1, 1], [1, 1], [0, 1]])
|
||||||
|
self._testCase([[1]], [[0]], 'ROC', [[0, 1], [0, 0], [0, 0]])
|
||||||
|
|
||||||
|
def testManyValuesROC(self):
|
||||||
|
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
|
||||||
|
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'ROC',
|
||||||
|
[[1.0, 1.0], [0.0, 0.75], [0.0, 0.0]])
|
||||||
|
|
||||||
|
def testEdgeCasesPR(self):
|
||||||
|
self._testCase([[1]], [[1]], 'PR', [[1, 1], [1, 1], [0, 1]])
|
||||||
|
self._testCase([[0]], [[0]], 'PR', [[1, 0], [1, 1], [1, 1]])
|
||||||
|
self._testCase([[0]], [[1]], 'PR', [[1, 0], [1, 0], [1, 1]])
|
||||||
|
self._testCase([[1]], [[0]], 'PR', [[1, 1], [0, 1], [0, 1]])
|
||||||
|
|
||||||
|
def testManyValuesPR(self):
|
||||||
|
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
|
||||||
|
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'PR',
|
||||||
|
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
|
||||||
|
|
||||||
|
|
||||||
class StreamingAUCTest(test.TestCase):
|
class StreamingAUCTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase):
|
|||||||
class BeamSearchDecoderTest(test.TestCase):
|
class BeamSearchDecoderTest(test.TestCase):
|
||||||
|
|
||||||
def _testDynamicDecodeRNN(self, time_major, has_attention):
|
def _testDynamicDecodeRNN(self, time_major, has_attention):
|
||||||
encoder_sequence_length = [3, 2, 3, 1, 1]
|
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
|
||||||
decoder_sequence_length = [2, 0, 1, 2, 3]
|
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
decoder_max_time = 4
|
decoder_max_time = 4
|
||||||
input_depth = 7
|
input_depth = 7
|
||||||
@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase):
|
|||||||
batch_size_tensor = constant_op.constant(batch_size)
|
batch_size_tensor = constant_op.constant(batch_size)
|
||||||
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
||||||
cell = rnn_cell.LSTMCell(cell_depth)
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
|
initial_state = cell.zero_state(batch_size, dtypes.float32)
|
||||||
if has_attention:
|
if has_attention:
|
||||||
inputs = array_ops.placeholder_with_default(
|
inputs = array_ops.placeholder_with_default(
|
||||||
np.random.randn(batch_size, decoder_max_time,
|
np.random.randn(batch_size, decoder_max_time,
|
||||||
@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase):
|
|||||||
num_units=attention_depth,
|
num_units=attention_depth,
|
||||||
memory=tiled_inputs,
|
memory=tiled_inputs,
|
||||||
memory_sequence_length=tiled_sequence_length)
|
memory_sequence_length=tiled_sequence_length)
|
||||||
|
initial_state = beam_search_decoder.tile_batch(
|
||||||
|
initial_state, multiplier=beam_width)
|
||||||
cell = attention_wrapper.AttentionWrapper(
|
cell = attention_wrapper.AttentionWrapper(
|
||||||
cell=cell,
|
cell=cell,
|
||||||
attention_mechanism=attention_mechanism,
|
attention_mechanism=attention_mechanism,
|
||||||
@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase):
|
|||||||
alignment_history=False)
|
alignment_history=False)
|
||||||
cell_state = cell.zero_state(
|
cell_state = cell.zero_state(
|
||||||
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
|
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
|
||||||
|
if has_attention:
|
||||||
|
cell_state = cell_state.clone(
|
||||||
|
cell_state=initial_state)
|
||||||
bsd = beam_search_decoder.BeamSearchDecoder(
|
bsd = beam_search_decoder.BeamSearchDecoder(
|
||||||
cell=cell,
|
cell=cell,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
|
@ -72,27 +72,8 @@ class FinalBeamSearchDecoderOutput(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def tile_batch(t, multiplier, name=None):
|
def _tile_batch(t, multiplier):
|
||||||
"""Tile the batch dimension of tensor t.
|
"""Core single-tensor implementation of tile_batch."""
|
||||||
|
|
||||||
This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
|
|
||||||
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
|
|
||||||
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
|
|
||||||
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
|
|
||||||
`multiplier` times.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
t: `Tensor` shaped `[batch_size, ...]`.
|
|
||||||
multiplier: Python int.
|
|
||||||
name: Name scope for any created operations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `Tensor` shaped `[batch_size * multiplier, ...]`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `t` does not have a statically known rank or it's < 1.
|
|
||||||
"""
|
|
||||||
with ops.name_scope(name, "tile_batch", [t, multiplier]):
|
|
||||||
t = ops.convert_to_tensor(t, name="t")
|
t = ops.convert_to_tensor(t, name="t")
|
||||||
shape_t = array_ops.shape(t)
|
shape_t = array_ops.shape(t)
|
||||||
if t.shape.ndims is None or t.shape.ndims < 1:
|
if t.shape.ndims is None or t.shape.ndims < 1:
|
||||||
@ -110,6 +91,34 @@ def tile_batch(t, multiplier, name=None):
|
|||||||
return tiled
|
return tiled
|
||||||
|
|
||||||
|
|
||||||
|
def tile_batch(t, multiplier, name=None):
|
||||||
|
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
|
||||||
|
|
||||||
|
For each tensor t in a (possibly nested structure) of tensors,
|
||||||
|
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
|
||||||
|
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
|
||||||
|
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
|
||||||
|
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
|
||||||
|
`multiplier` times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: `Tensor` shaped `[batch_size, ...]`.
|
||||||
|
multiplier: Python int.
|
||||||
|
name: Name scope for any created operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A (possibly nested structure of) `Tensor` shaped
|
||||||
|
`[batch_size * multiplier, ...]`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if tensor(s) `t` do not have a statically known rank or
|
||||||
|
the rank is < 1.
|
||||||
|
"""
|
||||||
|
flat_t = nest.flatten(t)
|
||||||
|
with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
|
||||||
|
return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
|
||||||
|
|
||||||
|
|
||||||
def _check_maybe(t):
|
def _check_maybe(t):
|
||||||
if isinstance(t, tensor_array_ops.TensorArray):
|
if isinstance(t, tensor_array_ops.TensorArray):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
@ -270,7 +270,7 @@ class SessionBundleTest : public ::testing::Test {
|
|||||||
// MetaGraphDef.
|
// MetaGraphDef.
|
||||||
// Returns the path of the export.
|
// Returns the path of the export.
|
||||||
// ** Should only be called once per test **
|
// ** Should only be called once per test **
|
||||||
string SetupExport(MetaGraphDefTwiddler twiddler) {
|
string SetupExport(const MetaGraphDefTwiddler& twiddler) {
|
||||||
return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename);
|
return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename);
|
||||||
}
|
}
|
||||||
// SetupExport that allows for the variables and meta_graph_def filenames
|
// SetupExport that allows for the variables and meta_graph_def filenames
|
||||||
|
@ -62,6 +62,7 @@ licenses(["notice"]) # Apache 2.0
|
|||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"full_path",
|
||||||
"if_android",
|
"if_android",
|
||||||
"if_ios",
|
"if_ios",
|
||||||
"if_x86",
|
"if_x86",
|
||||||
|
@ -30,7 +30,11 @@ Device::Device(Env* env, const DeviceAttributes& device_attributes)
|
|||||||
rmgr_ = new ResourceMgr(parsed_name_.job);
|
rmgr_ = new ResourceMgr(parsed_name_.job);
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::~Device() { delete rmgr_; }
|
Device::~Device() {
|
||||||
|
if (rmgr_ != nullptr) {
|
||||||
|
DeleteResourceMgr();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
DeviceAttributes Device::BuildDeviceAttributes(
|
DeviceAttributes Device::BuildDeviceAttributes(
|
||||||
|
@ -60,7 +60,9 @@ class Device : public DeviceBase {
|
|||||||
const string& name() const { return device_attributes_.name(); }
|
const string& name() const { return device_attributes_.name(); }
|
||||||
|
|
||||||
// Parsed name of this device
|
// Parsed name of this device
|
||||||
const DeviceNameUtils::ParsedName parsed_name() const { return parsed_name_; }
|
const DeviceNameUtils::ParsedName& parsed_name() const {
|
||||||
|
return parsed_name_;
|
||||||
|
}
|
||||||
|
|
||||||
// Describes what kind of device this is. This is intended to be
|
// Describes what kind of device this is. This is intended to be
|
||||||
// human-readable and not computer-parsed, except that two devices
|
// human-readable and not computer-parsed, except that two devices
|
||||||
@ -149,6 +151,12 @@ class Device : public DeviceBase {
|
|||||||
return BuildDeviceAttributes(name, device, memory_limit, locality, "");
|
return BuildDeviceAttributes(name, device, memory_limit, locality, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void DeleteResourceMgr() {
|
||||||
|
delete rmgr_;
|
||||||
|
rmgr_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const DeviceAttributes device_attributes_;
|
const DeviceAttributes device_attributes_;
|
||||||
DeviceNameUtils::ParsedName parsed_name_;
|
DeviceNameUtils::ParsedName parsed_name_;
|
||||||
|
@ -53,7 +53,7 @@ Device* DeviceSet::FindDeviceByName(const string& name) const {
|
|||||||
|
|
||||||
// static
|
// static
|
||||||
int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
|
int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
|
||||||
return DeviceFactory::DevicePriority(d.type());
|
return DeviceFactory::DevicePriority(d.type_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) {
|
static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) {
|
||||||
|
@ -1231,7 +1231,7 @@ Status FunctionDefToBodyHelper(
|
|||||||
GraphConstructorOptions opts;
|
GraphConstructorOptions opts;
|
||||||
opts.allow_internal_ops = true;
|
opts.allow_internal_ops = true;
|
||||||
opts.expect_device_spec = false;
|
opts.expect_device_spec = false;
|
||||||
Status s = ConvertGraphDefToGraph(opts, result.gdef, graph);
|
Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
delete graph;
|
delete graph;
|
||||||
} else {
|
} else {
|
||||||
|
@ -93,7 +93,7 @@ class FunctionTest : public ::testing::Test {
|
|||||||
GraphConstructorOptions opts;
|
GraphConstructorOptions opts;
|
||||||
opts.allow_internal_ops = true;
|
opts.allow_internal_ops = true;
|
||||||
opts.expect_device_spec = false;
|
opts.expect_device_spec = false;
|
||||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g));
|
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g));
|
||||||
|
|
||||||
const int version = g->versions().producer();
|
const int version = g->versions().producer();
|
||||||
LocalExecutorParams params;
|
LocalExecutorParams params;
|
||||||
@ -949,7 +949,7 @@ GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
|
|||||||
GraphConstructorOptions opts;
|
GraphConstructorOptions opts;
|
||||||
opts.allow_internal_ops = true;
|
opts.allow_internal_ops = true;
|
||||||
opts.expect_device_spec = false;
|
opts.expect_device_spec = false;
|
||||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get()));
|
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
|
||||||
pass(g.get());
|
pass(g.get());
|
||||||
std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
|
std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
|
||||||
CopyGraph(*g, g1.get());
|
CopyGraph(*g, g1.get());
|
||||||
|
@ -324,6 +324,7 @@ static void BM_AllocationDelayed(int iters, int delay) {
|
|||||||
int size_index = 0;
|
int size_index = 0;
|
||||||
|
|
||||||
std::vector<void*> ptrs;
|
std::vector<void*> ptrs;
|
||||||
|
ptrs.reserve(delay);
|
||||||
for (int i = 0; i < delay; i++) {
|
for (int i = 0; i < delay; i++) {
|
||||||
ptrs.push_back(nullptr);
|
ptrs.push_back(nullptr);
|
||||||
}
|
}
|
||||||
|
@ -123,10 +123,12 @@ void Benchmark::RunWithArgs(
|
|||||||
}
|
}
|
||||||
// Gets inputs' and outputs' rendezvous keys.
|
// Gets inputs' and outputs' rendezvous keys.
|
||||||
std::vector<std::pair<string, Tensor>> in;
|
std::vector<std::pair<string, Tensor>> in;
|
||||||
|
in.reserve(inputs.size());
|
||||||
for (const auto& p : inputs) {
|
for (const auto& p : inputs) {
|
||||||
in.push_back({GetRendezvousKey(p.first), p.second});
|
in.push_back({GetRendezvousKey(p.first), p.second});
|
||||||
}
|
}
|
||||||
std::vector<string> out;
|
std::vector<string> out;
|
||||||
|
out.reserve(outputs.size());
|
||||||
for (const auto& n : outputs) {
|
for (const auto& n : outputs) {
|
||||||
out.push_back(GetRendezvousKey(n));
|
out.push_back(GetRendezvousKey(n));
|
||||||
}
|
}
|
||||||
|
@ -94,6 +94,7 @@ Status SessionFactory::GetFactory(const SessionOptions& options,
|
|||||||
// TODO(mrry): Consider providing a system-default fallback option
|
// TODO(mrry): Consider providing a system-default fallback option
|
||||||
// in this case.
|
// in this case.
|
||||||
std::vector<string> factory_types;
|
std::vector<string> factory_types;
|
||||||
|
factory_types.reserve(candidate_factories.size());
|
||||||
for (const auto& candidate_factory : candidate_factories) {
|
for (const auto& candidate_factory : candidate_factories) {
|
||||||
factory_types.push_back(candidate_factory.first);
|
factory_types.push_back(candidate_factory.first);
|
||||||
}
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user