Merge pull request #12780 from martinwicke/branch_167401527

Branch 167401527
This commit is contained in:
Shanqing Cai 2017-09-03 18:02:50 -04:00 committed by GitHub
commit 512d3d0868
202 changed files with 12054 additions and 3197 deletions

View File

@ -296,6 +296,7 @@ filegroup(
"//tensorflow/contrib/ffmpeg/default:all_files",
"//tensorflow/contrib/framework:all_files",
"//tensorflow/contrib/fused_conv:all_files",
"//tensorflow/contrib/gan:all_files",
"//tensorflow/contrib/graph_editor:all_files",
"//tensorflow/contrib/grid_rnn:all_files",
"//tensorflow/contrib/hooks:all_files",
@ -323,6 +324,7 @@ filegroup(
"//tensorflow/contrib/nn:all_files",
"//tensorflow/contrib/opt:all_files",
"//tensorflow/contrib/predictor:all_files",
"//tensorflow/contrib/receptive_field:all_files",
"//tensorflow/contrib/reduce_slice_ops:all_files",
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
"//tensorflow/contrib/resampler:all_files",
@ -342,6 +344,7 @@ filegroup(
"//tensorflow/contrib/staging:all_files",
"//tensorflow/contrib/stat_summarizer:all_files",
"//tensorflow/contrib/stateless:all_files",
"//tensorflow/contrib/summary:all_files",
"//tensorflow/contrib/tensor_forest:all_files",
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
"//tensorflow/contrib/tensor_forest/kernels/v4:all_files",

View File

@ -45,8 +45,13 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api",
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
srcs = [
"c_api.cc",
"c_api_function.cc",
],
hdrs = [
"c_api.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
@ -157,6 +162,21 @@ tf_cc_test(
],
)
tf_cc_test(
name = "c_api_function_test",
size = "small",
srcs = ["c_api_function_test.cc"],
deps = [
":c_api",
":c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "while_loop_test",
size = "small",

View File

@ -165,22 +165,6 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
tensorflow::cpu_allocator()->DeallocateRaw(data);
}
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
TF_Buffer* out) {
if (out->data != nullptr) {
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
}
const auto proto_size = in.ByteSizeLong();
void* buf = tensorflow::port::Malloc(proto_size);
in.SerializeToArray(buf, proto_size);
out->data = buf;
out->length = proto_size;
out->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
return Status::OK();
}
} // namespace
TF_Tensor::~TF_Tensor() { buffer->Unref(); }
@ -559,6 +543,27 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
dimvec.size(), base, size, DeleteArray, base);
}
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
TF_Buffer* out) {
if (out->data != nullptr) {
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
}
const size_t proto_size = in.ByteSizeLong();
void* buf = tensorflow::port::Malloc(proto_size);
if (buf == nullptr) {
return tensorflow::errors::ResourceExhausted(
"Failed to allocate memory to serialize message of type '",
in.GetTypeName(), "' and size ", proto_size);
}
in.SerializeToArray(buf, proto_size);
out->data = buf;
out->length = proto_size;
out->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
return Status::OK();
}
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);

View File

@ -357,6 +357,14 @@ typedef struct TF_Output {
int index; // The index of the output within oper.
} TF_Output;
// TF_Function is a grouping of operations with defined inputs and outputs.
// Once created and added to graphs, functions can be invoked by creating an
// operation whose operation type matches the function name.
typedef struct TF_Function TF_Function;
// Function definition options. TODO(iga): Define and implement
typedef struct TF_FunctionOptions TF_FunctionOptions;
// Sets the shape of the Tensor referenced by `output` in `graph` to
// the shape described by `dims` and `num_dims`.
//
@ -914,6 +922,15 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef(
TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options, TF_Status* status);
// Add `function` to graph `g`. Once `function` is added to `g`,
// it can be called by creating an operation using the function's name.
//
// If successful, status is set to OK and function is added to g
// Otherwise, status is set to the encountered error and g is unmodified
TF_CAPI_EXPORT extern void TF_GraphAddFunction(TF_Graph* g,
const TF_Function* function,
TF_Status* status);
// Note: The following function may fail on very large protos in the future.
TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper,
@ -1001,6 +1018,105 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
TF_Output* x, int nx, TF_Output* dx,
TF_Status* status, TF_Output* dy);
// Create a TF_Function from a TF_Graph
//
// Params:
// fn_body - the graph whose operations (or subset of whose operations) will be
// converted to TF_Function.
// fn_name - the name of the new TF_Function. Should match the operation
// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct
// from other operation names (at least those registered in graphs
// where this function will be used).
// TODO(iga): Allow null in here and have C API come up with
// a unique name with high probability (similarly to
// _create_hash_str in function.py)
// num_opers - `num_opers` contains the number of elements in the `opers` array
// or a special value of -1 meaning that no array is given.
// The distinction between an empty array of operations and no
// array of operations is necessary to distinguish the case of
// creating a function with no body (e.g. identity or permutation)
// and the case of creating a function whose body contains all
// the nodes in the graph (except for the automatic skipping, see
// below).
// opers - Array of operations to become the body of the function or null.
// - If no array is given (`num_opers` = -1), all the
// operations in `fn_body` will become part of the function
// except operations referenced in `inputs`. These operations
// must have a single output (these operations are typically
// placeholders created for the sole purpose of representing
// an input. We can relax this constraint if there are
// compelling use cases).
// - If an array is given (`num_opers` >= 0), all operations
// in it will become part of the function. In particular, no
// automatic skipping of dummy input operations is performed.
// ninputs - number of elements in `inputs` array
// inputs - array of TF_Outputs that specify the inputs to the function.
// If `ninputs` is zero (the function takes no inputs), `inputs`
// can be null. The names used for function inputs are normalized
// names of the operations (usually placeholders) pointed to by
// `inputs`. These operation names should start with a letter.
// Normalization will convert all letters to lowercase and
// non-alphanumeric characters to '_' to make resulting names match
// the "[a-z][a-z0-9_]*" pattern for operation argument names.
// `inputs` cannot contain the same tensor twice.
// noutputs - number of elements in `outputs` array
// outputs - array of TF_Outputs that specify the outputs of the function.
// If `noutputs` is zero (the function returns no outputs), `outputs`
// can be null. `outputs` can contain the same tensor more than once.
// output_names - The names of the function's outputs. `output_names` array
// must either have the same length as `outputs`
// (i.e. `noutputs`) or be null. In the former case,
// the names should match the regular expression for ArgDef
// names - "[a-z][a-z0-9_]*". In the latter case,
// names for outputs will be generated automatically.
// opts - various options for the function, e.g. XLA's inlining control.
// status - Set to OK on success and an appropriate error on failure.
//
// Note that when the same TF_Output is listed as both an input and an output,
// the corresponding function's output will equal to this input,
// instead of the original node's output.
//
// Callers must also satisfy the following constraints:
// - `inputs` cannot refer to TF_Outputs within a control flow context. For
// example, one cannot use the output of "switch" node as input.
// - No TF_Output of a function (inside any of `inputs`, `outputs`, `fn_body`)
// is allowed to have a reference type. Reference types are not exposed
// through C API and are being deprecated.
// - Every node in the function's body must have all of its inputs (including
// control inputs). In other words, for every node in the body, each input
// must be either listed in `inputs` or must come from another node in
// the body. In particular, it is an error to have a control edge going from
// a node outside of the body into a node in the body. This applies to control
// edges going from nodes referenced in `inputs` to nodes in the body when
// the former nodes are not in the body (automatically skipped or not
// included in explicitly specified body).
//
// Returns:
// On successful, a newly created TF_Function instance. It must be deleted by
// calling TF_DeleteFunction.
//
// On failure, null.
//
// TODO(iga): Add input_names argument and get output_names working (they are
// currently ignored)
TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
const TF_Graph* fn_body, const char* fn_name, int num_opers,
const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs, const char* const* output_names,
const TF_FunctionOptions* opts, TF_Status* status);
// Write out a serialized representation of `func` (as a FunctionDef protocol
// message) to `output_func_def` (allocated by TF_NewBuffer()).
// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
// is called.
//
// May fail on very large graphs in the future.
TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func,
TF_Buffer* output_func_def,
TF_Status* status);
TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function*);
// TODO(josh11b): Register OpDef, available to all operations added
// to this graph.

View File

@ -0,0 +1,496 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api_internal.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
// Class that maintains a one-to-one original node name -> new node name
// mapping. We normalize the names used as input and output arguments to match
// regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
// Once we rename them, we risk creating a name collision with the other
// node names, so if necessary we add a suffix to make
// names unique. If we have an input named "A" and a node in the function
// body named "a", they will be renamed to "a" and "a_0".
class NodeNameMapping {
public:
NodeNameMapping() = default;
// Normalize the input/output name and make it unique.
string GetIOName(const string& name);
// Make the node name unique.
string Uniquify(const string& name);
// Look up how a node name was previously normalized/uniquified.
// Returns empty if name was never seen.
string Lookup(const string& name) const;
private:
string UniquifyHelper(const string& name) const;
static string Normalize(string name);
// The normalized/uniquified names already used as
// input names (in signature), output names (in signature), and node names
// (in node_def).
// This is a superset of values in name_mapping_.
std::unordered_set<string> used_names_;
// Mapping from original node name from the graph to the normalized
// and uniqified version of it.
std::unordered_map<string, string> name_mapping_;
};
string NodeNameMapping::Normalize(string name) {
// Convert letters to lowercase and non-alphanumeric characters to '_'.
if (name.empty()) return "unknown";
const int n = name.size();
for (int i = 0; i < n; ++i) {
char c = name[i];
if (isalnum(c)) {
if (isupper(c)) {
name[i] = tolower(c);
}
} else {
name[i] = '_';
}
}
// Find the first letter and start with it.
int i = 0;
for (; i < n; ++i) {
if (isalpha(name[i])) break;
}
// Return "unknown" if none of the name's chars were letters.
return i == n ? "unknown" : name.substr(i);
}
string NodeNameMapping::UniquifyHelper(const string& name) const {
// If the name hasn't been used yet, use it as-is.
if (used_names_.find(name) == used_names_.end()) return name;
// Add a suffix to name to make it unique.
for (int i = 0;; ++i) {
const string candidate = strings::StrCat(name, "_", i);
if (used_names_.find(candidate) == used_names_.end()) return candidate;
}
}
string NodeNameMapping::GetIOName(const string& name) {
const string& input_name = UniquifyHelper(Normalize(name));
// Record that we used this name, but don't add it to name_mapping_
// since this name is not for a node.
used_names_.insert(input_name);
return input_name;
}
string NodeNameMapping::Uniquify(const string& name) {
const string uniqued = UniquifyHelper(name);
name_mapping_[name] = uniqued;
used_names_.insert(uniqued);
return uniqued;
}
string NodeNameMapping::Lookup(const string& name) const {
const auto iter = name_mapping_.find(name);
if (iter == name_mapping_.end()) return string();
return iter->second;
}
Status ValidateNoRefOutputs(const Node* node) {
for (int i = 0; i < node->num_outputs(); ++i) {
const DataType& dt = node->output_type(i);
if (IsRefType(dt)) {
return errors::InvalidArgument("Output ", i, " of node '", node->name(),
"' has a reference "
"type ",
DataTypeString(dt));
}
}
return Status::OK();
}
Status FillFunctionBody(
const string& fn_name, const NodeNameMapping& node_names,
const std::vector<const Node*>& body_nodes,
const std::unordered_map<string, string>& tensor_renaming,
FunctionDef* fdef) {
std::vector<const Edge*> in_edges;
std::vector<const Edge*> control_edges;
for (const Node* node : body_nodes) {
NodeDef* node_def = fdef->add_node_def();
// First, copy the node_def as is. We will patch it next.
*node_def = node->def();
if (!node->assigned_device_name().empty()) {
node_def->set_device(node->assigned_device_name());
}
node_def->set_name(node_names.Lookup(node->name()));
// Input names must be set based on nested names in tensor_renaming.
// Clear the flat input names we got from the original node_def
// from the graph.
node_def->clear_input();
// Collect regular and control inputs. Regular inputs are indexed
// by the index at which they come into the `node`. Control inputs
// don't follow any order.
in_edges.clear();
in_edges.resize(node->num_inputs(), nullptr);
control_edges.clear();
for (const Edge* edge : node->in_edges()) {
if (edge->src()->IsSource()) continue;
if (edge->IsControlEdge()) {
control_edges.push_back(edge);
} else {
in_edges[edge->dst_input()] = edge;
}
}
// Add regular inputs.
for (size_t i = 0; i < in_edges.size(); ++i) {
const Edge* edge = in_edges[i];
string original_input_name;
if (edge == nullptr) {
// A backedge might not appear as a regular Edge, but be only present
// in the node_def. Such edges are referred to as requested_inputs().
if (i >= node->requested_inputs().size()) {
return errors::InvalidArgument(
"Graph to be converted to function appears to be malformed. ",
"Node ", node->name(), " is missing input edge ", i);
}
original_input_name =
ParseTensorName(node->requested_inputs()[i]).ToString();
} else {
original_input_name =
strings::StrCat(edge->src()->name(), ":", edge->src_output());
}
const auto iter = tensor_renaming.find(original_input_name);
if (iter == tensor_renaming.end()) {
return errors::InvalidArgument(
"Input ", i, ", '", original_input_name, "', of node '",
node->name(), "' in function '", fn_name,
"' is not available. You might need to include it in inputs "
"or include its source node in the body");
}
node_def->add_input(iter->second);
}
// Add control inputs.
for (const Edge* edge : control_edges) {
// Add this control input only if the src node is in the body.
const string normalized = node_names.Lookup(edge->src()->name());
// If we did not find a name for the source of control edge, this
// source must be outside of the body. Raise an error.
if (normalized.empty()) {
return errors::InvalidArgument(
"The source of control edge ", edge->DebugString(),
" is not in the body. Encountered while creating function '",
fn_name, "'");
}
node_def->add_input(strings::StrCat("^", normalized));
}
}
return Status::OK();
}
// Graph to FunctionDef conversion. This code is closely modeled on the Python
// code in third_party/tensorflow/python/framework/function.py.
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
const std::vector<const Node*>& body_nodes,
const std::vector<OutputTensor>& inputs,
const std::vector<OutputTensor>& outputs,
const std::vector<string>& output_names,
FunctionDef* fdef) {
fdef->mutable_signature()->set_name(fn_name);
// Keep track of names we used and how we normalized them.
NodeNameMapping node_names;
// Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
// name we used in the function:
// - For input tensors:
// {flat_tensor_name -> normalized_name_of_src_node}
// e.g. {In:3 -> in}
// - For tensors produced by nodes in function's body:
// {flat_tensor_name -> nested_tensor_name}
// e.g. {Add:3 -> add_0:z:1}
std::unordered_map<string, string> tensor_renaming;
// Fill inputs in function's signature.
for (size_t i = 0; i < inputs.size(); ++i) {
const Node* node = inputs[i].node;
int idx = inputs[i].index;
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
argdef->set_type(node->output_type(idx));
const string& input_name = node_names.GetIOName(node->name());
argdef->set_name(input_name);
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
}
// Fill outputs in function's signature.
for (size_t i = 0; i < outputs.size(); ++i) {
const Node* node = outputs[i].node;
int idx = outputs[i].index;
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
argdef->set_type(node->output_type(idx));
argdef->set_name(node_names.GetIOName(node->name()));
}
// Populate tensor_renaming and node_names.
// Generate the new output names for every node in the function.
// The NodeDefs in FunctionDefs use a different naming scheme for
// their inputs than the NodeDefs in a graph (see the comment for
// FunctionDef.node_def in function.proto). We do the
// graph tensor name -> function tensor name conversion for every
// possible input (i.e. every node's outputs) and store the result
// in tensor_renaming.
for (const Node* node : body_nodes) {
// Make sure node_name does not collide with an input or output name.
const string& node_name = node_names.Uniquify(node->name());
// For each output_arg in the op_def, the output_ranges
// map will have [start, end] range of indices that this arg produces
// among all the output tensors of this op.
NameRangeMap output_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
for (const auto& output : output_ranges) {
const string& output_name = output.first;
int index_start = output.second.first;
int index_end = output.second.second;
for (int i = index_start; i < index_end; ++i) {
const string& original_name = strings::StrCat(node->name(), ":", i);
const string& new_name =
strings::StrCat(node_name, ":", output_name, ":", i - index_start);
// Record the mapping if this tensor is not already mapped.
// Tensor can be already mapped if it is used as an input.
if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
tensor_renaming[original_name] = new_name;
}
}
}
}
TF_RETURN_IF_ERROR(
FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
// Remap return values.
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
const string& ret_name = fdef->signature().output_arg(r).name();
// We convert this flat tensor name to the nested value
// (e.g. `add:z:1`) that we stored in tensor_renaming.
const string& return_value =
strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
const auto iter = tensor_renaming.find(return_value);
if (iter == tensor_renaming.end()) {
return errors::InvalidArgument(
"TF_Output ", return_value, " is neither in the function body ",
"nor among function inputs. Encountered while creating function '",
fn_name, "'");
}
(*fdef->mutable_ret())[ret_name] = iter->second;
}
return Status::OK();
}
// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
// does various checks while doing so. `input_nodes` will contain the same
// information as input_tensors just in a different structure to make
// following processing easier. TODO(iga): Simplify this nested structure.
Status ProcessInputs(
const TF_Graph* fn_body, const char* fn_name, int ninputs,
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
const Node& node = inputs[i].oper->node;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(&node, idx),
"Encountered while processing input ", i, " into function '", fn_name,
"'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node),
"Encountered while processing input ", i,
" into function '", fn_name, "'");
input_tensors->emplace_back(&node, idx);
const auto& iter = input_nodes->find(&node);
if (iter == input_nodes->end()) {
input_nodes->insert({&node, {idx}});
} else {
auto& indices = iter->second;
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
return errors::InvalidArgument(
"TF_Output ", node.name(), ":", idx,
" appears more than once in the input list");
}
indices.push_back(idx);
}
}
return Status::OK();
}
// Converts `noutputs` and `outputs` into `outputs_tensors` and does various
// checks while doing so.
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
int noutputs, const TF_Output* outputs,
std::vector<OutputTensor>* output_tensors)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
const Node& node = outputs[i].oper->node;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(&node, idx),
"Encountered while processing output ", i, " from function '", fn_name,
"'");
output_tensors->emplace_back(&node, idx);
}
return Status::OK();
}
// Populates `body_nodes` with the nodes that will become function's body.
// Performs various checks.
Status ComputeBodyNodes(
const TF_Graph* fn_body, const char* fn_name, int num_opers,
const TF_Operation* const* opers,
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
std::vector<const Node*>* body_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
if (num_opers == -1) {
for (const Node* node : fn_body->graph.op_nodes()) {
const auto& iter = input_nodes.find(node);
if (iter == input_nodes.end()) {
// This node is not referenced in inputs. Add it to the body.
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
"Encountered while creating function '",
fn_name, "'");
body_nodes->push_back(node);
} else {
// This node is referenced in inputs. Currently, we place an
// artificial restriction and require that when num_opers=-1, such
// nodes must have a single output.
if (node->num_outputs() != 1) {
return errors::InvalidArgument(
"When `num_opers` is set to -1, nodes referenced in `inputs` "
"must have a single output. Node ",
node->name(), " has ", node->num_outputs(),
" outputs. Encountered while creating function '", fn_name, "'");
}
}
}
} else {
body_nodes->reserve(num_opers);
for (int i = 0; i < num_opers; ++i) {
const Node* node = &opers[i]->node;
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
"Encountered while creating function '",
fn_name, "'");
body_nodes->push_back(node);
}
}
return Status::OK();
}
} // anonymous namespace
} // namespace tensorflow
using tensorflow::Node;
using tensorflow::string;
TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
int num_opers, const TF_Operation* const* opers,
int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs,
const char* const* output_names,
const TF_FunctionOptions* opts,
TF_Status* status) {
tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
// Process inputs.
std::vector<tensorflow::OutputTensor> input_tensors;
std::unordered_map<const Node*, std::vector<int>> input_nodes;
status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
&input_tensors, &input_nodes);
if (!status->status.ok()) return nullptr;
// Process outputs.
std::vector<tensorflow::OutputTensor> output_tensors;
status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
outputs, &output_tensors);
if (!status->status.ok()) return nullptr;
// Process output names.
std::vector<string> output_names_vec;
if (output_names) {
output_names_vec.reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names_vec.push_back(string(output_names[i]));
}
}
// Compute body nodes.
std::vector<const Node*> body_nodes;
status->status = tensorflow::ComputeBodyNodes(
fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
if (!status->status.ok()) return nullptr;
// Do the actual function creation.
TF_Function* tf_function = new TF_Function();
status->status = tensorflow::GraphToFunctionDef(
fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
output_names_vec, tf_function->fdef_lib.add_function());
if (!status->status.ok()) {
TF_DeleteFunction(tf_function);
return nullptr;
}
return tf_function;
}
void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function,
TF_Status* status) {
tensorflow::mutex_lock l(g->mu);
// At the moment, we have only one function and no gradients in fdef_lib.
// This makes the following operation atomic.
// TODO(iga): Add an atomic version of AddFunctionLibrary when we support
// gradients
status->status = g->graph.AddFunctionLibrary(function->fdef_lib);
}
void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
TF_Status* status) {
DCHECK_EQ(1, func->fdef_lib.function_size());
status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def);
}
void TF_DeleteFunction(TF_Function* function) { delete function; }

File diff suppressed because it is too large Load Diff

View File

@ -130,6 +130,11 @@ struct TF_DeviceList {
std::vector<tensorflow::DeviceAttributes> response;
};
struct TF_Function {
// Currently contains a single function and no gradients
tensorflow::FunctionDefLibrary fdef_lib;
};
namespace tensorflow {
class TensorCApi {
@ -142,6 +147,9 @@ class TensorCApi {
};
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
} // end namespace tensorflow
#endif // TENSORFLOW_C_C_API_INTERNAL_H_

View File

@ -829,7 +829,7 @@ TEST(CAPI, ShapeInferenceError) {
TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Operation* add = Add(vec2, vec3, graph, status);
TF_Operation* add = AddNoCheck(vec2, vec3, graph, status);
ASSERT_NE(TF_OK, TF_GetCode(status));
ASSERT_TRUE(add == nullptr);

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
using tensorflow::GraphDef;
@ -36,6 +38,23 @@ TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
return t;
}
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
const int32_t* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
num_values *= dims[i];
}
TF_Tensor* t =
TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
return t;
}
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
int64_t dims = values.size();
return Int32Tensor(&dims, 1, values.data());
}
TF_Tensor* Int32Tensor(int32_t v) {
const int num_bytes = sizeof(int32_t);
int32_t* values = new int32_t[1];
@ -44,19 +63,40 @@ TF_Tensor* Int32Tensor(int32_t v) {
&Int32Deallocator, nullptr);
}
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
// All the *Helper methods are used as a workaround for the restrictions that
// one cannot call ASSERT_* methods in non-void-returning functions (when
// exceptions are disabled during compilation)
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
return TF_FinishOperation(desc, s);
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
TF_Operation* op;
PlaceholderHelper(graph, s, name, &op);
return op;
}
void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", t, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", t, s);
if (TF_GetCode(s) != TF_OK) return nullptr;
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
return TF_FinishOperation(desc, s);
TF_Operation* op;
ConstHelper(t, graph, s, name, &op);
return op;
}
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
@ -65,11 +105,39 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
return Const(tensor.get(), graph, s, name);
}
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
const char* name, TF_Operation** op, bool check) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
*op = TF_FinishOperation(desc, s);
if (check) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
}
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
TF_Operation* op;
AddHelper(l, r, graph, s, name, &op, true);
return op;
}
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
TF_Operation* op;
AddHelper(l, r, graph, s, name, &op, false);
return op;
}
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
TF_Graph* graph, TF_Operation* ctrl_op,
TF_Status* s, const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
TF_AddControlInput(desc, ctrl_op);
return TF_FinishOperation(desc, s);
}
@ -81,11 +149,20 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
return TF_FinishOperation(desc, s);
}
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
TF_Output neg_input = {n, 0};
TF_AddInput(desc, neg_input);
return TF_FinishOperation(desc, s);
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
TF_Operation* op;
NegHelper(n, graph, s, &op);
return op;
}
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
@ -96,6 +173,32 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
return TF_FinishOperation(desc, s);
}
void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name, TF_Operation** op) {
TF_Operation* zero = ScalarConst(
0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
TF_AddInput(desc, {zero, 0});
TF_AddInput(desc, {input, 0});
TF_SetAttrInt(desc, "num_split", 3);
TF_SetAttrType(desc, "T", TF_INT32);
// Set device to CPU since there is no version of split for int32 on GPU
// TODO(iga): Convert all these helpers and tests to use floats because
// they are usually available on GPUs. After doing this, remove TF_SetDevice
// call in c_api_function_test.cc
TF_SetDevice(desc, "/cpu:0");
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_Operation* op;
Split3Helper(input, graph, s, name, &op);
return op;
}
bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
return false;
@ -196,6 +299,18 @@ bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
return ret;
}
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
TF_Status* s = TF_NewStatus();
TF_Buffer* buffer = TF_NewBuffer();
TF_FunctionToFunctionDef(func, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
TF_DeleteBuffer(buffer);
TF_DeleteStatus(s);
return ret;
}
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s) {
TF_Buffer* buffer = TF_NewBuffer();

View File

@ -33,6 +33,13 @@ typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
// Create a tensor with values of type TF_INT32 provided by `values`.
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
const int32_t* values);
// Create 1 dimensional tensor with values from `values`
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values);
TF_Tensor* Int32Tensor(int32_t v);
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
@ -47,6 +54,13 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "add");
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "add");
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
TF_Graph* graph, TF_Operation* ctrl_op,
TF_Status* s, const char* name = "add");
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name = "add");
@ -54,6 +68,10 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s);
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);
// Split `input` along the first dimention into 3 tensors
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name = "split3");
bool IsPlaceholder(const tensorflow::NodeDef& node_def);
bool IsScalarConst(const tensorflow::NodeDef& node_def, int v);
@ -66,6 +84,8 @@ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def);
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def);
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s);

View File

@ -687,6 +687,72 @@ Status MeanGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Mean", MeanGrad);
Status MinOrMaxGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// The partial derivative for any input along a "reduced" dimension
// is 1 when it is the min (or max) and 0 everywhere else. So the
// gradient calculation is identical for both operators.
//
// There's a special case for propagating gradients when there are
// multiple minima (or maxima) - we choose to divide the gradient
// equally among all matching inputs.
//
// Please note this comment
// https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
// for details.
// Running example:
// input: [[5, 5, 5],
// [1, 2, -3]]
// reduction_indices: [1]
auto input = op.input(0);
auto reduction_indices = op.input(1);
// [2, 3]
auto input_shape = Shape(scope, input);
// [2, 1]
auto output_shape_kept_dims =
ReducedShapeHelper(scope, input_shape, reduction_indices);
// for op=min (say)
// output = [5, -3]
// y = [[5],
// [-3]]
auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
// reshape([g1, g2], [2, 1]) = [[g1],
// [g2]]
auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
// indicators = equal(y, input)
// = equal([[5], [[5, 5, 5],
// [-3]], [1, 2, -3]])
// = [[1, 1, 1],
// [0, 0, 1]]
auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
// [[3],
// [1]]
auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
output_shape_kept_dims);
// [[1/3, 1/3, 1/3],
// [0, 0, 1]]
auto scale = Div(scope, indicators, num_selected);
// [[g1/3, g1/3, g1/3],
// [0, 0, g2]]
grad_outputs->push_back(Mul(scope, scale, grad));
// Stop propagation along reduction_indices
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const bool is_batch,

View File

@ -955,6 +955,55 @@ TEST_F(NaryGradTest, Mean) {
RunTest({x}, {x_shape}, {y}, {y_shape});
}
TEST_F(NaryGradTest, Min) {
TensorShape x_shape({2, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto y = Min(scope_, x, {-1});
// y's shape is the result of reducing x along axes -1 (= 1)
TensorShape y_shape({2});
Tensor x_init_value =
test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
RunTest(x, x_init_value, y, y_shape);
}
TEST_F(NaryGradTest, Max) {
TensorShape x_shape({2, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto y = Max(scope_, x, {-1});
// y's shape is the result of reducing x along axes -1 (= 1)
TensorShape y_shape({2});
Tensor x_init_value =
test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
RunTest(x, x_init_value, y, y_shape);
}
TEST_F(NaryGradTest, MinMulti) {
// Test gradient when there are multiple minima.
// Note that we cannot directly use a test Tensor with multiple
// minima, as the numeric estimator will calculate incorrect
// gradients when perturbing each entry in the Tensor (which then
// changes how many minima exist.)
// Instead, we use a single input that broadcast-multiplies a larger
// tensor with equal values, and apply reduce_min to the multiplied
// result.
TensorShape x_shape({1});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
auto y = Min(scope_, all_same, {0});
// y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
TensorShape y_shape({1});
RunTest({x}, {x_shape}, {y}, {y_shape});
}
TEST_F(NaryGradTest, MaxMulti) {
TensorShape x_shape({1});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
auto y = Max(scope_, all_same, {0});
TensorShape y_shape({1});
RunTest({x}, {x_shape}, {y}, {y_shape});
}
TEST_F(NaryGradTest, AddN) {
TensorShape shape({3, 2, 5});
std::vector<Output> xs;

View File

@ -52,6 +52,12 @@ class BinaryOpsTest(XLATestCase):
def testFloatOps(self):
for dtype in self.float_types:
self._testBinary(
lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype),
np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype),
expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
self._testBinary(
gen_math_ops._real_div,
np.array([3, 3, -1.5, -8, 44], dtype=dtype),
@ -82,6 +88,12 @@ class BinaryOpsTest(XLATestCase):
dtype(4),
expected=np.array([[16], [81]], dtype=dtype))
self._testBinary(
gen_math_ops._reciprocal_grad,
np.array([4, -3, -2, 1], dtype=dtype),
np.array([5, -6, 7, -8], dtype=dtype),
expected=np.array([-80, 54, -28, 8], dtype=dtype))
self._testBinary(
gen_math_ops._sigmoid_grad,
np.array([4, 3, 2, 1], dtype=dtype),
@ -107,6 +119,13 @@ class BinaryOpsTest(XLATestCase):
expected=np.array(
[3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype))
self._testBinary(
gen_nn_ops._softsign_grad,
np.array([4, 3, 2, 1], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array(
[0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype))
self._testBinary(
gen_math_ops._tanh_grad,
np.array([4, 3, 2, 1], dtype=dtype),

View File

@ -888,6 +888,16 @@ TEST_F(OpTest, Any) {
});
}
TEST_F(OpTest, ApproximateEqual) {
Repeatedly([this]() {
auto dims = RandomDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Asinh) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
@ -1662,11 +1672,9 @@ TEST_F(OpTest, GreaterEqual) {
TEST_F(OpTest, L2Loss) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
// TODO(b/31644876): scalars currently crash.
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss")
.RandomInput(type, RandomDims(1))
.Attr("T", type));
DataType type = DT_FLOAT;
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("L2Loss").RandomInput(type).Attr("T", type));
});
}
@ -2165,6 +2173,15 @@ TEST_F(OpTest, Reciprocal) {
});
}
TEST_F(OpTest, ReciprocalGrad) {
Repeatedly([this]() {
std::vector<int64> dims = RandomDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Relu) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
@ -2250,6 +2267,13 @@ TEST_F(OpTest, ReverseV2) {
});
}
TEST_F(OpTest, Rint) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Rint").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Round) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
@ -2402,6 +2426,23 @@ TEST_F(OpTest, SoftplusGrad) {
});
}
TEST_F(OpTest, Softsign) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Softsign").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SoftsignGrad) {
Repeatedly([this]() {
std::vector<int64> dims = RandomDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftsignGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SpaceToBatch) {
Repeatedly([this]() {
std::vector<int64> block_dims = RandomDims(4, 4, 0, 5);

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@ -161,12 +163,17 @@ class UnaryOpsTest(XLATestCase):
np.array([[-1.7, 1.2]], dtype=dtype),
expected=np.array([[-2, 1]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
dtype=dtype),
expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool))
# Tests for tf.nn ops.
self._assertOpOutputMatchesExpected(
nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0))
# TODO(b/31644876): enable this test case when fixed.
# self._assertOpOutputMatchesExpected(tf.nn.l2_loss, dtype(4), dtype(10))
self._assertOpOutputMatchesExpected(nn_ops.l2_loss, dtype(4), dtype(8))
self._assertOpOutputMatchesExpected(
nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10))
@ -198,6 +205,12 @@ class UnaryOpsTest(XLATestCase):
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.rint,
np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
[0.5, 1.5, 2.5, 3.5]], dtype=dtype),
expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.round,
np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
@ -301,6 +314,12 @@ class UnaryOpsTest(XLATestCase):
np.array([[-2, 0, 8]], dtype=dtype),
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softsign,
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
np.array(
@ -335,6 +354,23 @@ class UnaryOpsTest(XLATestCase):
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
# TODO(phawkins): these tests fail unless fastmath optimizations
# are disabled. Use more robust IsInf/IsNaN detection and enable these
# tests.
@unittest.skip("test case fails in fast-math mode")
def testIsInfAndIsNan(self):
for dtype in self.float_types:
self._assertOpOutputMatchesExpected(
math_ops.is_inf,
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
dtype=dtype),
expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool))
self._assertOpOutputMatchesExpected(
math_ops.is_nan,
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
dtype=dtype),
expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool))
def testLogicalOps(self):
self._assertOpOutputMatchesExpected(
math_ops.logical_not,

View File

@ -31,7 +31,6 @@ tf_kernel_library(
"function_ops.cc",
"gather_op.cc",
"identity_op.cc",
"is_finite_op.cc",
"l2loss_op.cc",
"lrn_ops.cc",
"matmul_op.cc",

View File

@ -102,6 +102,7 @@ XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs))));
XLA_MAKE_BINARY(
RsqrtGrad,
b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
@ -140,6 +141,11 @@ XLA_MAKE_BINARY(SoftplusGrad,
b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
XlaHelpers::One(b, input_type(1)))));
// softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
XLA_MAKE_BINARY(SoftsignGrad,
b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)),
b->Abs(rhs)))));
XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
b->Mul(lhs, lhs))));
@ -147,5 +153,24 @@ XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions));
#undef XLA_MAKE_BINARY
class ApproximateEqualOp : public XlaOpKernel {
public:
explicit ApproximateEqualOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance_));
}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))),
XlaHelpers::FloatLiteral(b, input_type(0), tolerance_));
ctx->SetOutput(0, result);
}
private:
float tolerance_;
};
REGISTER_XLA_OP(Name("ApproximateEqual"), ApproximateEqualOp);
} // namespace
} // namespace tensorflow

View File

@ -1,43 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace {
class IsFiniteOp : public XlaOpKernel {
public:
explicit IsFiniteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle input = ctx->Input(0);
ctx->SetOutput(0, ctx->builder()->IsFinite(input));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(IsFiniteOp);
};
REGISTER_XLA_OP(Name("IsFinite"), IsFiniteOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -73,8 +73,12 @@ XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0))));
XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x));
XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x),
XlaHelpers::FloatLiteral(
b, input_type(0),
std::numeric_limits<double>::infinity())));
XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x));
// Return 1/x
XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x));
XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x));
@ -105,6 +109,12 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b,
b->Add(round_val, one), round_val);
}
XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x));
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
XLAJIT_MAKE_UNARY(Rsqrt,
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
DataType dtype,
@ -112,16 +122,19 @@ static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
}
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
XLAJIT_MAKE_UNARY(Rsqrt,
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
XLAJIT_MAKE_UNARY(Sinh,
b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
XLAJIT_MAKE_UNARY(Softplus,
b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0)))));
// softsign(x) = x / (abs(x) + 1)
XLAJIT_MAKE_UNARY(Softsign,
b->Div(x,
b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0)))));
XLAJIT_MAKE_UNARY(Sqrt,
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));

View File

@ -847,6 +847,7 @@ cc_test(
srcs = ["hlo_ordering_test.cc"],
deps = [
":hlo",
":hlo_dataflow_analysis",
":hlo_ordering",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",

View File

@ -241,7 +241,7 @@ Status Executor::Run() {
completion_queue_.pop_front();
break;
}
} while (1);
} while (true);
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelSlice(instruction));
void* result_buffer =

View File

@ -24,16 +24,14 @@ limitations under the License.
namespace xla {
Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo,
HloOpcode opcode) {
Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo) {
return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s",
HloOpcodeString(opcode).c_str());
HloOpcodeString(hlo->opcode()).c_str());
}
Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) {
Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo) {
return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s",
HloOpcodeString(opcode).c_str());
HloOpcodeString(hlo->opcode()).c_str());
}
DfsHloVisitor::VisitState DfsHloVisitor::GetVisitState(

View File

@ -63,37 +63,37 @@ class DfsHloVisitor {
// These routines are self-descriptive, see class comment for usage
// information.
virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode);
virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode);
virtual Status HandleElementwiseUnary(HloInstruction* hlo);
virtual Status HandleElementwiseBinary(HloInstruction* hlo);
virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
HloInstruction* arg, HloInstruction* max) = 0;
virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred,
HloInstruction* on_true,
HloInstruction* on_false) = 0;
virtual Status HandleMaximum(HloInstruction* maximum) {
return HandleElementwiseBinary(maximum, HloOpcode::kMaximum);
return HandleElementwiseBinary(maximum);
}
virtual Status HandleMinimum(HloInstruction* minimum) {
return HandleElementwiseBinary(minimum, HloOpcode::kMinimum);
return HandleElementwiseBinary(minimum);
}
virtual Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
virtual Status HandleConvert(HloInstruction* convert) {
return HandleElementwiseUnary(convert, HloOpcode::kConvert);
return HandleElementwiseUnary(convert);
}
virtual Status HandleCopy(HloInstruction* copy) {
return HandleElementwiseUnary(copy, HloOpcode::kCopy);
return HandleElementwiseUnary(copy);
}
virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(multiply, HloOpcode::kMultiply);
return HandleElementwiseBinary(multiply);
}
virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
HloInstruction* rhs) = 0;
virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(power, HloOpcode::kPower);
return HandleElementwiseBinary(power);
}
virtual Status HandleConvolution(HloInstruction* convolution,
HloInstruction* lhs, HloInstruction* rhs,
@ -101,73 +101,72 @@ class DfsHloVisitor {
virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0;
virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
HloInstruction* lhs, HloInstruction* rhs) {
return HandleElementwiseBinary(compare, opcode);
return HandleElementwiseBinary(compare);
}
virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(add, HloOpcode::kAdd);
return HandleElementwiseBinary(add);
}
virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(divide, HloOpcode::kDivide);
return HandleElementwiseBinary(divide);
}
virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(remainder, HloOpcode::kRemainder);
return HandleElementwiseBinary(remainder);
}
virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(subtract, HloOpcode::kSubtract);
return HandleElementwiseBinary(subtract);
}
virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
return HandleElementwiseUnary(abs, HloOpcode::kAbs);
return HandleElementwiseUnary(abs);
}
virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) {
return HandleElementwiseUnary(sign, HloOpcode::kSign);
return HandleElementwiseUnary(sign);
}
virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) {
return HandleElementwiseUnary(negate, HloOpcode::kNegate);
return HandleElementwiseUnary(negate);
}
virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) {
return HandleElementwiseUnary(exp, HloOpcode::kExp);
return HandleElementwiseUnary(exp);
}
virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) {
return HandleElementwiseUnary(floor, HloOpcode::kFloor);
return HandleElementwiseUnary(floor);
}
virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) {
return HandleElementwiseUnary(ceil, HloOpcode::kCeil);
return HandleElementwiseUnary(ceil);
}
virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) {
return HandleElementwiseUnary(log, HloOpcode::kLog);
return HandleElementwiseUnary(log);
}
virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) {
return HandleElementwiseUnary(cos, HloOpcode::kCos);
return HandleElementwiseUnary(cos);
}
virtual Status HandleSin(HloInstruction* sin, HloInstruction* operand) {
return HandleElementwiseUnary(sin, HloOpcode::kSin);
return HandleElementwiseUnary(sin);
}
virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
return HandleElementwiseUnary(tanh, HloOpcode::kTanh);
return HandleElementwiseUnary(tanh);
}
virtual Status HandleIsFinite(HloInstruction* is_finite,
HloInstruction* operand) {
return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite);
return HandleElementwiseUnary(is_finite);
}
virtual Status HandleLogicalAnd(HloInstruction* logical_and,
HloInstruction* lhs, HloInstruction* rhs) {
return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd);
return HandleElementwiseBinary(logical_and);
}
virtual Status HandleLogicalNot(HloInstruction* logical_not,
HloInstruction* operand) {
return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot);
return HandleElementwiseUnary(logical_not);
}
virtual Status HandleLogicalOr(HloInstruction* logical_or,
HloInstruction* lhs, HloInstruction* rhs) {
return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr);
return HandleElementwiseBinary(logical_or);
}
virtual Status HandleReducePrecision(HloInstruction* reduce_precision) {
return HandleElementwiseUnary(reduce_precision,
HloOpcode::kReducePrecision);
return HandleElementwiseUnary(reduce_precision);
}
virtual Status HandleInfeed(HloInstruction* infeed) = 0;

View File

@ -41,12 +41,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
// Default action performed on HloInstruction.
virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0;
Status HandleElementwiseUnary(HloInstruction* hlo,
HloOpcode opcode) override {
Status HandleElementwiseUnary(HloInstruction* hlo) override {
return DefaultAction(hlo);
}
Status HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) override {
Status HandleElementwiseBinary(HloInstruction* hlo) override {
return DefaultAction(hlo);
}

View File

@ -709,7 +709,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
} else {
auto r = ir_builder_->CreateSub(q, p);
auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)},
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)},
{param_ir_type}, ir_builder_);
auto in_block = ir_builder_->GetInsertBlock();

View File

@ -334,7 +334,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_);
IrArray::Index input_index(index.size());
llvm::Value* in_bounds = ir_builder_->getInt1(1);
llvm::Value* in_bounds = ir_builder_->getInt1(true);
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* stridden_index = ir_builder_->CreateNSWMul(
index[i], ir_builder_->getInt64(window.dimensions(i).stride()));

View File

@ -389,7 +389,7 @@ StatusOr<string> CompileModuleToPtx(llvm::Module* module,
// Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA
// again after the standard optimization passes [http://b/13329423].
// TODO(jingyue): SROA may further expose more optimization opportunities, such
// TODO(jingyue): SROA may further expose more optimization opportunities such
// as more precise alias analysis and more function inlining (SROA may change
// the inlining cost of a function). For now, running SROA already emits good
// enough code for the evaluated benchmarks. We may want to run more

View File

@ -37,6 +37,230 @@ namespace xla {
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
// Data structure used to construct the alias analysis. Thrown away after alias
// analysis is complete. This data structure keeps track of which sets of
// HloValues must be in the same HloBuffer. This is maintained as a map from a
// buffer identifier (BufferNumber) to set of HLoValues.
//
// Initially each value is its own buffer. In MergeAliasedBuffers, sets of
// values which must share the same buffer are merged together. The end result
// is a partitioning of all HloValues into sets where each set needs its own
// HloBuffer. By performing this analysis without constructing HloBuffers on the
// fly, we can after-the-fact construct a vector of contiguously numbered
// HloBuffers after the buffer requirement has been determined.
class BufferValueMap {
public:
// A unique identifier for a set of colocated values which must share the same
// buffer. This is not necessarily the same as the HloBuffer::Id which will
// ultimately contain the values. The reason is that HloBuffer::Id's are
// contiguous, while BufferNumbers may not be. BufferNumbers may not be
// dense because buffers may be created and destroyed during the analysis
// construction process.
using BufferNumber = int64;
explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
: dataflow_(dataflow) {
buffers_.reserve(dataflow_.values().size());
value_to_buffer_number_.reserve(dataflow_.values().size());
for (const HloValue* value : dataflow_.values()) {
BufferNumber buffer_number = next_buffer_number_++;
buffers_[buffer_number].insert(value);
value_to_buffer_number_[value] = buffer_number;
}
}
// Merge together sets of HloValues which must be in the same HloBuffer
// because of aliasing rules (eg, in-place kWhile instruction).
void MergeAliasedBuffers() {
for (const HloValue* value : dataflow_.values()) {
VLOG(3) << "Merging colocated values, value: " << value->ToShortString();
// Gather the set of buffers with aliasing rules (eg, kWhile) which this
// value must be contained in.
std::vector<BufferNumber> aliased_buffers = ComputeAliasedBuffers(*value);
BufferNumber current_buffer = value_to_buffer_number_.at(value);
if (aliased_buffers.empty()) {
// The buffer containing 'value' aliases no other buffers. If the buffer
// containing 'value' already only contains 'value', then no change is
// necessary. If the buffer containing 'value' does contain other
// values, then remove 'value' from the buffer and create a new buffer
// containing only 'value'
if (buffers_.at(current_buffer).size() == 1) {
CHECK_EQ(*buffers_.at(current_buffer).begin(), value);
} else {
MoveValueToNewBuffer(*value);
}
} else {
// If multiple buffers are aliased merge these buffers together into a
// single buffer (arbitrarily chosen as the first buffer in the vector).
if (aliased_buffers.size() > 1) {
for (int64 i = 1; i < aliased_buffers.size(); ++i) {
MergeBuffers(/*from=*/aliased_buffers[i],
/*to=*/aliased_buffers[0]);
}
}
BufferNumber new_buffer = aliased_buffers[0];
if (current_buffer != new_buffer) {
MoveValueToBuffer(*value, new_buffer);
}
}
}
}
// Compute and return a sorted vector of all BufferNumbers. Can be used to
// iterate through all buffers stabily.
std::vector<BufferNumber> ComputeSortedBufferNumbers() const {
std::vector<BufferNumber> buffer_numbers;
for (const auto& pair : buffers_) {
buffer_numbers.push_back(pair.first);
}
std::sort(buffer_numbers.begin(), buffer_numbers.end());
return buffer_numbers;
}
// Return a set of all the values in the given buffer.
const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer(
BufferNumber buffer_number) const {
return buffers_.at(buffer_number);
}
private:
// Create a new buffer.
void NewBuffer(const HloValue& value) {
BufferNumber buffer_number = next_buffer_number_++;
buffers_[buffer_number].insert(&value);
value_to_buffer_number_[&value] = buffer_number;
}
// Move the given value into a new buffer containing only the value.
void MoveValueToNewBuffer(const HloValue& value) {
BufferNumber new_buffer_number = next_buffer_number_++;
buffers_[new_buffer_number];
MoveValueToBuffer(value, new_buffer_number);
}
// Move the given value into the given buffer.
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
buffers_.at(old_buffer_number).erase(&value);
if (buffers_.at(old_buffer_number).empty()) {
buffers_.erase(old_buffer_number);
}
buffers_.at(buffer_number).insert(&value);
value_to_buffer_number_.at(&value) = buffer_number;
}
// Merge the buffer 'from' into the buffer 'to'.
void MergeBuffers(BufferNumber from, BufferNumber to) {
auto& from_value_set = buffers_.at(from);
buffers_.at(to).insert(from_value_set.begin(), from_value_set.end());
// NOTE: using a union-find algorithm to hold the colocated values might be
// faster.
for (const HloValue* value : from_value_set) {
value_to_buffer_number_.at(value) = to;
}
buffers_.erase(from);
}
BufferNumber GetBufferForValue(const HloValue& value) {
return value_to_buffer_number_.at(&value);
}
// Compute and return a vector of buffers that the given value must be
// contained in due to HLO aliasing rules.
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
// Value is init of a while (use is while).
std::vector<BufferNumber> aliased_buffers;
for (const HloUse& use : value.uses()) {
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
if (use.instruction->opcode() == HloOpcode::kWhile) {
// Determine the while value that this shares a buffer with.
const HloValue& while_value =
dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
aliased_buffers.push_back(GetBufferForValue(while_value));
VLOG(3) << " value is init value to a while; must share buffer with "
"while value "
<< while_value.ToShortString();
}
}
// Value is a parameter of a while body/condition.
if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
const HloComputation* computation =
value.defining_instruction()->parent();
const CallGraphNode& call_graph_node =
dataflow_.call_graph().GetNode(computation);
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Call graph must have been flattened.
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
const HloValue& while_value = dataflow_.GetUniqueValueAt(
callsite.instruction(), value.defining_index());
VLOG(3) << " value is parameter value of the body or condition of a "
"while; must share buffer with while value "
<< while_value.ToShortString();
aliased_buffers.push_back(GetBufferForValue(while_value));
}
}
}
// Value is the root of a while body.
for (const HloPosition& position : value.positions()) {
const HloComputation* computation = position.instruction->parent();
const CallGraphNode& call_graph_node =
dataflow_.call_graph().GetNode(computation);
if (position.instruction == computation->root_instruction()) {
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
callsite.instruction()->while_body() == computation) {
// Call graph must have been flattened.
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
const HloValue& while_value = dataflow_.GetUniqueValueAt(
callsite.instruction(), position.index);
VLOG(3) << " value is root the body computation of a while; must "
"share buffer with while value "
<< while_value.ToShortString();
aliased_buffers.push_back(GetBufferForValue(while_value));
}
}
}
}
// Value is the output of the while instruction itself.
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
VLOG(3) << " value is output of a while instruction";
aliased_buffers.push_back(GetBufferForValue(value));
}
// Uniquify aliased buffers.
std::sort(aliased_buffers.begin(), aliased_buffers.end());
aliased_buffers.erase(
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
aliased_buffers.end());
return aliased_buffers;
}
// Dataflow analysis used to construct the buffer map.
const HloDataflowAnalysis& dataflow_;
// A map containing the set of values contained in each buffer.
tensorflow::gtl::FlatMap<BufferNumber,
tensorflow::gtl::FlatSet<const HloValue*>>
buffers_;
// A map indicating which buffer each value is contained in.
tensorflow::gtl::FlatMap<const HloValue*, BufferNumber>
value_to_buffer_number_;
// The buffer number of the next buffer to be created.
BufferNumber next_buffer_number_ = 0;
};
HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {}
const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
@ -99,10 +323,11 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
}
} else {
// It's possible for multiple values at this index to have the same
// HloBuffer. This does not result in non-distictness. To account for this
// case, add all of the buffers at this index after checking whether each
// buffer exists at an earlier index. This is a corner case, however, as
// the number of values at an index is almost always one.
// HloBuffer. This does not result in non-distictness. To account for
// this case, add all of the buffers at this index after checking
// whether each buffer exists at an earlier index. This is a corner
// case, however, as the number of values at an index is almost always
// one.
std::vector<const HloBuffer*> buffers_at_this_index;
for (const HloValue* value : value_set.values()) {
const HloBuffer* buffer = &GetBufferContainingValue(*value);
@ -118,15 +343,6 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
return true;
}
void HloAliasAnalysis::InitializeBufferSets() {
// Initially define a buffer for every HloValue in the module.
for (const HloValue& value : dataflow_analysis_->values()) {
HloBuffer& buffer = NewHloBuffer();
buffer.AddValue(value);
value_to_buffer_[&value] = &buffer;
}
}
Status HloAliasAnalysis::Verify() const {
// Verify consistency between the value_to_buffer_ map and
// HloBuffer::values().
@ -137,9 +353,8 @@ Status HloAliasAnalysis::Verify() const {
value) != buffer.values().end());
}
for (const auto& pair : buffers_) {
const HloBuffer::Id id = pair.first;
const HloBuffer& buffer = pair.second;
for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
const HloBuffer& buffer = buffers_[id];
TF_RET_CHECK(buffer.id() == id);
HloValue::Id last_value_id = -1;
@ -152,116 +367,9 @@ Status HloAliasAnalysis::Verify() const {
}
}
if (!buffers_vector_.empty()) {
// buffers_vector_ should be a vector of all HloBuffers sorted by id.
std::vector<const HloBuffer*> buffers;
for (const auto& id_buffer : buffers_) {
buffers.push_back(&id_buffer.second);
}
std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan);
TF_RET_CHECK(buffers_vector_ == buffers);
}
return Status::OK();
}
Status HloAliasAnalysis::VerifyAgainstReference() const {
TF_RETURN_IF_ERROR(Verify());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> reference,
Run(module_));
TF_RETURN_IF_ERROR(reference->Verify());
VLOG(2) << "This analysis:";
XLA_VLOG_LINES(2, ToString());
VLOG(2) << "Reference:";
XLA_VLOG_LINES(2, reference->ToString());
// Create map from HloValue in the reference analysis to HloValue in this
// analysis and vice versa.
tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> reference_to_this;
tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> this_to_reference;
for (const HloValue& value : dataflow_analysis().values()) {
const HloValue& reference_value =
reference->dataflow_analysis().GetValueDefinedAt(
value.defining_instruction(), value.defining_index());
reference_to_this[&reference_value] = &value;
this_to_reference[&value] = &reference_value;
}
TF_RET_CHECK(buffers_.size() == reference->buffers_.size())
<< "Different number of buffers (" << buffers_.size()
<< " != " << reference->buffers_.size() << ")";
for (const auto& pair : reference->buffers_) {
const HloBuffer& reference_buffer = pair.second;
// Find the corresponding buffer in the reference by taking the first value
// in the buffer, finding the corresponding value in the reference, and then
// finding the buffer holding that value.
TF_RET_CHECK(!reference_buffer.values().empty());
const HloValue* reference_value = reference_buffer.values()[0];
const HloValue* value = reference_to_this.at(reference_value);
const HloBuffer& buffer = GetBufferContainingValue(*value);
// The buffer and the reference should have the exact same values. To make
// comparison easy, sort the values in the reference buffer identically to
// the values in the non-reference buffer (ie, by the corresponding id of
// the non-reference value).
std::vector<const HloValue*> reference_values = reference_buffer.values();
std::sort(reference_values.begin(), reference_values.end(),
[&reference_to_this](const HloValue* a, const HloValue* b) {
return reference_to_this.at(a)->id() <
reference_to_this.at(b)->id();
});
TF_RET_CHECK(reference_values.size() == buffer.values().size());
for (int i = 0; i < buffer.values().size(); ++i) {
TF_RET_CHECK(*reference_values[i] == *buffer.values()[i])
<< "Buffer:\n " << buffer
<< "\ndoes not have the same values as reference buffer:\n "
<< reference_buffer;
}
}
return Status::OK();
}
HloBuffer& HloAliasAnalysis::NewHloBuffer() {
HloBuffer::Id buffer_id = next_buffer_id_++;
auto emplaced = buffers_.emplace(std::piecewise_construct,
std::forward_as_tuple(buffer_id),
std::forward_as_tuple(buffer_id));
CHECK(emplaced.second);
buffers_vector_.clear();
return emplaced.first->second;
}
void HloAliasAnalysis::MoveValueToNewBuffer(const HloValue& value) {
HloBuffer& new_buffer = NewHloBuffer();
MoveValueToBuffer(value, &new_buffer);
VLOG(3) << "Moved value " << value.ToShortString() << " into new buffer "
<< new_buffer.id();
}
void HloAliasAnalysis::MoveValueToBuffer(const HloValue& value,
HloBuffer* buffer) {
HloBuffer& old_buffer = GetBufferContainingValue(value);
CHECK_NE(buffer, &old_buffer);
VLOG(3) << "Moved value " << value.ToShortString() << " from buffer "
<< old_buffer.id() << " into buffer " << buffer->id();
old_buffer.RemoveValue(value);
if (old_buffer.values().empty()) {
VLOG(3) << "Buffer " << old_buffer.id() << " now empty. Removing.";
buffers_.erase(old_buffer.id());
buffers_vector_.clear();
}
buffer->AddValue(value);
value_to_buffer_[&value] = buffer;
}
string HloAliasAnalysis::ToString() const {
string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Buffers at each position:\n");
@ -290,10 +398,10 @@ string HloAliasAnalysis::ToString() const {
}
StrAppend(&out, " Buffers:\n");
for (const HloBuffer* buffer : buffers()) {
StrAppend(&out, " ", buffer->ToString(), "\n");
for (const HloBuffer& buffer : buffers()) {
StrAppend(&out, " ", buffer.ToString(), "\n");
StrAppend(&out, " positions:\n");
for (const HloPosition& position : buffer->ComputePositions()) {
for (const HloPosition& position : buffer.ComputePositions()) {
StrAppend(&out, " ", position.ToString(), "\n");
}
}
@ -301,217 +409,6 @@ string HloAliasAnalysis::ToString() const {
return out;
}
const std::vector<const HloBuffer*>& HloAliasAnalysis::buffers() const {
if (buffers_vector_.empty()) {
// Lazily construct vector of buffers.
buffers_vector_.reserve(buffers_.size());
for (auto& pair : buffers_) {
buffers_vector_.push_back(&pair.second);
}
std::sort(buffers_vector_.begin(), buffers_vector_.end(),
HloBuffer::IdLessThan);
} else {
CHECK_EQ(buffers_vector_.size(), buffers_.size());
for (const HloBuffer* buffer : buffers_vector_) {
DCHECK(ContainsKey(buffers_, buffer->id()));
DCHECK(&GetBuffer(buffer->id()) == buffer);
}
}
return buffers_vector_;
}
void HloAliasAnalysis::UpdateAtInstructions(
tensorflow::gtl::ArraySlice<const HloInstruction*> instructions) {
VLOG(4) << "Updated HLO module:";
XLA_VLOG_LINES(4, module_->ToString());
VLOG(3) << "Before update:";
XLA_VLOG_LINES(3, ToString());
std::vector<const HloValue*> values_to_update;
for (const HloInstruction* instruction : instructions) {
for (auto& pair : dataflow_analysis().GetInstructionValueSet(instruction)) {
for (const HloValue* value : pair.second.values()) {
values_to_update.push_back(value);
}
}
}
UpdateBuffersForValues(values_to_update);
VLOG(3) << "After update:";
XLA_VLOG_LINES(3, ToString());
}
void HloAliasAnalysis::UpdateAfterChangingOperand(HloInstruction* instruction,
HloInstruction* old_operand,
HloInstruction* new_operand) {
VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
<< old_operand->name() << " => " << new_operand->name() << ")";
dataflow_analysis_->UpdateAfterChangingOperand(instruction, old_operand,
new_operand);
TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
VLOG(4) << "Updated dataflow:";
XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
UpdateAtInstructions({instruction, old_operand, new_operand});
}
void HloAliasAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
HloInstruction* new_root) {
VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
<< new_root->name() << ")";
dataflow_analysis_->UpdateAfterChangingRoot(old_root, new_root);
TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
VLOG(4) << "Updated dataflow:";
XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
UpdateAtInstructions({old_root, new_root});
}
std::vector<HloBuffer*> HloAliasAnalysis::ComputeAliasedBuffers(
const HloValue& value) {
std::vector<HloBuffer*> aliased_buffers;
// Value is init of a while (use is while).
for (const HloUse& use : value.uses()) {
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
if (use.instruction->opcode() == HloOpcode::kWhile) {
// Determine the while value that this shares a buffer with.
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
use.instruction, use.operand_index);
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
VLOG(3) << " value is init value to a while; must share buffer with "
"while value "
<< while_value.ToShortString();
}
}
// Value is a parameter of a while body/condition.
if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
const HloComputation* computation = value.defining_instruction()->parent();
const CallGraphNode& call_graph_node =
dataflow_analysis().call_graph().GetNode(computation);
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Call graph must have been flattened.
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
callsite.instruction(), value.defining_index());
VLOG(3) << " value is parameter value of the body or condition of a "
"while; must share buffer with while value "
<< while_value.ToShortString();
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
}
}
}
// Value is the root of a while body.
for (const HloPosition& position : value.positions()) {
const HloComputation* computation = position.instruction->parent();
const CallGraphNode& call_graph_node =
dataflow_analysis().call_graph().GetNode(computation);
if (position.instruction == computation->root_instruction()) {
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
callsite.instruction()->while_body() == computation) {
// Call graph must have been flattened.
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
// If the value appears in the root of a while body, then
// necessarily the value is defined in the body as well.
CHECK_EQ(value.defining_instruction()->parent(), computation);
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
callsite.instruction(), position.index);
VLOG(3) << " value is root the body computation of a while; must "
"share buffer with while value "
<< while_value.ToShortString();
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
}
}
}
}
// Value is in the while instruction itself.
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
VLOG(3) << " value is output of a while instruction";
aliased_buffers.push_back(&GetUniqueBufferAt(value.defining_instruction(),
value.defining_index()));
}
// Uniquify aliased buffers.
std::sort(aliased_buffers.begin(), aliased_buffers.end(),
HloBuffer::IdLessThan);
aliased_buffers.erase(
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
aliased_buffers.end());
return aliased_buffers;
}
// This method recomputes the HloBuffer for each of the given HloValues. The
// method does not necessarily update the HloBuffer of values which share a
// buffer with the given values, but are not explicitly passed in
// 'values'. Therefore, the caller must pass in all values which may require an
// update according to the kind of HLO graph change which occurred: operand
// changed (UpdateAfterChangingOperand), or root of computation changed
// (UpdateAfterChangingRoot).
void HloAliasAnalysis::UpdateBuffersForValues(
tensorflow::gtl::ArraySlice<const HloValue*> values) {
for (const HloValue* value : values) {
VLOG(3) << "Updating buffer for value: " << value->ToShortString();
// Gather the set of buffer with aliasing rules (eg, kWhile) which this
// value must be contained in due.
std::vector<HloBuffer*> aliased_buffers = ComputeAliasedBuffers(*value);
HloBuffer& current_buffer = GetBufferContainingValue(*value);
if (aliased_buffers.empty()) {
// The buffer containing 'value' aliases no other buffers. If the buffer
// containing 'value' already only contains 'value', then no change is
// necessary. If the buffer containing 'value' does contain other values,
// then remove 'value' from the buffer and create a new buffer containing
// only 'value'
if (current_buffer.values().size() == 1) {
CHECK_EQ(current_buffer.values()[0], value);
} else {
MoveValueToNewBuffer(*value);
}
} else {
// If multiple buffers are aliased merge these buffers together into a
// single buffer (arbitrarily chosen as the first buffer in the vector).
if (aliased_buffers.size() > 1) {
for (int64 i = 1; i < aliased_buffers.size(); ++i) {
// Make copy of values vector because MoveValueToBuffer invalidates
// the values iterator. The could be done more efficiently by moving
// all values and once.
std::vector<const HloValue*> values = aliased_buffers[i]->values();
for (const HloValue* value : values) {
MoveValueToBuffer(*value, aliased_buffers[0]);
}
}
aliased_buffers.resize(1);
}
CHECK_EQ(aliased_buffers.size(), 1);
HloBuffer* new_buffer = aliased_buffers[0];
if (&current_buffer != new_buffer) {
MoveValueToBuffer(*value, new_buffer);
}
}
VLOG(4) << "Analysis after update:";
XLA_VLOG_LINES(4, ToString());
}
}
/* static */
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
HloModule* module) {
@ -524,18 +421,28 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
HloDataflowAnalysis::Run(module, /*ssa_form=*/true,
/*bitcast_defines_value=*/false));
alias_analysis->InitializeBufferSets();
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
buffer_map.MergeAliasedBuffers();
VLOG(3) << "After initialization:";
XLA_VLOG_LINES(3, alias_analysis->ToString());
std::vector<const HloValue*> all_values;
for (const HloValue& value : alias_analysis->dataflow_analysis().values()) {
all_values.push_back(&value);
// Create a vector of HloBuffers, one for each set of values in the
// BufferValueMap. Create the HloBuffers as a vector of contiguously numbered
// buffers.
std::vector<BufferValueMap::BufferNumber> sorted_buffer_numbers =
buffer_map.ComputeSortedBufferNumbers();
alias_analysis->buffers_.reserve(sorted_buffer_numbers.size());
HloBuffer::Id next_id = 0;
for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) {
auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
std::vector<const HloValue*> sorted_values(value_set.begin(),
value_set.end());
std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan);
alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
for (const HloValue* value : sorted_values) {
alias_analysis->value_to_buffer_[value] =
&alias_analysis->buffers_.back();
}
}
alias_analysis->UpdateBuffersForValues(all_values);
TF_DCHECK_OK(alias_analysis->Verify());
XLA_VLOG_LINES(1, alias_analysis->ToString());

View File

@ -74,7 +74,7 @@ class HloAliasAnalysis {
// Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This
// vector is lazily computed. Mutating operations on HloAliasAnalysis may
// invalidate the underlying vector requiring recomputation.
const std::vector<const HloBuffer*>& buffers() const;
const std::vector<HloBuffer>& buffers() const { return buffers_; }
// Returns the underlying dataflow analysis used by this alias analysis.
const HloDataflowAnalysis& dataflow_analysis() const {
@ -90,50 +90,13 @@ class HloAliasAnalysis {
// output of the given instruction.
bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const;
// Updates the analysis after the operands of 'instruction' have changed or if
// 'instruction' has been made the root of a computation. Analysis update is
// not possible if instructions have been added or removed from the graph.
void UpdateAfterChangingOperand(HloInstruction* instruction,
HloInstruction* old_operand,
HloInstruction* new_operand);
void UpdateAfterChangingRoot(HloInstruction* old_root,
HloInstruction* new_root);
// Compare the dataflow analysis against a clean recomputation of the
// analysis. Returns an error status if there is a mismatch. Useful for
// verifying the correctness after updates to the analysis.
Status VerifyAgainstReference() const;
protected:
HloAliasAnalysis(HloModule* module);
// Create a new empty HloBuffer.
HloBuffer& NewHloBuffer();
// Move the given value to the given buffer. The value is removed from it's
// current buffer.
void MoveValueToBuffer(const HloValue& value, HloBuffer* buffer);
// Move the given value to a newly created buffer. The value is removed from
// it's current buffer.
void MoveValueToNewBuffer(const HloValue& value);
// Construct the initial set of buffer sets where an HloBuffer is created for
// each HloValue in the module.
void InitializeBufferSets();
// Compute and return the buffers with aliasing rules (eg, kWhile) which the
// given value must be contained in.
std::vector<HloBuffer*> ComputeAliasedBuffers(const HloValue& value);
// Recompute the HloBuffers for the given values.
void UpdateBuffersForValues(
tensorflow::gtl::ArraySlice<const HloValue*> values);
// Recompute the HloBuffers for all the values which appear in the output of
// the given instructions.
void UpdateAtInstructions(
tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
explicit HloAliasAnalysis(HloModule* module);
// Verify various invariants of the alias analysis.
Status Verify() const;
@ -143,20 +106,12 @@ class HloAliasAnalysis {
// The underlying dataflow analysis used by this alias analysis.
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
// The map of all HloBuffers in the module. We pass around pointers to the
// mapped HloBuffers, so the underlying container must keep them valid despite
// mutations touching other map entries.
std::unordered_map<HloBuffer::Id, HloBuffer> buffers_;
// A map indicating which buffer a value is contained in.
tensorflow::gtl::FlatMap<const HloValue*, HloBuffer*> value_to_buffer_;
// A lazily constructed vector containing all HloBuffers sorted by
// HloBuffer::Id.
mutable std::vector<const HloBuffer*> buffers_vector_;
// The Id to use for the next HloBuffer.
int64 next_buffer_id_ = 0;
std::vector<HloBuffer> buffers_;
};
} // namespace xla

View File

@ -87,14 +87,13 @@ class HloAliasAnalysisTest : public HloTestBase {
// constructed.
bool AnyValuesInSameBufferInterfere() {
DependencyHloOrdering ordering(module_.get());
for (const HloBuffer* buffer : analysis_->buffers()) {
for (const HloValue* value_a : buffer->values()) {
for (const HloValue* value_b : buffer->values()) {
for (const HloBuffer& buffer : analysis_->buffers()) {
for (const HloValue* value_a : buffer.values()) {
for (const HloValue* value_b : buffer.values()) {
if (*value_a != *value_b &&
analysis_->dataflow_analysis().MayInterfere(*value_a, *value_b,
ordering)) {
ordering.MayInterfere(*value_a, *value_b)) {
VLOG(1) << *value_a << " interferes with " << *value_b
<< " in buffer: " << *buffer;
<< " in buffer: " << buffer;
return true;
}
}
@ -384,10 +383,7 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
UnorderedElementsAre(GetValueDefinedAt(xla_while, /*index=*/{0}),
GetValueDefinedAt(body_param, /*index=*/{0}),
GetValueDefinedAt(cond_param, /*index=*/{0}),
GetValueDefinedAt(constant1)));
UnorderedElementsAre(GetValueDefinedAt(constant1)));
EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
UnorderedElementsAre(GetValueDefinedAt(constant2),
@ -631,9 +627,9 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
// HloBuffers.
EXPECT_THAT(
analysis.buffers(),
UnorderedElementsAre(&analysis.GetUniqueBufferAt(constant1),
&analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
&analysis.GetUniqueBufferAt(cond_constant)));
UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
analysis.GetUniqueBufferAt(cond_constant)));
// The tuple elements of the while and the three constant inputs should all be
// smooshed into the same buffer.
@ -820,127 +816,5 @@ TEST_F(HloAliasAnalysisTest, Bitcast) {
analysis.GetUniqueBufferAt(bitcast));
}
TEST_F(HloAliasAnalysisTest, UpdateAnalysisForWhile) {
// Test updating alias analysis after modifying a module with an array shaped
// while:
//
// body(F32[] %param):
// %negate = Negate(%param)
//
// condition(F32[] %param):
// return Constant(false)
//
// entry:
// %constant = Constant(1.0)
// %exp = Exp(%constant)
// return While(%exp, body, condition)
//
auto body_builder = HloComputation::Builder("body");
auto body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kNegate, body_param));
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
// Condition computation trivially returns a constant "false".
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
module_->AddEntryComputation(builder.Build());
HloAliasAnalysis& analysis = RunAnalysis();
// Sanity check some alias information.
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(body_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(cond_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(negate));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(xla_while));
// Set the body root to the body_param. Previously it was Negate(body_param).
body->set_root_instruction(body_param);
// Prior to updating, verify that the analysis is no longer valid.
Status verify_status = analysis.VerifyAgainstReference();
EXPECT_FALSE(verify_status.ok());
analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
/*new_root*/ body_param);
// Analysis should be valid after the update.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
// The exponential should now pass through the body transparently.
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(body_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(cond_param));
EXPECT_NE(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(negate));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(xla_while));
// Now replace the operand of the while with %constant (was %exp).
TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
/*new_operand=*/constant);
// Analysis should be valid after the update.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(body_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(cond_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(xla_while));
EXPECT_NE(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(exp));
EXPECT_NE(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(negate));
// And finally make the negate the root of the body again.
body->set_root_instruction(negate);
analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
/*new_root*/ negate);
// Analysis should be valid after the update.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
analysis.GetUniqueBufferAt(body_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
analysis.GetUniqueBufferAt(cond_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
analysis.GetUniqueBufferAt(xla_while));
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(negate));
auto value_of = [&analysis](const HloInstruction* instruction) {
return &analysis.dataflow_analysis().GetValueDefinedAt(instruction);
};
EXPECT_THAT(analysis.GetUniqueBufferAt(negate).values(),
UnorderedElementsAre(value_of(body_param), value_of(cond_param),
value_of(negate), value_of(constant),
value_of(xla_while)));
}
// Test update tuple element.
} // namespace
} // namespace xla

View File

@ -36,22 +36,6 @@ namespace xla {
using ::tensorflow::str_util::Join;
using ::tensorflow::strings::StrCat;
void HloBuffer::AddValue(const HloValue& value) {
values_.push_back(&value);
// Sort vector and remove duplicates.
std::sort(values_.begin(), values_.end(), HloValue::IdLessThan);
values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
values_.end());
}
void HloBuffer::RemoveValue(const HloValue& value) {
// The values are sorted, so finding the value could be done in log(n) time
// with a binary search.
auto it = std::find(values_.begin(), values_.end(), &value);
CHECK(it != values_.end());
values_.erase(it);
}
bool HloBuffer::operator==(const HloBuffer& other) const {
bool equal = id() == other.id();
if (equal) {

View File

@ -84,22 +84,15 @@ class HloBuffer {
return a->id() == b->id();
}
HloBuffer(Id id) : id_(id) {}
HloBuffer(Id id, tensorflow::gtl::ArraySlice<const HloValue*> values)
: id_(id), values_(values.begin(), values.end()) {}
// Return the unique identifier for this HloBuffer.
Id id() const { return id_; }
// Add a value to the set of values held by this buffer. Also adds the
// HloPositions of the value to the positions vector of the buffer. If the
// buffer already contains this value, then this method is a nop.
void AddValue(const HloValue& value);
void RemoveValue(const HloValue& value);
// Return all values contained in this buffer.
const std::vector<const HloValue*>& values() const { return values_; }
std::vector<HloPosition> ComputePositions() const;
// Return the unique HLO value in the buffer. CHECK fails if the buffer does
// not contain exactly one value.
const HloValue& GetUniqueValue() const {
@ -107,6 +100,8 @@ class HloBuffer {
return *values_[0];
}
std::vector<HloPosition> ComputePositions() const;
string ToString() const;
bool operator==(const HloBuffer& other) const;
@ -118,7 +113,7 @@ class HloBuffer {
// The set of values contained in this buffer. Vector contains no duplicates
// and is sorted stably by HloValue::Id.
std::vector<const HloValue*> values_;
const std::vector<const HloValue*> values_;
};
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);

View File

@ -118,13 +118,11 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) {
}
}
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo,
HloOpcode opcode) {
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo) {
return HandleElementwiseOp(hlo);
}
Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) {
Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo) {
return HandleElementwiseOp(hlo);
}

View File

@ -49,9 +49,8 @@ class HloCostAnalysis : public DfsHloVisitor {
using ShapeSizeFunction = std::function<int64(const Shape&)>;
explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override;
Status HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) override;
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
Status HandleConstant(HloInstruction* constant,
const Literal& literal) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element,

View File

@ -67,6 +67,22 @@ HloValue& HloDataflowAnalysis::GetValueDefinedAt(
return GetUniqueValueAt(instruction, index);
}
HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
const ShapeIndex& index,
bool is_phi) {
const int64 value_id = next_value_id_++;
auto emplaced = values_.emplace(
std::piecewise_construct, std::forward_as_tuple(value_id),
std::forward_as_tuple(value_id, instruction, index, is_phi));
CHECK(emplaced.second);
return &emplaced.first->second;
}
void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
values_.erase(value_id);
}
string HloDataflowAnalysis::ToString() const {
string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Instruction value sets:\n");
@ -99,22 +115,98 @@ string HloDataflowAnalysis::ToString() const {
}
}
StrAppend(&out, " HloValues:\n");
for (const HloValue& value : values()) {
StrAppend(&out, value.ToString(/*indent=*/4));
}
StrAppend(&out, " Phi resolutions:\n");
for (const HloValue& value : values()) {
if (value.is_phi()) {
const HloValue* resolved_value = ResolvePhi(value);
StrAppend(&out, " ", value.ToShortString(), " => ",
resolved_value == nullptr ? "UNKNOWN"
: resolved_value->ToShortString(),
"\n");
}
for (const HloValue* value : values()) {
StrAppend(&out, value->ToString(/*indent=*/4));
}
return out;
}
bool HloDataflowAnalysis::Phi(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
CHECK(ssa_form_);
for (const InstructionValueSet* input : inputs) {
DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
}
bool changed = false;
for (auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
// Positions with phi values should never have more than one value in the
// value set.
CHECK_LE(value_set.values().size(), 1);
const HloValue* current_value =
value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
// Construct a vector of unique value IDs of the inputs.
std::vector<HloValue::Id> input_value_ids;
for (const InstructionValueSet* input : inputs) {
for (const HloValue* value : input->element(index).values()) {
input_value_ids.push_back(value->id());
}
}
std::sort(input_value_ids.begin(), input_value_ids.end());
input_value_ids.erase(
std::unique(input_value_ids.begin(), input_value_ids.end()),
input_value_ids.end());
// Remove the existing phi value (if it exists). The phi can be its own
// input, for example, in while body parameters where the body passes
// through the parameter value.
bool current_value_defined_here =
(current_value != nullptr &&
current_value->defining_instruction() == instruction &&
current_value->defining_index() == index);
if (current_value_defined_here) {
CHECK(current_value->is_phi());
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
current_value->id());
if (it != input_value_ids.end()) {
input_value_ids.erase(it);
}
}
if (input_value_ids.empty()) {
// A value set which has at least one element should never have its value
// set reduced to zero elements. During dataflow value sets only can go
// from empty to non-empty, not the reverse.
CHECK_EQ(value_set.values().size(), 0)
<< "Instruction " << instruction->name() << " at index " << index
<< " previously had non-empty value set. Value set: " << value_set;
} else if (input_value_ids.size() == 1) {
// Only a single value reaches this point. There should be no phi, and
// this value set should contain this single value.
const HloValue& new_value = GetValue(input_value_ids[0]);
if (current_value == nullptr) {
value_set.Clear();
value_set.AddValue(&new_value);
changed = true;
} else if (current_value != &new_value) {
if (current_value_defined_here) {
// Remove the existing phi.
DeleteHloValue(current_value->id());
}
value_set.Clear();
value_set.AddValue(&new_value);
changed = true;
}
} else {
// Multiple distinct values reach this point. A phi value is
// necessary.
CHECK_GT(input_value_ids.size(), 1);
if (current_value == nullptr || !current_value->is_phi()) {
value_set.Clear();
value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
changed = true;
}
}
}
return changed;
}
const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
return values_.at(value_id);
}
@ -142,129 +234,6 @@ HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
return GetValueSet(position.instruction, position.index);
}
void HloDataflowAnalysis::UpdateAfterChangingOperand(
HloInstruction* instruction, HloInstruction* old_operand,
HloInstruction* new_operand) {
CHECK(std::find(instruction->operands().begin(),
instruction->operands().end(),
new_operand) != instruction->operands().end());
VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
<< old_operand->name() << " => " << new_operand->name() << ")";
std::vector<HloInstruction*> to_update = {instruction};
// If the instruction calls any computations then add the parameters of called
// computation to capture any changes to the dataflow into the subcomputation
// introduced by the new operand.
for (HloComputation* computation : instruction->called_computations()) {
to_update.insert(to_update.end(),
computation->parameter_instructions().begin(),
computation->parameter_instructions().end());
}
UpdateInstructionsAndPropagate(to_update);
// The uses of the values in the old and new operand may have changed. Uses of
// other HloValues are updated in UpdateInstructionsAndPropagate.
for (auto& pair : GetInstructionValueSet(old_operand)) {
for (const HloValue* value : pair.second.values()) {
GetValue(value->id()).RecomputeUses();
}
}
for (auto& pair : GetInstructionValueSet(new_operand)) {
for (const HloValue* value : pair.second.values()) {
GetValue(value->id()).RecomputeUses();
}
}
TF_DCHECK_OK(VerifyAgainstReference());
}
void HloDataflowAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
HloInstruction* new_root) {
VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
<< new_root->name() << ")";
CHECK_EQ(new_root, new_root->parent()->root_instruction());
CHECK_EQ(new_root->parent(), old_root->parent());
std::vector<HloInstruction*> to_update = {old_root, new_root};
const CallGraphNode& call_graph_node =
call_graph_->GetNode(new_root->parent());
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
to_update.push_back(callsite.instruction());
} else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Add the while itself, and the body and condition parameters.
to_update.push_back(callsite.instruction());
to_update.push_back(
callsite.instruction()->while_body()->parameter_instruction(0));
to_update.push_back(
callsite.instruction()->while_condition()->parameter_instruction(0));
}
}
UpdateInstructionsAndPropagate(to_update);
TF_DCHECK_OK(VerifyAgainstReference());
}
const HloValue* HloDataflowAnalysis::ResolvePhi(const HloValue& phi) const {
CHECK(phi.is_phi());
tensorflow::gtl::FlatSet<const HloValue*> visited;
std::queue<const HloValue*> worklist;
auto add_to_worklist = [&worklist, &visited](const HloValue* v) {
if (visited.insert(v).second) {
// 'v' was not previously in visited.
worklist.push(v);
}
};
add_to_worklist(&phi);
const HloValue* resolved_value = nullptr;
while (!worklist.empty()) {
const HloValue* value = worklist.front();
worklist.pop();
if (!value->is_phi()) {
if (resolved_value == nullptr) {
resolved_value = value;
} else if (resolved_value != value) {
return nullptr;
}
} else {
for (const HloValue* input : phi_inputs_.at(value)) {
add_to_worklist(input);
}
}
}
return resolved_value;
}
void HloDataflowAnalysis::UpdatePhiInputs(
const HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
CHECK(ssa_form_);
for (auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
const HloValue& phi_value = GetUniqueValueAt(instruction, index);
auto& phi_inputs = phi_inputs_.at(&phi_value);
phi_inputs.clear();
for (const InstructionValueSet* input : inputs) {
for (const HloValue* value : input->element(index).values()) {
// The number of phi inputs is typically 2, and virtually always very
// small.
if (std::find(phi_inputs.begin(), phi_inputs.end(), value) ==
phi_inputs.end()) {
phi_inputs.push_back(value);
}
}
}
}
}
bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
const InstructionValueSet& operand_set =
@ -380,8 +349,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
}
if (ssa_form_ && called_from_while) {
UpdatePhiInputs(parameter, inputs);
return false;
return Phi(parameter, inputs);
} else {
return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
}
@ -439,8 +407,7 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
&GetInstructionValueSet(xla_while->while_body()->root_instruction()),
&GetInstructionValueSet(xla_while->operand(0))};
if (ssa_form_) {
UpdatePhiInputs(xla_while, inputs);
return false;
return Phi(xla_while, inputs);
} else {
return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
}
@ -487,38 +454,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
VLOG(3) << "Worklist top: " << instruction->name();
VLOG(3) << ToString();
// The updating of the instruction value set below in
// UpdateInstructionValueSet does not update HloValue::positions(). To
// perform the positions() update remove all positions in 'instruction' from
// the HloValues in 'instruction's value set prior to the update, then after
// the update add the new positions back in. There is likely a more
// efficient way of doing this.
for (auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
for (const HloValue* value : value_set.values()) {
if (value->defining_instruction() != instruction) {
// Use GetValue for a non-const HloValue reference.
GetValue(value->id()).RemovePosition(instruction, index);
}
}
}
bool changed = UpdateInstructionValueSet(instruction);
// Add the positions back in.
for (auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
for (const HloValue* value : value_set.values()) {
if (value->defining_instruction() != instruction) {
// Use GetValue for a non-const HloValue reference.
GetValue(value->id()).AddPosition(instruction, index);
}
}
}
if (!changed) {
if (!UpdateInstructionValueSet(instruction)) {
// No change to the instruction's value set.
VLOG(4) << "No change.";
continue;
@ -531,12 +467,16 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
for (HloInstruction* user : instruction->users()) {
worklist.push(user);
// If user calls a computation, then the respective parameter(s) of the
// computation need to be updated.
// If user sequentially calls a computation, then the respective
// parameter(s) of the computation need to be updated.
for (HloComputation* called_computation : user->called_computations()) {
for (int64 operand_number : user->OperandIndices(instruction)) {
worklist.push(
called_computation->parameter_instruction(operand_number));
const CallGraphNode& call_graph_node =
call_graph_->GetNode(called_computation);
if (call_graph_node.context() == CallContext::kSequential) {
for (int64 operand_number : user->OperandIndices(instruction)) {
worklist.push(
called_computation->parameter_instruction(operand_number));
}
}
}
}
@ -574,25 +514,10 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
}
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// Gather the values to create before creating them. This is done because we
// want to allocate the vector of values only once so references to elements
// are stable.
struct ValueToCreate {
HloInstruction* instruction;
ShapeIndex index;
bool is_phi;
};
std::vector<ValueToCreate> values_to_create;
for (const std::unique_ptr<HloComputation>& computation :
module_->computations()) {
const CallGraphNode& call_graph_node =
call_graph_->GetNode(computation.get());
bool called_from_while = std::any_of(
call_graph_node.caller_callsites().begin(),
call_graph_node.caller_callsites().end(), [](const CallSite& cs) {
return cs.instruction()->opcode() == HloOpcode::kWhile;
});
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
@ -603,20 +528,22 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// Lambda to set the value set to define all values in the output of the
// instruction.
auto define_all_values = [this, &instruction,
&values_to_create](bool is_phi = false) {
auto define_all_values = [this, &instruction](bool is_phi = false) {
for (auto& pair : GetInstructionValueSet(instruction.get())) {
const ShapeIndex& index = pair.first;
values_to_create.push_back({instruction.get(), index, is_phi});
HloValue* value =
NewHloValue(instruction.get(), index, /*is_phi=*/false);
GetValueSet(instruction.get(), index).AddValue(value);
}
};
// Lambda to set the value set to define only the top-level buffer in the
// output of the instruction. Any other values flow from the operands of
// the instruction (or from cross-computation dataflow).
auto define_top_level_only = [this, &instruction, &values_to_create]() {
values_to_create.push_back(
{instruction.get(), /*index=*/{}, /*is_phi=*/false});
auto define_top_level_only = [this, &instruction]() {
HloValue* value =
NewHloValue(instruction.get(), /*index=*/{}, /*is_phi=*/false);
GetValueSet(instruction.get(), /*index=*/{}).AddValue(value);
};
switch (instruction->opcode()) {
@ -626,10 +553,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
}
break;
case HloOpcode::kWhile:
if (ssa_form_) {
define_all_values(/*is_phi=*/true);
}
break;
case HloOpcode::kCall:
case HloOpcode::kGetTupleElement:
// These instructions define no values. The values in their output
@ -654,10 +577,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// values in their output. Otherwise the values of the parameter
// come from the caller (eg, operands to the kCall instruction).
define_all_values();
} else if (call_graph_node.context() == CallContext::kSequential &&
called_from_while && ssa_form_) {
// Parameters of while bodies and conditions are phis.
define_all_values(/*is_phi=*/true);
}
break;
case HloOpcode::kCopy:
@ -674,164 +593,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
}
}
// Reserve the vector ahead of time so references to elements are stable.
values_.reserve(values_to_create.size());
for (int64 i = 0; i < values_to_create.size(); ++i) {
const ValueToCreate& to_create = values_to_create[i];
values_.emplace_back(/*id=*/i, to_create.instruction, to_create.index,
to_create.is_phi);
const HloValue& value = values_.back();
GetValueSet(to_create.instruction, to_create.index).AddValue(&value);
if (value.is_phi()) {
phi_inputs_[&value] = {};
}
}
return Status::OK();
}
bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const {
// If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
// is live into the module.
if (b.defining_instruction()->parent() == module_->entry_computation() &&
b.defining_instruction()->opcode() == HloOpcode::kParameter) {
return false;
}
// Phi values require special handling. Because XLA does not have a phi
// instruction, the definition instruction of the phis values are
// placeholders: either the subcomputation parameter (body or condition) or
// the while instruction. However, the program point where these values are
// logically defined does not necessarily coincide exactly with program point
// of these place-holder instructions. So we explicitly define the following
// order for phi values:
//
// body/condition parameter phi:
// Defined before all values defined in its computation excepting other
// phis.
//
// while phi:
// defined after all values defined in the condition or body.
//
auto is_body_or_condition_phi = [](const HloValue& v) {
return v.is_phi() &&
v.defining_instruction()->opcode() == HloOpcode::kParameter;
};
if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
a.defining_instruction()->parent())) {
return true;
}
if (is_body_or_condition_phi(b) &&
call_graph_->InstructionIsNestedIn(a.defining_instruction(),
b.defining_instruction()->parent())) {
return false;
}
// If 'b' is a while phi and 'a' is in the body or condition, then 'a'
// executes before 'b'.
if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
(call_graph_->InstructionIsNestedIn(
a.defining_instruction(), b.defining_instruction()->while_body()) ||
call_graph_->InstructionIsNestedIn(
a.defining_instruction(),
b.defining_instruction()->while_condition()))) {
return true;
}
return ordering.ExecutesBefore(a.defining_instruction(),
b.defining_instruction());
}
bool HloDataflowAnalysis::UseIsBeforeValueDefinition(
const HloUse& use, const HloValue& value,
const HloOrdering& ordering) const {
if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) {
return true;
}
// If the use is at the instruction where the value is defined, then the use
// is before the def if the instruction allows buffer sharing (in place
// computation).
if (use.instruction == value.defining_instruction() &&
CanShareOperandBufferWithUser(
use.instruction->mutable_operand(use.operand_number),
use.operand_index, value.defining_instruction(),
value.defining_index())) {
return true;
}
// The use at a while is an input to a phi, and logically occurs before values
// are defined in the body or condition computations.
if (use.instruction->opcode() == HloOpcode::kWhile) {
const HloInstruction* xla_while = use.instruction;
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
xla_while->while_body()) ||
call_graph_->InstructionIsNestedIn(value.defining_instruction(),
xla_while->while_condition())) {
return true;
}
}
// Similarly if the value is defined at a while, it logically occurs after any
// uses in the body or condition computations.
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
CHECK(ssa_form_);
const HloInstruction* xla_while = value.defining_instruction();
if (call_graph_->InstructionIsNestedIn(use.instruction,
xla_while->while_body()) ||
call_graph_->InstructionIsNestedIn(use.instruction,
xla_while->while_condition())) {
return true;
}
}
return false;
}
bool HloDataflowAnalysis::LiveRangeStrictlyBefore(
const HloValue& a, const HloValue& b, const HloOrdering& ordering) const {
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
<< ", b = " << b.ToShortString() << ")";
if (!IsDefinedBefore(a, b, ordering)) {
VLOG(4) << "a not defined before b";
return false;
}
// Live-out values from the module can never have ranges strictly before any
// other value.
if (a.live_out_of_module()) {
VLOG(4) << "a is live out of module";
return false;
}
// Live-out values of computations can never have ranges strictly before any
// other value in the computation (including values nested in
// subcomputations).
if (a.live_out_of_computation() &&
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
a.defining_instruction()->parent())) {
VLOG(4) << "a is live out of computation containing b";
return false;
}
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
if (!UseIsBeforeValueDefinition(use, b, ordering)) {
VLOG(4) << "use of a (" << use << ") not before b is defined";
return false;
}
}
return true;
}
bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const {
// Buffers without disjoint liveness may interfere.
return !LiveRangeStrictlyBefore(a, b, ordering) &&
!LiveRangeStrictlyBefore(b, a, ordering);
}
/* static */
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
HloModule* module, bool ssa_form, bool bitcast_defines_value) {
@ -855,6 +619,33 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
}
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
// Add in positions to all values.
for (const std::unique_ptr<HloComputation>& computation :
module->computations()) {
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
for (const auto& pair :
dataflow_analysis->GetInstructionValueSet(instruction.get())) {
const ShapeIndex& index = pair.first;
const HloValueSet& value_set = pair.second;
for (const HloValue* value : value_set.values()) {
if (value->defining_instruction() != instruction.get()) {
dataflow_analysis->GetValue(value->id())
.AddPosition(instruction.get(), index);
}
}
}
}
}
// Construct vector of values.
dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
for (auto& pair : dataflow_analysis->values_) {
dataflow_analysis->values_vector_.push_back(&pair.second);
}
std::sort(dataflow_analysis->values_vector_.begin(),
dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
TF_DCHECK_OK(dataflow_analysis->Verify());
XLA_VLOG_LINES(1, dataflow_analysis->ToString());
@ -865,14 +656,14 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
Status HloDataflowAnalysis::Verify() const {
// Verify each HloValue appears in the value sets that the value's positions()
// indicate.
for (const HloValue& value : values()) {
for (const HloPosition& position : value.positions()) {
for (const HloValue* value : values()) {
for (const HloPosition& position : value->positions()) {
const HloValueSet& value_set = GetValueSet(position);
TF_RET_CHECK(std::find(value_set.values().begin(),
value_set.values().end(),
&value) != value_set.values().end())
value) != value_set.values().end())
<< "Value set at position " << position << " does not contain value "
<< value.ToShortString();
<< value->ToShortString();
}
}
@ -898,75 +689,4 @@ Status HloDataflowAnalysis::Verify() const {
return Status::OK();
}
Status HloDataflowAnalysis::VerifyAgainstReference() const {
TF_RETURN_IF_ERROR(Verify());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> reference,
Run(module_, ssa_form_, bitcast_defines_value_));
TF_RETURN_IF_ERROR(reference->Verify());
VLOG(2) << "This analysis:";
XLA_VLOG_LINES(2, ToString());
VLOG(2) << "Reference:";
XLA_VLOG_LINES(2, reference->ToString());
// Verify value sets in each position are identical.
for (const auto& computation : module_->computations()) {
for (const auto& instruction : computation->instructions()) {
for (const auto& pair : GetInstructionValueSet(instruction.get())) {
const ShapeIndex& index = pair.first;
const HloValueSet& value_set = pair.second;
const HloValueSet& reference_value_set =
reference->GetValueSet(instruction.get(), index);
auto value_in_set = [](const HloValue& v, const HloValueSet& vset) {
return std::find_if(vset.values().begin(), vset.values().end(),
[&v](const HloValue* w) { return *w == v; }) !=
vset.values().end();
};
for (const HloValue* value : value_set.values()) {
TF_RET_CHECK(value_in_set(*value, reference_value_set))
<< "Value " << value->ToShortString()
<< " does not exist in reference";
}
for (const HloValue* reference_value : reference_value_set.values()) {
TF_RET_CHECK(value_in_set(*reference_value, value_set))
<< "Value " << reference_value->ToShortString()
<< " only exists in reference";
}
}
}
}
// Verify all phis resolve identically and uses are identical.
for (const HloValue& value : values()) {
const HloValue& reference_value = reference->GetValueDefinedAt(
value.defining_instruction(), value.defining_index());
TF_RET_CHECK(value.is_phi() == reference_value.is_phi());
if (value.is_phi()) {
const HloValue* resolved_value = ResolvePhi(value);
const HloValue* reference_resolved_value =
reference->ResolvePhi(reference_value);
if (resolved_value == nullptr) {
TF_RET_CHECK(reference_resolved_value == nullptr);
} else {
TF_RET_CHECK(reference_resolved_value != nullptr);
TF_RET_CHECK(*reference_resolved_value == *resolved_value);
}
}
for (const HloUse& use : value.uses()) {
TF_RET_CHECK(std::find(reference_value.uses().begin(),
reference_value.uses().end(),
use) != reference_value.uses().end());
}
for (const HloUse& reference_use : reference_value.uses()) {
TF_RET_CHECK(std::find(value.uses().begin(), value.uses().end(),
reference_use) != value.uses().end());
}
}
return Status::OK();
}
} // namespace xla

View File

@ -88,10 +88,10 @@ class HloDataflowAnalysis {
// given position.
const HloValueSet& GetValueSet(const HloInstruction* instruction,
const ShapeIndex& index = {}) const;
HloValueSet& GetValueSet(const HloInstruction* instruction,
const ShapeIndex& index = {});
const HloValueSet& GetValueSet(const HloPosition& position) const;
HloValueSet& GetValueSet(const HloPosition& position);
HloValueSet& GetValueSet(const HloInstruction* instruction,
const ShapeIndex& index = {});
// Return the unique value in the HloValueSet at the given instruction and
// shape index. CHECKs if the value set does not contain a exactly one value.
@ -108,49 +108,11 @@ class HloDataflowAnalysis {
const HloValue& GetValue(HloValue::Id value_id) const;
HloValue& GetValue(HloValue::Id value_id);
// Returns whether the given values interfere assuming the given HLO
// ordering. Two values interfere if they may both be simultaneously live.
bool MayInterfere(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const;
// Overload which takes HloValue:Ids.
bool MayInterfere(HloValue::Id a, HloValue::Id b,
const HloOrdering& ordering) const {
return MayInterfere(GetValue(a), GetValue(b), ordering);
}
// Return the total number of HloValues.
int64 value_count() const { return values_.size(); }
// Return a vector of all HloValues.
const std::vector<HloValue>& values() const { return values_; }
// Updates the dataflow after the changing an operand of
// 'instruction'. Dataflow update is not possible if instructions have been
// added or removed from the graph.
void UpdateAfterChangingOperand(HloInstruction* instruction,
HloInstruction* old_operand,
HloInstruction* new_operand);
// Updates the dataflow after the changing the root of a computation from
// 'old_root' to 'new_root'.
void UpdateAfterChangingRoot(HloInstruction* old_root,
HloInstruction* new_root);
// Returns the non-phi HloValue that is the unique (transitive) input to the
// given phi. If no such HloValue exists (there are multiple inputs to the
// phi) then nullptr is returned. This is computed by all walking the inputs
// of the given phi value until non-phi HloValue(s) are encountered.
const HloValue* ResolvePhi(const HloValue& phi) const;
const HloValue* ResolvePhi(const HloInstruction* instruction,
const ShapeIndex& index = {}) const {
return ResolvePhi(GetValueDefinedAt(instruction, index));
}
// Compare the dataflow analysis against a clean recomputation of the
// analysis. Returns an error status if there is a mismatch. Useful for
// verifying the correctness after updates to the analysis.
Status VerifyAgainstReference() const;
// Return a vector of all HloValues stabily sorted by HloValue::Id.
const std::vector<const HloValue*>& values() const { return values_vector_; }
// Return the call graph used for computing the dataflow.
const CallGraph& call_graph() const { return *call_graph_; }
@ -161,6 +123,13 @@ class HloDataflowAnalysis {
HloDataflowAnalysis(HloModule* module, bool ssa_form,
bool bitcast_defines_value = false);
// Returns a new HloValue defined at the given instruction and shape index.
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
bool is_phi = false);
// Delete the HloValue with the given ID.
void DeleteHloValue(HloValue::Id value_id);
// Constructs and initializes the InstructionValueSets of all instructions to
// contain exactly the HloValues defined by each instruction. These values can
// then propagated throughout the HLO graph by calling
@ -187,10 +156,11 @@ class HloDataflowAnalysis {
void UpdateInstructionsAndPropagate(
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
// Sets the inputs of the given phi to given value(s).
void UpdatePhiInputs(
const HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
// Return the result of the SSA Phi function applied to the given inputs at
// the given instruction. If skip_top_level is true, then the top level of the
// value set of 'instruction' is not modified.
bool Phi(HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
// Updates the positions of the HloValues in the output of the given
// instruction. This should be called after the instruction value set of
@ -203,20 +173,6 @@ class HloDataflowAnalysis {
HloInstruction* instruction, const InstructionValueSet& new_value_set,
const InstructionValueSet* prev_value_set = nullptr);
// Returns true if the live range of the given value 'a' is strictly before
// the live range of value 'b' using the given HLO ordering.
bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const;
// Returns whether the value 'a' is defined before the value 'b' under the
// given ordering.
bool IsDefinedBefore(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const;
// Returns whether the given use is before the given value definition.
bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value,
const HloOrdering& ordering) const;
// Verify various invariants of the dataflow analysis.
Status Verify() const;
@ -226,19 +182,19 @@ class HloDataflowAnalysis {
std::unique_ptr<CallGraph> call_graph_;
// Array of all values in the module. This is allocated once at analysis
// construction time so HloValue references are stable. Updates to the
// analysis via UpdateAfterChangingOperand and UpdateAfterChangingRoot do not
// result in the creation or destruction of any HloValues.
std::vector<HloValue> values_;
// Map hold the inputs to each phi value in the module. Used by ResolvePhi.
tensorflow::gtl::FlatMap<const HloValue*,
tensorflow::gtl::InlinedVector<const HloValue*, 2>>
phi_inputs_;
// The map of all HloValues in the module. We pass around pointers to the
// mapped HloValues, so the underlying container must keep them valid despite
// mutations touching other map entries.
std::unordered_map<HloValue::Id, HloValue> values_;
// A map from instruction to InstructionValueSet.
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
// A vector containing all HloValues sorted by HloValue::Id.
std::vector<const HloValue*> values_vector_;
// The Id to use for the next HloValue.
HloValue::Id next_value_id_ = 0;
};
} // namespace xla

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@ -44,8 +43,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
// Run dataflow analysis on the member module. For convenience returns a
// reference to the generated analysis stored in analysis_.
HloDataflowAnalysis& RunAnalysis(bool ssa_form,
bool bitcast_defines_value = false) {
const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
bool bitcast_defines_value = false) {
analysis_ =
HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
.ConsumeValueOrDie();
@ -71,8 +70,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
const HloInstruction* b) {
EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a),
analysis_->GetValueDefinedAt(b), ordering);
return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
analysis_->GetValueDefinedAt(b));
}
std::unique_ptr<HloModule> module_;
@ -499,37 +498,26 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
if (ssa_form) {
// While instruction should define phi values. The value at index {0} is a
// degenerate phi with a single input 'constant1'.
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{0}),
&analysis.GetValueDefinedAt(constant1));
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{0}),
&analysis.GetValueDefinedAt(constant1));
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{0}),
&analysis.GetValueDefinedAt(constant1));
// Element 0 of the tuple passed through the body so no phi value is
// defined.
EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
// Element 1 of the tuple should be a phi value.
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{1}), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{1}), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{1}), nullptr);
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{xla_while, 0, {0}}));
EXPECT_THAT(
analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}}));
EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
.live_out_of_module());
// Constant1 passes through the body and out of the module.
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
.live_out_of_module());
@ -613,20 +601,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
if (ssa_form) {
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while2).live_out_of_module());
EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
} else {
// Element 0 is passed through all the while instructions and out of the
// module.
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
}
// Element 0 is passed through all the while instructions and out of the
// module..
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
}
TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
@ -705,13 +688,18 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
if (ssa_form) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
// Element 0 of the nested while is %negate.
EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
// Element 1 is a phi value (join of %add and %constant2).
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
@ -724,8 +712,6 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
EXPECT_TRUE(
analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
} else {
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
UnorderedElementsAre(analysis.GetValueDefinedAt(add),
analysis.GetValueDefinedAt(constant2)));
@ -1496,256 +1482,6 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
}
TEST_P(HloDataflowAnalysisTest, UpdateAnalysisForWhile) {
// Test updating dataflow after modifying a module with an array shaped while:
//
// body(F32[] %param):
// %negate = Negate(%param)
//
// condition(F32[] %param):
// return Constant(false)
//
// entry:
// %constant = Constant(1.0)
// %exp = Exp(%constant)
// return While(%exp, body, condition)
//
auto body_builder = HloComputation::Builder("body");
auto body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kNegate, body_param));
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
// Condition computation trivially returns a constant "false".
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
// Sanity check the initial dataflow analysis before transforming the HLO
// graph.
if (ssa_form) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param));
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param).is_phi());
EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param));
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param).is_phi());
EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
} else {
EXPECT_THAT(HloValuesAt(body_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(cond_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(xla_while),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
analysis.GetValueDefinedAt(negate)));
EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
}
// Set the body root to the body_param. Previously it was Negate(body_param).
body->set_root_instruction(body_param);
// Prior to updating, verify that the dataflow analysis is no longer valid.
Status verify_status = analysis.VerifyAgainstReference();
EXPECT_FALSE(verify_status.ok());
analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
/*new_root=*/body_param);
// Analysis should be valid after the update.
TF_EXPECT_OK(analysis.VerifyAgainstReference());
if (ssa_form) {
// The phis should now be resolvable as 'exp' is passed through the body
// transparently.
EXPECT_EQ(analysis.ResolvePhi(body_param),
&analysis.GetValueDefinedAt(exp));
EXPECT_EQ(analysis.ResolvePhi(cond_param),
&analysis.GetValueDefinedAt(exp));
EXPECT_EQ(analysis.ResolvePhi(xla_while), &analysis.GetValueDefinedAt(exp));
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
} else {
EXPECT_THAT(HloValuesAt(body_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
EXPECT_THAT(HloValuesAt(cond_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
EXPECT_THAT(HloValuesAt(xla_while),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
}
EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
// Now replace the operand of the while with %constant (was %exp).
TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
/*new_operand=*/constant);
// Verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
if (ssa_form) {
// The phis now resolve to 'constant'.
EXPECT_EQ(analysis.ResolvePhi(body_param),
&analysis.GetValueDefinedAt(constant));
EXPECT_EQ(analysis.ResolvePhi(cond_param),
&analysis.GetValueDefinedAt(constant));
EXPECT_EQ(analysis.ResolvePhi(xla_while),
&analysis.GetValueDefinedAt(constant));
} else {
EXPECT_THAT(HloValuesAt(body_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
EXPECT_THAT(HloValuesAt(cond_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
EXPECT_THAT(HloValuesAt(xla_while),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
}
// And finally make the negate the root of the body again.
body->set_root_instruction(negate);
analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
/*new_root=*/negate);
// Verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
if (ssa_form) {
// Phis should no longer be resolvable.
EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
EXPECT_EQ(analysis.ResolvePhi(xla_while), nullptr);
} else {
EXPECT_THAT(HloValuesAt(body_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(cond_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(xla_while),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
analysis.GetValueDefinedAt(negate)));
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
}
// After the updates, verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
}
TEST_P(HloDataflowAnalysisTest, UpdateOfATupleSelect) {
// Test changing the operands of kSelects of a tuple value and updating the
// dataflow.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
auto a = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto b = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
auto c = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
auto d = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
auto tuple_a = builder.AddInstruction(HloInstruction::CreateTuple({a}));
auto tuple_b = builder.AddInstruction(HloInstruction::CreateTuple({b}));
auto tuple_c = builder.AddInstruction(HloInstruction::CreateTuple({c}));
auto tuple_d = builder.AddInstruction(HloInstruction::CreateTuple({d}));
const Shape tuple_shape = tuple_a->shape();
auto select_aa = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_a));
auto select_ab = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_b));
auto select_cd = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred, tuple_c, tuple_d));
auto select_abcd = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred, select_ab, select_cd));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
// Sanity check dataflow before changing the graph and updating.
EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a)));
EXPECT_THAT(HloValuesAt(select_ab, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(b)));
EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(c),
analysis.GetValueDefinedAt(d)));
EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(b),
analysis.GetValueDefinedAt(c),
analysis.GetValueDefinedAt(d)));
EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(c).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
// Set the rhs of 'select_aa' to be 'd'.
TF_ASSERT_OK(select_aa->ReplaceOperandWith(2, tuple_d));
analysis.UpdateAfterChangingOperand(select_aa, /*old_operand=*/tuple_a,
/*new_operand=*/tuple_d);
// Verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(d)));
// Set the lhs of 'select_cd' to be 'a'.
TF_ASSERT_OK(select_cd->ReplaceOperandWith(1, tuple_a));
analysis.UpdateAfterChangingOperand(select_cd, /*old_operand=*/tuple_c,
/*new_operand=*/tuple_a);
// Verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(d)));
EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(b),
analysis.GetValueDefinedAt(d)));
EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
EXPECT_FALSE(analysis.GetValueDefinedAt(c).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
// After the updates, verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
}
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
HloDataflowAnalysisTest,
::testing::Values(false, true));

View File

@ -561,13 +561,21 @@ tooltip = " ";
}
string comp_body = DumpComputation(subcomp);
string computation =
Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
// Add an edge from the subcomputation to its parent node. If subcomp
// belongs to a fusion node, it's drawn in place of the fusion instruction, so
// there's no need to link those.
if (parent_instr->opcode() != HloOpcode::kFusion) {
if (parent_instr->opcode() == HloOpcode::kFusion) {
// Dump any nested fusion nodes.
for (const auto& subcomp_instr : subcomp->instructions()) {
if (subcomp_instr->opcode() == HloOpcode::kFusion) {
StrAppend(
&comp_body,
DumpSubcomputation(subcomp_instr->fused_instructions_computation(),
subcomp_instr.get()));
}
}
} else {
// Add an edge from the subcomputation to its parent node. If subcomp
// belongs to a fusion node, it's drawn in place of the fusion instruction,
// so there's no need to link those.
edge_ids_.insert(
{{subcomp->root_instruction(), parent_instr}, next_edge_id_++});
const char* edge_fmt =
@ -578,6 +586,9 @@ tooltip = " ";
subcomp->name(), parent_instr->name()));
}
string computation =
Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
return computation;
}

View File

@ -793,13 +793,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
}
}
for (HloComputation* computation :
instruction_to_fuse->called_computations()) {
if (std::find(called_computations_.begin(), called_computations_.end(),
computation) == called_computations_.end()) {
called_computations_.push_back(computation);
}
}
VLOG(2) << "New clone:\n" << clone->ToString();
return clone;
}

View File

@ -797,8 +797,7 @@ class HloInstruction {
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
// Returns the computations this instruction calls (if any). This includes
// computations called by fused instructions inside of a fusion instruction.
// Returns the computations this instruction directly calls (if any).
const std::vector<HloComputation*>& called_computations() const {
return called_computations_;
}

View File

@ -758,16 +758,13 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
auto* fusion = computation->CreateFusionInstruction(
{map_3_y}, HloInstruction::FusionKind::kLoop);
auto* fused_computation = fusion->fused_instructions_computation();
EXPECT_THAT(fusion->called_computations(),
ElementsAre(fused_computation, computation_y));
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
fusion->FuseInstruction(map_2_x);
EXPECT_THAT(fusion->called_computations(),
ElementsAre(fused_computation, computation_y, computation_x));
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
fusion->FuseInstruction(map_1_x);
EXPECT_THAT(fusion->called_computations(),
ElementsAre(fused_computation, computation_y, computation_x));
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
}
TEST_F(HloInstructionTest, ComplexFusionOp) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
@ -218,6 +219,94 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
}
TEST_F(HloOrderingTest, ValuesInWhileComputations) {
// Tests the ordering of values (defined by dataflow analysis) in the body and
// condition of a while instruction. HLO code:
//
// body(F32[]) %param):
// %negate = Negate(%param)
//
// condition(F32[] %param):
// %convert = Convert<PRED>(%param)
//
// entry:
// %constant = Constant(1.0)
// %while = While(%constant, body, condition)
// %add = Add(%constant, %while)
//
auto module = CreateNewModule();
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
auto body_builder = HloComputation::Builder("body");
auto body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape, HloOpcode::kNegate, body_param));
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
HloComputation* condition =
module->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, constant, xla_while));
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(
auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true));
DependencyHloOrdering ordering(module.get());
// Init value is defined before the while, but live range is not before the
// while because of the use of the init value in the add.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_FALSE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
// Any value defined in the body or condition is defined before the while, and
// has a live range strictly before the while.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
// The live range of the while should be before the add.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
dataflow->GetValueDefinedAt(add)));
ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
EXPECT_EQ(while_use.instruction, add);
EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
while_use, dataflow->GetValueDefinedAt(add)));
EXPECT_TRUE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while),
dataflow->GetValueDefinedAt(add)));
}
} // namespace
} // namespace xla

View File

@ -1248,7 +1248,8 @@ StatusOr<bool> HloRematerialization::Run(
sequence->at(node.computation())));
}
return Status::OK();
}));
},
/*visit_unreachable_nodes=*/false));
// The peak memory usage of the module equals the peak memory use of the entry
// computation plus the output size of the computation. This is because the

View File

@ -159,12 +159,6 @@ void HloValue::AddPosition(HloInstruction* instruction,
for (const HloPosition& position : positions_) {
DCHECK_NE(position, new_position);
}
// The shape of the new position must match existing positions.
if (!positions_.empty()) {
CHECK(
ShapeUtil::Compatible(positions_.front().shape(), new_position.shape()))
<< "front: " << positions_.front() << " new: " << new_position;
}
positions_.push_back(std::move(new_position));

View File

@ -225,6 +225,9 @@ class HloValueSet {
// already exist in the set.
bool AddValue(const HloValue* value);
// Clear all values from the set.
void Clear() { values_.clear(); }
// Return the unique HLO value in the set. CHECKs if the set does not contain
// exactly one value.
const HloValue& GetUniqueValue() const {

View File

@ -32,13 +32,11 @@ class ShapeVerifier : public DfsHloVisitor {
const std::function<int64(const Shape&)>& shape_size_fn)
: shape_size_fn_(shape_size_fn) {}
Status HandleElementwiseUnary(HloInstruction* hlo,
HloOpcode opcode) override {
Status HandleElementwiseUnary(HloInstruction* hlo) override {
return CheckUnaryShape(hlo);
}
Status HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) override {
Status HandleElementwiseBinary(HloInstruction* hlo) override {
return CheckBinaryShape(hlo);
}
@ -282,6 +280,14 @@ class ShapeVerifier : public DfsHloVisitor {
const std::function<int64(const Shape&)> shape_size_fn_;
};
string ComputationsToString(
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
return tensorflow::str_util::Join(
computations, ",", [](string* s, const HloComputation* computation) {
s->append(computation->name());
});
}
} // namespace
StatusOr<bool> HloVerifier::Run(HloModule* module) {
@ -292,6 +298,17 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
for (const auto& instruction : computation->instructions()) {
TF_RET_CHECK(instruction->parent() == computation.get());
if (instruction->opcode() == HloOpcode::kFusion) {
TF_RET_CHECK(
ContainersEqual(instruction->called_computations(),
{instruction->fused_instructions_computation()}))
<< "Fusion HLO calls computations other than the "
"fused_instructions_computation: "
<< instruction->ToString()
<< " instruction->fused_instructions_computation(): "
<< instruction->fused_instructions_computation()->ToString()
<< " instruction->called_computations(): "
<< ComputationsToString(instruction->called_computations());
for (const auto& fused : instruction->fused_instructions()) {
TF_RET_CHECK(fused->parent() ==
instruction->fused_instructions_computation())

View File

@ -122,7 +122,8 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
continue;
}
if (instruction->opcode() == HloOpcode::kFusion) {
if (instruction->opcode() == HloOpcode::kFusion &&
instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) {
// Insert the reduce-precision operation inside the fusion computation,
// after the corresponding parameter instruction.
TF_ASSIGN_OR_RETURN(
@ -171,7 +172,8 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
continue;
}
if (instruction->opcode() == HloOpcode::kFusion) {
if (instruction->opcode() == HloOpcode::kFusion &&
instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) {
// Insert the reduce-precision operation as the last operation inside
// the fusion computation.
HloInstruction* fusion_root = instruction->fused_expression_root();

View File

@ -28,6 +28,7 @@ py_library(
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/fused_conv:fused_conv_py",
"//tensorflow/contrib/gan",
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/hooks",
@ -72,6 +73,7 @@ py_library(
"//tensorflow/contrib/staging",
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
"//tensorflow/contrib/stateless",
"//tensorflow/contrib/summary:summary_ops",
"//tensorflow/contrib/tensor_forest:init_py",
"//tensorflow/contrib/tensorboard",
"//tensorflow/contrib/testing:testing_py",

View File

@ -31,6 +31,7 @@ from tensorflow.contrib import deprecated
from tensorflow.contrib import distributions
from tensorflow.contrib import factorization
from tensorflow.contrib import framework
from tensorflow.contrib import gan
from tensorflow.contrib import graph_editor
from tensorflow.contrib import grid_rnn
from tensorflow.contrib import image

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2
@ -26,18 +29,21 @@ from tensorflow.contrib.learn.python.learn import export_strategy
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader as saved_model_loader
from tensorflow.python.saved_model import tag_constants
def make_custom_export_strategy(name, convert_fn, feature_columns,
def make_custom_export_strategy(name,
convert_fn,
feature_columns,
export_input_fn):
"""Makes custom exporter of GTFlow tree format.
Args:
name: A string, for the name of the export strategy.
convert_fn: A function that converts the tree proto to desired format and
saves it to the desired location.
saves it to the desired location. Can be None to skip conversion.
feature_columns: A list of feature columns.
export_input_fn: A function that takes no arguments and returns an
`InputFnOps`.
@ -68,9 +74,22 @@ def make_custom_export_strategy(name, convert_fn, feature_columns,
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
dtec.ParseFromString(dfec_str)
# Export the result in the same folder as the saved model.
convert_fn(dtec, sorted_feature_names, len(dense_floats),
len(sparse_float_indices), len(sparse_int_indices),
result_dir, eval_result)
if convert_fn:
convert_fn(dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices),
len(sparse_int_indices), result_dir, eval_result)
feature_importances = _get_feature_importances(
dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices), len(sparse_int_indices))
sorted_by_importance = sorted(
feature_importances.items(), key=lambda x: -x[1])
assets_dir = os.path.join(result_dir, "assets.extra")
gfile.MakeDirs(assets_dir)
with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
"w") as f:
f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
return result_dir
return export_strategy.ExportStrategy(name, export_fn)
@ -157,3 +176,41 @@ def convert_to_universal_format(dtec, sorted_feature_names,
node.left_child_id.value = split.left_id
node.right_child_id.value = split.right_id
return model_and_features
def _get_feature_importances(dtec, feature_names, num_dense_floats,
num_sparse_float, num_sparse_int):
"""Export the feature importance per feature column."""
del num_sparse_int # Unused.
sums = collections.defaultdict(lambda: 0)
for tree_idx in range(len(dtec.trees)):
tree = dtec.trees[tree_idx]
for tree_node in tree.nodes:
node_type = tree_node.WhichOneof("node")
if node_type == "dense_float_binary_split":
split = tree_node.dense_float_binary_split
split_column = feature_names[split.feature_column]
elif node_type == "sparse_float_binary_split_default_left":
split = tree_node.sparse_float_binary_split_default_left.split
split_column = feature_names[split.feature_column + num_dense_floats]
elif node_type == "sparse_float_binary_split_default_right":
split = tree_node.sparse_float_binary_split_default_right.split
split_column = feature_names[split.feature_column + num_dense_floats]
elif node_type == "categorical_id_binary_split":
split = tree_node.categorical_id_binary_split
split_column = feature_names[split.feature_column + num_dense_floats +
num_sparse_float]
elif node_type == "categorical_id_set_membership_binary_split":
split = tree_node.categorical_id_set_membership_binary_split
split_column = feature_names[split.feature_column + num_dense_floats +
num_sparse_float]
elif node_type == "leaf":
assert tree_node.node_metadata.gain == 0
continue
else:
raise ValueError("Unexpected split type %s", node_type)
# Apply shrinkage factor. It is important since it is not always uniform
# across different trees.
sums[split_column] += (
tree_node.node_metadata.gain * dtec.tree_weights[tree_idx])
return dict(sums)

View File

@ -27,7 +27,7 @@ from tensorflow.python.platform import googletest
class ConvertModelTest(test_util.TensorFlowTestCase):
def testConvertModel(self):
def _make_trees(self):
dtec_str = """
trees {
nodes {
@ -108,8 +108,12 @@ class ConvertModelTest(test_util.TensorFlowTestCase):
"""
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(dtec_str, dtec)
# The feature columns in the order they were added.
feature_columns = ["feature_b", "feature_a", "feature_d"]
return dtec, feature_columns
def testConvertModel(self):
dtec, feature_columns = self._make_trees()
# The feature columns in the order they were added.
out = custom_export_strategy.convert_to_universal_format(
dtec, feature_columns, 1, 1,
1)
@ -273,6 +277,16 @@ class ConvertModelTest(test_util.TensorFlowTestCase):
}"""
self.assertProtoEquals(expected_tree, out)
def testFeatureImportance(self):
dtec, feature_columns = self._make_trees()
feature_importances = custom_export_strategy._get_feature_importances(
dtec, feature_columns, 1, 1, 1)
self.assertItemsEqual(["feature_b", "feature_a", "feature_d"],
feature_importances.keys())
self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4)
self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4)
self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4)
if __name__ == "__main__":
googletest.main()

View File

@ -61,11 +61,19 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
Raises:
ValueError: If learner_config is not valid.
"""
head = head_lib.multi_class_head(
n_classes=n_classes,
weight_column_name=weight_column_name,
enable_centered_bias=False)
if learner_config.num_classes == 0:
learner_config.num_classes = n_classes
elif learner_config.num_classes != n_classes:
raise ValueError("n_classes (%d) doesn't match learner_config (%d)." %
(learner_config.num_classes, n_classes))
super(GradientBoostedDecisionTreeClassifier, self).__init__(
model_fn=model.model_builder,
params={
@ -129,6 +137,10 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
label_dimension=label_dimension,
weight_column_name=weight_column_name,
enable_centered_bias=False)
if label_dimension == 1:
learner_config.num_classes = 2
else:
learner_config.num_classes = label_dimension
super(GradientBoostedDecisionTreeRegressor, self).__init__(
model_fn=model.model_builder,
params={

View File

@ -92,6 +92,7 @@ def model_builder(features, labels, mode, params, config):
examples_per_layer=examples_per_layer,
learner_config=learner_config,
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
features=features)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)

View File

@ -74,7 +74,7 @@ class TreeEnsembleStampTokenOp : public OpKernel {
decision_tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_ensemble_resource));
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
@ -95,7 +95,7 @@ class TreeEnsembleSerializeOp : public OpKernel {
decision_tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_ensemble_resource));
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),

View File

@ -143,7 +143,7 @@ class GradientTreesPredictionOp : public OpKernel {
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
if (use_locking_) {
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
DoCompute(context, decision_tree_ensemble_resource);
} else {
DoCompute(context, decision_tree_ensemble_resource);
@ -334,7 +334,7 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
if (use_locking_) {
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
DoCompute(context, decision_tree_ensemble_resource);
} else {
DoCompute(context, decision_tree_ensemble_resource);

View File

@ -656,7 +656,8 @@ class GrowTreeEnsembleOp : public OpKernel {
CHECK(split->split_info.split_node().node_case() != TreeNode::NODE_NOT_SET);
CHECK(tree_config->nodes(node_id).node_case() == TreeNode::kLeaf)
<< "Unexpected node type to split "
<< tree_config->nodes(node_id).node_case();
<< tree_config->nodes(node_id).node_case() << " for node_id " << node_id
<< ". Tree config: " << tree_config->DebugString();
// Add left leaf.
int32 left_id = tree_config->nodes_size();
@ -767,7 +768,7 @@ class TreeEnsembleStatsOp : public OpKernel {
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_ensemble_resource));
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
// Get the stamp token.
const Tensor* stamp_token_t;

View File

@ -42,6 +42,7 @@ class BiasFeatureColumnHandlerTest : public ::testing::Test {
example_partitions_({0, 0, 1, 3}) {
// Set L2 regularization.
learner_config_.mutable_regularization()->set_l2(2.0f);
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
// Create handler.
handler_.reset(new BiasFeatureColumnHandler(kClassId, kSlotId, kBatchSize));

View File

@ -51,7 +51,7 @@ class CategoricalFeatureColumnHandlerTest : public ::testing::Test {
values_(test::AsTensor<int64>({1, 2, 2, 0}, {4})) {
// Set L2 regularization.
learner_config_.mutable_regularization()->set_l2(2.0f);
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
// Create handler.
handler_.reset(new CategoricalFeatureColumnHandler(
kClassId, kSlotId, kBatchSize, kFeatureColumn, indices_.matrix<int64>(),

View File

@ -51,7 +51,7 @@ class DenseQuantizedFeatureColumnHandlerTest : public ::testing::Test {
dense_quantized_values_(test::AsTensor<int32>({1, 1, 0, 1}, {4})) {
// Set L2 regularization.
learner_config_.mutable_regularization()->set_l2(2.0f);
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
// Create handler.
handler_.reset(new DenseQuantizedFeatureColumnHandler(
kClassId, kSlotId, kBatchSize, kFeatureColumn,

View File

@ -53,7 +53,7 @@ class SparseQuantizedFeatureColumnHandlerTest : public ::testing::Test {
sparse_quantized_values_(test::AsTensor<int32>({1, 0, 1}, {3})) {
// Set L2 regularization.
learner_config_.mutable_regularization()->set_l2(2.0f);
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
// Create handler.
handler_.reset(new SparseQuantizedFeatureColumnHandler(
kClassId, kSlotId, kBatchSize, kFeatureColumn,

View File

@ -30,6 +30,7 @@ const double kDelta = 1e-5;
TEST(NodeStatsTest, AlmostZero) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
NodeStats node_stats(learner_config, GradientStats(1e-8f, 1e-8f));
EXPECT_EQ(0, node_stats.weight_contribution[0]);
EXPECT_EQ(0, node_stats.gain);
@ -37,6 +38,7 @@ TEST(NodeStatsTest, AlmostZero) {
TEST(NodeStatsTest, LessThanMinWeightConstraint) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_constraints()->set_min_node_weight(3.2f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
EXPECT_EQ(0, node_stats.weight_contribution[0]);
@ -45,6 +47,7 @@ TEST(NodeStatsTest, LessThanMinWeightConstraint) {
TEST(NodeStatsTest, L1RegSquashed) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l1(10.0f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
EXPECT_EQ(0, node_stats.weight_contribution[0]);
@ -53,6 +56,7 @@ TEST(NodeStatsTest, L1RegSquashed) {
TEST(NodeStatsTest, L1RegPos) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l1(5.0f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
const float expected_clipped_grad = 7.32f - 5.0f;
@ -66,6 +70,7 @@ TEST(NodeStatsTest, L1RegPos) {
TEST(NodeStatsTest, L1RegNeg) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l1(5.0f);
NodeStats node_stats(learner_config, GradientStats(-7.32f, 1.63f));
const float expected_clipped_grad = -7.32f + 5.0f;
@ -79,6 +84,7 @@ TEST(NodeStatsTest, L1RegNeg) {
TEST(NodeStatsTest, L2Reg) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l2(8.0f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
const float expected_denom = 1.63f + 8.0f;
@ -91,6 +97,7 @@ TEST(NodeStatsTest, L2Reg) {
TEST(NodeStatsTest, L1L2Reg) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l1(5.0f);
learner_config.mutable_regularization()->set_l2(8.0f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));

View File

@ -15,6 +15,7 @@
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
#include <cstring>
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
@ -34,10 +35,27 @@ class WeightedQuantilesSummary {
struct SummaryEntry {
SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
const WeightType& max)
: value(v), weight(w), min_rank(min), max_rank(max) {}
const WeightType& max) {
// Explicitely initialize all of memory (including padding from memory
// alignment) to allow the struct to be msan-resistant "plain old data".
//
// POD = http://en.cppreference.com/w/cpp/concept/PODType
memset(this, 0, sizeof(*this));
SummaryEntry() : value(0), weight(0), min_rank(0), max_rank(0) {}
value = v;
weight = w;
min_rank = min;
max_rank = max;
}
SummaryEntry() {
memset(this, 0, sizeof(*this));
value = 0;
weight = 0;
min_rank = 0;
max_rank = 0;
}
bool operator==(const SummaryEntry& other) const {
return value == other.value && weight == other.weight &&

View File

@ -17,7 +17,7 @@ message TreeRegularizationConfig {
// Tree constraints config.
message TreeConstraintsConfig {
// Maximum depth of the trees.
// Maximum depth of the trees. The default value is 6 if not specified.
uint32 max_tree_depth = 1;
// Min hessian weight per node.
@ -86,20 +86,22 @@ message LearningRateDropoutDrivenConfig {
message LearnerConfig {
enum PruningMode {
PRE_PRUNE = 0;
POST_PRUNE = 1;
PRUNING_MODE_UNSPECIFIED = 0;
PRE_PRUNE = 1;
POST_PRUNE = 2;
}
enum GrowingMode {
WHOLE_TREE = 0;
// Layer by layer is only supported by the batch learner.
LAYER_BY_LAYER = 1;
GROWING_MODE_UNSPECIFIED = 0;
WHOLE_TREE = 1;
LAYER_BY_LAYER = 2;
}
enum MultiClassStrategy {
TREE_PER_CLASS = 0;
FULL_HESSIAN = 1;
DIAGONAL_HESSIAN = 2;
MULTI_CLASS_STRATEGY_UNSPECIFIED = 0;
TREE_PER_CLASS = 1;
FULL_HESSIAN = 2;
DIAGONAL_HESSIAN = 3;
}
// Number of classes.
@ -118,16 +120,18 @@ message LearnerConfig {
// Constraints.
TreeConstraintsConfig constraints = 5;
// Pruning.
// Pruning. POST_PRUNE is the default pruning mode.
PruningMode pruning_mode = 8;
// Growing Mode.
// Growing Mode. LAYER_BY_LAYER is the default growing mode.
GrowingMode growing_mode = 9;
// Learning rate.
// Learning rate. By default we use fixed learning rate of 0.1.
LearningRateConfig learning_rate_tuner = 6;
// Multi-class strategy.
// Multi-class strategy. By default we use TREE_PER_CLASS for binary
// classification and linear regression. For other cases, we use
// DIAGONAL_HESSIAN as the default.
MultiClassStrategy multi_class_strategy = 10;
// If you want to average the ensembles (for regularization), provide the

View File

@ -344,6 +344,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# Prepare learner config.
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE
result, result_no_dropout, dropout_info = (
prediction_ops.gradient_trees_prediction(

View File

@ -261,6 +261,7 @@ class GradientBoostedDecisionTreeModel(object):
examples_per_layer,
learner_config,
features,
logits_dimension,
feature_columns=None):
"""Construct a new GradientBoostedDecisionTreeModel function.
@ -273,8 +274,8 @@ class GradientBoostedDecisionTreeModel(object):
a tree layer. It can also be a function that computes the number of
examples based on the depth of the layer that's being built.
learner_config: A learner config.
print split, sorted_feature_names[split.feature_column]
features: `dict` of `Tensor` objects.
logits_dimension: An int, the dimension of logits.
feature_columns: A list of feature columns.
Raises:
@ -289,11 +290,39 @@ class GradientBoostedDecisionTreeModel(object):
if learner_config.num_classes < 2:
raise ValueError("Number of classes must be >=2")
self._logits_dimension = logits_dimension
self._is_chief = is_chief
self._num_ps_replicas = num_ps_replicas
self._ensemble_handle = ensemble_handle
self._center_bias = center_bias
self._examples_per_layer = examples_per_layer
# Fill in the defaults.
if (learner_config.multi_class_strategy ==
learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED):
if logits_dimension == 1:
learner_config.multi_class_strategy = (
learner_pb2.LearnerConfig.TREE_PER_CLASS)
else:
learner_config.multi_class_strategy = (
learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
if (learner_config.growing_mode ==
learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED):
learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
if (learner_config.pruning_mode ==
learner_pb2.LearnerConfig.PRUNING_MODE_UNSPECIFIED):
learner_config.pruning_mode = learner_pb2.LearnerConfig.POST_PRUNE
if learner_config.constraints.max_tree_depth == 0:
# Use 6 as the default maximum depth.
learner_config.constraints.max_tree_depth = 6
tuner = learner_config.learning_rate_tuner.WhichOneof("tuner")
if not tuner:
learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
@ -378,75 +407,81 @@ class GradientBoostedDecisionTreeModel(object):
local_stamp), _refresh_local_ensemble_fn,
lambda: (control_flow_ops.no_op(), ensemble_stamp))
# Once updated, Use the the local model for prediction.
# Once updated, use the local model for prediction.
with ops.control_dependencies([refresh_local_ensemble]):
ensemble_stats = training_ops.tree_ensemble_stats(
local_ensemble_handle, ensemble_stamp)
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
# We don't need dropout info - we can always restore it based on the
# seed.
predictions, predictions_no_dropout, _ = (
prediction_ops.gradient_trees_prediction(
local_ensemble_handle,
seed,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
learner_config=self._learner_config_serialized,
apply_dropout=apply_dropout,
apply_averaging=apply_averaging,
use_locking=False,
center_bias=self._center_bias,
reduce_dim=self._reduce_dim))
partition_ids = prediction_ops.gradient_trees_partition_examples(
local_ensemble_handle,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
use_locking=False)
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
# Make sure ensemble stats run. This will check that the ensemble has
# the right stamp.
with ops.control_dependencies(ensemble_stats):
predictions, predictions_no_dropout, _ = (
prediction_ops.gradient_trees_prediction(
local_ensemble_handle,
seed,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
learner_config=self._learner_config_serialized,
apply_dropout=apply_dropout,
apply_averaging=apply_averaging,
use_locking=True,
center_bias=self._center_bias,
reduce_dim=self._reduce_dim))
partition_ids = prediction_ops.gradient_trees_partition_examples(
local_ensemble_handle,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
use_locking=True)
else:
with ops.device(self._ensemble_handle.device):
ensemble_stats = training_ops.tree_ensemble_stats(
self._ensemble_handle, ensemble_stamp)
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
# We don't need dropout info - we can always restore it based on the
# seed.
predictions, predictions_no_dropout, _ = (
prediction_ops.gradient_trees_prediction(
self._ensemble_handle,
seed,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
learner_config=self._learner_config_serialized,
apply_dropout=apply_dropout,
apply_averaging=apply_averaging,
use_locking=False,
center_bias=self._center_bias,
reduce_dim=self._reduce_dim))
partition_ids = prediction_ops.gradient_trees_partition_examples(
self._ensemble_handle,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
use_locking=False)
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
# Make sure ensemble stats run. This will check that the ensemble has
# the right stamp.
with ops.control_dependencies(ensemble_stats):
predictions, predictions_no_dropout, _ = (
prediction_ops.gradient_trees_prediction(
self._ensemble_handle,
seed,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
learner_config=self._learner_config_serialized,
apply_dropout=apply_dropout,
apply_averaging=apply_averaging,
use_locking=True,
center_bias=self._center_bias,
reduce_dim=self._reduce_dim))
partition_ids = prediction_ops.gradient_trees_partition_examples(
self._ensemble_handle,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
use_locking=True)
return _make_predictions_dict(ensemble_stamp, predictions,
predictions_no_dropout, partition_ids,

View File

@ -164,7 +164,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -268,7 +268,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=num_examples_fn,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -371,7 +371,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -442,7 +442,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -505,7 +505,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -588,7 +588,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
# Create predict op.
mode = model_fn.ModeKeys.EVAL
@ -627,7 +627,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=5, features=features)
predictions = array_ops.constant(
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
@ -730,7 +730,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=5, features=features)
predictions = array_ops.constant(
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
@ -833,7 +833,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=5, features=features)
batch_size = 3
predictions = array_ops.constant(

View File

@ -14,8 +14,8 @@
# ==============================================================================
include (ExternalProject)
set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/69ceda618313df8e9cac6659d607b08949455d14.tar.gz)
set(cub_HASH SHA256=87e856522c283b8ea887c3b61d7d5b252d2dd74abac4f1d756d776e721223e82)
set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip)
set(cub_HASH SHA256=b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe)
set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive)

View File

@ -18,6 +18,7 @@
set(tf_c_srcs
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/c_api_function.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"

View File

@ -315,6 +315,7 @@ add_python_module("tensorflow/contrib/framework/ops")
add_python_module("tensorflow/contrib/framework/python")
add_python_module("tensorflow/contrib/framework/python/framework")
add_python_module("tensorflow/contrib/framework/python/ops")
add_python_module("tensorflow/contrib/gan")
add_python_module("tensorflow/contrib/graph_editor")
add_python_module("tensorflow/contrib/graph_editor/examples")
add_python_module("tensorflow/contrib/graph_editor/tests")

View File

@ -291,6 +291,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Failing with TF 1.3 (TODO)
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py"
# Test should only be run manually
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py"
)
endif()
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})

View File

@ -716,6 +716,482 @@ _cudnn_rnn_common_doc_string = """
"""
def _check_direction(direction):
if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
raise ValueError("Invalid direction: %s, expect %s or %s" %
(direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION))
def _check_rnn_mode(rnn_mode):
if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU):
raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" %
(rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH,
CUDNN_RNN_RELU))
def _get_seed(seed):
seed, seed2 = random_seed.get_seed(seed)
if seed is None and seed2 is None:
seed, seed2 = 0, 0
return seed, seed2
def _get_num_params(rnn_mode, num_layers, direction):
"""Return num params for given Cudnn config."""
if rnn_mode == CUDNN_LSTM:
num_params_per_layer = 8
elif rnn_mode == CUDNN_GRU:
num_params_per_layer = 6
elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH):
num_params_per_layer = 2
else:
raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode)
num_params = num_layers * num_params_per_layer
if direction != CUDNN_RNN_UNIDIRECTION:
num_params *= 2
return num_params
def _cudnn_rnn(inputs,
input_h,
input_c,
params,
is_training,
rnn_mode,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn RNN.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM.
A Tensor of the same shape as input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h, output_c
"""
_check_rnn_mode(rnn_mode)
_check_direction(direction)
seed, seed2 = random_seed.get_seed(seed)
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
input=inputs,
input_h=input_h,
input_c=input_c,
params=params,
is_training=is_training,
rnn_mode=rnn_mode,
input_mode=input_mode,
direction=direction,
dropout=dropout,
seed=seed,
seed2=seed2,
name=name)
return (outputs, output_h, output_c)
def cudnn_lstm(inputs,
input_h,
input_c,
params,
is_training,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn LSTM.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM.
A Tensor of the same shape as input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h, output_c
"""
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
input_mode, direction, dropout, seed, name)
def _cudnn_rnn_no_input_c(inputs,
input_h,
params,
is_training,
rnn_mode,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn RNN w/o input_c.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h
"""
input_c = array_ops.constant([], dtype=input_h.dtype)
outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params,
is_training, rnn_mode, input_mode,
direction, dropout, seed, name)
return outputs, output_h
def cudnn_gru(inputs,
input_h,
params,
is_training,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn GRU.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h
"""
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU,
input_mode, direction, dropout, seed, name)
def cudnn_rnn_relu(inputs,
input_h,
params,
is_training,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn RNN Relu.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h
"""
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
CUDNN_RNN_RELU, input_mode, direction, dropout,
seed, name)
def cudnn_rnn_tanh(inputs,
input_h,
params,
is_training,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn RNN Tanh.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h
"""
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
CUDNN_RNN_TANH, input_mode, direction, dropout,
seed, name)
def cudnn_rnn_params_to_canonical(rnn_mode,
num_layers,
num_units,
input_size,
params,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0,
seed=0,
name=None):
"""Convert cudnn opaque params to canonical.
Args:
rnn_mode: a string specifies the mode, under which this RNN model runs.
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
num_layers: the number of layers for the RNN model.
num_units: the number of units within the RNN model.
input_size: the size of the input, it could be different from the
num_units.
params: opaque cudnn params var.
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
weights list and bias list
Raises:
ValueError: if rnn_mode or direction is invalid.
"""
_check_rnn_mode(rnn_mode)
_check_direction(direction)
num_params = _get_num_params(rnn_mode, num_layers, direction)
seed, seed2 = random_seed.get_seed(seed)
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
rnn_mode=rnn_mode,
num_layers=num_layers,
num_units=num_units,
input_size=input_size,
params=params,
input_mode=input_mode,
direction=direction,
dropout=dropout,
seed=seed,
seed2=seed2,
num_params=num_params,
name=name)
return weights, biases
def cudnn_rnn_canonical_to_params(rnn_mode,
num_layers,
num_units,
input_size,
weights,
biases,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0,
seed=0,
name=None):
"""Converts params from the canonical format to a specific format of cuDNN.
Args:
rnn_mode: a string specifies the mode, under which this RNN model runs.
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
num_layers: the number of layers for the RNN model.
num_units: the number of units within the RNN model.
input_size: the size of the input, it could be different from the
num_units.
weights: a Tensor for weight parameters.
biases: a Tensor for bias parameters.
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
an opaque Cudnn param.
Raises:
ValueError: if rnn_mode or direction is invalid.
"""
_check_rnn_mode(rnn_mode)
_check_direction(direction)
seed, seed2 = random_seed.get_seed(seed)
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
rnn_mode=rnn_mode,
num_layers=num_layers,
num_units=num_units,
input_size=input_size,
weights=weights,
biases=biases,
input_mode=input_mode,
direction=direction,
dropout=dropout,
seed=seed,
seed2=seed2,
name=name)
def cudnn_opaque_params_size(rnn_mode,
num_layers,
num_units,
input_size,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dtype=dtypes.float32,
dropout=0,
seed=0,
name=None):
"""Returns opaque params size for specific Cudnn config.
Args:
rnn_mode: a string specifies the mode, under which this RNN model runs.
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
num_layers: the number of layers for the RNN model.
num_units: the number of units within the RNN model.
input_size: the size of the input, it could be different from the
num_units.
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dtype: one of tf.float32 or tf.float64.
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
a int, size of Cudnn opaque params.
Raises:
ValueError: if rnn_mode or direction is invalid.
"""
_check_rnn_mode(rnn_mode)
_check_direction(direction)
seed, seed2 = random_seed.get_seed(seed)
return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
rnn_mode=rnn_mode,
num_layers=num_layers,
num_units=num_units,
input_size=input_size,
T=dtype,
S=dtypes.int32,
dropout=dropout,
seed=seed,
seed2=seed2,
input_mode=input_mode,
direction=direction,
name=name)[0]
class _CudnnRNN(object):
"""Creates an RNN model using the underlying Cudnn implementation.
@ -761,9 +1237,6 @@ class _CudnnRNN(object):
Raises:
ValueError: if direction is invalid.
"""
if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
raise ValueError("Invalid direction: %s, expect %s or %s",
direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)
self._num_layers = num_layers
self._num_units = num_units
self._input_size = input_size
@ -772,10 +1245,7 @@ class _CudnnRNN(object):
self._direction = direction
self._dtype = dtype
self._dropout = dropout
# get graph and op seed.
self._seed, self._seed2 = random_seed.get_seed(seed)
if self._seed is None and self._seed2 is None:
self._seed, self._seed2 = 0, 0
self._seed = seed
@property
def input_mode(self):
@ -807,18 +1277,16 @@ class _CudnnRNN(object):
Returns:
The calculated parameter buffer size.
"""
return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
return cudnn_opaque_params_size(
rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
input_size=self._input_size,
T=self._dtype,
S=dtypes.int32,
dtype=self._dtype,
dropout=self._dropout,
seed=self._seed,
seed2=self._seed2,
rnn_mode=self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction)[0]
direction=self._direction)
def __call__(self, input_data, input_h, input_c, params, is_training=True):
"""Runs the forward step for the RNN model.
@ -837,22 +1305,17 @@ class _CudnnRNN(object):
output_h: the final state for h.
output_c: the final state for c. This is only relevant for LSTM.
"""
if self._rnn_mode != CUDNN_LSTM:
# For model that doesn't take input_c, replace with a dummy tensor.
input_c = array_ops.constant([], dtype=self._dtype)
output, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
input=input_data,
input_h=input_h,
input_c=input_c,
params=params,
rnn_mode=self._rnn_mode,
return _cudnn_rnn(
input_data,
input_h,
input_c,
params,
is_training,
self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction,
dropout=self._dropout,
seed=self._seed,
seed2=self._seed2,
is_training=is_training)
return (output, output_h, output_c)
seed=self._seed)
def params_to_canonical(self, params):
"""Converts params from a specific format of cuDNN to the canonical format.
@ -863,22 +1326,16 @@ class _CudnnRNN(object):
Returns:
A function for the specific-to-canonical conversion.
"""
num_params = self._num_layers * self._NUM_PARAMS_PER_LAYER
if self._direction != CUDNN_RNN_UNIDIRECTION:
num_params *= 2
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
return cudnn_rnn_params_to_canonical(
rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
input_size=self._input_size,
params=params,
dropout=self._dropout,
seed=self._seed,
seed2=self._seed2,
num_params=num_params,
rnn_mode=self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction)
return weights, biases
direction=self._direction,
dropout=self._dropout,
seed=self._seed)
def canonical_to_params(self, weights, biases):
"""Converts params from the canonical format to a specific format of cuDNN.
@ -890,18 +1347,17 @@ class _CudnnRNN(object):
Returns:
A function for the canonical-to-params-to-specific conversion..
"""
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
return cudnn_rnn_canonical_to_params(
rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
input_size=self._input_size,
weights=weights,
biases=biases,
dropout=self._dropout,
seed=self._seed,
seed2=self._seed2,
rnn_mode=self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction)
direction=self._direction,
dropout=self._dropout,
seed=self._seed)
class CudnnLSTM(_CudnnRNN):
@ -1036,9 +1492,16 @@ class _CudnnRNNNoInputC(_CudnnRNN):
output: the output sequuence.
output_h: the final state for h.
"""
output, output_h, _ = super(_CudnnRNNNoInputC, self).__call__(
input_data, input_h, None, params, is_training=is_training)
return (output, output_h)
return _cudnn_rnn_no_input_c(
input_data,
input_h,
params,
is_training,
self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction,
dropout=self._dropout,
seed=self._seed)
class CudnnGRU(_CudnnRNNNoInputC):

View File

@ -22,6 +22,7 @@
@@read_batch_features
@@rejection_resample
@@group_by_window
"""
from __future__ import absolute_import
@ -31,6 +32,7 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset
from tensorflow.contrib.data.python.ops.dataset_ops import group_by_window
from tensorflow.contrib.data.python.ops.dataset_ops import Iterator
from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features
from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample

View File

@ -37,7 +37,9 @@ class GroupByWindowTest(test.TestCase):
components = np.random.randint(100, size=(200,)).astype(np.int64)
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
.apply(
dataset_ops.group_by_window,
args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -61,8 +63,9 @@ class GroupByWindowTest(test.TestCase):
components = np.array(
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1)
.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
dataset_ops.group_by_window,
args=(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -81,8 +84,9 @@ class GroupByWindowTest(test.TestCase):
def testSmallGroups(self):
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components)
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
dataset_ops.Dataset.from_tensor_slices(components).apply(
dataset_ops.group_by_window,
args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -108,8 +112,9 @@ class GroupByWindowTest(test.TestCase):
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: (x, ops.convert_to_tensor([x * x])))
.group_by_window(lambda x, _: x % 2, reduce_func, 32))
.map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
dataset_ops.group_by_window,
args=(lambda x, _: x % 2, reduce_func, 32)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -124,17 +129,20 @@ class GroupByWindowTest(test.TestCase):
def reduce_func(key, window):
# Apply two different kinds of padding to the input: tight
# padding, and quantized (to a multiple of 10) padding.
return dataset_ops.Dataset.zip((window.padded_batch(
4,
padded_shapes=tensor_shape.TensorShape([None])), window.padded_batch(
return dataset_ops.Dataset.zip((
window.padded_batch(
4, padded_shapes=tensor_shape.TensorShape([None])),
window.padded_batch(
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),))
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
.group_by_window(
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
reduce_func, 4))
.apply(
dataset_ops.group_by_window,
args=
(lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
reduce_func, 4)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -151,10 +159,9 @@ class GroupByWindowTest(test.TestCase):
self.assertEqual(len(components), sum(counts))
# NOTE(mrry): These tests are based on the tests in
# bucket_ops_test.py. Currently, different batch sizes for each key
# are not supported, although this would be possible to add to
# `Dataset.group_by_window()`.
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
class BucketTest(test.TestCase):
def _dynamicPad(self, bucket, window, window_size):
@ -168,6 +175,7 @@ class BucketTest(test.TestCase):
tensor_shape.TensorShape([3])))))
def testSingleBucket(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
@ -175,9 +183,10 @@ class BucketTest(test.TestCase):
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
bucketed_dataset = input_dataset.group_by_window(
lambda x, y, z: 0, lambda k, bucket: self._dynamicPad(k, bucket, 32),
32)
bucketed_dataset = input_dataset.apply(
dataset_ops.group_by_window,
args=(lambda x, y, z: 0,
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
@ -201,6 +210,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
def testEvenOddBuckets(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
@ -208,9 +218,10 @@ class BucketTest(test.TestCase):
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
bucketed_dataset = input_dataset.group_by_window(
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
bucketed_dataset = input_dataset.apply(
dataset_ops.group_by_window,
args=(lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
@ -256,25 +267,31 @@ class BucketTest(test.TestCase):
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
def testEvenOddBucketsFilterOutAllOdd(self):
def _map_fn(v):
return {"x": v,
"y": array_ops.fill([v], v),
"z": array_ops.fill([3], string_ops.as_string(v))}
return {
"x": v,
"y": array_ops.fill([v], v),
"z": array_ops.fill([3], string_ops.as_string(v))
}
def _dynamic_pad_fn(bucket, window, _):
return dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensors(bucket), window.padded_batch(
32, {"x": tensor_shape.TensorShape([]),
"y": tensor_shape.TensorShape([None]),
"z": tensor_shape.TensorShape([3])})))
32, {
"x": tensor_shape.TensorShape([]),
"y": tensor_shape.TensorShape([None]),
"z": tensor_shape.TensorShape([3])
})))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
.filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
bucketed_dataset = input_dataset.group_by_window(
lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)
bucketed_dataset = input_dataset.apply(
dataset_ops.group_by_window,
args=(lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
@ -295,6 +312,40 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
def testDynamicWindowSize(self):
components = np.arange(100).astype(np.int64)
# Key fn: even/odd
# Reduce fn: batches of 5
# Window size fn: even=5, odd=10
def window_size_func(key):
window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
return window_sizes[key]
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
dataset_ops.group_by_window,
args=(lambda x: x % 2, lambda _, xs: xs.batch(20), None,
window_size_func))
iterator = dataset_ops.Iterator.from_dataset(dataset)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
while True:
result = sess.run(get_next)
is_even = all(x % 2 == 0 for x in result)
is_odd = all(x % 2 == 1 for x in result)
self.assertTrue(is_even or is_odd)
expected_batch_size = 5 if is_even else 10
self.assertEqual(expected_batch_size, result.shape[0])
batches += 1
self.assertEqual(batches, 15)
if __name__ == "__main__":
test.main()

View File

@ -1199,28 +1199,9 @@ class Dataset(object):
return DenseToSparseBatchDataset(self, batch_size, row_shape)
def group_by_window(self, key_func, reduce_func, window_size):
"""Performs a windowed "group-by" operation on this dataset.
This method maps each consecutive element in this dataset to a key
using `key_func` and groups the elements by key. It then applies
`reduce_func` to at most `window_size` elements matching the same
key. All execpt the final window for each key will contain
`window_size` elements; the final window may be smaller.
Args:
key_func: A function mapping a nested structure of tensors
(having shapes and types defined by `self.output_shapes` and
`self.output_types`) to a scalar `tf.int64` tensor.
reduce_func: A function mapping a key and a dataset of up to `batch_size`
consecutive elements matching that key to another dataset.
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements matching the same key to combine in a single
batch, which will be passed to `reduce_func`.
Returns:
A `Dataset`.
"""
return GroupByWindowDataset(self, key_func, reduce_func, window_size)
"""See group_by_window()."""
return self.apply(
group_by_window, args=(key_func, reduce_func, window_size))
def map(self,
map_func,
@ -1370,6 +1351,43 @@ class Dataset(object):
"""
return FilterDataset(self, predicate)
def apply(self, fn, args=(), kwargs={}): # pylint: disable=dangerous-default-value
"""Apply a function to this dataset.
`apply` enables chaining of custom `Dataset` transformations.
For example:
```
dataset.map(
lambda x: x**2
).apply(
group_by_window, args=(key_func, reduce_func, window_size)
).map(
lambda x: x**3
)
```
Args:
fn: A function that takes a `Dataset`, `args`, and `kwargs`, and
returns a `Dataset`.
args: A `tuple` or `list` of arguments to be passed to `fn`.
kwargs: A `dict` of keyword arguments to be passed to `fn`.
Returns:
The `Dataset` returned by `fn`.
"""
if not (isinstance(args, tuple) or isinstance(args, list)):
raise TypeError("args must be a tuple or list.")
if not isinstance(kwargs, dict):
raise TypeError("kwargs must be a dict.")
dataset = fn(self, *args, **kwargs)
if not isinstance(dataset, Dataset):
raise TypeError("fn must return a Dataset.")
return dataset
class TensorDataset(Dataset):
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
@ -1927,71 +1945,6 @@ class _ResourceDataset(Dataset):
return self._output_types
class GroupByWindowDataset(Dataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size):
"""See `Dataset.group_by_window()` for details."""
super(GroupByWindowDataset, self).__init__()
self._input_dataset = input_dataset
self._window_size = window_size
@function.Defun(*nest.flatten(input_dataset.output_types))
def tf_key_func(*args):
"""A wrapper for Defun that facilitates shape inference."""
# Pass in shape information from the input_dataset.
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
arg.set_shape(shape)
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
if _should_unpack_args(nested_args):
ret = key_func(*nested_args)
else:
ret = key_func(nested_args)
ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
if ret.dtype != dtypes.int64:
raise ValueError("`key_func` must return a single tf.int64 tensor.")
return ret
self._key_func = tf_key_func
self._key_func.add_to_graph(ops.get_default_graph())
@function.Defun(dtypes.int64, dtypes.resource)
def tf_reduce_func(key, window_dataset_resource):
"""A wrapper for Defun that facilitates shape inference."""
key.set_shape([])
window_dataset = _ResourceDataset(window_dataset_resource,
input_dataset.output_types,
input_dataset.output_shapes)
output_dataset = reduce_func(key, window_dataset)
if not isinstance(output_dataset, Dataset):
raise TypeError("`reduce_func` must return a `Dataset` object.")
self._output_types = output_dataset.output_types
self._output_shapes = output_dataset.output_shapes
return output_dataset.make_dataset_resource()
self._reduce_func = tf_reduce_func
self._reduce_func.add_to_graph(ops.get_default_graph())
def make_dataset_resource(self):
return gen_dataset_ops.group_by_window_dataset(
self._input_dataset.make_dataset_resource(),
self._key_func.captured_inputs,
self._reduce_func.captured_inputs,
self._window_size,
key_func=self._key_func,
reduce_func=self._reduce_func,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
@property
def output_shapes(self):
return self._output_shapes
@property
def output_types(self):
return self._output_types
class MapDataset(Dataset):
"""A `Dataset` that maps a function over elements in its input."""
@ -2660,3 +2613,149 @@ def _get_file_names(file_pattern, randomize_input):
if not randomize_input:
file_names = sorted(file_names)
return file_names
class GroupByWindowDataset(Dataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
super(GroupByWindowDataset, self).__init__()
self._input_dataset = input_dataset
self._make_key_func(key_func, input_dataset)
self._make_reduce_func(reduce_func, input_dataset)
self._make_window_size_func(window_size_func)
def _make_window_size_func(self, window_size_func):
"""Make wrapping Defun for window_size_func."""
@function.Defun(dtypes.int64)
def tf_window_size_func(key):
key.set_shape([])
window_size = ops.convert_to_tensor(
window_size_func(key), dtype=dtypes.int64)
if window_size.dtype != dtypes.int64:
raise ValueError(
"`window_size_func` must return a single tf.int64 tensor.")
return window_size
self._window_size_func = tf_window_size_func
self._window_size_func.add_to_graph(ops.get_default_graph())
def _make_key_func(self, key_func, input_dataset):
"""Make wrapping Defun for key_func."""
@function.Defun(*nest.flatten(input_dataset.output_types))
def tf_key_func(*args):
"""A wrapper for Defun that facilitates shape inference."""
# Pass in shape information from the input_dataset.
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
arg.set_shape(shape)
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
if _should_unpack_args(nested_args):
ret = key_func(*nested_args)
else:
ret = key_func(nested_args)
ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
if ret.dtype != dtypes.int64:
raise ValueError("`key_func` must return a single tf.int64 tensor.")
return ret
self._key_func = tf_key_func
self._key_func.add_to_graph(ops.get_default_graph())
def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping Defun for reduce_func."""
@function.Defun(dtypes.int64, dtypes.resource)
def tf_reduce_func(key, window_dataset_resource):
"""A wrapper for Defun that facilitates shape inference."""
key.set_shape([])
window_dataset = _ResourceDataset(window_dataset_resource,
input_dataset.output_types,
input_dataset.output_shapes)
output_dataset = reduce_func(key, window_dataset)
if not isinstance(output_dataset, Dataset):
raise TypeError("`reduce_func` must return a `Dataset` object.")
self._output_types = output_dataset.output_types
self._output_shapes = output_dataset.output_shapes
return output_dataset.make_dataset_resource()
self._reduce_func = tf_reduce_func
self._reduce_func.add_to_graph(ops.get_default_graph())
@property
def output_shapes(self):
return self._output_shapes
@property
def output_types(self):
return self._output_types
def make_dataset_resource(self):
return gen_dataset_ops.group_by_window_dataset(
self._input_dataset.make_dataset_resource(),
self._key_func.captured_inputs,
self._reduce_func.captured_inputs,
self._window_size_func.captured_inputs,
key_func=self._key_func,
reduce_func=self._reduce_func,
window_size_func=self._window_size_func,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
def group_by_window(dataset,
key_func,
reduce_func,
window_size=None,
window_size_func=None):
"""Performs a windowed "group-by" operation on this dataset.
This method maps each consecutive element in this dataset to a key
using `key_func` and groups the elements by key. It then applies
`reduce_func` to at most `window_size_func(key)` elements matching the same
key. All execpt the final window for each key will contain
`window_size_func(key)` elements; the final window may be smaller.
You may provide either a constant `window_size` or a window size determined by
the key through `window_size_func`.
Args:
dataset: A `Dataset`.
key_func: A function mapping a nested structure of tensors
(having shapes and types defined by `self.output_shapes` and
`self.output_types`) to a scalar `tf.int64` tensor.
reduce_func: A function mapping a key and a dataset of up to `batch_size`
consecutive elements matching that key to another dataset.
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements matching the same key to combine in a single
batch, which will be passed to `reduce_func`. Mutually exclusive with
`window_size_func`.
window_size_func: A function mapping a key to a `tf.int64` scalar
`tf.Tensor`, representing the number of consecutive elements matching
the same key to combine in a single batch, which will be passed to
`reduce_func`. Mutually exclusive with `window_size`.
Returns:
A `Dataset`.
Raises:
ValueError: if neither or both of {`window_size`, `window_size_func`} are
passed.
"""
if (window_size is not None and window_size_func or
not (window_size is not None or window_size_func)):
raise ValueError("Must pass either window_size or window_size_func.")
if window_size is not None:
def constant_window_func(unused_key):
return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
window_size_func = constant_window_func
assert window_size_func is not None
return GroupByWindowDataset(dataset, key_func, reduce_func, window_size_func)

View File

@ -341,7 +341,7 @@ cuda_py_test(
cuda_py_test(
name = "sample_stats_test",
size = "small",
size = "medium",
srcs = ["python/kernel_tests/sample_stats_test.py"],
additional_deps = [
":distributions_py",

View File

@ -17,440 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_checkpoint_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
ops.NotDifferentiable("GenerateVocabRemapping")
ops.NotDifferentiable("LoadAndRemapMatrix")
from tensorflow.python.training import checkpoint_ops
def _load_and_remap_matrix(ckpt_path,
old_tensor_name,
new_row_vocab_offset,
num_rows_to_load,
new_col_vocab_size,
initializer,
old_row_vocab_file=None,
new_row_vocab_file=None,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=0,
num_col_oov_buckets=0,
max_rows_in_memory=-1):
"""Loads a 2-D (matrix) `Tensor` from checkpoint.
Generates 1D-remappings for rows and columns using the
`GenerateVocabRemapping` op, and initializes any anticipated values with the
provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a
matrix that loads existing values from the checkpoint, while filling out
"missing" values with the newly initialized values. See
contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped
functionality (LoadAndRemapMatrix). This wrapper can be used to perform only
row remapping or only col remapping. If only row remapping is desired,
{new,old}_col_vocab_file should be `None`, and vice versa for column
remapping.
NOTE: This only supports div-partitioning the vocabulary on the 1st dimension
(row axis) via `new_row_vocab_offset`.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
new_row_vocab_offset: A 0-indexed integer representing what line to
start reading at in the new row vocabulary. Used for partitioned
variables.
num_rows_to_load: Number of rows to load for the new vocabulary (note: to
support variable partitioning and partial loading, this does not need to
be the same as the number of entries in `new_row_vocab_file`).
new_col_vocab_size: Number of columns to load - should be the same as the
number of entries in `new_col_vocab_file`, since we don't support
partitioning along the column axis.
initializer: Callable initializer function that accepts a 1-D tensor as the
arg to specify the shape of the returned tensor. Used to initialize
missing values.
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old row vocabulary file. Can be None, which represents no
remapping on the row axis.
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new row vocabulary file. Can be None, which represents no remapping
on the row axis - in which case, `new_row_vocab_offset` and
`num_rows_to_load` work under the assumption that the new row vocab is the
same as the old row vocab.
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old column vocabulary file. Can be None, which represents no
remapping on the column axis.
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new column vocabulary file. Can be None, which represents no
remapping on the column axis - in which case, `new_col_vocab_size` works
under the assumption that the new col vocab is the same as the old col
vocab.
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
to append. Must be >= 0.
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
columns to append. Must be >= 0.
max_rows_in_memory: `int` specifying the maximum number of rows to load from
the checkpoint at once. If less than or equal to 0, the entire matrix will
be loaded into memory. Setting this arg trades increased disk reads for
lower memory usage.
Returns:
A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the
specified tensor in the checkpoint, and any missing or OOV values
initialized with the given `initializer`.
Raises:
ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0.
ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is
provided, while the other is not. Same for `old_col_vocab_file` and
`new_col_vocab_file`.
ValueError: If neither row vocabs or col vocabs are provided.
"""
if num_row_oov_buckets < 0:
raise ValueError("num_row_oov_buckets must be >= 0, but received %d" %
num_row_oov_buckets)
if num_col_oov_buckets < 0:
raise ValueError("num_col_oov_buckets must be >= 0, but received %d" %
num_col_oov_buckets)
if bool(old_row_vocab_file) != bool(new_row_vocab_file):
raise ValueError(
"old_row_vocab_file and new_row_vocab_file must both be specified or "
"left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'".
format(old_row_vocab_file, new_row_vocab_file))
if bool(old_col_vocab_file) != bool(new_col_vocab_file):
raise ValueError(
"old_col_vocab_file and new_col_vocab_file must both be specified or "
"left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'".
format(old_col_vocab_file, new_col_vocab_file))
remap_rows = new_row_vocab_file and old_row_vocab_file
remap_cols = new_col_vocab_file and old_col_vocab_file
if not (remap_rows or remap_cols):
raise ValueError(
"Must provide either row or column vocab files. If no remapping is "
"necessary, consider using `tf.contrib.framework.init_from_checkpoint` "
"instead.")
num_rows_present = num_rows_to_load
if remap_rows:
row_remapping, num_rows_present = (
gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access
new_vocab_file=new_row_vocab_file,
old_vocab_file=old_row_vocab_file,
new_vocab_offset=new_row_vocab_offset,
num_new_vocab=num_rows_to_load))
else:
# Even when the rows are not being reordered, we still need to generate a
# remapping to account for initializing partitioned Variables (when
# new_row_vocab_offset is non-zero).
row_remapping = math_ops.range(
new_row_vocab_offset,
new_row_vocab_offset + num_rows_to_load,
dtype=dtypes.int64)
col_remapping = []
num_cols_present = new_col_vocab_size
if remap_cols:
col_remapping, num_cols_present = (
gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access
new_vocab_file=new_col_vocab_file,
old_vocab_file=old_col_vocab_file,
new_vocab_offset=0, # Offset is unused for cols (no partitioning).
num_new_vocab=new_col_vocab_size))
init_vals = initializer([
num_rows_to_load * new_col_vocab_size -
num_rows_present * num_cols_present, 1
])
return_tensor = gen_checkpoint_ops._load_and_remap_matrix( # pylint: disable=protected-access
ckpt_path=ckpt_path,
old_tensor_name=old_tensor_name,
row_remapping=row_remapping,
col_remapping=col_remapping,
initializing_values=init_vals,
num_rows=num_rows_to_load,
num_cols=new_col_vocab_size,
max_rows_in_memory=max_rows_in_memory)
# Add OOV row(s) and column(s).
if num_row_oov_buckets > 0:
init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size])
init_row_oov_val = ops.convert_to_tensor(init_row_oov_val)
return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0)
if num_col_oov_buckets > 0:
# We need to add any row OOV to the new column shape.
init_col_oov_val = initializer(
[num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets])
init_col_oov_val = ops.convert_to_tensor(init_col_oov_val)
return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1)
return return_tensor
def load_and_remap_matrix_initializer(ckpt_path,
old_tensor_name,
new_row_vocab_size,
new_col_vocab_size,
old_row_vocab_file=None,
new_row_vocab_file=None,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=0,
num_col_oov_buckets=0,
initializer=None,
max_rows_in_memory=-1):
r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor.
The returned initializer loads a 2-D (matrix) `Tensor` with name
`old_tensor_name` from the checkpoint at `ckpt_path`. It will reorder the
rows/columns according to the specified vocab files and append additional
out-of-vocabulary rows/columns according to the number of OOV buckets.
The format of the file at the `{old,new}_{row,col}_vocab_file` path should be
a text file, with each line containing a single entity within the vocabulary.
Let the function `line_of(f, "x")` return the 0-indexed line number of the
entity "x" in file f, and the function `entity_at(f, i)` return the entity at
line i of file f. Then, row i of the new output matrix will be taken from row
`line_of(old_row_vocab_file, entity_at(new_row_vocab_file, i))` of the old
matrix. If any entity in `new_row_vocab_file` is not found in
`old_row_vocab_file`, that row is considered a "missing" row, and its values
will be initialized using the `initializer` arg. The same logic also applies
for the columns.
For example, assuming that:
* `old_row_vocab_file` contains "mercury\nvenus\nmars"
* `new_row_vocab_file` contains "venus\njupiter\nmercury"
* `old_col_vocab_file` contains "good\nbetter\nbest"
* `new_col_vocab_file` contains "good\nbest\nfantastic"
* `initializer` returns the natural numbers `[1, 2, 3, 4, ...]`
* `w(i, j)` represents the value from row i, column j of the old matrix
Then the new output matrix will look like:
`[[w(1, 0), w(1, 2), 1],
[2, 3, 4],
[w(0, 0), w(0, 2), 5]]`
If we further specify that:
* `num_row_oov_buckets` == 2
* `num_col_oov_buckets` == 1
Then the new output matrix will look like:
`[[w(1, 0), w(1, 2), 1, 12],
[2, 3, 4, 13],
[w(0, 0), w(0, 2), 5, 14],
[6, 7, 8, 15],
[9, 10, 11, 16]]`
If `{old,new}_row_vocab_file` are None, we assume that the old and new row
vocab files are the same, and no row remapping is done. If
`{old,new}_col_vocab_file` are None, we assume that the old and new column
vocab files are the same, and no column remapping is done.
The returned initializer only supports div-partitioning along the row axis. It
does not support partitioning along the column axis or mod-partitioning.
NOTE: When this is used to warm-start variables, client code should use
`tf.lookup.index_table_from_tensor()` like
contrib/layers/python/layers/feature_column.py does, as opposed to
`tf.feature_to_id()` - in order to ensure the underlying lookup tables are the
same.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
new_row_vocab_size: `int` specifying the number of entries in
`new_row_vocab_file`. If no row remapping is needed (no row vocab
provided), this should be equal to the number of rows to load from the old
matrix (which can theoretically be smaller than the number of rows in the
old matrix).
new_col_vocab_size: `int` specifying the number of entries in
`new_col_vocab_file`. If no column remapping is needed (no column vocab
provided), this should be equal to the number of columns in the old
matrix.
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old row vocabulary file. Can be None, which represents no
remapping on the row axis.
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new row vocabulary file. Can be None, which represents no remapping
on the row axis.
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old column vocabulary file. Can be None, which represents no
remapping on the column axis.
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new column vocabulary file. Can be None, which represents no
remapping on the column axis.
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
to append. Must be >= 0.
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
columns to append. Must be >= 0.
initializer: Initializer function to initialize missing values. Accepts a
1-D tensor as the arg to specify the shape of the returned tensor. If
`None`, defaults to using `zeros_initializer()`.
max_rows_in_memory: `int` specifying the maximum number of rows to load from
the checkpoint at once. If less than or equal to 0, the entire matrix will
be loaded into memory. Setting this arg trades increased disk reads for
lower memory usage.
Returns:
A variable initializer function that should be used to initialize a
(potentially partitioned) `Variable` whose complete shape is
`[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size +
num_col_oov_buckets]`.
Raises:
TypeError: If `initializer` is specified but not callable.
"""
if initializer is None:
# TODO(b/25671353): Consider using sqrt(6/(fan_in + fan_out)) instead, from
# Glorot and Bengio, 2010.
initializer = init_ops.zeros_initializer()
if not callable(initializer):
raise TypeError(
"initializer must be callable, instead of being {} of type {}.".format(
initializer, type(initializer)))
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
"""Variable initializer.
Args:
shape: Shape of `Tensor` to return. Should include OOV on both axes.
dtype: Must be float32.
partition_info: variable_scope._PartitionInfo.
Returns:
`Tensor` of shape `shape`.
Raises:
TypeError: If `dtype` is anything other than float32.
ValueError: For shape mismatch upon invocation.
"""
# Sanity checks.
if dtype != dtypes.float32:
raise TypeError(
"Currently, only float32 is supported. Received dtype: {}".format(
dtype))
if len(shape) != 2:
raise ValueError("Expected 2-dim shape, but received: {}".format(shape))
if shape[0] <= 0:
raise ValueError(
"Expected 1st dim of shape to be > 0, but received shape: {}".format(
shape))
if shape[1] != (new_col_vocab_size + num_col_oov_buckets):
raise ValueError(
"Expected 2nd dim of shape to be new_col_vocab_size ({}) + "
"num_col_oov_buckets ({}) = {}, but received shape: {}".format(
new_col_vocab_size, num_col_oov_buckets,
new_col_vocab_size + num_col_oov_buckets, shape))
offset = 0
if partition_info is not None:
offset = partition_info.single_offset(shape)
if offset + shape[0] > new_row_vocab_size + num_row_oov_buckets:
raise ValueError(
"Trying to initialize {} additional rows after {} rows have already "
"been initialized, which would exceed expected total row count of "
"new_row_vocab_size ({}) + num_row_oov_buckets ({}) = {}.".format(
shape[0], offset, new_row_vocab_size, num_row_oov_buckets,
new_row_vocab_size + num_row_oov_buckets))
row_oov_buckets_to_use = min(shape[0],
max(0, offset + shape[0] - new_row_vocab_size))
num_rows_to_load = shape[0] - row_oov_buckets_to_use
return _load_and_remap_matrix(
ckpt_path=ckpt_path,
old_tensor_name=old_tensor_name,
new_row_vocab_offset=offset,
num_rows_to_load=num_rows_to_load,
new_col_vocab_size=new_col_vocab_size,
initializer=initializer,
old_row_vocab_file=old_row_vocab_file,
new_row_vocab_file=new_row_vocab_file,
old_col_vocab_file=old_col_vocab_file,
new_col_vocab_file=new_col_vocab_file,
num_row_oov_buckets=row_oov_buckets_to_use,
num_col_oov_buckets=num_col_oov_buckets,
max_rows_in_memory=max_rows_in_memory)
return _initializer
def load_embedding_initializer(ckpt_path,
embedding_tensor_name,
new_vocab_size,
embedding_dim,
old_vocab_file,
new_vocab_file,
num_oov_buckets=0,
initializer=None,
max_rows_in_memory=-1):
"""Returns a variable initializer for loading pre-trained embeddings.
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
embedding weights and remapping according to the provided vocab files. See
docs for `load_and_remap_matrix_initializer()` for more details.
NOTE: Only for use with div-partitioned variables / vocabularies.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
embedding_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
new_vocab_size: Number of entries in the new vocab.
embedding_dim: `int` specifying the dimension of the embedding vectors from
the checkpoint. Must match the number of columns in the old embedding
matrix.
old_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old vocabulary file.
new_vocab_file: A scalar `Tensor` of type `string` containing the
path to the new vocabulary file.
num_oov_buckets: `int` specifying the number of out-of-vocabulary
buckets to use. Must be >= 0.
initializer: Initializer function that accepts a 1-D tensor as the arg to
specify the shape of the returned tensor. If `None`, defaults to using
`truncated_normal_initializer()`.
max_rows_in_memory: `int` specifying the maximum number of rows to load from
the checkpoint at once. If less than or equal to 0, the entire matrix will
be loaded into memory. Setting this arg trades increased disk reads for
lower memory usage.
Returns:
A variable initializer function.
"""
if initializer is None:
# TODO(b/25671353): This should be kept in sync with the stddev used by
# feature_column.py's _EmbeddingColumn.
initializer = init_ops.truncated_normal_initializer(
stddev=1.0 / math.sqrt(embedding_dim))
return load_and_remap_matrix_initializer(
ckpt_path=ckpt_path,
old_tensor_name=embedding_tensor_name,
new_row_vocab_size=new_vocab_size,
new_col_vocab_size=embedding_dim,
old_row_vocab_file=old_vocab_file,
new_row_vocab_file=new_vocab_file,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=num_oov_buckets,
num_col_oov_buckets=0,
initializer=initializer,
max_rows_in_memory=max_rows_in_memory)
# pylint: disable=protected-access,line-too-long
load_and_remap_matrix_initializer = checkpoint_ops._load_and_remap_matrix_initializer
# pylint: enable=line-too-long
load_embedding_initializer = checkpoint_ops._load_embedding_initializer
# pylint: enable=protected-access
def load_linear_multiclass_bias_initializer(ckpt_path,

View File

@ -21,7 +21,6 @@ import os
import numpy as np
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework.python.ops import checkpoint_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -38,250 +37,6 @@ FLAGS = flags.FLAGS
_TESTDATA_PATH = 'contrib/framework/testdata'
class LoadAndRemapWrappersTest(test.TestCase):
"""Tests for the functionality of the Python wrappers."""
def setUp(self):
self.bundle_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint')
self.new_feature_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint_vocab.txt')
self.old_feature_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH),
'bundle_checkpoint_vocab_with_oov.txt')
self.new_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
self.old_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
self.init_val = 42
def _init_val_initializer(shape, dtype=None, partition_info=None):
del dtype, partition_info # Unused by this unit-testing initializer.
return array_ops.tile(
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
self.initializer = _init_val_initializer
def test_load_and_remap_matrix(self):
"""Tests the end-to-end loading / remapping of weights."""
# _load_and_remap_matrix() is the generalized wrapper that takes in row and
# column vocabulary files, calls the relevant remappings, and returns the
# weight matrix. Take this example to be linear multi-class by providing
# both row and column vocabularies.
remapped_matrix = checkpoint_ops._load_and_remap_matrix(
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_rows_to_load=4,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_offset=1,
initializer=self.initializer,
num_row_oov_buckets=1,
num_col_oov_buckets=1)
# [4 in vocab + 1 oov features, 4 in vocab + 1 oov classes]. The offset
# means we read
expected_remapped_matrix = np.concatenate(
[
np.reshape([18, 34, 50, self.init_val, self.init_val], [5, 1]),
np.reshape([16, 32, 48, self.init_val, self.init_val], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1]),
np.reshape([17, 33, 49, self.init_val, self.init_val], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1])
],
axis=1)
with self.test_session():
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_output_layer_weight_initializer_linear(self):
"""Tests for the output layer initializer in the linear multi-class case."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_row_oov_buckets=1,
num_col_oov_buckets=1,
initializer=self.initializer))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, self.init_val, self.init_val], [6, 1]),
np.reshape([0, 16, 32, 48, self.init_val, self.init_val], [6, 1]),
np.reshape([self.init_val] * 6, [6, 1]),
np.reshape([1, 17, 33, 49, self.init_val, self.init_val], [6, 1]),
np.reshape([self.init_val] * 6, [6, 1])
],
axis=1)
# The new weight matrix is of size
# [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV]. Use a
# partitioned variable to confirm that the offset logic works.
remapped_matrix = variable_scope.get_variable(
name='linear/obtained_weight_matrix',
shape=[6, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_load_and_remap_output_layer_weight_initializer_dnn_output(self):
"""Tests for the output layer initializer in the DNN output case."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
num_col_oov_buckets=1,
initializer=self.initializer))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, 66], [5, 1]),
np.reshape([0, 16, 32, 48, 64], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1]),
np.reshape([1, 17, 33, 49, 65], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1])
],
axis=1)
# The new weight matrix is of size
# [5-sized input layer, 4 class vocab + 1 class OOV].
remapped_matrix = variable_scope.get_variable(
name='dnn_output/obtained_weight_matrix',
shape=[5, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_initializer_with_oov_only_partition(self):
"""Tests for the output layer initializer where one partition is all OOV."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_row_oov_buckets=5,
num_col_oov_buckets=1,
initializer=self.initializer))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50] + [self.init_val] * 6, [10, 1]),
np.reshape([0, 16, 32, 48] + [self.init_val] * 6, [10, 1]),
np.reshape([self.init_val] * 10, [10, 1]),
np.reshape([1, 17, 33, 49] + [self.init_val] * 6, [10, 1]),
np.reshape([self.init_val] * 10, [10, 1]),
],
axis=1)
# The new weight matrix is of size
# [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV]. The
# second partition has only OOV.
remapped_matrix = variable_scope.get_variable(
name='linear_all_oov/obtained_weight_matrix',
shape=[10, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_load_and_remap_linear_multiclass_initializer_default_init(self):
"""Tests where the zeros_initializer default is used for linear."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_row_oov_buckets=1,
num_col_oov_buckets=1))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, 0, 0], [6, 1]),
np.reshape([0, 16, 32, 48, 0, 0], [6, 1]),
np.reshape([0] * 6, [6, 1]),
np.reshape([1, 17, 33, 49, 0, 0], [6, 1]),
np.reshape([0] * 6, [6, 1])
],
axis=1)
remapped_matrix = variable_scope.get_variable(
name='linear_init_fallback/obtained_weight_matrix',
shape=[6, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_load_embedding_initializer(self):
"""Tests for the load_embedding_initializer wrapper."""
embedding_loading_initializer = (
contrib_framework.load_embedding_initializer(
new_vocab_file=self.new_feature_vocab_file,
old_vocab_file=self.old_feature_vocab_file,
new_vocab_size=5,
embedding_dim=16,
embedding_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
num_oov_buckets=1,
initializer=self.initializer))
expected_remapped_embeddings = np.concatenate(
[
np.reshape(range(64), [4, 16]),
np.reshape([self.init_val] * 32, [2, 16]),
],
axis=0)
# The new weight matrix is of size
# [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
# last vocab row (2nd last row) is newly initialized (wasn't found in
# previous vocab) and the actual last row is OOV and also newly initialized.
# Use a partitioned variable to confirm that the offset logic works.
remapped_embeddings = variable_scope.get_variable(
name='embedding/obtained_embedding_matrix',
shape=[6, 16],
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
class LoadMulticlassBiasTest(test.TestCase):
"""Tests for the load_linear_multiclass_bias_initializer functionality."""

View File

@ -0,0 +1,27 @@
package(default_visibility = ["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "gan",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
deps = [
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,4 @@
This directory contains the TFGAN project.
This file will have more details as code is added.

View File

@ -0,0 +1,19 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN grouped API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -62,6 +62,7 @@ tf_cuda_library(
}),
deps = [
":gdr_proto_cc",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",

View File

@ -121,12 +121,9 @@ tf_gen_op_wrapper_py(
cc_library(
name = "image_ops_cc",
srcs = [
"ops/image_ops.cc",
],
srcs = ["ops/image_ops.cc"],
deps = [
":image_ops_kernels",
"//tensorflow/core",
"//tensorflow/core:framework",
],
alwayslink = 1,

View File

@ -551,6 +551,7 @@ py_test(
size = "small",
srcs = ["python/keras/utils/io_utils_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = [
":keras",
"//tensorflow/python:client_testlib",

View File

@ -57,43 +57,44 @@ class TestIOUtils(test.TestCase):
h5_path = os.path.join(temp_dir, 'test.h5')
create_dataset(h5_path)
# Instantiating HDF5Matrix for the training set,
# which is a slice of the first 150 elements
x_train = keras.utils.io_utils.HDF5Matrix(
h5_path, 'my_data', start=0, end=150)
y_train = keras.utils.io_utils.HDF5Matrix(
h5_path, 'my_labels', start=0, end=150)
with self.test_session():
# Instantiating HDF5Matrix for the training set,
# which is a slice of the first 150 elements
x_train = keras.utils.io_utils.HDF5Matrix(
h5_path, 'my_data', start=0, end=150)
y_train = keras.utils.io_utils.HDF5Matrix(
h5_path, 'my_labels', start=0, end=150)
# Likewise for the test set
x_test = keras.utils.io_utils.HDF5Matrix(
h5_path, 'my_data', start=150, end=200)
y_test = keras.utils.io_utils.HDF5Matrix(
h5_path, 'my_labels', start=150, end=200)
# Likewise for the test set
x_test = keras.utils.io_utils.HDF5Matrix(
h5_path, 'my_data', start=150, end=200)
y_test = keras.utils.io_utils.HDF5Matrix(
h5_path, 'my_labels', start=150, end=200)
# HDF5Matrix behave more or less like Numpy matrices
# with regard to indexing
self.assertEqual(y_train.shape, (150, 1))
# But they do not support negative indices, so don't try print(x_train[-1])
# HDF5Matrix behave more or less like Numpy matrices
# with regard to indexing
self.assertEqual(y_train.shape, (150, 1))
# But they don't support negative indices, so don't try print(x_train[-1])
self.assertEqual(y_train.dtype, np.dtype('i'))
self.assertEqual(y_train.ndim, 2)
self.assertEqual(y_train.size, 150)
self.assertEqual(y_train.dtype, np.dtype('i'))
self.assertEqual(y_train.ndim, 2)
self.assertEqual(y_train.size, 150)
model = keras.models.Sequential()
model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='sgd')
model = keras.models.Sequential()
model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='sgd')
# Note: you have to use shuffle='batch' or False with HDF5Matrix
model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False)
# test that evalutation and prediction
# don't crash and return reasonable results
out_pred = model.predict(x_test, batch_size=32, verbose=False)
out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False)
# Note: you have to use shuffle='batch' or False with HDF5Matrix
model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False)
# test that evalutation and prediction
# don't crash and return reasonable results
out_pred = model.predict(x_test, batch_size=32, verbose=False)
out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False)
self.assertEqual(out_pred.shape, (50, 1))
self.assertEqual(out_eval.shape, ())
self.assertGreater(out_eval, 0)
self.assertEqual(out_pred.shape, (50, 1))
self.assertEqual(out_eval.shape, ())
self.assertGreater(out_eval, 0)
if __name__ == '__main__':

View File

@ -28,7 +28,6 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
@ -44,7 +43,7 @@ def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
x_is_dict, y_is_dict = isinstance(
x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
if y_is_dict and n_classes is not None:
assert (isinstance(n_classes, dict))
assert isinstance(n_classes, dict)
if batch_size is None:
batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0]
@ -322,10 +321,12 @@ class DataFeeder(object):
self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())
]) if x_is_dict else check_array(x, x.dtype)
self._y = None if y is None else \
dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if y_is_dict else check_array(y, y.dtype)
self._y = None if y is None else (
dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())])
if y_is_dict else check_array(y, y.dtype))
# self.n_classes is not None means we're converting raw target indices to one-hot.
# self.n_classes is not None means we're converting raw target indices
# to one-hot.
if n_classes is not None:
if not y_is_dict:
y_dtype = (np.int64
@ -344,12 +345,15 @@ class DataFeeder(object):
x_shape, y_shape, n_classes, batch_size)
# Input dtype matches dtype of x.
self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) if x_is_dict \
else _check_dtype(self._x.dtype)
self._input_dtype = (
dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())])
if x_is_dict else _check_dtype(self._x.dtype))
# note: self._output_dtype = np.float32 when y is None
self._output_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) if y_is_dict \
else _check_dtype(self._y.dtype) if y is not None else np.float32
# self._output_dtype == np.float32 when y is None
self._output_dtype = (
dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())])
if y_is_dict else (
_check_dtype(self._y.dtype) if y is not None else np.float32))
# self.n_classes is None means we're passing in raw target indices
if n_classes is not None and y_is_dict:

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities supporting export to SavedModel.
Some contents of this file are moved to tensorflow/python/estimator/export.py:
@ -39,6 +38,7 @@ import time
from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.learn.python.learn import export_strategy
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import metric_key
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.learn.python.learn.utils import gc
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
@ -75,8 +75,8 @@ FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative'
_FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative'
def build_standardized_signature_def(
input_tensors, output_tensors, problem_type):
def build_standardized_signature_def(input_tensors, output_tensors,
problem_type):
"""Build a SignatureDef using problem type and input and output Tensors.
Note that this delegates the actual creation of the signatures to methods in
@ -116,8 +116,8 @@ def build_standardized_signature_def(
(_, predictions), = output_tensors.items()
return signature_def_utils.regression_signature_def(examples, predictions)
else:
return signature_def_utils.predict_signature_def(
input_tensors, output_tensors)
return signature_def_utils.predict_signature_def(input_tensors,
output_tensors)
def _get_classification_scores(output_tensors):
@ -139,17 +139,15 @@ def _is_classification_problem(problem_type, input_tensors, output_tensors):
classes = _get_classification_classes(output_tensors)
scores = _get_classification_scores(output_tensors)
return ((problem_type == constants.ProblemType.CLASSIFICATION or
problem_type == constants.ProblemType.LOGISTIC_REGRESSION)
and len(input_tensors) == 1
and (classes is not None or
scores is not None or
len(output_tensors) == 1))
problem_type == constants.ProblemType.LOGISTIC_REGRESSION) and
len(input_tensors) == 1 and
(classes is not None or scores is not None or
len(output_tensors) == 1))
def _is_regression_problem(problem_type, input_tensors, output_tensors):
return (problem_type == constants.ProblemType.LINEAR_REGRESSION
and len(input_tensors) == 1
and len(output_tensors) == 1)
return (problem_type == constants.ProblemType.LINEAR_REGRESSION and
len(input_tensors) == 1 and len(output_tensors) == 1)
def get_input_alternatives(input_ops):
@ -177,9 +175,7 @@ def get_input_alternatives(input_ops):
return input_alternatives, features
def get_output_alternatives(
model_fn_ops,
default_output_alternative_key=None):
def get_output_alternatives(model_fn_ops, default_output_alternative_key=None):
"""Obtain all output alternatives using the model_fn output and heuristics.
Args:
@ -218,8 +214,10 @@ def get_output_alternatives(
default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs}
actual_default_output_alternative_key = (
_FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY)
output_alternatives = {actual_default_output_alternative_key:
(default_problem_type, default_outputs)}
output_alternatives = {
actual_default_output_alternative_key: (default_problem_type,
default_outputs)
}
return output_alternatives, actual_default_output_alternative_key
if default_output_alternative_key:
@ -246,13 +244,12 @@ def build_all_signature_defs(input_alternatives, output_alternatives,
actual_default_output_alternative_key):
"""Build `SignatureDef`s from all pairs of input and output alternatives."""
signature_def_map = {
('%s:%s' % (input_key, output_key or 'None')):
build_standardized_signature_def(
inputs, outputs, problem_type)
for input_key, inputs in input_alternatives.items()
for output_key, (problem_type, outputs)
in output_alternatives.items()}
signature_def_map = {('%s:%s' % (input_key, output_key or 'None')):
build_standardized_signature_def(inputs, outputs,
problem_type)
for input_key, inputs in input_alternatives.items()
for output_key, (problem_type,
outputs) in output_alternatives.items()}
# Add the default SignatureDef
default_inputs = input_alternatives.get(DEFAULT_INPUT_ALTERNATIVE_KEY)
@ -263,8 +260,8 @@ def build_all_signature_defs(input_alternatives, output_alternatives,
(default_problem_type, default_outputs) = (
output_alternatives[actual_default_output_alternative_key])
signature_def_map[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
build_standardized_signature_def(
default_inputs, default_outputs, default_problem_type))
build_standardized_signature_def(default_inputs, default_outputs,
default_problem_type))
return signature_def_map
@ -308,9 +305,8 @@ def get_timestamped_export_dir(export_dir_base):
return export_dir
time.sleep(1)
attempts += 1
logging.warn(
'Export directory {} already exists; retrying (attempt {}/{})'.format(
export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
logging.warn('Export directory {} already exists; retrying (attempt {}/{})'.
format(export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
raise RuntimeError('Failed to obtain a unique export directory name after '
'{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
@ -330,8 +326,7 @@ def get_temp_export_dir(timestamped_export_dir):
"""
(dirname, basename) = os.path.split(timestamped_export_dir)
temp_export_dir = os.path.join(
compat.as_bytes(dirname),
compat.as_bytes('temp-{}'.format(basename)))
compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename)))
return temp_export_dir
@ -357,8 +352,8 @@ def get_most_recent_export(export_dir_base):
A gc.Path, with is just a namedtuple of (path, export_version).
"""
select_filter = gc.largest_export_versions(1)
results = select_filter(gc.get_paths(export_dir_base,
parser=_export_version_parser))
results = select_filter(
gc.get_paths(export_dir_base, parser=_export_version_parser))
return next(iter(results or []), None)
@ -378,8 +373,8 @@ def garbage_collect_exports(export_dir_base, exports_to_keep):
keep_filter = gc.largest_export_versions(exports_to_keep)
delete_filter = gc.negation(keep_filter)
for p in delete_filter(gc.get_paths(export_dir_base,
parser=_export_version_parser)):
for p in delete_filter(
gc.get_paths(export_dir_base, parser=_export_version_parser)):
try:
gfile.DeleteRecursively(p.path)
except errors_impl.NotFoundError as e:
@ -416,10 +411,7 @@ def make_export_strategy(serving_input_fn,
An ExportStrategy that can be passed to the Experiment constructor.
"""
def export_fn(estimator,
export_dir_base,
checkpoint_path=None
):
def export_fn(estimator, export_dir_base, checkpoint_path=None):
"""Exports the given Estimator as a SavedModel.
Args:
@ -512,3 +504,128 @@ def make_parsing_export_strategy(feature_columns,
assets_extra=assets_extra,
as_text=as_text,
exports_to_keep=exports_to_keep)
def _default_compare_fn(curr_best_eval_result, cand_eval_result):
"""Compares two evaluation results and returns true if the 2nd one is better.
Both evaluation results should have the values for MetricKey.LOSS, which are
used for comparison.
Args:
curr_best_eval_result: current best eval metrics.
cand_eval_result: candidate eval metrics.
Returns:
True if cand_eval_result is better.
Raises:
ValueError: If input eval result is None or no loss is available.
"""
default_key = metric_key.MetricKey.LOSS
if not curr_best_eval_result or default_key not in curr_best_eval_result:
raise ValueError(
'curr_best_eval_result cannot be empty or no loss is found in it.')
if not cand_eval_result or default_key not in cand_eval_result:
raise ValueError(
'cand_eval_result cannot be empty or no loss is found in it.')
return curr_best_eval_result[default_key] > cand_eval_result[default_key]
class BestModelSelector(object):
"""A helper that keeps track of export selection candidates."""
def __init__(self, compare_fn=None):
"""Constructor of this class.
Args:
compare_fn: a function that returns true if the candidate is better than
the current best model.
"""
self._best_eval_result = None
self._compare_fn = compare_fn or _default_compare_fn
def update(self, checkpoint_path, eval_result):
"""Records a given checkpoint and exports if this is the best model.
Args:
checkpoint_path: the checkpoint path to export.
eval_result: a dictionary which is usually generated in evaluation runs.
By default, eval_results contains 'loss' field.
Returns:
A string representing the path to the checkpoint to be exported.
A dictionary of the same type of eval_result.
Raises:
ValueError: if checkpoint path is empty.
ValueError: if eval_results is None object.
"""
if not checkpoint_path:
raise ValueError('Checkpoint path is empty.')
if eval_result is None:
raise ValueError('%s has empty evaluation results.', checkpoint_path)
if (self._best_eval_result is None or
self._compare_fn(self._best_eval_result, eval_result)):
self._best_eval_result = eval_result
return checkpoint_path, eval_result
else:
return '', None
def make_best_model_export_strategy(serving_input_fn,
exports_to_keep=1,
compare_fn=None,
default_output_alternative_key=None):
"""Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment.
Args:
serving_input_fn: a function that takes no arguments and returns an
`InputFnOps`.
exports_to_keep: an integer indicating how many historical best models need
to be preserved.
compare_fn: a function that select the 'best' candidate from a dictionary
of evaluation result keyed by corresponding checkpoint path.
default_output_alternative_key: the key for default serving signature for
multi-headed inference graphs.
Returns:
An ExportStrategy that can be passed to the Experiment constructor.
"""
best_model_export_strategy = make_export_strategy(
serving_input_fn,
exports_to_keep=exports_to_keep,
default_output_alternative_key=default_output_alternative_key)
best_model_selector = BestModelSelector(compare_fn)
def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None):
"""Exports the given Estimator as a SavedModel.
Args:
estimator: the Estimator to export.
export_dir_base: A string containing a directory to write the exported
graph and checkpoints.
checkpoint_path: The checkpoint path to export. If None (the default),
the most recent checkpoint found within the model directory is chosen.
eval_result: placehold args matching the call signature of ExportStrategy.
Returns:
The string path to the exported directory.
"""
export_checkpoint_path, export_eval_result = best_model_selector.update(
checkpoint_path, eval_result)
if export_checkpoint_path and export_eval_result is not None:
checkpoint_base = os.path.basename(export_checkpoint_path)
export_dir = os.path.join(export_dir_base, checkpoint_base)
return best_model_export_strategy.export(
estimator, export_dir, export_checkpoint_path, export_eval_result)
else:
return ''
return export_strategy.ExportStrategy('best_model', export_fn)

View File

@ -24,6 +24,7 @@ import time
from tensorflow.contrib.layers.python.layers import feature_column as fc
from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator as core_estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
@ -40,18 +41,43 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util import compat
class TestEstimator(core_estimator.Estimator):
def __init__(self, *args, **kwargs):
super(TestEstimator, self).__init__(*args, **kwargs)
self.last_exported_checkpoint = ""
self.last_exported_dir = ""
# @Override
def export_savedmodel(self,
export_dir,
serving_input_fn,
default_output_alternative_key=None,
assets_extra=None,
as_text=False,
checkpoint_path=None):
if not os.path.exists(export_dir):
os.makedirs(export_dir)
open(os.path.join(export_dir, "placeholder.txt"), "a").close()
self.last_exported_checkpoint = checkpoint_path
self.last_exported_dir = export_dir
return export_dir
class SavedModelExportUtilsTest(test.TestCase):
def test_build_standardized_signature_def_regression(self):
input_tensors = {
"input-1":
array_ops.placeholder(
dtypes.float32, 1, name="input-tensor-1")
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
}
output_tensors = {
"output-1":
array_ops.placeholder(
dtypes.float32, 1, name="output-tensor-1")
array_ops.placeholder(dtypes.float32, 1, name="output-tensor-1")
}
problem_type = constants.ProblemType.LINEAR_REGRESSION
actual_signature_def = (
@ -61,10 +87,9 @@ class SavedModelExportUtilsTest(test.TestCase):
shape = tensor_shape_pb2.TensorShapeProto(
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
dtype = types_pb2.DataType.Value("DT_FLOAT")
expected_signature_def.inputs[
signature_constants.REGRESS_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype, tensor_shape=shape))
expected_signature_def.inputs[signature_constants.REGRESS_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype, tensor_shape=shape))
expected_signature_def.outputs[
signature_constants.REGRESS_OUTPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
@ -77,13 +102,11 @@ class SavedModelExportUtilsTest(test.TestCase):
"""Tests classification with one output tensor."""
input_tensors = {
"input-1":
array_ops.placeholder(
dtypes.float32, 1, name="input-tensor-1")
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
}
output_tensors = {
"output-1":
array_ops.placeholder(
dtypes.string, 1, name="output-tensor-1")
array_ops.placeholder(dtypes.string, 1, name="output-tensor-1")
}
problem_type = constants.ProblemType.CLASSIFICATION
actual_signature_def = (
@ -94,14 +117,14 @@ class SavedModelExportUtilsTest(test.TestCase):
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
dtype_string = types_pb2.DataType.Value("DT_STRING")
expected_signature_def.inputs[
signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.outputs[
signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
meta_graph_pb2.TensorInfo(
name="output-tensor-1:0", dtype=dtype_string,
name="output-tensor-1:0",
dtype=dtype_string,
tensor_shape=shape))
expected_signature_def.method_name = (
@ -112,8 +135,7 @@ class SavedModelExportUtilsTest(test.TestCase):
"""Tests multiple output tensors that include classes and probabilities."""
input_tensors = {
"input-1":
array_ops.placeholder(
dtypes.float32, 1, name="input-tensor-1")
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
}
output_tensors = {
"classes":
@ -136,19 +158,20 @@ class SavedModelExportUtilsTest(test.TestCase):
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
dtype_string = types_pb2.DataType.Value("DT_STRING")
expected_signature_def.inputs[
signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.outputs[
signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
meta_graph_pb2.TensorInfo(
name="output-tensor-classes:0", dtype=dtype_string,
name="output-tensor-classes:0",
dtype=dtype_string,
tensor_shape=shape))
expected_signature_def.outputs[
signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
meta_graph_pb2.TensorInfo(
name="output-tensor-proba:0", dtype=dtype_float,
name="output-tensor-proba:0",
dtype=dtype_float,
tensor_shape=shape))
expected_signature_def.method_name = (
@ -159,8 +182,7 @@ class SavedModelExportUtilsTest(test.TestCase):
"""Tests multiple output tensors that include classes and scores."""
input_tensors = {
"input-1":
array_ops.placeholder(
dtypes.float32, 1, name="input-tensor-1")
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
}
output_tensors = {
"classes":
@ -182,19 +204,20 @@ class SavedModelExportUtilsTest(test.TestCase):
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
dtype_string = types_pb2.DataType.Value("DT_STRING")
expected_signature_def.inputs[
signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.outputs[
signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
meta_graph_pb2.TensorInfo(
name="output-tensor-classes:0", dtype=dtype_string,
name="output-tensor-classes:0",
dtype=dtype_string,
tensor_shape=shape))
expected_signature_def.outputs[
signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
meta_graph_pb2.TensorInfo(
name="output-tensor-scores:0", dtype=dtype_float,
name="output-tensor-scores:0",
dtype=dtype_float,
tensor_shape=shape))
expected_signature_def.method_name = (
@ -205,8 +228,7 @@ class SavedModelExportUtilsTest(test.TestCase):
"""Tests classification without classes tensor."""
input_tensors = {
"input-1":
array_ops.placeholder(
dtypes.float32, 1, name="input-tensor-1")
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
}
output_tensors = {
"probabilities":
@ -224,14 +246,14 @@ class SavedModelExportUtilsTest(test.TestCase):
shape = tensor_shape_pb2.TensorShapeProto(
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
expected_signature_def.inputs[
signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.outputs[
signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
meta_graph_pb2.TensorInfo(
name="output-tensor-proba:0", dtype=dtype_float,
name="output-tensor-proba:0",
dtype=dtype_float,
tensor_shape=shape))
expected_signature_def.method_name = (
@ -246,8 +268,7 @@ class SavedModelExportUtilsTest(test.TestCase):
"""
input_tensors = {
"input-1":
array_ops.placeholder(
dtypes.float32, 1, name="input-tensor-1")
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
}
output_tensors = {
"classes":
@ -268,14 +289,14 @@ class SavedModelExportUtilsTest(test.TestCase):
shape = tensor_shape_pb2.TensorShapeProto(
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
expected_signature_def.inputs[
signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.outputs[
signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom(
meta_graph_pb2.TensorInfo(
name="output-tensor-scores:0", dtype=dtype_float,
name="output-tensor-scores:0",
dtype=dtype_float,
tensor_shape=shape))
expected_signature_def.method_name = (
@ -290,8 +311,7 @@ class SavedModelExportUtilsTest(test.TestCase):
"""
input_tensors = {
"input-1":
array_ops.placeholder(
dtypes.float32, 1, name="input-tensor-1")
array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1")
}
output_tensors = {
"classes":
@ -310,17 +330,18 @@ class SavedModelExportUtilsTest(test.TestCase):
dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
dtype_int64 = types_pb2.DataType.Value("DT_INT64")
dtype_float = types_pb2.DataType.Value("DT_FLOAT")
expected_signature_def.inputs[
"input-1"].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.inputs["input-1"].CopyFrom(
meta_graph_pb2.TensorInfo(
name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape))
expected_signature_def.outputs["classes"].CopyFrom(
meta_graph_pb2.TensorInfo(
name="output-tensor-classes:0", dtype=dtype_int64,
name="output-tensor-classes:0",
dtype=dtype_int64,
tensor_shape=shape))
expected_signature_def.outputs["logits"].CopyFrom(
meta_graph_pb2.TensorInfo(
name="output-tensor-logits:0", dtype=dtype_float,
name="output-tensor-logits:0",
dtype=dtype_float,
tensor_shape=shape))
expected_signature_def.method_name = (
@ -379,8 +400,9 @@ class SavedModelExportUtilsTest(test.TestCase):
def test_get_output_alternatives_single_no_default(self):
prediction_tensor = constant_op.constant(["bogus"])
provided_output_alternatives = {
"head-1": (constants.ProblemType.LINEAR_REGRESSION,
{"output": prediction_tensor}),
"head-1": (constants.ProblemType.LINEAR_REGRESSION, {
"output": prediction_tensor
}),
}
model_fn_ops = model_fn.ModelFnOps(
model_fn.ModeKeys.INFER,
@ -390,10 +412,11 @@ class SavedModelExportUtilsTest(test.TestCase):
output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
model_fn_ops)
self.assertEqual({"head-1":
(constants.ProblemType.LINEAR_REGRESSION,
{"output": prediction_tensor})},
output_alternatives)
self.assertEqual({
"head-1": (constants.ProblemType.LINEAR_REGRESSION, {
"output": prediction_tensor
})
}, output_alternatives)
def test_get_output_alternatives_multi_no_default(self):
provided_output_alternatives = {
@ -424,10 +447,11 @@ class SavedModelExportUtilsTest(test.TestCase):
output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
model_fn_ops)
self.assertEqual(
{"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
"some_output": prediction_tensor})},
output_alternatives)
self.assertEqual({
"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
"some_output": prediction_tensor
})
}, output_alternatives)
def test_get_output_alternatives_empty_provided_with_default(self):
prediction_tensor = constant_op.constant(["bogus"])
@ -452,10 +476,11 @@ class SavedModelExportUtilsTest(test.TestCase):
output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
model_fn_ops)
self.assertEqual(
{"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
"some_output": prediction_tensor})},
output_alternatives)
self.assertEqual({
"default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
"some_output": prediction_tensor
})
}, output_alternatives)
def test_get_output_alternatives_implicit_single(self):
prediction_tensor = constant_op.constant(["bogus"])
@ -506,14 +531,14 @@ class SavedModelExportUtilsTest(test.TestCase):
expected_signature_defs = {
"serving_default":
signature_def_utils.regression_signature_def(input_example,
output_1),
signature_def_utils.regression_signature_def(
input_example, output_1),
"default_input_alternative:head-1":
signature_def_utils.regression_signature_def(input_example,
output_1),
signature_def_utils.regression_signature_def(
input_example, output_1),
"default_input_alternative:head-2":
signature_def_utils.classification_signature_def(input_example,
output_2, None),
signature_def_utils.classification_signature_def(
input_example, output_2, None),
"default_input_alternative:head-3":
signature_def_utils.predict_signature_def({
"default input": input_example
@ -624,17 +649,20 @@ class SavedModelExportUtilsTest(test.TestCase):
(most_recent_export_dir, most_recent_export_version) = (
saved_model_export_utils.get_most_recent_export(export_dir_base))
self.assertEqual(compat.as_bytes(export_dir_4),
compat.as_bytes(most_recent_export_dir))
self.assertEqual(compat.as_bytes(export_dir_4),
os.path.join(compat.as_bytes(export_dir_base),
compat.as_bytes(
str(most_recent_export_version))))
self.assertEqual(
compat.as_bytes(export_dir_4), compat.as_bytes(most_recent_export_dir))
self.assertEqual(
compat.as_bytes(export_dir_4),
os.path.join(
compat.as_bytes(export_dir_base),
compat.as_bytes(str(most_recent_export_version))))
def test_make_export_strategy(self):
"""Only tests that an ExportStrategy instance is created."""
def _serving_input_fn():
return array_ops.constant([1]), None
export_strategy = saved_model_export_utils.make_export_strategy(
serving_input_fn=_serving_input_fn,
default_output_alternative_key="default",
@ -655,14 +683,61 @@ class SavedModelExportUtilsTest(test.TestCase):
real_valued_col1 = fc.real_valued_column("real_valued_column1")
bucketized_col1 = fc.bucketized_column(
fc.real_valued_column("real_valued_column_for_bucketization1"), [0, 4])
feature_columns = [sparse_col, embedding_col, real_valued_col1,
bucketized_col1]
feature_columns = [
sparse_col, embedding_col, real_valued_col1, bucketized_col1
]
export_strategy = saved_model_export_utils.make_parsing_export_strategy(
feature_columns=feature_columns)
self.assertTrue(
isinstance(export_strategy, export_strategy_lib.ExportStrategy))
def test_make_best_model_export_strategy(self):
export_dir_base = tempfile.mkdtemp() + "export/"
gfile.MkDir(export_dir_base)
test_estimator = TestEstimator()
export_strategy = saved_model_export_utils.make_best_model_export_strategy(
serving_input_fn=None, exports_to_keep=3, compare_fn=None)
self.assertNotEqual("",
export_strategy.export(test_estimator, export_dir_base,
"fake_ckpt_0", {"loss": 100}))
self.assertNotEqual("", test_estimator.last_exported_dir)
self.assertNotEqual("", test_estimator.last_exported_checkpoint)
self.assertEqual("",
export_strategy.export(test_estimator, export_dir_base,
"fake_ckpt_1", {"loss": 101}))
self.assertEqual(test_estimator.last_exported_dir,
os.path.join(export_dir_base, "fake_ckpt_0"))
self.assertNotEqual("",
export_strategy.export(test_estimator, export_dir_base,
"fake_ckpt_2", {"loss": 10}))
self.assertEqual(test_estimator.last_exported_dir,
os.path.join(export_dir_base, "fake_ckpt_2"))
self.assertEqual("",
export_strategy.export(test_estimator, export_dir_base,
"fake_ckpt_3", {"loss": 20}))
self.assertEqual(test_estimator.last_exported_dir,
os.path.join(export_dir_base, "fake_ckpt_2"))
def test_make_best_model_export_strategy_exceptions(self):
export_dir_base = tempfile.mkdtemp() + "export/"
test_estimator = TestEstimator()
export_strategy = saved_model_export_utils.make_best_model_export_strategy(
serving_input_fn=None, exports_to_keep=3, compare_fn=None)
with self.assertRaises(ValueError):
export_strategy.export(test_estimator, export_dir_base, "", {"loss": 200})
with self.assertRaises(ValueError):
export_strategy.export(test_estimator, export_dir_base, "fake_ckpt_1",
None)
def _create_test_export_dir(export_dir_base):
export_dir = saved_model_export_utils.get_timestamped_export_dir(

View File

@ -0,0 +1,71 @@
# Description:
# Contains modules to compute receptive field parameters for CNN models.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
# Transitive dependencies of this target will be included in the pip package.
py_library(
name = "receptive_field_pip",
deps = [
":graph_compute_order_py",
":receptive_field_py",
],
)
py_library(
name = "graph_compute_order_py",
srcs = [
"__init__.py",
"python/util/graph_compute_order.py",
],
srcs_version = "PY2AND3",
)
py_library(
name = "receptive_field_py",
srcs = [
"__init__.py",
"python/util/receptive_field.py",
],
srcs_version = "PY2AND3",
deps = [
":graph_compute_order_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:platform",
],
)
py_test(
name = "receptive_field_test",
srcs = ["python/util/receptive_field_test.py"],
srcs_version = "PY2AND3",
deps = [
":receptive_field_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/slim",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:nn",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,165 @@
# Receptive field computation for convnets
This library enables you to easily compute the receptive field parameters of
your favorite convnet. You can use it to understand how big of an input image
region your output features depend on. Better yet, using the parameters computed
by the library, you can easily find the exact image region which is used to
compute each convnet feature.
## Basic usage
The main function to be called is `compute_receptive_field_from_graph_def`,
which will return the receptive field, effective stride and effective padding
for both horizontal and vertical directions.
For example, if your model is constructed using the function
`my_model_construction()`, you can use the library as follows:
```python
import tensorflow as tf
from tensorflow.contrib import receptive_field
# Construct graph.
g = tf.Graph()
with g.as_default():
images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image')
my_model_construction(images)
# Compute receptive field parameters.
rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \
receptive_field.compute_receptive_field_from_graph_def( \
g.as_graph_def(), 'input_image', 'my_output_endpoint')
```
Here's a simple example of computing the receptive field parameters for
Inception-Resnet-v2. To get this to work, be sure to checkout
[tensorflow/models](https://github.com/tensorflow/models), so that the Inception
models are available to you. This can be done in three simple commands:
```sh
git clone https://github.com/tensorflow/models
cd models/slim
sudo python setup.py install_lib
```
You can then compute the receptive field parameters for Inception-Resnet-v2 as:
```python
from nets import inception
import tensorflow as tf
from tensorflow.contrib import receptive_field
# Construct graph.
g = tf.Graph()
with g.as_default():
images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image')
inception.inception_resnet_v2_base(images)
# Compute receptive field parameters.
rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \
receptive_field.compute_receptive_field_from_graph_def( \
g.as_graph_def(), 'input_image', 'InceptionResnetV2/Conv2d_7b_1x1/Relu')
```
This will give you `rf_x = rf_y = 3039`, `eff_stride_x = eff_stride_y = 32`, and
`eff_pad_x = eff_pad_y = 1482`. This means that each feature that is output at
the node `'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is computed from a region
which is of size `3039x3039`. Further, by using the expressions
```python
center_x = -eff_pad_x + feature_x*eff_stride_x + (rf_x - 1)/2
center_y = -eff_pad_y + feature_y*eff_stride_y + (rf_y - 1)/2
```
one can compute the center of the region in the input image that is used to
compute the output feature at position `[feature_x, feature_y]`. For example,
the feature at position `[0, 2]` at the output of the layer
`'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is centered in the original image in
the position `[37, 101]`.
TODO: include link to derivations and definitions of different parameters.
## Receptive field benchmark
As you might expect, it is straightforward to run this library on the popular
convnets, and gather their receptive fields. We provide a python script which
does exactly that, available under `python/util/examples/rf_benchmark.py`.
To get this to work, be sure to checkout
[tensorflow/models](https://github.com/tensorflow/models) (see the 3-command
instructions for this above). Then, simply:
```sh
cd python/util/examples
python rf_benchmark.py --csv_path /tmp/rf_benchmark_results.csv
```
The script will write to stdout the receptive field parameters for many variants
of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They
are also written to the file `/tmp/rf_benchmark_results.csv`.
TODO: include here a plot for receptive field sizes of different convnets.
TODO: include table/link to pre-computed RF parameters.
## Compute RF parameters from a graph pbtxt
We also provide a utility to compute the receptive field parameters directly
from a graph protobuf file.
Have a `graph.pbtxt` file and want to compute its receptive field parameters? We
got you covered. The only prerequisite is to install
[google/protobuf](https://github.com/google/protobuf), which you probably
already have if you're using tensorflow (otherwise, follow installation
instructions [here](https://github.com/google/protobuf/tree/master/python)).
This should work:
```sh
cd python/util/examples
python compute_rf.py \
--graph_path /path/to/graph.pbtxt \
--output_path /path/to/output/rf_info.txt \
--input_node my_input_node \
--output_node my_output_node
```
Don't know how to generate a graph protobuf file? Take a look at the
`write_inception_resnet_v2_graph.py` script, which shows how to save it for the
Inception-Resnet-v2 model:
```sh
cd python/util/examples
python write_inception_resnet_v2_graph.py --graph_dir /tmp --graph_filename graph.pbtxt
```
This will write the Inception-Resnet-v2 graph protobuf to `/tmp/graph.pbtxt`.
For completeness, here's how you would use this file to get the receptive field
parameters of the Inception-Resnet-v2 model:
```sh
cd python/util/examples
python compute_rf.py \
--graph_path /tmp/graph.pbtxt \
--output_path /tmp/rf_info.txt \
--input_node input_image \
--output_node InceptionResnetV2/Conv2d_7b_1x1/Relu
```
This will write the receptive field parameters of the model to
`/tmp/rf_info.txt`, which will look like:
```sh
Receptive field size (horizontal) = 3039
Receptive field size (vertical) = 3039
Effective stride (horizontal) = 32
Effective stride (vertical) = 32
Effective padding (horizontal) = 1482
Effective padding (vertical) = 1482
```
## Authors
Andr&eacute; Araujo (github id: andrefaraujo) and Mark Sandler (github id:
marksandler)

View File

@ -0,0 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module to compute receptive field parameters for CNN tensorflow models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.receptive_field.python.util.graph_compute_order import get_compute_order
from tensorflow.contrib.receptive_field.python.util.receptive_field import compute_receptive_field_from_graph_def
# pylint: enable=unused-import

View File

@ -0,0 +1,19 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module to compute receptive field parameters for CNN tensorflow models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,94 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Computes Receptive Field (RF) information given a graph protobuf.
For an example of usage, see accompanying file compute_rf.sh
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from google.protobuf import text_format
from tensorflow.contrib import receptive_field
from tensorflow.core.framework import graph_pb2
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
cmd_args = None
def _load_graphdef(path):
"""Helper function to load GraphDef from file.
Args:
path: Path to pbtxt file.
Returns:
graph_def: A GraphDef object.
"""
graph_def = graph_pb2.GraphDef()
pbstr = gfile.Open(path).read()
text_format.Parse(pbstr, graph_def)
return graph_def
def main(unused_argv):
graph_def = _load_graphdef(cmd_args.graph_path)
(receptive_field_x, receptive_field_y, effective_stride_x, effective_stride_y,
effective_padding_x, effective_padding_y
) = receptive_field.compute_receptive_field_from_graph_def(
graph_def, cmd_args.input_node, cmd_args.output_node)
logging.info('Receptive field size (horizontal) = %s', receptive_field_x)
logging.info('Receptive field size (vertical) = %s', receptive_field_y)
logging.info('Effective stride (horizontal) = %s', effective_stride_x)
logging.info('Effective stride (vertical) = %s', effective_stride_y)
logging.info('Effective padding (horizontal) = %s', effective_padding_x)
logging.info('Effective padding (vertical) = %s', effective_padding_y)
f = gfile.GFile('%s' % cmd_args.output_path, 'w')
f.write('Receptive field size (horizontal) = %s\n' % receptive_field_x)
f.write('Receptive field size (vertical) = %s\n' % receptive_field_y)
f.write('Effective stride (horizontal) = %s\n' % effective_stride_x)
f.write('Effective stride (vertical) = %s\n' % effective_stride_y)
f.write('Effective padding (horizontal) = %s\n' % effective_padding_x)
f.write('Effective padding (vertical) = %s\n' % effective_padding_y)
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--graph_path', type=str, default='', help='Graph path (pbtxt format).')
parser.add_argument(
'--output_path',
type=str,
default='',
help='Path to output text file where RF information will be written to.')
parser.add_argument(
'--input_node', type=str, default='', help='Name of input node.')
parser.add_argument(
'--output_node', type=str, default='', help='Name of output node.')
cmd_args, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,460 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Computes Receptive Field (RF) information for different models.
The receptive field (and related parameters) for the different models are
printed to stdout, and may also optionally be written to a CSV file.
For an example of usage, see rf_benchmark.sh
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import csv
import sys
from nets import alexnet
from nets import inception
from nets import mobilenet_v1
from nets import resnet_v1
from nets import resnet_v2
from nets import vgg
from tensorflow.contrib import framework
from tensorflow.contrib import receptive_field
from tensorflow.contrib import slim
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import app
cmd_args = None
# Input node name for all architectures.
_INPUT_NODE = 'input_image'
# Variants of different network architectures.
# - resnet: different versions and sizes.
_SUPPORTED_RESNET_VARIANTS = [
'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v1_200',
'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'resnet_v2_200'
]
# - inception_resnet_v2: default, and version with SAME padding.
_SUPPORTED_INCEPTIONRESNETV2_VARIANTS = [
'inception_resnet_v2', 'inception_resnet_v2-same'
]
# - inception_v2: default, and version with no separable conv.
_SUPPORTED_INCEPTIONV2_VARIANTS = [
'inception_v2', 'inception_v2-no-separable-conv'
]
# - inception_v3: default version.
_SUPPORTED_INCEPTIONV3_VARIANTS = ['inception_v3']
# - inception_v4: default version.
_SUPPORTED_INCEPTIONV4_VARIANTS = ['inception_v4']
# - alexnet_v2: default version.
_SUPPORTED_ALEXNETV2_VARIANTS = ['alexnet_v2']
# - vgg: vgg_a (with 11 layers) and vgg_16 (version D).
_SUPPORTED_VGG_VARIANTS = ['vgg_a', 'vgg_16']
# - mobilenet_v1: 100% and 75%.
_SUPPORTED_MOBILENETV1_VARIANTS = ['mobilenet_v1', 'mobilenet_v1_075']
def _construct_model(model_type='resnet_v1_50'):
"""Constructs model for the desired type of CNN.
Args:
model_type: Type of model to be used.
Returns:
end_points: A dictionary from components of the network to the corresponding
activations.
Raises:
ValueError: If the model_type is not supported.
"""
# Placeholder input.
images = array_ops.placeholder(
dtypes.float32, shape=(1, None, None, 3), name=_INPUT_NODE)
# Construct model.
if model_type == 'inception_resnet_v2':
_, end_points = inception.inception_resnet_v2_base(images)
elif model_type == 'inception_resnet_v2-same':
_, end_points = inception.inception_resnet_v2_base(
images, align_feature_maps=True)
elif model_type == 'inception_v2':
_, end_points = inception.inception_v2_base(images)
elif model_type == 'inception_v2-no-separable-conv':
_, end_points = inception.inception_v2_base(
images, use_separable_conv=False)
elif model_type == 'inception_v3':
_, end_points = inception.inception_v3_base(images)
elif model_type == 'inception_v4':
_, end_points = inception.inception_v4_base(images)
elif model_type == 'alexnet_v2':
_, end_points = alexnet.alexnet_v2(images)
elif model_type == 'vgg_a':
_, end_points = vgg.vgg_a(images)
elif model_type == 'vgg_16':
_, end_points = vgg.vgg_16(images)
elif model_type == 'mobilenet_v1':
_, end_points = mobilenet_v1.mobilenet_v1_base(images)
elif model_type == 'mobilenet_v1_075':
_, end_points = mobilenet_v1.mobilenet_v1_base(
images, depth_multiplier=0.75)
elif model_type == 'resnet_v1_50':
_, end_points = resnet_v1.resnet_v1_50(
images, num_classes=None, is_training=False, global_pool=False)
elif model_type == 'resnet_v1_101':
_, end_points = resnet_v1.resnet_v1_101(
images, num_classes=None, is_training=False, global_pool=False)
elif model_type == 'resnet_v1_152':
_, end_points = resnet_v1.resnet_v1_152(
images, num_classes=None, is_training=False, global_pool=False)
elif model_type == 'resnet_v1_200':
_, end_points = resnet_v1.resnet_v1_200(
images, num_classes=None, is_training=False, global_pool=False)
elif model_type == 'resnet_v2_50':
_, end_points = resnet_v2.resnet_v2_50(
images, num_classes=None, is_training=False, global_pool=False)
elif model_type == 'resnet_v2_101':
_, end_points = resnet_v2.resnet_v2_101(
images, num_classes=None, is_training=False, global_pool=False)
elif model_type == 'resnet_v2_152':
_, end_points = resnet_v2.resnet_v2_152(
images, num_classes=None, is_training=False, global_pool=False)
elif model_type == 'resnet_v2_200':
_, end_points = resnet_v2.resnet_v2_200(
images, num_classes=None, is_training=False, global_pool=False)
else:
raise ValueError('Unsupported model_type %s.' % model_type)
return end_points
def _get_desired_end_point_keys(model_type='resnet_v1_50'):
"""Gets list of desired end point keys for a type of CNN.
Args:
model_type: Type of model to be used.
Returns:
desired_end_point_types: A list containing the desired end-points.
Raises:
ValueError: If the model_type is not supported.
"""
if model_type in _SUPPORTED_RESNET_VARIANTS:
blocks = ['block1', 'block2', 'block3', 'block4']
desired_end_point_keys = ['%s/%s' % (model_type, i) for i in blocks]
elif model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS:
desired_end_point_keys = [
'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3',
'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b',
'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1'
]
elif model_type in _SUPPORTED_INCEPTIONV2_VARIANTS:
desired_end_point_keys = [
'Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3',
'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b',
'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c'
]
elif model_type in _SUPPORTED_INCEPTIONV3_VARIANTS:
desired_end_point_keys = [
'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3',
'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b',
'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'
]
elif model_type in _SUPPORTED_INCEPTIONV4_VARIANTS:
desired_end_point_keys = [
'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a',
'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_5e',
'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_6f',
'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'Mixed_7d'
]
elif model_type in _SUPPORTED_ALEXNETV2_VARIANTS:
ep = ['conv1', 'pool1', 'conv2', 'conv3', 'conv4', 'conv5', 'pool5']
desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep]
elif model_type in _SUPPORTED_VGG_VARIANTS:
ep = [
'conv1/conv1_1', 'pool1', 'conv2/conv2_1', 'pool2', 'conv3/conv3_1',
'conv3/conv3_2', 'pool3', 'conv4/conv4_1', 'conv4/conv4_2', 'pool4',
'conv5/conv5_1', 'conv5/conv5_2', 'pool5'
]
desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep]
elif model_type in _SUPPORTED_MOBILENETV1_VARIANTS:
desired_end_point_keys = [
'Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise',
'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5_pointwise',
'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise',
'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise',
'Conv2d_12_pointwise', 'Conv2d_13_pointwise'
]
else:
raise ValueError('Unsupported model_type %s.' % model_type)
return desired_end_point_keys
def _model_graph_def(model_type='resnet_v1_50', arg_sc=None):
"""Constructs a model graph, returning GraphDef and end-points.
Args:
model_type: Type of model to be used.
arg_sc: Optional arg scope to use in constructing the graph.
Returns:
graph_def: GraphDef of constructed graph.
end_points: A dictionary from components of the network to the corresponding
activations.
"""
if arg_sc is None:
arg_sc = {}
g = ops.Graph()
with g.as_default():
with framework.arg_scope(arg_sc):
end_points = _construct_model(model_type)
return g.as_graph_def(), end_points
def _model_rf(graphdef,
end_points,
desired_end_point_keys,
model_type='resnet_v1_50',
csv_writer=None):
"""Computes receptive field information for a given CNN model.
The information will be printed to stdout. If the RF parameters are the same
for the horizontal and vertical directions, it will be printed only once.
Otherwise, they are printed once for the horizontal and once for the vertical
directions.
Args:
graphdef: GraphDef of given model.
end_points: A dictionary from components of the model to the corresponding
activations.
desired_end_point_keys: List of desired end points for which receptive field
information will be computed.
model_type: Type of model to be used, used only for printing purposes.
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
"""
for desired_end_point_key in desired_end_point_keys:
print('- %s:' % desired_end_point_key)
output_node_with_colon = end_points[desired_end_point_key].name
pos = output_node_with_colon.rfind(':')
output_node = output_node_with_colon[:pos]
(receptive_field_x, receptive_field_y, effective_stride_x,
effective_stride_y, effective_padding_x, effective_padding_y
) = receptive_field.compute_receptive_field_from_graph_def(
graphdef, _INPUT_NODE, output_node)
# If values are the same in horizontal/vertical directions, just report one
# of them. Otherwise, report both.
if (receptive_field_x == receptive_field_y) and (
effective_stride_x == effective_stride_y) and (
effective_padding_x == effective_padding_y):
print('Receptive field size = %5s, effective stride = %5s, effective '
'padding = %5s' % (str(receptive_field_x), str(effective_stride_x),
str(effective_padding_x)))
else:
print('Receptive field size: horizontal = %5s, vertical = %5s. '
'Effective stride: horizontal = %5s, vertical = %5s. Effective '
'padding: horizontal = %5s, vertical = %5s' %
(str(receptive_field_x), str(receptive_field_y),
str(effective_stride_x), str(effective_stride_y),
str(effective_padding_x), str(effective_padding_y)))
if csv_writer is not None:
csv_writer.writerow({
'CNN': model_type,
'end_point': desired_end_point_key,
'RF size hor': str(receptive_field_x),
'RF size ver': str(receptive_field_y),
'effective stride hor': str(effective_stride_x),
'effective stride ver': str(effective_stride_y),
'effective padding hor': str(effective_padding_x),
'effective padding ver': str(effective_padding_y)
})
def _process_model_rf(model_type='resnet_v1_50', csv_writer=None, arg_sc=None):
"""Contructs model graph and desired end-points, and compute RF.
The computed RF parameters are printed to stdout by the _model_rf function.
Args:
model_type: Type of model to be used.
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
arg_sc: Optional arg scope to use in constructing the graph.
"""
print('********************%s' % model_type)
graphdef, end_points = _model_graph_def(model_type, arg_sc)
desired_end_point_keys = _get_desired_end_point_keys(model_type)
_model_rf(graphdef, end_points, desired_end_point_keys, model_type,
csv_writer)
def _resnet_rf(csv_writer=None):
"""Computes RF and associated parameters for resnet models.
The computed values are written to stdout.
Args:
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
"""
for model_type in _SUPPORTED_RESNET_VARIANTS:
arg_sc = resnet_v1.resnet_arg_scope()
_process_model_rf(model_type, csv_writer, arg_sc)
def _inception_resnet_v2_rf(csv_writer=None):
"""Computes RF and associated parameters for the inception_resnet_v2 model.
The computed values are written to stdout.
Args:
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
"""
for model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS:
_process_model_rf(model_type, csv_writer)
def _inception_v2_rf(csv_writer=None):
"""Computes RF and associated parameters for the inception_v2 model.
The computed values are written to stdout.
Args:
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
"""
for model_type in _SUPPORTED_INCEPTIONV2_VARIANTS:
_process_model_rf(model_type, csv_writer)
def _inception_v3_rf(csv_writer=None):
"""Computes RF and associated parameters for the inception_v3 model.
The computed values are written to stdout.
Args:
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
"""
for model_type in _SUPPORTED_INCEPTIONV3_VARIANTS:
_process_model_rf(model_type, csv_writer)
def _inception_v4_rf(csv_writer=None):
"""Computes RF and associated parameters for the inception_v4 model.
The computed values are written to stdout.
Args:
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
"""
for model_type in _SUPPORTED_INCEPTIONV4_VARIANTS:
_process_model_rf(model_type, csv_writer)
def _alexnet_v2_rf(csv_writer=None):
"""Computes RF and associated parameters for the alexnet_v2 model.
The computed values are written to stdout.
Args:
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
"""
for model_type in _SUPPORTED_ALEXNETV2_VARIANTS:
_process_model_rf(model_type, csv_writer)
def _vgg_rf(csv_writer=None):
"""Computes RF and associated parameters for the vgg model.
The computed values are written to stdout.
Args:
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
"""
for model_type in _SUPPORTED_VGG_VARIANTS:
_process_model_rf(model_type, csv_writer)
def _mobilenet_v1_rf(csv_writer=None):
"""Computes RF and associated parameters for the mobilenet_v1 model.
The computed values are written to stdout.
Args:
csv_writer: A CSV writer for RF parameters, which is used if it is not None.
"""
for model_type in _SUPPORTED_MOBILENETV1_VARIANTS:
with slim.arg_scope(
[slim.batch_norm, slim.dropout], is_training=False) as arg_sc:
_process_model_rf(model_type, csv_writer, arg_sc)
def main(unused_argv):
# Configure CSV file which will be written, if desired.
if cmd_args.csv_path:
csv_file = open(cmd_args.csv_path, 'w')
field_names = [
'CNN', 'end_point', 'RF size hor', 'RF size ver',
'effective stride hor', 'effective stride ver', 'effective padding hor',
'effective padding ver'
]
rf_writer = csv.DictWriter(csv_file, fieldnames=field_names)
rf_writer.writeheader()
else:
rf_writer = None
# Compute RF parameters for each network architecture.
_alexnet_v2_rf(rf_writer)
_vgg_rf(rf_writer)
_inception_v2_rf(rf_writer)
_inception_v3_rf(rf_writer)
_inception_v4_rf(rf_writer)
_inception_resnet_v2_rf(rf_writer)
_mobilenet_v1_rf(rf_writer)
_resnet_rf(rf_writer)
# Close CSV file, if it was opened.
if cmd_args.csv_path:
csv_file.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--csv_path',
type=str,
default='',
help="""\
Path to CSV file that will be written with RF parameters.If empty, no
file will be written.\
""")
cmd_args, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,61 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple script to write Inception-ResNet-v2 model to graph file.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from nets import inception
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import app
cmd_args = None
def main(unused_argv):
# Model definition.
g = ops.Graph()
with g.as_default():
images = array_ops.placeholder(
dtypes.float32, shape=(1, None, None, 3), name='input_image')
inception.inception_resnet_v2_base(images)
graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir,
cmd_args.graph_filename)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--graph_dir',
type=str,
default='/tmp',
help='Directory where graph will be saved.')
parser.add_argument(
'--graph_filename',
type=str,
default='graph.pbtxt',
help='Filename of graph that will be saved.')
cmd_args, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,88 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to compute order of computations in a graph.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
class GraphDefHelper(object):
"""Helper class to collect node names and definitions.
Example:
b = GraphDefHelper(graph_def)
# Prints node that produces given output.
print b.output_of['conv/foo/bar']
"""
def __init__(self, gd):
self.output_of = {}
for each in gd.node:
self.output_of[each.name] = each
# pylint: disable=invalid-name
_NodeEntry = collections.namedtuple('NodeEntry', field_names=['order', 'node'])
def _get_computed_nodes(g, output, seen):
"""Traverses the graph in topological order.
Args:
g: GraphDefHelper object.
output: current node.
seen: map of nodes we've already traversed.
Returns:
order in topological sort for 'output'.
"""
if output in seen:
return seen[output].order
node_def = g.output_of.get(output, None)
if node_def is None:
seen[output] = _NodeEntry(0, None)
return 0
r = 0
for each in node_def.input:
# Parses name of input node.
if each.startswith('^'):
each = each[1:]
each = each.split(':')[0]
# Recursively computes ordering.
new_v = _get_computed_nodes(g, each, seen)
r = max(r, new_v + 1)
seen[output] = _NodeEntry(r, node_def)
return seen[output].order
def get_compute_order(graph_def):
"""Computes order of computation for a given graph.
Args:
graph_def: GraphDef object.
Returns:
map: name -> {order, node}
"""
helper = GraphDefHelper(graph_def)
seen = collections.defaultdict(_NodeEntry)
for each in graph_def.node:
_get_computed_nodes(helper, each.name, seen)
return seen

View File

@ -0,0 +1,485 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to compute receptive field of a fully-convolutional network.
Please refer to the following g3doc for detailed explanation on how this
computation is performed, and why it is important:
g3doc/photos/vision/features/delf/g3doc/rf_computation.md
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.contrib.receptive_field.python.util import graph_compute_order
from tensorflow.contrib.util import make_ndarray
from tensorflow.python.platform import tf_logging as logging
# White-listed layer operations, which do not affect the receptive field
# computation.
_UNCHANGED_RF_LAYER_OPS = [
"Softplus", "Relu", "BiasAdd", "Mul", "Add", "Const", "Identity",
"VariableV2", "Sub", "Rsqrt", "ConcatV2"
]
# Different ways in which padding modes may be spelled.
_VALID_PADDING = ["VALID", b"VALID"]
_SAME_PADDING = ["SAME", b"SAME"]
def _stride_size(node):
"""Computes stride size given a TF node.
Args:
node: Tensorflow node (NodeDef proto).
Returns:
stride_x: Stride size for horizontal direction (integer).
stride_y: Stride size for vertical direction (integer).
"""
strides_attr = node.attr["strides"]
logging.vlog(4, "strides_attr = %s", strides_attr)
stride_y = strides_attr.list.i[1]
stride_x = strides_attr.list.i[2]
return stride_x, stride_y
def _conv_kernel_size(node, name_to_order_node):
"""Computes kernel size given a TF convolution or pooling node.
Args:
node: Tensorflow node (NodeDef proto).
name_to_order_node: Map from name to {order, node}. Output of
graph_compute_order.get_compute_order().
Returns:
kernel_size_x: Kernel size for horizontal direction (integer).
kernel_size_y: Kernel size for vertical direction (integer).
Raises:
ValueError: If the weight layer node is invalid.
"""
weights_layer_read_name = node.input[1]
if not weights_layer_read_name.endswith("/read"):
raise ValueError(
"Weight layer's name input to conv layer does not end with '/read'")
weights_layer_param_name = weights_layer_read_name[:-5]
weights_node = name_to_order_node[weights_layer_param_name].node
if weights_node.op != "VariableV2":
raise ValueError("Weight layer is not of type VariableV2")
shape = weights_node.attr["shape"]
logging.vlog(4, "weight shape = %s", shape)
kernel_size_y = shape.shape.dim[0].size
kernel_size_x = shape.shape.dim[1].size
return kernel_size_x, kernel_size_y
def _padding_size_conv_pool(node, kernel_size, stride):
"""Computes padding size given a TF convolution or pooling node.
Args:
node: Tensorflow node (NodeDef proto).
kernel_size: Kernel size of node (integer).
stride: Stride size of node (integer).
Returns:
padding: Padding size (integer).
Raises:
ValueError: If padding is invalid.
"""
# In this case, we need to carefully consider the different TF padding modes.
# The padding depends on kernel size, and may depend on input size. If it
# depends on input size, we raise an exception.
padding_attr = node.attr["padding"]
logging.vlog(4, "padding_attr = %s", padding_attr)
if padding_attr.s in _VALID_PADDING:
padding = 0
elif padding_attr.s in _SAME_PADDING:
if kernel_size == 1:
padding = 0
elif stride == 1:
padding = int(math.floor((float(kernel_size) - 1) / 2))
elif stride == 2 and kernel_size % 2 == 0:
padding = int(math.floor((float(kernel_size) - 1) / 2))
else:
padding = None
logging.warning(
"Padding depends on input size, which means that the effective "
"padding may be different depending on the input image "
"dimensionality. In this case, alignment check will be skipped.")
else:
raise ValueError("Invalid padding operation %s" % padding_attr.s)
return padding
def _pool_kernel_size(node):
"""Computes kernel size given a TF pooling node.
Args:
node: Tensorflow node (NodeDef proto).
Returns:
kernel_size_x: Kernel size for horizontal direction (integer).
kernel_size_y: Kernel size for vertical direction (integer).
Raises:
ValueError: If pooling is invalid.
"""
ksize = node.attr["ksize"]
kernel_size_y = ksize.list.i[1]
kernel_size_x = ksize.list.i[2]
if ksize.list.i[0] != 1:
raise ValueError("pool ksize for first dim is not 1")
if ksize.list.i[3] != 1:
raise ValueError("pool ksize for last dim is not 1")
return kernel_size_x, kernel_size_y
def _padding_size_pad_layer(node, name_to_order_node):
"""Computes padding size given a TF padding node.
Args:
node: Tensorflow node (NodeDef proto).
name_to_order_node: Map from name to {order, node}. Output of
graph_compute_order.get_compute_order().
Returns:
padding_x: Padding size for horizontal direction (integer).
padding_y: Padding size for vertical direction (integer).
Raises:
ValueError: If padding layer is invalid.
"""
paddings_layer_name = node.input[1]
if not paddings_layer_name.endswith("/paddings"):
raise ValueError("Padding layer name does not end with '/paddings'")
paddings_node = name_to_order_node[paddings_layer_name].node
if paddings_node.op != "Const":
raise ValueError("Padding op is not Const")
value = paddings_node.attr["value"]
t = make_ndarray(value.tensor)
padding_y = t[1][0]
padding_x = t[2][0]
if t[0][0] != 0:
raise ValueError("padding is not zero for first tensor dim")
if t[3][0] != 0:
raise ValueError("padding is not zero for last tensor dim")
return padding_x, padding_y
def _get_layer_params(node, name_to_order_node):
"""Gets layer parameters relevant for RF computation.
Currently, only these nodes are supported:
- Conv2D
- DepthwiseConv2dNative
- Pad
- MaxPool
- AvgPool
- all nodes listed in _UNCHANGED_RF_LAYER_OPS
Args:
node: Tensorflow node (NodeDef proto).
name_to_order_node: Map from name to {order, node}. Output of
graph_compute_order.get_compute_order().
Returns:
kernel_size_x: Kernel size for horizontal direction (integer).
kernel_size_y: Kernel size for vertical direction (integer).
stride_x: Stride size for horizontal direction (integer).
stride_y: Stride size for vertical direction (integer).
padding_x: Padding size for horizontal direction (integer).
padding_y: Padding size for vertical direction (integer).
Raises:
ValueError: If layer op is unknown.
"""
logging.vlog(3, "node.op = %s", node.op)
logging.vlog(4, "node = %s", node)
if node.op == "Conv2D" or node.op == "DepthwiseConv2dNative":
stride_x, stride_y = _stride_size(node)
kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_order_node)
# Compute the padding for this node separately for each direction.
padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x)
padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y)
elif node.op == "Pad":
# Kernel and stride are simply 1 in this case.
kernel_size_x = 1
kernel_size_y = 1
stride_x = 1
stride_y = 1
padding_x, padding_y = _padding_size_pad_layer(node, name_to_order_node)
elif node.op == "MaxPool" or node.op == "AvgPool":
stride_x, stride_y = _stride_size(node)
kernel_size_x, kernel_size_y = _pool_kernel_size(node)
# Compute the padding for this node separately for each direction.
padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x)
padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y)
elif node.op in _UNCHANGED_RF_LAYER_OPS:
# These nodes do not modify the RF parameters.
kernel_size_x = 1
kernel_size_y = 1
stride_x = 1
stride_y = 1
padding_x = 0
padding_y = 0
else:
raise ValueError("Unknown layer op: %s" % node.op)
return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y
def _reverse_sort_by_order(name_to_order_node):
"""Sorts map of name_to_order_node nodes in reverse order.
The output is such that the nodes in name_to_order_node are sorted in
descending order of the "order" field.
Args:
name_to_order_node: Map from name to {order, node}. Output of
graph_compute_order.get_compute_order().
Returns:
sorted_name_to_order_node: Sorted version of the input, in descending order.
"""
return sorted(name_to_order_node.items(), key=lambda x: -x[1].order)
def _get_rf_size_node_input(stride, kernel_size, rf_size_output):
"""Computes RF size at the input of a given layer.
Args:
stride: Stride of given layer (integer).
kernel_size: Kernel size of given layer (integer).
rf_size_output: RF size at output of given layer (integer).
Returns:
rf_size_input: RF size at input of given layer (integer).
"""
return stride * rf_size_output + kernel_size - stride
def _get_effective_stride_node_input(stride, effective_stride_output):
"""Computes effective stride at the input of a given layer.
Args:
stride: Stride of given layer (integer).
effective_stride_output: Effective stride at output of given layer
(integer).
Returns:
effective_stride_input: Effective stride at input of given layer
(integer).
"""
return stride * effective_stride_output
def _get_effective_padding_node_input(stride, padding,
effective_padding_output):
"""Computes effective padding at the input of a given layer.
Args:
stride: Stride of given layer (integer).
padding: Padding of given layer (integer).
effective_padding_output: Effective padding at output of given layer
(integer).
Returns:
effective_padding_input: Effective padding at input of given layer
(integer).
"""
return stride * effective_padding_output + padding
def compute_receptive_field_from_graph_def(graph_def, input_node, output_node):
"""Computes receptive field (RF) parameters from a GraphDef object.
Args:
graph_def: GraphDef object.
input_node: Name of the input node from graph.
output_node: Name of the output node from graph.
Returns:
rf_size_x: Receptive field size of network in the horizontal direction, with
respect to specified input and output.
rf_size_y: Receptive field size of network in the vertical direction, with
respect to specified input and output.
effective_stride_x: Effective stride of network in the horizontal direction,
with respect to specified input and output.
effective_stride_y: Effective stride of network in the vertical direction,
with respect to specified input and output.
effective_padding_x: Effective padding of network in the horizontal
direction, with respect to specified input and output.
effective_padding_y: Effective padding of network in the vertical
direction, with respect to specified input and output.
Raises:
ValueError: If network is not aligned or if either input or output nodes
cannot be found. For network criterion alignment, see
photos/vision/features/delf/g3doc/rf_computation.md
"""
# Computes order of computation for a given graph.
name_to_order_node = graph_compute_order.get_compute_order(
graph_def=graph_def)
# Sort in reverse topological order.
order = _reverse_sort_by_order(name_to_order_node)
# Dictionaries to keep track of receptive field, effective stride and
# effective padding of different nodes.
rf_sizes_x = {}
rf_sizes_y = {}
effective_strides_x = {}
effective_strides_y = {}
effective_paddings_x = {}
effective_paddings_y = {}
# Initialize dicts for output_node.
rf_sizes_x[output_node] = 1
rf_sizes_y[output_node] = 1
effective_strides_x[output_node] = 1
effective_strides_y[output_node] = 1
effective_paddings_x[output_node] = 0
effective_paddings_y[output_node] = 0
# Flag to denote if we found output node yet. If we have not, we skip nodes
# until the output node is found.
found_output_node = False
# Flag to denote if padding is undefined. This happens when SAME padding mode
# is used in conjunction with stride and kernel sizes which make it such that
# the padding to be applied would depend on the input size. In this case,
# alignment checks are skipped, and the effective padding is None.
undefined_padding = False
for _, (o, node) in order:
if node:
logging.vlog(3, "%10d %-100s %-20s" % (o, node.name[:90], node.op))
else:
continue
# When we find input node, we can stop.
if node.name == input_node:
break
# Loop until we find the output node. All nodes before finding the output
# one are irrelevant, so they can be skipped.
if not found_output_node:
if node.name == output_node:
found_output_node = True
if found_output_node:
if node.name not in rf_sizes_x:
assert node.name not in rf_sizes_y, ("Node %s is in rf_sizes_y, but "
"not in rf_sizes_x" % node.name)
# In this case, node is not relevant since it's not part of the
# computation we're interested in.
logging.vlog(3, "Irrelevant node %s, skipping it...", node.name)
continue
# Get params for this layer.
kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y = (
_get_layer_params(node, name_to_order_node))
logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, "
"stride_x = %s, stride_y = %s, "
"padding_x = %s, padding_y = %s" %
(kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x,
padding_y))
if padding_x is None or padding_y is None:
undefined_padding = True
# Get parameters at input of this layer which may or may not be propagated
# to the input layers.
rf_size_input_x = _get_rf_size_node_input(stride_x, kernel_size_x,
rf_sizes_x[node.name])
rf_size_input_y = _get_rf_size_node_input(stride_y, kernel_size_y,
rf_sizes_y[node.name])
effective_stride_input_x = _get_effective_stride_node_input(
stride_x, effective_strides_x[node.name])
effective_stride_input_y = _get_effective_stride_node_input(
stride_y, effective_strides_y[node.name])
if not undefined_padding:
effective_padding_input_x = _get_effective_padding_node_input(
stride_x, padding_x, effective_paddings_x[node.name])
effective_padding_input_y = _get_effective_padding_node_input(
stride_y, padding_y, effective_paddings_y[node.name])
else:
effective_padding_input_x = None
effective_padding_input_y = None
# Loop over this node's inputs and potentially propagate information down.
for inp_name in node.input:
logging.vlog(4, "inp_name = %s", inp_name)
inp_node = name_to_order_node[inp_name].node
logging.vlog(4, "inp_node = \n%s", inp_node)
if inp_node.name in rf_sizes_x:
assert inp_node.name in rf_sizes_y, (
"Node %s is in rf_sizes_x, but "
"not in rf_sizes_y" % inp_node.name)
# This node was already discovered through a previous path, so we need
# to make sure that graph is aligned. This alignment check is skipped
# if the padding is not defined, since in this case alignment cannot
# be checked.
if not undefined_padding:
if effective_strides_x[inp_node.name] != effective_stride_input_x:
raise ValueError(
"Graph is not aligned since effective stride from different "
"paths is different in horizontal direction")
if effective_strides_y[inp_node.name] != effective_stride_input_y:
raise ValueError(
"Graph is not aligned since effective stride from different "
"paths is different in vertical direction")
if (rf_sizes_x[inp_node.name] - 1
) / 2 - effective_paddings_x[inp_node.name] != (
rf_size_input_x - 1) / 2 - effective_padding_input_x:
raise ValueError(
"Graph is not aligned since center shift from different "
"paths is different in horizontal direction")
if (rf_sizes_y[inp_node.name] - 1
) / 2 - effective_paddings_y[inp_node.name] != (
rf_size_input_y - 1) / 2 - effective_padding_input_y:
raise ValueError(
"Graph is not aligned since center shift from different "
"paths is different in vertical direction")
# Keep track of path with largest RF, for both directions.
if rf_sizes_x[inp_node.name] < rf_size_input_x:
rf_sizes_x[inp_node.name] = rf_size_input_x
effective_strides_x[inp_node.name] = effective_stride_input_x
effective_paddings_x[inp_node.name] = effective_padding_input_x
if rf_sizes_y[inp_node.name] < rf_size_input_y:
rf_sizes_y[inp_node.name] = rf_size_input_y
effective_strides_y[inp_node.name] = effective_stride_input_y
effective_paddings_y[inp_node.name] = effective_padding_input_y
else:
assert inp_node.name not in rf_sizes_y, (
"Node %s is in rf_sizes_y, but "
"not in rf_sizes_x" % inp_node.name)
# In this case, it is the first time we encounter this node. So we
# propagate the RF parameters.
rf_sizes_x[inp_node.name] = rf_size_input_x
rf_sizes_y[inp_node.name] = rf_size_input_y
effective_strides_x[inp_node.name] = effective_stride_input_x
effective_strides_y[inp_node.name] = effective_stride_input_y
effective_paddings_x[inp_node.name] = effective_padding_input_x
effective_paddings_y[inp_node.name] = effective_padding_input_y
if not found_output_node:
raise ValueError("Output node was not found")
if input_node not in rf_sizes_x:
raise ValueError("Input node was not found")
return (rf_sizes_x[input_node], rf_sizes_y[input_node],
effective_strides_x[input_node], effective_strides_y[input_node],
effective_paddings_x[input_node], effective_paddings_y[input_node])

View File

@ -0,0 +1,225 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for receptive_fields module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib import slim
from tensorflow.contrib.receptive_field.python.util import receptive_field
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
def create_test_network_1():
"""Aligned network for test.
The graph corresponds to the example from the second figure in
go/cnn-rf-computation#arbitrary-computation-graphs
Returns:
g: Tensorflow graph object (Graph proto).
"""
g = ops.Graph()
with g.as_default():
# An 8x8 test image.
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
# Left branch.
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
# Right branch.
l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID')
l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
# Addition.
nn.relu(l1 + l3, name='output')
return g
def create_test_network_2():
"""Aligned network for test.
The graph corresponds to a variation to the example from the second figure in
go/cnn-rf-computation#arbitrary-computation-graphs. Layers 2 and 3 are changed
to max-pooling operations. Since the functionality is the same as convolution,
the network is aligned and the receptive field size is the same as from the
network created using create_test_network_1().
Returns:
g: Tensorflow graph object (Graph proto).
"""
g = ops.Graph()
with g.as_default():
# An 8x8 test image.
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
# Left branch.
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
# Right branch.
l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]])
l2 = slim.max_pool2d(l2_pad, [3, 3], stride=2, scope='L2', padding='VALID')
l3 = slim.max_pool2d(l2, [1, 1], stride=2, scope='L3', padding='VALID')
# Addition.
nn.relu(l1 + l3, name='output')
return g
def create_test_network_3():
"""Misaligned network for test.
The graph corresponds to the example from the first figure in
go/cnn-rf-computation#arbitrary-computation-graphs
Returns:
g: Tensorflow graph object (Graph proto).
"""
g = ops.Graph()
with g.as_default():
# An 8x8 test image.
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
# Left branch.
l1_pad = array_ops.pad(x, [[0, 0], [2, 1], [2, 1], [0, 0]])
l1 = slim.conv2d(l1_pad, 1, [5, 5], stride=2, scope='L1', padding='VALID')
# Right branch.
l2 = slim.conv2d(x, 1, [3, 3], stride=1, scope='L2', padding='VALID')
l3 = slim.conv2d(l2, 1, [3, 3], stride=1, scope='L3', padding='VALID')
# Addition.
nn.relu(l1 + l3, name='output')
return g
def create_test_network_4():
"""Misaligned network for test.
The graph corresponds to a variation from the example from the second figure
in go/cnn-rf-computation#arbitrary-computation-graphs. Layer 2 uses 'SAME'
padding, which makes its padding dependent on the input image dimensionality.
In this case, the effective padding will be undetermined, and the utility is
not able to check the network alignment.
Returns:
g: Tensorflow graph object (Graph proto).
"""
g = ops.Graph()
with g.as_default():
# An 8x8 test image.
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
# Left branch.
l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID')
# Right branch.
l2 = slim.conv2d(x, 1, [3, 3], stride=2, scope='L2', padding='SAME')
l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID')
# Addition.
nn.relu(l1 + l3, name='output')
return g
def create_test_network_5():
"""Single-path network for testing non-square kernels.
The graph is similar to the right branch of the graph from
create_test_network_1(), except that the kernel sizes are changed to be
non-square.
Returns:
g: Tensorflow graph object (Graph proto).
"""
g = ops.Graph()
with g.as_default():
# An 8x8 test image.
x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image')
# Two convolutional layers, where the first one has non-square kernel.
l1 = slim.conv2d(x, 1, [3, 5], stride=2, scope='L1', padding='VALID')
l2 = slim.conv2d(l1, 1, [3, 1], stride=2, scope='L2', padding='VALID')
# ReLU.
nn.relu(l2, name='output')
return g
class RfUtilsTest(test.TestCase):
def testComputeRFFromGraphDefAligned(self):
graph_def = create_test_network_1().as_graph_def()
input_node = 'input_image'
output_node = 'output'
(receptive_field_x, receptive_field_y, effective_stride_x,
effective_stride_y, effective_padding_x, effective_padding_y) = (
receptive_field.compute_receptive_field_from_graph_def(
graph_def, input_node, output_node))
self.assertEqual(receptive_field_x, 3)
self.assertEqual(receptive_field_y, 3)
self.assertEqual(effective_stride_x, 4)
self.assertEqual(effective_stride_y, 4)
self.assertEqual(effective_padding_x, 1)
self.assertEqual(effective_padding_y, 1)
def testComputeRFFromGraphDefAligned2(self):
graph_def = create_test_network_2().as_graph_def()
input_node = 'input_image'
output_node = 'output'
(receptive_field_x, receptive_field_y, effective_stride_x,
effective_stride_y, effective_padding_x, effective_padding_y) = (
receptive_field.compute_receptive_field_from_graph_def(
graph_def, input_node, output_node))
self.assertEqual(receptive_field_x, 3)
self.assertEqual(receptive_field_y, 3)
self.assertEqual(effective_stride_x, 4)
self.assertEqual(effective_stride_y, 4)
self.assertEqual(effective_padding_x, 1)
self.assertEqual(effective_padding_y, 1)
def testComputeRFFromGraphDefUnaligned(self):
graph_def = create_test_network_3().as_graph_def()
input_node = 'input_image'
output_node = 'output'
with self.assertRaises(ValueError):
receptive_field.compute_receptive_field_from_graph_def(
graph_def, input_node, output_node)
def testComputeRFFromGraphDefUnaligned2(self):
graph_def = create_test_network_4().as_graph_def()
input_node = 'input_image'
output_node = 'output'
(receptive_field_x, receptive_field_y, effective_stride_x,
effective_stride_y, effective_padding_x, effective_padding_y) = (
receptive_field.compute_receptive_field_from_graph_def(
graph_def, input_node, output_node))
self.assertEqual(receptive_field_x, 3)
self.assertEqual(receptive_field_y, 3)
self.assertEqual(effective_stride_x, 4)
self.assertEqual(effective_stride_y, 4)
self.assertEqual(effective_padding_x, None)
self.assertEqual(effective_padding_y, None)
def testComputeRFFromGraphDefNonSquareRF(self):
graph_def = create_test_network_5().as_graph_def()
input_node = 'input_image'
output_node = 'output'
(receptive_field_x, receptive_field_y, effective_stride_x,
effective_stride_y, effective_padding_x, effective_padding_y) = (
receptive_field.compute_receptive_field_from_graph_def(
graph_def, input_node, output_node))
self.assertEqual(receptive_field_x, 5)
self.assertEqual(receptive_field_y, 7)
self.assertEqual(effective_stride_x, 4)
self.assertEqual(effective_stride_y, 4)
self.assertEqual(effective_padding_x, 0)
self.assertEqual(effective_padding_y, 0)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,59 @@
licenses(["notice"]) # Apache 2.0
exports_files([
"LICENSE",
])
load(
"//tensorflow:tensorflow.bzl",
"py_test",
"tf_gen_op_wrapper_py",
)
tf_gen_op_wrapper_py(
name = "gen_summary_ops",
out = "gen_summary_ops.py",
deps = ["//tensorflow/core:summary_ops_op_lib"],
)
py_test(
name = "summary_ops_test",
srcs = ["summary_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":summary_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
)
py_library(
name = "summary_ops",
srcs = ["summary_ops.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":gen_summary_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:summary_op_util",
"//tensorflow/python:training",
"//tensorflow/python/eager:context",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,159 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations to emit summaries."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.summary import gen_summary_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import summary_op_util
from tensorflow.python.training import training_util
# Name for a collection which is expected to have at most a single boolean
# Tensor. If this tensor is True the summary ops will record summaries.
_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
def should_record_summaries():
"""Returns boolean Tensor which is true if summaries should be recorded."""
should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME)
if not should_record_collection:
return constant_op.constant(False)
if len(should_record_collection) != 1:
raise ValueError(
"More than one tensor specified for whether summaries "
"should be recorded: %s" % should_record_collection)
return should_record_collection[0]
# TODO(apassos) consider how to handle local step here.
def record_summaries_every_n_global_steps(n):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
collection_ref[:] = [training_util.get_global_step() % n == 0]
def always_record_summaries():
"""Sets the should_record_summaries Tensor to always true."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
collection_ref[:] = [constant_op.constant(True)]
def never_record_summaries():
"""Sets the should_record_summaries Tensor to always false."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
collection_ref[:] = [constant_op.constant(False)]
def create_summary_file_writer(logdir,
max_queue=None,
flush_secs=None,
filename_suffix=None):
"""Creates a summary file writer in the current context."""
if max_queue is None:
max_queue = constant_op.constant(10)
if flush_secs is None:
flush_secs = constant_op.constant(120)
if filename_suffix is None:
filename_suffix = constant_op.constant("")
resource = gen_summary_ops.summary_writer()
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
flush_secs, filename_suffix)
context.context().summary_writer_resource = resource
def _nothing():
"""Convenient else branch for when summaries do not record."""
return
def generic(name, tensor, metadata, family=None):
"""Writes a tensor summary if possible."""
def record():
with summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
gen_summary_ops.write_summary(context.context().summary_writer_resource,
training_util.get_global_step(), tensor,
tag, metadata, name=scope)
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
def scalar(name, tensor, family=None):
"""Writes a scalar summary if possible."""
def record():
with summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
gen_summary_ops.write_scalar_summary(
context.context().summary_writer_resource,
training_util.get_global_step(), tag, tensor, name=scope)
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
def histogram(name, tensor, family=None):
"""Writes a histogram summary if possible."""
def record():
with summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
gen_summary_ops.write_histogram_summary(
context.context().summary_writer_resource,
training_util.get_global_step(), tag, tensor, name=scope)
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
def image(name, tensor, bad_color=None, max_images=3, family=None):
"""Writes an image summary if possible."""
def record():
if bad_color is None:
bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
with summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
gen_summary_ops.write_image_summary(
context.context().summary_writer_resource,
training_util.get_global_step(), tag, tensor, bad_color_, max_images,
name=scope)
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
def audio(name, tensor, sample_rate, max_outputs, family=None):
"""Writes an audio summary if possible."""
def record():
with summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
gen_summary_ops.write_audio_summary(
context.context().summary_writer_resource,
training_util.get_global_step(),
tag,
tensor,
sample_rate=sample_rate,
max_outputs=max_outputs,
name=scope)
return control_flow_ops.cond(should_record_summaries(), record, _nothing)

View File

@ -0,0 +1,52 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from tensorflow.contrib.summary import summary_ops
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
class TargetTest(test_util.TensorFlowTestCase):
def testShouldRecordSummary(self):
self.assertFalse(summary_ops.should_record_summaries().numpy())
summary_ops.always_record_summaries()
self.assertTrue(summary_ops.should_record_summaries().numpy())
def testSummaryOps(self):
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
summary_ops.create_summary_file_writer(logdir, max_queue=0)
summary_ops.always_record_summaries()
summary_ops.generic('tensor', 1, '')
summary_ops.scalar('scalar', 2.0)
summary_ops.histogram('histogram', [1.0])
summary_ops.image('image', [[[[1.0]]]])
summary_ops.audio('audio', [[1.0]], 1.0, 1)
# The working condition of the ops is tested in the C++ test so we just
# test here that we're calling them correctly.
self.assertTrue(gfile.Exists(logdir))
if __name__ == '__main__':
test.main()

Some files were not shown because too many files have changed in this diff Show More