Merge pull request #12780 from martinwicke/branch_167401527
Branch 167401527
This commit is contained in:
commit
512d3d0868
@ -296,6 +296,7 @@ filegroup(
|
||||
"//tensorflow/contrib/ffmpeg/default:all_files",
|
||||
"//tensorflow/contrib/framework:all_files",
|
||||
"//tensorflow/contrib/fused_conv:all_files",
|
||||
"//tensorflow/contrib/gan:all_files",
|
||||
"//tensorflow/contrib/graph_editor:all_files",
|
||||
"//tensorflow/contrib/grid_rnn:all_files",
|
||||
"//tensorflow/contrib/hooks:all_files",
|
||||
@ -323,6 +324,7 @@ filegroup(
|
||||
"//tensorflow/contrib/nn:all_files",
|
||||
"//tensorflow/contrib/opt:all_files",
|
||||
"//tensorflow/contrib/predictor:all_files",
|
||||
"//tensorflow/contrib/receptive_field:all_files",
|
||||
"//tensorflow/contrib/reduce_slice_ops:all_files",
|
||||
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
|
||||
"//tensorflow/contrib/resampler:all_files",
|
||||
@ -342,6 +344,7 @@ filegroup(
|
||||
"//tensorflow/contrib/staging:all_files",
|
||||
"//tensorflow/contrib/stat_summarizer:all_files",
|
||||
"//tensorflow/contrib/stateless:all_files",
|
||||
"//tensorflow/contrib/summary:all_files",
|
||||
"//tensorflow/contrib/tensor_forest:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/kernels/v4:all_files",
|
||||
|
@ -45,8 +45,13 @@ tf_cuda_library(
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api",
|
||||
srcs = ["c_api.cc"],
|
||||
hdrs = ["c_api.h"],
|
||||
srcs = [
|
||||
"c_api.cc",
|
||||
"c_api_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"c_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
@ -157,6 +162,21 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "c_api_function_test",
|
||||
size = "small",
|
||||
srcs = ["c_api_function_test.cc"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "while_loop_test",
|
||||
size = "small",
|
||||
|
@ -165,22 +165,6 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
tensorflow::cpu_allocator()->DeallocateRaw(data);
|
||||
}
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
|
||||
TF_Buffer* out) {
|
||||
if (out->data != nullptr) {
|
||||
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
|
||||
}
|
||||
const auto proto_size = in.ByteSizeLong();
|
||||
void* buf = tensorflow::port::Malloc(proto_size);
|
||||
in.SerializeToArray(buf, proto_size);
|
||||
out->data = buf;
|
||||
out->length = proto_size;
|
||||
out->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TF_Tensor::~TF_Tensor() { buffer->Unref(); }
|
||||
@ -559,6 +543,27 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
dimvec.size(), base, size, DeleteArray, base);
|
||||
}
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
|
||||
TF_Buffer* out) {
|
||||
if (out->data != nullptr) {
|
||||
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
|
||||
}
|
||||
const size_t proto_size = in.ByteSizeLong();
|
||||
void* buf = tensorflow::port::Malloc(proto_size);
|
||||
if (buf == nullptr) {
|
||||
return tensorflow::errors::ResourceExhausted(
|
||||
"Failed to allocate memory to serialize message of type '",
|
||||
in.GetTypeName(), "' and size ", proto_size);
|
||||
}
|
||||
in.SerializeToArray(buf, proto_size);
|
||||
out->data = buf;
|
||||
out->length = proto_size;
|
||||
out->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helpers for loading a TensorFlow plugin (a .so file).
|
||||
Status LoadLibrary(const char* library_filename, void** result,
|
||||
const void** buf, size_t* len);
|
||||
|
@ -357,6 +357,14 @@ typedef struct TF_Output {
|
||||
int index; // The index of the output within oper.
|
||||
} TF_Output;
|
||||
|
||||
// TF_Function is a grouping of operations with defined inputs and outputs.
|
||||
// Once created and added to graphs, functions can be invoked by creating an
|
||||
// operation whose operation type matches the function name.
|
||||
typedef struct TF_Function TF_Function;
|
||||
|
||||
// Function definition options. TODO(iga): Define and implement
|
||||
typedef struct TF_FunctionOptions TF_FunctionOptions;
|
||||
|
||||
// Sets the shape of the Tensor referenced by `output` in `graph` to
|
||||
// the shape described by `dims` and `num_dims`.
|
||||
//
|
||||
@ -914,6 +922,15 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef(
|
||||
TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options, TF_Status* status);
|
||||
|
||||
// Add `function` to graph `g`. Once `function` is added to `g`,
|
||||
// it can be called by creating an operation using the function's name.
|
||||
//
|
||||
// If successful, status is set to OK and function is added to g
|
||||
// Otherwise, status is set to the encountered error and g is unmodified
|
||||
TF_CAPI_EXPORT extern void TF_GraphAddFunction(TF_Graph* g,
|
||||
const TF_Function* function,
|
||||
TF_Status* status);
|
||||
|
||||
// Note: The following function may fail on very large protos in the future.
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper,
|
||||
@ -1001,6 +1018,105 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
|
||||
TF_Output* x, int nx, TF_Output* dx,
|
||||
TF_Status* status, TF_Output* dy);
|
||||
|
||||
// Create a TF_Function from a TF_Graph
|
||||
//
|
||||
// Params:
|
||||
// fn_body - the graph whose operations (or subset of whose operations) will be
|
||||
// converted to TF_Function.
|
||||
// fn_name - the name of the new TF_Function. Should match the operation
|
||||
// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct
|
||||
// from other operation names (at least those registered in graphs
|
||||
// where this function will be used).
|
||||
// TODO(iga): Allow null in here and have C API come up with
|
||||
// a unique name with high probability (similarly to
|
||||
// _create_hash_str in function.py)
|
||||
// num_opers - `num_opers` contains the number of elements in the `opers` array
|
||||
// or a special value of -1 meaning that no array is given.
|
||||
// The distinction between an empty array of operations and no
|
||||
// array of operations is necessary to distinguish the case of
|
||||
// creating a function with no body (e.g. identity or permutation)
|
||||
// and the case of creating a function whose body contains all
|
||||
// the nodes in the graph (except for the automatic skipping, see
|
||||
// below).
|
||||
// opers - Array of operations to become the body of the function or null.
|
||||
// - If no array is given (`num_opers` = -1), all the
|
||||
// operations in `fn_body` will become part of the function
|
||||
// except operations referenced in `inputs`. These operations
|
||||
// must have a single output (these operations are typically
|
||||
// placeholders created for the sole purpose of representing
|
||||
// an input. We can relax this constraint if there are
|
||||
// compelling use cases).
|
||||
// - If an array is given (`num_opers` >= 0), all operations
|
||||
// in it will become part of the function. In particular, no
|
||||
// automatic skipping of dummy input operations is performed.
|
||||
// ninputs - number of elements in `inputs` array
|
||||
// inputs - array of TF_Outputs that specify the inputs to the function.
|
||||
// If `ninputs` is zero (the function takes no inputs), `inputs`
|
||||
// can be null. The names used for function inputs are normalized
|
||||
// names of the operations (usually placeholders) pointed to by
|
||||
// `inputs`. These operation names should start with a letter.
|
||||
// Normalization will convert all letters to lowercase and
|
||||
// non-alphanumeric characters to '_' to make resulting names match
|
||||
// the "[a-z][a-z0-9_]*" pattern for operation argument names.
|
||||
// `inputs` cannot contain the same tensor twice.
|
||||
// noutputs - number of elements in `outputs` array
|
||||
// outputs - array of TF_Outputs that specify the outputs of the function.
|
||||
// If `noutputs` is zero (the function returns no outputs), `outputs`
|
||||
// can be null. `outputs` can contain the same tensor more than once.
|
||||
// output_names - The names of the function's outputs. `output_names` array
|
||||
// must either have the same length as `outputs`
|
||||
// (i.e. `noutputs`) or be null. In the former case,
|
||||
// the names should match the regular expression for ArgDef
|
||||
// names - "[a-z][a-z0-9_]*". In the latter case,
|
||||
// names for outputs will be generated automatically.
|
||||
// opts - various options for the function, e.g. XLA's inlining control.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
//
|
||||
// Note that when the same TF_Output is listed as both an input and an output,
|
||||
// the corresponding function's output will equal to this input,
|
||||
// instead of the original node's output.
|
||||
//
|
||||
// Callers must also satisfy the following constraints:
|
||||
// - `inputs` cannot refer to TF_Outputs within a control flow context. For
|
||||
// example, one cannot use the output of "switch" node as input.
|
||||
// - No TF_Output of a function (inside any of `inputs`, `outputs`, `fn_body`)
|
||||
// is allowed to have a reference type. Reference types are not exposed
|
||||
// through C API and are being deprecated.
|
||||
// - Every node in the function's body must have all of its inputs (including
|
||||
// control inputs). In other words, for every node in the body, each input
|
||||
// must be either listed in `inputs` or must come from another node in
|
||||
// the body. In particular, it is an error to have a control edge going from
|
||||
// a node outside of the body into a node in the body. This applies to control
|
||||
// edges going from nodes referenced in `inputs` to nodes in the body when
|
||||
// the former nodes are not in the body (automatically skipped or not
|
||||
// included in explicitly specified body).
|
||||
//
|
||||
// Returns:
|
||||
// On successful, a newly created TF_Function instance. It must be deleted by
|
||||
// calling TF_DeleteFunction.
|
||||
//
|
||||
// On failure, null.
|
||||
//
|
||||
// TODO(iga): Add input_names argument and get output_names working (they are
|
||||
// currently ignored)
|
||||
TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
|
||||
const TF_Graph* fn_body, const char* fn_name, int num_opers,
|
||||
const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
|
||||
int noutputs, const TF_Output* outputs, const char* const* output_names,
|
||||
const TF_FunctionOptions* opts, TF_Status* status);
|
||||
|
||||
// Write out a serialized representation of `func` (as a FunctionDef protocol
|
||||
// message) to `output_func_def` (allocated by TF_NewBuffer()).
|
||||
// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
|
||||
// is called.
|
||||
//
|
||||
// May fail on very large graphs in the future.
|
||||
TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func,
|
||||
TF_Buffer* output_func_def,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function*);
|
||||
|
||||
// TODO(josh11b): Register OpDef, available to all operations added
|
||||
// to this graph.
|
||||
|
||||
|
496
tensorflow/c/c_api_function.cc
Normal file
496
tensorflow/c/c_api_function.cc
Normal file
@ -0,0 +1,496 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Class that maintains a one-to-one original node name -> new node name
|
||||
// mapping. We normalize the names used as input and output arguments to match
|
||||
// regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
|
||||
// Once we rename them, we risk creating a name collision with the other
|
||||
// node names, so if necessary we add a suffix to make
|
||||
// names unique. If we have an input named "A" and a node in the function
|
||||
// body named "a", they will be renamed to "a" and "a_0".
|
||||
class NodeNameMapping {
|
||||
public:
|
||||
NodeNameMapping() = default;
|
||||
|
||||
// Normalize the input/output name and make it unique.
|
||||
string GetIOName(const string& name);
|
||||
|
||||
// Make the node name unique.
|
||||
string Uniquify(const string& name);
|
||||
|
||||
// Look up how a node name was previously normalized/uniquified.
|
||||
// Returns empty if name was never seen.
|
||||
string Lookup(const string& name) const;
|
||||
|
||||
private:
|
||||
string UniquifyHelper(const string& name) const;
|
||||
static string Normalize(string name);
|
||||
|
||||
// The normalized/uniquified names already used as
|
||||
// input names (in signature), output names (in signature), and node names
|
||||
// (in node_def).
|
||||
// This is a superset of values in name_mapping_.
|
||||
std::unordered_set<string> used_names_;
|
||||
// Mapping from original node name from the graph to the normalized
|
||||
// and uniqified version of it.
|
||||
std::unordered_map<string, string> name_mapping_;
|
||||
};
|
||||
|
||||
string NodeNameMapping::Normalize(string name) {
|
||||
// Convert letters to lowercase and non-alphanumeric characters to '_'.
|
||||
if (name.empty()) return "unknown";
|
||||
const int n = name.size();
|
||||
for (int i = 0; i < n; ++i) {
|
||||
char c = name[i];
|
||||
if (isalnum(c)) {
|
||||
if (isupper(c)) {
|
||||
name[i] = tolower(c);
|
||||
}
|
||||
} else {
|
||||
name[i] = '_';
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first letter and start with it.
|
||||
int i = 0;
|
||||
for (; i < n; ++i) {
|
||||
if (isalpha(name[i])) break;
|
||||
}
|
||||
|
||||
// Return "unknown" if none of the name's chars were letters.
|
||||
return i == n ? "unknown" : name.substr(i);
|
||||
}
|
||||
|
||||
string NodeNameMapping::UniquifyHelper(const string& name) const {
|
||||
// If the name hasn't been used yet, use it as-is.
|
||||
if (used_names_.find(name) == used_names_.end()) return name;
|
||||
// Add a suffix to name to make it unique.
|
||||
for (int i = 0;; ++i) {
|
||||
const string candidate = strings::StrCat(name, "_", i);
|
||||
if (used_names_.find(candidate) == used_names_.end()) return candidate;
|
||||
}
|
||||
}
|
||||
|
||||
string NodeNameMapping::GetIOName(const string& name) {
|
||||
const string& input_name = UniquifyHelper(Normalize(name));
|
||||
// Record that we used this name, but don't add it to name_mapping_
|
||||
// since this name is not for a node.
|
||||
used_names_.insert(input_name);
|
||||
return input_name;
|
||||
}
|
||||
|
||||
string NodeNameMapping::Uniquify(const string& name) {
|
||||
const string uniqued = UniquifyHelper(name);
|
||||
name_mapping_[name] = uniqued;
|
||||
used_names_.insert(uniqued);
|
||||
return uniqued;
|
||||
}
|
||||
|
||||
string NodeNameMapping::Lookup(const string& name) const {
|
||||
const auto iter = name_mapping_.find(name);
|
||||
if (iter == name_mapping_.end()) return string();
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
Status ValidateNoRefOutputs(const Node* node) {
|
||||
for (int i = 0; i < node->num_outputs(); ++i) {
|
||||
const DataType& dt = node->output_type(i);
|
||||
if (IsRefType(dt)) {
|
||||
return errors::InvalidArgument("Output ", i, " of node '", node->name(),
|
||||
"' has a reference "
|
||||
"type ",
|
||||
DataTypeString(dt));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FillFunctionBody(
|
||||
const string& fn_name, const NodeNameMapping& node_names,
|
||||
const std::vector<const Node*>& body_nodes,
|
||||
const std::unordered_map<string, string>& tensor_renaming,
|
||||
FunctionDef* fdef) {
|
||||
std::vector<const Edge*> in_edges;
|
||||
std::vector<const Edge*> control_edges;
|
||||
for (const Node* node : body_nodes) {
|
||||
NodeDef* node_def = fdef->add_node_def();
|
||||
// First, copy the node_def as is. We will patch it next.
|
||||
*node_def = node->def();
|
||||
if (!node->assigned_device_name().empty()) {
|
||||
node_def->set_device(node->assigned_device_name());
|
||||
}
|
||||
node_def->set_name(node_names.Lookup(node->name()));
|
||||
|
||||
// Input names must be set based on nested names in tensor_renaming.
|
||||
// Clear the flat input names we got from the original node_def
|
||||
// from the graph.
|
||||
node_def->clear_input();
|
||||
|
||||
// Collect regular and control inputs. Regular inputs are indexed
|
||||
// by the index at which they come into the `node`. Control inputs
|
||||
// don't follow any order.
|
||||
in_edges.clear();
|
||||
in_edges.resize(node->num_inputs(), nullptr);
|
||||
control_edges.clear();
|
||||
for (const Edge* edge : node->in_edges()) {
|
||||
if (edge->src()->IsSource()) continue;
|
||||
if (edge->IsControlEdge()) {
|
||||
control_edges.push_back(edge);
|
||||
} else {
|
||||
in_edges[edge->dst_input()] = edge;
|
||||
}
|
||||
}
|
||||
|
||||
// Add regular inputs.
|
||||
for (size_t i = 0; i < in_edges.size(); ++i) {
|
||||
const Edge* edge = in_edges[i];
|
||||
string original_input_name;
|
||||
if (edge == nullptr) {
|
||||
// A backedge might not appear as a regular Edge, but be only present
|
||||
// in the node_def. Such edges are referred to as requested_inputs().
|
||||
if (i >= node->requested_inputs().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Graph to be converted to function appears to be malformed. ",
|
||||
"Node ", node->name(), " is missing input edge ", i);
|
||||
}
|
||||
original_input_name =
|
||||
ParseTensorName(node->requested_inputs()[i]).ToString();
|
||||
} else {
|
||||
original_input_name =
|
||||
strings::StrCat(edge->src()->name(), ":", edge->src_output());
|
||||
}
|
||||
|
||||
const auto iter = tensor_renaming.find(original_input_name);
|
||||
if (iter == tensor_renaming.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Input ", i, ", '", original_input_name, "', of node '",
|
||||
node->name(), "' in function '", fn_name,
|
||||
"' is not available. You might need to include it in inputs "
|
||||
"or include its source node in the body");
|
||||
}
|
||||
node_def->add_input(iter->second);
|
||||
}
|
||||
|
||||
// Add control inputs.
|
||||
for (const Edge* edge : control_edges) {
|
||||
// Add this control input only if the src node is in the body.
|
||||
const string normalized = node_names.Lookup(edge->src()->name());
|
||||
// If we did not find a name for the source of control edge, this
|
||||
// source must be outside of the body. Raise an error.
|
||||
if (normalized.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"The source of control edge ", edge->DebugString(),
|
||||
" is not in the body. Encountered while creating function '",
|
||||
fn_name, "'");
|
||||
}
|
||||
node_def->add_input(strings::StrCat("^", normalized));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Graph to FunctionDef conversion. This code is closely modeled on the Python
|
||||
// code in third_party/tensorflow/python/framework/function.py.
|
||||
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
|
||||
const std::vector<const Node*>& body_nodes,
|
||||
const std::vector<OutputTensor>& inputs,
|
||||
const std::vector<OutputTensor>& outputs,
|
||||
const std::vector<string>& output_names,
|
||||
FunctionDef* fdef) {
|
||||
fdef->mutable_signature()->set_name(fn_name);
|
||||
|
||||
// Keep track of names we used and how we normalized them.
|
||||
NodeNameMapping node_names;
|
||||
|
||||
// Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
|
||||
// name we used in the function:
|
||||
// - For input tensors:
|
||||
// {flat_tensor_name -> normalized_name_of_src_node}
|
||||
// e.g. {In:3 -> in}
|
||||
// - For tensors produced by nodes in function's body:
|
||||
// {flat_tensor_name -> nested_tensor_name}
|
||||
// e.g. {Add:3 -> add_0:z:1}
|
||||
std::unordered_map<string, string> tensor_renaming;
|
||||
|
||||
// Fill inputs in function's signature.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const Node* node = inputs[i].node;
|
||||
int idx = inputs[i].index;
|
||||
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
|
||||
argdef->set_type(node->output_type(idx));
|
||||
const string& input_name = node_names.GetIOName(node->name());
|
||||
argdef->set_name(input_name);
|
||||
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
|
||||
}
|
||||
|
||||
// Fill outputs in function's signature.
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
const Node* node = outputs[i].node;
|
||||
int idx = outputs[i].index;
|
||||
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
|
||||
argdef->set_type(node->output_type(idx));
|
||||
argdef->set_name(node_names.GetIOName(node->name()));
|
||||
}
|
||||
|
||||
// Populate tensor_renaming and node_names.
|
||||
// Generate the new output names for every node in the function.
|
||||
// The NodeDefs in FunctionDefs use a different naming scheme for
|
||||
// their inputs than the NodeDefs in a graph (see the comment for
|
||||
// FunctionDef.node_def in function.proto). We do the
|
||||
// graph tensor name -> function tensor name conversion for every
|
||||
// possible input (i.e. every node's outputs) and store the result
|
||||
// in tensor_renaming.
|
||||
for (const Node* node : body_nodes) {
|
||||
// Make sure node_name does not collide with an input or output name.
|
||||
const string& node_name = node_names.Uniquify(node->name());
|
||||
// For each output_arg in the op_def, the output_ranges
|
||||
// map will have [start, end] range of indices that this arg produces
|
||||
// among all the output tensors of this op.
|
||||
NameRangeMap output_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
|
||||
for (const auto& output : output_ranges) {
|
||||
const string& output_name = output.first;
|
||||
int index_start = output.second.first;
|
||||
int index_end = output.second.second;
|
||||
for (int i = index_start; i < index_end; ++i) {
|
||||
const string& original_name = strings::StrCat(node->name(), ":", i);
|
||||
const string& new_name =
|
||||
strings::StrCat(node_name, ":", output_name, ":", i - index_start);
|
||||
// Record the mapping if this tensor is not already mapped.
|
||||
// Tensor can be already mapped if it is used as an input.
|
||||
if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
|
||||
tensor_renaming[original_name] = new_name;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
|
||||
|
||||
// Remap return values.
|
||||
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
|
||||
const string& ret_name = fdef->signature().output_arg(r).name();
|
||||
|
||||
// We convert this flat tensor name to the nested value
|
||||
// (e.g. `add:z:1`) that we stored in tensor_renaming.
|
||||
const string& return_value =
|
||||
strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
|
||||
const auto iter = tensor_renaming.find(return_value);
|
||||
if (iter == tensor_renaming.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"TF_Output ", return_value, " is neither in the function body ",
|
||||
"nor among function inputs. Encountered while creating function '",
|
||||
fn_name, "'");
|
||||
}
|
||||
(*fdef->mutable_ret())[ret_name] = iter->second;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
|
||||
// does various checks while doing so. `input_nodes` will contain the same
|
||||
// information as input_tensors just in a different structure to make
|
||||
// following processing easier. TODO(iga): Simplify this nested structure.
|
||||
Status ProcessInputs(
|
||||
const TF_Graph* fn_body, const char* fn_name, int ninputs,
|
||||
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
|
||||
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
input_tensors->reserve(ninputs);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
const Node& node = inputs[i].oper->node;
|
||||
int idx = inputs[i].index;
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
fn_body->graph.IsValidOutputTensor(&node, idx),
|
||||
"Encountered while processing input ", i, " into function '", fn_name,
|
||||
"'");
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node),
|
||||
"Encountered while processing input ", i,
|
||||
" into function '", fn_name, "'");
|
||||
|
||||
input_tensors->emplace_back(&node, idx);
|
||||
|
||||
const auto& iter = input_nodes->find(&node);
|
||||
if (iter == input_nodes->end()) {
|
||||
input_nodes->insert({&node, {idx}});
|
||||
} else {
|
||||
auto& indices = iter->second;
|
||||
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"TF_Output ", node.name(), ":", idx,
|
||||
" appears more than once in the input list");
|
||||
}
|
||||
indices.push_back(idx);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts `noutputs` and `outputs` into `outputs_tensors` and does various
|
||||
// checks while doing so.
|
||||
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
||||
int noutputs, const TF_Output* outputs,
|
||||
std::vector<OutputTensor>* output_tensors)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
output_tensors->reserve(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
const Node& node = outputs[i].oper->node;
|
||||
int idx = outputs[i].index;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
fn_body->graph.IsValidOutputTensor(&node, idx),
|
||||
"Encountered while processing output ", i, " from function '", fn_name,
|
||||
"'");
|
||||
output_tensors->emplace_back(&node, idx);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Populates `body_nodes` with the nodes that will become function's body.
|
||||
// Performs various checks.
|
||||
Status ComputeBodyNodes(
|
||||
const TF_Graph* fn_body, const char* fn_name, int num_opers,
|
||||
const TF_Operation* const* opers,
|
||||
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
|
||||
std::vector<const Node*>* body_nodes)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
if (num_opers == -1) {
|
||||
for (const Node* node : fn_body->graph.op_nodes()) {
|
||||
const auto& iter = input_nodes.find(node);
|
||||
if (iter == input_nodes.end()) {
|
||||
// This node is not referenced in inputs. Add it to the body.
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
|
||||
"Encountered while creating function '",
|
||||
fn_name, "'");
|
||||
body_nodes->push_back(node);
|
||||
} else {
|
||||
// This node is referenced in inputs. Currently, we place an
|
||||
// artificial restriction and require that when num_opers=-1, such
|
||||
// nodes must have a single output.
|
||||
if (node->num_outputs() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"When `num_opers` is set to -1, nodes referenced in `inputs` "
|
||||
"must have a single output. Node ",
|
||||
node->name(), " has ", node->num_outputs(),
|
||||
" outputs. Encountered while creating function '", fn_name, "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
body_nodes->reserve(num_opers);
|
||||
for (int i = 0; i < num_opers; ++i) {
|
||||
const Node* node = &opers[i]->node;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
|
||||
"Encountered while creating function '",
|
||||
fn_name, "'");
|
||||
body_nodes->push_back(node);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
using tensorflow::Node;
|
||||
using tensorflow::string;
|
||||
|
||||
TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
|
||||
int num_opers, const TF_Operation* const* opers,
|
||||
int ninputs, const TF_Output* inputs,
|
||||
int noutputs, const TF_Output* outputs,
|
||||
const char* const* output_names,
|
||||
const TF_FunctionOptions* opts,
|
||||
TF_Status* status) {
|
||||
tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
|
||||
|
||||
// Process inputs.
|
||||
std::vector<tensorflow::OutputTensor> input_tensors;
|
||||
std::unordered_map<const Node*, std::vector<int>> input_nodes;
|
||||
status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
|
||||
&input_tensors, &input_nodes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
// Process outputs.
|
||||
std::vector<tensorflow::OutputTensor> output_tensors;
|
||||
status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
|
||||
outputs, &output_tensors);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
// Process output names.
|
||||
std::vector<string> output_names_vec;
|
||||
if (output_names) {
|
||||
output_names_vec.reserve(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names_vec.push_back(string(output_names[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// Compute body nodes.
|
||||
std::vector<const Node*> body_nodes;
|
||||
status->status = tensorflow::ComputeBodyNodes(
|
||||
fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
// Do the actual function creation.
|
||||
TF_Function* tf_function = new TF_Function();
|
||||
status->status = tensorflow::GraphToFunctionDef(
|
||||
fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
|
||||
output_names_vec, tf_function->fdef_lib.add_function());
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteFunction(tf_function);
|
||||
return nullptr;
|
||||
}
|
||||
return tf_function;
|
||||
}
|
||||
|
||||
void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function,
|
||||
TF_Status* status) {
|
||||
tensorflow::mutex_lock l(g->mu);
|
||||
|
||||
// At the moment, we have only one function and no gradients in fdef_lib.
|
||||
// This makes the following operation atomic.
|
||||
// TODO(iga): Add an atomic version of AddFunctionLibrary when we support
|
||||
// gradients
|
||||
status->status = g->graph.AddFunctionLibrary(function->fdef_lib);
|
||||
}
|
||||
|
||||
void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
|
||||
TF_Status* status) {
|
||||
DCHECK_EQ(1, func->fdef_lib.function_size());
|
||||
status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def);
|
||||
}
|
||||
|
||||
void TF_DeleteFunction(TF_Function* function) { delete function; }
|
1039
tensorflow/c/c_api_function_test.cc
Normal file
1039
tensorflow/c/c_api_function_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -130,6 +130,11 @@ struct TF_DeviceList {
|
||||
std::vector<tensorflow::DeviceAttributes> response;
|
||||
};
|
||||
|
||||
struct TF_Function {
|
||||
// Currently contains a single function and no gradients
|
||||
tensorflow::FunctionDefLibrary fdef_lib;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
@ -142,6 +147,9 @@ class TensorCApi {
|
||||
};
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_C_API_INTERNAL_H_
|
||||
|
@ -829,7 +829,7 @@ TEST(CAPI, ShapeInferenceError) {
|
||||
TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Operation* add = Add(vec2, vec3, graph, status);
|
||||
TF_Operation* add = AddNoCheck(vec2, vec3, graph, status);
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status));
|
||||
ASSERT_TRUE(add == nullptr);
|
||||
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using tensorflow::GraphDef;
|
||||
@ -36,6 +38,23 @@ TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
|
||||
return t;
|
||||
}
|
||||
|
||||
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
|
||||
const int32_t* values) {
|
||||
int64_t num_values = 1;
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
num_values *= dims[i];
|
||||
}
|
||||
TF_Tensor* t =
|
||||
TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
|
||||
memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
|
||||
return t;
|
||||
}
|
||||
|
||||
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
|
||||
int64_t dims = values.size();
|
||||
return Int32Tensor(&dims, 1, values.data());
|
||||
}
|
||||
|
||||
TF_Tensor* Int32Tensor(int32_t v) {
|
||||
const int num_bytes = sizeof(int32_t);
|
||||
int32_t* values = new int32_t[1];
|
||||
@ -44,19 +63,40 @@ TF_Tensor* Int32Tensor(int32_t v) {
|
||||
&Int32Deallocator, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
|
||||
// All the *Helper methods are used as a workaround for the restrictions that
|
||||
// one cannot call ASSERT_* methods in non-void-returning functions (when
|
||||
// exceptions are disabled during compilation)
|
||||
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
|
||||
TF_Operation** op) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
|
||||
TF_SetAttrType(desc, "dtype", TF_INT32);
|
||||
return TF_FinishOperation(desc, s);
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
|
||||
TF_Operation* op;
|
||||
PlaceholderHelper(graph, s, name, &op);
|
||||
return op;
|
||||
}
|
||||
|
||||
void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
|
||||
TF_Operation** op) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
|
||||
TF_SetAttrTensor(desc, "value", t, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
|
||||
TF_SetAttrTensor(desc, "value", t, s);
|
||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
|
||||
return TF_FinishOperation(desc, s);
|
||||
TF_Operation* op;
|
||||
ConstHelper(t, graph, s, name, &op);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
@ -65,11 +105,39 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
return Const(tensor.get(), graph, s, name);
|
||||
}
|
||||
|
||||
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name) {
|
||||
void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
|
||||
const char* name, TF_Operation** op, bool check) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
|
||||
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
|
||||
TF_AddInputList(desc, add_inputs, 2);
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
if (check) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name) {
|
||||
TF_Operation* op;
|
||||
AddHelper(l, r, graph, s, name, &op, true);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name) {
|
||||
TF_Operation* op;
|
||||
AddHelper(l, r, graph, s, name, &op, false);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
|
||||
TF_Graph* graph, TF_Operation* ctrl_op,
|
||||
TF_Status* s, const char* name) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
|
||||
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
|
||||
TF_AddInputList(desc, add_inputs, 2);
|
||||
TF_AddControlInput(desc, ctrl_op);
|
||||
return TF_FinishOperation(desc, s);
|
||||
}
|
||||
|
||||
@ -81,11 +149,20 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
|
||||
return TF_FinishOperation(desc, s);
|
||||
}
|
||||
|
||||
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
|
||||
void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s,
|
||||
TF_Operation** op) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
|
||||
TF_Output neg_input = {n, 0};
|
||||
TF_AddInput(desc, neg_input);
|
||||
return TF_FinishOperation(desc, s);
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
|
||||
TF_Operation* op;
|
||||
NegHelper(n, graph, s, &op);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
|
||||
@ -96,6 +173,32 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
|
||||
return TF_FinishOperation(desc, s);
|
||||
}
|
||||
|
||||
void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
|
||||
const char* name, TF_Operation** op) {
|
||||
TF_Operation* zero = ScalarConst(
|
||||
0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
|
||||
TF_AddInput(desc, {zero, 0});
|
||||
TF_AddInput(desc, {input, 0});
|
||||
TF_SetAttrInt(desc, "num_split", 3);
|
||||
TF_SetAttrType(desc, "T", TF_INT32);
|
||||
// Set device to CPU since there is no version of split for int32 on GPU
|
||||
// TODO(iga): Convert all these helpers and tests to use floats because
|
||||
// they are usually available on GPUs. After doing this, remove TF_SetDevice
|
||||
// call in c_api_function_test.cc
|
||||
TF_SetDevice(desc, "/cpu:0");
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
TF_Operation* op;
|
||||
Split3Helper(input, graph, s, name, &op);
|
||||
return op;
|
||||
}
|
||||
|
||||
bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
|
||||
if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
|
||||
return false;
|
||||
@ -196,6 +299,18 @@ bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Buffer* buffer = TF_NewBuffer();
|
||||
TF_FunctionToFunctionDef(func, buffer, s);
|
||||
bool ret = TF_GetCode(s) == TF_OK;
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
|
||||
TF_DeleteBuffer(buffer);
|
||||
TF_DeleteStatus(s);
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
|
||||
tensorflow::AttrValue* attr_value, TF_Status* s) {
|
||||
TF_Buffer* buffer = TF_NewBuffer();
|
||||
|
@ -33,6 +33,13 @@ typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
|
||||
// Create a tensor with values of type TF_INT8 provided by `values`.
|
||||
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
|
||||
|
||||
// Create a tensor with values of type TF_INT32 provided by `values`.
|
||||
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
|
||||
const int32_t* values);
|
||||
|
||||
// Create 1 dimensional tensor with values from `values`
|
||||
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values);
|
||||
|
||||
TF_Tensor* Int32Tensor(int32_t v);
|
||||
|
||||
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
|
||||
@ -47,6 +54,13 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name = "add");
|
||||
|
||||
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name = "add");
|
||||
|
||||
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
|
||||
TF_Graph* graph, TF_Operation* ctrl_op,
|
||||
TF_Status* s, const char* name = "add");
|
||||
|
||||
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "add");
|
||||
|
||||
@ -54,6 +68,10 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s);
|
||||
|
||||
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);
|
||||
|
||||
// Split `input` along the first dimention into 3 tensors
|
||||
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "split3");
|
||||
|
||||
bool IsPlaceholder(const tensorflow::NodeDef& node_def);
|
||||
|
||||
bool IsScalarConst(const tensorflow::NodeDef& node_def, int v);
|
||||
@ -66,6 +84,8 @@ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
|
||||
|
||||
bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def);
|
||||
|
||||
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def);
|
||||
|
||||
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
|
||||
tensorflow::AttrValue* attr_value, TF_Status* s);
|
||||
|
||||
|
@ -687,6 +687,72 @@ Status MeanGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Mean", MeanGrad);
|
||||
|
||||
Status MinOrMaxGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
// The partial derivative for any input along a "reduced" dimension
|
||||
// is 1 when it is the min (or max) and 0 everywhere else. So the
|
||||
// gradient calculation is identical for both operators.
|
||||
//
|
||||
// There's a special case for propagating gradients when there are
|
||||
// multiple minima (or maxima) - we choose to divide the gradient
|
||||
// equally among all matching inputs.
|
||||
//
|
||||
// Please note this comment
|
||||
// https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
|
||||
// for details.
|
||||
|
||||
// Running example:
|
||||
// input: [[5, 5, 5],
|
||||
// [1, 2, -3]]
|
||||
// reduction_indices: [1]
|
||||
auto input = op.input(0);
|
||||
auto reduction_indices = op.input(1);
|
||||
|
||||
// [2, 3]
|
||||
auto input_shape = Shape(scope, input);
|
||||
|
||||
// [2, 1]
|
||||
auto output_shape_kept_dims =
|
||||
ReducedShapeHelper(scope, input_shape, reduction_indices);
|
||||
|
||||
// for op=min (say)
|
||||
// output = [5, -3]
|
||||
// y = [[5],
|
||||
// [-3]]
|
||||
auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
|
||||
|
||||
// reshape([g1, g2], [2, 1]) = [[g1],
|
||||
// [g2]]
|
||||
auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
|
||||
|
||||
// indicators = equal(y, input)
|
||||
// = equal([[5], [[5, 5, 5],
|
||||
// [-3]], [1, 2, -3]])
|
||||
// = [[1, 1, 1],
|
||||
// [0, 0, 1]]
|
||||
auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
|
||||
|
||||
// [[3],
|
||||
// [1]]
|
||||
auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
|
||||
output_shape_kept_dims);
|
||||
|
||||
// [[1/3, 1/3, 1/3],
|
||||
// [0, 0, 1]]
|
||||
auto scale = Div(scope, indicators, num_selected);
|
||||
|
||||
// [[g1/3, g1/3, g1/3],
|
||||
// [0, 0, g2]]
|
||||
grad_outputs->push_back(Mul(scope, scale, grad));
|
||||
|
||||
// Stop propagation along reduction_indices
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
|
||||
REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
|
||||
|
||||
// MatMulGrad helper function used to compute two MatMul operations
|
||||
// based on input matrix transposition combinations.
|
||||
Status MatMulGradHelper(const Scope& scope, const bool is_batch,
|
||||
|
@ -955,6 +955,55 @@ TEST_F(NaryGradTest, Mean) {
|
||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, Min) {
|
||||
TensorShape x_shape({2, 3});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto y = Min(scope_, x, {-1});
|
||||
// y's shape is the result of reducing x along axes -1 (= 1)
|
||||
TensorShape y_shape({2});
|
||||
Tensor x_init_value =
|
||||
test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, Max) {
|
||||
TensorShape x_shape({2, 3});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto y = Max(scope_, x, {-1});
|
||||
// y's shape is the result of reducing x along axes -1 (= 1)
|
||||
TensorShape y_shape({2});
|
||||
Tensor x_init_value =
|
||||
test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, MinMulti) {
|
||||
// Test gradient when there are multiple minima.
|
||||
// Note that we cannot directly use a test Tensor with multiple
|
||||
// minima, as the numeric estimator will calculate incorrect
|
||||
// gradients when perturbing each entry in the Tensor (which then
|
||||
// changes how many minima exist.)
|
||||
// Instead, we use a single input that broadcast-multiplies a larger
|
||||
// tensor with equal values, and apply reduce_min to the multiplied
|
||||
// result.
|
||||
TensorShape x_shape({1});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
|
||||
auto y = Min(scope_, all_same, {0});
|
||||
// y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
|
||||
TensorShape y_shape({1});
|
||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, MaxMulti) {
|
||||
TensorShape x_shape({1});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
|
||||
auto y = Max(scope_, all_same, {0});
|
||||
TensorShape y_shape({1});
|
||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, AddN) {
|
||||
TensorShape shape({3, 2, 5});
|
||||
std::vector<Output> xs;
|
||||
|
@ -52,6 +52,12 @@ class BinaryOpsTest(XLATestCase):
|
||||
|
||||
def testFloatOps(self):
|
||||
for dtype in self.float_types:
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
|
||||
np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype),
|
||||
np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype),
|
||||
expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._real_div,
|
||||
np.array([3, 3, -1.5, -8, 44], dtype=dtype),
|
||||
@ -82,6 +88,12 @@ class BinaryOpsTest(XLATestCase):
|
||||
dtype(4),
|
||||
expected=np.array([[16], [81]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._reciprocal_grad,
|
||||
np.array([4, -3, -2, 1], dtype=dtype),
|
||||
np.array([5, -6, 7, -8], dtype=dtype),
|
||||
expected=np.array([-80, 54, -28, 8], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._sigmoid_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
@ -107,6 +119,13 @@ class BinaryOpsTest(XLATestCase):
|
||||
expected=np.array(
|
||||
[3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops._softsign_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array(
|
||||
[0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._tanh_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
|
@ -888,6 +888,16 @@ TEST_F(OpTest, Any) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, ApproximateEqual) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = RandomDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Asinh) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
@ -1662,11 +1672,9 @@ TEST_F(OpTest, GreaterEqual) {
|
||||
|
||||
TEST_F(OpTest, L2Loss) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
// TODO(b/31644876): scalars currently crash.
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss")
|
||||
.RandomInput(type, RandomDims(1))
|
||||
.Attr("T", type));
|
||||
DataType type = DT_FLOAT;
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("L2Loss").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -2165,6 +2173,15 @@ TEST_F(OpTest, Reciprocal) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, ReciprocalGrad) {
|
||||
Repeatedly([this]() {
|
||||
std::vector<int64> dims = RandomDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
TEST_F(OpTest, Relu) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
@ -2250,6 +2267,13 @@ TEST_F(OpTest, ReverseV2) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Rint) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Rint").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Round) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
@ -2402,6 +2426,23 @@ TEST_F(OpTest, SoftplusGrad) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Softsign) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Softsign").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, SoftsignGrad) {
|
||||
Repeatedly([this]() {
|
||||
std::vector<int64> dims = RandomDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftsignGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, SpaceToBatch) {
|
||||
Repeatedly([this]() {
|
||||
std::vector<int64> block_dims = RandomDims(4, 4, 0, 5);
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
@ -161,12 +163,17 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[-1.7, 1.2]], dtype=dtype),
|
||||
expected=np.array([[-2, 1]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_finite,
|
||||
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
|
||||
dtype=dtype),
|
||||
expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool))
|
||||
|
||||
# Tests for tf.nn ops.
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0))
|
||||
|
||||
# TODO(b/31644876): enable this test case when fixed.
|
||||
# self._assertOpOutputMatchesExpected(tf.nn.l2_loss, dtype(4), dtype(10))
|
||||
self._assertOpOutputMatchesExpected(nn_ops.l2_loss, dtype(4), dtype(8))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10))
|
||||
@ -198,6 +205,12 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
|
||||
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.rint,
|
||||
np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
|
||||
[0.5, 1.5, 2.5, 3.5]], dtype=dtype),
|
||||
expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]],
|
||||
dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.round,
|
||||
np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
|
||||
@ -301,6 +314,12 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[-2, 0, 8]], dtype=dtype),
|
||||
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softsign,
|
||||
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
|
||||
expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]],
|
||||
dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_finite,
|
||||
np.array(
|
||||
@ -335,6 +354,23 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[4, 3], [2, 1]], dtype=dtype),
|
||||
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
|
||||
|
||||
# TODO(phawkins): these tests fail unless fastmath optimizations
|
||||
# are disabled. Use more robust IsInf/IsNaN detection and enable these
|
||||
# tests.
|
||||
@unittest.skip("test case fails in fast-math mode")
|
||||
def testIsInfAndIsNan(self):
|
||||
for dtype in self.float_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_inf,
|
||||
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
|
||||
dtype=dtype),
|
||||
expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_nan,
|
||||
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
|
||||
dtype=dtype),
|
||||
expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool))
|
||||
|
||||
def testLogicalOps(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.logical_not,
|
||||
|
@ -31,7 +31,6 @@ tf_kernel_library(
|
||||
"function_ops.cc",
|
||||
"gather_op.cc",
|
||||
"identity_op.cc",
|
||||
"is_finite_op.cc",
|
||||
"l2loss_op.cc",
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
|
@ -102,6 +102,7 @@ XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs))));
|
||||
XLA_MAKE_BINARY(
|
||||
RsqrtGrad,
|
||||
b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
|
||||
@ -140,6 +141,11 @@ XLA_MAKE_BINARY(SoftplusGrad,
|
||||
b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
|
||||
XlaHelpers::One(b, input_type(1)))));
|
||||
|
||||
// softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
|
||||
XLA_MAKE_BINARY(SoftsignGrad,
|
||||
b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)),
|
||||
b->Abs(rhs)))));
|
||||
|
||||
XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
|
||||
b->Mul(lhs, lhs))));
|
||||
|
||||
@ -147,5 +153,24 @@ XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions));
|
||||
|
||||
#undef XLA_MAKE_BINARY
|
||||
|
||||
class ApproximateEqualOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ApproximateEqualOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance_));
|
||||
}
|
||||
|
||||
// Computes the max of the scalar input x and 0.
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), tolerance_));
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
|
||||
private:
|
||||
float tolerance_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ApproximateEqual"), ApproximateEqualOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,43 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class IsFiniteOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit IsFiniteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle input = ctx->Input(0);
|
||||
ctx->SetOutput(0, ctx->builder()->IsFinite(input));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IsFiniteOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("IsFinite"), IsFiniteOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
@ -73,8 +73,12 @@ XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
|
||||
XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0))));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
|
||||
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
|
||||
XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
|
||||
XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x));
|
||||
XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x),
|
||||
XlaHelpers::FloatLiteral(
|
||||
b, input_type(0),
|
||||
std::numeric_limits<double>::infinity())));
|
||||
XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x));
|
||||
// Return 1/x
|
||||
XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x));
|
||||
XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x));
|
||||
@ -105,6 +109,12 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b,
|
||||
b->Add(round_val, one), round_val);
|
||||
}
|
||||
|
||||
XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x));
|
||||
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Rsqrt,
|
||||
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
|
||||
|
||||
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
|
||||
static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
|
||||
DataType dtype,
|
||||
@ -112,16 +122,19 @@ static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
|
||||
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
|
||||
return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
|
||||
}
|
||||
|
||||
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
|
||||
XLAJIT_MAKE_UNARY(Rsqrt,
|
||||
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
|
||||
XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
|
||||
|
||||
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
|
||||
XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
|
||||
XLAJIT_MAKE_UNARY(Sinh,
|
||||
b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
|
||||
XLAJIT_MAKE_UNARY(Softplus,
|
||||
b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0)))));
|
||||
// softsign(x) = x / (abs(x) + 1)
|
||||
XLAJIT_MAKE_UNARY(Softsign,
|
||||
b->Div(x,
|
||||
b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0)))));
|
||||
XLAJIT_MAKE_UNARY(Sqrt,
|
||||
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
|
||||
XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));
|
||||
|
@ -847,6 +847,7 @@ cc_test(
|
||||
srcs = ["hlo_ordering_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_ordering",
|
||||
":hlo_scheduling",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -241,7 +241,7 @@ Status Executor::Run() {
|
||||
completion_queue_.pop_front();
|
||||
break;
|
||||
}
|
||||
} while (1);
|
||||
} while (true);
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
|
||||
assignment_->GetUniqueTopLevelSlice(instruction));
|
||||
void* result_buffer =
|
||||
|
@ -24,16 +24,14 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo,
|
||||
HloOpcode opcode) {
|
||||
Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo) {
|
||||
return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s",
|
||||
HloOpcodeString(opcode).c_str());
|
||||
HloOpcodeString(hlo->opcode()).c_str());
|
||||
}
|
||||
|
||||
Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) {
|
||||
Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo) {
|
||||
return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s",
|
||||
HloOpcodeString(opcode).c_str());
|
||||
HloOpcodeString(hlo->opcode()).c_str());
|
||||
}
|
||||
|
||||
DfsHloVisitor::VisitState DfsHloVisitor::GetVisitState(
|
||||
|
@ -63,37 +63,37 @@ class DfsHloVisitor {
|
||||
// These routines are self-descriptive, see class comment for usage
|
||||
// information.
|
||||
|
||||
virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode);
|
||||
virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode);
|
||||
virtual Status HandleElementwiseUnary(HloInstruction* hlo);
|
||||
virtual Status HandleElementwiseBinary(HloInstruction* hlo);
|
||||
virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
||||
HloInstruction* arg, HloInstruction* max) = 0;
|
||||
virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred,
|
||||
HloInstruction* on_true,
|
||||
HloInstruction* on_false) = 0;
|
||||
virtual Status HandleMaximum(HloInstruction* maximum) {
|
||||
return HandleElementwiseBinary(maximum, HloOpcode::kMaximum);
|
||||
return HandleElementwiseBinary(maximum);
|
||||
}
|
||||
virtual Status HandleMinimum(HloInstruction* minimum) {
|
||||
return HandleElementwiseBinary(minimum, HloOpcode::kMinimum);
|
||||
return HandleElementwiseBinary(minimum);
|
||||
}
|
||||
virtual Status HandleConcatenate(
|
||||
HloInstruction* concatenate,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
|
||||
virtual Status HandleConvert(HloInstruction* convert) {
|
||||
return HandleElementwiseUnary(convert, HloOpcode::kConvert);
|
||||
return HandleElementwiseUnary(convert);
|
||||
}
|
||||
virtual Status HandleCopy(HloInstruction* copy) {
|
||||
return HandleElementwiseUnary(copy, HloOpcode::kCopy);
|
||||
return HandleElementwiseUnary(copy);
|
||||
}
|
||||
virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(multiply, HloOpcode::kMultiply);
|
||||
return HandleElementwiseBinary(multiply);
|
||||
}
|
||||
virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
|
||||
HloInstruction* rhs) = 0;
|
||||
virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(power, HloOpcode::kPower);
|
||||
return HandleElementwiseBinary(power);
|
||||
}
|
||||
virtual Status HandleConvolution(HloInstruction* convolution,
|
||||
HloInstruction* lhs, HloInstruction* rhs,
|
||||
@ -101,73 +101,72 @@ class DfsHloVisitor {
|
||||
virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0;
|
||||
virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
||||
HloInstruction* lhs, HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(compare, opcode);
|
||||
return HandleElementwiseBinary(compare);
|
||||
}
|
||||
virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(add, HloOpcode::kAdd);
|
||||
return HandleElementwiseBinary(add);
|
||||
}
|
||||
virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(divide, HloOpcode::kDivide);
|
||||
return HandleElementwiseBinary(divide);
|
||||
}
|
||||
virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(remainder, HloOpcode::kRemainder);
|
||||
return HandleElementwiseBinary(remainder);
|
||||
}
|
||||
virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(subtract, HloOpcode::kSubtract);
|
||||
return HandleElementwiseBinary(subtract);
|
||||
}
|
||||
virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(abs, HloOpcode::kAbs);
|
||||
return HandleElementwiseUnary(abs);
|
||||
}
|
||||
virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(sign, HloOpcode::kSign);
|
||||
return HandleElementwiseUnary(sign);
|
||||
}
|
||||
virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(negate, HloOpcode::kNegate);
|
||||
return HandleElementwiseUnary(negate);
|
||||
}
|
||||
virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(exp, HloOpcode::kExp);
|
||||
return HandleElementwiseUnary(exp);
|
||||
}
|
||||
virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(floor, HloOpcode::kFloor);
|
||||
return HandleElementwiseUnary(floor);
|
||||
}
|
||||
virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(ceil, HloOpcode::kCeil);
|
||||
return HandleElementwiseUnary(ceil);
|
||||
}
|
||||
virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(log, HloOpcode::kLog);
|
||||
return HandleElementwiseUnary(log);
|
||||
}
|
||||
virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(cos, HloOpcode::kCos);
|
||||
return HandleElementwiseUnary(cos);
|
||||
}
|
||||
virtual Status HandleSin(HloInstruction* sin, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(sin, HloOpcode::kSin);
|
||||
return HandleElementwiseUnary(sin);
|
||||
}
|
||||
virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(tanh, HloOpcode::kTanh);
|
||||
return HandleElementwiseUnary(tanh);
|
||||
}
|
||||
virtual Status HandleIsFinite(HloInstruction* is_finite,
|
||||
HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite);
|
||||
return HandleElementwiseUnary(is_finite);
|
||||
}
|
||||
virtual Status HandleLogicalAnd(HloInstruction* logical_and,
|
||||
HloInstruction* lhs, HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd);
|
||||
return HandleElementwiseBinary(logical_and);
|
||||
}
|
||||
virtual Status HandleLogicalNot(HloInstruction* logical_not,
|
||||
HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot);
|
||||
return HandleElementwiseUnary(logical_not);
|
||||
}
|
||||
virtual Status HandleLogicalOr(HloInstruction* logical_or,
|
||||
HloInstruction* lhs, HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr);
|
||||
return HandleElementwiseBinary(logical_or);
|
||||
}
|
||||
virtual Status HandleReducePrecision(HloInstruction* reduce_precision) {
|
||||
return HandleElementwiseUnary(reduce_precision,
|
||||
HloOpcode::kReducePrecision);
|
||||
return HandleElementwiseUnary(reduce_precision);
|
||||
}
|
||||
|
||||
virtual Status HandleInfeed(HloInstruction* infeed) = 0;
|
||||
|
@ -41,12 +41,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
|
||||
// Default action performed on HloInstruction.
|
||||
virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0;
|
||||
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override {
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override {
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
|
@ -709,7 +709,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
|
||||
} else {
|
||||
auto r = ir_builder_->CreateSub(q, p);
|
||||
auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
|
||||
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)},
|
||||
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)},
|
||||
{param_ir_type}, ir_builder_);
|
||||
auto in_block = ir_builder_->GetInsertBlock();
|
||||
|
||||
|
@ -334,7 +334,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
||||
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_);
|
||||
|
||||
IrArray::Index input_index(index.size());
|
||||
llvm::Value* in_bounds = ir_builder_->getInt1(1);
|
||||
llvm::Value* in_bounds = ir_builder_->getInt1(true);
|
||||
for (size_t i = 0; i < index.size(); ++i) {
|
||||
llvm::Value* stridden_index = ir_builder_->CreateNSWMul(
|
||||
index[i], ir_builder_->getInt64(window.dimensions(i).stride()));
|
||||
|
@ -389,7 +389,7 @@ StatusOr<string> CompileModuleToPtx(llvm::Module* module,
|
||||
|
||||
// Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA
|
||||
// again after the standard optimization passes [http://b/13329423].
|
||||
// TODO(jingyue): SROA may further expose more optimization opportunities, such
|
||||
// TODO(jingyue): SROA may further expose more optimization opportunities such
|
||||
// as more precise alias analysis and more function inlining (SROA may change
|
||||
// the inlining cost of a function). For now, running SROA already emits good
|
||||
// enough code for the evaluated benchmarks. We may want to run more
|
||||
|
@ -37,6 +37,230 @@ namespace xla {
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
// Data structure used to construct the alias analysis. Thrown away after alias
|
||||
// analysis is complete. This data structure keeps track of which sets of
|
||||
// HloValues must be in the same HloBuffer. This is maintained as a map from a
|
||||
// buffer identifier (BufferNumber) to set of HLoValues.
|
||||
//
|
||||
// Initially each value is its own buffer. In MergeAliasedBuffers, sets of
|
||||
// values which must share the same buffer are merged together. The end result
|
||||
// is a partitioning of all HloValues into sets where each set needs its own
|
||||
// HloBuffer. By performing this analysis without constructing HloBuffers on the
|
||||
// fly, we can after-the-fact construct a vector of contiguously numbered
|
||||
// HloBuffers after the buffer requirement has been determined.
|
||||
class BufferValueMap {
|
||||
public:
|
||||
// A unique identifier for a set of colocated values which must share the same
|
||||
// buffer. This is not necessarily the same as the HloBuffer::Id which will
|
||||
// ultimately contain the values. The reason is that HloBuffer::Id's are
|
||||
// contiguous, while BufferNumbers may not be. BufferNumbers may not be
|
||||
// dense because buffers may be created and destroyed during the analysis
|
||||
// construction process.
|
||||
using BufferNumber = int64;
|
||||
|
||||
explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
|
||||
: dataflow_(dataflow) {
|
||||
buffers_.reserve(dataflow_.values().size());
|
||||
value_to_buffer_number_.reserve(dataflow_.values().size());
|
||||
for (const HloValue* value : dataflow_.values()) {
|
||||
BufferNumber buffer_number = next_buffer_number_++;
|
||||
buffers_[buffer_number].insert(value);
|
||||
value_to_buffer_number_[value] = buffer_number;
|
||||
}
|
||||
}
|
||||
|
||||
// Merge together sets of HloValues which must be in the same HloBuffer
|
||||
// because of aliasing rules (eg, in-place kWhile instruction).
|
||||
void MergeAliasedBuffers() {
|
||||
for (const HloValue* value : dataflow_.values()) {
|
||||
VLOG(3) << "Merging colocated values, value: " << value->ToShortString();
|
||||
|
||||
// Gather the set of buffers with aliasing rules (eg, kWhile) which this
|
||||
// value must be contained in.
|
||||
std::vector<BufferNumber> aliased_buffers = ComputeAliasedBuffers(*value);
|
||||
|
||||
BufferNumber current_buffer = value_to_buffer_number_.at(value);
|
||||
if (aliased_buffers.empty()) {
|
||||
// The buffer containing 'value' aliases no other buffers. If the buffer
|
||||
// containing 'value' already only contains 'value', then no change is
|
||||
// necessary. If the buffer containing 'value' does contain other
|
||||
// values, then remove 'value' from the buffer and create a new buffer
|
||||
// containing only 'value'
|
||||
if (buffers_.at(current_buffer).size() == 1) {
|
||||
CHECK_EQ(*buffers_.at(current_buffer).begin(), value);
|
||||
} else {
|
||||
MoveValueToNewBuffer(*value);
|
||||
}
|
||||
} else {
|
||||
// If multiple buffers are aliased merge these buffers together into a
|
||||
// single buffer (arbitrarily chosen as the first buffer in the vector).
|
||||
if (aliased_buffers.size() > 1) {
|
||||
for (int64 i = 1; i < aliased_buffers.size(); ++i) {
|
||||
MergeBuffers(/*from=*/aliased_buffers[i],
|
||||
/*to=*/aliased_buffers[0]);
|
||||
}
|
||||
}
|
||||
BufferNumber new_buffer = aliased_buffers[0];
|
||||
if (current_buffer != new_buffer) {
|
||||
MoveValueToBuffer(*value, new_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute and return a sorted vector of all BufferNumbers. Can be used to
|
||||
// iterate through all buffers stabily.
|
||||
std::vector<BufferNumber> ComputeSortedBufferNumbers() const {
|
||||
std::vector<BufferNumber> buffer_numbers;
|
||||
for (const auto& pair : buffers_) {
|
||||
buffer_numbers.push_back(pair.first);
|
||||
}
|
||||
std::sort(buffer_numbers.begin(), buffer_numbers.end());
|
||||
return buffer_numbers;
|
||||
}
|
||||
|
||||
// Return a set of all the values in the given buffer.
|
||||
const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer(
|
||||
BufferNumber buffer_number) const {
|
||||
return buffers_.at(buffer_number);
|
||||
}
|
||||
|
||||
private:
|
||||
// Create a new buffer.
|
||||
void NewBuffer(const HloValue& value) {
|
||||
BufferNumber buffer_number = next_buffer_number_++;
|
||||
buffers_[buffer_number].insert(&value);
|
||||
value_to_buffer_number_[&value] = buffer_number;
|
||||
}
|
||||
|
||||
// Move the given value into a new buffer containing only the value.
|
||||
void MoveValueToNewBuffer(const HloValue& value) {
|
||||
BufferNumber new_buffer_number = next_buffer_number_++;
|
||||
buffers_[new_buffer_number];
|
||||
MoveValueToBuffer(value, new_buffer_number);
|
||||
}
|
||||
|
||||
// Move the given value into the given buffer.
|
||||
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
|
||||
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
|
||||
buffers_.at(old_buffer_number).erase(&value);
|
||||
if (buffers_.at(old_buffer_number).empty()) {
|
||||
buffers_.erase(old_buffer_number);
|
||||
}
|
||||
|
||||
buffers_.at(buffer_number).insert(&value);
|
||||
value_to_buffer_number_.at(&value) = buffer_number;
|
||||
}
|
||||
|
||||
// Merge the buffer 'from' into the buffer 'to'.
|
||||
void MergeBuffers(BufferNumber from, BufferNumber to) {
|
||||
auto& from_value_set = buffers_.at(from);
|
||||
buffers_.at(to).insert(from_value_set.begin(), from_value_set.end());
|
||||
// NOTE: using a union-find algorithm to hold the colocated values might be
|
||||
// faster.
|
||||
for (const HloValue* value : from_value_set) {
|
||||
value_to_buffer_number_.at(value) = to;
|
||||
}
|
||||
buffers_.erase(from);
|
||||
}
|
||||
|
||||
BufferNumber GetBufferForValue(const HloValue& value) {
|
||||
return value_to_buffer_number_.at(&value);
|
||||
}
|
||||
|
||||
// Compute and return a vector of buffers that the given value must be
|
||||
// contained in due to HLO aliasing rules.
|
||||
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
|
||||
// Value is init of a while (use is while).
|
||||
std::vector<BufferNumber> aliased_buffers;
|
||||
for (const HloUse& use : value.uses()) {
|
||||
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
// Determine the while value that this shares a buffer with.
|
||||
const HloValue& while_value =
|
||||
dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
|
||||
aliased_buffers.push_back(GetBufferForValue(while_value));
|
||||
VLOG(3) << " value is init value to a while; must share buffer with "
|
||||
"while value "
|
||||
<< while_value.ToShortString();
|
||||
}
|
||||
}
|
||||
|
||||
// Value is a parameter of a while body/condition.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
|
||||
const HloComputation* computation =
|
||||
value.defining_instruction()->parent();
|
||||
const CallGraphNode& call_graph_node =
|
||||
dataflow_.call_graph().GetNode(computation);
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
|
||||
// Call graph must have been flattened.
|
||||
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
|
||||
|
||||
const HloValue& while_value = dataflow_.GetUniqueValueAt(
|
||||
callsite.instruction(), value.defining_index());
|
||||
VLOG(3) << " value is parameter value of the body or condition of a "
|
||||
"while; must share buffer with while value "
|
||||
<< while_value.ToShortString();
|
||||
aliased_buffers.push_back(GetBufferForValue(while_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Value is the root of a while body.
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
const HloComputation* computation = position.instruction->parent();
|
||||
const CallGraphNode& call_graph_node =
|
||||
dataflow_.call_graph().GetNode(computation);
|
||||
if (position.instruction == computation->root_instruction()) {
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
|
||||
callsite.instruction()->while_body() == computation) {
|
||||
// Call graph must have been flattened.
|
||||
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
|
||||
|
||||
const HloValue& while_value = dataflow_.GetUniqueValueAt(
|
||||
callsite.instruction(), position.index);
|
||||
VLOG(3) << " value is root the body computation of a while; must "
|
||||
"share buffer with while value "
|
||||
<< while_value.ToShortString();
|
||||
aliased_buffers.push_back(GetBufferForValue(while_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Value is the output of the while instruction itself.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
|
||||
VLOG(3) << " value is output of a while instruction";
|
||||
aliased_buffers.push_back(GetBufferForValue(value));
|
||||
}
|
||||
|
||||
// Uniquify aliased buffers.
|
||||
std::sort(aliased_buffers.begin(), aliased_buffers.end());
|
||||
aliased_buffers.erase(
|
||||
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
|
||||
aliased_buffers.end());
|
||||
|
||||
return aliased_buffers;
|
||||
}
|
||||
|
||||
// Dataflow analysis used to construct the buffer map.
|
||||
const HloDataflowAnalysis& dataflow_;
|
||||
|
||||
// A map containing the set of values contained in each buffer.
|
||||
tensorflow::gtl::FlatMap<BufferNumber,
|
||||
tensorflow::gtl::FlatSet<const HloValue*>>
|
||||
buffers_;
|
||||
|
||||
// A map indicating which buffer each value is contained in.
|
||||
tensorflow::gtl::FlatMap<const HloValue*, BufferNumber>
|
||||
value_to_buffer_number_;
|
||||
|
||||
// The buffer number of the next buffer to be created.
|
||||
BufferNumber next_buffer_number_ = 0;
|
||||
};
|
||||
|
||||
HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {}
|
||||
|
||||
const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
|
||||
@ -99,10 +323,11 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
|
||||
}
|
||||
} else {
|
||||
// It's possible for multiple values at this index to have the same
|
||||
// HloBuffer. This does not result in non-distictness. To account for this
|
||||
// case, add all of the buffers at this index after checking whether each
|
||||
// buffer exists at an earlier index. This is a corner case, however, as
|
||||
// the number of values at an index is almost always one.
|
||||
// HloBuffer. This does not result in non-distictness. To account for
|
||||
// this case, add all of the buffers at this index after checking
|
||||
// whether each buffer exists at an earlier index. This is a corner
|
||||
// case, however, as the number of values at an index is almost always
|
||||
// one.
|
||||
std::vector<const HloBuffer*> buffers_at_this_index;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
const HloBuffer* buffer = &GetBufferContainingValue(*value);
|
||||
@ -118,15 +343,6 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
|
||||
return true;
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::InitializeBufferSets() {
|
||||
// Initially define a buffer for every HloValue in the module.
|
||||
for (const HloValue& value : dataflow_analysis_->values()) {
|
||||
HloBuffer& buffer = NewHloBuffer();
|
||||
buffer.AddValue(value);
|
||||
value_to_buffer_[&value] = &buffer;
|
||||
}
|
||||
}
|
||||
|
||||
Status HloAliasAnalysis::Verify() const {
|
||||
// Verify consistency between the value_to_buffer_ map and
|
||||
// HloBuffer::values().
|
||||
@ -137,9 +353,8 @@ Status HloAliasAnalysis::Verify() const {
|
||||
value) != buffer.values().end());
|
||||
}
|
||||
|
||||
for (const auto& pair : buffers_) {
|
||||
const HloBuffer::Id id = pair.first;
|
||||
const HloBuffer& buffer = pair.second;
|
||||
for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
|
||||
const HloBuffer& buffer = buffers_[id];
|
||||
TF_RET_CHECK(buffer.id() == id);
|
||||
|
||||
HloValue::Id last_value_id = -1;
|
||||
@ -152,116 +367,9 @@ Status HloAliasAnalysis::Verify() const {
|
||||
}
|
||||
}
|
||||
|
||||
if (!buffers_vector_.empty()) {
|
||||
// buffers_vector_ should be a vector of all HloBuffers sorted by id.
|
||||
std::vector<const HloBuffer*> buffers;
|
||||
for (const auto& id_buffer : buffers_) {
|
||||
buffers.push_back(&id_buffer.second);
|
||||
}
|
||||
std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan);
|
||||
TF_RET_CHECK(buffers_vector_ == buffers);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloAliasAnalysis::VerifyAgainstReference() const {
|
||||
TF_RETURN_IF_ERROR(Verify());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> reference,
|
||||
Run(module_));
|
||||
TF_RETURN_IF_ERROR(reference->Verify());
|
||||
|
||||
VLOG(2) << "This analysis:";
|
||||
XLA_VLOG_LINES(2, ToString());
|
||||
VLOG(2) << "Reference:";
|
||||
XLA_VLOG_LINES(2, reference->ToString());
|
||||
|
||||
// Create map from HloValue in the reference analysis to HloValue in this
|
||||
// analysis and vice versa.
|
||||
tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> reference_to_this;
|
||||
tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> this_to_reference;
|
||||
for (const HloValue& value : dataflow_analysis().values()) {
|
||||
const HloValue& reference_value =
|
||||
reference->dataflow_analysis().GetValueDefinedAt(
|
||||
value.defining_instruction(), value.defining_index());
|
||||
reference_to_this[&reference_value] = &value;
|
||||
this_to_reference[&value] = &reference_value;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(buffers_.size() == reference->buffers_.size())
|
||||
<< "Different number of buffers (" << buffers_.size()
|
||||
<< " != " << reference->buffers_.size() << ")";
|
||||
for (const auto& pair : reference->buffers_) {
|
||||
const HloBuffer& reference_buffer = pair.second;
|
||||
|
||||
// Find the corresponding buffer in the reference by taking the first value
|
||||
// in the buffer, finding the corresponding value in the reference, and then
|
||||
// finding the buffer holding that value.
|
||||
TF_RET_CHECK(!reference_buffer.values().empty());
|
||||
const HloValue* reference_value = reference_buffer.values()[0];
|
||||
const HloValue* value = reference_to_this.at(reference_value);
|
||||
const HloBuffer& buffer = GetBufferContainingValue(*value);
|
||||
|
||||
// The buffer and the reference should have the exact same values. To make
|
||||
// comparison easy, sort the values in the reference buffer identically to
|
||||
// the values in the non-reference buffer (ie, by the corresponding id of
|
||||
// the non-reference value).
|
||||
std::vector<const HloValue*> reference_values = reference_buffer.values();
|
||||
std::sort(reference_values.begin(), reference_values.end(),
|
||||
[&reference_to_this](const HloValue* a, const HloValue* b) {
|
||||
return reference_to_this.at(a)->id() <
|
||||
reference_to_this.at(b)->id();
|
||||
});
|
||||
TF_RET_CHECK(reference_values.size() == buffer.values().size());
|
||||
for (int i = 0; i < buffer.values().size(); ++i) {
|
||||
TF_RET_CHECK(*reference_values[i] == *buffer.values()[i])
|
||||
<< "Buffer:\n " << buffer
|
||||
<< "\ndoes not have the same values as reference buffer:\n "
|
||||
<< reference_buffer;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
HloBuffer& HloAliasAnalysis::NewHloBuffer() {
|
||||
HloBuffer::Id buffer_id = next_buffer_id_++;
|
||||
auto emplaced = buffers_.emplace(std::piecewise_construct,
|
||||
std::forward_as_tuple(buffer_id),
|
||||
std::forward_as_tuple(buffer_id));
|
||||
CHECK(emplaced.second);
|
||||
|
||||
buffers_vector_.clear();
|
||||
|
||||
return emplaced.first->second;
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::MoveValueToNewBuffer(const HloValue& value) {
|
||||
HloBuffer& new_buffer = NewHloBuffer();
|
||||
MoveValueToBuffer(value, &new_buffer);
|
||||
|
||||
VLOG(3) << "Moved value " << value.ToShortString() << " into new buffer "
|
||||
<< new_buffer.id();
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::MoveValueToBuffer(const HloValue& value,
|
||||
HloBuffer* buffer) {
|
||||
HloBuffer& old_buffer = GetBufferContainingValue(value);
|
||||
CHECK_NE(buffer, &old_buffer);
|
||||
VLOG(3) << "Moved value " << value.ToShortString() << " from buffer "
|
||||
<< old_buffer.id() << " into buffer " << buffer->id();
|
||||
old_buffer.RemoveValue(value);
|
||||
if (old_buffer.values().empty()) {
|
||||
VLOG(3) << "Buffer " << old_buffer.id() << " now empty. Removing.";
|
||||
buffers_.erase(old_buffer.id());
|
||||
buffers_vector_.clear();
|
||||
}
|
||||
|
||||
buffer->AddValue(value);
|
||||
value_to_buffer_[&value] = buffer;
|
||||
}
|
||||
|
||||
string HloAliasAnalysis::ToString() const {
|
||||
string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
|
||||
StrAppend(&out, " Buffers at each position:\n");
|
||||
@ -290,10 +398,10 @@ string HloAliasAnalysis::ToString() const {
|
||||
}
|
||||
|
||||
StrAppend(&out, " Buffers:\n");
|
||||
for (const HloBuffer* buffer : buffers()) {
|
||||
StrAppend(&out, " ", buffer->ToString(), "\n");
|
||||
for (const HloBuffer& buffer : buffers()) {
|
||||
StrAppend(&out, " ", buffer.ToString(), "\n");
|
||||
StrAppend(&out, " positions:\n");
|
||||
for (const HloPosition& position : buffer->ComputePositions()) {
|
||||
for (const HloPosition& position : buffer.ComputePositions()) {
|
||||
StrAppend(&out, " ", position.ToString(), "\n");
|
||||
}
|
||||
}
|
||||
@ -301,217 +409,6 @@ string HloAliasAnalysis::ToString() const {
|
||||
return out;
|
||||
}
|
||||
|
||||
const std::vector<const HloBuffer*>& HloAliasAnalysis::buffers() const {
|
||||
if (buffers_vector_.empty()) {
|
||||
// Lazily construct vector of buffers.
|
||||
buffers_vector_.reserve(buffers_.size());
|
||||
for (auto& pair : buffers_) {
|
||||
buffers_vector_.push_back(&pair.second);
|
||||
}
|
||||
std::sort(buffers_vector_.begin(), buffers_vector_.end(),
|
||||
HloBuffer::IdLessThan);
|
||||
} else {
|
||||
CHECK_EQ(buffers_vector_.size(), buffers_.size());
|
||||
for (const HloBuffer* buffer : buffers_vector_) {
|
||||
DCHECK(ContainsKey(buffers_, buffer->id()));
|
||||
DCHECK(&GetBuffer(buffer->id()) == buffer);
|
||||
}
|
||||
}
|
||||
return buffers_vector_;
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::UpdateAtInstructions(
|
||||
tensorflow::gtl::ArraySlice<const HloInstruction*> instructions) {
|
||||
VLOG(4) << "Updated HLO module:";
|
||||
XLA_VLOG_LINES(4, module_->ToString());
|
||||
|
||||
VLOG(3) << "Before update:";
|
||||
XLA_VLOG_LINES(3, ToString());
|
||||
|
||||
std::vector<const HloValue*> values_to_update;
|
||||
for (const HloInstruction* instruction : instructions) {
|
||||
for (auto& pair : dataflow_analysis().GetInstructionValueSet(instruction)) {
|
||||
for (const HloValue* value : pair.second.values()) {
|
||||
values_to_update.push_back(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
UpdateBuffersForValues(values_to_update);
|
||||
|
||||
VLOG(3) << "After update:";
|
||||
XLA_VLOG_LINES(3, ToString());
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::UpdateAfterChangingOperand(HloInstruction* instruction,
|
||||
HloInstruction* old_operand,
|
||||
HloInstruction* new_operand) {
|
||||
VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
|
||||
<< old_operand->name() << " => " << new_operand->name() << ")";
|
||||
|
||||
dataflow_analysis_->UpdateAfterChangingOperand(instruction, old_operand,
|
||||
new_operand);
|
||||
TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
|
||||
|
||||
VLOG(4) << "Updated dataflow:";
|
||||
XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
|
||||
|
||||
UpdateAtInstructions({instruction, old_operand, new_operand});
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
|
||||
HloInstruction* new_root) {
|
||||
VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
|
||||
<< new_root->name() << ")";
|
||||
|
||||
dataflow_analysis_->UpdateAfterChangingRoot(old_root, new_root);
|
||||
TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
|
||||
|
||||
VLOG(4) << "Updated dataflow:";
|
||||
XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
|
||||
|
||||
UpdateAtInstructions({old_root, new_root});
|
||||
}
|
||||
|
||||
std::vector<HloBuffer*> HloAliasAnalysis::ComputeAliasedBuffers(
|
||||
const HloValue& value) {
|
||||
std::vector<HloBuffer*> aliased_buffers;
|
||||
|
||||
// Value is init of a while (use is while).
|
||||
for (const HloUse& use : value.uses()) {
|
||||
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
// Determine the while value that this shares a buffer with.
|
||||
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
|
||||
use.instruction, use.operand_index);
|
||||
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
|
||||
VLOG(3) << " value is init value to a while; must share buffer with "
|
||||
"while value "
|
||||
<< while_value.ToShortString();
|
||||
}
|
||||
}
|
||||
|
||||
// Value is a parameter of a while body/condition.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
|
||||
const HloComputation* computation = value.defining_instruction()->parent();
|
||||
const CallGraphNode& call_graph_node =
|
||||
dataflow_analysis().call_graph().GetNode(computation);
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
|
||||
// Call graph must have been flattened.
|
||||
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
|
||||
|
||||
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
|
||||
callsite.instruction(), value.defining_index());
|
||||
VLOG(3) << " value is parameter value of the body or condition of a "
|
||||
"while; must share buffer with while value "
|
||||
<< while_value.ToShortString();
|
||||
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Value is the root of a while body.
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
const HloComputation* computation = position.instruction->parent();
|
||||
const CallGraphNode& call_graph_node =
|
||||
dataflow_analysis().call_graph().GetNode(computation);
|
||||
if (position.instruction == computation->root_instruction()) {
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
|
||||
callsite.instruction()->while_body() == computation) {
|
||||
// Call graph must have been flattened.
|
||||
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
|
||||
|
||||
// If the value appears in the root of a while body, then
|
||||
// necessarily the value is defined in the body as well.
|
||||
CHECK_EQ(value.defining_instruction()->parent(), computation);
|
||||
|
||||
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
|
||||
callsite.instruction(), position.index);
|
||||
VLOG(3) << " value is root the body computation of a while; must "
|
||||
"share buffer with while value "
|
||||
<< while_value.ToShortString();
|
||||
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Value is in the while instruction itself.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
|
||||
VLOG(3) << " value is output of a while instruction";
|
||||
aliased_buffers.push_back(&GetUniqueBufferAt(value.defining_instruction(),
|
||||
value.defining_index()));
|
||||
}
|
||||
|
||||
// Uniquify aliased buffers.
|
||||
std::sort(aliased_buffers.begin(), aliased_buffers.end(),
|
||||
HloBuffer::IdLessThan);
|
||||
aliased_buffers.erase(
|
||||
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
|
||||
aliased_buffers.end());
|
||||
|
||||
return aliased_buffers;
|
||||
}
|
||||
|
||||
// This method recomputes the HloBuffer for each of the given HloValues. The
|
||||
// method does not necessarily update the HloBuffer of values which share a
|
||||
// buffer with the given values, but are not explicitly passed in
|
||||
// 'values'. Therefore, the caller must pass in all values which may require an
|
||||
// update according to the kind of HLO graph change which occurred: operand
|
||||
// changed (UpdateAfterChangingOperand), or root of computation changed
|
||||
// (UpdateAfterChangingRoot).
|
||||
void HloAliasAnalysis::UpdateBuffersForValues(
|
||||
tensorflow::gtl::ArraySlice<const HloValue*> values) {
|
||||
for (const HloValue* value : values) {
|
||||
VLOG(3) << "Updating buffer for value: " << value->ToShortString();
|
||||
|
||||
// Gather the set of buffer with aliasing rules (eg, kWhile) which this
|
||||
// value must be contained in due.
|
||||
std::vector<HloBuffer*> aliased_buffers = ComputeAliasedBuffers(*value);
|
||||
|
||||
HloBuffer& current_buffer = GetBufferContainingValue(*value);
|
||||
if (aliased_buffers.empty()) {
|
||||
// The buffer containing 'value' aliases no other buffers. If the buffer
|
||||
// containing 'value' already only contains 'value', then no change is
|
||||
// necessary. If the buffer containing 'value' does contain other values,
|
||||
// then remove 'value' from the buffer and create a new buffer containing
|
||||
// only 'value'
|
||||
if (current_buffer.values().size() == 1) {
|
||||
CHECK_EQ(current_buffer.values()[0], value);
|
||||
} else {
|
||||
MoveValueToNewBuffer(*value);
|
||||
}
|
||||
} else {
|
||||
// If multiple buffers are aliased merge these buffers together into a
|
||||
// single buffer (arbitrarily chosen as the first buffer in the vector).
|
||||
if (aliased_buffers.size() > 1) {
|
||||
for (int64 i = 1; i < aliased_buffers.size(); ++i) {
|
||||
// Make copy of values vector because MoveValueToBuffer invalidates
|
||||
// the values iterator. The could be done more efficiently by moving
|
||||
// all values and once.
|
||||
std::vector<const HloValue*> values = aliased_buffers[i]->values();
|
||||
for (const HloValue* value : values) {
|
||||
MoveValueToBuffer(*value, aliased_buffers[0]);
|
||||
}
|
||||
}
|
||||
aliased_buffers.resize(1);
|
||||
}
|
||||
|
||||
CHECK_EQ(aliased_buffers.size(), 1);
|
||||
HloBuffer* new_buffer = aliased_buffers[0];
|
||||
|
||||
if (¤t_buffer != new_buffer) {
|
||||
MoveValueToBuffer(*value, new_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(4) << "Analysis after update:";
|
||||
XLA_VLOG_LINES(4, ToString());
|
||||
}
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
HloModule* module) {
|
||||
@ -524,18 +421,28 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
HloDataflowAnalysis::Run(module, /*ssa_form=*/true,
|
||||
/*bitcast_defines_value=*/false));
|
||||
|
||||
alias_analysis->InitializeBufferSets();
|
||||
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
|
||||
buffer_map.MergeAliasedBuffers();
|
||||
|
||||
VLOG(3) << "After initialization:";
|
||||
XLA_VLOG_LINES(3, alias_analysis->ToString());
|
||||
|
||||
std::vector<const HloValue*> all_values;
|
||||
for (const HloValue& value : alias_analysis->dataflow_analysis().values()) {
|
||||
all_values.push_back(&value);
|
||||
// Create a vector of HloBuffers, one for each set of values in the
|
||||
// BufferValueMap. Create the HloBuffers as a vector of contiguously numbered
|
||||
// buffers.
|
||||
std::vector<BufferValueMap::BufferNumber> sorted_buffer_numbers =
|
||||
buffer_map.ComputeSortedBufferNumbers();
|
||||
alias_analysis->buffers_.reserve(sorted_buffer_numbers.size());
|
||||
HloBuffer::Id next_id = 0;
|
||||
for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) {
|
||||
auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
|
||||
std::vector<const HloValue*> sorted_values(value_set.begin(),
|
||||
value_set.end());
|
||||
std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan);
|
||||
alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
|
||||
for (const HloValue* value : sorted_values) {
|
||||
alias_analysis->value_to_buffer_[value] =
|
||||
&alias_analysis->buffers_.back();
|
||||
}
|
||||
}
|
||||
|
||||
alias_analysis->UpdateBuffersForValues(all_values);
|
||||
|
||||
TF_DCHECK_OK(alias_analysis->Verify());
|
||||
|
||||
XLA_VLOG_LINES(1, alias_analysis->ToString());
|
||||
|
@ -74,7 +74,7 @@ class HloAliasAnalysis {
|
||||
// Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This
|
||||
// vector is lazily computed. Mutating operations on HloAliasAnalysis may
|
||||
// invalidate the underlying vector requiring recomputation.
|
||||
const std::vector<const HloBuffer*>& buffers() const;
|
||||
const std::vector<HloBuffer>& buffers() const { return buffers_; }
|
||||
|
||||
// Returns the underlying dataflow analysis used by this alias analysis.
|
||||
const HloDataflowAnalysis& dataflow_analysis() const {
|
||||
@ -90,50 +90,13 @@ class HloAliasAnalysis {
|
||||
// output of the given instruction.
|
||||
bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const;
|
||||
|
||||
// Updates the analysis after the operands of 'instruction' have changed or if
|
||||
// 'instruction' has been made the root of a computation. Analysis update is
|
||||
// not possible if instructions have been added or removed from the graph.
|
||||
void UpdateAfterChangingOperand(HloInstruction* instruction,
|
||||
HloInstruction* old_operand,
|
||||
HloInstruction* new_operand);
|
||||
void UpdateAfterChangingRoot(HloInstruction* old_root,
|
||||
HloInstruction* new_root);
|
||||
|
||||
// Compare the dataflow analysis against a clean recomputation of the
|
||||
// analysis. Returns an error status if there is a mismatch. Useful for
|
||||
// verifying the correctness after updates to the analysis.
|
||||
Status VerifyAgainstReference() const;
|
||||
|
||||
protected:
|
||||
HloAliasAnalysis(HloModule* module);
|
||||
|
||||
// Create a new empty HloBuffer.
|
||||
HloBuffer& NewHloBuffer();
|
||||
|
||||
// Move the given value to the given buffer. The value is removed from it's
|
||||
// current buffer.
|
||||
void MoveValueToBuffer(const HloValue& value, HloBuffer* buffer);
|
||||
|
||||
// Move the given value to a newly created buffer. The value is removed from
|
||||
// it's current buffer.
|
||||
void MoveValueToNewBuffer(const HloValue& value);
|
||||
|
||||
// Construct the initial set of buffer sets where an HloBuffer is created for
|
||||
// each HloValue in the module.
|
||||
void InitializeBufferSets();
|
||||
|
||||
// Compute and return the buffers with aliasing rules (eg, kWhile) which the
|
||||
// given value must be contained in.
|
||||
std::vector<HloBuffer*> ComputeAliasedBuffers(const HloValue& value);
|
||||
|
||||
// Recompute the HloBuffers for the given values.
|
||||
void UpdateBuffersForValues(
|
||||
tensorflow::gtl::ArraySlice<const HloValue*> values);
|
||||
|
||||
// Recompute the HloBuffers for all the values which appear in the output of
|
||||
// the given instructions.
|
||||
void UpdateAtInstructions(
|
||||
tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
|
||||
explicit HloAliasAnalysis(HloModule* module);
|
||||
|
||||
// Verify various invariants of the alias analysis.
|
||||
Status Verify() const;
|
||||
@ -143,20 +106,12 @@ class HloAliasAnalysis {
|
||||
// The underlying dataflow analysis used by this alias analysis.
|
||||
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
||||
|
||||
// The map of all HloBuffers in the module. We pass around pointers to the
|
||||
// mapped HloBuffers, so the underlying container must keep them valid despite
|
||||
// mutations touching other map entries.
|
||||
std::unordered_map<HloBuffer::Id, HloBuffer> buffers_;
|
||||
|
||||
// A map indicating which buffer a value is contained in.
|
||||
tensorflow::gtl::FlatMap<const HloValue*, HloBuffer*> value_to_buffer_;
|
||||
|
||||
// A lazily constructed vector containing all HloBuffers sorted by
|
||||
// HloBuffer::Id.
|
||||
mutable std::vector<const HloBuffer*> buffers_vector_;
|
||||
|
||||
// The Id to use for the next HloBuffer.
|
||||
int64 next_buffer_id_ = 0;
|
||||
std::vector<HloBuffer> buffers_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -87,14 +87,13 @@ class HloAliasAnalysisTest : public HloTestBase {
|
||||
// constructed.
|
||||
bool AnyValuesInSameBufferInterfere() {
|
||||
DependencyHloOrdering ordering(module_.get());
|
||||
for (const HloBuffer* buffer : analysis_->buffers()) {
|
||||
for (const HloValue* value_a : buffer->values()) {
|
||||
for (const HloValue* value_b : buffer->values()) {
|
||||
for (const HloBuffer& buffer : analysis_->buffers()) {
|
||||
for (const HloValue* value_a : buffer.values()) {
|
||||
for (const HloValue* value_b : buffer.values()) {
|
||||
if (*value_a != *value_b &&
|
||||
analysis_->dataflow_analysis().MayInterfere(*value_a, *value_b,
|
||||
ordering)) {
|
||||
ordering.MayInterfere(*value_a, *value_b)) {
|
||||
VLOG(1) << *value_a << " interferes with " << *value_b
|
||||
<< " in buffer: " << *buffer;
|
||||
<< " in buffer: " << buffer;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -384,10 +383,7 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
|
||||
|
||||
EXPECT_THAT(
|
||||
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
|
||||
UnorderedElementsAre(GetValueDefinedAt(xla_while, /*index=*/{0}),
|
||||
GetValueDefinedAt(body_param, /*index=*/{0}),
|
||||
GetValueDefinedAt(cond_param, /*index=*/{0}),
|
||||
GetValueDefinedAt(constant1)));
|
||||
UnorderedElementsAre(GetValueDefinedAt(constant1)));
|
||||
EXPECT_THAT(
|
||||
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
|
||||
UnorderedElementsAre(GetValueDefinedAt(constant2),
|
||||
@ -631,9 +627,9 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
|
||||
// HloBuffers.
|
||||
EXPECT_THAT(
|
||||
analysis.buffers(),
|
||||
UnorderedElementsAre(&analysis.GetUniqueBufferAt(constant1),
|
||||
&analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
|
||||
&analysis.GetUniqueBufferAt(cond_constant)));
|
||||
UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
|
||||
analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
|
||||
analysis.GetUniqueBufferAt(cond_constant)));
|
||||
|
||||
// The tuple elements of the while and the three constant inputs should all be
|
||||
// smooshed into the same buffer.
|
||||
@ -820,127 +816,5 @@ TEST_F(HloAliasAnalysisTest, Bitcast) {
|
||||
analysis.GetUniqueBufferAt(bitcast));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, UpdateAnalysisForWhile) {
|
||||
// Test updating alias analysis after modifying a module with an array shaped
|
||||
// while:
|
||||
//
|
||||
// body(F32[] %param):
|
||||
// %negate = Negate(%param)
|
||||
//
|
||||
// condition(F32[] %param):
|
||||
// return Constant(false)
|
||||
//
|
||||
// entry:
|
||||
// %constant = Constant(1.0)
|
||||
// %exp = Exp(%constant)
|
||||
// return While(%exp, body, condition)
|
||||
//
|
||||
auto body_builder = HloComputation::Builder("body");
|
||||
auto body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
scalar_shape_, HloOpcode::kNegate, body_param));
|
||||
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
// Condition computation trivially returns a constant "false".
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
auto cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloComputation* condition =
|
||||
module_->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
|
||||
auto xla_while = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
// Sanity check some alias information.
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(body_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(cond_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(negate));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(xla_while));
|
||||
|
||||
// Set the body root to the body_param. Previously it was Negate(body_param).
|
||||
body->set_root_instruction(body_param);
|
||||
|
||||
// Prior to updating, verify that the analysis is no longer valid.
|
||||
Status verify_status = analysis.VerifyAgainstReference();
|
||||
EXPECT_FALSE(verify_status.ok());
|
||||
|
||||
analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
|
||||
/*new_root*/ body_param);
|
||||
|
||||
// Analysis should be valid after the update.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
// The exponential should now pass through the body transparently.
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(body_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(cond_param));
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(negate));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(xla_while));
|
||||
|
||||
// Now replace the operand of the while with %constant (was %exp).
|
||||
TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
|
||||
analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
|
||||
/*new_operand=*/constant);
|
||||
|
||||
// Analysis should be valid after the update.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(body_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(cond_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(xla_while));
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(exp));
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(negate));
|
||||
|
||||
// And finally make the negate the root of the body again.
|
||||
body->set_root_instruction(negate);
|
||||
analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
|
||||
/*new_root*/ negate);
|
||||
|
||||
// Analysis should be valid after the update.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
|
||||
analysis.GetUniqueBufferAt(body_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
|
||||
analysis.GetUniqueBufferAt(cond_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
|
||||
analysis.GetUniqueBufferAt(xla_while));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(negate));
|
||||
|
||||
auto value_of = [&analysis](const HloInstruction* instruction) {
|
||||
return &analysis.dataflow_analysis().GetValueDefinedAt(instruction);
|
||||
};
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(negate).values(),
|
||||
UnorderedElementsAre(value_of(body_param), value_of(cond_param),
|
||||
value_of(negate), value_of(constant),
|
||||
value_of(xla_while)));
|
||||
}
|
||||
|
||||
// Test update tuple element.
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -36,22 +36,6 @@ namespace xla {
|
||||
using ::tensorflow::str_util::Join;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
void HloBuffer::AddValue(const HloValue& value) {
|
||||
values_.push_back(&value);
|
||||
// Sort vector and remove duplicates.
|
||||
std::sort(values_.begin(), values_.end(), HloValue::IdLessThan);
|
||||
values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
|
||||
values_.end());
|
||||
}
|
||||
|
||||
void HloBuffer::RemoveValue(const HloValue& value) {
|
||||
// The values are sorted, so finding the value could be done in log(n) time
|
||||
// with a binary search.
|
||||
auto it = std::find(values_.begin(), values_.end(), &value);
|
||||
CHECK(it != values_.end());
|
||||
values_.erase(it);
|
||||
}
|
||||
|
||||
bool HloBuffer::operator==(const HloBuffer& other) const {
|
||||
bool equal = id() == other.id();
|
||||
if (equal) {
|
||||
|
@ -84,22 +84,15 @@ class HloBuffer {
|
||||
return a->id() == b->id();
|
||||
}
|
||||
|
||||
HloBuffer(Id id) : id_(id) {}
|
||||
HloBuffer(Id id, tensorflow::gtl::ArraySlice<const HloValue*> values)
|
||||
: id_(id), values_(values.begin(), values.end()) {}
|
||||
|
||||
// Return the unique identifier for this HloBuffer.
|
||||
Id id() const { return id_; }
|
||||
|
||||
// Add a value to the set of values held by this buffer. Also adds the
|
||||
// HloPositions of the value to the positions vector of the buffer. If the
|
||||
// buffer already contains this value, then this method is a nop.
|
||||
void AddValue(const HloValue& value);
|
||||
void RemoveValue(const HloValue& value);
|
||||
|
||||
// Return all values contained in this buffer.
|
||||
const std::vector<const HloValue*>& values() const { return values_; }
|
||||
|
||||
std::vector<HloPosition> ComputePositions() const;
|
||||
|
||||
// Return the unique HLO value in the buffer. CHECK fails if the buffer does
|
||||
// not contain exactly one value.
|
||||
const HloValue& GetUniqueValue() const {
|
||||
@ -107,6 +100,8 @@ class HloBuffer {
|
||||
return *values_[0];
|
||||
}
|
||||
|
||||
std::vector<HloPosition> ComputePositions() const;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const HloBuffer& other) const;
|
||||
@ -118,7 +113,7 @@ class HloBuffer {
|
||||
|
||||
// The set of values contained in this buffer. Vector contains no duplicates
|
||||
// and is sorted stably by HloValue::Id.
|
||||
std::vector<const HloValue*> values_;
|
||||
const std::vector<const HloValue*> values_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
|
||||
|
@ -118,13 +118,11 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) {
|
||||
}
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo,
|
||||
HloOpcode opcode) {
|
||||
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo) {
|
||||
return HandleElementwiseOp(hlo);
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) {
|
||||
Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo) {
|
||||
return HandleElementwiseOp(hlo);
|
||||
}
|
||||
|
||||
|
@ -49,9 +49,8 @@ class HloCostAnalysis : public DfsHloVisitor {
|
||||
using ShapeSizeFunction = std::function<int64(const Shape&)>;
|
||||
explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
|
||||
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override;
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override;
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo) override;
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo) override;
|
||||
Status HandleConstant(HloInstruction* constant,
|
||||
const Literal& literal) override;
|
||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element,
|
||||
|
@ -67,6 +67,22 @@ HloValue& HloDataflowAnalysis::GetValueDefinedAt(
|
||||
return GetUniqueValueAt(instruction, index);
|
||||
}
|
||||
|
||||
HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
|
||||
const ShapeIndex& index,
|
||||
bool is_phi) {
|
||||
const int64 value_id = next_value_id_++;
|
||||
auto emplaced = values_.emplace(
|
||||
std::piecewise_construct, std::forward_as_tuple(value_id),
|
||||
std::forward_as_tuple(value_id, instruction, index, is_phi));
|
||||
CHECK(emplaced.second);
|
||||
|
||||
return &emplaced.first->second;
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
|
||||
values_.erase(value_id);
|
||||
}
|
||||
|
||||
string HloDataflowAnalysis::ToString() const {
|
||||
string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
|
||||
StrAppend(&out, " Instruction value sets:\n");
|
||||
@ -99,22 +115,98 @@ string HloDataflowAnalysis::ToString() const {
|
||||
}
|
||||
}
|
||||
StrAppend(&out, " HloValues:\n");
|
||||
for (const HloValue& value : values()) {
|
||||
StrAppend(&out, value.ToString(/*indent=*/4));
|
||||
}
|
||||
StrAppend(&out, " Phi resolutions:\n");
|
||||
for (const HloValue& value : values()) {
|
||||
if (value.is_phi()) {
|
||||
const HloValue* resolved_value = ResolvePhi(value);
|
||||
StrAppend(&out, " ", value.ToShortString(), " => ",
|
||||
resolved_value == nullptr ? "UNKNOWN"
|
||||
: resolved_value->ToShortString(),
|
||||
"\n");
|
||||
}
|
||||
for (const HloValue* value : values()) {
|
||||
StrAppend(&out, value->ToString(/*indent=*/4));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::Phi(
|
||||
HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
|
||||
CHECK(ssa_form_);
|
||||
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (auto& pair : GetInstructionValueSet(instruction)) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
HloValueSet& value_set = pair.second;
|
||||
|
||||
// Positions with phi values should never have more than one value in the
|
||||
// value set.
|
||||
CHECK_LE(value_set.values().size(), 1);
|
||||
const HloValue* current_value =
|
||||
value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
|
||||
|
||||
// Construct a vector of unique value IDs of the inputs.
|
||||
std::vector<HloValue::Id> input_value_ids;
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
for (const HloValue* value : input->element(index).values()) {
|
||||
input_value_ids.push_back(value->id());
|
||||
}
|
||||
}
|
||||
std::sort(input_value_ids.begin(), input_value_ids.end());
|
||||
input_value_ids.erase(
|
||||
std::unique(input_value_ids.begin(), input_value_ids.end()),
|
||||
input_value_ids.end());
|
||||
|
||||
// Remove the existing phi value (if it exists). The phi can be its own
|
||||
// input, for example, in while body parameters where the body passes
|
||||
// through the parameter value.
|
||||
bool current_value_defined_here =
|
||||
(current_value != nullptr &&
|
||||
current_value->defining_instruction() == instruction &&
|
||||
current_value->defining_index() == index);
|
||||
if (current_value_defined_here) {
|
||||
CHECK(current_value->is_phi());
|
||||
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
|
||||
current_value->id());
|
||||
if (it != input_value_ids.end()) {
|
||||
input_value_ids.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
if (input_value_ids.empty()) {
|
||||
// A value set which has at least one element should never have its value
|
||||
// set reduced to zero elements. During dataflow value sets only can go
|
||||
// from empty to non-empty, not the reverse.
|
||||
CHECK_EQ(value_set.values().size(), 0)
|
||||
<< "Instruction " << instruction->name() << " at index " << index
|
||||
<< " previously had non-empty value set. Value set: " << value_set;
|
||||
} else if (input_value_ids.size() == 1) {
|
||||
// Only a single value reaches this point. There should be no phi, and
|
||||
// this value set should contain this single value.
|
||||
const HloValue& new_value = GetValue(input_value_ids[0]);
|
||||
if (current_value == nullptr) {
|
||||
value_set.Clear();
|
||||
value_set.AddValue(&new_value);
|
||||
changed = true;
|
||||
} else if (current_value != &new_value) {
|
||||
if (current_value_defined_here) {
|
||||
// Remove the existing phi.
|
||||
DeleteHloValue(current_value->id());
|
||||
}
|
||||
value_set.Clear();
|
||||
value_set.AddValue(&new_value);
|
||||
changed = true;
|
||||
}
|
||||
} else {
|
||||
// Multiple distinct values reach this point. A phi value is
|
||||
// necessary.
|
||||
CHECK_GT(input_value_ids.size(), 1);
|
||||
if (current_value == nullptr || !current_value->is_phi()) {
|
||||
value_set.Clear();
|
||||
value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
|
||||
return values_.at(value_id);
|
||||
}
|
||||
@ -142,129 +234,6 @@ HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
|
||||
return GetValueSet(position.instruction, position.index);
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateAfterChangingOperand(
|
||||
HloInstruction* instruction, HloInstruction* old_operand,
|
||||
HloInstruction* new_operand) {
|
||||
CHECK(std::find(instruction->operands().begin(),
|
||||
instruction->operands().end(),
|
||||
new_operand) != instruction->operands().end());
|
||||
VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
|
||||
<< old_operand->name() << " => " << new_operand->name() << ")";
|
||||
|
||||
std::vector<HloInstruction*> to_update = {instruction};
|
||||
|
||||
// If the instruction calls any computations then add the parameters of called
|
||||
// computation to capture any changes to the dataflow into the subcomputation
|
||||
// introduced by the new operand.
|
||||
for (HloComputation* computation : instruction->called_computations()) {
|
||||
to_update.insert(to_update.end(),
|
||||
computation->parameter_instructions().begin(),
|
||||
computation->parameter_instructions().end());
|
||||
}
|
||||
|
||||
UpdateInstructionsAndPropagate(to_update);
|
||||
|
||||
// The uses of the values in the old and new operand may have changed. Uses of
|
||||
// other HloValues are updated in UpdateInstructionsAndPropagate.
|
||||
for (auto& pair : GetInstructionValueSet(old_operand)) {
|
||||
for (const HloValue* value : pair.second.values()) {
|
||||
GetValue(value->id()).RecomputeUses();
|
||||
}
|
||||
}
|
||||
for (auto& pair : GetInstructionValueSet(new_operand)) {
|
||||
for (const HloValue* value : pair.second.values()) {
|
||||
GetValue(value->id()).RecomputeUses();
|
||||
}
|
||||
}
|
||||
|
||||
TF_DCHECK_OK(VerifyAgainstReference());
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
|
||||
HloInstruction* new_root) {
|
||||
VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
|
||||
<< new_root->name() << ")";
|
||||
|
||||
CHECK_EQ(new_root, new_root->parent()->root_instruction());
|
||||
CHECK_EQ(new_root->parent(), old_root->parent());
|
||||
|
||||
std::vector<HloInstruction*> to_update = {old_root, new_root};
|
||||
|
||||
const CallGraphNode& call_graph_node =
|
||||
call_graph_->GetNode(new_root->parent());
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
|
||||
to_update.push_back(callsite.instruction());
|
||||
} else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
|
||||
// Add the while itself, and the body and condition parameters.
|
||||
to_update.push_back(callsite.instruction());
|
||||
to_update.push_back(
|
||||
callsite.instruction()->while_body()->parameter_instruction(0));
|
||||
to_update.push_back(
|
||||
callsite.instruction()->while_condition()->parameter_instruction(0));
|
||||
}
|
||||
}
|
||||
|
||||
UpdateInstructionsAndPropagate(to_update);
|
||||
|
||||
TF_DCHECK_OK(VerifyAgainstReference());
|
||||
}
|
||||
|
||||
const HloValue* HloDataflowAnalysis::ResolvePhi(const HloValue& phi) const {
|
||||
CHECK(phi.is_phi());
|
||||
|
||||
tensorflow::gtl::FlatSet<const HloValue*> visited;
|
||||
std::queue<const HloValue*> worklist;
|
||||
auto add_to_worklist = [&worklist, &visited](const HloValue* v) {
|
||||
if (visited.insert(v).second) {
|
||||
// 'v' was not previously in visited.
|
||||
worklist.push(v);
|
||||
}
|
||||
};
|
||||
add_to_worklist(&phi);
|
||||
|
||||
const HloValue* resolved_value = nullptr;
|
||||
while (!worklist.empty()) {
|
||||
const HloValue* value = worklist.front();
|
||||
worklist.pop();
|
||||
|
||||
if (!value->is_phi()) {
|
||||
if (resolved_value == nullptr) {
|
||||
resolved_value = value;
|
||||
} else if (resolved_value != value) {
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
for (const HloValue* input : phi_inputs_.at(value)) {
|
||||
add_to_worklist(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
return resolved_value;
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdatePhiInputs(
|
||||
const HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
|
||||
CHECK(ssa_form_);
|
||||
for (auto& pair : GetInstructionValueSet(instruction)) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
const HloValue& phi_value = GetUniqueValueAt(instruction, index);
|
||||
auto& phi_inputs = phi_inputs_.at(&phi_value);
|
||||
phi_inputs.clear();
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
for (const HloValue* value : input->element(index).values()) {
|
||||
// The number of phi inputs is typically 2, and virtually always very
|
||||
// small.
|
||||
if (std::find(phi_inputs.begin(), phi_inputs.end(), value) ==
|
||||
phi_inputs.end()) {
|
||||
phi_inputs.push_back(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
|
||||
CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
|
||||
const InstructionValueSet& operand_set =
|
||||
@ -380,8 +349,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
|
||||
}
|
||||
|
||||
if (ssa_form_ && called_from_while) {
|
||||
UpdatePhiInputs(parameter, inputs);
|
||||
return false;
|
||||
return Phi(parameter, inputs);
|
||||
} else {
|
||||
return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
|
||||
}
|
||||
@ -439,8 +407,7 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
|
||||
&GetInstructionValueSet(xla_while->while_body()->root_instruction()),
|
||||
&GetInstructionValueSet(xla_while->operand(0))};
|
||||
if (ssa_form_) {
|
||||
UpdatePhiInputs(xla_while, inputs);
|
||||
return false;
|
||||
return Phi(xla_while, inputs);
|
||||
} else {
|
||||
return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
|
||||
}
|
||||
@ -487,38 +454,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
|
||||
VLOG(3) << "Worklist top: " << instruction->name();
|
||||
VLOG(3) << ToString();
|
||||
|
||||
// The updating of the instruction value set below in
|
||||
// UpdateInstructionValueSet does not update HloValue::positions(). To
|
||||
// perform the positions() update remove all positions in 'instruction' from
|
||||
// the HloValues in 'instruction's value set prior to the update, then after
|
||||
// the update add the new positions back in. There is likely a more
|
||||
// efficient way of doing this.
|
||||
for (auto& pair : GetInstructionValueSet(instruction)) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
HloValueSet& value_set = pair.second;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
if (value->defining_instruction() != instruction) {
|
||||
// Use GetValue for a non-const HloValue reference.
|
||||
GetValue(value->id()).RemovePosition(instruction, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = UpdateInstructionValueSet(instruction);
|
||||
|
||||
// Add the positions back in.
|
||||
for (auto& pair : GetInstructionValueSet(instruction)) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
HloValueSet& value_set = pair.second;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
if (value->defining_instruction() != instruction) {
|
||||
// Use GetValue for a non-const HloValue reference.
|
||||
GetValue(value->id()).AddPosition(instruction, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!changed) {
|
||||
if (!UpdateInstructionValueSet(instruction)) {
|
||||
// No change to the instruction's value set.
|
||||
VLOG(4) << "No change.";
|
||||
continue;
|
||||
@ -531,12 +467,16 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
worklist.push(user);
|
||||
|
||||
// If user calls a computation, then the respective parameter(s) of the
|
||||
// computation need to be updated.
|
||||
// If user sequentially calls a computation, then the respective
|
||||
// parameter(s) of the computation need to be updated.
|
||||
for (HloComputation* called_computation : user->called_computations()) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
worklist.push(
|
||||
called_computation->parameter_instruction(operand_number));
|
||||
const CallGraphNode& call_graph_node =
|
||||
call_graph_->GetNode(called_computation);
|
||||
if (call_graph_node.context() == CallContext::kSequential) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
worklist.push(
|
||||
called_computation->parameter_instruction(operand_number));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -574,25 +514,10 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
|
||||
}
|
||||
|
||||
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
// Gather the values to create before creating them. This is done because we
|
||||
// want to allocate the vector of values only once so references to elements
|
||||
// are stable.
|
||||
struct ValueToCreate {
|
||||
HloInstruction* instruction;
|
||||
ShapeIndex index;
|
||||
bool is_phi;
|
||||
};
|
||||
std::vector<ValueToCreate> values_to_create;
|
||||
|
||||
for (const std::unique_ptr<HloComputation>& computation :
|
||||
module_->computations()) {
|
||||
const CallGraphNode& call_graph_node =
|
||||
call_graph_->GetNode(computation.get());
|
||||
bool called_from_while = std::any_of(
|
||||
call_graph_node.caller_callsites().begin(),
|
||||
call_graph_node.caller_callsites().end(), [](const CallSite& cs) {
|
||||
return cs.instruction()->opcode() == HloOpcode::kWhile;
|
||||
});
|
||||
|
||||
for (const std::unique_ptr<HloInstruction>& instruction :
|
||||
computation->instructions()) {
|
||||
@ -603,20 +528,22 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
|
||||
// Lambda to set the value set to define all values in the output of the
|
||||
// instruction.
|
||||
auto define_all_values = [this, &instruction,
|
||||
&values_to_create](bool is_phi = false) {
|
||||
auto define_all_values = [this, &instruction](bool is_phi = false) {
|
||||
for (auto& pair : GetInstructionValueSet(instruction.get())) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
values_to_create.push_back({instruction.get(), index, is_phi});
|
||||
HloValue* value =
|
||||
NewHloValue(instruction.get(), index, /*is_phi=*/false);
|
||||
GetValueSet(instruction.get(), index).AddValue(value);
|
||||
}
|
||||
};
|
||||
|
||||
// Lambda to set the value set to define only the top-level buffer in the
|
||||
// output of the instruction. Any other values flow from the operands of
|
||||
// the instruction (or from cross-computation dataflow).
|
||||
auto define_top_level_only = [this, &instruction, &values_to_create]() {
|
||||
values_to_create.push_back(
|
||||
{instruction.get(), /*index=*/{}, /*is_phi=*/false});
|
||||
auto define_top_level_only = [this, &instruction]() {
|
||||
HloValue* value =
|
||||
NewHloValue(instruction.get(), /*index=*/{}, /*is_phi=*/false);
|
||||
GetValueSet(instruction.get(), /*index=*/{}).AddValue(value);
|
||||
};
|
||||
|
||||
switch (instruction->opcode()) {
|
||||
@ -626,10 +553,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kWhile:
|
||||
if (ssa_form_) {
|
||||
define_all_values(/*is_phi=*/true);
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
// These instructions define no values. The values in their output
|
||||
@ -654,10 +577,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
// values in their output. Otherwise the values of the parameter
|
||||
// come from the caller (eg, operands to the kCall instruction).
|
||||
define_all_values();
|
||||
} else if (call_graph_node.context() == CallContext::kSequential &&
|
||||
called_from_while && ssa_form_) {
|
||||
// Parameters of while bodies and conditions are phis.
|
||||
define_all_values(/*is_phi=*/true);
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kCopy:
|
||||
@ -674,164 +593,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
}
|
||||
}
|
||||
|
||||
// Reserve the vector ahead of time so references to elements are stable.
|
||||
values_.reserve(values_to_create.size());
|
||||
for (int64 i = 0; i < values_to_create.size(); ++i) {
|
||||
const ValueToCreate& to_create = values_to_create[i];
|
||||
values_.emplace_back(/*id=*/i, to_create.instruction, to_create.index,
|
||||
to_create.is_phi);
|
||||
const HloValue& value = values_.back();
|
||||
GetValueSet(to_create.instruction, to_create.index).AddValue(&value);
|
||||
if (value.is_phi()) {
|
||||
phi_inputs_[&value] = {};
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const {
|
||||
// If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
|
||||
// is live into the module.
|
||||
if (b.defining_instruction()->parent() == module_->entry_computation() &&
|
||||
b.defining_instruction()->opcode() == HloOpcode::kParameter) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Phi values require special handling. Because XLA does not have a phi
|
||||
// instruction, the definition instruction of the phis values are
|
||||
// placeholders: either the subcomputation parameter (body or condition) or
|
||||
// the while instruction. However, the program point where these values are
|
||||
// logically defined does not necessarily coincide exactly with program point
|
||||
// of these place-holder instructions. So we explicitly define the following
|
||||
// order for phi values:
|
||||
//
|
||||
// body/condition parameter phi:
|
||||
// Defined before all values defined in its computation excepting other
|
||||
// phis.
|
||||
//
|
||||
// while phi:
|
||||
// defined after all values defined in the condition or body.
|
||||
//
|
||||
auto is_body_or_condition_phi = [](const HloValue& v) {
|
||||
return v.is_phi() &&
|
||||
v.defining_instruction()->opcode() == HloOpcode::kParameter;
|
||||
};
|
||||
if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
|
||||
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
|
||||
a.defining_instruction()->parent())) {
|
||||
return true;
|
||||
}
|
||||
if (is_body_or_condition_phi(b) &&
|
||||
call_graph_->InstructionIsNestedIn(a.defining_instruction(),
|
||||
b.defining_instruction()->parent())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If 'b' is a while phi and 'a' is in the body or condition, then 'a'
|
||||
// executes before 'b'.
|
||||
if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
|
||||
(call_graph_->InstructionIsNestedIn(
|
||||
a.defining_instruction(), b.defining_instruction()->while_body()) ||
|
||||
call_graph_->InstructionIsNestedIn(
|
||||
a.defining_instruction(),
|
||||
b.defining_instruction()->while_condition()))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return ordering.ExecutesBefore(a.defining_instruction(),
|
||||
b.defining_instruction());
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UseIsBeforeValueDefinition(
|
||||
const HloUse& use, const HloValue& value,
|
||||
const HloOrdering& ordering) const {
|
||||
if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the use is at the instruction where the value is defined, then the use
|
||||
// is before the def if the instruction allows buffer sharing (in place
|
||||
// computation).
|
||||
if (use.instruction == value.defining_instruction() &&
|
||||
CanShareOperandBufferWithUser(
|
||||
use.instruction->mutable_operand(use.operand_number),
|
||||
use.operand_index, value.defining_instruction(),
|
||||
value.defining_index())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// The use at a while is an input to a phi, and logically occurs before values
|
||||
// are defined in the body or condition computations.
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
const HloInstruction* xla_while = use.instruction;
|
||||
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
xla_while->while_body()) ||
|
||||
call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
xla_while->while_condition())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Similarly if the value is defined at a while, it logically occurs after any
|
||||
// uses in the body or condition computations.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
|
||||
CHECK(ssa_form_);
|
||||
const HloInstruction* xla_while = value.defining_instruction();
|
||||
if (call_graph_->InstructionIsNestedIn(use.instruction,
|
||||
xla_while->while_body()) ||
|
||||
call_graph_->InstructionIsNestedIn(use.instruction,
|
||||
xla_while->while_condition())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::LiveRangeStrictlyBefore(
|
||||
const HloValue& a, const HloValue& b, const HloOrdering& ordering) const {
|
||||
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
|
||||
<< ", b = " << b.ToShortString() << ")";
|
||||
if (!IsDefinedBefore(a, b, ordering)) {
|
||||
VLOG(4) << "a not defined before b";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Live-out values from the module can never have ranges strictly before any
|
||||
// other value.
|
||||
if (a.live_out_of_module()) {
|
||||
VLOG(4) << "a is live out of module";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Live-out values of computations can never have ranges strictly before any
|
||||
// other value in the computation (including values nested in
|
||||
// subcomputations).
|
||||
if (a.live_out_of_computation() &&
|
||||
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
|
||||
a.defining_instruction()->parent())) {
|
||||
VLOG(4) << "a is live out of computation containing b";
|
||||
return false;
|
||||
}
|
||||
|
||||
// All uses of 'a' must be before 'b' is defined.
|
||||
for (const HloUse& use : a.uses()) {
|
||||
if (!UseIsBeforeValueDefinition(use, b, ordering)) {
|
||||
VLOG(4) << "use of a (" << use << ") not before b is defined";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const {
|
||||
// Buffers without disjoint liveness may interfere.
|
||||
return !LiveRangeStrictlyBefore(a, b, ordering) &&
|
||||
!LiveRangeStrictlyBefore(b, a, ordering);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
HloModule* module, bool ssa_form, bool bitcast_defines_value) {
|
||||
@ -855,6 +619,33 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
}
|
||||
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
|
||||
|
||||
// Add in positions to all values.
|
||||
for (const std::unique_ptr<HloComputation>& computation :
|
||||
module->computations()) {
|
||||
for (const std::unique_ptr<HloInstruction>& instruction :
|
||||
computation->instructions()) {
|
||||
for (const auto& pair :
|
||||
dataflow_analysis->GetInstructionValueSet(instruction.get())) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
const HloValueSet& value_set = pair.second;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
if (value->defining_instruction() != instruction.get()) {
|
||||
dataflow_analysis->GetValue(value->id())
|
||||
.AddPosition(instruction.get(), index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct vector of values.
|
||||
dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
|
||||
for (auto& pair : dataflow_analysis->values_) {
|
||||
dataflow_analysis->values_vector_.push_back(&pair.second);
|
||||
}
|
||||
std::sort(dataflow_analysis->values_vector_.begin(),
|
||||
dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
|
||||
|
||||
TF_DCHECK_OK(dataflow_analysis->Verify());
|
||||
|
||||
XLA_VLOG_LINES(1, dataflow_analysis->ToString());
|
||||
@ -865,14 +656,14 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
Status HloDataflowAnalysis::Verify() const {
|
||||
// Verify each HloValue appears in the value sets that the value's positions()
|
||||
// indicate.
|
||||
for (const HloValue& value : values()) {
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
for (const HloValue* value : values()) {
|
||||
for (const HloPosition& position : value->positions()) {
|
||||
const HloValueSet& value_set = GetValueSet(position);
|
||||
TF_RET_CHECK(std::find(value_set.values().begin(),
|
||||
value_set.values().end(),
|
||||
&value) != value_set.values().end())
|
||||
value) != value_set.values().end())
|
||||
<< "Value set at position " << position << " does not contain value "
|
||||
<< value.ToShortString();
|
||||
<< value->ToShortString();
|
||||
}
|
||||
}
|
||||
|
||||
@ -898,75 +689,4 @@ Status HloDataflowAnalysis::Verify() const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloDataflowAnalysis::VerifyAgainstReference() const {
|
||||
TF_RETURN_IF_ERROR(Verify());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> reference,
|
||||
Run(module_, ssa_form_, bitcast_defines_value_));
|
||||
TF_RETURN_IF_ERROR(reference->Verify());
|
||||
|
||||
VLOG(2) << "This analysis:";
|
||||
XLA_VLOG_LINES(2, ToString());
|
||||
VLOG(2) << "Reference:";
|
||||
XLA_VLOG_LINES(2, reference->ToString());
|
||||
|
||||
// Verify value sets in each position are identical.
|
||||
for (const auto& computation : module_->computations()) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
for (const auto& pair : GetInstructionValueSet(instruction.get())) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
const HloValueSet& value_set = pair.second;
|
||||
const HloValueSet& reference_value_set =
|
||||
reference->GetValueSet(instruction.get(), index);
|
||||
|
||||
auto value_in_set = [](const HloValue& v, const HloValueSet& vset) {
|
||||
return std::find_if(vset.values().begin(), vset.values().end(),
|
||||
[&v](const HloValue* w) { return *w == v; }) !=
|
||||
vset.values().end();
|
||||
};
|
||||
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
TF_RET_CHECK(value_in_set(*value, reference_value_set))
|
||||
<< "Value " << value->ToShortString()
|
||||
<< " does not exist in reference";
|
||||
}
|
||||
for (const HloValue* reference_value : reference_value_set.values()) {
|
||||
TF_RET_CHECK(value_in_set(*reference_value, value_set))
|
||||
<< "Value " << reference_value->ToShortString()
|
||||
<< " only exists in reference";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all phis resolve identically and uses are identical.
|
||||
for (const HloValue& value : values()) {
|
||||
const HloValue& reference_value = reference->GetValueDefinedAt(
|
||||
value.defining_instruction(), value.defining_index());
|
||||
TF_RET_CHECK(value.is_phi() == reference_value.is_phi());
|
||||
if (value.is_phi()) {
|
||||
const HloValue* resolved_value = ResolvePhi(value);
|
||||
const HloValue* reference_resolved_value =
|
||||
reference->ResolvePhi(reference_value);
|
||||
if (resolved_value == nullptr) {
|
||||
TF_RET_CHECK(reference_resolved_value == nullptr);
|
||||
} else {
|
||||
TF_RET_CHECK(reference_resolved_value != nullptr);
|
||||
TF_RET_CHECK(*reference_resolved_value == *resolved_value);
|
||||
}
|
||||
}
|
||||
|
||||
for (const HloUse& use : value.uses()) {
|
||||
TF_RET_CHECK(std::find(reference_value.uses().begin(),
|
||||
reference_value.uses().end(),
|
||||
use) != reference_value.uses().end());
|
||||
}
|
||||
for (const HloUse& reference_use : reference_value.uses()) {
|
||||
TF_RET_CHECK(std::find(value.uses().begin(), value.uses().end(),
|
||||
reference_use) != value.uses().end());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -88,10 +88,10 @@ class HloDataflowAnalysis {
|
||||
// given position.
|
||||
const HloValueSet& GetValueSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const;
|
||||
HloValueSet& GetValueSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {});
|
||||
const HloValueSet& GetValueSet(const HloPosition& position) const;
|
||||
HloValueSet& GetValueSet(const HloPosition& position);
|
||||
HloValueSet& GetValueSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {});
|
||||
|
||||
// Return the unique value in the HloValueSet at the given instruction and
|
||||
// shape index. CHECKs if the value set does not contain a exactly one value.
|
||||
@ -108,49 +108,11 @@ class HloDataflowAnalysis {
|
||||
const HloValue& GetValue(HloValue::Id value_id) const;
|
||||
HloValue& GetValue(HloValue::Id value_id);
|
||||
|
||||
// Returns whether the given values interfere assuming the given HLO
|
||||
// ordering. Two values interfere if they may both be simultaneously live.
|
||||
bool MayInterfere(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const;
|
||||
|
||||
// Overload which takes HloValue:Ids.
|
||||
bool MayInterfere(HloValue::Id a, HloValue::Id b,
|
||||
const HloOrdering& ordering) const {
|
||||
return MayInterfere(GetValue(a), GetValue(b), ordering);
|
||||
}
|
||||
|
||||
// Return the total number of HloValues.
|
||||
int64 value_count() const { return values_.size(); }
|
||||
|
||||
// Return a vector of all HloValues.
|
||||
const std::vector<HloValue>& values() const { return values_; }
|
||||
|
||||
// Updates the dataflow after the changing an operand of
|
||||
// 'instruction'. Dataflow update is not possible if instructions have been
|
||||
// added or removed from the graph.
|
||||
void UpdateAfterChangingOperand(HloInstruction* instruction,
|
||||
HloInstruction* old_operand,
|
||||
HloInstruction* new_operand);
|
||||
|
||||
// Updates the dataflow after the changing the root of a computation from
|
||||
// 'old_root' to 'new_root'.
|
||||
void UpdateAfterChangingRoot(HloInstruction* old_root,
|
||||
HloInstruction* new_root);
|
||||
|
||||
// Returns the non-phi HloValue that is the unique (transitive) input to the
|
||||
// given phi. If no such HloValue exists (there are multiple inputs to the
|
||||
// phi) then nullptr is returned. This is computed by all walking the inputs
|
||||
// of the given phi value until non-phi HloValue(s) are encountered.
|
||||
const HloValue* ResolvePhi(const HloValue& phi) const;
|
||||
const HloValue* ResolvePhi(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const {
|
||||
return ResolvePhi(GetValueDefinedAt(instruction, index));
|
||||
}
|
||||
|
||||
// Compare the dataflow analysis against a clean recomputation of the
|
||||
// analysis. Returns an error status if there is a mismatch. Useful for
|
||||
// verifying the correctness after updates to the analysis.
|
||||
Status VerifyAgainstReference() const;
|
||||
// Return a vector of all HloValues stabily sorted by HloValue::Id.
|
||||
const std::vector<const HloValue*>& values() const { return values_vector_; }
|
||||
|
||||
// Return the call graph used for computing the dataflow.
|
||||
const CallGraph& call_graph() const { return *call_graph_; }
|
||||
@ -161,6 +123,13 @@ class HloDataflowAnalysis {
|
||||
HloDataflowAnalysis(HloModule* module, bool ssa_form,
|
||||
bool bitcast_defines_value = false);
|
||||
|
||||
// Returns a new HloValue defined at the given instruction and shape index.
|
||||
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
|
||||
bool is_phi = false);
|
||||
|
||||
// Delete the HloValue with the given ID.
|
||||
void DeleteHloValue(HloValue::Id value_id);
|
||||
|
||||
// Constructs and initializes the InstructionValueSets of all instructions to
|
||||
// contain exactly the HloValues defined by each instruction. These values can
|
||||
// then propagated throughout the HLO graph by calling
|
||||
@ -187,10 +156,11 @@ class HloDataflowAnalysis {
|
||||
void UpdateInstructionsAndPropagate(
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
|
||||
|
||||
// Sets the inputs of the given phi to given value(s).
|
||||
void UpdatePhiInputs(
|
||||
const HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
|
||||
// Return the result of the SSA Phi function applied to the given inputs at
|
||||
// the given instruction. If skip_top_level is true, then the top level of the
|
||||
// value set of 'instruction' is not modified.
|
||||
bool Phi(HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
|
||||
|
||||
// Updates the positions of the HloValues in the output of the given
|
||||
// instruction. This should be called after the instruction value set of
|
||||
@ -203,20 +173,6 @@ class HloDataflowAnalysis {
|
||||
HloInstruction* instruction, const InstructionValueSet& new_value_set,
|
||||
const InstructionValueSet* prev_value_set = nullptr);
|
||||
|
||||
// Returns true if the live range of the given value 'a' is strictly before
|
||||
// the live range of value 'b' using the given HLO ordering.
|
||||
bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const;
|
||||
|
||||
// Returns whether the value 'a' is defined before the value 'b' under the
|
||||
// given ordering.
|
||||
bool IsDefinedBefore(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const;
|
||||
|
||||
// Returns whether the given use is before the given value definition.
|
||||
bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value,
|
||||
const HloOrdering& ordering) const;
|
||||
|
||||
// Verify various invariants of the dataflow analysis.
|
||||
Status Verify() const;
|
||||
|
||||
@ -226,19 +182,19 @@ class HloDataflowAnalysis {
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph_;
|
||||
|
||||
// Array of all values in the module. This is allocated once at analysis
|
||||
// construction time so HloValue references are stable. Updates to the
|
||||
// analysis via UpdateAfterChangingOperand and UpdateAfterChangingRoot do not
|
||||
// result in the creation or destruction of any HloValues.
|
||||
std::vector<HloValue> values_;
|
||||
|
||||
// Map hold the inputs to each phi value in the module. Used by ResolvePhi.
|
||||
tensorflow::gtl::FlatMap<const HloValue*,
|
||||
tensorflow::gtl::InlinedVector<const HloValue*, 2>>
|
||||
phi_inputs_;
|
||||
// The map of all HloValues in the module. We pass around pointers to the
|
||||
// mapped HloValues, so the underlying container must keep them valid despite
|
||||
// mutations touching other map entries.
|
||||
std::unordered_map<HloValue::Id, HloValue> values_;
|
||||
|
||||
// A map from instruction to InstructionValueSet.
|
||||
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
|
||||
|
||||
// A vector containing all HloValues sorted by HloValue::Id.
|
||||
std::vector<const HloValue*> values_vector_;
|
||||
|
||||
// The Id to use for the next HloValue.
|
||||
HloValue::Id next_value_id_ = 0;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -44,8 +43,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
|
||||
|
||||
// Run dataflow analysis on the member module. For convenience returns a
|
||||
// reference to the generated analysis stored in analysis_.
|
||||
HloDataflowAnalysis& RunAnalysis(bool ssa_form,
|
||||
bool bitcast_defines_value = false) {
|
||||
const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
|
||||
bool bitcast_defines_value = false) {
|
||||
analysis_ =
|
||||
HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
|
||||
.ConsumeValueOrDie();
|
||||
@ -71,8 +70,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
|
||||
const HloInstruction* b) {
|
||||
EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
|
||||
EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
|
||||
return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a),
|
||||
analysis_->GetValueDefinedAt(b), ordering);
|
||||
return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
|
||||
analysis_->GetValueDefinedAt(b));
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> module_;
|
||||
@ -499,37 +498,26 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
|
||||
|
||||
if (ssa_form) {
|
||||
// While instruction should define phi values. The value at index {0} is a
|
||||
// degenerate phi with a single input 'constant1'.
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{0}),
|
||||
&analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{0}),
|
||||
&analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{0}),
|
||||
&analysis.GetValueDefinedAt(constant1));
|
||||
// Element 0 of the tuple passed through the body so no phi value is
|
||||
// defined.
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
|
||||
|
||||
// Element 1 of the tuple should be a phi value.
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{1}), nullptr);
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{1}), nullptr);
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{1}), nullptr);
|
||||
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{xla_while, 0, {0}}));
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}}));
|
||||
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
|
||||
.live_out_of_module());
|
||||
// Constant1 passes through the body and out of the module.
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
|
||||
.live_out_of_module());
|
||||
|
||||
@ -613,20 +601,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
|
||||
bool ssa_form = GetParam();
|
||||
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
if (ssa_form) {
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while2).live_out_of_module());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
} else {
|
||||
// Element 0 is passed through all the while instructions and out of the
|
||||
// module.
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
}
|
||||
// Element 0 is passed through all the while instructions and out of the
|
||||
// module..
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
|
||||
@ -705,13 +688,18 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
|
||||
bool ssa_form = GetParam();
|
||||
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
|
||||
if (ssa_form) {
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
|
||||
|
||||
// Element 0 of the nested while is %negate.
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
|
||||
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
|
||||
// Element 1 is a phi value (join of %add and %constant2).
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
|
||||
@ -724,8 +712,6 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(add),
|
||||
analysis.GetValueDefinedAt(constant2)));
|
||||
@ -1496,256 +1482,6 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
|
||||
EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, UpdateAnalysisForWhile) {
|
||||
// Test updating dataflow after modifying a module with an array shaped while:
|
||||
//
|
||||
// body(F32[] %param):
|
||||
// %negate = Negate(%param)
|
||||
//
|
||||
// condition(F32[] %param):
|
||||
// return Constant(false)
|
||||
//
|
||||
// entry:
|
||||
// %constant = Constant(1.0)
|
||||
// %exp = Exp(%constant)
|
||||
// return While(%exp, body, condition)
|
||||
//
|
||||
auto body_builder = HloComputation::Builder("body");
|
||||
auto body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
scalar_shape_, HloOpcode::kNegate, body_param));
|
||||
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
// Condition computation trivially returns a constant "false".
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
auto cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloComputation* condition =
|
||||
module_->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
|
||||
auto xla_while = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
bool ssa_form = GetParam();
|
||||
HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
// Sanity check the initial dataflow analysis before transforming the HLO
|
||||
// graph.
|
||||
if (ssa_form) {
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
|
||||
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(body_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(cond_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(xla_while),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
}
|
||||
|
||||
// Set the body root to the body_param. Previously it was Negate(body_param).
|
||||
body->set_root_instruction(body_param);
|
||||
|
||||
// Prior to updating, verify that the dataflow analysis is no longer valid.
|
||||
Status verify_status = analysis.VerifyAgainstReference();
|
||||
EXPECT_FALSE(verify_status.ok());
|
||||
|
||||
analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
|
||||
/*new_root=*/body_param);
|
||||
|
||||
// Analysis should be valid after the update.
|
||||
TF_EXPECT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
if (ssa_form) {
|
||||
// The phis should now be resolvable as 'exp' is passed through the body
|
||||
// transparently.
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param),
|
||||
&analysis.GetValueDefinedAt(exp));
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param),
|
||||
&analysis.GetValueDefinedAt(exp));
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while), &analysis.GetValueDefinedAt(exp));
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(body_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
|
||||
EXPECT_THAT(HloValuesAt(cond_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
|
||||
EXPECT_THAT(HloValuesAt(xla_while),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
}
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
|
||||
|
||||
// Now replace the operand of the while with %constant (was %exp).
|
||||
TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
|
||||
analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
|
||||
/*new_operand=*/constant);
|
||||
|
||||
// Verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
if (ssa_form) {
|
||||
// The phis now resolve to 'constant'.
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param),
|
||||
&analysis.GetValueDefinedAt(constant));
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param),
|
||||
&analysis.GetValueDefinedAt(constant));
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while),
|
||||
&analysis.GetValueDefinedAt(constant));
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(body_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
|
||||
EXPECT_THAT(HloValuesAt(cond_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
|
||||
EXPECT_THAT(HloValuesAt(xla_while),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
|
||||
}
|
||||
|
||||
// And finally make the negate the root of the body again.
|
||||
body->set_root_instruction(negate);
|
||||
analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
|
||||
/*new_root=*/negate);
|
||||
|
||||
// Verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
if (ssa_form) {
|
||||
// Phis should no longer be resolvable.
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while), nullptr);
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(body_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(cond_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(xla_while),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
|
||||
}
|
||||
|
||||
// After the updates, verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, UpdateOfATupleSelect) {
|
||||
// Test changing the operands of kSelects of a tuple value and updating the
|
||||
// dataflow.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto pred = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
auto a = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto b = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
auto c = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
|
||||
auto d = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
|
||||
auto tuple_a = builder.AddInstruction(HloInstruction::CreateTuple({a}));
|
||||
auto tuple_b = builder.AddInstruction(HloInstruction::CreateTuple({b}));
|
||||
auto tuple_c = builder.AddInstruction(HloInstruction::CreateTuple({c}));
|
||||
auto tuple_d = builder.AddInstruction(HloInstruction::CreateTuple({d}));
|
||||
const Shape tuple_shape = tuple_a->shape();
|
||||
auto select_aa = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_a));
|
||||
auto select_ab = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_b));
|
||||
auto select_cd = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple_shape, HloOpcode::kSelect, pred, tuple_c, tuple_d));
|
||||
auto select_abcd = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple_shape, HloOpcode::kSelect, pred, select_ab, select_cd));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
bool ssa_form = GetParam();
|
||||
HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
// Sanity check dataflow before changing the graph and updating.
|
||||
EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a)));
|
||||
EXPECT_THAT(HloValuesAt(select_ab, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(b)));
|
||||
EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(c),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(b),
|
||||
analysis.GetValueDefinedAt(c),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(c).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
|
||||
|
||||
// Set the rhs of 'select_aa' to be 'd'.
|
||||
TF_ASSERT_OK(select_aa->ReplaceOperandWith(2, tuple_d));
|
||||
analysis.UpdateAfterChangingOperand(select_aa, /*old_operand=*/tuple_a,
|
||||
/*new_operand=*/tuple_d);
|
||||
|
||||
// Verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
|
||||
// Set the lhs of 'select_cd' to be 'a'.
|
||||
TF_ASSERT_OK(select_cd->ReplaceOperandWith(1, tuple_a));
|
||||
analysis.UpdateAfterChangingOperand(select_cd, /*old_operand=*/tuple_c,
|
||||
/*new_operand=*/tuple_a);
|
||||
|
||||
// Verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(b),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(c).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
|
||||
|
||||
// After the updates, verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
|
||||
HloDataflowAnalysisTest,
|
||||
::testing::Values(false, true));
|
||||
|
@ -561,13 +561,21 @@ tooltip = " ";
|
||||
}
|
||||
|
||||
string comp_body = DumpComputation(subcomp);
|
||||
string computation =
|
||||
Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
|
||||
|
||||
// Add an edge from the subcomputation to its parent node. If subcomp
|
||||
// belongs to a fusion node, it's drawn in place of the fusion instruction, so
|
||||
// there's no need to link those.
|
||||
if (parent_instr->opcode() != HloOpcode::kFusion) {
|
||||
if (parent_instr->opcode() == HloOpcode::kFusion) {
|
||||
// Dump any nested fusion nodes.
|
||||
for (const auto& subcomp_instr : subcomp->instructions()) {
|
||||
if (subcomp_instr->opcode() == HloOpcode::kFusion) {
|
||||
StrAppend(
|
||||
&comp_body,
|
||||
DumpSubcomputation(subcomp_instr->fused_instructions_computation(),
|
||||
subcomp_instr.get()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Add an edge from the subcomputation to its parent node. If subcomp
|
||||
// belongs to a fusion node, it's drawn in place of the fusion instruction,
|
||||
// so there's no need to link those.
|
||||
edge_ids_.insert(
|
||||
{{subcomp->root_instruction(), parent_instr}, next_edge_id_++});
|
||||
const char* edge_fmt =
|
||||
@ -578,6 +586,9 @@ tooltip = " ";
|
||||
subcomp->name(), parent_instr->name()));
|
||||
}
|
||||
|
||||
string computation =
|
||||
Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
|
||||
|
||||
return computation;
|
||||
}
|
||||
|
||||
|
@ -793,13 +793,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
}
|
||||
}
|
||||
|
||||
for (HloComputation* computation :
|
||||
instruction_to_fuse->called_computations()) {
|
||||
if (std::find(called_computations_.begin(), called_computations_.end(),
|
||||
computation) == called_computations_.end()) {
|
||||
called_computations_.push_back(computation);
|
||||
}
|
||||
}
|
||||
VLOG(2) << "New clone:\n" << clone->ToString();
|
||||
return clone;
|
||||
}
|
||||
|
@ -797,8 +797,7 @@ class HloInstruction {
|
||||
const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
|
||||
|
||||
// Returns the computations this instruction calls (if any). This includes
|
||||
// computations called by fused instructions inside of a fusion instruction.
|
||||
// Returns the computations this instruction directly calls (if any).
|
||||
const std::vector<HloComputation*>& called_computations() const {
|
||||
return called_computations_;
|
||||
}
|
||||
|
@ -758,16 +758,13 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{map_3_y}, HloInstruction::FusionKind::kLoop);
|
||||
auto* fused_computation = fusion->fused_instructions_computation();
|
||||
EXPECT_THAT(fusion->called_computations(),
|
||||
ElementsAre(fused_computation, computation_y));
|
||||
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
|
||||
|
||||
fusion->FuseInstruction(map_2_x);
|
||||
EXPECT_THAT(fusion->called_computations(),
|
||||
ElementsAre(fused_computation, computation_y, computation_x));
|
||||
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
|
||||
|
||||
fusion->FuseInstruction(map_1_x);
|
||||
EXPECT_THAT(fusion->called_computations(),
|
||||
ElementsAre(fused_computation, computation_y, computation_x));
|
||||
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, ComplexFusionOp) {
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
|
||||
@ -218,6 +219,94 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
|
||||
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
|
||||
}
|
||||
|
||||
TEST_F(HloOrderingTest, ValuesInWhileComputations) {
|
||||
// Tests the ordering of values (defined by dataflow analysis) in the body and
|
||||
// condition of a while instruction. HLO code:
|
||||
//
|
||||
// body(F32[]) %param):
|
||||
// %negate = Negate(%param)
|
||||
//
|
||||
// condition(F32[] %param):
|
||||
// %convert = Convert<PRED>(%param)
|
||||
//
|
||||
// entry:
|
||||
// %constant = Constant(1.0)
|
||||
// %while = While(%constant, body, condition)
|
||||
// %add = Add(%constant, %while)
|
||||
//
|
||||
auto module = CreateNewModule();
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
||||
|
||||
auto body_builder = HloComputation::Builder("body");
|
||||
auto body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
|
||||
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
scalar_shape, HloOpcode::kNegate, body_param));
|
||||
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
auto cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
|
||||
auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
|
||||
HloComputation* condition =
|
||||
module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto xla_while = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, constant, xla_while));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true));
|
||||
DependencyHloOrdering ordering(module.get());
|
||||
|
||||
// Init value is defined before the while, but live range is not before the
|
||||
// while because of the use of the init value in the add.
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_FALSE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
|
||||
// Any value defined in the body or condition is defined before the while, and
|
||||
// has a live range strictly before the while.
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
|
||||
// The live range of the while should be before the add.
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
|
||||
dataflow->GetValueDefinedAt(add)));
|
||||
ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
|
||||
|
||||
const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
|
||||
EXPECT_EQ(while_use.instruction, add);
|
||||
EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
|
||||
while_use, dataflow->GetValueDefinedAt(add)));
|
||||
EXPECT_TRUE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while),
|
||||
dataflow->GetValueDefinedAt(add)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
|
@ -1248,7 +1248,8 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
sequence->at(node.computation())));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
},
|
||||
/*visit_unreachable_nodes=*/false));
|
||||
|
||||
// The peak memory usage of the module equals the peak memory use of the entry
|
||||
// computation plus the output size of the computation. This is because the
|
||||
|
@ -159,12 +159,6 @@ void HloValue::AddPosition(HloInstruction* instruction,
|
||||
for (const HloPosition& position : positions_) {
|
||||
DCHECK_NE(position, new_position);
|
||||
}
|
||||
// The shape of the new position must match existing positions.
|
||||
if (!positions_.empty()) {
|
||||
CHECK(
|
||||
ShapeUtil::Compatible(positions_.front().shape(), new_position.shape()))
|
||||
<< "front: " << positions_.front() << " new: " << new_position;
|
||||
}
|
||||
|
||||
positions_.push_back(std::move(new_position));
|
||||
|
||||
|
@ -225,6 +225,9 @@ class HloValueSet {
|
||||
// already exist in the set.
|
||||
bool AddValue(const HloValue* value);
|
||||
|
||||
// Clear all values from the set.
|
||||
void Clear() { values_.clear(); }
|
||||
|
||||
// Return the unique HLO value in the set. CHECKs if the set does not contain
|
||||
// exactly one value.
|
||||
const HloValue& GetUniqueValue() const {
|
||||
|
@ -32,13 +32,11 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
const std::function<int64(const Shape&)>& shape_size_fn)
|
||||
: shape_size_fn_(shape_size_fn) {}
|
||||
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override {
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo) override {
|
||||
return CheckUnaryShape(hlo);
|
||||
}
|
||||
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override {
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo) override {
|
||||
return CheckBinaryShape(hlo);
|
||||
}
|
||||
|
||||
@ -282,6 +280,14 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
const std::function<int64(const Shape&)> shape_size_fn_;
|
||||
};
|
||||
|
||||
string ComputationsToString(
|
||||
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
|
||||
return tensorflow::str_util::Join(
|
||||
computations, ",", [](string* s, const HloComputation* computation) {
|
||||
s->append(computation->name());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
@ -292,6 +298,17 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
TF_RET_CHECK(instruction->parent() == computation.get());
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
TF_RET_CHECK(
|
||||
ContainersEqual(instruction->called_computations(),
|
||||
{instruction->fused_instructions_computation()}))
|
||||
<< "Fusion HLO calls computations other than the "
|
||||
"fused_instructions_computation: "
|
||||
<< instruction->ToString()
|
||||
<< " instruction->fused_instructions_computation(): "
|
||||
<< instruction->fused_instructions_computation()->ToString()
|
||||
<< " instruction->called_computations(): "
|
||||
<< ComputationsToString(instruction->called_computations());
|
||||
|
||||
for (const auto& fused : instruction->fused_instructions()) {
|
||||
TF_RET_CHECK(fused->parent() ==
|
||||
instruction->fused_instructions_computation())
|
||||
|
@ -122,7 +122,8 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
if (instruction->opcode() == HloOpcode::kFusion &&
|
||||
instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
// Insert the reduce-precision operation inside the fusion computation,
|
||||
// after the corresponding parameter instruction.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -171,7 +172,8 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
if (instruction->opcode() == HloOpcode::kFusion &&
|
||||
instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
// Insert the reduce-precision operation as the last operation inside
|
||||
// the fusion computation.
|
||||
HloInstruction* fusion_root = instruction->fused_expression_root();
|
||||
|
@ -28,6 +28,7 @@ py_library(
|
||||
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/fused_conv:fused_conv_py",
|
||||
"//tensorflow/contrib/gan",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||
"//tensorflow/contrib/hooks",
|
||||
@ -72,6 +73,7 @@ py_library(
|
||||
"//tensorflow/contrib/staging",
|
||||
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
|
||||
"//tensorflow/contrib/stateless",
|
||||
"//tensorflow/contrib/summary:summary_ops",
|
||||
"//tensorflow/contrib/tensor_forest:init_py",
|
||||
"//tensorflow/contrib/tensorboard",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.contrib import deprecated
|
||||
from tensorflow.contrib import distributions
|
||||
from tensorflow.contrib import factorization
|
||||
from tensorflow.contrib import framework
|
||||
from tensorflow.contrib import gan
|
||||
from tensorflow.contrib import graph_editor
|
||||
from tensorflow.contrib import grid_rnn
|
||||
from tensorflow.contrib import image
|
||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
|
||||
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
|
||||
from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2
|
||||
@ -26,18 +29,21 @@ from tensorflow.contrib.learn.python.learn import export_strategy
|
||||
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.saved_model import loader as saved_model_loader
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
|
||||
|
||||
def make_custom_export_strategy(name, convert_fn, feature_columns,
|
||||
def make_custom_export_strategy(name,
|
||||
convert_fn,
|
||||
feature_columns,
|
||||
export_input_fn):
|
||||
"""Makes custom exporter of GTFlow tree format.
|
||||
|
||||
Args:
|
||||
name: A string, for the name of the export strategy.
|
||||
convert_fn: A function that converts the tree proto to desired format and
|
||||
saves it to the desired location.
|
||||
saves it to the desired location. Can be None to skip conversion.
|
||||
feature_columns: A list of feature columns.
|
||||
export_input_fn: A function that takes no arguments and returns an
|
||||
`InputFnOps`.
|
||||
@ -68,9 +74,22 @@ def make_custom_export_strategy(name, convert_fn, feature_columns,
|
||||
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||
dtec.ParseFromString(dfec_str)
|
||||
# Export the result in the same folder as the saved model.
|
||||
convert_fn(dtec, sorted_feature_names, len(dense_floats),
|
||||
len(sparse_float_indices), len(sparse_int_indices),
|
||||
result_dir, eval_result)
|
||||
if convert_fn:
|
||||
convert_fn(dtec, sorted_feature_names,
|
||||
len(dense_floats),
|
||||
len(sparse_float_indices),
|
||||
len(sparse_int_indices), result_dir, eval_result)
|
||||
feature_importances = _get_feature_importances(
|
||||
dtec, sorted_feature_names,
|
||||
len(dense_floats),
|
||||
len(sparse_float_indices), len(sparse_int_indices))
|
||||
sorted_by_importance = sorted(
|
||||
feature_importances.items(), key=lambda x: -x[1])
|
||||
assets_dir = os.path.join(result_dir, "assets.extra")
|
||||
gfile.MakeDirs(assets_dir)
|
||||
with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
|
||||
"w") as f:
|
||||
f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
|
||||
return result_dir
|
||||
return export_strategy.ExportStrategy(name, export_fn)
|
||||
|
||||
@ -157,3 +176,41 @@ def convert_to_universal_format(dtec, sorted_feature_names,
|
||||
node.left_child_id.value = split.left_id
|
||||
node.right_child_id.value = split.right_id
|
||||
return model_and_features
|
||||
|
||||
|
||||
def _get_feature_importances(dtec, feature_names, num_dense_floats,
|
||||
num_sparse_float, num_sparse_int):
|
||||
"""Export the feature importance per feature column."""
|
||||
del num_sparse_int # Unused.
|
||||
sums = collections.defaultdict(lambda: 0)
|
||||
for tree_idx in range(len(dtec.trees)):
|
||||
tree = dtec.trees[tree_idx]
|
||||
for tree_node in tree.nodes:
|
||||
node_type = tree_node.WhichOneof("node")
|
||||
if node_type == "dense_float_binary_split":
|
||||
split = tree_node.dense_float_binary_split
|
||||
split_column = feature_names[split.feature_column]
|
||||
elif node_type == "sparse_float_binary_split_default_left":
|
||||
split = tree_node.sparse_float_binary_split_default_left.split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats]
|
||||
elif node_type == "sparse_float_binary_split_default_right":
|
||||
split = tree_node.sparse_float_binary_split_default_right.split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats]
|
||||
elif node_type == "categorical_id_binary_split":
|
||||
split = tree_node.categorical_id_binary_split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats +
|
||||
num_sparse_float]
|
||||
elif node_type == "categorical_id_set_membership_binary_split":
|
||||
split = tree_node.categorical_id_set_membership_binary_split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats +
|
||||
num_sparse_float]
|
||||
elif node_type == "leaf":
|
||||
assert tree_node.node_metadata.gain == 0
|
||||
continue
|
||||
else:
|
||||
raise ValueError("Unexpected split type %s", node_type)
|
||||
# Apply shrinkage factor. It is important since it is not always uniform
|
||||
# across different trees.
|
||||
sums[split_column] += (
|
||||
tree_node.node_metadata.gain * dtec.tree_weights[tree_idx])
|
||||
return dict(sums)
|
||||
|
@ -27,7 +27,7 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
class ConvertModelTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testConvertModel(self):
|
||||
def _make_trees(self):
|
||||
dtec_str = """
|
||||
trees {
|
||||
nodes {
|
||||
@ -108,8 +108,12 @@ class ConvertModelTest(test_util.TensorFlowTestCase):
|
||||
"""
|
||||
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||
text_format.Merge(dtec_str, dtec)
|
||||
# The feature columns in the order they were added.
|
||||
feature_columns = ["feature_b", "feature_a", "feature_d"]
|
||||
return dtec, feature_columns
|
||||
|
||||
def testConvertModel(self):
|
||||
dtec, feature_columns = self._make_trees()
|
||||
# The feature columns in the order they were added.
|
||||
out = custom_export_strategy.convert_to_universal_format(
|
||||
dtec, feature_columns, 1, 1,
|
||||
1)
|
||||
@ -273,6 +277,16 @@ class ConvertModelTest(test_util.TensorFlowTestCase):
|
||||
}"""
|
||||
self.assertProtoEquals(expected_tree, out)
|
||||
|
||||
def testFeatureImportance(self):
|
||||
dtec, feature_columns = self._make_trees()
|
||||
feature_importances = custom_export_strategy._get_feature_importances(
|
||||
dtec, feature_columns, 1, 1, 1)
|
||||
self.assertItemsEqual(["feature_b", "feature_a", "feature_d"],
|
||||
feature_importances.keys())
|
||||
self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4)
|
||||
self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4)
|
||||
self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -61,11 +61,19 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
|
||||
Raises:
|
||||
ValueError: If learner_config is not valid.
|
||||
"""
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False)
|
||||
if learner_config.num_classes == 0:
|
||||
learner_config.num_classes = n_classes
|
||||
elif learner_config.num_classes != n_classes:
|
||||
raise ValueError("n_classes (%d) doesn't match learner_config (%d)." %
|
||||
(learner_config.num_classes, n_classes))
|
||||
super(GradientBoostedDecisionTreeClassifier, self).__init__(
|
||||
model_fn=model.model_builder,
|
||||
params={
|
||||
@ -129,6 +137,10 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
|
||||
label_dimension=label_dimension,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False)
|
||||
if label_dimension == 1:
|
||||
learner_config.num_classes = 2
|
||||
else:
|
||||
learner_config.num_classes = label_dimension
|
||||
super(GradientBoostedDecisionTreeRegressor, self).__init__(
|
||||
model_fn=model.model_builder,
|
||||
params={
|
||||
|
@ -92,6 +92,7 @@ def model_builder(features, labels, mode, params, config):
|
||||
examples_per_layer=examples_per_layer,
|
||||
learner_config=learner_config,
|
||||
feature_columns=feature_columns,
|
||||
logits_dimension=head.logits_dimension,
|
||||
features=features)
|
||||
with ops.name_scope("gbdt", "gbdt_optimizer"):
|
||||
predictions_dict = gbdt_model.predict(mode)
|
||||
|
@ -74,7 +74,7 @@ class TreeEnsembleStampTokenOp : public OpKernel {
|
||||
decision_tree_ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_ensemble_resource));
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
Tensor* output_stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||
@ -95,7 +95,7 @@ class TreeEnsembleSerializeOp : public OpKernel {
|
||||
decision_tree_ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_ensemble_resource));
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
Tensor* output_stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||
|
@ -143,7 +143,7 @@ class GradientTreesPredictionOp : public OpKernel {
|
||||
// Release the reference to the resource once we're done using it.
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
if (use_locking_) {
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
DoCompute(context, decision_tree_ensemble_resource);
|
||||
} else {
|
||||
DoCompute(context, decision_tree_ensemble_resource);
|
||||
@ -334,7 +334,7 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
|
||||
// Release the reference to the resource once we're done using it.
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
if (use_locking_) {
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
DoCompute(context, decision_tree_ensemble_resource);
|
||||
} else {
|
||||
DoCompute(context, decision_tree_ensemble_resource);
|
||||
|
@ -656,7 +656,8 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
CHECK(split->split_info.split_node().node_case() != TreeNode::NODE_NOT_SET);
|
||||
CHECK(tree_config->nodes(node_id).node_case() == TreeNode::kLeaf)
|
||||
<< "Unexpected node type to split "
|
||||
<< tree_config->nodes(node_id).node_case();
|
||||
<< tree_config->nodes(node_id).node_case() << " for node_id " << node_id
|
||||
<< ". Tree config: " << tree_config->DebugString();
|
||||
|
||||
// Add left leaf.
|
||||
int32 left_id = tree_config->nodes_size();
|
||||
@ -767,7 +768,7 @@ class TreeEnsembleStatsOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_ensemble_resource));
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
|
@ -42,6 +42,7 @@ class BiasFeatureColumnHandlerTest : public ::testing::Test {
|
||||
example_partitions_({0, 0, 1, 3}) {
|
||||
// Set L2 regularization.
|
||||
learner_config_.mutable_regularization()->set_l2(2.0f);
|
||||
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
|
||||
// Create handler.
|
||||
handler_.reset(new BiasFeatureColumnHandler(kClassId, kSlotId, kBatchSize));
|
||||
|
@ -51,7 +51,7 @@ class CategoricalFeatureColumnHandlerTest : public ::testing::Test {
|
||||
values_(test::AsTensor<int64>({1, 2, 2, 0}, {4})) {
|
||||
// Set L2 regularization.
|
||||
learner_config_.mutable_regularization()->set_l2(2.0f);
|
||||
|
||||
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
// Create handler.
|
||||
handler_.reset(new CategoricalFeatureColumnHandler(
|
||||
kClassId, kSlotId, kBatchSize, kFeatureColumn, indices_.matrix<int64>(),
|
||||
|
@ -51,7 +51,7 @@ class DenseQuantizedFeatureColumnHandlerTest : public ::testing::Test {
|
||||
dense_quantized_values_(test::AsTensor<int32>({1, 1, 0, 1}, {4})) {
|
||||
// Set L2 regularization.
|
||||
learner_config_.mutable_regularization()->set_l2(2.0f);
|
||||
|
||||
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
// Create handler.
|
||||
handler_.reset(new DenseQuantizedFeatureColumnHandler(
|
||||
kClassId, kSlotId, kBatchSize, kFeatureColumn,
|
||||
|
@ -53,7 +53,7 @@ class SparseQuantizedFeatureColumnHandlerTest : public ::testing::Test {
|
||||
sparse_quantized_values_(test::AsTensor<int32>({1, 0, 1}, {3})) {
|
||||
// Set L2 regularization.
|
||||
learner_config_.mutable_regularization()->set_l2(2.0f);
|
||||
|
||||
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
// Create handler.
|
||||
handler_.reset(new SparseQuantizedFeatureColumnHandler(
|
||||
kClassId, kSlotId, kBatchSize, kFeatureColumn,
|
||||
|
@ -30,6 +30,7 @@ const double kDelta = 1e-5;
|
||||
|
||||
TEST(NodeStatsTest, AlmostZero) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
NodeStats node_stats(learner_config, GradientStats(1e-8f, 1e-8f));
|
||||
EXPECT_EQ(0, node_stats.weight_contribution[0]);
|
||||
EXPECT_EQ(0, node_stats.gain);
|
||||
@ -37,6 +38,7 @@ TEST(NodeStatsTest, AlmostZero) {
|
||||
|
||||
TEST(NodeStatsTest, LessThanMinWeightConstraint) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_constraints()->set_min_node_weight(3.2f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
EXPECT_EQ(0, node_stats.weight_contribution[0]);
|
||||
@ -45,6 +47,7 @@ TEST(NodeStatsTest, LessThanMinWeightConstraint) {
|
||||
|
||||
TEST(NodeStatsTest, L1RegSquashed) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l1(10.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
EXPECT_EQ(0, node_stats.weight_contribution[0]);
|
||||
@ -53,6 +56,7 @@ TEST(NodeStatsTest, L1RegSquashed) {
|
||||
|
||||
TEST(NodeStatsTest, L1RegPos) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l1(5.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
const float expected_clipped_grad = 7.32f - 5.0f;
|
||||
@ -66,6 +70,7 @@ TEST(NodeStatsTest, L1RegPos) {
|
||||
|
||||
TEST(NodeStatsTest, L1RegNeg) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l1(5.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(-7.32f, 1.63f));
|
||||
const float expected_clipped_grad = -7.32f + 5.0f;
|
||||
@ -79,6 +84,7 @@ TEST(NodeStatsTest, L1RegNeg) {
|
||||
|
||||
TEST(NodeStatsTest, L2Reg) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l2(8.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
const float expected_denom = 1.63f + 8.0f;
|
||||
@ -91,6 +97,7 @@ TEST(NodeStatsTest, L2Reg) {
|
||||
|
||||
TEST(NodeStatsTest, L1L2Reg) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l1(5.0f);
|
||||
learner_config.mutable_regularization()->set_l2(8.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
|
@ -15,6 +15,7 @@
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
|
||||
@ -34,10 +35,27 @@ class WeightedQuantilesSummary {
|
||||
|
||||
struct SummaryEntry {
|
||||
SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
|
||||
const WeightType& max)
|
||||
: value(v), weight(w), min_rank(min), max_rank(max) {}
|
||||
const WeightType& max) {
|
||||
// Explicitely initialize all of memory (including padding from memory
|
||||
// alignment) to allow the struct to be msan-resistant "plain old data".
|
||||
//
|
||||
// POD = http://en.cppreference.com/w/cpp/concept/PODType
|
||||
memset(this, 0, sizeof(*this));
|
||||
|
||||
SummaryEntry() : value(0), weight(0), min_rank(0), max_rank(0) {}
|
||||
value = v;
|
||||
weight = w;
|
||||
min_rank = min;
|
||||
max_rank = max;
|
||||
}
|
||||
|
||||
SummaryEntry() {
|
||||
memset(this, 0, sizeof(*this));
|
||||
|
||||
value = 0;
|
||||
weight = 0;
|
||||
min_rank = 0;
|
||||
max_rank = 0;
|
||||
}
|
||||
|
||||
bool operator==(const SummaryEntry& other) const {
|
||||
return value == other.value && weight == other.weight &&
|
||||
|
@ -17,7 +17,7 @@ message TreeRegularizationConfig {
|
||||
|
||||
// Tree constraints config.
|
||||
message TreeConstraintsConfig {
|
||||
// Maximum depth of the trees.
|
||||
// Maximum depth of the trees. The default value is 6 if not specified.
|
||||
uint32 max_tree_depth = 1;
|
||||
|
||||
// Min hessian weight per node.
|
||||
@ -86,20 +86,22 @@ message LearningRateDropoutDrivenConfig {
|
||||
|
||||
message LearnerConfig {
|
||||
enum PruningMode {
|
||||
PRE_PRUNE = 0;
|
||||
POST_PRUNE = 1;
|
||||
PRUNING_MODE_UNSPECIFIED = 0;
|
||||
PRE_PRUNE = 1;
|
||||
POST_PRUNE = 2;
|
||||
}
|
||||
|
||||
enum GrowingMode {
|
||||
WHOLE_TREE = 0;
|
||||
// Layer by layer is only supported by the batch learner.
|
||||
LAYER_BY_LAYER = 1;
|
||||
GROWING_MODE_UNSPECIFIED = 0;
|
||||
WHOLE_TREE = 1;
|
||||
LAYER_BY_LAYER = 2;
|
||||
}
|
||||
|
||||
enum MultiClassStrategy {
|
||||
TREE_PER_CLASS = 0;
|
||||
FULL_HESSIAN = 1;
|
||||
DIAGONAL_HESSIAN = 2;
|
||||
MULTI_CLASS_STRATEGY_UNSPECIFIED = 0;
|
||||
TREE_PER_CLASS = 1;
|
||||
FULL_HESSIAN = 2;
|
||||
DIAGONAL_HESSIAN = 3;
|
||||
}
|
||||
|
||||
// Number of classes.
|
||||
@ -118,16 +120,18 @@ message LearnerConfig {
|
||||
// Constraints.
|
||||
TreeConstraintsConfig constraints = 5;
|
||||
|
||||
// Pruning.
|
||||
// Pruning. POST_PRUNE is the default pruning mode.
|
||||
PruningMode pruning_mode = 8;
|
||||
|
||||
// Growing Mode.
|
||||
// Growing Mode. LAYER_BY_LAYER is the default growing mode.
|
||||
GrowingMode growing_mode = 9;
|
||||
|
||||
// Learning rate.
|
||||
// Learning rate. By default we use fixed learning rate of 0.1.
|
||||
LearningRateConfig learning_rate_tuner = 6;
|
||||
|
||||
// Multi-class strategy.
|
||||
// Multi-class strategy. By default we use TREE_PER_CLASS for binary
|
||||
// classification and linear regression. For other cases, we use
|
||||
// DIAGONAL_HESSIAN as the default.
|
||||
MultiClassStrategy multi_class_strategy = 10;
|
||||
|
||||
// If you want to average the ensembles (for regularization), provide the
|
||||
|
@ -344,6 +344,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
|
||||
# Prepare learner config.
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.num_classes = 2
|
||||
learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE
|
||||
|
||||
result, result_no_dropout, dropout_info = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
|
@ -261,6 +261,7 @@ class GradientBoostedDecisionTreeModel(object):
|
||||
examples_per_layer,
|
||||
learner_config,
|
||||
features,
|
||||
logits_dimension,
|
||||
feature_columns=None):
|
||||
"""Construct a new GradientBoostedDecisionTreeModel function.
|
||||
|
||||
@ -273,8 +274,8 @@ class GradientBoostedDecisionTreeModel(object):
|
||||
a tree layer. It can also be a function that computes the number of
|
||||
examples based on the depth of the layer that's being built.
|
||||
learner_config: A learner config.
|
||||
print split, sorted_feature_names[split.feature_column]
|
||||
features: `dict` of `Tensor` objects.
|
||||
logits_dimension: An int, the dimension of logits.
|
||||
feature_columns: A list of feature columns.
|
||||
|
||||
Raises:
|
||||
@ -289,11 +290,39 @@ class GradientBoostedDecisionTreeModel(object):
|
||||
if learner_config.num_classes < 2:
|
||||
raise ValueError("Number of classes must be >=2")
|
||||
|
||||
self._logits_dimension = logits_dimension
|
||||
self._is_chief = is_chief
|
||||
self._num_ps_replicas = num_ps_replicas
|
||||
self._ensemble_handle = ensemble_handle
|
||||
self._center_bias = center_bias
|
||||
self._examples_per_layer = examples_per_layer
|
||||
|
||||
# Fill in the defaults.
|
||||
if (learner_config.multi_class_strategy ==
|
||||
learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED):
|
||||
if logits_dimension == 1:
|
||||
learner_config.multi_class_strategy = (
|
||||
learner_pb2.LearnerConfig.TREE_PER_CLASS)
|
||||
else:
|
||||
learner_config.multi_class_strategy = (
|
||||
learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
|
||||
|
||||
if (learner_config.growing_mode ==
|
||||
learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED):
|
||||
learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
|
||||
|
||||
if (learner_config.pruning_mode ==
|
||||
learner_pb2.LearnerConfig.PRUNING_MODE_UNSPECIFIED):
|
||||
learner_config.pruning_mode = learner_pb2.LearnerConfig.POST_PRUNE
|
||||
|
||||
if learner_config.constraints.max_tree_depth == 0:
|
||||
# Use 6 as the default maximum depth.
|
||||
learner_config.constraints.max_tree_depth = 6
|
||||
|
||||
tuner = learner_config.learning_rate_tuner.WhichOneof("tuner")
|
||||
if not tuner:
|
||||
learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
|
||||
|
||||
self._learner_config = learner_config
|
||||
self._feature_columns = feature_columns
|
||||
self._learner_config_serialized = learner_config.SerializeToString()
|
||||
@ -378,75 +407,81 @@ class GradientBoostedDecisionTreeModel(object):
|
||||
local_stamp), _refresh_local_ensemble_fn,
|
||||
lambda: (control_flow_ops.no_op(), ensemble_stamp))
|
||||
|
||||
# Once updated, Use the the local model for prediction.
|
||||
# Once updated, use the local model for prediction.
|
||||
with ops.control_dependencies([refresh_local_ensemble]):
|
||||
ensemble_stats = training_ops.tree_ensemble_stats(
|
||||
local_ensemble_handle, ensemble_stamp)
|
||||
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
|
||||
# We don't need dropout info - we can always restore it based on the
|
||||
# seed.
|
||||
predictions, predictions_no_dropout, _ = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
local_ensemble_handle,
|
||||
seed,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
learner_config=self._learner_config_serialized,
|
||||
apply_dropout=apply_dropout,
|
||||
apply_averaging=apply_averaging,
|
||||
use_locking=False,
|
||||
center_bias=self._center_bias,
|
||||
reduce_dim=self._reduce_dim))
|
||||
partition_ids = prediction_ops.gradient_trees_partition_examples(
|
||||
local_ensemble_handle,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
use_locking=False)
|
||||
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
|
||||
# Make sure ensemble stats run. This will check that the ensemble has
|
||||
# the right stamp.
|
||||
with ops.control_dependencies(ensemble_stats):
|
||||
predictions, predictions_no_dropout, _ = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
local_ensemble_handle,
|
||||
seed,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
learner_config=self._learner_config_serialized,
|
||||
apply_dropout=apply_dropout,
|
||||
apply_averaging=apply_averaging,
|
||||
use_locking=True,
|
||||
center_bias=self._center_bias,
|
||||
reduce_dim=self._reduce_dim))
|
||||
partition_ids = prediction_ops.gradient_trees_partition_examples(
|
||||
local_ensemble_handle,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
use_locking=True)
|
||||
|
||||
else:
|
||||
with ops.device(self._ensemble_handle.device):
|
||||
ensemble_stats = training_ops.tree_ensemble_stats(
|
||||
self._ensemble_handle, ensemble_stamp)
|
||||
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
|
||||
# We don't need dropout info - we can always restore it based on the
|
||||
# seed.
|
||||
predictions, predictions_no_dropout, _ = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
self._ensemble_handle,
|
||||
seed,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
learner_config=self._learner_config_serialized,
|
||||
apply_dropout=apply_dropout,
|
||||
apply_averaging=apply_averaging,
|
||||
use_locking=False,
|
||||
center_bias=self._center_bias,
|
||||
reduce_dim=self._reduce_dim))
|
||||
partition_ids = prediction_ops.gradient_trees_partition_examples(
|
||||
self._ensemble_handle,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
use_locking=False)
|
||||
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
|
||||
# Make sure ensemble stats run. This will check that the ensemble has
|
||||
# the right stamp.
|
||||
with ops.control_dependencies(ensemble_stats):
|
||||
predictions, predictions_no_dropout, _ = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
self._ensemble_handle,
|
||||
seed,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
learner_config=self._learner_config_serialized,
|
||||
apply_dropout=apply_dropout,
|
||||
apply_averaging=apply_averaging,
|
||||
use_locking=True,
|
||||
center_bias=self._center_bias,
|
||||
reduce_dim=self._reduce_dim))
|
||||
partition_ids = prediction_ops.gradient_trees_partition_examples(
|
||||
self._ensemble_handle,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
use_locking=True)
|
||||
|
||||
return _make_predictions_dict(ensemble_stamp, predictions,
|
||||
predictions_no_dropout, partition_ids,
|
||||
|
@ -164,7 +164,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -268,7 +268,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=num_examples_fn,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -371,7 +371,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -442,7 +442,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -505,7 +505,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -588,7 +588,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
# Create predict op.
|
||||
mode = model_fn.ModeKeys.EVAL
|
||||
@ -627,7 +627,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=5, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
|
||||
@ -730,7 +730,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=5, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
|
||||
@ -833,7 +833,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=5, features=features)
|
||||
|
||||
batch_size = 3
|
||||
predictions = array_ops.constant(
|
||||
|
4
tensorflow/contrib/cmake/external/cub.cmake
vendored
4
tensorflow/contrib/cmake/external/cub.cmake
vendored
@ -14,8 +14,8 @@
|
||||
# ==============================================================================
|
||||
include (ExternalProject)
|
||||
|
||||
set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/69ceda618313df8e9cac6659d607b08949455d14.tar.gz)
|
||||
set(cub_HASH SHA256=87e856522c283b8ea887c3b61d7d5b252d2dd74abac4f1d756d776e721223e82)
|
||||
set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip)
|
||||
set(cub_HASH SHA256=b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe)
|
||||
set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
|
||||
set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
|
||||
set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive)
|
||||
|
@ -18,6 +18,7 @@
|
||||
set(tf_c_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
|
||||
"${tensorflow_source_dir}/tensorflow/c/c_api_function.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"
|
||||
|
@ -315,6 +315,7 @@ add_python_module("tensorflow/contrib/framework/ops")
|
||||
add_python_module("tensorflow/contrib/framework/python")
|
||||
add_python_module("tensorflow/contrib/framework/python/framework")
|
||||
add_python_module("tensorflow/contrib/framework/python/ops")
|
||||
add_python_module("tensorflow/contrib/gan")
|
||||
add_python_module("tensorflow/contrib/graph_editor")
|
||||
add_python_module("tensorflow/contrib/graph_editor/examples")
|
||||
add_python_module("tensorflow/contrib/graph_editor/tests")
|
||||
|
@ -291,6 +291,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
# Failing with TF 1.3 (TODO)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py"
|
||||
# Test should only be run manually
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py"
|
||||
)
|
||||
endif()
|
||||
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
|
||||
|
@ -716,6 +716,482 @@ _cudnn_rnn_common_doc_string = """
|
||||
"""
|
||||
|
||||
|
||||
def _check_direction(direction):
|
||||
if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
|
||||
raise ValueError("Invalid direction: %s, expect %s or %s" %
|
||||
(direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION))
|
||||
|
||||
|
||||
def _check_rnn_mode(rnn_mode):
|
||||
if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU):
|
||||
raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" %
|
||||
(rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH,
|
||||
CUDNN_RNN_RELU))
|
||||
|
||||
|
||||
def _get_seed(seed):
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
if seed is None and seed2 is None:
|
||||
seed, seed2 = 0, 0
|
||||
return seed, seed2
|
||||
|
||||
|
||||
def _get_num_params(rnn_mode, num_layers, direction):
|
||||
"""Return num params for given Cudnn config."""
|
||||
if rnn_mode == CUDNN_LSTM:
|
||||
num_params_per_layer = 8
|
||||
elif rnn_mode == CUDNN_GRU:
|
||||
num_params_per_layer = 6
|
||||
elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH):
|
||||
num_params_per_layer = 2
|
||||
else:
|
||||
raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode)
|
||||
num_params = num_layers * num_params_per_layer
|
||||
if direction != CUDNN_RNN_UNIDIRECTION:
|
||||
num_params *= 2
|
||||
return num_params
|
||||
|
||||
|
||||
def _cudnn_rnn(inputs,
|
||||
input_h,
|
||||
input_c,
|
||||
params,
|
||||
is_training,
|
||||
rnn_mode,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn RNN.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
||||
A Tensor of the same shape as input_h.
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h, output_c
|
||||
"""
|
||||
_check_rnn_mode(rnn_mode)
|
||||
_check_direction(direction)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
||||
input=inputs,
|
||||
input_h=input_h,
|
||||
input_c=input_c,
|
||||
params=params,
|
||||
is_training=is_training,
|
||||
rnn_mode=rnn_mode,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
name=name)
|
||||
return (outputs, output_h, output_c)
|
||||
|
||||
|
||||
def cudnn_lstm(inputs,
|
||||
input_h,
|
||||
input_c,
|
||||
params,
|
||||
is_training,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn LSTM.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
||||
A Tensor of the same shape as input_h.
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h, output_c
|
||||
"""
|
||||
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
|
||||
input_mode, direction, dropout, seed, name)
|
||||
|
||||
|
||||
def _cudnn_rnn_no_input_c(inputs,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
rnn_mode,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn RNN w/o input_c.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h
|
||||
"""
|
||||
input_c = array_ops.constant([], dtype=input_h.dtype)
|
||||
outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params,
|
||||
is_training, rnn_mode, input_mode,
|
||||
direction, dropout, seed, name)
|
||||
return outputs, output_h
|
||||
|
||||
|
||||
def cudnn_gru(inputs,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn GRU.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h
|
||||
"""
|
||||
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU,
|
||||
input_mode, direction, dropout, seed, name)
|
||||
|
||||
|
||||
def cudnn_rnn_relu(inputs,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn RNN Relu.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h
|
||||
"""
|
||||
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
|
||||
CUDNN_RNN_RELU, input_mode, direction, dropout,
|
||||
seed, name)
|
||||
|
||||
|
||||
def cudnn_rnn_tanh(inputs,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn RNN Tanh.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h
|
||||
"""
|
||||
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
|
||||
CUDNN_RNN_TANH, input_mode, direction, dropout,
|
||||
seed, name)
|
||||
|
||||
|
||||
def cudnn_rnn_params_to_canonical(rnn_mode,
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
params,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Convert cudnn opaque params to canonical.
|
||||
|
||||
Args:
|
||||
rnn_mode: a string specifies the mode, under which this RNN model runs.
|
||||
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
|
||||
num_layers: the number of layers for the RNN model.
|
||||
num_units: the number of units within the RNN model.
|
||||
input_size: the size of the input, it could be different from the
|
||||
num_units.
|
||||
params: opaque cudnn params var.
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
weights list and bias list
|
||||
Raises:
|
||||
ValueError: if rnn_mode or direction is invalid.
|
||||
"""
|
||||
|
||||
_check_rnn_mode(rnn_mode)
|
||||
_check_direction(direction)
|
||||
num_params = _get_num_params(rnn_mode, num_layers, direction)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
params=params,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
num_params=num_params,
|
||||
name=name)
|
||||
return weights, biases
|
||||
|
||||
|
||||
def cudnn_rnn_canonical_to_params(rnn_mode,
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
weights,
|
||||
biases,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Converts params from the canonical format to a specific format of cuDNN.
|
||||
|
||||
Args:
|
||||
rnn_mode: a string specifies the mode, under which this RNN model runs.
|
||||
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
|
||||
num_layers: the number of layers for the RNN model.
|
||||
num_units: the number of units within the RNN model.
|
||||
input_size: the size of the input, it could be different from the
|
||||
num_units.
|
||||
weights: a Tensor for weight parameters.
|
||||
biases: a Tensor for bias parameters.
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
an opaque Cudnn param.
|
||||
Raises:
|
||||
ValueError: if rnn_mode or direction is invalid.
|
||||
"""
|
||||
_check_rnn_mode(rnn_mode)
|
||||
_check_direction(direction)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
weights=weights,
|
||||
biases=biases,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
name=name)
|
||||
|
||||
|
||||
def cudnn_opaque_params_size(rnn_mode,
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dtype=dtypes.float32,
|
||||
dropout=0,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Returns opaque params size for specific Cudnn config.
|
||||
|
||||
Args:
|
||||
rnn_mode: a string specifies the mode, under which this RNN model runs.
|
||||
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
|
||||
num_layers: the number of layers for the RNN model.
|
||||
num_units: the number of units within the RNN model.
|
||||
input_size: the size of the input, it could be different from the
|
||||
num_units.
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dtype: one of tf.float32 or tf.float64.
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
a int, size of Cudnn opaque params.
|
||||
Raises:
|
||||
ValueError: if rnn_mode or direction is invalid.
|
||||
"""
|
||||
_check_rnn_mode(rnn_mode)
|
||||
_check_direction(direction)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
T=dtype,
|
||||
S=dtypes.int32,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
name=name)[0]
|
||||
|
||||
|
||||
class _CudnnRNN(object):
|
||||
"""Creates an RNN model using the underlying Cudnn implementation.
|
||||
|
||||
@ -761,9 +1237,6 @@ class _CudnnRNN(object):
|
||||
Raises:
|
||||
ValueError: if direction is invalid.
|
||||
"""
|
||||
if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
|
||||
raise ValueError("Invalid direction: %s, expect %s or %s",
|
||||
direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)
|
||||
self._num_layers = num_layers
|
||||
self._num_units = num_units
|
||||
self._input_size = input_size
|
||||
@ -772,10 +1245,7 @@ class _CudnnRNN(object):
|
||||
self._direction = direction
|
||||
self._dtype = dtype
|
||||
self._dropout = dropout
|
||||
# get graph and op seed.
|
||||
self._seed, self._seed2 = random_seed.get_seed(seed)
|
||||
if self._seed is None and self._seed2 is None:
|
||||
self._seed, self._seed2 = 0, 0
|
||||
self._seed = seed
|
||||
|
||||
@property
|
||||
def input_mode(self):
|
||||
@ -807,18 +1277,16 @@ class _CudnnRNN(object):
|
||||
Returns:
|
||||
The calculated parameter buffer size.
|
||||
"""
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
|
||||
return cudnn_opaque_params_size(
|
||||
rnn_mode=self._rnn_mode,
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
T=self._dtype,
|
||||
S=dtypes.int32,
|
||||
dtype=self._dtype,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)[0]
|
||||
direction=self._direction)
|
||||
|
||||
def __call__(self, input_data, input_h, input_c, params, is_training=True):
|
||||
"""Runs the forward step for the RNN model.
|
||||
@ -837,22 +1305,17 @@ class _CudnnRNN(object):
|
||||
output_h: the final state for h.
|
||||
output_c: the final state for c. This is only relevant for LSTM.
|
||||
"""
|
||||
if self._rnn_mode != CUDNN_LSTM:
|
||||
# For model that doesn't take input_c, replace with a dummy tensor.
|
||||
input_c = array_ops.constant([], dtype=self._dtype)
|
||||
output, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
||||
input=input_data,
|
||||
input_h=input_h,
|
||||
input_c=input_c,
|
||||
params=params,
|
||||
rnn_mode=self._rnn_mode,
|
||||
return _cudnn_rnn(
|
||||
input_data,
|
||||
input_h,
|
||||
input_c,
|
||||
params,
|
||||
is_training,
|
||||
self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
is_training=is_training)
|
||||
return (output, output_h, output_c)
|
||||
seed=self._seed)
|
||||
|
||||
def params_to_canonical(self, params):
|
||||
"""Converts params from a specific format of cuDNN to the canonical format.
|
||||
@ -863,22 +1326,16 @@ class _CudnnRNN(object):
|
||||
Returns:
|
||||
A function for the specific-to-canonical conversion.
|
||||
"""
|
||||
num_params = self._num_layers * self._NUM_PARAMS_PER_LAYER
|
||||
if self._direction != CUDNN_RNN_UNIDIRECTION:
|
||||
num_params *= 2
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||
return cudnn_rnn_params_to_canonical(
|
||||
rnn_mode=self._rnn_mode,
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
params=params,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
num_params=num_params,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)
|
||||
return weights, biases
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed)
|
||||
|
||||
def canonical_to_params(self, weights, biases):
|
||||
"""Converts params from the canonical format to a specific format of cuDNN.
|
||||
@ -890,18 +1347,17 @@ class _CudnnRNN(object):
|
||||
Returns:
|
||||
A function for the canonical-to-params-to-specific conversion..
|
||||
"""
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||
return cudnn_rnn_canonical_to_params(
|
||||
rnn_mode=self._rnn_mode,
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
weights=weights,
|
||||
biases=biases,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed)
|
||||
|
||||
|
||||
class CudnnLSTM(_CudnnRNN):
|
||||
@ -1036,9 +1492,16 @@ class _CudnnRNNNoInputC(_CudnnRNN):
|
||||
output: the output sequuence.
|
||||
output_h: the final state for h.
|
||||
"""
|
||||
output, output_h, _ = super(_CudnnRNNNoInputC, self).__call__(
|
||||
input_data, input_h, None, params, is_training=is_training)
|
||||
return (output, output_h)
|
||||
return _cudnn_rnn_no_input_c(
|
||||
input_data,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed)
|
||||
|
||||
|
||||
class CudnnGRU(_CudnnRNNNoInputC):
|
||||
|
@ -22,6 +22,7 @@
|
||||
|
||||
@@read_batch_features
|
||||
@@rejection_resample
|
||||
@@group_by_window
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -31,6 +32,7 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import group_by_window
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import Iterator
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample
|
||||
|
@ -37,7 +37,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
|
||||
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
|
||||
.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -61,8 +63,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
components = np.array(
|
||||
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1)
|
||||
.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -81,8 +84,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
def testSmallGroups(self):
|
||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
|
||||
dataset_ops.Dataset.from_tensor_slices(components).apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -108,8 +112,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: (x, ops.convert_to_tensor([x * x])))
|
||||
.group_by_window(lambda x, _: x % 2, reduce_func, 32))
|
||||
.map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x, _: x % 2, reduce_func, 32)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -124,17 +129,20 @@ class GroupByWindowTest(test.TestCase):
|
||||
def reduce_func(key, window):
|
||||
# Apply two different kinds of padding to the input: tight
|
||||
# padding, and quantized (to a multiple of 10) padding.
|
||||
return dataset_ops.Dataset.zip((window.padded_batch(
|
||||
4,
|
||||
padded_shapes=tensor_shape.TensorShape([None])), window.padded_batch(
|
||||
return dataset_ops.Dataset.zip((
|
||||
window.padded_batch(
|
||||
4, padded_shapes=tensor_shape.TensorShape([None])),
|
||||
window.padded_batch(
|
||||
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
|
||||
.group_by_window(
|
||||
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
|
||||
reduce_func, 4))
|
||||
.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=
|
||||
(lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
|
||||
reduce_func, 4)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -151,10 +159,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
self.assertEqual(len(components), sum(counts))
|
||||
|
||||
|
||||
# NOTE(mrry): These tests are based on the tests in
|
||||
# bucket_ops_test.py. Currently, different batch sizes for each key
|
||||
# are not supported, although this would be possible to add to
|
||||
# `Dataset.group_by_window()`.
|
||||
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
|
||||
# Currently, they use a constant batch size, though should be made to use a
|
||||
# different batch size per key.
|
||||
class BucketTest(test.TestCase):
|
||||
|
||||
def _dynamicPad(self, bucket, window, window_size):
|
||||
@ -168,6 +175,7 @@ class BucketTest(test.TestCase):
|
||||
tensor_shape.TensorShape([3])))))
|
||||
|
||||
def testSingleBucket(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
@ -175,9 +183,10 @@ class BucketTest(test.TestCase):
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda x, y, z: 0, lambda k, bucket: self._dynamicPad(k, bucket, 32),
|
||||
32)
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x, y, z: 0,
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
@ -201,6 +210,7 @@ class BucketTest(test.TestCase):
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
|
||||
|
||||
def testEvenOddBuckets(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
@ -208,9 +218,10 @@ class BucketTest(test.TestCase):
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
@ -256,25 +267,31 @@ class BucketTest(test.TestCase):
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
|
||||
|
||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return {"x": v,
|
||||
"y": array_ops.fill([v], v),
|
||||
"z": array_ops.fill([3], string_ops.as_string(v))}
|
||||
return {
|
||||
"x": v,
|
||||
"y": array_ops.fill([v], v),
|
||||
"z": array_ops.fill([3], string_ops.as_string(v))
|
||||
}
|
||||
|
||||
def _dynamic_pad_fn(bucket, window, _):
|
||||
return dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensors(bucket), window.padded_batch(
|
||||
32, {"x": tensor_shape.TensorShape([]),
|
||||
"y": tensor_shape.TensorShape([None]),
|
||||
"z": tensor_shape.TensorShape([3])})))
|
||||
32, {
|
||||
"x": tensor_shape.TensorShape([]),
|
||||
"y": tensor_shape.TensorShape([None]),
|
||||
"z": tensor_shape.TensorShape([3])
|
||||
})))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
|
||||
.filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
|
||||
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
|
||||
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
@ -295,6 +312,40 @@ class BucketTest(test.TestCase):
|
||||
self.assertAllEqual(
|
||||
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
|
||||
|
||||
def testDynamicWindowSize(self):
|
||||
components = np.arange(100).astype(np.int64)
|
||||
|
||||
# Key fn: even/odd
|
||||
# Reduce fn: batches of 5
|
||||
# Window size fn: even=5, odd=10
|
||||
|
||||
def window_size_func(key):
|
||||
window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
|
||||
return window_sizes[key]
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x: x % 2, lambda _, xs: xs.batch(20), None,
|
||||
window_size_func))
|
||||
iterator = dataset_ops.Iterator.from_dataset(dataset)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
batches = 0
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
is_even = all(x % 2 == 0 for x in result)
|
||||
is_odd = all(x % 2 == 1 for x in result)
|
||||
self.assertTrue(is_even or is_odd)
|
||||
expected_batch_size = 5 if is_even else 10
|
||||
self.assertEqual(expected_batch_size, result.shape[0])
|
||||
batches += 1
|
||||
|
||||
self.assertEqual(batches, 15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -1199,28 +1199,9 @@ class Dataset(object):
|
||||
return DenseToSparseBatchDataset(self, batch_size, row_shape)
|
||||
|
||||
def group_by_window(self, key_func, reduce_func, window_size):
|
||||
"""Performs a windowed "group-by" operation on this dataset.
|
||||
|
||||
This method maps each consecutive element in this dataset to a key
|
||||
using `key_func` and groups the elements by key. It then applies
|
||||
`reduce_func` to at most `window_size` elements matching the same
|
||||
key. All execpt the final window for each key will contain
|
||||
`window_size` elements; the final window may be smaller.
|
||||
|
||||
Args:
|
||||
key_func: A function mapping a nested structure of tensors
|
||||
(having shapes and types defined by `self.output_shapes` and
|
||||
`self.output_types`) to a scalar `tf.int64` tensor.
|
||||
reduce_func: A function mapping a key and a dataset of up to `batch_size`
|
||||
consecutive elements matching that key to another dataset.
|
||||
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
||||
consecutive elements matching the same key to combine in a single
|
||||
batch, which will be passed to `reduce_func`.
|
||||
|
||||
Returns:
|
||||
A `Dataset`.
|
||||
"""
|
||||
return GroupByWindowDataset(self, key_func, reduce_func, window_size)
|
||||
"""See group_by_window()."""
|
||||
return self.apply(
|
||||
group_by_window, args=(key_func, reduce_func, window_size))
|
||||
|
||||
def map(self,
|
||||
map_func,
|
||||
@ -1370,6 +1351,43 @@ class Dataset(object):
|
||||
"""
|
||||
return FilterDataset(self, predicate)
|
||||
|
||||
def apply(self, fn, args=(), kwargs={}): # pylint: disable=dangerous-default-value
|
||||
"""Apply a function to this dataset.
|
||||
|
||||
`apply` enables chaining of custom `Dataset` transformations.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
dataset.map(
|
||||
lambda x: x**2
|
||||
).apply(
|
||||
group_by_window, args=(key_func, reduce_func, window_size)
|
||||
).map(
|
||||
lambda x: x**3
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
fn: A function that takes a `Dataset`, `args`, and `kwargs`, and
|
||||
returns a `Dataset`.
|
||||
args: A `tuple` or `list` of arguments to be passed to `fn`.
|
||||
kwargs: A `dict` of keyword arguments to be passed to `fn`.
|
||||
|
||||
Returns:
|
||||
The `Dataset` returned by `fn`.
|
||||
"""
|
||||
if not (isinstance(args, tuple) or isinstance(args, list)):
|
||||
raise TypeError("args must be a tuple or list.")
|
||||
if not isinstance(kwargs, dict):
|
||||
raise TypeError("kwargs must be a dict.")
|
||||
|
||||
dataset = fn(self, *args, **kwargs)
|
||||
|
||||
if not isinstance(dataset, Dataset):
|
||||
raise TypeError("fn must return a Dataset.")
|
||||
return dataset
|
||||
|
||||
|
||||
class TensorDataset(Dataset):
|
||||
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
|
||||
@ -1927,71 +1945,6 @@ class _ResourceDataset(Dataset):
|
||||
return self._output_types
|
||||
|
||||
|
||||
class GroupByWindowDataset(Dataset):
|
||||
"""A `Dataset` that groups its input and performs a windowed reduction."""
|
||||
|
||||
def __init__(self, input_dataset, key_func, reduce_func, window_size):
|
||||
"""See `Dataset.group_by_window()` for details."""
|
||||
super(GroupByWindowDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
self._window_size = window_size
|
||||
|
||||
@function.Defun(*nest.flatten(input_dataset.output_types))
|
||||
def tf_key_func(*args):
|
||||
"""A wrapper for Defun that facilitates shape inference."""
|
||||
# Pass in shape information from the input_dataset.
|
||||
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
|
||||
arg.set_shape(shape)
|
||||
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
|
||||
if _should_unpack_args(nested_args):
|
||||
ret = key_func(*nested_args)
|
||||
else:
|
||||
ret = key_func(nested_args)
|
||||
ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
|
||||
if ret.dtype != dtypes.int64:
|
||||
raise ValueError("`key_func` must return a single tf.int64 tensor.")
|
||||
return ret
|
||||
|
||||
self._key_func = tf_key_func
|
||||
self._key_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
@function.Defun(dtypes.int64, dtypes.resource)
|
||||
def tf_reduce_func(key, window_dataset_resource):
|
||||
"""A wrapper for Defun that facilitates shape inference."""
|
||||
key.set_shape([])
|
||||
window_dataset = _ResourceDataset(window_dataset_resource,
|
||||
input_dataset.output_types,
|
||||
input_dataset.output_shapes)
|
||||
output_dataset = reduce_func(key, window_dataset)
|
||||
if not isinstance(output_dataset, Dataset):
|
||||
raise TypeError("`reduce_func` must return a `Dataset` object.")
|
||||
self._output_types = output_dataset.output_types
|
||||
self._output_shapes = output_dataset.output_shapes
|
||||
return output_dataset.make_dataset_resource()
|
||||
|
||||
self._reduce_func = tf_reduce_func
|
||||
self._reduce_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
def make_dataset_resource(self):
|
||||
return gen_dataset_ops.group_by_window_dataset(
|
||||
self._input_dataset.make_dataset_resource(),
|
||||
self._key_func.captured_inputs,
|
||||
self._reduce_func.captured_inputs,
|
||||
self._window_size,
|
||||
key_func=self._key_func,
|
||||
reduce_func=self._reduce_func,
|
||||
output_types=nest.flatten(self.output_types),
|
||||
output_shapes=nest.flatten(self.output_shapes))
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._output_types
|
||||
|
||||
|
||||
class MapDataset(Dataset):
|
||||
"""A `Dataset` that maps a function over elements in its input."""
|
||||
|
||||
@ -2660,3 +2613,149 @@ def _get_file_names(file_pattern, randomize_input):
|
||||
if not randomize_input:
|
||||
file_names = sorted(file_names)
|
||||
return file_names
|
||||
|
||||
|
||||
class GroupByWindowDataset(Dataset):
|
||||
"""A `Dataset` that groups its input and performs a windowed reduction."""
|
||||
|
||||
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
|
||||
"""See `group_by_window()` for details."""
|
||||
super(GroupByWindowDataset, self).__init__()
|
||||
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._make_key_func(key_func, input_dataset)
|
||||
self._make_reduce_func(reduce_func, input_dataset)
|
||||
self._make_window_size_func(window_size_func)
|
||||
|
||||
def _make_window_size_func(self, window_size_func):
|
||||
"""Make wrapping Defun for window_size_func."""
|
||||
|
||||
@function.Defun(dtypes.int64)
|
||||
def tf_window_size_func(key):
|
||||
key.set_shape([])
|
||||
window_size = ops.convert_to_tensor(
|
||||
window_size_func(key), dtype=dtypes.int64)
|
||||
if window_size.dtype != dtypes.int64:
|
||||
raise ValueError(
|
||||
"`window_size_func` must return a single tf.int64 tensor.")
|
||||
return window_size
|
||||
|
||||
self._window_size_func = tf_window_size_func
|
||||
self._window_size_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
def _make_key_func(self, key_func, input_dataset):
|
||||
"""Make wrapping Defun for key_func."""
|
||||
|
||||
@function.Defun(*nest.flatten(input_dataset.output_types))
|
||||
def tf_key_func(*args):
|
||||
"""A wrapper for Defun that facilitates shape inference."""
|
||||
# Pass in shape information from the input_dataset.
|
||||
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
|
||||
arg.set_shape(shape)
|
||||
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
|
||||
if _should_unpack_args(nested_args):
|
||||
ret = key_func(*nested_args)
|
||||
else:
|
||||
ret = key_func(nested_args)
|
||||
ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
|
||||
if ret.dtype != dtypes.int64:
|
||||
raise ValueError("`key_func` must return a single tf.int64 tensor.")
|
||||
return ret
|
||||
|
||||
self._key_func = tf_key_func
|
||||
self._key_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
def _make_reduce_func(self, reduce_func, input_dataset):
|
||||
"""Make wrapping Defun for reduce_func."""
|
||||
|
||||
@function.Defun(dtypes.int64, dtypes.resource)
|
||||
def tf_reduce_func(key, window_dataset_resource):
|
||||
"""A wrapper for Defun that facilitates shape inference."""
|
||||
key.set_shape([])
|
||||
window_dataset = _ResourceDataset(window_dataset_resource,
|
||||
input_dataset.output_types,
|
||||
input_dataset.output_shapes)
|
||||
output_dataset = reduce_func(key, window_dataset)
|
||||
if not isinstance(output_dataset, Dataset):
|
||||
raise TypeError("`reduce_func` must return a `Dataset` object.")
|
||||
self._output_types = output_dataset.output_types
|
||||
self._output_shapes = output_dataset.output_shapes
|
||||
return output_dataset.make_dataset_resource()
|
||||
|
||||
self._reduce_func = tf_reduce_func
|
||||
self._reduce_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._output_types
|
||||
|
||||
def make_dataset_resource(self):
|
||||
return gen_dataset_ops.group_by_window_dataset(
|
||||
self._input_dataset.make_dataset_resource(),
|
||||
self._key_func.captured_inputs,
|
||||
self._reduce_func.captured_inputs,
|
||||
self._window_size_func.captured_inputs,
|
||||
key_func=self._key_func,
|
||||
reduce_func=self._reduce_func,
|
||||
window_size_func=self._window_size_func,
|
||||
output_types=nest.flatten(self.output_types),
|
||||
output_shapes=nest.flatten(self.output_shapes))
|
||||
|
||||
|
||||
def group_by_window(dataset,
|
||||
key_func,
|
||||
reduce_func,
|
||||
window_size=None,
|
||||
window_size_func=None):
|
||||
"""Performs a windowed "group-by" operation on this dataset.
|
||||
|
||||
This method maps each consecutive element in this dataset to a key
|
||||
using `key_func` and groups the elements by key. It then applies
|
||||
`reduce_func` to at most `window_size_func(key)` elements matching the same
|
||||
key. All execpt the final window for each key will contain
|
||||
`window_size_func(key)` elements; the final window may be smaller.
|
||||
|
||||
You may provide either a constant `window_size` or a window size determined by
|
||||
the key through `window_size_func`.
|
||||
|
||||
Args:
|
||||
dataset: A `Dataset`.
|
||||
key_func: A function mapping a nested structure of tensors
|
||||
(having shapes and types defined by `self.output_shapes` and
|
||||
`self.output_types`) to a scalar `tf.int64` tensor.
|
||||
reduce_func: A function mapping a key and a dataset of up to `batch_size`
|
||||
consecutive elements matching that key to another dataset.
|
||||
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
||||
consecutive elements matching the same key to combine in a single
|
||||
batch, which will be passed to `reduce_func`. Mutually exclusive with
|
||||
`window_size_func`.
|
||||
window_size_func: A function mapping a key to a `tf.int64` scalar
|
||||
`tf.Tensor`, representing the number of consecutive elements matching
|
||||
the same key to combine in a single batch, which will be passed to
|
||||
`reduce_func`. Mutually exclusive with `window_size`.
|
||||
|
||||
Returns:
|
||||
A `Dataset`.
|
||||
|
||||
Raises:
|
||||
ValueError: if neither or both of {`window_size`, `window_size_func`} are
|
||||
passed.
|
||||
"""
|
||||
if (window_size is not None and window_size_func or
|
||||
not (window_size is not None or window_size_func)):
|
||||
raise ValueError("Must pass either window_size or window_size_func.")
|
||||
|
||||
if window_size is not None:
|
||||
|
||||
def constant_window_func(unused_key):
|
||||
return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
|
||||
|
||||
window_size_func = constant_window_func
|
||||
|
||||
assert window_size_func is not None
|
||||
return GroupByWindowDataset(dataset, key_func, reduce_func, window_size_func)
|
||||
|
@ -341,7 +341,7 @@ cuda_py_test(
|
||||
|
||||
cuda_py_test(
|
||||
name = "sample_stats_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/sample_stats_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
|
@ -17,440 +17,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_checkpoint_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
ops.NotDifferentiable("GenerateVocabRemapping")
|
||||
ops.NotDifferentiable("LoadAndRemapMatrix")
|
||||
from tensorflow.python.training import checkpoint_ops
|
||||
|
||||
|
||||
def _load_and_remap_matrix(ckpt_path,
|
||||
old_tensor_name,
|
||||
new_row_vocab_offset,
|
||||
num_rows_to_load,
|
||||
new_col_vocab_size,
|
||||
initializer,
|
||||
old_row_vocab_file=None,
|
||||
new_row_vocab_file=None,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=0,
|
||||
num_col_oov_buckets=0,
|
||||
max_rows_in_memory=-1):
|
||||
"""Loads a 2-D (matrix) `Tensor` from checkpoint.
|
||||
|
||||
Generates 1D-remappings for rows and columns using the
|
||||
`GenerateVocabRemapping` op, and initializes any anticipated values with the
|
||||
provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a
|
||||
matrix that loads existing values from the checkpoint, while filling out
|
||||
"missing" values with the newly initialized values. See
|
||||
contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped
|
||||
functionality (LoadAndRemapMatrix). This wrapper can be used to perform only
|
||||
row remapping or only col remapping. If only row remapping is desired,
|
||||
{new,old}_col_vocab_file should be `None`, and vice versa for column
|
||||
remapping.
|
||||
|
||||
NOTE: This only supports div-partitioning the vocabulary on the 1st dimension
|
||||
(row axis) via `new_row_vocab_offset`.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
new_row_vocab_offset: A 0-indexed integer representing what line to
|
||||
start reading at in the new row vocabulary. Used for partitioned
|
||||
variables.
|
||||
num_rows_to_load: Number of rows to load for the new vocabulary (note: to
|
||||
support variable partitioning and partial loading, this does not need to
|
||||
be the same as the number of entries in `new_row_vocab_file`).
|
||||
new_col_vocab_size: Number of columns to load - should be the same as the
|
||||
number of entries in `new_col_vocab_file`, since we don't support
|
||||
partitioning along the column axis.
|
||||
initializer: Callable initializer function that accepts a 1-D tensor as the
|
||||
arg to specify the shape of the returned tensor. Used to initialize
|
||||
missing values.
|
||||
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old row vocabulary file. Can be None, which represents no
|
||||
remapping on the row axis.
|
||||
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new row vocabulary file. Can be None, which represents no remapping
|
||||
on the row axis - in which case, `new_row_vocab_offset` and
|
||||
`num_rows_to_load` work under the assumption that the new row vocab is the
|
||||
same as the old row vocab.
|
||||
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis - in which case, `new_col_vocab_size` works
|
||||
under the assumption that the new col vocab is the same as the old col
|
||||
vocab.
|
||||
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
|
||||
to append. Must be >= 0.
|
||||
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
columns to append. Must be >= 0.
|
||||
max_rows_in_memory: `int` specifying the maximum number of rows to load from
|
||||
the checkpoint at once. If less than or equal to 0, the entire matrix will
|
||||
be loaded into memory. Setting this arg trades increased disk reads for
|
||||
lower memory usage.
|
||||
|
||||
Returns:
|
||||
A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
|
||||
new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the
|
||||
specified tensor in the checkpoint, and any missing or OOV values
|
||||
initialized with the given `initializer`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0.
|
||||
ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is
|
||||
provided, while the other is not. Same for `old_col_vocab_file` and
|
||||
`new_col_vocab_file`.
|
||||
ValueError: If neither row vocabs or col vocabs are provided.
|
||||
"""
|
||||
if num_row_oov_buckets < 0:
|
||||
raise ValueError("num_row_oov_buckets must be >= 0, but received %d" %
|
||||
num_row_oov_buckets)
|
||||
if num_col_oov_buckets < 0:
|
||||
raise ValueError("num_col_oov_buckets must be >= 0, but received %d" %
|
||||
num_col_oov_buckets)
|
||||
|
||||
if bool(old_row_vocab_file) != bool(new_row_vocab_file):
|
||||
raise ValueError(
|
||||
"old_row_vocab_file and new_row_vocab_file must both be specified or "
|
||||
"left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'".
|
||||
format(old_row_vocab_file, new_row_vocab_file))
|
||||
if bool(old_col_vocab_file) != bool(new_col_vocab_file):
|
||||
raise ValueError(
|
||||
"old_col_vocab_file and new_col_vocab_file must both be specified or "
|
||||
"left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'".
|
||||
format(old_col_vocab_file, new_col_vocab_file))
|
||||
|
||||
remap_rows = new_row_vocab_file and old_row_vocab_file
|
||||
remap_cols = new_col_vocab_file and old_col_vocab_file
|
||||
if not (remap_rows or remap_cols):
|
||||
raise ValueError(
|
||||
"Must provide either row or column vocab files. If no remapping is "
|
||||
"necessary, consider using `tf.contrib.framework.init_from_checkpoint` "
|
||||
"instead.")
|
||||
|
||||
num_rows_present = num_rows_to_load
|
||||
if remap_rows:
|
||||
row_remapping, num_rows_present = (
|
||||
gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access
|
||||
new_vocab_file=new_row_vocab_file,
|
||||
old_vocab_file=old_row_vocab_file,
|
||||
new_vocab_offset=new_row_vocab_offset,
|
||||
num_new_vocab=num_rows_to_load))
|
||||
else:
|
||||
# Even when the rows are not being reordered, we still need to generate a
|
||||
# remapping to account for initializing partitioned Variables (when
|
||||
# new_row_vocab_offset is non-zero).
|
||||
row_remapping = math_ops.range(
|
||||
new_row_vocab_offset,
|
||||
new_row_vocab_offset + num_rows_to_load,
|
||||
dtype=dtypes.int64)
|
||||
|
||||
col_remapping = []
|
||||
num_cols_present = new_col_vocab_size
|
||||
if remap_cols:
|
||||
col_remapping, num_cols_present = (
|
||||
gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access
|
||||
new_vocab_file=new_col_vocab_file,
|
||||
old_vocab_file=old_col_vocab_file,
|
||||
new_vocab_offset=0, # Offset is unused for cols (no partitioning).
|
||||
num_new_vocab=new_col_vocab_size))
|
||||
|
||||
init_vals = initializer([
|
||||
num_rows_to_load * new_col_vocab_size -
|
||||
num_rows_present * num_cols_present, 1
|
||||
])
|
||||
return_tensor = gen_checkpoint_ops._load_and_remap_matrix( # pylint: disable=protected-access
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=old_tensor_name,
|
||||
row_remapping=row_remapping,
|
||||
col_remapping=col_remapping,
|
||||
initializing_values=init_vals,
|
||||
num_rows=num_rows_to_load,
|
||||
num_cols=new_col_vocab_size,
|
||||
max_rows_in_memory=max_rows_in_memory)
|
||||
|
||||
# Add OOV row(s) and column(s).
|
||||
if num_row_oov_buckets > 0:
|
||||
init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size])
|
||||
init_row_oov_val = ops.convert_to_tensor(init_row_oov_val)
|
||||
return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0)
|
||||
if num_col_oov_buckets > 0:
|
||||
# We need to add any row OOV to the new column shape.
|
||||
init_col_oov_val = initializer(
|
||||
[num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets])
|
||||
init_col_oov_val = ops.convert_to_tensor(init_col_oov_val)
|
||||
return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1)
|
||||
|
||||
return return_tensor
|
||||
|
||||
|
||||
def load_and_remap_matrix_initializer(ckpt_path,
|
||||
old_tensor_name,
|
||||
new_row_vocab_size,
|
||||
new_col_vocab_size,
|
||||
old_row_vocab_file=None,
|
||||
new_row_vocab_file=None,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=0,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=None,
|
||||
max_rows_in_memory=-1):
|
||||
r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor.
|
||||
|
||||
The returned initializer loads a 2-D (matrix) `Tensor` with name
|
||||
`old_tensor_name` from the checkpoint at `ckpt_path`. It will reorder the
|
||||
rows/columns according to the specified vocab files and append additional
|
||||
out-of-vocabulary rows/columns according to the number of OOV buckets.
|
||||
|
||||
The format of the file at the `{old,new}_{row,col}_vocab_file` path should be
|
||||
a text file, with each line containing a single entity within the vocabulary.
|
||||
Let the function `line_of(f, "x")` return the 0-indexed line number of the
|
||||
entity "x" in file f, and the function `entity_at(f, i)` return the entity at
|
||||
line i of file f. Then, row i of the new output matrix will be taken from row
|
||||
`line_of(old_row_vocab_file, entity_at(new_row_vocab_file, i))` of the old
|
||||
matrix. If any entity in `new_row_vocab_file` is not found in
|
||||
`old_row_vocab_file`, that row is considered a "missing" row, and its values
|
||||
will be initialized using the `initializer` arg. The same logic also applies
|
||||
for the columns.
|
||||
|
||||
For example, assuming that:
|
||||
|
||||
* `old_row_vocab_file` contains "mercury\nvenus\nmars"
|
||||
* `new_row_vocab_file` contains "venus\njupiter\nmercury"
|
||||
* `old_col_vocab_file` contains "good\nbetter\nbest"
|
||||
* `new_col_vocab_file` contains "good\nbest\nfantastic"
|
||||
* `initializer` returns the natural numbers `[1, 2, 3, 4, ...]`
|
||||
* `w(i, j)` represents the value from row i, column j of the old matrix
|
||||
|
||||
Then the new output matrix will look like:
|
||||
|
||||
`[[w(1, 0), w(1, 2), 1],
|
||||
[2, 3, 4],
|
||||
[w(0, 0), w(0, 2), 5]]`
|
||||
|
||||
If we further specify that:
|
||||
|
||||
* `num_row_oov_buckets` == 2
|
||||
* `num_col_oov_buckets` == 1
|
||||
|
||||
Then the new output matrix will look like:
|
||||
|
||||
`[[w(1, 0), w(1, 2), 1, 12],
|
||||
[2, 3, 4, 13],
|
||||
[w(0, 0), w(0, 2), 5, 14],
|
||||
[6, 7, 8, 15],
|
||||
[9, 10, 11, 16]]`
|
||||
|
||||
If `{old,new}_row_vocab_file` are None, we assume that the old and new row
|
||||
vocab files are the same, and no row remapping is done. If
|
||||
`{old,new}_col_vocab_file` are None, we assume that the old and new column
|
||||
vocab files are the same, and no column remapping is done.
|
||||
|
||||
The returned initializer only supports div-partitioning along the row axis. It
|
||||
does not support partitioning along the column axis or mod-partitioning.
|
||||
|
||||
NOTE: When this is used to warm-start variables, client code should use
|
||||
`tf.lookup.index_table_from_tensor()` like
|
||||
contrib/layers/python/layers/feature_column.py does, as opposed to
|
||||
`tf.feature_to_id()` - in order to ensure the underlying lookup tables are the
|
||||
same.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
new_row_vocab_size: `int` specifying the number of entries in
|
||||
`new_row_vocab_file`. If no row remapping is needed (no row vocab
|
||||
provided), this should be equal to the number of rows to load from the old
|
||||
matrix (which can theoretically be smaller than the number of rows in the
|
||||
old matrix).
|
||||
new_col_vocab_size: `int` specifying the number of entries in
|
||||
`new_col_vocab_file`. If no column remapping is needed (no column vocab
|
||||
provided), this should be equal to the number of columns in the old
|
||||
matrix.
|
||||
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old row vocabulary file. Can be None, which represents no
|
||||
remapping on the row axis.
|
||||
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new row vocabulary file. Can be None, which represents no remapping
|
||||
on the row axis.
|
||||
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
|
||||
to append. Must be >= 0.
|
||||
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
columns to append. Must be >= 0.
|
||||
initializer: Initializer function to initialize missing values. Accepts a
|
||||
1-D tensor as the arg to specify the shape of the returned tensor. If
|
||||
`None`, defaults to using `zeros_initializer()`.
|
||||
max_rows_in_memory: `int` specifying the maximum number of rows to load from
|
||||
the checkpoint at once. If less than or equal to 0, the entire matrix will
|
||||
be loaded into memory. Setting this arg trades increased disk reads for
|
||||
lower memory usage.
|
||||
|
||||
Returns:
|
||||
A variable initializer function that should be used to initialize a
|
||||
(potentially partitioned) `Variable` whose complete shape is
|
||||
`[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size +
|
||||
num_col_oov_buckets]`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `initializer` is specified but not callable.
|
||||
"""
|
||||
if initializer is None:
|
||||
# TODO(b/25671353): Consider using sqrt(6/(fan_in + fan_out)) instead, from
|
||||
# Glorot and Bengio, 2010.
|
||||
initializer = init_ops.zeros_initializer()
|
||||
|
||||
if not callable(initializer):
|
||||
raise TypeError(
|
||||
"initializer must be callable, instead of being {} of type {}.".format(
|
||||
initializer, type(initializer)))
|
||||
|
||||
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
|
||||
"""Variable initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of `Tensor` to return. Should include OOV on both axes.
|
||||
dtype: Must be float32.
|
||||
partition_info: variable_scope._PartitionInfo.
|
||||
|
||||
Returns:
|
||||
`Tensor` of shape `shape`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `dtype` is anything other than float32.
|
||||
ValueError: For shape mismatch upon invocation.
|
||||
"""
|
||||
# Sanity checks.
|
||||
if dtype != dtypes.float32:
|
||||
raise TypeError(
|
||||
"Currently, only float32 is supported. Received dtype: {}".format(
|
||||
dtype))
|
||||
if len(shape) != 2:
|
||||
raise ValueError("Expected 2-dim shape, but received: {}".format(shape))
|
||||
if shape[0] <= 0:
|
||||
raise ValueError(
|
||||
"Expected 1st dim of shape to be > 0, but received shape: {}".format(
|
||||
shape))
|
||||
if shape[1] != (new_col_vocab_size + num_col_oov_buckets):
|
||||
raise ValueError(
|
||||
"Expected 2nd dim of shape to be new_col_vocab_size ({}) + "
|
||||
"num_col_oov_buckets ({}) = {}, but received shape: {}".format(
|
||||
new_col_vocab_size, num_col_oov_buckets,
|
||||
new_col_vocab_size + num_col_oov_buckets, shape))
|
||||
|
||||
offset = 0
|
||||
if partition_info is not None:
|
||||
offset = partition_info.single_offset(shape)
|
||||
|
||||
if offset + shape[0] > new_row_vocab_size + num_row_oov_buckets:
|
||||
raise ValueError(
|
||||
"Trying to initialize {} additional rows after {} rows have already "
|
||||
"been initialized, which would exceed expected total row count of "
|
||||
"new_row_vocab_size ({}) + num_row_oov_buckets ({}) = {}.".format(
|
||||
shape[0], offset, new_row_vocab_size, num_row_oov_buckets,
|
||||
new_row_vocab_size + num_row_oov_buckets))
|
||||
|
||||
row_oov_buckets_to_use = min(shape[0],
|
||||
max(0, offset + shape[0] - new_row_vocab_size))
|
||||
num_rows_to_load = shape[0] - row_oov_buckets_to_use
|
||||
|
||||
return _load_and_remap_matrix(
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=old_tensor_name,
|
||||
new_row_vocab_offset=offset,
|
||||
num_rows_to_load=num_rows_to_load,
|
||||
new_col_vocab_size=new_col_vocab_size,
|
||||
initializer=initializer,
|
||||
old_row_vocab_file=old_row_vocab_file,
|
||||
new_row_vocab_file=new_row_vocab_file,
|
||||
old_col_vocab_file=old_col_vocab_file,
|
||||
new_col_vocab_file=new_col_vocab_file,
|
||||
num_row_oov_buckets=row_oov_buckets_to_use,
|
||||
num_col_oov_buckets=num_col_oov_buckets,
|
||||
max_rows_in_memory=max_rows_in_memory)
|
||||
|
||||
return _initializer
|
||||
|
||||
|
||||
def load_embedding_initializer(ckpt_path,
|
||||
embedding_tensor_name,
|
||||
new_vocab_size,
|
||||
embedding_dim,
|
||||
old_vocab_file,
|
||||
new_vocab_file,
|
||||
num_oov_buckets=0,
|
||||
initializer=None,
|
||||
max_rows_in_memory=-1):
|
||||
"""Returns a variable initializer for loading pre-trained embeddings.
|
||||
|
||||
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
|
||||
embedding weights and remapping according to the provided vocab files. See
|
||||
docs for `load_and_remap_matrix_initializer()` for more details.
|
||||
|
||||
NOTE: Only for use with div-partitioned variables / vocabularies.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
embedding_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
new_vocab_size: Number of entries in the new vocab.
|
||||
embedding_dim: `int` specifying the dimension of the embedding vectors from
|
||||
the checkpoint. Must match the number of columns in the old embedding
|
||||
matrix.
|
||||
old_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old vocabulary file.
|
||||
new_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the new vocabulary file.
|
||||
num_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
buckets to use. Must be >= 0.
|
||||
initializer: Initializer function that accepts a 1-D tensor as the arg to
|
||||
specify the shape of the returned tensor. If `None`, defaults to using
|
||||
`truncated_normal_initializer()`.
|
||||
max_rows_in_memory: `int` specifying the maximum number of rows to load from
|
||||
the checkpoint at once. If less than or equal to 0, the entire matrix will
|
||||
be loaded into memory. Setting this arg trades increased disk reads for
|
||||
lower memory usage.
|
||||
|
||||
Returns:
|
||||
A variable initializer function.
|
||||
"""
|
||||
if initializer is None:
|
||||
# TODO(b/25671353): This should be kept in sync with the stddev used by
|
||||
# feature_column.py's _EmbeddingColumn.
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
stddev=1.0 / math.sqrt(embedding_dim))
|
||||
|
||||
return load_and_remap_matrix_initializer(
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=embedding_tensor_name,
|
||||
new_row_vocab_size=new_vocab_size,
|
||||
new_col_vocab_size=embedding_dim,
|
||||
old_row_vocab_file=old_vocab_file,
|
||||
new_row_vocab_file=new_vocab_file,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=num_oov_buckets,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=initializer,
|
||||
max_rows_in_memory=max_rows_in_memory)
|
||||
# pylint: disable=protected-access,line-too-long
|
||||
load_and_remap_matrix_initializer = checkpoint_ops._load_and_remap_matrix_initializer
|
||||
# pylint: enable=line-too-long
|
||||
load_embedding_initializer = checkpoint_ops._load_embedding_initializer
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def load_linear_multiclass_bias_initializer(ckpt_path,
|
||||
|
@ -21,7 +21,6 @@ import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.contrib.framework.python.ops import checkpoint_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -38,250 +37,6 @@ FLAGS = flags.FLAGS
|
||||
_TESTDATA_PATH = 'contrib/framework/testdata'
|
||||
|
||||
|
||||
class LoadAndRemapWrappersTest(test.TestCase):
|
||||
"""Tests for the functionality of the Python wrappers."""
|
||||
|
||||
def setUp(self):
|
||||
self.bundle_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint')
|
||||
self.new_feature_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint_vocab.txt')
|
||||
self.old_feature_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH),
|
||||
'bundle_checkpoint_vocab_with_oov.txt')
|
||||
self.new_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
|
||||
self.old_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
|
||||
self.init_val = 42
|
||||
|
||||
def _init_val_initializer(shape, dtype=None, partition_info=None):
|
||||
del dtype, partition_info # Unused by this unit-testing initializer.
|
||||
return array_ops.tile(
|
||||
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
|
||||
|
||||
self.initializer = _init_val_initializer
|
||||
|
||||
def test_load_and_remap_matrix(self):
|
||||
"""Tests the end-to-end loading / remapping of weights."""
|
||||
# _load_and_remap_matrix() is the generalized wrapper that takes in row and
|
||||
# column vocabulary files, calls the relevant remappings, and returns the
|
||||
# weight matrix. Take this example to be linear multi-class by providing
|
||||
# both row and column vocabularies.
|
||||
remapped_matrix = checkpoint_ops._load_and_remap_matrix(
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_rows_to_load=4,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_offset=1,
|
||||
initializer=self.initializer,
|
||||
num_row_oov_buckets=1,
|
||||
num_col_oov_buckets=1)
|
||||
|
||||
# [4 in vocab + 1 oov features, 4 in vocab + 1 oov classes]. The offset
|
||||
# means we read
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([18, 34, 50, self.init_val, self.init_val], [5, 1]),
|
||||
np.reshape([16, 32, 48, self.init_val, self.init_val], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1]),
|
||||
np.reshape([17, 33, 49, self.init_val, self.init_val], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
with self.test_session():
|
||||
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
|
||||
|
||||
def test_load_and_remap_output_layer_weight_initializer_linear(self):
|
||||
"""Tests for the output layer initializer in the linear multi-class case."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_row_oov_buckets=1,
|
||||
num_col_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50, self.init_val, self.init_val], [6, 1]),
|
||||
np.reshape([0, 16, 32, 48, self.init_val, self.init_val], [6, 1]),
|
||||
np.reshape([self.init_val] * 6, [6, 1]),
|
||||
np.reshape([1, 17, 33, 49, self.init_val, self.init_val], [6, 1]),
|
||||
np.reshape([self.init_val] * 6, [6, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV]. Use a
|
||||
# partitioned variable to confirm that the offset logic works.
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='linear/obtained_weight_matrix',
|
||||
shape=[6, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_and_remap_output_layer_weight_initializer_dnn_output(self):
|
||||
"""Tests for the output layer initializer in the DNN output case."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
num_col_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50, 66], [5, 1]),
|
||||
np.reshape([0, 16, 32, 48, 64], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1]),
|
||||
np.reshape([1, 17, 33, 49, 65], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5-sized input layer, 4 class vocab + 1 class OOV].
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='dnn_output/obtained_weight_matrix',
|
||||
shape=[5, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_initializer_with_oov_only_partition(self):
|
||||
"""Tests for the output layer initializer where one partition is all OOV."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_row_oov_buckets=5,
|
||||
num_col_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50] + [self.init_val] * 6, [10, 1]),
|
||||
np.reshape([0, 16, 32, 48] + [self.init_val] * 6, [10, 1]),
|
||||
np.reshape([self.init_val] * 10, [10, 1]),
|
||||
np.reshape([1, 17, 33, 49] + [self.init_val] * 6, [10, 1]),
|
||||
np.reshape([self.init_val] * 10, [10, 1]),
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV]. The
|
||||
# second partition has only OOV.
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='linear_all_oov/obtained_weight_matrix',
|
||||
shape=[10, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_and_remap_linear_multiclass_initializer_default_init(self):
|
||||
"""Tests where the zeros_initializer default is used for linear."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_row_oov_buckets=1,
|
||||
num_col_oov_buckets=1))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50, 0, 0], [6, 1]),
|
||||
np.reshape([0, 16, 32, 48, 0, 0], [6, 1]),
|
||||
np.reshape([0] * 6, [6, 1]),
|
||||
np.reshape([1, 17, 33, 49, 0, 0], [6, 1]),
|
||||
np.reshape([0] * 6, [6, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='linear_init_fallback/obtained_weight_matrix',
|
||||
shape=[6, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_embedding_initializer(self):
|
||||
"""Tests for the load_embedding_initializer wrapper."""
|
||||
embedding_loading_initializer = (
|
||||
contrib_framework.load_embedding_initializer(
|
||||
new_vocab_file=self.new_feature_vocab_file,
|
||||
old_vocab_file=self.old_feature_vocab_file,
|
||||
new_vocab_size=5,
|
||||
embedding_dim=16,
|
||||
embedding_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
num_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_embeddings = np.concatenate(
|
||||
[
|
||||
np.reshape(range(64), [4, 16]),
|
||||
np.reshape([self.init_val] * 32, [2, 16]),
|
||||
],
|
||||
axis=0)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
|
||||
# last vocab row (2nd last row) is newly initialized (wasn't found in
|
||||
# previous vocab) and the actual last row is OOV and also newly initialized.
|
||||
# Use a partitioned variable to confirm that the offset logic works.
|
||||
remapped_embeddings = variable_scope.get_variable(
|
||||
name='embedding/obtained_embedding_matrix',
|
||||
shape=[6, 16],
|
||||
initializer=embedding_loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_embeddings,
|
||||
remapped_embeddings.as_tensor().eval())
|
||||
|
||||
|
||||
class LoadMulticlassBiasTest(test.TestCase):
|
||||
"""Tests for the load_linear_multiclass_bias_initializer functionality."""
|
||||
|
||||
|
27
tensorflow/contrib/gan/BUILD
Normal file
27
tensorflow/contrib/gan/BUILD
Normal file
@ -0,0 +1,27 @@
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "gan",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
4
tensorflow/contrib/gan/README.md
Normal file
4
tensorflow/contrib/gan/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
This directory contains the TFGAN project.
|
||||
|
||||
This file will have more details as code is added.
|
||||
|
19
tensorflow/contrib/gan/__init__.py
Normal file
19
tensorflow/contrib/gan/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TFGAN grouped API."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
@ -62,6 +62,7 @@ tf_cuda_library(
|
||||
}),
|
||||
deps = [
|
||||
":gdr_proto_cc",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -121,12 +121,9 @@ tf_gen_op_wrapper_py(
|
||||
|
||||
cc_library(
|
||||
name = "image_ops_cc",
|
||||
srcs = [
|
||||
"ops/image_ops.cc",
|
||||
],
|
||||
srcs = ["ops/image_ops.cc"],
|
||||
deps = [
|
||||
":image_ops_kernels",
|
||||
"//tensorflow/core",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -551,6 +551,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/keras/utils/io_utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"],
|
||||
deps = [
|
||||
":keras",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -57,43 +57,44 @@ class TestIOUtils(test.TestCase):
|
||||
h5_path = os.path.join(temp_dir, 'test.h5')
|
||||
create_dataset(h5_path)
|
||||
|
||||
# Instantiating HDF5Matrix for the training set,
|
||||
# which is a slice of the first 150 elements
|
||||
x_train = keras.utils.io_utils.HDF5Matrix(
|
||||
h5_path, 'my_data', start=0, end=150)
|
||||
y_train = keras.utils.io_utils.HDF5Matrix(
|
||||
h5_path, 'my_labels', start=0, end=150)
|
||||
with self.test_session():
|
||||
# Instantiating HDF5Matrix for the training set,
|
||||
# which is a slice of the first 150 elements
|
||||
x_train = keras.utils.io_utils.HDF5Matrix(
|
||||
h5_path, 'my_data', start=0, end=150)
|
||||
y_train = keras.utils.io_utils.HDF5Matrix(
|
||||
h5_path, 'my_labels', start=0, end=150)
|
||||
|
||||
# Likewise for the test set
|
||||
x_test = keras.utils.io_utils.HDF5Matrix(
|
||||
h5_path, 'my_data', start=150, end=200)
|
||||
y_test = keras.utils.io_utils.HDF5Matrix(
|
||||
h5_path, 'my_labels', start=150, end=200)
|
||||
# Likewise for the test set
|
||||
x_test = keras.utils.io_utils.HDF5Matrix(
|
||||
h5_path, 'my_data', start=150, end=200)
|
||||
y_test = keras.utils.io_utils.HDF5Matrix(
|
||||
h5_path, 'my_labels', start=150, end=200)
|
||||
|
||||
# HDF5Matrix behave more or less like Numpy matrices
|
||||
# with regard to indexing
|
||||
self.assertEqual(y_train.shape, (150, 1))
|
||||
# But they do not support negative indices, so don't try print(x_train[-1])
|
||||
# HDF5Matrix behave more or less like Numpy matrices
|
||||
# with regard to indexing
|
||||
self.assertEqual(y_train.shape, (150, 1))
|
||||
# But they don't support negative indices, so don't try print(x_train[-1])
|
||||
|
||||
self.assertEqual(y_train.dtype, np.dtype('i'))
|
||||
self.assertEqual(y_train.ndim, 2)
|
||||
self.assertEqual(y_train.size, 150)
|
||||
self.assertEqual(y_train.dtype, np.dtype('i'))
|
||||
self.assertEqual(y_train.ndim, 2)
|
||||
self.assertEqual(y_train.size, 150)
|
||||
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu'))
|
||||
model.add(keras.layers.Dense(1, activation='sigmoid'))
|
||||
model.compile(loss='binary_crossentropy', optimizer='sgd')
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu'))
|
||||
model.add(keras.layers.Dense(1, activation='sigmoid'))
|
||||
model.compile(loss='binary_crossentropy', optimizer='sgd')
|
||||
|
||||
# Note: you have to use shuffle='batch' or False with HDF5Matrix
|
||||
model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False)
|
||||
# test that evalutation and prediction
|
||||
# don't crash and return reasonable results
|
||||
out_pred = model.predict(x_test, batch_size=32, verbose=False)
|
||||
out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False)
|
||||
# Note: you have to use shuffle='batch' or False with HDF5Matrix
|
||||
model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False)
|
||||
# test that evalutation and prediction
|
||||
# don't crash and return reasonable results
|
||||
out_pred = model.predict(x_test, batch_size=32, verbose=False)
|
||||
out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False)
|
||||
|
||||
self.assertEqual(out_pred.shape, (50, 1))
|
||||
self.assertEqual(out_eval.shape, ())
|
||||
self.assertGreater(out_eval, 0)
|
||||
self.assertEqual(out_pred.shape, (50, 1))
|
||||
self.assertEqual(out_eval.shape, ())
|
||||
self.assertGreater(out_eval, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -28,7 +28,6 @@ import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
@ -44,7 +43,7 @@ def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
|
||||
x_is_dict, y_is_dict = isinstance(
|
||||
x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
|
||||
if y_is_dict and n_classes is not None:
|
||||
assert (isinstance(n_classes, dict))
|
||||
assert isinstance(n_classes, dict)
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0]
|
||||
@ -322,10 +321,12 @@ class DataFeeder(object):
|
||||
|
||||
self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())
|
||||
]) if x_is_dict else check_array(x, x.dtype)
|
||||
self._y = None if y is None else \
|
||||
dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if y_is_dict else check_array(y, y.dtype)
|
||||
self._y = None if y is None else (
|
||||
dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())])
|
||||
if y_is_dict else check_array(y, y.dtype))
|
||||
|
||||
# self.n_classes is not None means we're converting raw target indices to one-hot.
|
||||
# self.n_classes is not None means we're converting raw target indices
|
||||
# to one-hot.
|
||||
if n_classes is not None:
|
||||
if not y_is_dict:
|
||||
y_dtype = (np.int64
|
||||
@ -344,12 +345,15 @@ class DataFeeder(object):
|
||||
x_shape, y_shape, n_classes, batch_size)
|
||||
|
||||
# Input dtype matches dtype of x.
|
||||
self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) if x_is_dict \
|
||||
else _check_dtype(self._x.dtype)
|
||||
self._input_dtype = (
|
||||
dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())])
|
||||
if x_is_dict else _check_dtype(self._x.dtype))
|
||||
|
||||
# note: self._output_dtype = np.float32 when y is None
|
||||
self._output_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) if y_is_dict \
|
||||
else _check_dtype(self._y.dtype) if y is not None else np.float32
|
||||
# self._output_dtype == np.float32 when y is None
|
||||
self._output_dtype = (
|
||||
dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())])
|
||||
if y_is_dict else (
|
||||
_check_dtype(self._y.dtype) if y is not None else np.float32))
|
||||
|
||||
# self.n_classes is None means we're passing in raw target indices
|
||||
if n_classes is not None and y_is_dict:
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Utilities supporting export to SavedModel.
|
||||
|
||||
Some contents of this file are moved to tensorflow/python/estimator/export.py:
|
||||
@ -39,6 +38,7 @@ import time
|
||||
from tensorflow.contrib.layers.python.layers import feature_column
|
||||
from tensorflow.contrib.learn.python.learn import export_strategy
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.learn.python.learn.utils import gc
|
||||
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
|
||||
@ -75,8 +75,8 @@ FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative'
|
||||
_FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative'
|
||||
|
||||
|
||||
def build_standardized_signature_def(
|
||||
input_tensors, output_tensors, problem_type):
|
||||
def build_standardized_signature_def(input_tensors, output_tensors,
|
||||
problem_type):
|
||||
"""Build a SignatureDef using problem type and input and output Tensors.
|
||||
|
||||
Note that this delegates the actual creation of the signatures to methods in
|
||||
@ -116,8 +116,8 @@ def build_standardized_signature_def(
|
||||
(_, predictions), = output_tensors.items()
|
||||
return signature_def_utils.regression_signature_def(examples, predictions)
|
||||
else:
|
||||
return signature_def_utils.predict_signature_def(
|
||||
input_tensors, output_tensors)
|
||||
return signature_def_utils.predict_signature_def(input_tensors,
|
||||
output_tensors)
|
||||
|
||||
|
||||
def _get_classification_scores(output_tensors):
|
||||
@ -139,17 +139,15 @@ def _is_classification_problem(problem_type, input_tensors, output_tensors):
|
||||
classes = _get_classification_classes(output_tensors)
|
||||
scores = _get_classification_scores(output_tensors)
|
||||
return ((problem_type == constants.ProblemType.CLASSIFICATION or
|
||||
problem_type == constants.ProblemType.LOGISTIC_REGRESSION)
|
||||
and len(input_tensors) == 1
|
||||
and (classes is not None or
|
||||
scores is not None or
|
||||
len(output_tensors) == 1))
|
||||
problem_type == constants.ProblemType.LOGISTIC_REGRESSION) and
|
||||
len(input_tensors) == 1 and
|
||||
(classes is not None or scores is not None or
|
||||
len(output_tensors) == 1))
|
||||
|
||||
|
||||
def _is_regression_problem(problem_type, input_tensors, output_tensors):
|
||||
return (problem_type == constants.ProblemType.LINEAR_REGRESSION
|
||||
and len(input_tensors) == 1
|
||||
and len(output_tensors) == 1)
|
||||
return (problem_type == constants.ProblemType.LINEAR_REGRESSION and
|
||||
len(input_tensors) == 1 and len(output_tensors) == 1)
|
||||
|
||||
|
||||
def get_input_alternatives(input_ops):
|
||||
@ -177,9 +175,7 @@ def get_input_alternatives(input_ops):
|
||||
return input_alternatives, features
|
||||
|
||||
|
||||
def get_output_alternatives(
|
||||
model_fn_ops,
|
||||
default_output_alternative_key=None):
|
||||
def get_output_alternatives(model_fn_ops, default_output_alternative_key=None):
|
||||
"""Obtain all output alternatives using the model_fn output and heuristics.
|
||||
|
||||
Args:
|
||||
@ -218,8 +214,10 @@ def get_output_alternatives(
|
||||
default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs}
|
||||
actual_default_output_alternative_key = (
|
||||
_FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY)
|
||||
output_alternatives = {actual_default_output_alternative_key:
|
||||
(default_problem_type, default_outputs)}
|
||||
output_alternatives = {
|
||||
actual_default_output_alternative_key: (default_problem_type,
|
||||
default_outputs)
|
||||
}
|
||||
return output_alternatives, actual_default_output_alternative_key
|
||||
|
||||
if default_output_alternative_key:
|
||||
@ -246,13 +244,12 @@ def build_all_signature_defs(input_alternatives, output_alternatives,
|
||||
actual_default_output_alternative_key):
|
||||
"""Build `SignatureDef`s from all pairs of input and output alternatives."""
|
||||
|
||||
signature_def_map = {
|
||||
('%s:%s' % (input_key, output_key or 'None')):
|
||||
build_standardized_signature_def(
|
||||
inputs, outputs, problem_type)
|
||||
for input_key, inputs in input_alternatives.items()
|
||||
for output_key, (problem_type, outputs)
|
||||
in output_alternatives.items()}
|
||||
signature_def_map = {('%s:%s' % (input_key, output_key or 'None')):
|
||||
build_standardized_signature_def(inputs, outputs,
|
||||
problem_type)
|
||||
for input_key, inputs in input_alternatives.items()
|
||||
for output_key, (problem_type,
|
||||
outputs) in output_alternatives.items()}
|
||||
|
||||
# Add the default SignatureDef
|
||||
default_inputs = input_alternatives.get(DEFAULT_INPUT_ALTERNATIVE_KEY)
|
||||
@ -263,8 +260,8 @@ def build_all_signature_defs(input_alternatives, output_alternatives,
|
||||
(default_problem_type, default_outputs) = (
|
||||
output_alternatives[actual_default_output_alternative_key])
|
||||
signature_def_map[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
|
||||
build_standardized_signature_def(
|
||||
default_inputs, default_outputs, default_problem_type))
|
||||
build_standardized_signature_def(default_inputs, default_outputs,
|
||||
default_problem_type))
|
||||
|
||||
return signature_def_map
|
||||
|
||||
@ -308,9 +305,8 @@ def get_timestamped_export_dir(export_dir_base):
|
||||
return export_dir
|
||||
time.sleep(1)
|
||||
attempts += 1
|
||||
logging.warn(
|
||||
'Export directory {} already exists; retrying (attempt {}/{})'.format(
|
||||
export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
|
||||
logging.warn('Export directory {} already exists; retrying (attempt {}/{})'.
|
||||
format(export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
|
||||
raise RuntimeError('Failed to obtain a unique export directory name after '
|
||||
'{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
|
||||
|
||||
@ -330,8 +326,7 @@ def get_temp_export_dir(timestamped_export_dir):
|
||||
"""
|
||||
(dirname, basename) = os.path.split(timestamped_export_dir)
|
||||
temp_export_dir = os.path.join(
|
||||
compat.as_bytes(dirname),
|
||||
compat.as_bytes('temp-{}'.format(basename)))
|
||||
compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename)))
|
||||
return temp_export_dir
|
||||
|
||||
|
||||
@ -357,8 +352,8 @@ def get_most_recent_export(export_dir_base):
|
||||
A gc.Path, with is just a namedtuple of (path, export_version).
|
||||
"""
|
||||
select_filter = gc.largest_export_versions(1)
|
||||
results = select_filter(gc.get_paths(export_dir_base,
|
||||
parser=_export_version_parser))
|
||||
results = select_filter(
|
||||
gc.get_paths(export_dir_base, parser=_export_version_parser))
|
||||
return next(iter(results or []), None)
|
||||
|
||||
|
||||
@ -378,8 +373,8 @@ def garbage_collect_exports(export_dir_base, exports_to_keep):
|
||||
|
||||
keep_filter = gc.largest_export_versions(exports_to_keep)
|
||||
delete_filter = gc.negation(keep_filter)
|
||||
for p in delete_filter(gc.get_paths(export_dir_base,
|
||||
parser=_export_version_parser)):
|
||||
for p in delete_filter(
|
||||
gc.get_paths(export_dir_base, parser=_export_version_parser)):
|
||||
try:
|
||||
gfile.DeleteRecursively(p.path)
|
||||
except errors_impl.NotFoundError as e:
|
||||
@ -416,10 +411,7 @@ def make_export_strategy(serving_input_fn,
|
||||
An ExportStrategy that can be passed to the Experiment constructor.
|
||||
"""
|
||||
|
||||
def export_fn(estimator,
|
||||
export_dir_base,
|
||||
checkpoint_path=None
|
||||
):
|
||||
def export_fn(estimator, export_dir_base, checkpoint_path=None):
|
||||
"""Exports the given Estimator as a SavedModel.
|
||||
|
||||
Args:
|
||||
@ -512,3 +504,128 @@ def make_parsing_export_strategy(feature_columns,
|
||||
assets_extra=assets_extra,
|
||||
as_text=as_text,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
|
||||
def _default_compare_fn(curr_best_eval_result, cand_eval_result):
|
||||
"""Compares two evaluation results and returns true if the 2nd one is better.
|
||||
|
||||
Both evaluation results should have the values for MetricKey.LOSS, which are
|
||||
used for comparison.
|
||||
|
||||
Args:
|
||||
curr_best_eval_result: current best eval metrics.
|
||||
cand_eval_result: candidate eval metrics.
|
||||
|
||||
Returns:
|
||||
True if cand_eval_result is better.
|
||||
|
||||
Raises:
|
||||
ValueError: If input eval result is None or no loss is available.
|
||||
"""
|
||||
default_key = metric_key.MetricKey.LOSS
|
||||
if not curr_best_eval_result or default_key not in curr_best_eval_result:
|
||||
raise ValueError(
|
||||
'curr_best_eval_result cannot be empty or no loss is found in it.')
|
||||
|
||||
if not cand_eval_result or default_key not in cand_eval_result:
|
||||
raise ValueError(
|
||||
'cand_eval_result cannot be empty or no loss is found in it.')
|
||||
|
||||
return curr_best_eval_result[default_key] > cand_eval_result[default_key]
|
||||
|
||||
|
||||
class BestModelSelector(object):
|
||||
"""A helper that keeps track of export selection candidates."""
|
||||
|
||||
def __init__(self, compare_fn=None):
|
||||
"""Constructor of this class.
|
||||
|
||||
Args:
|
||||
compare_fn: a function that returns true if the candidate is better than
|
||||
the current best model.
|
||||
"""
|
||||
self._best_eval_result = None
|
||||
self._compare_fn = compare_fn or _default_compare_fn
|
||||
|
||||
def update(self, checkpoint_path, eval_result):
|
||||
"""Records a given checkpoint and exports if this is the best model.
|
||||
|
||||
Args:
|
||||
checkpoint_path: the checkpoint path to export.
|
||||
eval_result: a dictionary which is usually generated in evaluation runs.
|
||||
By default, eval_results contains 'loss' field.
|
||||
|
||||
Returns:
|
||||
A string representing the path to the checkpoint to be exported.
|
||||
A dictionary of the same type of eval_result.
|
||||
|
||||
Raises:
|
||||
ValueError: if checkpoint path is empty.
|
||||
ValueError: if eval_results is None object.
|
||||
"""
|
||||
if not checkpoint_path:
|
||||
raise ValueError('Checkpoint path is empty.')
|
||||
if eval_result is None:
|
||||
raise ValueError('%s has empty evaluation results.', checkpoint_path)
|
||||
|
||||
if (self._best_eval_result is None or
|
||||
self._compare_fn(self._best_eval_result, eval_result)):
|
||||
self._best_eval_result = eval_result
|
||||
return checkpoint_path, eval_result
|
||||
else:
|
||||
return '', None
|
||||
|
||||
|
||||
def make_best_model_export_strategy(serving_input_fn,
|
||||
exports_to_keep=1,
|
||||
compare_fn=None,
|
||||
default_output_alternative_key=None):
|
||||
"""Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment.
|
||||
|
||||
Args:
|
||||
serving_input_fn: a function that takes no arguments and returns an
|
||||
`InputFnOps`.
|
||||
exports_to_keep: an integer indicating how many historical best models need
|
||||
to be preserved.
|
||||
compare_fn: a function that select the 'best' candidate from a dictionary
|
||||
of evaluation result keyed by corresponding checkpoint path.
|
||||
default_output_alternative_key: the key for default serving signature for
|
||||
multi-headed inference graphs.
|
||||
|
||||
Returns:
|
||||
An ExportStrategy that can be passed to the Experiment constructor.
|
||||
"""
|
||||
best_model_export_strategy = make_export_strategy(
|
||||
serving_input_fn,
|
||||
exports_to_keep=exports_to_keep,
|
||||
default_output_alternative_key=default_output_alternative_key)
|
||||
|
||||
best_model_selector = BestModelSelector(compare_fn)
|
||||
|
||||
def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None):
|
||||
"""Exports the given Estimator as a SavedModel.
|
||||
|
||||
Args:
|
||||
estimator: the Estimator to export.
|
||||
export_dir_base: A string containing a directory to write the exported
|
||||
graph and checkpoints.
|
||||
checkpoint_path: The checkpoint path to export. If None (the default),
|
||||
the most recent checkpoint found within the model directory is chosen.
|
||||
eval_result: placehold args matching the call signature of ExportStrategy.
|
||||
|
||||
Returns:
|
||||
The string path to the exported directory.
|
||||
"""
|
||||
|
||||
export_checkpoint_path, export_eval_result = best_model_selector.update(
|
||||
checkpoint_path, eval_result)
|
||||
|
||||
if export_checkpoint_path and export_eval_result is not None:
|
||||
checkpoint_base = os.path.basename(export_checkpoint_path)
|
||||
export_dir = os.path.join(export_dir_base, checkpoint_base)
|
||||
return best_model_export_strategy.export(
|
||||
estimator, export_dir, export_checkpoint_path, export_eval_result)
|
||||
else:
|
||||
return ''
|
||||
|
||||
return export_strategy.ExportStrategy('best_model', export_fn)
|
||||
|
@ -24,6 +24,7 @@ import time
|
||||
from tensorflow.contrib.layers.python.layers import feature_column as fc
|
||||
from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator as core_estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
|
||||
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
|
||||
@ -40,18 +41,43 @@ from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class TestEstimator(core_estimator.Estimator):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestEstimator, self).__init__(*args, **kwargs)
|
||||
self.last_exported_checkpoint = ""
|
||||
self.last_exported_dir = ""
|
||||
|
||||
# @Override
|
||||
def export_savedmodel(self,
|
||||
export_dir,
|
||||
serving_input_fn,
|
||||
default_output_alternative_key=None,
|
||||
assets_extra=None,
|
||||
as_text=False,
|
||||
checkpoint_path=None):
|
||||
|
||||
if not os.path.exists(export_dir):
|
||||
os.makedirs(export_dir)
|
||||
|
||||
open(os.path.join(export_dir, "placeholder.txt"), "a").close()
|
||||
|
||||
self.last_exported_checkpoint = checkpoint_path
|
||||
self.last_exported_dir = export_dir
|
||||
|
||||
return export_dir
|
||||
|
||||
|
||||
class SavedModelExportUtilsTest(test.TestCase):
|
||||
|
||||
def test_build_standardized_signature_def_regression(self):
|
||||
input_tensors = {
|
||||
"input-1":
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, 1, name="input-tensor-1")
|
||||
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
|
||||
}
|
||||
output_tensors = {
|
||||
"output-1":
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, 1, name="output-tensor-1")
|
||||
array_ops.placeholder(dtypes.float32, 1, name="output-tensor-1")
|
||||
}
|
||||
problem_type = constants.ProblemType.LINEAR_REGRESSION
|
||||
actual_signature_def = (
|
||||
@ -61,10 +87,9 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
shape = tensor_shape_pb2.TensorShapeProto(
|
||||
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
|
||||
dtype = types_pb2.DataType.Value("DT_FLOAT")
|
||||
expected_signature_def.inputs[
|
||||
signature_constants.REGRESS_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype, tensor_shape=shape))
|
||||
expected_signature_def.inputs[signature_constants.REGRESS_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype, tensor_shape=shape))
|
||||
expected_signature_def.outputs[
|
||||
signature_constants.REGRESS_OUTPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
@ -77,13 +102,11 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
"""Tests classification with one output tensor."""
|
||||
input_tensors = {
|
||||
"input-1":
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, 1, name="input-tensor-1")
|
||||
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
|
||||
}
|
||||
output_tensors = {
|
||||
"output-1":
|
||||
array_ops.placeholder(
|
||||
dtypes.string, 1, name="output-tensor-1")
|
||||
array_ops.placeholder(dtypes.string, 1, name="output-tensor-1")
|
||||
}
|
||||
problem_type = constants.ProblemType.CLASSIFICATION
|
||||
actual_signature_def = (
|
||||
@ -94,14 +117,14 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
|
||||
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
|
||||
dtype_string = types_pb2.DataType.Value("DT_STRING")
|
||||
expected_signature_def.inputs[
|
||||
signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.outputs[
|
||||
signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="output-tensor-1:0", dtype=dtype_string,
|
||||
name="output-tensor-1:0",
|
||||
dtype=dtype_string,
|
||||
tensor_shape=shape))
|
||||
|
||||
expected_signature_def.method_name = (
|
||||
@ -112,8 +135,7 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
"""Tests multiple output tensors that include classes and probabilities."""
|
||||
input_tensors = {
|
||||
"input-1":
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, 1, name="input-tensor-1")
|
||||
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
|
||||
}
|
||||
output_tensors = {
|
||||
"classes":
|
||||
@ -136,19 +158,20 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
|
||||
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
|
||||
dtype_string = types_pb2.DataType.Value("DT_STRING")
|
||||
expected_signature_def.inputs[
|
||||
signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.outputs[
|
||||
signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="output-tensor-classes:0", dtype=dtype_string,
|
||||
name="output-tensor-classes:0",
|
||||
dtype=dtype_string,
|
||||
tensor_shape=shape))
|
||||
expected_signature_def.outputs[
|
||||
signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="output-tensor-proba:0", dtype=dtype_float,
|
||||
name="output-tensor-proba:0",
|
||||
dtype=dtype_float,
|
||||
tensor_shape=shape))
|
||||
|
||||
expected_signature_def.method_name = (
|
||||
@ -159,8 +182,7 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
"""Tests multiple output tensors that include classes and scores."""
|
||||
input_tensors = {
|
||||
"input-1":
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, 1, name="input-tensor-1")
|
||||
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
|
||||
}
|
||||
output_tensors = {
|
||||
"classes":
|
||||
@ -182,19 +204,20 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
|
||||
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
|
||||
dtype_string = types_pb2.DataType.Value("DT_STRING")
|
||||
expected_signature_def.inputs[
|
||||
signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.outputs[
|
||||
signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="output-tensor-classes:0", dtype=dtype_string,
|
||||
name="output-tensor-classes:0",
|
||||
dtype=dtype_string,
|
||||
tensor_shape=shape))
|
||||
expected_signature_def.outputs[
|
||||
signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="output-tensor-scores:0", dtype=dtype_float,
|
||||
name="output-tensor-scores:0",
|
||||
dtype=dtype_float,
|
||||
tensor_shape=shape))
|
||||
|
||||
expected_signature_def.method_name = (
|
||||
@ -205,8 +228,7 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
"""Tests classification without classes tensor."""
|
||||
input_tensors = {
|
||||
"input-1":
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, 1, name="input-tensor-1")
|
||||
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
|
||||
}
|
||||
output_tensors = {
|
||||
"probabilities":
|
||||
@ -224,14 +246,14 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
shape = tensor_shape_pb2.TensorShapeProto(
|
||||
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
|
||||
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
|
||||
expected_signature_def.inputs[
|
||||
signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.outputs[
|
||||
signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="output-tensor-proba:0", dtype=dtype_float,
|
||||
name="output-tensor-proba:0",
|
||||
dtype=dtype_float,
|
||||
tensor_shape=shape))
|
||||
|
||||
expected_signature_def.method_name = (
|
||||
@ -246,8 +268,7 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
"""
|
||||
input_tensors = {
|
||||
"input-1":
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, 1, name="input-tensor-1")
|
||||
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
|
||||
}
|
||||
output_tensors = {
|
||||
"classes":
|
||||
@ -268,14 +289,14 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
shape = tensor_shape_pb2.TensorShapeProto(
|
||||
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
|
||||
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
|
||||
expected_signature_def.inputs[
|
||||
signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.outputs[
|
||||
signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="output-tensor-scores:0", dtype=dtype_float,
|
||||
name="output-tensor-scores:0",
|
||||
dtype=dtype_float,
|
||||
tensor_shape=shape))
|
||||
|
||||
expected_signature_def.method_name = (
|
||||
@ -290,8 +311,7 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
"""
|
||||
input_tensors = {
|
||||
"input-1":
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, 1, name="input-tensor-1")
|
||||
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
|
||||
}
|
||||
output_tensors = {
|
||||
"classes":
|
||||
@ -310,17 +330,18 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
|
||||
dtype_int64 = types_pb2.DataType.Value("DT_INT64")
|
||||
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
|
||||
expected_signature_def.inputs[
|
||||
"input-1"].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.inputs["input-1"].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
|
||||
expected_signature_def.outputs["classes"].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="output-tensor-classes:0", dtype=dtype_int64,
|
||||
name="output-tensor-classes:0",
|
||||
dtype=dtype_int64,
|
||||
tensor_shape=shape))
|
||||
expected_signature_def.outputs["logits"].CopyFrom(
|
||||
meta_graph_pb2.TensorInfo(
|
||||
name="output-tensor-logits:0", dtype=dtype_float,
|
||||
name="output-tensor-logits:0",
|
||||
dtype=dtype_float,
|
||||
tensor_shape=shape))
|
||||
|
||||
expected_signature_def.method_name = (
|
||||
@ -379,8 +400,9 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
def test_get_output_alternatives_single_no_default(self):
|
||||
prediction_tensor = constant_op.constant(["bogus"])
|
||||
provided_output_alternatives = {
|
||||
"head-1": (constants.ProblemType.LINEAR_REGRESSION,
|
||||
{"output": prediction_tensor}),
|
||||
"head-1": (constants.ProblemType.LINEAR_REGRESSION, {
|
||||
"output": prediction_tensor
|
||||
}),
|
||||
}
|
||||
model_fn_ops = model_fn.ModelFnOps(
|
||||
model_fn.ModeKeys.INFER,
|
||||
@ -390,10 +412,11 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
|
||||
model_fn_ops)
|
||||
|
||||
self.assertEqual({"head-1":
|
||||
(constants.ProblemType.LINEAR_REGRESSION,
|
||||
{"output": prediction_tensor})},
|
||||
output_alternatives)
|
||||
self.assertEqual({
|
||||
"head-1": (constants.ProblemType.LINEAR_REGRESSION, {
|
||||
"output": prediction_tensor
|
||||
})
|
||||
}, output_alternatives)
|
||||
|
||||
def test_get_output_alternatives_multi_no_default(self):
|
||||
provided_output_alternatives = {
|
||||
@ -424,10 +447,11 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
|
||||
model_fn_ops)
|
||||
|
||||
self.assertEqual(
|
||||
{"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
|
||||
"some_output": prediction_tensor})},
|
||||
output_alternatives)
|
||||
self.assertEqual({
|
||||
"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
|
||||
"some_output": prediction_tensor
|
||||
})
|
||||
}, output_alternatives)
|
||||
|
||||
def test_get_output_alternatives_empty_provided_with_default(self):
|
||||
prediction_tensor = constant_op.constant(["bogus"])
|
||||
@ -452,10 +476,11 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
|
||||
model_fn_ops)
|
||||
|
||||
self.assertEqual(
|
||||
{"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
|
||||
"some_output": prediction_tensor})},
|
||||
output_alternatives)
|
||||
self.assertEqual({
|
||||
"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
|
||||
"some_output": prediction_tensor
|
||||
})
|
||||
}, output_alternatives)
|
||||
|
||||
def test_get_output_alternatives_implicit_single(self):
|
||||
prediction_tensor = constant_op.constant(["bogus"])
|
||||
@ -506,14 +531,14 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
|
||||
expected_signature_defs = {
|
||||
"serving_default":
|
||||
signature_def_utils.regression_signature_def(input_example,
|
||||
output_1),
|
||||
signature_def_utils.regression_signature_def(
|
||||
input_example, output_1),
|
||||
"default_input_alternative:head-1":
|
||||
signature_def_utils.regression_signature_def(input_example,
|
||||
output_1),
|
||||
signature_def_utils.regression_signature_def(
|
||||
input_example, output_1),
|
||||
"default_input_alternative:head-2":
|
||||
signature_def_utils.classification_signature_def(input_example,
|
||||
output_2, None),
|
||||
signature_def_utils.classification_signature_def(
|
||||
input_example, output_2, None),
|
||||
"default_input_alternative:head-3":
|
||||
signature_def_utils.predict_signature_def({
|
||||
"default input": input_example
|
||||
@ -624,17 +649,20 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
(most_recent_export_dir, most_recent_export_version) = (
|
||||
saved_model_export_utils.get_most_recent_export(export_dir_base))
|
||||
|
||||
self.assertEqual(compat.as_bytes(export_dir_4),
|
||||
compat.as_bytes(most_recent_export_dir))
|
||||
self.assertEqual(compat.as_bytes(export_dir_4),
|
||||
os.path.join(compat.as_bytes(export_dir_base),
|
||||
compat.as_bytes(
|
||||
str(most_recent_export_version))))
|
||||
self.assertEqual(
|
||||
compat.as_bytes(export_dir_4), compat.as_bytes(most_recent_export_dir))
|
||||
self.assertEqual(
|
||||
compat.as_bytes(export_dir_4),
|
||||
os.path.join(
|
||||
compat.as_bytes(export_dir_base),
|
||||
compat.as_bytes(str(most_recent_export_version))))
|
||||
|
||||
def test_make_export_strategy(self):
|
||||
"""Only tests that an ExportStrategy instance is created."""
|
||||
|
||||
def _serving_input_fn():
|
||||
return array_ops.constant([1]), None
|
||||
|
||||
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||
serving_input_fn=_serving_input_fn,
|
||||
default_output_alternative_key="default",
|
||||
@ -655,14 +683,61 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
real_valued_col1 = fc.real_valued_column("real_valued_column1")
|
||||
bucketized_col1 = fc.bucketized_column(
|
||||
fc.real_valued_column("real_valued_column_for_bucketization1"), [0, 4])
|
||||
feature_columns = [sparse_col, embedding_col, real_valued_col1,
|
||||
bucketized_col1]
|
||||
feature_columns = [
|
||||
sparse_col, embedding_col, real_valued_col1, bucketized_col1
|
||||
]
|
||||
|
||||
export_strategy = saved_model_export_utils.make_parsing_export_strategy(
|
||||
feature_columns=feature_columns)
|
||||
self.assertTrue(
|
||||
isinstance(export_strategy, export_strategy_lib.ExportStrategy))
|
||||
|
||||
def test_make_best_model_export_strategy(self):
|
||||
export_dir_base = tempfile.mkdtemp() + "export/"
|
||||
gfile.MkDir(export_dir_base)
|
||||
|
||||
test_estimator = TestEstimator()
|
||||
export_strategy = saved_model_export_utils.make_best_model_export_strategy(
|
||||
serving_input_fn=None, exports_to_keep=3, compare_fn=None)
|
||||
|
||||
self.assertNotEqual("",
|
||||
export_strategy.export(test_estimator, export_dir_base,
|
||||
"fake_ckpt_0", {"loss": 100}))
|
||||
self.assertNotEqual("", test_estimator.last_exported_dir)
|
||||
self.assertNotEqual("", test_estimator.last_exported_checkpoint)
|
||||
|
||||
self.assertEqual("",
|
||||
export_strategy.export(test_estimator, export_dir_base,
|
||||
"fake_ckpt_1", {"loss": 101}))
|
||||
self.assertEqual(test_estimator.last_exported_dir,
|
||||
os.path.join(export_dir_base, "fake_ckpt_0"))
|
||||
|
||||
self.assertNotEqual("",
|
||||
export_strategy.export(test_estimator, export_dir_base,
|
||||
"fake_ckpt_2", {"loss": 10}))
|
||||
self.assertEqual(test_estimator.last_exported_dir,
|
||||
os.path.join(export_dir_base, "fake_ckpt_2"))
|
||||
|
||||
self.assertEqual("",
|
||||
export_strategy.export(test_estimator, export_dir_base,
|
||||
"fake_ckpt_3", {"loss": 20}))
|
||||
self.assertEqual(test_estimator.last_exported_dir,
|
||||
os.path.join(export_dir_base, "fake_ckpt_2"))
|
||||
|
||||
def test_make_best_model_export_strategy_exceptions(self):
|
||||
export_dir_base = tempfile.mkdtemp() + "export/"
|
||||
|
||||
test_estimator = TestEstimator()
|
||||
export_strategy = saved_model_export_utils.make_best_model_export_strategy(
|
||||
serving_input_fn=None, exports_to_keep=3, compare_fn=None)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
export_strategy.export(test_estimator, export_dir_base, "", {"loss": 200})
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
export_strategy.export(test_estimator, export_dir_base, "fake_ckpt_1",
|
||||
None)
|
||||
|
||||
|
||||
def _create_test_export_dir(export_dir_base):
|
||||
export_dir = saved_model_export_utils.get_timestamped_export_dir(
|
||||
|
71
tensorflow/contrib/receptive_field/BUILD
Normal file
71
tensorflow/contrib/receptive_field/BUILD
Normal file
@ -0,0 +1,71 @@
|
||||
# Description:
|
||||
# Contains modules to compute receptive field parameters for CNN models.
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
# Transitive dependencies of this target will be included in the pip package.
|
||||
py_library(
|
||||
name = "receptive_field_pip",
|
||||
deps = [
|
||||
":graph_compute_order_py",
|
||||
":receptive_field_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "graph_compute_order_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/util/graph_compute_order.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "receptive_field_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/util/receptive_field.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":graph_compute_order_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "receptive_field_test",
|
||||
srcs = ["python/util/receptive_field_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":receptive_field_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/slim",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:nn",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
165
tensorflow/contrib/receptive_field/README.md
Normal file
165
tensorflow/contrib/receptive_field/README.md
Normal file
@ -0,0 +1,165 @@
|
||||
# Receptive field computation for convnets
|
||||
|
||||
This library enables you to easily compute the receptive field parameters of
|
||||
your favorite convnet. You can use it to understand how big of an input image
|
||||
region your output features depend on. Better yet, using the parameters computed
|
||||
by the library, you can easily find the exact image region which is used to
|
||||
compute each convnet feature.
|
||||
|
||||
## Basic usage
|
||||
|
||||
The main function to be called is `compute_receptive_field_from_graph_def`,
|
||||
which will return the receptive field, effective stride and effective padding
|
||||
for both horizontal and vertical directions.
|
||||
|
||||
For example, if your model is constructed using the function
|
||||
`my_model_construction()`, you can use the library as follows:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib import receptive_field
|
||||
|
||||
# Construct graph.
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image')
|
||||
my_model_construction(images)
|
||||
|
||||
# Compute receptive field parameters.
|
||||
rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \
|
||||
receptive_field.compute_receptive_field_from_graph_def( \
|
||||
g.as_graph_def(), 'input_image', 'my_output_endpoint')
|
||||
```
|
||||
|
||||
Here's a simple example of computing the receptive field parameters for
|
||||
Inception-Resnet-v2. To get this to work, be sure to checkout
|
||||
[tensorflow/models](https://github.com/tensorflow/models), so that the Inception
|
||||
models are available to you. This can be done in three simple commands:
|
||||
|
||||
```sh
|
||||
git clone https://github.com/tensorflow/models
|
||||
cd models/slim
|
||||
sudo python setup.py install_lib
|
||||
```
|
||||
|
||||
You can then compute the receptive field parameters for Inception-Resnet-v2 as:
|
||||
|
||||
```python
|
||||
from nets import inception
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib import receptive_field
|
||||
|
||||
# Construct graph.
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image')
|
||||
inception.inception_resnet_v2_base(images)
|
||||
|
||||
# Compute receptive field parameters.
|
||||
rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \
|
||||
receptive_field.compute_receptive_field_from_graph_def( \
|
||||
g.as_graph_def(), 'input_image', 'InceptionResnetV2/Conv2d_7b_1x1/Relu')
|
||||
```
|
||||
|
||||
This will give you `rf_x = rf_y = 3039`, `eff_stride_x = eff_stride_y = 32`, and
|
||||
`eff_pad_x = eff_pad_y = 1482`. This means that each feature that is output at
|
||||
the node `'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is computed from a region
|
||||
which is of size `3039x3039`. Further, by using the expressions
|
||||
|
||||
```python
|
||||
center_x = -eff_pad_x + feature_x*eff_stride_x + (rf_x - 1)/2
|
||||
center_y = -eff_pad_y + feature_y*eff_stride_y + (rf_y - 1)/2
|
||||
```
|
||||
|
||||
one can compute the center of the region in the input image that is used to
|
||||
compute the output feature at position `[feature_x, feature_y]`. For example,
|
||||
the feature at position `[0, 2]` at the output of the layer
|
||||
`'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is centered in the original image in
|
||||
the position `[37, 101]`.
|
||||
|
||||
TODO: include link to derivations and definitions of different parameters.
|
||||
|
||||
## Receptive field benchmark
|
||||
|
||||
As you might expect, it is straightforward to run this library on the popular
|
||||
convnets, and gather their receptive fields. We provide a python script which
|
||||
does exactly that, available under `python/util/examples/rf_benchmark.py`.
|
||||
|
||||
To get this to work, be sure to checkout
|
||||
[tensorflow/models](https://github.com/tensorflow/models) (see the 3-command
|
||||
instructions for this above). Then, simply:
|
||||
|
||||
```sh
|
||||
cd python/util/examples
|
||||
python rf_benchmark.py --csv_path /tmp/rf_benchmark_results.csv
|
||||
```
|
||||
|
||||
The script will write to stdout the receptive field parameters for many variants
|
||||
of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They
|
||||
are also written to the file `/tmp/rf_benchmark_results.csv`.
|
||||
|
||||
TODO: include here a plot for receptive field sizes of different convnets.
|
||||
|
||||
TODO: include table/link to pre-computed RF parameters.
|
||||
|
||||
## Compute RF parameters from a graph pbtxt
|
||||
|
||||
We also provide a utility to compute the receptive field parameters directly
|
||||
from a graph protobuf file.
|
||||
|
||||
Have a `graph.pbtxt` file and want to compute its receptive field parameters? We
|
||||
got you covered. The only prerequisite is to install
|
||||
[google/protobuf](https://github.com/google/protobuf), which you probably
|
||||
already have if you're using tensorflow (otherwise, follow installation
|
||||
instructions [here](https://github.com/google/protobuf/tree/master/python)).
|
||||
|
||||
This should work:
|
||||
|
||||
```sh
|
||||
cd python/util/examples
|
||||
python compute_rf.py \
|
||||
--graph_path /path/to/graph.pbtxt \
|
||||
--output_path /path/to/output/rf_info.txt \
|
||||
--input_node my_input_node \
|
||||
--output_node my_output_node
|
||||
```
|
||||
|
||||
Don't know how to generate a graph protobuf file? Take a look at the
|
||||
`write_inception_resnet_v2_graph.py` script, which shows how to save it for the
|
||||
Inception-Resnet-v2 model:
|
||||
|
||||
```sh
|
||||
cd python/util/examples
|
||||
python write_inception_resnet_v2_graph.py --graph_dir /tmp --graph_filename graph.pbtxt
|
||||
```
|
||||
|
||||
This will write the Inception-Resnet-v2 graph protobuf to `/tmp/graph.pbtxt`.
|
||||
|
||||
For completeness, here's how you would use this file to get the receptive field
|
||||
parameters of the Inception-Resnet-v2 model:
|
||||
|
||||
```sh
|
||||
cd python/util/examples
|
||||
python compute_rf.py \
|
||||
--graph_path /tmp/graph.pbtxt \
|
||||
--output_path /tmp/rf_info.txt \
|
||||
--input_node input_image \
|
||||
--output_node InceptionResnetV2/Conv2d_7b_1x1/Relu
|
||||
```
|
||||
|
||||
This will write the receptive field parameters of the model to
|
||||
`/tmp/rf_info.txt`, which will look like:
|
||||
|
||||
```sh
|
||||
Receptive field size (horizontal) = 3039
|
||||
Receptive field size (vertical) = 3039
|
||||
Effective stride (horizontal) = 32
|
||||
Effective stride (vertical) = 32
|
||||
Effective padding (horizontal) = 1482
|
||||
Effective padding (vertical) = 1482
|
||||
```
|
||||
|
||||
## Authors
|
||||
|
||||
André Araujo (github id: andrefaraujo) and Mark Sandler (github id:
|
||||
marksandler)
|
23
tensorflow/contrib/receptive_field/__init__.py
Normal file
23
tensorflow/contrib/receptive_field/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Module to compute receptive field parameters for CNN tensorflow models."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.receptive_field.python.util.graph_compute_order import get_compute_order
|
||||
from tensorflow.contrib.receptive_field.python.util.receptive_field import compute_receptive_field_from_graph_def
|
||||
# pylint: enable=unused-import
|
19
tensorflow/contrib/receptive_field/python/__init__.py
Normal file
19
tensorflow/contrib/receptive_field/python/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Module to compute receptive field parameters for CNN tensorflow models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
@ -0,0 +1,94 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Computes Receptive Field (RF) information given a graph protobuf.
|
||||
|
||||
For an example of usage, see accompanying file compute_rf.sh
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.contrib import receptive_field
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
cmd_args = None
|
||||
|
||||
|
||||
def _load_graphdef(path):
|
||||
"""Helper function to load GraphDef from file.
|
||||
|
||||
Args:
|
||||
path: Path to pbtxt file.
|
||||
|
||||
Returns:
|
||||
graph_def: A GraphDef object.
|
||||
"""
|
||||
graph_def = graph_pb2.GraphDef()
|
||||
pbstr = gfile.Open(path).read()
|
||||
text_format.Parse(pbstr, graph_def)
|
||||
return graph_def
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
|
||||
graph_def = _load_graphdef(cmd_args.graph_path)
|
||||
|
||||
(receptive_field_x, receptive_field_y, effective_stride_x, effective_stride_y,
|
||||
effective_padding_x, effective_padding_y
|
||||
) = receptive_field.compute_receptive_field_from_graph_def(
|
||||
graph_def, cmd_args.input_node, cmd_args.output_node)
|
||||
|
||||
logging.info('Receptive field size (horizontal) = %s', receptive_field_x)
|
||||
logging.info('Receptive field size (vertical) = %s', receptive_field_y)
|
||||
logging.info('Effective stride (horizontal) = %s', effective_stride_x)
|
||||
logging.info('Effective stride (vertical) = %s', effective_stride_y)
|
||||
logging.info('Effective padding (horizontal) = %s', effective_padding_x)
|
||||
logging.info('Effective padding (vertical) = %s', effective_padding_y)
|
||||
|
||||
f = gfile.GFile('%s' % cmd_args.output_path, 'w')
|
||||
f.write('Receptive field size (horizontal) = %s\n' % receptive_field_x)
|
||||
f.write('Receptive field size (vertical) = %s\n' % receptive_field_y)
|
||||
f.write('Effective stride (horizontal) = %s\n' % effective_stride_x)
|
||||
f.write('Effective stride (vertical) = %s\n' % effective_stride_y)
|
||||
f.write('Effective padding (horizontal) = %s\n' % effective_padding_x)
|
||||
f.write('Effective padding (vertical) = %s\n' % effective_padding_y)
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||
parser.add_argument(
|
||||
'--graph_path', type=str, default='', help='Graph path (pbtxt format).')
|
||||
parser.add_argument(
|
||||
'--output_path',
|
||||
type=str,
|
||||
default='',
|
||||
help='Path to output text file where RF information will be written to.')
|
||||
parser.add_argument(
|
||||
'--input_node', type=str, default='', help='Name of input node.')
|
||||
parser.add_argument(
|
||||
'--output_node', type=str, default='', help='Name of output node.')
|
||||
cmd_args, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -0,0 +1,460 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Computes Receptive Field (RF) information for different models.
|
||||
|
||||
The receptive field (and related parameters) for the different models are
|
||||
printed to stdout, and may also optionally be written to a CSV file.
|
||||
|
||||
For an example of usage, see rf_benchmark.sh
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import sys
|
||||
|
||||
from nets import alexnet
|
||||
from nets import inception
|
||||
from nets import mobilenet_v1
|
||||
from nets import resnet_v1
|
||||
from nets import resnet_v2
|
||||
from nets import vgg
|
||||
from tensorflow.contrib import framework
|
||||
from tensorflow.contrib import receptive_field
|
||||
from tensorflow.contrib import slim
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
cmd_args = None
|
||||
|
||||
# Input node name for all architectures.
|
||||
_INPUT_NODE = 'input_image'
|
||||
|
||||
# Variants of different network architectures.
|
||||
|
||||
# - resnet: different versions and sizes.
|
||||
_SUPPORTED_RESNET_VARIANTS = [
|
||||
'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v1_200',
|
||||
'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'resnet_v2_200'
|
||||
]
|
||||
|
||||
# - inception_resnet_v2: default, and version with SAME padding.
|
||||
_SUPPORTED_INCEPTIONRESNETV2_VARIANTS = [
|
||||
'inception_resnet_v2', 'inception_resnet_v2-same'
|
||||
]
|
||||
|
||||
# - inception_v2: default, and version with no separable conv.
|
||||
_SUPPORTED_INCEPTIONV2_VARIANTS = [
|
||||
'inception_v2', 'inception_v2-no-separable-conv'
|
||||
]
|
||||
|
||||
# - inception_v3: default version.
|
||||
_SUPPORTED_INCEPTIONV3_VARIANTS = ['inception_v3']
|
||||
|
||||
# - inception_v4: default version.
|
||||
_SUPPORTED_INCEPTIONV4_VARIANTS = ['inception_v4']
|
||||
|
||||
# - alexnet_v2: default version.
|
||||
_SUPPORTED_ALEXNETV2_VARIANTS = ['alexnet_v2']
|
||||
|
||||
# - vgg: vgg_a (with 11 layers) and vgg_16 (version D).
|
||||
_SUPPORTED_VGG_VARIANTS = ['vgg_a', 'vgg_16']
|
||||
|
||||
# - mobilenet_v1: 100% and 75%.
|
||||
_SUPPORTED_MOBILENETV1_VARIANTS = ['mobilenet_v1', 'mobilenet_v1_075']
|
||||
|
||||
|
||||
def _construct_model(model_type='resnet_v1_50'):
|
||||
"""Constructs model for the desired type of CNN.
|
||||
|
||||
Args:
|
||||
model_type: Type of model to be used.
|
||||
|
||||
Returns:
|
||||
end_points: A dictionary from components of the network to the corresponding
|
||||
activations.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model_type is not supported.
|
||||
"""
|
||||
# Placeholder input.
|
||||
images = array_ops.placeholder(
|
||||
dtypes.float32, shape=(1, None, None, 3), name=_INPUT_NODE)
|
||||
|
||||
# Construct model.
|
||||
if model_type == 'inception_resnet_v2':
|
||||
_, end_points = inception.inception_resnet_v2_base(images)
|
||||
elif model_type == 'inception_resnet_v2-same':
|
||||
_, end_points = inception.inception_resnet_v2_base(
|
||||
images, align_feature_maps=True)
|
||||
elif model_type == 'inception_v2':
|
||||
_, end_points = inception.inception_v2_base(images)
|
||||
elif model_type == 'inception_v2-no-separable-conv':
|
||||
_, end_points = inception.inception_v2_base(
|
||||
images, use_separable_conv=False)
|
||||
elif model_type == 'inception_v3':
|
||||
_, end_points = inception.inception_v3_base(images)
|
||||
elif model_type == 'inception_v4':
|
||||
_, end_points = inception.inception_v4_base(images)
|
||||
elif model_type == 'alexnet_v2':
|
||||
_, end_points = alexnet.alexnet_v2(images)
|
||||
elif model_type == 'vgg_a':
|
||||
_, end_points = vgg.vgg_a(images)
|
||||
elif model_type == 'vgg_16':
|
||||
_, end_points = vgg.vgg_16(images)
|
||||
elif model_type == 'mobilenet_v1':
|
||||
_, end_points = mobilenet_v1.mobilenet_v1_base(images)
|
||||
elif model_type == 'mobilenet_v1_075':
|
||||
_, end_points = mobilenet_v1.mobilenet_v1_base(
|
||||
images, depth_multiplier=0.75)
|
||||
elif model_type == 'resnet_v1_50':
|
||||
_, end_points = resnet_v1.resnet_v1_50(
|
||||
images, num_classes=None, is_training=False, global_pool=False)
|
||||
elif model_type == 'resnet_v1_101':
|
||||
_, end_points = resnet_v1.resnet_v1_101(
|
||||
images, num_classes=None, is_training=False, global_pool=False)
|
||||
elif model_type == 'resnet_v1_152':
|
||||
_, end_points = resnet_v1.resnet_v1_152(
|
||||
images, num_classes=None, is_training=False, global_pool=False)
|
||||
elif model_type == 'resnet_v1_200':
|
||||
_, end_points = resnet_v1.resnet_v1_200(
|
||||
images, num_classes=None, is_training=False, global_pool=False)
|
||||
elif model_type == 'resnet_v2_50':
|
||||
_, end_points = resnet_v2.resnet_v2_50(
|
||||
images, num_classes=None, is_training=False, global_pool=False)
|
||||
elif model_type == 'resnet_v2_101':
|
||||
_, end_points = resnet_v2.resnet_v2_101(
|
||||
images, num_classes=None, is_training=False, global_pool=False)
|
||||
elif model_type == 'resnet_v2_152':
|
||||
_, end_points = resnet_v2.resnet_v2_152(
|
||||
images, num_classes=None, is_training=False, global_pool=False)
|
||||
elif model_type == 'resnet_v2_200':
|
||||
_, end_points = resnet_v2.resnet_v2_200(
|
||||
images, num_classes=None, is_training=False, global_pool=False)
|
||||
else:
|
||||
raise ValueError('Unsupported model_type %s.' % model_type)
|
||||
|
||||
return end_points
|
||||
|
||||
|
||||
def _get_desired_end_point_keys(model_type='resnet_v1_50'):
|
||||
"""Gets list of desired end point keys for a type of CNN.
|
||||
|
||||
Args:
|
||||
model_type: Type of model to be used.
|
||||
|
||||
Returns:
|
||||
desired_end_point_types: A list containing the desired end-points.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model_type is not supported.
|
||||
"""
|
||||
if model_type in _SUPPORTED_RESNET_VARIANTS:
|
||||
blocks = ['block1', 'block2', 'block3', 'block4']
|
||||
desired_end_point_keys = ['%s/%s' % (model_type, i) for i in blocks]
|
||||
elif model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS:
|
||||
desired_end_point_keys = [
|
||||
'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3',
|
||||
'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b',
|
||||
'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1'
|
||||
]
|
||||
elif model_type in _SUPPORTED_INCEPTIONV2_VARIANTS:
|
||||
desired_end_point_keys = [
|
||||
'Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3',
|
||||
'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b',
|
||||
'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c'
|
||||
]
|
||||
elif model_type in _SUPPORTED_INCEPTIONV3_VARIANTS:
|
||||
desired_end_point_keys = [
|
||||
'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3',
|
||||
'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b',
|
||||
'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
|
||||
'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'
|
||||
]
|
||||
elif model_type in _SUPPORTED_INCEPTIONV4_VARIANTS:
|
||||
desired_end_point_keys = [
|
||||
'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a',
|
||||
'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_5e',
|
||||
'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_6f',
|
||||
'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'Mixed_7d'
|
||||
]
|
||||
elif model_type in _SUPPORTED_ALEXNETV2_VARIANTS:
|
||||
ep = ['conv1', 'pool1', 'conv2', 'conv3', 'conv4', 'conv5', 'pool5']
|
||||
desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep]
|
||||
elif model_type in _SUPPORTED_VGG_VARIANTS:
|
||||
ep = [
|
||||
'conv1/conv1_1', 'pool1', 'conv2/conv2_1', 'pool2', 'conv3/conv3_1',
|
||||
'conv3/conv3_2', 'pool3', 'conv4/conv4_1', 'conv4/conv4_2', 'pool4',
|
||||
'conv5/conv5_1', 'conv5/conv5_2', 'pool5'
|
||||
]
|
||||
desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep]
|
||||
elif model_type in _SUPPORTED_MOBILENETV1_VARIANTS:
|
||||
desired_end_point_keys = [
|
||||
'Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise',
|
||||
'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5_pointwise',
|
||||
'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise',
|
||||
'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise',
|
||||
'Conv2d_12_pointwise', 'Conv2d_13_pointwise'
|
||||
]
|
||||
else:
|
||||
raise ValueError('Unsupported model_type %s.' % model_type)
|
||||
|
||||
return desired_end_point_keys
|
||||
|
||||
|
||||
def _model_graph_def(model_type='resnet_v1_50', arg_sc=None):
|
||||
"""Constructs a model graph, returning GraphDef and end-points.
|
||||
|
||||
Args:
|
||||
model_type: Type of model to be used.
|
||||
arg_sc: Optional arg scope to use in constructing the graph.
|
||||
|
||||
Returns:
|
||||
graph_def: GraphDef of constructed graph.
|
||||
end_points: A dictionary from components of the network to the corresponding
|
||||
activations.
|
||||
"""
|
||||
if arg_sc is None:
|
||||
arg_sc = {}
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
with framework.arg_scope(arg_sc):
|
||||
end_points = _construct_model(model_type)
|
||||
|
||||
return g.as_graph_def(), end_points
|
||||
|
||||
|
||||
def _model_rf(graphdef,
|
||||
end_points,
|
||||
desired_end_point_keys,
|
||||
model_type='resnet_v1_50',
|
||||
csv_writer=None):
|
||||
"""Computes receptive field information for a given CNN model.
|
||||
|
||||
The information will be printed to stdout. If the RF parameters are the same
|
||||
for the horizontal and vertical directions, it will be printed only once.
|
||||
Otherwise, they are printed once for the horizontal and once for the vertical
|
||||
directions.
|
||||
|
||||
Args:
|
||||
graphdef: GraphDef of given model.
|
||||
end_points: A dictionary from components of the model to the corresponding
|
||||
activations.
|
||||
desired_end_point_keys: List of desired end points for which receptive field
|
||||
information will be computed.
|
||||
model_type: Type of model to be used, used only for printing purposes.
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
"""
|
||||
for desired_end_point_key in desired_end_point_keys:
|
||||
print('- %s:' % desired_end_point_key)
|
||||
output_node_with_colon = end_points[desired_end_point_key].name
|
||||
pos = output_node_with_colon.rfind(':')
|
||||
output_node = output_node_with_colon[:pos]
|
||||
(receptive_field_x, receptive_field_y, effective_stride_x,
|
||||
effective_stride_y, effective_padding_x, effective_padding_y
|
||||
) = receptive_field.compute_receptive_field_from_graph_def(
|
||||
graphdef, _INPUT_NODE, output_node)
|
||||
# If values are the same in horizontal/vertical directions, just report one
|
||||
# of them. Otherwise, report both.
|
||||
if (receptive_field_x == receptive_field_y) and (
|
||||
effective_stride_x == effective_stride_y) and (
|
||||
effective_padding_x == effective_padding_y):
|
||||
print('Receptive field size = %5s, effective stride = %5s, effective '
|
||||
'padding = %5s' % (str(receptive_field_x), str(effective_stride_x),
|
||||
str(effective_padding_x)))
|
||||
else:
|
||||
print('Receptive field size: horizontal = %5s, vertical = %5s. '
|
||||
'Effective stride: horizontal = %5s, vertical = %5s. Effective '
|
||||
'padding: horizontal = %5s, vertical = %5s' %
|
||||
(str(receptive_field_x), str(receptive_field_y),
|
||||
str(effective_stride_x), str(effective_stride_y),
|
||||
str(effective_padding_x), str(effective_padding_y)))
|
||||
if csv_writer is not None:
|
||||
csv_writer.writerow({
|
||||
'CNN': model_type,
|
||||
'end_point': desired_end_point_key,
|
||||
'RF size hor': str(receptive_field_x),
|
||||
'RF size ver': str(receptive_field_y),
|
||||
'effective stride hor': str(effective_stride_x),
|
||||
'effective stride ver': str(effective_stride_y),
|
||||
'effective padding hor': str(effective_padding_x),
|
||||
'effective padding ver': str(effective_padding_y)
|
||||
})
|
||||
|
||||
|
||||
def _process_model_rf(model_type='resnet_v1_50', csv_writer=None, arg_sc=None):
|
||||
"""Contructs model graph and desired end-points, and compute RF.
|
||||
|
||||
The computed RF parameters are printed to stdout by the _model_rf function.
|
||||
|
||||
Args:
|
||||
model_type: Type of model to be used.
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
arg_sc: Optional arg scope to use in constructing the graph.
|
||||
|
||||
"""
|
||||
print('********************%s' % model_type)
|
||||
graphdef, end_points = _model_graph_def(model_type, arg_sc)
|
||||
desired_end_point_keys = _get_desired_end_point_keys(model_type)
|
||||
_model_rf(graphdef, end_points, desired_end_point_keys, model_type,
|
||||
csv_writer)
|
||||
|
||||
|
||||
def _resnet_rf(csv_writer=None):
|
||||
"""Computes RF and associated parameters for resnet models.
|
||||
|
||||
The computed values are written to stdout.
|
||||
|
||||
Args:
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
"""
|
||||
for model_type in _SUPPORTED_RESNET_VARIANTS:
|
||||
arg_sc = resnet_v1.resnet_arg_scope()
|
||||
_process_model_rf(model_type, csv_writer, arg_sc)
|
||||
|
||||
|
||||
def _inception_resnet_v2_rf(csv_writer=None):
|
||||
"""Computes RF and associated parameters for the inception_resnet_v2 model.
|
||||
|
||||
The computed values are written to stdout.
|
||||
|
||||
Args:
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
"""
|
||||
for model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS:
|
||||
_process_model_rf(model_type, csv_writer)
|
||||
|
||||
|
||||
def _inception_v2_rf(csv_writer=None):
|
||||
"""Computes RF and associated parameters for the inception_v2 model.
|
||||
|
||||
The computed values are written to stdout.
|
||||
|
||||
Args:
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
"""
|
||||
for model_type in _SUPPORTED_INCEPTIONV2_VARIANTS:
|
||||
_process_model_rf(model_type, csv_writer)
|
||||
|
||||
|
||||
def _inception_v3_rf(csv_writer=None):
|
||||
"""Computes RF and associated parameters for the inception_v3 model.
|
||||
|
||||
The computed values are written to stdout.
|
||||
|
||||
Args:
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
"""
|
||||
for model_type in _SUPPORTED_INCEPTIONV3_VARIANTS:
|
||||
_process_model_rf(model_type, csv_writer)
|
||||
|
||||
|
||||
def _inception_v4_rf(csv_writer=None):
|
||||
"""Computes RF and associated parameters for the inception_v4 model.
|
||||
|
||||
The computed values are written to stdout.
|
||||
|
||||
Args:
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
"""
|
||||
for model_type in _SUPPORTED_INCEPTIONV4_VARIANTS:
|
||||
_process_model_rf(model_type, csv_writer)
|
||||
|
||||
|
||||
def _alexnet_v2_rf(csv_writer=None):
|
||||
"""Computes RF and associated parameters for the alexnet_v2 model.
|
||||
|
||||
The computed values are written to stdout.
|
||||
|
||||
Args:
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
"""
|
||||
for model_type in _SUPPORTED_ALEXNETV2_VARIANTS:
|
||||
_process_model_rf(model_type, csv_writer)
|
||||
|
||||
|
||||
def _vgg_rf(csv_writer=None):
|
||||
"""Computes RF and associated parameters for the vgg model.
|
||||
|
||||
The computed values are written to stdout.
|
||||
|
||||
Args:
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
"""
|
||||
for model_type in _SUPPORTED_VGG_VARIANTS:
|
||||
_process_model_rf(model_type, csv_writer)
|
||||
|
||||
|
||||
def _mobilenet_v1_rf(csv_writer=None):
|
||||
"""Computes RF and associated parameters for the mobilenet_v1 model.
|
||||
|
||||
The computed values are written to stdout.
|
||||
|
||||
Args:
|
||||
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
|
||||
"""
|
||||
for model_type in _SUPPORTED_MOBILENETV1_VARIANTS:
|
||||
with slim.arg_scope(
|
||||
[slim.batch_norm, slim.dropout], is_training=False) as arg_sc:
|
||||
_process_model_rf(model_type, csv_writer, arg_sc)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
# Configure CSV file which will be written, if desired.
|
||||
if cmd_args.csv_path:
|
||||
csv_file = open(cmd_args.csv_path, 'w')
|
||||
field_names = [
|
||||
'CNN', 'end_point', 'RF size hor', 'RF size ver',
|
||||
'effective stride hor', 'effective stride ver', 'effective padding hor',
|
||||
'effective padding ver'
|
||||
]
|
||||
rf_writer = csv.DictWriter(csv_file, fieldnames=field_names)
|
||||
rf_writer.writeheader()
|
||||
else:
|
||||
rf_writer = None
|
||||
|
||||
# Compute RF parameters for each network architecture.
|
||||
_alexnet_v2_rf(rf_writer)
|
||||
_vgg_rf(rf_writer)
|
||||
_inception_v2_rf(rf_writer)
|
||||
_inception_v3_rf(rf_writer)
|
||||
_inception_v4_rf(rf_writer)
|
||||
_inception_resnet_v2_rf(rf_writer)
|
||||
_mobilenet_v1_rf(rf_writer)
|
||||
_resnet_rf(rf_writer)
|
||||
|
||||
# Close CSV file, if it was opened.
|
||||
if cmd_args.csv_path:
|
||||
csv_file.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||
parser.add_argument(
|
||||
'--csv_path',
|
||||
type=str,
|
||||
default='',
|
||||
help="""\
|
||||
Path to CSV file that will be written with RF parameters.If empty, no
|
||||
file will be written.\
|
||||
""")
|
||||
cmd_args, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -0,0 +1,61 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Simple script to write Inception-ResNet-v2 model to graph file.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from nets import inception
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import graph_io
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
cmd_args = None
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
# Model definition.
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
images = array_ops.placeholder(
|
||||
dtypes.float32, shape=(1, None, None, 3), name='input_image')
|
||||
inception.inception_resnet_v2_base(images)
|
||||
|
||||
graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir,
|
||||
cmd_args.graph_filename)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||
parser.add_argument(
|
||||
'--graph_dir',
|
||||
type=str,
|
||||
default='/tmp',
|
||||
help='Directory where graph will be saved.')
|
||||
parser.add_argument(
|
||||
'--graph_filename',
|
||||
type=str,
|
||||
default='graph.pbtxt',
|
||||
help='Filename of graph that will be saved.')
|
||||
cmd_args, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -0,0 +1,88 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Library to compute order of computations in a graph.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
|
||||
class GraphDefHelper(object):
|
||||
"""Helper class to collect node names and definitions.
|
||||
|
||||
Example:
|
||||
b = GraphDefHelper(graph_def)
|
||||
# Prints node that produces given output.
|
||||
print b.output_of['conv/foo/bar']
|
||||
"""
|
||||
|
||||
def __init__(self, gd):
|
||||
self.output_of = {}
|
||||
for each in gd.node:
|
||||
self.output_of[each.name] = each
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
_NodeEntry = collections.namedtuple('NodeEntry', field_names=['order', 'node'])
|
||||
|
||||
|
||||
def _get_computed_nodes(g, output, seen):
|
||||
"""Traverses the graph in topological order.
|
||||
|
||||
Args:
|
||||
g: GraphDefHelper object.
|
||||
output: current node.
|
||||
seen: map of nodes we've already traversed.
|
||||
Returns:
|
||||
order in topological sort for 'output'.
|
||||
"""
|
||||
if output in seen:
|
||||
return seen[output].order
|
||||
node_def = g.output_of.get(output, None)
|
||||
if node_def is None:
|
||||
seen[output] = _NodeEntry(0, None)
|
||||
return 0
|
||||
|
||||
r = 0
|
||||
for each in node_def.input:
|
||||
# Parses name of input node.
|
||||
if each.startswith('^'):
|
||||
each = each[1:]
|
||||
each = each.split(':')[0]
|
||||
# Recursively computes ordering.
|
||||
new_v = _get_computed_nodes(g, each, seen)
|
||||
r = max(r, new_v + 1)
|
||||
|
||||
seen[output] = _NodeEntry(r, node_def)
|
||||
|
||||
return seen[output].order
|
||||
|
||||
|
||||
def get_compute_order(graph_def):
|
||||
"""Computes order of computation for a given graph.
|
||||
|
||||
Args:
|
||||
graph_def: GraphDef object.
|
||||
Returns:
|
||||
map: name -> {order, node}
|
||||
"""
|
||||
helper = GraphDefHelper(graph_def)
|
||||
seen = collections.defaultdict(_NodeEntry)
|
||||
for each in graph_def.node:
|
||||
_get_computed_nodes(helper, each.name, seen)
|
||||
return seen
|
@ -0,0 +1,485 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functions to compute receptive field of a fully-convolutional network.
|
||||
|
||||
Please refer to the following g3doc for detailed explanation on how this
|
||||
computation is performed, and why it is important:
|
||||
g3doc/photos/vision/features/delf/g3doc/rf_computation.md
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
from tensorflow.contrib.receptive_field.python.util import graph_compute_order
|
||||
from tensorflow.contrib.util import make_ndarray
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
# White-listed layer operations, which do not affect the receptive field
|
||||
# computation.
|
||||
_UNCHANGED_RF_LAYER_OPS = [
|
||||
"Softplus", "Relu", "BiasAdd", "Mul", "Add", "Const", "Identity",
|
||||
"VariableV2", "Sub", "Rsqrt", "ConcatV2"
|
||||
]
|
||||
|
||||
# Different ways in which padding modes may be spelled.
|
||||
_VALID_PADDING = ["VALID", b"VALID"]
|
||||
_SAME_PADDING = ["SAME", b"SAME"]
|
||||
|
||||
|
||||
def _stride_size(node):
|
||||
"""Computes stride size given a TF node.
|
||||
|
||||
Args:
|
||||
node: Tensorflow node (NodeDef proto).
|
||||
|
||||
Returns:
|
||||
stride_x: Stride size for horizontal direction (integer).
|
||||
stride_y: Stride size for vertical direction (integer).
|
||||
"""
|
||||
strides_attr = node.attr["strides"]
|
||||
logging.vlog(4, "strides_attr = %s", strides_attr)
|
||||
stride_y = strides_attr.list.i[1]
|
||||
stride_x = strides_attr.list.i[2]
|
||||
return stride_x, stride_y
|
||||
|
||||
|
||||
def _conv_kernel_size(node, name_to_order_node):
|
||||
"""Computes kernel size given a TF convolution or pooling node.
|
||||
|
||||
Args:
|
||||
node: Tensorflow node (NodeDef proto).
|
||||
name_to_order_node: Map from name to {order, node}. Output of
|
||||
graph_compute_order.get_compute_order().
|
||||
|
||||
Returns:
|
||||
kernel_size_x: Kernel size for horizontal direction (integer).
|
||||
kernel_size_y: Kernel size for vertical direction (integer).
|
||||
|
||||
Raises:
|
||||
ValueError: If the weight layer node is invalid.
|
||||
"""
|
||||
weights_layer_read_name = node.input[1]
|
||||
if not weights_layer_read_name.endswith("/read"):
|
||||
raise ValueError(
|
||||
"Weight layer's name input to conv layer does not end with '/read'")
|
||||
weights_layer_param_name = weights_layer_read_name[:-5]
|
||||
weights_node = name_to_order_node[weights_layer_param_name].node
|
||||
if weights_node.op != "VariableV2":
|
||||
raise ValueError("Weight layer is not of type VariableV2")
|
||||
shape = weights_node.attr["shape"]
|
||||
logging.vlog(4, "weight shape = %s", shape)
|
||||
kernel_size_y = shape.shape.dim[0].size
|
||||
kernel_size_x = shape.shape.dim[1].size
|
||||
return kernel_size_x, kernel_size_y
|
||||
|
||||
|
||||
def _padding_size_conv_pool(node, kernel_size, stride):
|
||||
"""Computes padding size given a TF convolution or pooling node.
|
||||
|
||||
Args:
|
||||
node: Tensorflow node (NodeDef proto).
|
||||
kernel_size: Kernel size of node (integer).
|
||||
stride: Stride size of node (integer).
|
||||
|
||||
Returns:
|
||||
padding: Padding size (integer).
|
||||
|
||||
Raises:
|
||||
ValueError: If padding is invalid.
|
||||
"""
|
||||
# In this case, we need to carefully consider the different TF padding modes.
|
||||
# The padding depends on kernel size, and may depend on input size. If it
|
||||
# depends on input size, we raise an exception.
|
||||
padding_attr = node.attr["padding"]
|
||||
logging.vlog(4, "padding_attr = %s", padding_attr)
|
||||
if padding_attr.s in _VALID_PADDING:
|
||||
padding = 0
|
||||
elif padding_attr.s in _SAME_PADDING:
|
||||
if kernel_size == 1:
|
||||
padding = 0
|
||||
elif stride == 1:
|
||||
padding = int(math.floor((float(kernel_size) - 1) / 2))
|
||||
elif stride == 2 and kernel_size % 2 == 0:
|
||||
padding = int(math.floor((float(kernel_size) - 1) / 2))
|
||||
else:
|
||||
padding = None
|
||||
logging.warning(
|
||||
"Padding depends on input size, which means that the effective "
|
||||
"padding may be different depending on the input image "
|
||||
"dimensionality. In this case, alignment check will be skipped.")
|
||||
else:
|
||||
raise ValueError("Invalid padding operation %s" % padding_attr.s)
|
||||
return padding
|
||||
|
||||
|
||||
def _pool_kernel_size(node):
|
||||
"""Computes kernel size given a TF pooling node.
|
||||
|
||||
Args:
|
||||
node: Tensorflow node (NodeDef proto).
|
||||
|
||||
Returns:
|
||||
kernel_size_x: Kernel size for horizontal direction (integer).
|
||||
kernel_size_y: Kernel size for vertical direction (integer).
|
||||
|
||||
Raises:
|
||||
ValueError: If pooling is invalid.
|
||||
"""
|
||||
ksize = node.attr["ksize"]
|
||||
kernel_size_y = ksize.list.i[1]
|
||||
kernel_size_x = ksize.list.i[2]
|
||||
if ksize.list.i[0] != 1:
|
||||
raise ValueError("pool ksize for first dim is not 1")
|
||||
if ksize.list.i[3] != 1:
|
||||
raise ValueError("pool ksize for last dim is not 1")
|
||||
return kernel_size_x, kernel_size_y
|
||||
|
||||
|
||||
def _padding_size_pad_layer(node, name_to_order_node):
|
||||
"""Computes padding size given a TF padding node.
|
||||
|
||||
Args:
|
||||
node: Tensorflow node (NodeDef proto).
|
||||
name_to_order_node: Map from name to {order, node}. Output of
|
||||
graph_compute_order.get_compute_order().
|
||||
|
||||
Returns:
|
||||
padding_x: Padding size for horizontal direction (integer).
|
||||
padding_y: Padding size for vertical direction (integer).
|
||||
|
||||
Raises:
|
||||
ValueError: If padding layer is invalid.
|
||||
"""
|
||||
paddings_layer_name = node.input[1]
|
||||
if not paddings_layer_name.endswith("/paddings"):
|
||||
raise ValueError("Padding layer name does not end with '/paddings'")
|
||||
paddings_node = name_to_order_node[paddings_layer_name].node
|
||||
if paddings_node.op != "Const":
|
||||
raise ValueError("Padding op is not Const")
|
||||
value = paddings_node.attr["value"]
|
||||
t = make_ndarray(value.tensor)
|
||||
padding_y = t[1][0]
|
||||
padding_x = t[2][0]
|
||||
if t[0][0] != 0:
|
||||
raise ValueError("padding is not zero for first tensor dim")
|
||||
if t[3][0] != 0:
|
||||
raise ValueError("padding is not zero for last tensor dim")
|
||||
return padding_x, padding_y
|
||||
|
||||
|
||||
def _get_layer_params(node, name_to_order_node):
|
||||
"""Gets layer parameters relevant for RF computation.
|
||||
|
||||
Currently, only these nodes are supported:
|
||||
- Conv2D
|
||||
- DepthwiseConv2dNative
|
||||
- Pad
|
||||
- MaxPool
|
||||
- AvgPool
|
||||
- all nodes listed in _UNCHANGED_RF_LAYER_OPS
|
||||
|
||||
Args:
|
||||
node: Tensorflow node (NodeDef proto).
|
||||
name_to_order_node: Map from name to {order, node}. Output of
|
||||
graph_compute_order.get_compute_order().
|
||||
|
||||
Returns:
|
||||
kernel_size_x: Kernel size for horizontal direction (integer).
|
||||
kernel_size_y: Kernel size for vertical direction (integer).
|
||||
stride_x: Stride size for horizontal direction (integer).
|
||||
stride_y: Stride size for vertical direction (integer).
|
||||
padding_x: Padding size for horizontal direction (integer).
|
||||
padding_y: Padding size for vertical direction (integer).
|
||||
|
||||
Raises:
|
||||
ValueError: If layer op is unknown.
|
||||
"""
|
||||
logging.vlog(3, "node.op = %s", node.op)
|
||||
logging.vlog(4, "node = %s", node)
|
||||
if node.op == "Conv2D" or node.op == "DepthwiseConv2dNative":
|
||||
stride_x, stride_y = _stride_size(node)
|
||||
kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_order_node)
|
||||
# Compute the padding for this node separately for each direction.
|
||||
padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x)
|
||||
padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y)
|
||||
elif node.op == "Pad":
|
||||
# Kernel and stride are simply 1 in this case.
|
||||
kernel_size_x = 1
|
||||
kernel_size_y = 1
|
||||
stride_x = 1
|
||||
stride_y = 1
|
||||
padding_x, padding_y = _padding_size_pad_layer(node, name_to_order_node)
|
||||
elif node.op == "MaxPool" or node.op == "AvgPool":
|
||||
stride_x, stride_y = _stride_size(node)
|
||||
kernel_size_x, kernel_size_y = _pool_kernel_size(node)
|
||||
# Compute the padding for this node separately for each direction.
|
||||
padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x)
|
||||
padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y)
|
||||
elif node.op in _UNCHANGED_RF_LAYER_OPS:
|
||||
# These nodes do not modify the RF parameters.
|
||||
kernel_size_x = 1
|
||||
kernel_size_y = 1
|
||||
stride_x = 1
|
||||
stride_y = 1
|
||||
padding_x = 0
|
||||
padding_y = 0
|
||||
else:
|
||||
raise ValueError("Unknown layer op: %s" % node.op)
|
||||
return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y
|
||||
|
||||
|
||||
def _reverse_sort_by_order(name_to_order_node):
|
||||
"""Sorts map of name_to_order_node nodes in reverse order.
|
||||
|
||||
The output is such that the nodes in name_to_order_node are sorted in
|
||||
descending order of the "order" field.
|
||||
|
||||
Args:
|
||||
name_to_order_node: Map from name to {order, node}. Output of
|
||||
graph_compute_order.get_compute_order().
|
||||
|
||||
Returns:
|
||||
sorted_name_to_order_node: Sorted version of the input, in descending order.
|
||||
"""
|
||||
return sorted(name_to_order_node.items(), key=lambda x: -x[1].order)
|
||||
|
||||
|
||||
def _get_rf_size_node_input(stride, kernel_size, rf_size_output):
|
||||
"""Computes RF size at the input of a given layer.
|
||||
|
||||
Args:
|
||||
stride: Stride of given layer (integer).
|
||||
kernel_size: Kernel size of given layer (integer).
|
||||
rf_size_output: RF size at output of given layer (integer).
|
||||
|
||||
Returns:
|
||||
rf_size_input: RF size at input of given layer (integer).
|
||||
"""
|
||||
return stride * rf_size_output + kernel_size - stride
|
||||
|
||||
|
||||
def _get_effective_stride_node_input(stride, effective_stride_output):
|
||||
"""Computes effective stride at the input of a given layer.
|
||||
|
||||
Args:
|
||||
stride: Stride of given layer (integer).
|
||||
effective_stride_output: Effective stride at output of given layer
|
||||
(integer).
|
||||
|
||||
Returns:
|
||||
effective_stride_input: Effective stride at input of given layer
|
||||
(integer).
|
||||
"""
|
||||
return stride * effective_stride_output
|
||||
|
||||
|
||||
def _get_effective_padding_node_input(stride, padding,
|
||||
effective_padding_output):
|
||||
"""Computes effective padding at the input of a given layer.
|
||||
|
||||
Args:
|
||||
stride: Stride of given layer (integer).
|
||||
padding: Padding of given layer (integer).
|
||||
effective_padding_output: Effective padding at output of given layer
|
||||
(integer).
|
||||
|
||||
Returns:
|
||||
effective_padding_input: Effective padding at input of given layer
|
||||
(integer).
|
||||
"""
|
||||
return stride * effective_padding_output + padding
|
||||
|
||||
|
||||
def compute_receptive_field_from_graph_def(graph_def, input_node, output_node):
|
||||
"""Computes receptive field (RF) parameters from a GraphDef object.
|
||||
|
||||
Args:
|
||||
graph_def: GraphDef object.
|
||||
input_node: Name of the input node from graph.
|
||||
output_node: Name of the output node from graph.
|
||||
|
||||
Returns:
|
||||
rf_size_x: Receptive field size of network in the horizontal direction, with
|
||||
respect to specified input and output.
|
||||
rf_size_y: Receptive field size of network in the vertical direction, with
|
||||
respect to specified input and output.
|
||||
effective_stride_x: Effective stride of network in the horizontal direction,
|
||||
with respect to specified input and output.
|
||||
effective_stride_y: Effective stride of network in the vertical direction,
|
||||
with respect to specified input and output.
|
||||
effective_padding_x: Effective padding of network in the horizontal
|
||||
direction, with respect to specified input and output.
|
||||
effective_padding_y: Effective padding of network in the vertical
|
||||
direction, with respect to specified input and output.
|
||||
|
||||
Raises:
|
||||
ValueError: If network is not aligned or if either input or output nodes
|
||||
cannot be found. For network criterion alignment, see
|
||||
photos/vision/features/delf/g3doc/rf_computation.md
|
||||
"""
|
||||
# Computes order of computation for a given graph.
|
||||
name_to_order_node = graph_compute_order.get_compute_order(
|
||||
graph_def=graph_def)
|
||||
|
||||
# Sort in reverse topological order.
|
||||
order = _reverse_sort_by_order(name_to_order_node)
|
||||
|
||||
# Dictionaries to keep track of receptive field, effective stride and
|
||||
# effective padding of different nodes.
|
||||
rf_sizes_x = {}
|
||||
rf_sizes_y = {}
|
||||
effective_strides_x = {}
|
||||
effective_strides_y = {}
|
||||
effective_paddings_x = {}
|
||||
effective_paddings_y = {}
|
||||
|
||||
# Initialize dicts for output_node.
|
||||
rf_sizes_x[output_node] = 1
|
||||
rf_sizes_y[output_node] = 1
|
||||
effective_strides_x[output_node] = 1
|
||||
effective_strides_y[output_node] = 1
|
||||
effective_paddings_x[output_node] = 0
|
||||
effective_paddings_y[output_node] = 0
|
||||
|
||||
# Flag to denote if we found output node yet. If we have not, we skip nodes
|
||||
# until the output node is found.
|
||||
found_output_node = False
|
||||
|
||||
# Flag to denote if padding is undefined. This happens when SAME padding mode
|
||||
# is used in conjunction with stride and kernel sizes which make it such that
|
||||
# the padding to be applied would depend on the input size. In this case,
|
||||
# alignment checks are skipped, and the effective padding is None.
|
||||
undefined_padding = False
|
||||
|
||||
for _, (o, node) in order:
|
||||
if node:
|
||||
logging.vlog(3, "%10d %-100s %-20s" % (o, node.name[:90], node.op))
|
||||
else:
|
||||
continue
|
||||
|
||||
# When we find input node, we can stop.
|
||||
if node.name == input_node:
|
||||
break
|
||||
|
||||
# Loop until we find the output node. All nodes before finding the output
|
||||
# one are irrelevant, so they can be skipped.
|
||||
if not found_output_node:
|
||||
if node.name == output_node:
|
||||
found_output_node = True
|
||||
|
||||
if found_output_node:
|
||||
if node.name not in rf_sizes_x:
|
||||
assert node.name not in rf_sizes_y, ("Node %s is in rf_sizes_y, but "
|
||||
"not in rf_sizes_x" % node.name)
|
||||
# In this case, node is not relevant since it's not part of the
|
||||
# computation we're interested in.
|
||||
logging.vlog(3, "Irrelevant node %s, skipping it...", node.name)
|
||||
continue
|
||||
|
||||
# Get params for this layer.
|
||||
kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y = (
|
||||
_get_layer_params(node, name_to_order_node))
|
||||
logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, "
|
||||
"stride_x = %s, stride_y = %s, "
|
||||
"padding_x = %s, padding_y = %s" %
|
||||
(kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x,
|
||||
padding_y))
|
||||
if padding_x is None or padding_y is None:
|
||||
undefined_padding = True
|
||||
|
||||
# Get parameters at input of this layer which may or may not be propagated
|
||||
# to the input layers.
|
||||
rf_size_input_x = _get_rf_size_node_input(stride_x, kernel_size_x,
|
||||
rf_sizes_x[node.name])
|
||||
rf_size_input_y = _get_rf_size_node_input(stride_y, kernel_size_y,
|
||||
rf_sizes_y[node.name])
|
||||
effective_stride_input_x = _get_effective_stride_node_input(
|
||||
stride_x, effective_strides_x[node.name])
|
||||
effective_stride_input_y = _get_effective_stride_node_input(
|
||||
stride_y, effective_strides_y[node.name])
|
||||
if not undefined_padding:
|
||||
effective_padding_input_x = _get_effective_padding_node_input(
|
||||
stride_x, padding_x, effective_paddings_x[node.name])
|
||||
effective_padding_input_y = _get_effective_padding_node_input(
|
||||
stride_y, padding_y, effective_paddings_y[node.name])
|
||||
else:
|
||||
effective_padding_input_x = None
|
||||
effective_padding_input_y = None
|
||||
|
||||
# Loop over this node's inputs and potentially propagate information down.
|
||||
for inp_name in node.input:
|
||||
logging.vlog(4, "inp_name = %s", inp_name)
|
||||
inp_node = name_to_order_node[inp_name].node
|
||||
logging.vlog(4, "inp_node = \n%s", inp_node)
|
||||
if inp_node.name in rf_sizes_x:
|
||||
assert inp_node.name in rf_sizes_y, (
|
||||
"Node %s is in rf_sizes_x, but "
|
||||
"not in rf_sizes_y" % inp_node.name)
|
||||
# This node was already discovered through a previous path, so we need
|
||||
# to make sure that graph is aligned. This alignment check is skipped
|
||||
# if the padding is not defined, since in this case alignment cannot
|
||||
# be checked.
|
||||
if not undefined_padding:
|
||||
if effective_strides_x[inp_node.name] != effective_stride_input_x:
|
||||
raise ValueError(
|
||||
"Graph is not aligned since effective stride from different "
|
||||
"paths is different in horizontal direction")
|
||||
if effective_strides_y[inp_node.name] != effective_stride_input_y:
|
||||
raise ValueError(
|
||||
"Graph is not aligned since effective stride from different "
|
||||
"paths is different in vertical direction")
|
||||
if (rf_sizes_x[inp_node.name] - 1
|
||||
) / 2 - effective_paddings_x[inp_node.name] != (
|
||||
rf_size_input_x - 1) / 2 - effective_padding_input_x:
|
||||
raise ValueError(
|
||||
"Graph is not aligned since center shift from different "
|
||||
"paths is different in horizontal direction")
|
||||
if (rf_sizes_y[inp_node.name] - 1
|
||||
) / 2 - effective_paddings_y[inp_node.name] != (
|
||||
rf_size_input_y - 1) / 2 - effective_padding_input_y:
|
||||
raise ValueError(
|
||||
"Graph is not aligned since center shift from different "
|
||||
"paths is different in vertical direction")
|
||||
# Keep track of path with largest RF, for both directions.
|
||||
if rf_sizes_x[inp_node.name] < rf_size_input_x:
|
||||
rf_sizes_x[inp_node.name] = rf_size_input_x
|
||||
effective_strides_x[inp_node.name] = effective_stride_input_x
|
||||
effective_paddings_x[inp_node.name] = effective_padding_input_x
|
||||
if rf_sizes_y[inp_node.name] < rf_size_input_y:
|
||||
rf_sizes_y[inp_node.name] = rf_size_input_y
|
||||
effective_strides_y[inp_node.name] = effective_stride_input_y
|
||||
effective_paddings_y[inp_node.name] = effective_padding_input_y
|
||||
else:
|
||||
assert inp_node.name not in rf_sizes_y, (
|
||||
"Node %s is in rf_sizes_y, but "
|
||||
"not in rf_sizes_x" % inp_node.name)
|
||||
# In this case, it is the first time we encounter this node. So we
|
||||
# propagate the RF parameters.
|
||||
rf_sizes_x[inp_node.name] = rf_size_input_x
|
||||
rf_sizes_y[inp_node.name] = rf_size_input_y
|
||||
effective_strides_x[inp_node.name] = effective_stride_input_x
|
||||
effective_strides_y[inp_node.name] = effective_stride_input_y
|
||||
effective_paddings_x[inp_node.name] = effective_padding_input_x
|
||||
effective_paddings_y[inp_node.name] = effective_padding_input_y
|
||||
|
||||
if not found_output_node:
|
||||
raise ValueError("Output node was not found")
|
||||
if input_node not in rf_sizes_x:
|
||||
raise ValueError("Input node was not found")
|
||||
return (rf_sizes_x[input_node], rf_sizes_y[input_node],
|
||||
effective_strides_x[input_node], effective_strides_y[input_node],
|
||||
effective_paddings_x[input_node], effective_paddings_y[input_node])
|
@ -0,0 +1,225 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for receptive_fields module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib import slim
|
||||
from tensorflow.contrib.receptive_field.python.util import receptive_field
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def create_test_network_1():
|
||||
"""Aligned network for test.
|
||||
|
||||
The graph corresponds to the example from the second figure in
|
||||
go/cnn-rf-computation#arbitrary-computation-graphs
|
||||
|
||||
Returns:
|
||||
g: Tensorflow graph object (Graph proto).
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
# An 8x8 test image.
|
||||
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
|
||||
# Left branch.
|
||||
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
|
||||
# Right branch.
|
||||
l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
|
||||
l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
|
||||
l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
|
||||
# Addition.
|
||||
nn.relu(l1 + l3, name='output')
|
||||
return g
|
||||
|
||||
|
||||
def create_test_network_2():
|
||||
"""Aligned network for test.
|
||||
|
||||
The graph corresponds to a variation to the example from the second figure in
|
||||
go/cnn-rf-computation#arbitrary-computation-graphs. Layers 2 and 3 are changed
|
||||
to max-pooling operations. Since the functionality is the same as convolution,
|
||||
the network is aligned and the receptive field size is the same as from the
|
||||
network created using create_test_network_1().
|
||||
|
||||
Returns:
|
||||
g: Tensorflow graph object (Graph proto).
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
# An 8x8 test image.
|
||||
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
|
||||
# Left branch.
|
||||
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
|
||||
# Right branch.
|
||||
l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
|
||||
l2 = slim.max_pool2d(l2_pad, [3, 3], stride=2, scope='L2', padding='VALID')
|
||||
l3 = slim.max_pool2d(l2, [1, 1], stride=2, scope='L3', padding='VALID')
|
||||
# Addition.
|
||||
nn.relu(l1 + l3, name='output')
|
||||
return g
|
||||
|
||||
|
||||
def create_test_network_3():
|
||||
"""Misaligned network for test.
|
||||
|
||||
The graph corresponds to the example from the first figure in
|
||||
go/cnn-rf-computation#arbitrary-computation-graphs
|
||||
|
||||
Returns:
|
||||
g: Tensorflow graph object (Graph proto).
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
# An 8x8 test image.
|
||||
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
|
||||
# Left branch.
|
||||
l1_pad = array_ops.pad(x, [[0, 0], [2, 1], [2, 1], [0, 0]])
|
||||
l1 = slim.conv2d(l1_pad, 1, [5, 5], stride=2, scope='L1', padding='VALID')
|
||||
# Right branch.
|
||||
l2 = slim.conv2d(x, 1, [3, 3], stride=1, scope='L2', padding='VALID')
|
||||
l3 = slim.conv2d(l2, 1, [3, 3], stride=1, scope='L3', padding='VALID')
|
||||
# Addition.
|
||||
nn.relu(l1 + l3, name='output')
|
||||
return g
|
||||
|
||||
|
||||
def create_test_network_4():
|
||||
"""Misaligned network for test.
|
||||
|
||||
The graph corresponds to a variation from the example from the second figure
|
||||
in go/cnn-rf-computation#arbitrary-computation-graphs. Layer 2 uses 'SAME'
|
||||
padding, which makes its padding dependent on the input image dimensionality.
|
||||
In this case, the effective padding will be undetermined, and the utility is
|
||||
not able to check the network alignment.
|
||||
|
||||
Returns:
|
||||
g: Tensorflow graph object (Graph proto).
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
# An 8x8 test image.
|
||||
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
|
||||
# Left branch.
|
||||
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
|
||||
# Right branch.
|
||||
l2 = slim.conv2d(x, 1, [3, 3], stride=2, scope='L2', padding='SAME')
|
||||
l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
|
||||
# Addition.
|
||||
nn.relu(l1 + l3, name='output')
|
||||
return g
|
||||
|
||||
|
||||
def create_test_network_5():
|
||||
"""Single-path network for testing non-square kernels.
|
||||
|
||||
The graph is similar to the right branch of the graph from
|
||||
create_test_network_1(), except that the kernel sizes are changed to be
|
||||
non-square.
|
||||
|
||||
Returns:
|
||||
g: Tensorflow graph object (Graph proto).
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
# An 8x8 test image.
|
||||
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
|
||||
# Two convolutional layers, where the first one has non-square kernel.
|
||||
l1 = slim.conv2d(x, 1, [3, 5], stride=2, scope='L1', padding='VALID')
|
||||
l2 = slim.conv2d(l1, 1, [3, 1], stride=2, scope='L2', padding='VALID')
|
||||
# ReLU.
|
||||
nn.relu(l2, name='output')
|
||||
return g
|
||||
|
||||
|
||||
class RfUtilsTest(test.TestCase):
|
||||
|
||||
def testComputeRFFromGraphDefAligned(self):
|
||||
graph_def = create_test_network_1().as_graph_def()
|
||||
input_node = 'input_image'
|
||||
output_node = 'output'
|
||||
(receptive_field_x, receptive_field_y, effective_stride_x,
|
||||
effective_stride_y, effective_padding_x, effective_padding_y) = (
|
||||
receptive_field.compute_receptive_field_from_graph_def(
|
||||
graph_def, input_node, output_node))
|
||||
self.assertEqual(receptive_field_x, 3)
|
||||
self.assertEqual(receptive_field_y, 3)
|
||||
self.assertEqual(effective_stride_x, 4)
|
||||
self.assertEqual(effective_stride_y, 4)
|
||||
self.assertEqual(effective_padding_x, 1)
|
||||
self.assertEqual(effective_padding_y, 1)
|
||||
|
||||
def testComputeRFFromGraphDefAligned2(self):
|
||||
graph_def = create_test_network_2().as_graph_def()
|
||||
input_node = 'input_image'
|
||||
output_node = 'output'
|
||||
(receptive_field_x, receptive_field_y, effective_stride_x,
|
||||
effective_stride_y, effective_padding_x, effective_padding_y) = (
|
||||
receptive_field.compute_receptive_field_from_graph_def(
|
||||
graph_def, input_node, output_node))
|
||||
self.assertEqual(receptive_field_x, 3)
|
||||
self.assertEqual(receptive_field_y, 3)
|
||||
self.assertEqual(effective_stride_x, 4)
|
||||
self.assertEqual(effective_stride_y, 4)
|
||||
self.assertEqual(effective_padding_x, 1)
|
||||
self.assertEqual(effective_padding_y, 1)
|
||||
|
||||
def testComputeRFFromGraphDefUnaligned(self):
|
||||
graph_def = create_test_network_3().as_graph_def()
|
||||
input_node = 'input_image'
|
||||
output_node = 'output'
|
||||
with self.assertRaises(ValueError):
|
||||
receptive_field.compute_receptive_field_from_graph_def(
|
||||
graph_def, input_node, output_node)
|
||||
|
||||
def testComputeRFFromGraphDefUnaligned2(self):
|
||||
graph_def = create_test_network_4().as_graph_def()
|
||||
input_node = 'input_image'
|
||||
output_node = 'output'
|
||||
(receptive_field_x, receptive_field_y, effective_stride_x,
|
||||
effective_stride_y, effective_padding_x, effective_padding_y) = (
|
||||
receptive_field.compute_receptive_field_from_graph_def(
|
||||
graph_def, input_node, output_node))
|
||||
self.assertEqual(receptive_field_x, 3)
|
||||
self.assertEqual(receptive_field_y, 3)
|
||||
self.assertEqual(effective_stride_x, 4)
|
||||
self.assertEqual(effective_stride_y, 4)
|
||||
self.assertEqual(effective_padding_x, None)
|
||||
self.assertEqual(effective_padding_y, None)
|
||||
|
||||
def testComputeRFFromGraphDefNonSquareRF(self):
|
||||
graph_def = create_test_network_5().as_graph_def()
|
||||
input_node = 'input_image'
|
||||
output_node = 'output'
|
||||
(receptive_field_x, receptive_field_y, effective_stride_x,
|
||||
effective_stride_y, effective_padding_x, effective_padding_y) = (
|
||||
receptive_field.compute_receptive_field_from_graph_def(
|
||||
graph_def, input_node, output_node))
|
||||
self.assertEqual(receptive_field_x, 5)
|
||||
self.assertEqual(receptive_field_y, 7)
|
||||
self.assertEqual(effective_stride_x, 4)
|
||||
self.assertEqual(effective_stride_y, 4)
|
||||
self.assertEqual(effective_padding_x, 0)
|
||||
self.assertEqual(effective_padding_y, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
59
tensorflow/contrib/summary/BUILD
Normal file
59
tensorflow/contrib/summary/BUILD
Normal file
@ -0,0 +1,59 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
])
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"py_test",
|
||||
"tf_gen_op_wrapper_py",
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_summary_ops",
|
||||
out = "gen_summary_ops.py",
|
||||
deps = ["//tensorflow/core:summary_ops_op_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "summary_ops_test",
|
||||
srcs = ["summary_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":summary_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "summary_ops",
|
||||
srcs = ["summary_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":gen_summary_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:summary_op_util",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
159
tensorflow/contrib/summary/summary_ops.py
Normal file
159
tensorflow/contrib/summary/summary_ops.py
Normal file
@ -0,0 +1,159 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Operations to emit summaries."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.summary import gen_summary_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import summary_op_util
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
# Name for a collection which is expected to have at most a single boolean
|
||||
# Tensor. If this tensor is True the summary ops will record summaries.
|
||||
_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
|
||||
|
||||
|
||||
def should_record_summaries():
|
||||
"""Returns boolean Tensor which is true if summaries should be recorded."""
|
||||
should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||
if not should_record_collection:
|
||||
return constant_op.constant(False)
|
||||
if len(should_record_collection) != 1:
|
||||
raise ValueError(
|
||||
"More than one tensor specified for whether summaries "
|
||||
"should be recorded: %s" % should_record_collection)
|
||||
return should_record_collection[0]
|
||||
|
||||
|
||||
# TODO(apassos) consider how to handle local step here.
|
||||
def record_summaries_every_n_global_steps(n):
|
||||
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
|
||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||
collection_ref[:] = [training_util.get_global_step() % n == 0]
|
||||
|
||||
|
||||
def always_record_summaries():
|
||||
"""Sets the should_record_summaries Tensor to always true."""
|
||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||
collection_ref[:] = [constant_op.constant(True)]
|
||||
|
||||
|
||||
def never_record_summaries():
|
||||
"""Sets the should_record_summaries Tensor to always false."""
|
||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||
collection_ref[:] = [constant_op.constant(False)]
|
||||
|
||||
|
||||
def create_summary_file_writer(logdir,
|
||||
max_queue=None,
|
||||
flush_secs=None,
|
||||
filename_suffix=None):
|
||||
"""Creates a summary file writer in the current context."""
|
||||
if max_queue is None:
|
||||
max_queue = constant_op.constant(10)
|
||||
if flush_secs is None:
|
||||
flush_secs = constant_op.constant(120)
|
||||
if filename_suffix is None:
|
||||
filename_suffix = constant_op.constant("")
|
||||
resource = gen_summary_ops.summary_writer()
|
||||
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
|
||||
flush_secs, filename_suffix)
|
||||
context.context().summary_writer_resource = resource
|
||||
|
||||
|
||||
def _nothing():
|
||||
"""Convenient else branch for when summaries do not record."""
|
||||
return
|
||||
|
||||
|
||||
def generic(name, tensor, metadata, family=None):
|
||||
"""Writes a tensor summary if possible."""
|
||||
|
||||
def record():
|
||||
with summary_op_util.summary_scope(
|
||||
name, family, values=[tensor]) as (tag, scope):
|
||||
gen_summary_ops.write_summary(context.context().summary_writer_resource,
|
||||
training_util.get_global_step(), tensor,
|
||||
tag, metadata, name=scope)
|
||||
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
|
||||
|
||||
|
||||
def scalar(name, tensor, family=None):
|
||||
"""Writes a scalar summary if possible."""
|
||||
|
||||
def record():
|
||||
with summary_op_util.summary_scope(
|
||||
name, family, values=[tensor]) as (tag, scope):
|
||||
gen_summary_ops.write_scalar_summary(
|
||||
context.context().summary_writer_resource,
|
||||
training_util.get_global_step(), tag, tensor, name=scope)
|
||||
|
||||
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
|
||||
|
||||
|
||||
def histogram(name, tensor, family=None):
|
||||
"""Writes a histogram summary if possible."""
|
||||
|
||||
def record():
|
||||
with summary_op_util.summary_scope(
|
||||
name, family, values=[tensor]) as (tag, scope):
|
||||
gen_summary_ops.write_histogram_summary(
|
||||
context.context().summary_writer_resource,
|
||||
training_util.get_global_step(), tag, tensor, name=scope)
|
||||
|
||||
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
|
||||
|
||||
|
||||
def image(name, tensor, bad_color=None, max_images=3, family=None):
|
||||
"""Writes an image summary if possible."""
|
||||
|
||||
def record():
|
||||
if bad_color is None:
|
||||
bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
|
||||
with summary_op_util.summary_scope(
|
||||
name, family, values=[tensor]) as (tag, scope):
|
||||
gen_summary_ops.write_image_summary(
|
||||
context.context().summary_writer_resource,
|
||||
training_util.get_global_step(), tag, tensor, bad_color_, max_images,
|
||||
name=scope)
|
||||
|
||||
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
|
||||
|
||||
|
||||
def audio(name, tensor, sample_rate, max_outputs, family=None):
|
||||
"""Writes an audio summary if possible."""
|
||||
|
||||
def record():
|
||||
with summary_op_util.summary_scope(
|
||||
name, family, values=[tensor]) as (tag, scope):
|
||||
gen_summary_ops.write_audio_summary(
|
||||
context.context().summary_writer_resource,
|
||||
training_util.get_global_step(),
|
||||
tag,
|
||||
tensor,
|
||||
sample_rate=sample_rate,
|
||||
max_outputs=max_outputs,
|
||||
name=scope)
|
||||
|
||||
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
|
52
tensorflow/contrib/summary/summary_ops_test.py
Normal file
52
tensorflow/contrib/summary/summary_ops_test.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from tensorflow.contrib.summary import summary_ops
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
class TargetTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testShouldRecordSummary(self):
|
||||
self.assertFalse(summary_ops.should_record_summaries().numpy())
|
||||
summary_ops.always_record_summaries()
|
||||
self.assertTrue(summary_ops.should_record_summaries().numpy())
|
||||
|
||||
def testSummaryOps(self):
|
||||
training_util.get_or_create_global_step()
|
||||
logdir = tempfile.mkdtemp()
|
||||
summary_ops.create_summary_file_writer(logdir, max_queue=0)
|
||||
summary_ops.always_record_summaries()
|
||||
summary_ops.generic('tensor', 1, '')
|
||||
summary_ops.scalar('scalar', 2.0)
|
||||
summary_ops.histogram('histogram', [1.0])
|
||||
summary_ops.image('image', [[[[1.0]]]])
|
||||
summary_ops.audio('audio', [[1.0]], 1.0, 1)
|
||||
# The working condition of the ops is tested in the C++ test so we just
|
||||
# test here that we're calling them correctly.
|
||||
self.assertTrue(gfile.Exists(logdir))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user