Merge branch 'master' into expose-symbols-osx

This commit is contained in:
Fritz Obermeyer 2017-09-06 15:26:58 -07:00
commit 6d7ac549b6
315 changed files with 15400 additions and 5266 deletions
configure.py
tensorflow
BUILD
c
cc
compiler
contrib

View File

@ -685,10 +685,12 @@ def set_tf_cunn_version(environ_cp):
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)',
cudnn_path_from_ldconfig).group(1)
if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
break
cudnn_path_from_ldconfig)
if cudnn_path_from_ldconfig:
cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
break
# Reset and Retry
print(

View File

@ -296,6 +296,7 @@ filegroup(
"//tensorflow/contrib/ffmpeg/default:all_files",
"//tensorflow/contrib/framework:all_files",
"//tensorflow/contrib/fused_conv:all_files",
"//tensorflow/contrib/gan:all_files",
"//tensorflow/contrib/graph_editor:all_files",
"//tensorflow/contrib/grid_rnn:all_files",
"//tensorflow/contrib/hooks:all_files",
@ -323,6 +324,7 @@ filegroup(
"//tensorflow/contrib/nn:all_files",
"//tensorflow/contrib/opt:all_files",
"//tensorflow/contrib/predictor:all_files",
"//tensorflow/contrib/receptive_field:all_files",
"//tensorflow/contrib/reduce_slice_ops:all_files",
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
"//tensorflow/contrib/resampler:all_files",
@ -342,6 +344,7 @@ filegroup(
"//tensorflow/contrib/staging:all_files",
"//tensorflow/contrib/stat_summarizer:all_files",
"//tensorflow/contrib/stateless:all_files",
"//tensorflow/contrib/summary:all_files",
"//tensorflow/contrib/tensor_forest:all_files",
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
"//tensorflow/contrib/tensor_forest/kernels/v4:all_files",

View File

@ -45,8 +45,13 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api",
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
srcs = [
"c_api.cc",
"c_api_function.cc",
],
hdrs = [
"c_api.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
@ -157,6 +162,21 @@ tf_cc_test(
],
)
tf_cc_test(
name = "c_api_function_test",
size = "small",
srcs = ["c_api_function_test.cc"],
deps = [
":c_api",
":c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "while_loop_test",
size = "small",

View File

@ -165,22 +165,6 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
tensorflow::cpu_allocator()->DeallocateRaw(data);
}
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
TF_Buffer* out) {
if (out->data != nullptr) {
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
}
const auto proto_size = in.ByteSizeLong();
void* buf = tensorflow::port::Malloc(proto_size);
in.SerializeToArray(buf, proto_size);
out->data = buf;
out->length = proto_size;
out->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
return Status::OK();
}
} // namespace
TF_Tensor::~TF_Tensor() { buffer->Unref(); }
@ -559,6 +543,27 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
dimvec.size(), base, size, DeleteArray, base);
}
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
TF_Buffer* out) {
if (out->data != nullptr) {
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
}
const size_t proto_size = in.ByteSizeLong();
void* buf = tensorflow::port::Malloc(proto_size);
if (buf == nullptr) {
return tensorflow::errors::ResourceExhausted(
"Failed to allocate memory to serialize message of type '",
in.GetTypeName(), "' and size ", proto_size);
}
in.SerializeToArray(buf, proto_size);
out->data = buf;
out->length = proto_size;
out->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
return Status::OK();
}
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);

View File

@ -357,6 +357,14 @@ typedef struct TF_Output {
int index; // The index of the output within oper.
} TF_Output;
// TF_Function is a grouping of operations with defined inputs and outputs.
// Once created and added to graphs, functions can be invoked by creating an
// operation whose operation type matches the function name.
typedef struct TF_Function TF_Function;
// Function definition options. TODO(iga): Define and implement
typedef struct TF_FunctionOptions TF_FunctionOptions;
// Sets the shape of the Tensor referenced by `output` in `graph` to
// the shape described by `dims` and `num_dims`.
//
@ -914,6 +922,15 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef(
TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options, TF_Status* status);
// Add `function` to graph `g`. Once `function` is added to `g`,
// it can be called by creating an operation using the function's name.
//
// If successful, status is set to OK and function is added to g
// Otherwise, status is set to the encountered error and g is unmodified
TF_CAPI_EXPORT extern void TF_GraphAddFunction(TF_Graph* g,
const TF_Function* function,
TF_Status* status);
// Note: The following function may fail on very large protos in the future.
TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper,
@ -1001,6 +1018,105 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
TF_Output* x, int nx, TF_Output* dx,
TF_Status* status, TF_Output* dy);
// Create a TF_Function from a TF_Graph
//
// Params:
// fn_body - the graph whose operations (or subset of whose operations) will be
// converted to TF_Function.
// fn_name - the name of the new TF_Function. Should match the operation
// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct
// from other operation names (at least those registered in graphs
// where this function will be used).
// TODO(iga): Allow null in here and have C API come up with
// a unique name with high probability (similarly to
// _create_hash_str in function.py)
// num_opers - `num_opers` contains the number of elements in the `opers` array
// or a special value of -1 meaning that no array is given.
// The distinction between an empty array of operations and no
// array of operations is necessary to distinguish the case of
// creating a function with no body (e.g. identity or permutation)
// and the case of creating a function whose body contains all
// the nodes in the graph (except for the automatic skipping, see
// below).
// opers - Array of operations to become the body of the function or null.
// - If no array is given (`num_opers` = -1), all the
// operations in `fn_body` will become part of the function
// except operations referenced in `inputs`. These operations
// must have a single output (these operations are typically
// placeholders created for the sole purpose of representing
// an input. We can relax this constraint if there are
// compelling use cases).
// - If an array is given (`num_opers` >= 0), all operations
// in it will become part of the function. In particular, no
// automatic skipping of dummy input operations is performed.
// ninputs - number of elements in `inputs` array
// inputs - array of TF_Outputs that specify the inputs to the function.
// If `ninputs` is zero (the function takes no inputs), `inputs`
// can be null. The names used for function inputs are normalized
// names of the operations (usually placeholders) pointed to by
// `inputs`. These operation names should start with a letter.
// Normalization will convert all letters to lowercase and
// non-alphanumeric characters to '_' to make resulting names match
// the "[a-z][a-z0-9_]*" pattern for operation argument names.
// `inputs` cannot contain the same tensor twice.
// noutputs - number of elements in `outputs` array
// outputs - array of TF_Outputs that specify the outputs of the function.
// If `noutputs` is zero (the function returns no outputs), `outputs`
// can be null. `outputs` can contain the same tensor more than once.
// output_names - The names of the function's outputs. `output_names` array
// must either have the same length as `outputs`
// (i.e. `noutputs`) or be null. In the former case,
// the names should match the regular expression for ArgDef
// names - "[a-z][a-z0-9_]*". In the latter case,
// names for outputs will be generated automatically.
// opts - various options for the function, e.g. XLA's inlining control.
// status - Set to OK on success and an appropriate error on failure.
//
// Note that when the same TF_Output is listed as both an input and an output,
// the corresponding function's output will equal to this input,
// instead of the original node's output.
//
// Callers must also satisfy the following constraints:
// - `inputs` cannot refer to TF_Outputs within a control flow context. For
// example, one cannot use the output of "switch" node as input.
// - No TF_Output of a function (inside any of `inputs`, `outputs`, `fn_body`)
// is allowed to have a reference type. Reference types are not exposed
// through C API and are being deprecated.
// - Every node in the function's body must have all of its inputs (including
// control inputs). In other words, for every node in the body, each input
// must be either listed in `inputs` or must come from another node in
// the body. In particular, it is an error to have a control edge going from
// a node outside of the body into a node in the body. This applies to control
// edges going from nodes referenced in `inputs` to nodes in the body when
// the former nodes are not in the body (automatically skipped or not
// included in explicitly specified body).
//
// Returns:
// On successful, a newly created TF_Function instance. It must be deleted by
// calling TF_DeleteFunction.
//
// On failure, null.
//
// TODO(iga): Add input_names argument and get output_names working (they are
// currently ignored)
TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
const TF_Graph* fn_body, const char* fn_name, int num_opers,
const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs, const char* const* output_names,
const TF_FunctionOptions* opts, TF_Status* status);
// Write out a serialized representation of `func` (as a FunctionDef protocol
// message) to `output_func_def` (allocated by TF_NewBuffer()).
// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
// is called.
//
// May fail on very large graphs in the future.
TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func,
TF_Buffer* output_func_def,
TF_Status* status);
TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function*);
// TODO(josh11b): Register OpDef, available to all operations added
// to this graph.

View File

@ -0,0 +1,496 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api_internal.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
// Class that maintains a one-to-one original node name -> new node name
// mapping. We normalize the names used as input and output arguments to match
// regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
// Once we rename them, we risk creating a name collision with the other
// node names, so if necessary we add a suffix to make
// names unique. If we have an input named "A" and a node in the function
// body named "a", they will be renamed to "a" and "a_0".
class NodeNameMapping {
public:
NodeNameMapping() = default;
// Normalize the input/output name and make it unique.
string GetIOName(const string& name);
// Make the node name unique.
string Uniquify(const string& name);
// Look up how a node name was previously normalized/uniquified.
// Returns empty if name was never seen.
string Lookup(const string& name) const;
private:
string UniquifyHelper(const string& name) const;
static string Normalize(string name);
// The normalized/uniquified names already used as
// input names (in signature), output names (in signature), and node names
// (in node_def).
// This is a superset of values in name_mapping_.
std::unordered_set<string> used_names_;
// Mapping from original node name from the graph to the normalized
// and uniqified version of it.
std::unordered_map<string, string> name_mapping_;
};
string NodeNameMapping::Normalize(string name) {
// Convert letters to lowercase and non-alphanumeric characters to '_'.
if (name.empty()) return "unknown";
const int n = name.size();
for (int i = 0; i < n; ++i) {
char c = name[i];
if (isalnum(c)) {
if (isupper(c)) {
name[i] = tolower(c);
}
} else {
name[i] = '_';
}
}
// Find the first letter and start with it.
int i = 0;
for (; i < n; ++i) {
if (isalpha(name[i])) break;
}
// Return "unknown" if none of the name's chars were letters.
return i == n ? "unknown" : name.substr(i);
}
string NodeNameMapping::UniquifyHelper(const string& name) const {
// If the name hasn't been used yet, use it as-is.
if (used_names_.find(name) == used_names_.end()) return name;
// Add a suffix to name to make it unique.
for (int i = 0;; ++i) {
const string candidate = strings::StrCat(name, "_", i);
if (used_names_.find(candidate) == used_names_.end()) return candidate;
}
}
string NodeNameMapping::GetIOName(const string& name) {
const string& input_name = UniquifyHelper(Normalize(name));
// Record that we used this name, but don't add it to name_mapping_
// since this name is not for a node.
used_names_.insert(input_name);
return input_name;
}
string NodeNameMapping::Uniquify(const string& name) {
const string uniqued = UniquifyHelper(name);
name_mapping_[name] = uniqued;
used_names_.insert(uniqued);
return uniqued;
}
string NodeNameMapping::Lookup(const string& name) const {
const auto iter = name_mapping_.find(name);
if (iter == name_mapping_.end()) return string();
return iter->second;
}
Status ValidateNoRefOutputs(const Node* node) {
for (int i = 0; i < node->num_outputs(); ++i) {
const DataType& dt = node->output_type(i);
if (IsRefType(dt)) {
return errors::InvalidArgument("Output ", i, " of node '", node->name(),
"' has a reference "
"type ",
DataTypeString(dt));
}
}
return Status::OK();
}
Status FillFunctionBody(
const string& fn_name, const NodeNameMapping& node_names,
const std::vector<const Node*>& body_nodes,
const std::unordered_map<string, string>& tensor_renaming,
FunctionDef* fdef) {
std::vector<const Edge*> in_edges;
std::vector<const Edge*> control_edges;
for (const Node* node : body_nodes) {
NodeDef* node_def = fdef->add_node_def();
// First, copy the node_def as is. We will patch it next.
*node_def = node->def();
if (!node->assigned_device_name().empty()) {
node_def->set_device(node->assigned_device_name());
}
node_def->set_name(node_names.Lookup(node->name()));
// Input names must be set based on nested names in tensor_renaming.
// Clear the flat input names we got from the original node_def
// from the graph.
node_def->clear_input();
// Collect regular and control inputs. Regular inputs are indexed
// by the index at which they come into the `node`. Control inputs
// don't follow any order.
in_edges.clear();
in_edges.resize(node->num_inputs(), nullptr);
control_edges.clear();
for (const Edge* edge : node->in_edges()) {
if (edge->src()->IsSource()) continue;
if (edge->IsControlEdge()) {
control_edges.push_back(edge);
} else {
in_edges[edge->dst_input()] = edge;
}
}
// Add regular inputs.
for (size_t i = 0; i < in_edges.size(); ++i) {
const Edge* edge = in_edges[i];
string original_input_name;
if (edge == nullptr) {
// A backedge might not appear as a regular Edge, but be only present
// in the node_def. Such edges are referred to as requested_inputs().
if (i >= node->requested_inputs().size()) {
return errors::InvalidArgument(
"Graph to be converted to function appears to be malformed. ",
"Node ", node->name(), " is missing input edge ", i);
}
original_input_name =
ParseTensorName(node->requested_inputs()[i]).ToString();
} else {
original_input_name =
strings::StrCat(edge->src()->name(), ":", edge->src_output());
}
const auto iter = tensor_renaming.find(original_input_name);
if (iter == tensor_renaming.end()) {
return errors::InvalidArgument(
"Input ", i, ", '", original_input_name, "', of node '",
node->name(), "' in function '", fn_name,
"' is not available. You might need to include it in inputs "
"or include its source node in the body");
}
node_def->add_input(iter->second);
}
// Add control inputs.
for (const Edge* edge : control_edges) {
// Add this control input only if the src node is in the body.
const string normalized = node_names.Lookup(edge->src()->name());
// If we did not find a name for the source of control edge, this
// source must be outside of the body. Raise an error.
if (normalized.empty()) {
return errors::InvalidArgument(
"The source of control edge ", edge->DebugString(),
" is not in the body. Encountered while creating function '",
fn_name, "'");
}
node_def->add_input(strings::StrCat("^", normalized));
}
}
return Status::OK();
}
// Graph to FunctionDef conversion. This code is closely modeled on the Python
// code in third_party/tensorflow/python/framework/function.py.
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
const std::vector<const Node*>& body_nodes,
const std::vector<OutputTensor>& inputs,
const std::vector<OutputTensor>& outputs,
const std::vector<string>& output_names,
FunctionDef* fdef) {
fdef->mutable_signature()->set_name(fn_name);
// Keep track of names we used and how we normalized them.
NodeNameMapping node_names;
// Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
// name we used in the function:
// - For input tensors:
// {flat_tensor_name -> normalized_name_of_src_node}
// e.g. {In:3 -> in}
// - For tensors produced by nodes in function's body:
// {flat_tensor_name -> nested_tensor_name}
// e.g. {Add:3 -> add_0:z:1}
std::unordered_map<string, string> tensor_renaming;
// Fill inputs in function's signature.
for (size_t i = 0; i < inputs.size(); ++i) {
const Node* node = inputs[i].node;
int idx = inputs[i].index;
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
argdef->set_type(node->output_type(idx));
const string& input_name = node_names.GetIOName(node->name());
argdef->set_name(input_name);
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
}
// Fill outputs in function's signature.
for (size_t i = 0; i < outputs.size(); ++i) {
const Node* node = outputs[i].node;
int idx = outputs[i].index;
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
argdef->set_type(node->output_type(idx));
argdef->set_name(node_names.GetIOName(node->name()));
}
// Populate tensor_renaming and node_names.
// Generate the new output names for every node in the function.
// The NodeDefs in FunctionDefs use a different naming scheme for
// their inputs than the NodeDefs in a graph (see the comment for
// FunctionDef.node_def in function.proto). We do the
// graph tensor name -> function tensor name conversion for every
// possible input (i.e. every node's outputs) and store the result
// in tensor_renaming.
for (const Node* node : body_nodes) {
// Make sure node_name does not collide with an input or output name.
const string& node_name = node_names.Uniquify(node->name());
// For each output_arg in the op_def, the output_ranges
// map will have [start, end] range of indices that this arg produces
// among all the output tensors of this op.
NameRangeMap output_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
for (const auto& output : output_ranges) {
const string& output_name = output.first;
int index_start = output.second.first;
int index_end = output.second.second;
for (int i = index_start; i < index_end; ++i) {
const string& original_name = strings::StrCat(node->name(), ":", i);
const string& new_name =
strings::StrCat(node_name, ":", output_name, ":", i - index_start);
// Record the mapping if this tensor is not already mapped.
// Tensor can be already mapped if it is used as an input.
if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
tensor_renaming[original_name] = new_name;
}
}
}
}
TF_RETURN_IF_ERROR(
FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
// Remap return values.
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
const string& ret_name = fdef->signature().output_arg(r).name();
// We convert this flat tensor name to the nested value
// (e.g. `add:z:1`) that we stored in tensor_renaming.
const string& return_value =
strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
const auto iter = tensor_renaming.find(return_value);
if (iter == tensor_renaming.end()) {
return errors::InvalidArgument(
"TF_Output ", return_value, " is neither in the function body ",
"nor among function inputs. Encountered while creating function '",
fn_name, "'");
}
(*fdef->mutable_ret())[ret_name] = iter->second;
}
return Status::OK();
}
// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
// does various checks while doing so. `input_nodes` will contain the same
// information as input_tensors just in a different structure to make
// following processing easier. TODO(iga): Simplify this nested structure.
Status ProcessInputs(
const TF_Graph* fn_body, const char* fn_name, int ninputs,
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
const Node& node = inputs[i].oper->node;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(&node, idx),
"Encountered while processing input ", i, " into function '", fn_name,
"'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node),
"Encountered while processing input ", i,
" into function '", fn_name, "'");
input_tensors->emplace_back(&node, idx);
const auto& iter = input_nodes->find(&node);
if (iter == input_nodes->end()) {
input_nodes->insert({&node, {idx}});
} else {
auto& indices = iter->second;
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
return errors::InvalidArgument(
"TF_Output ", node.name(), ":", idx,
" appears more than once in the input list");
}
indices.push_back(idx);
}
}
return Status::OK();
}
// Converts `noutputs` and `outputs` into `outputs_tensors` and does various
// checks while doing so.
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
int noutputs, const TF_Output* outputs,
std::vector<OutputTensor>* output_tensors)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
const Node& node = outputs[i].oper->node;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(&node, idx),
"Encountered while processing output ", i, " from function '", fn_name,
"'");
output_tensors->emplace_back(&node, idx);
}
return Status::OK();
}
// Populates `body_nodes` with the nodes that will become function's body.
// Performs various checks.
Status ComputeBodyNodes(
const TF_Graph* fn_body, const char* fn_name, int num_opers,
const TF_Operation* const* opers,
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
std::vector<const Node*>* body_nodes)
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
if (num_opers == -1) {
for (const Node* node : fn_body->graph.op_nodes()) {
const auto& iter = input_nodes.find(node);
if (iter == input_nodes.end()) {
// This node is not referenced in inputs. Add it to the body.
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
"Encountered while creating function '",
fn_name, "'");
body_nodes->push_back(node);
} else {
// This node is referenced in inputs. Currently, we place an
// artificial restriction and require that when num_opers=-1, such
// nodes must have a single output.
if (node->num_outputs() != 1) {
return errors::InvalidArgument(
"When `num_opers` is set to -1, nodes referenced in `inputs` "
"must have a single output. Node ",
node->name(), " has ", node->num_outputs(),
" outputs. Encountered while creating function '", fn_name, "'");
}
}
}
} else {
body_nodes->reserve(num_opers);
for (int i = 0; i < num_opers; ++i) {
const Node* node = &opers[i]->node;
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
"Encountered while creating function '",
fn_name, "'");
body_nodes->push_back(node);
}
}
return Status::OK();
}
} // anonymous namespace
} // namespace tensorflow
using tensorflow::Node;
using tensorflow::string;
TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
int num_opers, const TF_Operation* const* opers,
int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs,
const char* const* output_names,
const TF_FunctionOptions* opts,
TF_Status* status) {
tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
// Process inputs.
std::vector<tensorflow::OutputTensor> input_tensors;
std::unordered_map<const Node*, std::vector<int>> input_nodes;
status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
&input_tensors, &input_nodes);
if (!status->status.ok()) return nullptr;
// Process outputs.
std::vector<tensorflow::OutputTensor> output_tensors;
status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
outputs, &output_tensors);
if (!status->status.ok()) return nullptr;
// Process output names.
std::vector<string> output_names_vec;
if (output_names) {
output_names_vec.reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names_vec.push_back(string(output_names[i]));
}
}
// Compute body nodes.
std::vector<const Node*> body_nodes;
status->status = tensorflow::ComputeBodyNodes(
fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
if (!status->status.ok()) return nullptr;
// Do the actual function creation.
TF_Function* tf_function = new TF_Function();
status->status = tensorflow::GraphToFunctionDef(
fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
output_names_vec, tf_function->fdef_lib.add_function());
if (!status->status.ok()) {
TF_DeleteFunction(tf_function);
return nullptr;
}
return tf_function;
}
void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function,
TF_Status* status) {
tensorflow::mutex_lock l(g->mu);
// At the moment, we have only one function and no gradients in fdef_lib.
// This makes the following operation atomic.
// TODO(iga): Add an atomic version of AddFunctionLibrary when we support
// gradients
status->status = g->graph.AddFunctionLibrary(function->fdef_lib);
}
void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
TF_Status* status) {
DCHECK_EQ(1, func->fdef_lib.function_size());
status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def);
}
void TF_DeleteFunction(TF_Function* function) { delete function; }

File diff suppressed because it is too large Load Diff

View File

@ -130,6 +130,11 @@ struct TF_DeviceList {
std::vector<tensorflow::DeviceAttributes> response;
};
struct TF_Function {
// Currently contains a single function and no gradients
tensorflow::FunctionDefLibrary fdef_lib;
};
namespace tensorflow {
class TensorCApi {
@ -141,7 +146,12 @@ class TensorCApi {
}
};
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
} // end namespace tensorflow
#endif // TENSORFLOW_C_C_API_INTERNAL_H_

View File

@ -829,7 +829,7 @@ TEST(CAPI, ShapeInferenceError) {
TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Operation* add = Add(vec2, vec3, graph, status);
TF_Operation* add = AddNoCheck(vec2, vec3, graph, status);
ASSERT_NE(TF_OK, TF_GetCode(status));
ASSERT_TRUE(add == nullptr);

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
using tensorflow::GraphDef;
@ -36,6 +38,23 @@ TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
return t;
}
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
const int32_t* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
num_values *= dims[i];
}
TF_Tensor* t =
TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
return t;
}
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
int64_t dims = values.size();
return Int32Tensor(&dims, 1, values.data());
}
TF_Tensor* Int32Tensor(int32_t v) {
const int num_bytes = sizeof(int32_t);
int32_t* values = new int32_t[1];
@ -44,19 +63,40 @@ TF_Tensor* Int32Tensor(int32_t v) {
&Int32Deallocator, nullptr);
}
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
// All the *Helper methods are used as a workaround for the restrictions that
// one cannot call ASSERT_* methods in non-void-returning functions (when
// exceptions are disabled during compilation)
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
return TF_FinishOperation(desc, s);
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
TF_Operation* op;
PlaceholderHelper(graph, s, name, &op);
return op;
}
void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", t, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", t, s);
if (TF_GetCode(s) != TF_OK) return nullptr;
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
return TF_FinishOperation(desc, s);
TF_Operation* op;
ConstHelper(t, graph, s, name, &op);
return op;
}
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
@ -65,11 +105,39 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
return Const(tensor.get(), graph, s, name);
}
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
const char* name, TF_Operation** op, bool check) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
*op = TF_FinishOperation(desc, s);
if (check) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
}
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
TF_Operation* op;
AddHelper(l, r, graph, s, name, &op, true);
return op;
}
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
TF_Operation* op;
AddHelper(l, r, graph, s, name, &op, false);
return op;
}
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
TF_Graph* graph, TF_Operation* ctrl_op,
TF_Status* s, const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
TF_AddControlInput(desc, ctrl_op);
return TF_FinishOperation(desc, s);
}
@ -81,11 +149,20 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
return TF_FinishOperation(desc, s);
}
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
TF_Output neg_input = {n, 0};
TF_AddInput(desc, neg_input);
return TF_FinishOperation(desc, s);
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
TF_Operation* op;
NegHelper(n, graph, s, &op);
return op;
}
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
@ -96,6 +173,32 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
return TF_FinishOperation(desc, s);
}
void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name, TF_Operation** op) {
TF_Operation* zero = ScalarConst(
0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
TF_AddInput(desc, {zero, 0});
TF_AddInput(desc, {input, 0});
TF_SetAttrInt(desc, "num_split", 3);
TF_SetAttrType(desc, "T", TF_INT32);
// Set device to CPU since there is no version of split for int32 on GPU
// TODO(iga): Convert all these helpers and tests to use floats because
// they are usually available on GPUs. After doing this, remove TF_SetDevice
// call in c_api_function_test.cc
TF_SetDevice(desc, "/cpu:0");
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_Operation* op;
Split3Helper(input, graph, s, name, &op);
return op;
}
bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
return false;
@ -196,6 +299,18 @@ bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
return ret;
}
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
TF_Status* s = TF_NewStatus();
TF_Buffer* buffer = TF_NewBuffer();
TF_FunctionToFunctionDef(func, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
TF_DeleteBuffer(buffer);
TF_DeleteStatus(s);
return ret;
}
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s) {
TF_Buffer* buffer = TF_NewBuffer();

View File

@ -33,6 +33,13 @@ typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
// Create a tensor with values of type TF_INT32 provided by `values`.
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
const int32_t* values);
// Create 1 dimensional tensor with values from `values`
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values);
TF_Tensor* Int32Tensor(int32_t v);
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
@ -47,6 +54,13 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "add");
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "add");
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
TF_Graph* graph, TF_Operation* ctrl_op,
TF_Status* s, const char* name = "add");
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name = "add");
@ -54,6 +68,10 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s);
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);
// Split `input` along the first dimention into 3 tensors
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name = "split3");
bool IsPlaceholder(const tensorflow::NodeDef& node_def);
bool IsScalarConst(const tensorflow::NodeDef& node_def, int v);
@ -66,6 +84,8 @@ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def);
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def);
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s);

View File

@ -151,10 +151,11 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
return TF_SessionListDevices(ctx->session, status);
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) {
return new TFE_TensorHandle(
tensorflow::TensorCApi::MakeTensor(t->dtype, t->shape, t->buffer),
nullptr);
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;
return new TFE_TensorHandle(tensor, nullptr);
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }

View File

@ -20,6 +20,25 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
// This duplicates the TF_EXPORT macro definition in
// tensorflow/core/platform/macros.h in order to keep this .h file independent
// of any other includes.$a
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(COMPILER_MSVC)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // COMPILER_MSVC
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
@ -30,11 +49,11 @@ extern "C" {
// TODO(ashankar): Merge with TF_Session?
typedef struct TFE_Context TFE_Context;
extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
TF_Status* status);
extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);
// A handle to a tensor on a device.
//
@ -43,14 +62,15 @@ extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
// placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle;
extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t);
extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index);
extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h);
extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index);
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h);
TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
TF_Status* status);
// Create a new TFE_TensorHandle with the same contents as 'h' but placed
// in the memory of the device name 'device_name'.
@ -58,10 +78,10 @@ extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
// that shares the underlying buffer. Otherwise, it currently requires at least
// one of the source or destination devices to be CPU (i.e., for the source or
// destination tensor to be placed in host memory).
extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status);
// Description of the TensorFlow op to execute.
//
@ -76,49 +96,49 @@ extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
// the additional sanity checks there seem unnecessary;
typedef struct TFE_Op TFE_Op;
extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status);
extern void TFE_DeleteOp(TFE_Op* op);
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
// TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context
// parameter. Instead, the TFE_Context should be captured when creating the
// TFE_Op.
extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx,
const char* device_name, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx,
const char* device_name, TF_Status* status);
extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status);
TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status);
extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name,
const char* value);
extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value);
extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value);
extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
unsigned char value);
extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
TF_DataType value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name,
const char* value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
unsigned char value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
TF_DataType value);
// If the number of dimensions is unknown, `num_dims` must be set to
// -1 and `dims` can be null. If a dimension is unknown, the
// corresponding entry in the `dims` array must be -1.
extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
const int64_t* dims, const int num_dims,
TF_Status* out_status);
TF_CAPI_EXPORT extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
const int64_t* dims, const int num_dims,
TF_Status* out_status);
extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const char** value, int num_values);
extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values);
extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values);
extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values);
extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values);
extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status);
TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const char** value, int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status);
// Execute the operation defined by 'op' and return handles to computed
// tensors in 'retvals'.
@ -128,14 +148,14 @@ extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
//
// On return, 'num_retvals' will be set to the actual number of outputs
// returned by the operation.
extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
int* num_retvals, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
int* num_retvals, TF_Status* status);
// Add a function (serialized FunctionDef protocol buffer) to ctx so
// that it can be invoked using TFE_Execute.
extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
const char* serialized_function_def,
size_t size, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
const char* serialized_function_def,
size_t size, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */

View File

@ -34,7 +34,9 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandle(t);
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
return th;
}
@ -383,7 +385,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle);
value_handle(TFE_NewTensorHandle(t.get(), status), TFE_DeleteTensorHandle);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op, value_handle.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;

View File

@ -2,6 +2,7 @@ VERS_1.0 {
# Export symbols in c_api.h.
global:
*TF_*;
*TFE_*;
# Hide everything else.
local:

View File

@ -77,6 +77,10 @@ class SymbolicGradientBuilder {
Status CallGradFunction(const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs);
// Returns a list mapping whether each node in the graph is reachable
// from outputs_. Keyed by node id.
std::vector<bool> GetReachableNodes();
const Scope& scope_;
const ops::GradOpRegistry* registry_;
@ -143,11 +147,36 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad,
return Status::OK();
}
std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
std::vector<bool> reachable_nodes(scope_.graph()->num_node_ids(), false);
std::deque<Node*> queue;
for (const Output& out : outputs_) {
if (!reachable_nodes[out.node()->id()]) {
queue.push_back(out.node());
reachable_nodes[out.node()->id()] = true;
}
}
while (!queue.empty()) {
Node* n = queue.front();
queue.pop_front();
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) continue;
queue.push_back(e->src());
reachable_nodes[e->src()->id()] = true;
}
}
return reachable_nodes;
}
Status SymbolicGradientBuilder::Initialize() {
if (outputs_.size() != grad_inputs_.size()) {
return errors::InvalidArgument(
"Must specify a gradient input for each output.");
}
std::vector<bool> reachable_nodes = GetReachableNodes();
// TODO(theflofly) Check that inputs_ are reachable from
// outputs_ using reachable_nodes
grad_outputs_->clear();
grad_outputs_->resize(inputs_.size());
// Populate `output_nodes_` from node ids in `outputs_`.
@ -188,12 +217,15 @@ Status SymbolicGradientBuilder::Initialize() {
if (output_nodes_.find(n->id()) == output_nodes_.end()) {
// Internal node: continue BFS along connected outputs.
for (const Edge* e : n->out_edges()) {
if (e->IsControlEdge()) continue;
++num_expected_backprops;
// If a node is not reachable from outputs_,
// we don't expect it to receive a backpropagated gradient.
// It will not be counted in num_expected_backprops.
if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue;
if (visited.find(e->dst()) == visited.end()) {
queue.push_back(e->dst());
visited.insert(e->dst());
}
++num_expected_backprops;
}
} else {
// Output node: stop BFS and update `num_expected_backprops` for

View File

@ -364,6 +364,73 @@ TEST_F(GradientsTest, MultipleNodeOutputGrads) {
test::AsTensor<int>({60, 61, 62, 63, 66, 66, 66, 67}, {4, 2}));
}
TEST_F(GradientsTest, UnreachableEdgeGradOneOutput) {
auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
auto x_assign = Assign(scope_test_, x, x_const);
auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
auto y_assign = Assign(scope_test_, y, y_const);
auto m1 = MatMul(scope_test_, x, y);
auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
auto z_assign = Assign(scope_test_, z, z_const);
auto m2 = MatMul(scope_test_, y, z);
auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
std::vector<Output> grad_outputs;
TF_ASSERT_OK(
AddSymbolicGradients(scope_test_, {m1}, {y}, {dm1}, &grad_outputs));
std::vector<Tensor> outputs;
test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
{grad_outputs[0]}, &outputs);
// dz/dy = xT * dm1
test::ExpectTensorNear<double>(
outputs[0], test::AsTensor<double>({2.5, 3.5, 4.5}, {3, 1}), 1e-5);
}
TEST_F(GradientsTest, UnreachableEdgeGradTwoOutputs) {
auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
auto x_assign = Assign(scope_test_, x, x_const);
auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
auto y_assign = Assign(scope_test_, y, y_const);
auto m1 = MatMul(scope_test_, x, y);
auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
auto z_assign = Assign(scope_test_, z, z_const);
auto m2 = MatMul(scope_test_, y, z);
auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
auto dm2 =
Const(scope_test_, {{0.5, 0.5, 0.5}, {0.6, 0.7, 0.8}, {0.6, 0.7, 0.9}});
std::vector<Output> grad_outputs;
TF_ASSERT_OK(AddSymbolicGradients(scope_test_, {m1, m2}, {y}, {dm1, dm2},
&grad_outputs));
std::vector<Tensor> outputs;
test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
{grad_outputs[0]}, &outputs);
// the gradients from m1 and m2 will be summed to compute the gradient
// w.r.t y
// dz/dy = xT * dm1 + dm2 * zT
test::ExpectTensorNear<double>(
outputs[0], test::AsTensor<double>({17.5, 24.7, 26.8}, {3, 1}), 1e-5);
}
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
// a single nodes output.

View File

@ -36,5 +36,19 @@ void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
*out = outputs[0];
}
void GetTensors(const Scope& scope, const std::vector<Output>& assign_vars,
OutputList tensors, std::vector<Tensor>* out) {
ClientSession session(scope);
TF_CHECK_OK(session.Run(assign_vars, nullptr));
TF_CHECK_OK(session.Run(tensors, out));
}
void GetTensor(const Scope& scope, const std::vector<Output>& assign_vars,
Output tensor, Tensor* out) {
std::vector<Tensor> outputs;
GetTensors(scope, assign_vars, {std::move(tensor)}, &outputs);
*out = outputs[0];
}
} // end namespace test
} // end namespace tensorflow

View File

@ -26,9 +26,21 @@ namespace test {
void GetTensors(const Scope& scope, OutputList tensors,
std::vector<Tensor>* out);
// Computes the outputs listed in 'tensors', returns the tensors in 'out'.
// assign_vars are extra outputs that should be run
// e.g. to assign values to variables.
void GetTensors(const Scope& scope, const std::vector<Output>& assign_vars,
OutputList tensors, std::vector<Tensor>* out);
/// Computes the output 'tensor', returning the resulting tensor in 'out'.
void GetTensor(const Scope& scope, Output tensor, Tensor* out);
// Computes the output 'tensor', returning the resulting tensor in 'out'.
// assign_vars are extra outputs that should be run
// e.g. to assign values to variables.
void GetTensor(const Scope& scope, const std::vector<Output>& assign_vars,
Output tensor, Tensor* out);
} // namespace test
} // namespace tensorflow

View File

@ -687,6 +687,72 @@ Status MeanGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Mean", MeanGrad);
Status MinOrMaxGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// The partial derivative for any input along a "reduced" dimension
// is 1 when it is the min (or max) and 0 everywhere else. So the
// gradient calculation is identical for both operators.
//
// There's a special case for propagating gradients when there are
// multiple minima (or maxima) - we choose to divide the gradient
// equally among all matching inputs.
//
// Please note this comment
// https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
// for details.
// Running example:
// input: [[5, 5, 5],
// [1, 2, -3]]
// reduction_indices: [1]
auto input = op.input(0);
auto reduction_indices = op.input(1);
// [2, 3]
auto input_shape = Shape(scope, input);
// [2, 1]
auto output_shape_kept_dims =
ReducedShapeHelper(scope, input_shape, reduction_indices);
// for op=min (say)
// output = [5, -3]
// y = [[5],
// [-3]]
auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
// reshape([g1, g2], [2, 1]) = [[g1],
// [g2]]
auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
// indicators = equal(y, input)
// = equal([[5], [[5, 5, 5],
// [-3]], [1, 2, -3]])
// = [[1, 1, 1],
// [0, 0, 1]]
auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
// [[3],
// [1]]
auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
output_shape_kept_dims);
// [[1/3, 1/3, 1/3],
// [0, 0, 1]]
auto scale = Div(scope, indicators, num_selected);
// [[g1/3, g1/3, g1/3],
// [0, 0, g2]]
grad_outputs->push_back(Mul(scope, scale, grad));
// Stop propagation along reduction_indices
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const bool is_batch,

View File

@ -955,6 +955,55 @@ TEST_F(NaryGradTest, Mean) {
RunTest({x}, {x_shape}, {y}, {y_shape});
}
TEST_F(NaryGradTest, Min) {
TensorShape x_shape({2, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto y = Min(scope_, x, {-1});
// y's shape is the result of reducing x along axes -1 (= 1)
TensorShape y_shape({2});
Tensor x_init_value =
test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
RunTest(x, x_init_value, y, y_shape);
}
TEST_F(NaryGradTest, Max) {
TensorShape x_shape({2, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto y = Max(scope_, x, {-1});
// y's shape is the result of reducing x along axes -1 (= 1)
TensorShape y_shape({2});
Tensor x_init_value =
test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
RunTest(x, x_init_value, y, y_shape);
}
TEST_F(NaryGradTest, MinMulti) {
// Test gradient when there are multiple minima.
// Note that we cannot directly use a test Tensor with multiple
// minima, as the numeric estimator will calculate incorrect
// gradients when perturbing each entry in the Tensor (which then
// changes how many minima exist.)
// Instead, we use a single input that broadcast-multiplies a larger
// tensor with equal values, and apply reduce_min to the multiplied
// result.
TensorShape x_shape({1});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
auto y = Min(scope_, all_same, {0});
// y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
TensorShape y_shape({1});
RunTest({x}, {x_shape}, {y}, {y_shape});
}
TEST_F(NaryGradTest, MaxMulti) {
TensorShape x_shape({1});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
auto y = Max(scope_, all_same, {0});
TensorShape y_shape({1});
RunTest({x}, {x_shape}, {y}, {y_shape});
}
TEST_F(NaryGradTest, AddN) {
TensorShape shape({3, 2, 5});
std::vector<Output> xs;

View File

@ -52,6 +52,12 @@ class BinaryOpsTest(XLATestCase):
def testFloatOps(self):
for dtype in self.float_types:
self._testBinary(
lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype),
np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype),
expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
self._testBinary(
gen_math_ops._real_div,
np.array([3, 3, -1.5, -8, 44], dtype=dtype),
@ -82,6 +88,12 @@ class BinaryOpsTest(XLATestCase):
dtype(4),
expected=np.array([[16], [81]], dtype=dtype))
self._testBinary(
gen_math_ops._reciprocal_grad,
np.array([4, -3, -2, 1], dtype=dtype),
np.array([5, -6, 7, -8], dtype=dtype),
expected=np.array([-80, 54, -28, 8], dtype=dtype))
self._testBinary(
gen_math_ops._sigmoid_grad,
np.array([4, 3, 2, 1], dtype=dtype),
@ -107,6 +119,13 @@ class BinaryOpsTest(XLATestCase):
expected=np.array(
[3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype))
self._testBinary(
gen_nn_ops._softsign_grad,
np.array([4, 3, 2, 1], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array(
[0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype))
self._testBinary(
gen_math_ops._tanh_grad,
np.array([4, 3, 2, 1], dtype=dtype),

View File

@ -888,6 +888,16 @@ TEST_F(OpTest, Any) {
});
}
TEST_F(OpTest, ApproximateEqual) {
Repeatedly([this]() {
auto dims = RandomDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Asinh) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
@ -1662,11 +1672,9 @@ TEST_F(OpTest, GreaterEqual) {
TEST_F(OpTest, L2Loss) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
// TODO(b/31644876): scalars currently crash.
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss")
.RandomInput(type, RandomDims(1))
.Attr("T", type));
DataType type = DT_FLOAT;
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("L2Loss").RandomInput(type).Attr("T", type));
});
}
@ -2165,6 +2173,15 @@ TEST_F(OpTest, Reciprocal) {
});
}
TEST_F(OpTest, ReciprocalGrad) {
Repeatedly([this]() {
std::vector<int64> dims = RandomDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Relu) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
@ -2250,6 +2267,13 @@ TEST_F(OpTest, ReverseV2) {
});
}
TEST_F(OpTest, Rint) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Rint").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Round) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
@ -2402,6 +2426,23 @@ TEST_F(OpTest, SoftplusGrad) {
});
}
TEST_F(OpTest, Softsign) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Softsign").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SoftsignGrad) {
Repeatedly([this]() {
std::vector<int64> dims = RandomDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftsignGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, SpaceToBatch) {
Repeatedly([this]() {
std::vector<int64> block_dims = RandomDims(4, 4, 0, 5);

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@ -161,12 +163,17 @@ class UnaryOpsTest(XLATestCase):
np.array([[-1.7, 1.2]], dtype=dtype),
expected=np.array([[-2, 1]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
dtype=dtype),
expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool))
# Tests for tf.nn ops.
self._assertOpOutputMatchesExpected(
nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0))
# TODO(b/31644876): enable this test case when fixed.
# self._assertOpOutputMatchesExpected(tf.nn.l2_loss, dtype(4), dtype(10))
self._assertOpOutputMatchesExpected(nn_ops.l2_loss, dtype(4), dtype(8))
self._assertOpOutputMatchesExpected(
nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10))
@ -198,6 +205,12 @@ class UnaryOpsTest(XLATestCase):
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.rint,
np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
[0.5, 1.5, 2.5, 3.5]], dtype=dtype),
expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.round,
np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
@ -301,6 +314,12 @@ class UnaryOpsTest(XLATestCase):
np.array([[-2, 0, 8]], dtype=dtype),
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
self._assertOpOutputMatchesExpected(
nn_ops.softsign,
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
np.array(
@ -335,6 +354,23 @@ class UnaryOpsTest(XLATestCase):
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
# TODO(phawkins): these tests fail unless fastmath optimizations
# are disabled. Use more robust IsInf/IsNaN detection and enable these
# tests.
@unittest.skip("test case fails in fast-math mode")
def testIsInfAndIsNan(self):
for dtype in self.float_types:
self._assertOpOutputMatchesExpected(
math_ops.is_inf,
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
dtype=dtype),
expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool))
self._assertOpOutputMatchesExpected(
math_ops.is_nan,
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
dtype=dtype),
expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool))
def testLogicalOps(self):
self._assertOpOutputMatchesExpected(
math_ops.logical_not,

View File

@ -31,7 +31,6 @@ tf_kernel_library(
"function_ops.cc",
"gather_op.cc",
"identity_op.cc",
"is_finite_op.cc",
"l2loss_op.cc",
"lrn_ops.cc",
"matmul_op.cc",

View File

@ -102,6 +102,7 @@ XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs))));
XLA_MAKE_BINARY(
RsqrtGrad,
b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
@ -140,6 +141,11 @@ XLA_MAKE_BINARY(SoftplusGrad,
b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
XlaHelpers::One(b, input_type(1)))));
// softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
XLA_MAKE_BINARY(SoftsignGrad,
b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)),
b->Abs(rhs)))));
XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
b->Mul(lhs, lhs))));
@ -147,5 +153,24 @@ XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions));
#undef XLA_MAKE_BINARY
class ApproximateEqualOp : public XlaOpKernel {
public:
explicit ApproximateEqualOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance_));
}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))),
XlaHelpers::FloatLiteral(b, input_type(0), tolerance_));
ctx->SetOutput(0, result);
}
private:
float tolerance_;
};
REGISTER_XLA_OP(Name("ApproximateEqual"), ApproximateEqualOp);
} // namespace
} // namespace tensorflow

View File

@ -1,43 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace {
class IsFiniteOp : public XlaOpKernel {
public:
explicit IsFiniteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle input = ctx->Input(0);
ctx->SetOutput(0, ctx->builder()->IsFinite(input));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(IsFiniteOp);
};
REGISTER_XLA_OP(Name("IsFinite"), IsFiniteOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -73,8 +73,12 @@ XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0))));
XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x));
XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x),
XlaHelpers::FloatLiteral(
b, input_type(0),
std::numeric_limits<double>::infinity())));
XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x));
// Return 1/x
XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x));
XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x));
@ -105,6 +109,12 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b,
b->Add(round_val, one), round_val);
}
XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x));
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
XLAJIT_MAKE_UNARY(Rsqrt,
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
DataType dtype,
@ -112,16 +122,19 @@ static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
}
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
XLAJIT_MAKE_UNARY(Rsqrt,
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
XLAJIT_MAKE_UNARY(Sinh,
b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
XLAJIT_MAKE_UNARY(Softplus,
b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0)))));
// softsign(x) = x / (abs(x) + 1)
XLAJIT_MAKE_UNARY(Softsign,
b->Div(x,
b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0)))));
XLAJIT_MAKE_UNARY(Sqrt,
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));

View File

@ -847,6 +847,7 @@ cc_test(
srcs = ["hlo_ordering_test.cc"],
deps = [
":hlo",
":hlo_dataflow_analysis",
":hlo_ordering",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",

View File

@ -241,7 +241,7 @@ Status Executor::Run() {
completion_queue_.pop_front();
break;
}
} while (1);
} while (true);
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelSlice(instruction));
void* result_buffer =

View File

@ -24,16 +24,14 @@ limitations under the License.
namespace xla {
Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo,
HloOpcode opcode) {
Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo) {
return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s",
HloOpcodeString(opcode).c_str());
HloOpcodeString(hlo->opcode()).c_str());
}
Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) {
Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo) {
return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s",
HloOpcodeString(opcode).c_str());
HloOpcodeString(hlo->opcode()).c_str());
}
DfsHloVisitor::VisitState DfsHloVisitor::GetVisitState(

View File

@ -63,37 +63,37 @@ class DfsHloVisitor {
// These routines are self-descriptive, see class comment for usage
// information.
virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode);
virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode);
virtual Status HandleElementwiseUnary(HloInstruction* hlo);
virtual Status HandleElementwiseBinary(HloInstruction* hlo);
virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
HloInstruction* arg, HloInstruction* max) = 0;
virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred,
HloInstruction* on_true,
HloInstruction* on_false) = 0;
virtual Status HandleMaximum(HloInstruction* maximum) {
return HandleElementwiseBinary(maximum, HloOpcode::kMaximum);
return HandleElementwiseBinary(maximum);
}
virtual Status HandleMinimum(HloInstruction* minimum) {
return HandleElementwiseBinary(minimum, HloOpcode::kMinimum);
return HandleElementwiseBinary(minimum);
}
virtual Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
virtual Status HandleConvert(HloInstruction* convert) {
return HandleElementwiseUnary(convert, HloOpcode::kConvert);
return HandleElementwiseUnary(convert);
}
virtual Status HandleCopy(HloInstruction* copy) {
return HandleElementwiseUnary(copy, HloOpcode::kCopy);
return HandleElementwiseUnary(copy);
}
virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(multiply, HloOpcode::kMultiply);
return HandleElementwiseBinary(multiply);
}
virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
HloInstruction* rhs) = 0;
virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(power, HloOpcode::kPower);
return HandleElementwiseBinary(power);
}
virtual Status HandleConvolution(HloInstruction* convolution,
HloInstruction* lhs, HloInstruction* rhs,
@ -101,73 +101,72 @@ class DfsHloVisitor {
virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0;
virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
HloInstruction* lhs, HloInstruction* rhs) {
return HandleElementwiseBinary(compare, opcode);
return HandleElementwiseBinary(compare);
}
virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(add, HloOpcode::kAdd);
return HandleElementwiseBinary(add);
}
virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(divide, HloOpcode::kDivide);
return HandleElementwiseBinary(divide);
}
virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(remainder, HloOpcode::kRemainder);
return HandleElementwiseBinary(remainder);
}
virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(subtract, HloOpcode::kSubtract);
return HandleElementwiseBinary(subtract);
}
virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
return HandleElementwiseUnary(abs, HloOpcode::kAbs);
return HandleElementwiseUnary(abs);
}
virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) {
return HandleElementwiseUnary(sign, HloOpcode::kSign);
return HandleElementwiseUnary(sign);
}
virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) {
return HandleElementwiseUnary(negate, HloOpcode::kNegate);
return HandleElementwiseUnary(negate);
}
virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) {
return HandleElementwiseUnary(exp, HloOpcode::kExp);
return HandleElementwiseUnary(exp);
}
virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) {
return HandleElementwiseUnary(floor, HloOpcode::kFloor);
return HandleElementwiseUnary(floor);
}
virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) {
return HandleElementwiseUnary(ceil, HloOpcode::kCeil);
return HandleElementwiseUnary(ceil);
}
virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) {
return HandleElementwiseUnary(log, HloOpcode::kLog);
return HandleElementwiseUnary(log);
}
virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) {
return HandleElementwiseUnary(cos, HloOpcode::kCos);
return HandleElementwiseUnary(cos);
}
virtual Status HandleSin(HloInstruction* sin, HloInstruction* operand) {
return HandleElementwiseUnary(sin, HloOpcode::kSin);
return HandleElementwiseUnary(sin);
}
virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
return HandleElementwiseUnary(tanh, HloOpcode::kTanh);
return HandleElementwiseUnary(tanh);
}
virtual Status HandleIsFinite(HloInstruction* is_finite,
HloInstruction* operand) {
return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite);
return HandleElementwiseUnary(is_finite);
}
virtual Status HandleLogicalAnd(HloInstruction* logical_and,
HloInstruction* lhs, HloInstruction* rhs) {
return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd);
return HandleElementwiseBinary(logical_and);
}
virtual Status HandleLogicalNot(HloInstruction* logical_not,
HloInstruction* operand) {
return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot);
return HandleElementwiseUnary(logical_not);
}
virtual Status HandleLogicalOr(HloInstruction* logical_or,
HloInstruction* lhs, HloInstruction* rhs) {
return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr);
return HandleElementwiseBinary(logical_or);
}
virtual Status HandleReducePrecision(HloInstruction* reduce_precision) {
return HandleElementwiseUnary(reduce_precision,
HloOpcode::kReducePrecision);
return HandleElementwiseUnary(reduce_precision);
}
virtual Status HandleInfeed(HloInstruction* infeed) = 0;

View File

@ -41,12 +41,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
// Default action performed on HloInstruction.
virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0;
Status HandleElementwiseUnary(HloInstruction* hlo,
HloOpcode opcode) override {
Status HandleElementwiseUnary(HloInstruction* hlo) override {
return DefaultAction(hlo);
}
Status HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) override {
Status HandleElementwiseBinary(HloInstruction* hlo) override {
return DefaultAction(hlo);
}

View File

@ -709,7 +709,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
} else {
auto r = ir_builder_->CreateSub(q, p);
auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)},
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)},
{param_ir_type}, ir_builder_);
auto in_block = ir_builder_->GetInsertBlock();

View File

@ -334,7 +334,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_);
IrArray::Index input_index(index.size());
llvm::Value* in_bounds = ir_builder_->getInt1(1);
llvm::Value* in_bounds = ir_builder_->getInt1(true);
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* stridden_index = ir_builder_->CreateNSWMul(
index[i], ir_builder_->getInt64(window.dimensions(i).stride()));

View File

@ -389,7 +389,7 @@ StatusOr<string> CompileModuleToPtx(llvm::Module* module,
// Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA
// again after the standard optimization passes [http://b/13329423].
// TODO(jingyue): SROA may further expose more optimization opportunities, such
// TODO(jingyue): SROA may further expose more optimization opportunities such
// as more precise alias analysis and more function inlining (SROA may change
// the inlining cost of a function). For now, running SROA already emits good
// enough code for the evaluated benchmarks. We may want to run more

View File

@ -37,6 +37,230 @@ namespace xla {
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
// Data structure used to construct the alias analysis. Thrown away after alias
// analysis is complete. This data structure keeps track of which sets of
// HloValues must be in the same HloBuffer. This is maintained as a map from a
// buffer identifier (BufferNumber) to set of HLoValues.
//
// Initially each value is its own buffer. In MergeAliasedBuffers, sets of
// values which must share the same buffer are merged together. The end result
// is a partitioning of all HloValues into sets where each set needs its own
// HloBuffer. By performing this analysis without constructing HloBuffers on the
// fly, we can after-the-fact construct a vector of contiguously numbered
// HloBuffers after the buffer requirement has been determined.
class BufferValueMap {
public:
// A unique identifier for a set of colocated values which must share the same
// buffer. This is not necessarily the same as the HloBuffer::Id which will
// ultimately contain the values. The reason is that HloBuffer::Id's are
// contiguous, while BufferNumbers may not be. BufferNumbers may not be
// dense because buffers may be created and destroyed during the analysis
// construction process.
using BufferNumber = int64;
explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
: dataflow_(dataflow) {
buffers_.reserve(dataflow_.values().size());
value_to_buffer_number_.reserve(dataflow_.values().size());
for (const HloValue* value : dataflow_.values()) {
BufferNumber buffer_number = next_buffer_number_++;
buffers_[buffer_number].insert(value);
value_to_buffer_number_[value] = buffer_number;
}
}
// Merge together sets of HloValues which must be in the same HloBuffer
// because of aliasing rules (eg, in-place kWhile instruction).
void MergeAliasedBuffers() {
for (const HloValue* value : dataflow_.values()) {
VLOG(3) << "Merging colocated values, value: " << value->ToShortString();
// Gather the set of buffers with aliasing rules (eg, kWhile) which this
// value must be contained in.
std::vector<BufferNumber> aliased_buffers = ComputeAliasedBuffers(*value);
BufferNumber current_buffer = value_to_buffer_number_.at(value);
if (aliased_buffers.empty()) {
// The buffer containing 'value' aliases no other buffers. If the buffer
// containing 'value' already only contains 'value', then no change is
// necessary. If the buffer containing 'value' does contain other
// values, then remove 'value' from the buffer and create a new buffer
// containing only 'value'
if (buffers_.at(current_buffer).size() == 1) {
CHECK_EQ(*buffers_.at(current_buffer).begin(), value);
} else {
MoveValueToNewBuffer(*value);
}
} else {
// If multiple buffers are aliased merge these buffers together into a
// single buffer (arbitrarily chosen as the first buffer in the vector).
if (aliased_buffers.size() > 1) {
for (int64 i = 1; i < aliased_buffers.size(); ++i) {
MergeBuffers(/*from=*/aliased_buffers[i],
/*to=*/aliased_buffers[0]);
}
}
BufferNumber new_buffer = aliased_buffers[0];
if (current_buffer != new_buffer) {
MoveValueToBuffer(*value, new_buffer);
}
}
}
}
// Compute and return a sorted vector of all BufferNumbers. Can be used to
// iterate through all buffers stabily.
std::vector<BufferNumber> ComputeSortedBufferNumbers() const {
std::vector<BufferNumber> buffer_numbers;
for (const auto& pair : buffers_) {
buffer_numbers.push_back(pair.first);
}
std::sort(buffer_numbers.begin(), buffer_numbers.end());
return buffer_numbers;
}
// Return a set of all the values in the given buffer.
const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer(
BufferNumber buffer_number) const {
return buffers_.at(buffer_number);
}
private:
// Create a new buffer.
void NewBuffer(const HloValue& value) {
BufferNumber buffer_number = next_buffer_number_++;
buffers_[buffer_number].insert(&value);
value_to_buffer_number_[&value] = buffer_number;
}
// Move the given value into a new buffer containing only the value.
void MoveValueToNewBuffer(const HloValue& value) {
BufferNumber new_buffer_number = next_buffer_number_++;
buffers_[new_buffer_number];
MoveValueToBuffer(value, new_buffer_number);
}
// Move the given value into the given buffer.
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
buffers_.at(old_buffer_number).erase(&value);
if (buffers_.at(old_buffer_number).empty()) {
buffers_.erase(old_buffer_number);
}
buffers_.at(buffer_number).insert(&value);
value_to_buffer_number_.at(&value) = buffer_number;
}
// Merge the buffer 'from' into the buffer 'to'.
void MergeBuffers(BufferNumber from, BufferNumber to) {
auto& from_value_set = buffers_.at(from);
buffers_.at(to).insert(from_value_set.begin(), from_value_set.end());
// NOTE: using a union-find algorithm to hold the colocated values might be
// faster.
for (const HloValue* value : from_value_set) {
value_to_buffer_number_.at(value) = to;
}
buffers_.erase(from);
}
BufferNumber GetBufferForValue(const HloValue& value) {
return value_to_buffer_number_.at(&value);
}
// Compute and return a vector of buffers that the given value must be
// contained in due to HLO aliasing rules.
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
// Value is init of a while (use is while).
std::vector<BufferNumber> aliased_buffers;
for (const HloUse& use : value.uses()) {
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
if (use.instruction->opcode() == HloOpcode::kWhile) {
// Determine the while value that this shares a buffer with.
const HloValue& while_value =
dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
aliased_buffers.push_back(GetBufferForValue(while_value));
VLOG(3) << " value is init value to a while; must share buffer with "
"while value "
<< while_value.ToShortString();
}
}
// Value is a parameter of a while body/condition.
if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
const HloComputation* computation =
value.defining_instruction()->parent();
const CallGraphNode& call_graph_node =
dataflow_.call_graph().GetNode(computation);
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Call graph must have been flattened.
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
const HloValue& while_value = dataflow_.GetUniqueValueAt(
callsite.instruction(), value.defining_index());
VLOG(3) << " value is parameter value of the body or condition of a "
"while; must share buffer with while value "
<< while_value.ToShortString();
aliased_buffers.push_back(GetBufferForValue(while_value));
}
}
}
// Value is the root of a while body.
for (const HloPosition& position : value.positions()) {
const HloComputation* computation = position.instruction->parent();
const CallGraphNode& call_graph_node =
dataflow_.call_graph().GetNode(computation);
if (position.instruction == computation->root_instruction()) {
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
callsite.instruction()->while_body() == computation) {
// Call graph must have been flattened.
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
const HloValue& while_value = dataflow_.GetUniqueValueAt(
callsite.instruction(), position.index);
VLOG(3) << " value is root the body computation of a while; must "
"share buffer with while value "
<< while_value.ToShortString();
aliased_buffers.push_back(GetBufferForValue(while_value));
}
}
}
}
// Value is the output of the while instruction itself.
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
VLOG(3) << " value is output of a while instruction";
aliased_buffers.push_back(GetBufferForValue(value));
}
// Uniquify aliased buffers.
std::sort(aliased_buffers.begin(), aliased_buffers.end());
aliased_buffers.erase(
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
aliased_buffers.end());
return aliased_buffers;
}
// Dataflow analysis used to construct the buffer map.
const HloDataflowAnalysis& dataflow_;
// A map containing the set of values contained in each buffer.
tensorflow::gtl::FlatMap<BufferNumber,
tensorflow::gtl::FlatSet<const HloValue*>>
buffers_;
// A map indicating which buffer each value is contained in.
tensorflow::gtl::FlatMap<const HloValue*, BufferNumber>
value_to_buffer_number_;
// The buffer number of the next buffer to be created.
BufferNumber next_buffer_number_ = 0;
};
HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {}
const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
@ -99,10 +323,11 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
}
} else {
// It's possible for multiple values at this index to have the same
// HloBuffer. This does not result in non-distictness. To account for this
// case, add all of the buffers at this index after checking whether each
// buffer exists at an earlier index. This is a corner case, however, as
// the number of values at an index is almost always one.
// HloBuffer. This does not result in non-distictness. To account for
// this case, add all of the buffers at this index after checking
// whether each buffer exists at an earlier index. This is a corner
// case, however, as the number of values at an index is almost always
// one.
std::vector<const HloBuffer*> buffers_at_this_index;
for (const HloValue* value : value_set.values()) {
const HloBuffer* buffer = &GetBufferContainingValue(*value);
@ -118,15 +343,6 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
return true;
}
void HloAliasAnalysis::InitializeBufferSets() {
// Initially define a buffer for every HloValue in the module.
for (const HloValue& value : dataflow_analysis_->values()) {
HloBuffer& buffer = NewHloBuffer();
buffer.AddValue(value);
value_to_buffer_[&value] = &buffer;
}
}
Status HloAliasAnalysis::Verify() const {
// Verify consistency between the value_to_buffer_ map and
// HloBuffer::values().
@ -137,9 +353,8 @@ Status HloAliasAnalysis::Verify() const {
value) != buffer.values().end());
}
for (const auto& pair : buffers_) {
const HloBuffer::Id id = pair.first;
const HloBuffer& buffer = pair.second;
for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
const HloBuffer& buffer = buffers_[id];
TF_RET_CHECK(buffer.id() == id);
HloValue::Id last_value_id = -1;
@ -152,116 +367,9 @@ Status HloAliasAnalysis::Verify() const {
}
}
if (!buffers_vector_.empty()) {
// buffers_vector_ should be a vector of all HloBuffers sorted by id.
std::vector<const HloBuffer*> buffers;
for (const auto& id_buffer : buffers_) {
buffers.push_back(&id_buffer.second);
}
std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan);
TF_RET_CHECK(buffers_vector_ == buffers);
}
return Status::OK();
}
Status HloAliasAnalysis::VerifyAgainstReference() const {
TF_RETURN_IF_ERROR(Verify());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> reference,
Run(module_));
TF_RETURN_IF_ERROR(reference->Verify());
VLOG(2) << "This analysis:";
XLA_VLOG_LINES(2, ToString());
VLOG(2) << "Reference:";
XLA_VLOG_LINES(2, reference->ToString());
// Create map from HloValue in the reference analysis to HloValue in this
// analysis and vice versa.
tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> reference_to_this;
tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> this_to_reference;
for (const HloValue& value : dataflow_analysis().values()) {
const HloValue& reference_value =
reference->dataflow_analysis().GetValueDefinedAt(
value.defining_instruction(), value.defining_index());
reference_to_this[&reference_value] = &value;
this_to_reference[&value] = &reference_value;
}
TF_RET_CHECK(buffers_.size() == reference->buffers_.size())
<< "Different number of buffers (" << buffers_.size()
<< " != " << reference->buffers_.size() << ")";
for (const auto& pair : reference->buffers_) {
const HloBuffer& reference_buffer = pair.second;
// Find the corresponding buffer in the reference by taking the first value
// in the buffer, finding the corresponding value in the reference, and then
// finding the buffer holding that value.
TF_RET_CHECK(!reference_buffer.values().empty());
const HloValue* reference_value = reference_buffer.values()[0];
const HloValue* value = reference_to_this.at(reference_value);
const HloBuffer& buffer = GetBufferContainingValue(*value);
// The buffer and the reference should have the exact same values. To make
// comparison easy, sort the values in the reference buffer identically to
// the values in the non-reference buffer (ie, by the corresponding id of
// the non-reference value).
std::vector<const HloValue*> reference_values = reference_buffer.values();
std::sort(reference_values.begin(), reference_values.end(),
[&reference_to_this](const HloValue* a, const HloValue* b) {
return reference_to_this.at(a)->id() <
reference_to_this.at(b)->id();
});
TF_RET_CHECK(reference_values.size() == buffer.values().size());
for (int i = 0; i < buffer.values().size(); ++i) {
TF_RET_CHECK(*reference_values[i] == *buffer.values()[i])
<< "Buffer:\n " << buffer
<< "\ndoes not have the same values as reference buffer:\n "
<< reference_buffer;
}
}
return Status::OK();
}
HloBuffer& HloAliasAnalysis::NewHloBuffer() {
HloBuffer::Id buffer_id = next_buffer_id_++;
auto emplaced = buffers_.emplace(std::piecewise_construct,
std::forward_as_tuple(buffer_id),
std::forward_as_tuple(buffer_id));
CHECK(emplaced.second);
buffers_vector_.clear();
return emplaced.first->second;
}
void HloAliasAnalysis::MoveValueToNewBuffer(const HloValue& value) {
HloBuffer& new_buffer = NewHloBuffer();
MoveValueToBuffer(value, &new_buffer);
VLOG(3) << "Moved value " << value.ToShortString() << " into new buffer "
<< new_buffer.id();
}
void HloAliasAnalysis::MoveValueToBuffer(const HloValue& value,
HloBuffer* buffer) {
HloBuffer& old_buffer = GetBufferContainingValue(value);
CHECK_NE(buffer, &old_buffer);
VLOG(3) << "Moved value " << value.ToShortString() << " from buffer "
<< old_buffer.id() << " into buffer " << buffer->id();
old_buffer.RemoveValue(value);
if (old_buffer.values().empty()) {
VLOG(3) << "Buffer " << old_buffer.id() << " now empty. Removing.";
buffers_.erase(old_buffer.id());
buffers_vector_.clear();
}
buffer->AddValue(value);
value_to_buffer_[&value] = buffer;
}
string HloAliasAnalysis::ToString() const {
string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Buffers at each position:\n");
@ -290,10 +398,10 @@ string HloAliasAnalysis::ToString() const {
}
StrAppend(&out, " Buffers:\n");
for (const HloBuffer* buffer : buffers()) {
StrAppend(&out, " ", buffer->ToString(), "\n");
for (const HloBuffer& buffer : buffers()) {
StrAppend(&out, " ", buffer.ToString(), "\n");
StrAppend(&out, " positions:\n");
for (const HloPosition& position : buffer->ComputePositions()) {
for (const HloPosition& position : buffer.ComputePositions()) {
StrAppend(&out, " ", position.ToString(), "\n");
}
}
@ -301,217 +409,6 @@ string HloAliasAnalysis::ToString() const {
return out;
}
const std::vector<const HloBuffer*>& HloAliasAnalysis::buffers() const {
if (buffers_vector_.empty()) {
// Lazily construct vector of buffers.
buffers_vector_.reserve(buffers_.size());
for (auto& pair : buffers_) {
buffers_vector_.push_back(&pair.second);
}
std::sort(buffers_vector_.begin(), buffers_vector_.end(),
HloBuffer::IdLessThan);
} else {
CHECK_EQ(buffers_vector_.size(), buffers_.size());
for (const HloBuffer* buffer : buffers_vector_) {
DCHECK(ContainsKey(buffers_, buffer->id()));
DCHECK(&GetBuffer(buffer->id()) == buffer);
}
}
return buffers_vector_;
}
void HloAliasAnalysis::UpdateAtInstructions(
tensorflow::gtl::ArraySlice<const HloInstruction*> instructions) {
VLOG(4) << "Updated HLO module:";
XLA_VLOG_LINES(4, module_->ToString());
VLOG(3) << "Before update:";
XLA_VLOG_LINES(3, ToString());
std::vector<const HloValue*> values_to_update;
for (const HloInstruction* instruction : instructions) {
for (auto& pair : dataflow_analysis().GetInstructionValueSet(instruction)) {
for (const HloValue* value : pair.second.values()) {
values_to_update.push_back(value);
}
}
}
UpdateBuffersForValues(values_to_update);
VLOG(3) << "After update:";
XLA_VLOG_LINES(3, ToString());
}
void HloAliasAnalysis::UpdateAfterChangingOperand(HloInstruction* instruction,
HloInstruction* old_operand,
HloInstruction* new_operand) {
VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
<< old_operand->name() << " => " << new_operand->name() << ")";
dataflow_analysis_->UpdateAfterChangingOperand(instruction, old_operand,
new_operand);
TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
VLOG(4) << "Updated dataflow:";
XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
UpdateAtInstructions({instruction, old_operand, new_operand});
}
void HloAliasAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
HloInstruction* new_root) {
VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
<< new_root->name() << ")";
dataflow_analysis_->UpdateAfterChangingRoot(old_root, new_root);
TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
VLOG(4) << "Updated dataflow:";
XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
UpdateAtInstructions({old_root, new_root});
}
std::vector<HloBuffer*> HloAliasAnalysis::ComputeAliasedBuffers(
const HloValue& value) {
std::vector<HloBuffer*> aliased_buffers;
// Value is init of a while (use is while).
for (const HloUse& use : value.uses()) {
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
if (use.instruction->opcode() == HloOpcode::kWhile) {
// Determine the while value that this shares a buffer with.
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
use.instruction, use.operand_index);
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
VLOG(3) << " value is init value to a while; must share buffer with "
"while value "
<< while_value.ToShortString();
}
}
// Value is a parameter of a while body/condition.
if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
const HloComputation* computation = value.defining_instruction()->parent();
const CallGraphNode& call_graph_node =
dataflow_analysis().call_graph().GetNode(computation);
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Call graph must have been flattened.
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
callsite.instruction(), value.defining_index());
VLOG(3) << " value is parameter value of the body or condition of a "
"while; must share buffer with while value "
<< while_value.ToShortString();
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
}
}
}
// Value is the root of a while body.
for (const HloPosition& position : value.positions()) {
const HloComputation* computation = position.instruction->parent();
const CallGraphNode& call_graph_node =
dataflow_analysis().call_graph().GetNode(computation);
if (position.instruction == computation->root_instruction()) {
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
callsite.instruction()->while_body() == computation) {
// Call graph must have been flattened.
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
// If the value appears in the root of a while body, then
// necessarily the value is defined in the body as well.
CHECK_EQ(value.defining_instruction()->parent(), computation);
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
callsite.instruction(), position.index);
VLOG(3) << " value is root the body computation of a while; must "
"share buffer with while value "
<< while_value.ToShortString();
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
}
}
}
}
// Value is in the while instruction itself.
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
VLOG(3) << " value is output of a while instruction";
aliased_buffers.push_back(&GetUniqueBufferAt(value.defining_instruction(),
value.defining_index()));
}
// Uniquify aliased buffers.
std::sort(aliased_buffers.begin(), aliased_buffers.end(),
HloBuffer::IdLessThan);
aliased_buffers.erase(
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
aliased_buffers.end());
return aliased_buffers;
}
// This method recomputes the HloBuffer for each of the given HloValues. The
// method does not necessarily update the HloBuffer of values which share a
// buffer with the given values, but are not explicitly passed in
// 'values'. Therefore, the caller must pass in all values which may require an
// update according to the kind of HLO graph change which occurred: operand
// changed (UpdateAfterChangingOperand), or root of computation changed
// (UpdateAfterChangingRoot).
void HloAliasAnalysis::UpdateBuffersForValues(
tensorflow::gtl::ArraySlice<const HloValue*> values) {
for (const HloValue* value : values) {
VLOG(3) << "Updating buffer for value: " << value->ToShortString();
// Gather the set of buffer with aliasing rules (eg, kWhile) which this
// value must be contained in due.
std::vector<HloBuffer*> aliased_buffers = ComputeAliasedBuffers(*value);
HloBuffer& current_buffer = GetBufferContainingValue(*value);
if (aliased_buffers.empty()) {
// The buffer containing 'value' aliases no other buffers. If the buffer
// containing 'value' already only contains 'value', then no change is
// necessary. If the buffer containing 'value' does contain other values,
// then remove 'value' from the buffer and create a new buffer containing
// only 'value'
if (current_buffer.values().size() == 1) {
CHECK_EQ(current_buffer.values()[0], value);
} else {
MoveValueToNewBuffer(*value);
}
} else {
// If multiple buffers are aliased merge these buffers together into a
// single buffer (arbitrarily chosen as the first buffer in the vector).
if (aliased_buffers.size() > 1) {
for (int64 i = 1; i < aliased_buffers.size(); ++i) {
// Make copy of values vector because MoveValueToBuffer invalidates
// the values iterator. The could be done more efficiently by moving
// all values and once.
std::vector<const HloValue*> values = aliased_buffers[i]->values();
for (const HloValue* value : values) {
MoveValueToBuffer(*value, aliased_buffers[0]);
}
}
aliased_buffers.resize(1);
}
CHECK_EQ(aliased_buffers.size(), 1);
HloBuffer* new_buffer = aliased_buffers[0];
if (&current_buffer != new_buffer) {
MoveValueToBuffer(*value, new_buffer);
}
}
VLOG(4) << "Analysis after update:";
XLA_VLOG_LINES(4, ToString());
}
}
/* static */
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
HloModule* module) {
@ -524,18 +421,28 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
HloDataflowAnalysis::Run(module, /*ssa_form=*/true,
/*bitcast_defines_value=*/false));
alias_analysis->InitializeBufferSets();
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
buffer_map.MergeAliasedBuffers();
VLOG(3) << "After initialization:";
XLA_VLOG_LINES(3, alias_analysis->ToString());
std::vector<const HloValue*> all_values;
for (const HloValue& value : alias_analysis->dataflow_analysis().values()) {
all_values.push_back(&value);
// Create a vector of HloBuffers, one for each set of values in the
// BufferValueMap. Create the HloBuffers as a vector of contiguously numbered
// buffers.
std::vector<BufferValueMap::BufferNumber> sorted_buffer_numbers =
buffer_map.ComputeSortedBufferNumbers();
alias_analysis->buffers_.reserve(sorted_buffer_numbers.size());
HloBuffer::Id next_id = 0;
for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) {
auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
std::vector<const HloValue*> sorted_values(value_set.begin(),
value_set.end());
std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan);
alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
for (const HloValue* value : sorted_values) {
alias_analysis->value_to_buffer_[value] =
&alias_analysis->buffers_.back();
}
}
alias_analysis->UpdateBuffersForValues(all_values);
TF_DCHECK_OK(alias_analysis->Verify());
XLA_VLOG_LINES(1, alias_analysis->ToString());

View File

@ -74,7 +74,7 @@ class HloAliasAnalysis {
// Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This
// vector is lazily computed. Mutating operations on HloAliasAnalysis may
// invalidate the underlying vector requiring recomputation.
const std::vector<const HloBuffer*>& buffers() const;
const std::vector<HloBuffer>& buffers() const { return buffers_; }
// Returns the underlying dataflow analysis used by this alias analysis.
const HloDataflowAnalysis& dataflow_analysis() const {
@ -90,50 +90,13 @@ class HloAliasAnalysis {
// output of the given instruction.
bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const;
// Updates the analysis after the operands of 'instruction' have changed or if
// 'instruction' has been made the root of a computation. Analysis update is
// not possible if instructions have been added or removed from the graph.
void UpdateAfterChangingOperand(HloInstruction* instruction,
HloInstruction* old_operand,
HloInstruction* new_operand);
void UpdateAfterChangingRoot(HloInstruction* old_root,
HloInstruction* new_root);
// Compare the dataflow analysis against a clean recomputation of the
// analysis. Returns an error status if there is a mismatch. Useful for
// verifying the correctness after updates to the analysis.
Status VerifyAgainstReference() const;
protected:
HloAliasAnalysis(HloModule* module);
// Create a new empty HloBuffer.
HloBuffer& NewHloBuffer();
// Move the given value to the given buffer. The value is removed from it's
// current buffer.
void MoveValueToBuffer(const HloValue& value, HloBuffer* buffer);
// Move the given value to a newly created buffer. The value is removed from
// it's current buffer.
void MoveValueToNewBuffer(const HloValue& value);
// Construct the initial set of buffer sets where an HloBuffer is created for
// each HloValue in the module.
void InitializeBufferSets();
// Compute and return the buffers with aliasing rules (eg, kWhile) which the
// given value must be contained in.
std::vector<HloBuffer*> ComputeAliasedBuffers(const HloValue& value);
// Recompute the HloBuffers for the given values.
void UpdateBuffersForValues(
tensorflow::gtl::ArraySlice<const HloValue*> values);
// Recompute the HloBuffers for all the values which appear in the output of
// the given instructions.
void UpdateAtInstructions(
tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
explicit HloAliasAnalysis(HloModule* module);
// Verify various invariants of the alias analysis.
Status Verify() const;
@ -143,20 +106,12 @@ class HloAliasAnalysis {
// The underlying dataflow analysis used by this alias analysis.
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
// The map of all HloBuffers in the module. We pass around pointers to the
// mapped HloBuffers, so the underlying container must keep them valid despite
// mutations touching other map entries.
std::unordered_map<HloBuffer::Id, HloBuffer> buffers_;
// A map indicating which buffer a value is contained in.
tensorflow::gtl::FlatMap<const HloValue*, HloBuffer*> value_to_buffer_;
// A lazily constructed vector containing all HloBuffers sorted by
// HloBuffer::Id.
mutable std::vector<const HloBuffer*> buffers_vector_;
// The Id to use for the next HloBuffer.
int64 next_buffer_id_ = 0;
std::vector<HloBuffer> buffers_;
};
} // namespace xla

View File

@ -87,14 +87,13 @@ class HloAliasAnalysisTest : public HloTestBase {
// constructed.
bool AnyValuesInSameBufferInterfere() {
DependencyHloOrdering ordering(module_.get());
for (const HloBuffer* buffer : analysis_->buffers()) {
for (const HloValue* value_a : buffer->values()) {
for (const HloValue* value_b : buffer->values()) {
for (const HloBuffer& buffer : analysis_->buffers()) {
for (const HloValue* value_a : buffer.values()) {
for (const HloValue* value_b : buffer.values()) {
if (*value_a != *value_b &&
analysis_->dataflow_analysis().MayInterfere(*value_a, *value_b,
ordering)) {
ordering.MayInterfere(*value_a, *value_b)) {
VLOG(1) << *value_a << " interferes with " << *value_b
<< " in buffer: " << *buffer;
<< " in buffer: " << buffer;
return true;
}
}
@ -384,10 +383,7 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
UnorderedElementsAre(GetValueDefinedAt(xla_while, /*index=*/{0}),
GetValueDefinedAt(body_param, /*index=*/{0}),
GetValueDefinedAt(cond_param, /*index=*/{0}),
GetValueDefinedAt(constant1)));
UnorderedElementsAre(GetValueDefinedAt(constant1)));
EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
UnorderedElementsAre(GetValueDefinedAt(constant2),
@ -631,9 +627,9 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
// HloBuffers.
EXPECT_THAT(
analysis.buffers(),
UnorderedElementsAre(&analysis.GetUniqueBufferAt(constant1),
&analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
&analysis.GetUniqueBufferAt(cond_constant)));
UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
analysis.GetUniqueBufferAt(cond_constant)));
// The tuple elements of the while and the three constant inputs should all be
// smooshed into the same buffer.
@ -820,127 +816,5 @@ TEST_F(HloAliasAnalysisTest, Bitcast) {
analysis.GetUniqueBufferAt(bitcast));
}
TEST_F(HloAliasAnalysisTest, UpdateAnalysisForWhile) {
// Test updating alias analysis after modifying a module with an array shaped
// while:
//
// body(F32[] %param):
// %negate = Negate(%param)
//
// condition(F32[] %param):
// return Constant(false)
//
// entry:
// %constant = Constant(1.0)
// %exp = Exp(%constant)
// return While(%exp, body, condition)
//
auto body_builder = HloComputation::Builder("body");
auto body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kNegate, body_param));
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
// Condition computation trivially returns a constant "false".
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
module_->AddEntryComputation(builder.Build());
HloAliasAnalysis& analysis = RunAnalysis();
// Sanity check some alias information.
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(body_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(cond_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(negate));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(xla_while));
// Set the body root to the body_param. Previously it was Negate(body_param).
body->set_root_instruction(body_param);
// Prior to updating, verify that the analysis is no longer valid.
Status verify_status = analysis.VerifyAgainstReference();
EXPECT_FALSE(verify_status.ok());
analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
/*new_root*/ body_param);
// Analysis should be valid after the update.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
// The exponential should now pass through the body transparently.
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(body_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(cond_param));
EXPECT_NE(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(negate));
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
analysis.GetUniqueBufferAt(xla_while));
// Now replace the operand of the while with %constant (was %exp).
TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
/*new_operand=*/constant);
// Analysis should be valid after the update.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(body_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(cond_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(xla_while));
EXPECT_NE(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(exp));
EXPECT_NE(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(negate));
// And finally make the negate the root of the body again.
body->set_root_instruction(negate);
analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
/*new_root*/ negate);
// Analysis should be valid after the update.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
analysis.GetUniqueBufferAt(body_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
analysis.GetUniqueBufferAt(cond_param));
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
analysis.GetUniqueBufferAt(xla_while));
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
analysis.GetUniqueBufferAt(negate));
auto value_of = [&analysis](const HloInstruction* instruction) {
return &analysis.dataflow_analysis().GetValueDefinedAt(instruction);
};
EXPECT_THAT(analysis.GetUniqueBufferAt(negate).values(),
UnorderedElementsAre(value_of(body_param), value_of(cond_param),
value_of(negate), value_of(constant),
value_of(xla_while)));
}
// Test update tuple element.
} // namespace
} // namespace xla

View File

@ -36,22 +36,6 @@ namespace xla {
using ::tensorflow::str_util::Join;
using ::tensorflow::strings::StrCat;
void HloBuffer::AddValue(const HloValue& value) {
values_.push_back(&value);
// Sort vector and remove duplicates.
std::sort(values_.begin(), values_.end(), HloValue::IdLessThan);
values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
values_.end());
}
void HloBuffer::RemoveValue(const HloValue& value) {
// The values are sorted, so finding the value could be done in log(n) time
// with a binary search.
auto it = std::find(values_.begin(), values_.end(), &value);
CHECK(it != values_.end());
values_.erase(it);
}
bool HloBuffer::operator==(const HloBuffer& other) const {
bool equal = id() == other.id();
if (equal) {

View File

@ -84,22 +84,15 @@ class HloBuffer {
return a->id() == b->id();
}
HloBuffer(Id id) : id_(id) {}
HloBuffer(Id id, tensorflow::gtl::ArraySlice<const HloValue*> values)
: id_(id), values_(values.begin(), values.end()) {}
// Return the unique identifier for this HloBuffer.
Id id() const { return id_; }
// Add a value to the set of values held by this buffer. Also adds the
// HloPositions of the value to the positions vector of the buffer. If the
// buffer already contains this value, then this method is a nop.
void AddValue(const HloValue& value);
void RemoveValue(const HloValue& value);
// Return all values contained in this buffer.
const std::vector<const HloValue*>& values() const { return values_; }
std::vector<HloPosition> ComputePositions() const;
// Return the unique HLO value in the buffer. CHECK fails if the buffer does
// not contain exactly one value.
const HloValue& GetUniqueValue() const {
@ -107,6 +100,8 @@ class HloBuffer {
return *values_[0];
}
std::vector<HloPosition> ComputePositions() const;
string ToString() const;
bool operator==(const HloBuffer& other) const;
@ -118,7 +113,7 @@ class HloBuffer {
// The set of values contained in this buffer. Vector contains no duplicates
// and is sorted stably by HloValue::Id.
std::vector<const HloValue*> values_;
const std::vector<const HloValue*> values_;
};
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);

View File

@ -118,13 +118,11 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) {
}
}
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo,
HloOpcode opcode) {
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo) {
return HandleElementwiseOp(hlo);
}
Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) {
Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo) {
return HandleElementwiseOp(hlo);
}

View File

@ -49,9 +49,8 @@ class HloCostAnalysis : public DfsHloVisitor {
using ShapeSizeFunction = std::function<int64(const Shape&)>;
explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override;
Status HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) override;
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
Status HandleConstant(HloInstruction* constant,
const Literal& literal) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element,

View File

@ -67,6 +67,22 @@ HloValue& HloDataflowAnalysis::GetValueDefinedAt(
return GetUniqueValueAt(instruction, index);
}
HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
const ShapeIndex& index,
bool is_phi) {
const int64 value_id = next_value_id_++;
auto emplaced = values_.emplace(
std::piecewise_construct, std::forward_as_tuple(value_id),
std::forward_as_tuple(value_id, instruction, index, is_phi));
CHECK(emplaced.second);
return &emplaced.first->second;
}
void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
values_.erase(value_id);
}
string HloDataflowAnalysis::ToString() const {
string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Instruction value sets:\n");
@ -99,22 +115,98 @@ string HloDataflowAnalysis::ToString() const {
}
}
StrAppend(&out, " HloValues:\n");
for (const HloValue& value : values()) {
StrAppend(&out, value.ToString(/*indent=*/4));
}
StrAppend(&out, " Phi resolutions:\n");
for (const HloValue& value : values()) {
if (value.is_phi()) {
const HloValue* resolved_value = ResolvePhi(value);
StrAppend(&out, " ", value.ToShortString(), " => ",
resolved_value == nullptr ? "UNKNOWN"
: resolved_value->ToShortString(),
"\n");
}
for (const HloValue* value : values()) {
StrAppend(&out, value->ToString(/*indent=*/4));
}
return out;
}
bool HloDataflowAnalysis::Phi(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
CHECK(ssa_form_);
for (const InstructionValueSet* input : inputs) {
DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
}
bool changed = false;
for (auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
// Positions with phi values should never have more than one value in the
// value set.
CHECK_LE(value_set.values().size(), 1);
const HloValue* current_value =
value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
// Construct a vector of unique value IDs of the inputs.
std::vector<HloValue::Id> input_value_ids;
for (const InstructionValueSet* input : inputs) {
for (const HloValue* value : input->element(index).values()) {
input_value_ids.push_back(value->id());
}
}
std::sort(input_value_ids.begin(), input_value_ids.end());
input_value_ids.erase(
std::unique(input_value_ids.begin(), input_value_ids.end()),
input_value_ids.end());
// Remove the existing phi value (if it exists). The phi can be its own
// input, for example, in while body parameters where the body passes
// through the parameter value.
bool current_value_defined_here =
(current_value != nullptr &&
current_value->defining_instruction() == instruction &&
current_value->defining_index() == index);
if (current_value_defined_here) {
CHECK(current_value->is_phi());
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
current_value->id());
if (it != input_value_ids.end()) {
input_value_ids.erase(it);
}
}
if (input_value_ids.empty()) {
// A value set which has at least one element should never have its value
// set reduced to zero elements. During dataflow value sets only can go
// from empty to non-empty, not the reverse.
CHECK_EQ(value_set.values().size(), 0)
<< "Instruction " << instruction->name() << " at index " << index
<< " previously had non-empty value set. Value set: " << value_set;
} else if (input_value_ids.size() == 1) {
// Only a single value reaches this point. There should be no phi, and
// this value set should contain this single value.
const HloValue& new_value = GetValue(input_value_ids[0]);
if (current_value == nullptr) {
value_set.Clear();
value_set.AddValue(&new_value);
changed = true;
} else if (current_value != &new_value) {
if (current_value_defined_here) {
// Remove the existing phi.
DeleteHloValue(current_value->id());
}
value_set.Clear();
value_set.AddValue(&new_value);
changed = true;
}
} else {
// Multiple distinct values reach this point. A phi value is
// necessary.
CHECK_GT(input_value_ids.size(), 1);
if (current_value == nullptr || !current_value->is_phi()) {
value_set.Clear();
value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
changed = true;
}
}
}
return changed;
}
const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
return values_.at(value_id);
}
@ -142,129 +234,6 @@ HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
return GetValueSet(position.instruction, position.index);
}
void HloDataflowAnalysis::UpdateAfterChangingOperand(
HloInstruction* instruction, HloInstruction* old_operand,
HloInstruction* new_operand) {
CHECK(std::find(instruction->operands().begin(),
instruction->operands().end(),
new_operand) != instruction->operands().end());
VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
<< old_operand->name() << " => " << new_operand->name() << ")";
std::vector<HloInstruction*> to_update = {instruction};
// If the instruction calls any computations then add the parameters of called
// computation to capture any changes to the dataflow into the subcomputation
// introduced by the new operand.
for (HloComputation* computation : instruction->called_computations()) {
to_update.insert(to_update.end(),
computation->parameter_instructions().begin(),
computation->parameter_instructions().end());
}
UpdateInstructionsAndPropagate(to_update);
// The uses of the values in the old and new operand may have changed. Uses of
// other HloValues are updated in UpdateInstructionsAndPropagate.
for (auto& pair : GetInstructionValueSet(old_operand)) {
for (const HloValue* value : pair.second.values()) {
GetValue(value->id()).RecomputeUses();
}
}
for (auto& pair : GetInstructionValueSet(new_operand)) {
for (const HloValue* value : pair.second.values()) {
GetValue(value->id()).RecomputeUses();
}
}
TF_DCHECK_OK(VerifyAgainstReference());
}
void HloDataflowAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
HloInstruction* new_root) {
VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
<< new_root->name() << ")";
CHECK_EQ(new_root, new_root->parent()->root_instruction());
CHECK_EQ(new_root->parent(), old_root->parent());
std::vector<HloInstruction*> to_update = {old_root, new_root};
const CallGraphNode& call_graph_node =
call_graph_->GetNode(new_root->parent());
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
to_update.push_back(callsite.instruction());
} else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Add the while itself, and the body and condition parameters.
to_update.push_back(callsite.instruction());
to_update.push_back(
callsite.instruction()->while_body()->parameter_instruction(0));
to_update.push_back(
callsite.instruction()->while_condition()->parameter_instruction(0));
}
}
UpdateInstructionsAndPropagate(to_update);
TF_DCHECK_OK(VerifyAgainstReference());
}
const HloValue* HloDataflowAnalysis::ResolvePhi(const HloValue& phi) const {
CHECK(phi.is_phi());
tensorflow::gtl::FlatSet<const HloValue*> visited;
std::queue<const HloValue*> worklist;
auto add_to_worklist = [&worklist, &visited](const HloValue* v) {
if (visited.insert(v).second) {
// 'v' was not previously in visited.
worklist.push(v);
}
};
add_to_worklist(&phi);
const HloValue* resolved_value = nullptr;
while (!worklist.empty()) {
const HloValue* value = worklist.front();
worklist.pop();
if (!value->is_phi()) {
if (resolved_value == nullptr) {
resolved_value = value;
} else if (resolved_value != value) {
return nullptr;
}
} else {
for (const HloValue* input : phi_inputs_.at(value)) {
add_to_worklist(input);
}
}
}
return resolved_value;
}
void HloDataflowAnalysis::UpdatePhiInputs(
const HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
CHECK(ssa_form_);
for (auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
const HloValue& phi_value = GetUniqueValueAt(instruction, index);
auto& phi_inputs = phi_inputs_.at(&phi_value);
phi_inputs.clear();
for (const InstructionValueSet* input : inputs) {
for (const HloValue* value : input->element(index).values()) {
// The number of phi inputs is typically 2, and virtually always very
// small.
if (std::find(phi_inputs.begin(), phi_inputs.end(), value) ==
phi_inputs.end()) {
phi_inputs.push_back(value);
}
}
}
}
}
bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
const InstructionValueSet& operand_set =
@ -380,8 +349,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
}
if (ssa_form_ && called_from_while) {
UpdatePhiInputs(parameter, inputs);
return false;
return Phi(parameter, inputs);
} else {
return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
}
@ -439,8 +407,7 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
&GetInstructionValueSet(xla_while->while_body()->root_instruction()),
&GetInstructionValueSet(xla_while->operand(0))};
if (ssa_form_) {
UpdatePhiInputs(xla_while, inputs);
return false;
return Phi(xla_while, inputs);
} else {
return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
}
@ -487,38 +454,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
VLOG(3) << "Worklist top: " << instruction->name();
VLOG(3) << ToString();
// The updating of the instruction value set below in
// UpdateInstructionValueSet does not update HloValue::positions(). To
// perform the positions() update remove all positions in 'instruction' from
// the HloValues in 'instruction's value set prior to the update, then after
// the update add the new positions back in. There is likely a more
// efficient way of doing this.
for (auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
for (const HloValue* value : value_set.values()) {
if (value->defining_instruction() != instruction) {
// Use GetValue for a non-const HloValue reference.
GetValue(value->id()).RemovePosition(instruction, index);
}
}
}
bool changed = UpdateInstructionValueSet(instruction);
// Add the positions back in.
for (auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
for (const HloValue* value : value_set.values()) {
if (value->defining_instruction() != instruction) {
// Use GetValue for a non-const HloValue reference.
GetValue(value->id()).AddPosition(instruction, index);
}
}
}
if (!changed) {
if (!UpdateInstructionValueSet(instruction)) {
// No change to the instruction's value set.
VLOG(4) << "No change.";
continue;
@ -531,12 +467,16 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
for (HloInstruction* user : instruction->users()) {
worklist.push(user);
// If user calls a computation, then the respective parameter(s) of the
// computation need to be updated.
// If user sequentially calls a computation, then the respective
// parameter(s) of the computation need to be updated.
for (HloComputation* called_computation : user->called_computations()) {
for (int64 operand_number : user->OperandIndices(instruction)) {
worklist.push(
called_computation->parameter_instruction(operand_number));
const CallGraphNode& call_graph_node =
call_graph_->GetNode(called_computation);
if (call_graph_node.context() == CallContext::kSequential) {
for (int64 operand_number : user->OperandIndices(instruction)) {
worklist.push(
called_computation->parameter_instruction(operand_number));
}
}
}
}
@ -574,25 +514,10 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
}
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// Gather the values to create before creating them. This is done because we
// want to allocate the vector of values only once so references to elements
// are stable.
struct ValueToCreate {
HloInstruction* instruction;
ShapeIndex index;
bool is_phi;
};
std::vector<ValueToCreate> values_to_create;
for (const std::unique_ptr<HloComputation>& computation :
module_->computations()) {
const CallGraphNode& call_graph_node =
call_graph_->GetNode(computation.get());
bool called_from_while = std::any_of(
call_graph_node.caller_callsites().begin(),
call_graph_node.caller_callsites().end(), [](const CallSite& cs) {
return cs.instruction()->opcode() == HloOpcode::kWhile;
});
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
@ -603,20 +528,22 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// Lambda to set the value set to define all values in the output of the
// instruction.
auto define_all_values = [this, &instruction,
&values_to_create](bool is_phi = false) {
auto define_all_values = [this, &instruction](bool is_phi = false) {
for (auto& pair : GetInstructionValueSet(instruction.get())) {
const ShapeIndex& index = pair.first;
values_to_create.push_back({instruction.get(), index, is_phi});
HloValue* value =
NewHloValue(instruction.get(), index, /*is_phi=*/false);
GetValueSet(instruction.get(), index).AddValue(value);
}
};
// Lambda to set the value set to define only the top-level buffer in the
// output of the instruction. Any other values flow from the operands of
// the instruction (or from cross-computation dataflow).
auto define_top_level_only = [this, &instruction, &values_to_create]() {
values_to_create.push_back(
{instruction.get(), /*index=*/{}, /*is_phi=*/false});
auto define_top_level_only = [this, &instruction]() {
HloValue* value =
NewHloValue(instruction.get(), /*index=*/{}, /*is_phi=*/false);
GetValueSet(instruction.get(), /*index=*/{}).AddValue(value);
};
switch (instruction->opcode()) {
@ -626,10 +553,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
}
break;
case HloOpcode::kWhile:
if (ssa_form_) {
define_all_values(/*is_phi=*/true);
}
break;
case HloOpcode::kCall:
case HloOpcode::kGetTupleElement:
// These instructions define no values. The values in their output
@ -654,10 +577,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// values in their output. Otherwise the values of the parameter
// come from the caller (eg, operands to the kCall instruction).
define_all_values();
} else if (call_graph_node.context() == CallContext::kSequential &&
called_from_while && ssa_form_) {
// Parameters of while bodies and conditions are phis.
define_all_values(/*is_phi=*/true);
}
break;
case HloOpcode::kCopy:
@ -674,164 +593,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
}
}
// Reserve the vector ahead of time so references to elements are stable.
values_.reserve(values_to_create.size());
for (int64 i = 0; i < values_to_create.size(); ++i) {
const ValueToCreate& to_create = values_to_create[i];
values_.emplace_back(/*id=*/i, to_create.instruction, to_create.index,
to_create.is_phi);
const HloValue& value = values_.back();
GetValueSet(to_create.instruction, to_create.index).AddValue(&value);
if (value.is_phi()) {
phi_inputs_[&value] = {};
}
}
return Status::OK();
}
bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const {
// If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
// is live into the module.
if (b.defining_instruction()->parent() == module_->entry_computation() &&
b.defining_instruction()->opcode() == HloOpcode::kParameter) {
return false;
}
// Phi values require special handling. Because XLA does not have a phi
// instruction, the definition instruction of the phis values are
// placeholders: either the subcomputation parameter (body or condition) or
// the while instruction. However, the program point where these values are
// logically defined does not necessarily coincide exactly with program point
// of these place-holder instructions. So we explicitly define the following
// order for phi values:
//
// body/condition parameter phi:
// Defined before all values defined in its computation excepting other
// phis.
//
// while phi:
// defined after all values defined in the condition or body.
//
auto is_body_or_condition_phi = [](const HloValue& v) {
return v.is_phi() &&
v.defining_instruction()->opcode() == HloOpcode::kParameter;
};
if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
a.defining_instruction()->parent())) {
return true;
}
if (is_body_or_condition_phi(b) &&
call_graph_->InstructionIsNestedIn(a.defining_instruction(),
b.defining_instruction()->parent())) {
return false;
}
// If 'b' is a while phi and 'a' is in the body or condition, then 'a'
// executes before 'b'.
if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
(call_graph_->InstructionIsNestedIn(
a.defining_instruction(), b.defining_instruction()->while_body()) ||
call_graph_->InstructionIsNestedIn(
a.defining_instruction(),
b.defining_instruction()->while_condition()))) {
return true;
}
return ordering.ExecutesBefore(a.defining_instruction(),
b.defining_instruction());
}
bool HloDataflowAnalysis::UseIsBeforeValueDefinition(
const HloUse& use, const HloValue& value,
const HloOrdering& ordering) const {
if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) {
return true;
}
// If the use is at the instruction where the value is defined, then the use
// is before the def if the instruction allows buffer sharing (in place
// computation).
if (use.instruction == value.defining_instruction() &&
CanShareOperandBufferWithUser(
use.instruction->mutable_operand(use.operand_number),
use.operand_index, value.defining_instruction(),
value.defining_index())) {
return true;
}
// The use at a while is an input to a phi, and logically occurs before values
// are defined in the body or condition computations.
if (use.instruction->opcode() == HloOpcode::kWhile) {
const HloInstruction* xla_while = use.instruction;
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
xla_while->while_body()) ||
call_graph_->InstructionIsNestedIn(value.defining_instruction(),
xla_while->while_condition())) {
return true;
}
}
// Similarly if the value is defined at a while, it logically occurs after any
// uses in the body or condition computations.
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
CHECK(ssa_form_);
const HloInstruction* xla_while = value.defining_instruction();
if (call_graph_->InstructionIsNestedIn(use.instruction,
xla_while->while_body()) ||
call_graph_->InstructionIsNestedIn(use.instruction,
xla_while->while_condition())) {
return true;
}
}
return false;
}
bool HloDataflowAnalysis::LiveRangeStrictlyBefore(
const HloValue& a, const HloValue& b, const HloOrdering& ordering) const {
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
<< ", b = " << b.ToShortString() << ")";
if (!IsDefinedBefore(a, b, ordering)) {
VLOG(4) << "a not defined before b";
return false;
}
// Live-out values from the module can never have ranges strictly before any
// other value.
if (a.live_out_of_module()) {
VLOG(4) << "a is live out of module";
return false;
}
// Live-out values of computations can never have ranges strictly before any
// other value in the computation (including values nested in
// subcomputations).
if (a.live_out_of_computation() &&
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
a.defining_instruction()->parent())) {
VLOG(4) << "a is live out of computation containing b";
return false;
}
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
if (!UseIsBeforeValueDefinition(use, b, ordering)) {
VLOG(4) << "use of a (" << use << ") not before b is defined";
return false;
}
}
return true;
}
bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const {
// Buffers without disjoint liveness may interfere.
return !LiveRangeStrictlyBefore(a, b, ordering) &&
!LiveRangeStrictlyBefore(b, a, ordering);
}
/* static */
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
HloModule* module, bool ssa_form, bool bitcast_defines_value) {
@ -855,6 +619,33 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
}
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
// Add in positions to all values.
for (const std::unique_ptr<HloComputation>& computation :
module->computations()) {
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
for (const auto& pair :
dataflow_analysis->GetInstructionValueSet(instruction.get())) {
const ShapeIndex& index = pair.first;
const HloValueSet& value_set = pair.second;
for (const HloValue* value : value_set.values()) {
if (value->defining_instruction() != instruction.get()) {
dataflow_analysis->GetValue(value->id())
.AddPosition(instruction.get(), index);
}
}
}
}
}
// Construct vector of values.
dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
for (auto& pair : dataflow_analysis->values_) {
dataflow_analysis->values_vector_.push_back(&pair.second);
}
std::sort(dataflow_analysis->values_vector_.begin(),
dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
TF_DCHECK_OK(dataflow_analysis->Verify());
XLA_VLOG_LINES(1, dataflow_analysis->ToString());
@ -865,14 +656,14 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
Status HloDataflowAnalysis::Verify() const {
// Verify each HloValue appears in the value sets that the value's positions()
// indicate.
for (const HloValue& value : values()) {
for (const HloPosition& position : value.positions()) {
for (const HloValue* value : values()) {
for (const HloPosition& position : value->positions()) {
const HloValueSet& value_set = GetValueSet(position);
TF_RET_CHECK(std::find(value_set.values().begin(),
value_set.values().end(),
&value) != value_set.values().end())
value) != value_set.values().end())
<< "Value set at position " << position << " does not contain value "
<< value.ToShortString();
<< value->ToShortString();
}
}
@ -898,75 +689,4 @@ Status HloDataflowAnalysis::Verify() const {
return Status::OK();
}
Status HloDataflowAnalysis::VerifyAgainstReference() const {
TF_RETURN_IF_ERROR(Verify());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> reference,
Run(module_, ssa_form_, bitcast_defines_value_));
TF_RETURN_IF_ERROR(reference->Verify());
VLOG(2) << "This analysis:";
XLA_VLOG_LINES(2, ToString());
VLOG(2) << "Reference:";
XLA_VLOG_LINES(2, reference->ToString());
// Verify value sets in each position are identical.
for (const auto& computation : module_->computations()) {
for (const auto& instruction : computation->instructions()) {
for (const auto& pair : GetInstructionValueSet(instruction.get())) {
const ShapeIndex& index = pair.first;
const HloValueSet& value_set = pair.second;
const HloValueSet& reference_value_set =
reference->GetValueSet(instruction.get(), index);
auto value_in_set = [](const HloValue& v, const HloValueSet& vset) {
return std::find_if(vset.values().begin(), vset.values().end(),
[&v](const HloValue* w) { return *w == v; }) !=
vset.values().end();
};
for (const HloValue* value : value_set.values()) {
TF_RET_CHECK(value_in_set(*value, reference_value_set))
<< "Value " << value->ToShortString()
<< " does not exist in reference";
}
for (const HloValue* reference_value : reference_value_set.values()) {
TF_RET_CHECK(value_in_set(*reference_value, value_set))
<< "Value " << reference_value->ToShortString()
<< " only exists in reference";
}
}
}
}
// Verify all phis resolve identically and uses are identical.
for (const HloValue& value : values()) {
const HloValue& reference_value = reference->GetValueDefinedAt(
value.defining_instruction(), value.defining_index());
TF_RET_CHECK(value.is_phi() == reference_value.is_phi());
if (value.is_phi()) {
const HloValue* resolved_value = ResolvePhi(value);
const HloValue* reference_resolved_value =
reference->ResolvePhi(reference_value);
if (resolved_value == nullptr) {
TF_RET_CHECK(reference_resolved_value == nullptr);
} else {
TF_RET_CHECK(reference_resolved_value != nullptr);
TF_RET_CHECK(*reference_resolved_value == *resolved_value);
}
}
for (const HloUse& use : value.uses()) {
TF_RET_CHECK(std::find(reference_value.uses().begin(),
reference_value.uses().end(),
use) != reference_value.uses().end());
}
for (const HloUse& reference_use : reference_value.uses()) {
TF_RET_CHECK(std::find(value.uses().begin(), value.uses().end(),
reference_use) != value.uses().end());
}
}
return Status::OK();
}
} // namespace xla

View File

@ -88,10 +88,10 @@ class HloDataflowAnalysis {
// given position.
const HloValueSet& GetValueSet(const HloInstruction* instruction,
const ShapeIndex& index = {}) const;
HloValueSet& GetValueSet(const HloInstruction* instruction,
const ShapeIndex& index = {});
const HloValueSet& GetValueSet(const HloPosition& position) const;
HloValueSet& GetValueSet(const HloPosition& position);
HloValueSet& GetValueSet(const HloInstruction* instruction,
const ShapeIndex& index = {});
// Return the unique value in the HloValueSet at the given instruction and
// shape index. CHECKs if the value set does not contain a exactly one value.
@ -108,49 +108,11 @@ class HloDataflowAnalysis {
const HloValue& GetValue(HloValue::Id value_id) const;
HloValue& GetValue(HloValue::Id value_id);
// Returns whether the given values interfere assuming the given HLO
// ordering. Two values interfere if they may both be simultaneously live.
bool MayInterfere(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const;
// Overload which takes HloValue:Ids.
bool MayInterfere(HloValue::Id a, HloValue::Id b,
const HloOrdering& ordering) const {
return MayInterfere(GetValue(a), GetValue(b), ordering);
}
// Return the total number of HloValues.
int64 value_count() const { return values_.size(); }
// Return a vector of all HloValues.
const std::vector<HloValue>& values() const { return values_; }
// Updates the dataflow after the changing an operand of
// 'instruction'. Dataflow update is not possible if instructions have been
// added or removed from the graph.
void UpdateAfterChangingOperand(HloInstruction* instruction,
HloInstruction* old_operand,
HloInstruction* new_operand);
// Updates the dataflow after the changing the root of a computation from
// 'old_root' to 'new_root'.
void UpdateAfterChangingRoot(HloInstruction* old_root,
HloInstruction* new_root);
// Returns the non-phi HloValue that is the unique (transitive) input to the
// given phi. If no such HloValue exists (there are multiple inputs to the
// phi) then nullptr is returned. This is computed by all walking the inputs
// of the given phi value until non-phi HloValue(s) are encountered.
const HloValue* ResolvePhi(const HloValue& phi) const;
const HloValue* ResolvePhi(const HloInstruction* instruction,
const ShapeIndex& index = {}) const {
return ResolvePhi(GetValueDefinedAt(instruction, index));
}
// Compare the dataflow analysis against a clean recomputation of the
// analysis. Returns an error status if there is a mismatch. Useful for
// verifying the correctness after updates to the analysis.
Status VerifyAgainstReference() const;
// Return a vector of all HloValues stabily sorted by HloValue::Id.
const std::vector<const HloValue*>& values() const { return values_vector_; }
// Return the call graph used for computing the dataflow.
const CallGraph& call_graph() const { return *call_graph_; }
@ -161,6 +123,13 @@ class HloDataflowAnalysis {
HloDataflowAnalysis(HloModule* module, bool ssa_form,
bool bitcast_defines_value = false);
// Returns a new HloValue defined at the given instruction and shape index.
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
bool is_phi = false);
// Delete the HloValue with the given ID.
void DeleteHloValue(HloValue::Id value_id);
// Constructs and initializes the InstructionValueSets of all instructions to
// contain exactly the HloValues defined by each instruction. These values can
// then propagated throughout the HLO graph by calling
@ -187,10 +156,11 @@ class HloDataflowAnalysis {
void UpdateInstructionsAndPropagate(
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
// Sets the inputs of the given phi to given value(s).
void UpdatePhiInputs(
const HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
// Return the result of the SSA Phi function applied to the given inputs at
// the given instruction. If skip_top_level is true, then the top level of the
// value set of 'instruction' is not modified.
bool Phi(HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
// Updates the positions of the HloValues in the output of the given
// instruction. This should be called after the instruction value set of
@ -203,20 +173,6 @@ class HloDataflowAnalysis {
HloInstruction* instruction, const InstructionValueSet& new_value_set,
const InstructionValueSet* prev_value_set = nullptr);
// Returns true if the live range of the given value 'a' is strictly before
// the live range of value 'b' using the given HLO ordering.
bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const;
// Returns whether the value 'a' is defined before the value 'b' under the
// given ordering.
bool IsDefinedBefore(const HloValue& a, const HloValue& b,
const HloOrdering& ordering) const;
// Returns whether the given use is before the given value definition.
bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value,
const HloOrdering& ordering) const;
// Verify various invariants of the dataflow analysis.
Status Verify() const;
@ -226,19 +182,19 @@ class HloDataflowAnalysis {
std::unique_ptr<CallGraph> call_graph_;
// Array of all values in the module. This is allocated once at analysis
// construction time so HloValue references are stable. Updates to the
// analysis via UpdateAfterChangingOperand and UpdateAfterChangingRoot do not
// result in the creation or destruction of any HloValues.
std::vector<HloValue> values_;
// Map hold the inputs to each phi value in the module. Used by ResolvePhi.
tensorflow::gtl::FlatMap<const HloValue*,
tensorflow::gtl::InlinedVector<const HloValue*, 2>>
phi_inputs_;
// The map of all HloValues in the module. We pass around pointers to the
// mapped HloValues, so the underlying container must keep them valid despite
// mutations touching other map entries.
std::unordered_map<HloValue::Id, HloValue> values_;
// A map from instruction to InstructionValueSet.
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
// A vector containing all HloValues sorted by HloValue::Id.
std::vector<const HloValue*> values_vector_;
// The Id to use for the next HloValue.
HloValue::Id next_value_id_ = 0;
};
} // namespace xla

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@ -44,8 +43,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
// Run dataflow analysis on the member module. For convenience returns a
// reference to the generated analysis stored in analysis_.
HloDataflowAnalysis& RunAnalysis(bool ssa_form,
bool bitcast_defines_value = false) {
const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
bool bitcast_defines_value = false) {
analysis_ =
HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
.ConsumeValueOrDie();
@ -71,8 +70,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
const HloInstruction* b) {
EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a),
analysis_->GetValueDefinedAt(b), ordering);
return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
analysis_->GetValueDefinedAt(b));
}
std::unique_ptr<HloModule> module_;
@ -499,37 +498,26 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
if (ssa_form) {
// While instruction should define phi values. The value at index {0} is a
// degenerate phi with a single input 'constant1'.
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{0}),
&analysis.GetValueDefinedAt(constant1));
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{0}),
&analysis.GetValueDefinedAt(constant1));
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{0}),
&analysis.GetValueDefinedAt(constant1));
// Element 0 of the tuple passed through the body so no phi value is
// defined.
EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
// Element 1 of the tuple should be a phi value.
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{1}), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{1}), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{1}), nullptr);
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{xla_while, 0, {0}}));
EXPECT_THAT(
analysis.GetValueDefinedAt(constant1).uses(),
UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}}));
EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
.live_out_of_module());
// Constant1 passes through the body and out of the module.
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
.live_out_of_module());
@ -613,20 +601,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
if (ssa_form) {
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while2).live_out_of_module());
EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
} else {
// Element 0 is passed through all the while instructions and out of the
// module.
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
}
// Element 0 is passed through all the while instructions and out of the
// module..
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
analysis.GetValueDefinedAt(constant1));
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
}
TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
@ -705,13 +688,18 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
if (ssa_form) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
// Element 0 of the nested while is %negate.
EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
// Element 1 is a phi value (join of %add and %constant2).
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
@ -724,8 +712,6 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
EXPECT_TRUE(
analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
} else {
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
UnorderedElementsAre(analysis.GetValueDefinedAt(add),
analysis.GetValueDefinedAt(constant2)));
@ -1496,256 +1482,6 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
}
TEST_P(HloDataflowAnalysisTest, UpdateAnalysisForWhile) {
// Test updating dataflow after modifying a module with an array shaped while:
//
// body(F32[] %param):
// %negate = Negate(%param)
//
// condition(F32[] %param):
// return Constant(false)
//
// entry:
// %constant = Constant(1.0)
// %exp = Exp(%constant)
// return While(%exp, body, condition)
//
auto body_builder = HloComputation::Builder("body");
auto body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kNegate, body_param));
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
// Condition computation trivially returns a constant "false".
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
// Sanity check the initial dataflow analysis before transforming the HLO
// graph.
if (ssa_form) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param));
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param).is_phi());
EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param));
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param).is_phi());
EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
} else {
EXPECT_THAT(HloValuesAt(body_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(cond_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(xla_while),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
analysis.GetValueDefinedAt(negate)));
EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
}
// Set the body root to the body_param. Previously it was Negate(body_param).
body->set_root_instruction(body_param);
// Prior to updating, verify that the dataflow analysis is no longer valid.
Status verify_status = analysis.VerifyAgainstReference();
EXPECT_FALSE(verify_status.ok());
analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
/*new_root=*/body_param);
// Analysis should be valid after the update.
TF_EXPECT_OK(analysis.VerifyAgainstReference());
if (ssa_form) {
// The phis should now be resolvable as 'exp' is passed through the body
// transparently.
EXPECT_EQ(analysis.ResolvePhi(body_param),
&analysis.GetValueDefinedAt(exp));
EXPECT_EQ(analysis.ResolvePhi(cond_param),
&analysis.GetValueDefinedAt(exp));
EXPECT_EQ(analysis.ResolvePhi(xla_while), &analysis.GetValueDefinedAt(exp));
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
} else {
EXPECT_THAT(HloValuesAt(body_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
EXPECT_THAT(HloValuesAt(cond_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
EXPECT_THAT(HloValuesAt(xla_while),
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
}
EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
// Now replace the operand of the while with %constant (was %exp).
TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
/*new_operand=*/constant);
// Verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
if (ssa_form) {
// The phis now resolve to 'constant'.
EXPECT_EQ(analysis.ResolvePhi(body_param),
&analysis.GetValueDefinedAt(constant));
EXPECT_EQ(analysis.ResolvePhi(cond_param),
&analysis.GetValueDefinedAt(constant));
EXPECT_EQ(analysis.ResolvePhi(xla_while),
&analysis.GetValueDefinedAt(constant));
} else {
EXPECT_THAT(HloValuesAt(body_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
EXPECT_THAT(HloValuesAt(cond_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
EXPECT_THAT(HloValuesAt(xla_while),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
}
// And finally make the negate the root of the body again.
body->set_root_instruction(negate);
analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
/*new_root=*/negate);
// Verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
if (ssa_form) {
// Phis should no longer be resolvable.
EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
EXPECT_EQ(analysis.ResolvePhi(xla_while), nullptr);
} else {
EXPECT_THAT(HloValuesAt(body_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(cond_param),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(xla_while),
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
analysis.GetValueDefinedAt(negate)));
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
}
// After the updates, verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
}
TEST_P(HloDataflowAnalysisTest, UpdateOfATupleSelect) {
// Test changing the operands of kSelects of a tuple value and updating the
// dataflow.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
auto a = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto b = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
auto c = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
auto d = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
auto tuple_a = builder.AddInstruction(HloInstruction::CreateTuple({a}));
auto tuple_b = builder.AddInstruction(HloInstruction::CreateTuple({b}));
auto tuple_c = builder.AddInstruction(HloInstruction::CreateTuple({c}));
auto tuple_d = builder.AddInstruction(HloInstruction::CreateTuple({d}));
const Shape tuple_shape = tuple_a->shape();
auto select_aa = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_a));
auto select_ab = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_b));
auto select_cd = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred, tuple_c, tuple_d));
auto select_abcd = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred, select_ab, select_cd));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
// Sanity check dataflow before changing the graph and updating.
EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a)));
EXPECT_THAT(HloValuesAt(select_ab, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(b)));
EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(c),
analysis.GetValueDefinedAt(d)));
EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(b),
analysis.GetValueDefinedAt(c),
analysis.GetValueDefinedAt(d)));
EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(c).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
// Set the rhs of 'select_aa' to be 'd'.
TF_ASSERT_OK(select_aa->ReplaceOperandWith(2, tuple_d));
analysis.UpdateAfterChangingOperand(select_aa, /*old_operand=*/tuple_a,
/*new_operand=*/tuple_d);
// Verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(d)));
// Set the lhs of 'select_cd' to be 'a'.
TF_ASSERT_OK(select_cd->ReplaceOperandWith(1, tuple_a));
analysis.UpdateAfterChangingOperand(select_cd, /*old_operand=*/tuple_c,
/*new_operand=*/tuple_a);
// Verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(d)));
EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
analysis.GetValueDefinedAt(b),
analysis.GetValueDefinedAt(d)));
EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
EXPECT_FALSE(analysis.GetValueDefinedAt(c).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
// After the updates, verify that the dataflow is correct.
TF_ASSERT_OK(analysis.VerifyAgainstReference());
}
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
HloDataflowAnalysisTest,
::testing::Values(false, true));

View File

@ -561,13 +561,21 @@ tooltip = " ";
}
string comp_body = DumpComputation(subcomp);
string computation =
Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
// Add an edge from the subcomputation to its parent node. If subcomp
// belongs to a fusion node, it's drawn in place of the fusion instruction, so
// there's no need to link those.
if (parent_instr->opcode() != HloOpcode::kFusion) {
if (parent_instr->opcode() == HloOpcode::kFusion) {
// Dump any nested fusion nodes.
for (const auto& subcomp_instr : subcomp->instructions()) {
if (subcomp_instr->opcode() == HloOpcode::kFusion) {
StrAppend(
&comp_body,
DumpSubcomputation(subcomp_instr->fused_instructions_computation(),
subcomp_instr.get()));
}
}
} else {
// Add an edge from the subcomputation to its parent node. If subcomp
// belongs to a fusion node, it's drawn in place of the fusion instruction,
// so there's no need to link those.
edge_ids_.insert(
{{subcomp->root_instruction(), parent_instr}, next_edge_id_++});
const char* edge_fmt =
@ -578,6 +586,9 @@ tooltip = " ";
subcomp->name(), parent_instr->name()));
}
string computation =
Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
return computation;
}

View File

@ -793,13 +793,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
}
}
for (HloComputation* computation :
instruction_to_fuse->called_computations()) {
if (std::find(called_computations_.begin(), called_computations_.end(),
computation) == called_computations_.end()) {
called_computations_.push_back(computation);
}
}
VLOG(2) << "New clone:\n" << clone->ToString();
return clone;
}
@ -1669,6 +1662,21 @@ string HloInstruction::ExtendedOpcodeStr() const {
string HloInstruction::ToString(bool compact_operands,
bool include_metadata) const {
string result =
StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
ExtendedOpcodeStr(), "(", OperandsToString(compact_operands), ")");
for (const string& extra : ExtraAttributesToString()) {
StrAppend(&result, ", ", extra);
}
if (include_metadata &&
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty())) {
StrAppend(&result, " # metadata=", metadata_.ShortDebugString());
}
return result;
}
string HloInstruction::OperandsToString(bool compact) const {
string operands;
if (opcode() == HloOpcode::kConstant) {
// For constants, show the actual value in place of an empty operand list.
@ -1697,12 +1705,12 @@ string HloInstruction::ToString(bool compact_operands,
} else {
tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
const int64 kMaxOperandsToShowIfCompact = 4;
if (compact_operands && slice.size() > kMaxOperandsToShowIfCompact) {
if (compact && slice.size() > kMaxOperandsToShowIfCompact) {
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
*out += ShapeUtil::HumanStringWithLayout(operand->shape());
if (!compact_operands) {
if (!compact) {
StrAppend(out, " ", operand->name());
}
});
@ -1711,15 +1719,19 @@ string HloInstruction::ToString(bool compact_operands,
StrAppend(&operands, ", ...(+", remaining, ")");
}
}
string extra;
return operands;
}
std::vector<string> HloInstruction::ExtraAttributesToString() const {
std::vector<string> extra;
if (CanHaveDimensionsField()) {
StrAppend(&extra, ", dimensions={", Join(dimensions(), ","), "}");
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
}
if (window_ != nullptr) {
StrAppend(&extra, ", ", window_util::ToString(*window_));
extra.push_back(window_util::ToString(*window_));
}
if (padding_config_ != nullptr) {
StrAppend(&extra, ", padding=", padding_config_->ShortDebugString());
extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
}
if (!slice_starts_.empty() && !slice_limits_.empty()) {
std::vector<string> bounds;
@ -1728,45 +1740,38 @@ string HloInstruction::ToString(bool compact_operands,
bounds.push_back(
StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]"));
}
StrAppend(&extra, ", slice={", Join(bounds, ", "), "}");
extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
}
if (convolution_dimension_numbers_ != nullptr) {
StrAppend(&extra, ", ", ConvolutionDimensionNumbersToString());
extra.push_back(ConvolutionDimensionNumbersToString());
}
if (opcode() == HloOpcode::kWhile) {
StrAppend(&extra, ", condition=", while_condition()->name());
StrAppend(&extra, ", body=", while_body()->name());
extra.push_back(StrCat("condition=", while_condition()->name()));
extra.push_back(StrCat("body=", while_body()->name()));
} else if (opcode() == HloOpcode::kSelectAndScatter) {
StrAppend(&extra, ", select=", select()->name());
StrAppend(&extra, ", scatter=", scatter()->name());
extra.push_back(StrCat("select=", select()->name()));
extra.push_back(StrCat("scatter=", scatter()->name()));
} else if (!called_computations().empty()) {
StrAppend(&extra, ", calls=",
Join(called_computations(), ", ",
[](string* out, const HloComputation* computation) {
StrAppend(out, computation->name());
}));
extra.push_back(StrCat(
"calls=", Join(called_computations(), ", ",
[](string* out, const HloComputation* computation) {
StrAppend(out, computation->name());
})));
}
if (opcode() == HloOpcode::kGetTupleElement) {
StrAppend(&extra, ", index=", tuple_index());
extra.push_back(StrCat("index=", tuple_index()));
}
if (!control_successors_.empty()) {
StrAppend(
&extra, ", control-successors=",
extra.push_back(StrCat(
"control-successors=",
Join(control_successors_, ", ", [](string* out, HloInstruction* succ) {
StrAppend(out, succ->name());
}));
})));
}
if (include_metadata &&
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty())) {
StrAppend(&extra, " # metadata=", metadata_.ShortDebugString());
}
return StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
ExtendedOpcodeStr(), "(", operands, ")", extra);
return extra;
}
string HloInstruction::ToShortString() const {

View File

@ -548,6 +548,14 @@ class HloInstruction {
string ToString(bool compact_operands = false,
bool include_metadata = true) const;
// Components of the ToString() representation:
// Returns a string representation of the operand list.
string OperandsToString(bool compact) const;
// Returns string representation of op-specific attributes.
std::vector<string> ExtraAttributesToString() const;
string ToStringNoMetadata() const { return ToString(false, false); }
// As ToString, but returns a shorter string.
@ -797,8 +805,7 @@ class HloInstruction {
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
// Returns the computations this instruction calls (if any). This includes
// computations called by fused instructions inside of a fusion instruction.
// Returns the computations this instruction directly calls (if any).
const std::vector<HloComputation*>& called_computations() const {
return called_computations_;
}

View File

@ -758,16 +758,13 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
auto* fusion = computation->CreateFusionInstruction(
{map_3_y}, HloInstruction::FusionKind::kLoop);
auto* fused_computation = fusion->fused_instructions_computation();
EXPECT_THAT(fusion->called_computations(),
ElementsAre(fused_computation, computation_y));
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
fusion->FuseInstruction(map_2_x);
EXPECT_THAT(fusion->called_computations(),
ElementsAre(fused_computation, computation_y, computation_x));
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
fusion->FuseInstruction(map_1_x);
EXPECT_THAT(fusion->called_computations(),
ElementsAre(fused_computation, computation_y, computation_x));
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
}
TEST_F(HloInstructionTest, ComplexFusionOp) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
@ -218,6 +219,94 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
}
TEST_F(HloOrderingTest, ValuesInWhileComputations) {
// Tests the ordering of values (defined by dataflow analysis) in the body and
// condition of a while instruction. HLO code:
//
// body(F32[]) %param):
// %negate = Negate(%param)
//
// condition(F32[] %param):
// %convert = Convert<PRED>(%param)
//
// entry:
// %constant = Constant(1.0)
// %while = While(%constant, body, condition)
// %add = Add(%constant, %while)
//
auto module = CreateNewModule();
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
auto body_builder = HloComputation::Builder("body");
auto body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape, HloOpcode::kNegate, body_param));
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
HloComputation* condition =
module->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, constant, xla_while));
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(
auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true));
DependencyHloOrdering ordering(module.get());
// Init value is defined before the while, but live range is not before the
// while because of the use of the init value in the add.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_FALSE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
// Any value defined in the body or condition is defined before the while, and
// has a live range strictly before the while.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
// The live range of the while should be before the add.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
dataflow->GetValueDefinedAt(add)));
ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
EXPECT_EQ(while_use.instruction, add);
EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
while_use, dataflow->GetValueDefinedAt(add)));
EXPECT_TRUE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while),
dataflow->GetValueDefinedAt(add)));
}
} // namespace
} // namespace xla

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
namespace xla {
@ -54,11 +55,18 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
<< tensorflow::str_util::Join(disabled_passes, ", ");
}
auto run_invariant_checkers = [this, module]() -> Status {
auto run_invariant_checkers = [this,
module](const string& message) -> Status {
for (auto& invariant_checker : invariant_checkers_) {
VLOG(1) << " Invariant checker " << invariant_checker->name();
TF_ASSIGN_OR_RETURN(bool changed, invariant_checker->Run(module));
TF_RET_CHECK(!changed) << "invariant checkers must not change the graph";
StatusOr<bool> changed_status = invariant_checker->Run(module);
if (!changed_status.ok()) {
return Status(changed_status.status().code(),
StrCat(changed_status.status().error_message(),
"\n\nFailed ", message));
}
TF_RET_CHECK(!changed_status.ValueOrDie())
<< "invariant checkers must not change the graph";
}
return Status::OK();
};
@ -66,6 +74,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
string prefix = name().ToString() + ": pipeline start";
bool changed = false;
string message;
TF_RETURN_IF_ERROR(
run_invariant_checkers(StrCat("before running pipeline: ", name())));
for (auto& pass : passes_) {
if (disabled_passes.count(pass->name().ToString()) > 0) {
VLOG(1) << " Skipping HLO pass " << pass->name()
@ -80,14 +90,14 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
StrAppend(&message, prefix, ", before ", pass->name());
DumpModule(*module, message);
TF_RETURN_IF_ERROR(run_invariant_checkers());
TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
TF_RETURN_IF_ERROR(
run_invariant_checkers(StrCat("after running pass: ", pass->name())));
changed |= changed_this_pass;
prefix.clear();
StrAppend(&prefix, name(), ": after ", pass->name());
}
TF_RETURN_IF_ERROR(run_invariant_checkers());
DumpModule(*module, prefix + ", pipeline end");
return changed;
}

View File

@ -1202,7 +1202,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
StatusOr<bool> HloRematerialization::Run(
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
int64 memory_limit_bytes) {
int64 memory_limit_bytes, RematerializationSizes* sizes) {
// The sequence is constructed entirely by this method.
TF_RET_CHECK(sequence->empty());
@ -1248,7 +1248,8 @@ StatusOr<bool> HloRematerialization::Run(
sequence->at(node.computation())));
}
return Status::OK();
}));
},
/*visit_unreachable_nodes=*/false));
// The peak memory usage of the module equals the peak memory use of the entry
// computation plus the output size of the computation. This is because the
@ -1318,13 +1319,20 @@ StatusOr<bool> HloRematerialization::Run(
<< HumanReadableNumBytes(reduced_peak_memory) << " ("
<< reduced_peak_memory << " bytes)";
if (sizes != nullptr) {
sizes->before_bytes = before_peak_memory;
sizes->after_bytes = current_peak_memory;
}
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
if (current_peak_memory > memory_limit_bytes) {
LOG(WARNING) << "Can't reduce memory use below "
<< HumanReadableNumBytes(memory_limit_bytes)
<< " by rematerialization (only reduced to "
<< HumanReadableNumBytes(current_peak_memory) << ")";
LOG(WARNING) << tensorflow::strings::Printf(
"Can't reduce memory use below %s (%lld bytes) by rematerialization; "
"only reduced to %s (%lld bytes)",
HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes,
HumanReadableNumBytes(current_peak_memory).c_str(),
current_peak_memory);
}
return changed;
@ -1333,9 +1341,10 @@ StatusOr<bool> HloRematerialization::Run(
/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
const HloRematerialization::ShapeSizeFunction& size_function,
int64 memory_limit_bytes, HloModule* hlo_module,
SequentialHloOrdering::HloModuleSequence* sequence) {
SequentialHloOrdering::HloModuleSequence* sequence,
RematerializationSizes* sizes) {
HloRematerialization remat(size_function);
return remat.Run(hlo_module, sequence, memory_limit_bytes);
return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
}
} // namespace xla

View File

@ -28,6 +28,13 @@ class HloRematerialization {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
// Helper struct that communicates the before / after sizes for the
// rematerialization process.
struct RematerializationSizes {
int64 before_bytes;
int64 after_bytes;
};
// Rematerialize HLO instructions in the given module to reduce peak memory
// use below memory_limit_bytes where memory use is defined as the total size
// of all live HLO instruction values. Parameters and constants are included
@ -46,6 +53,9 @@ class HloRematerialization {
// rematerialization. This is the order in which HLO instructions should
// be emitted to minimize memory use.
//
// sizes: Optional outparam that indicates the peak memory usage of the HLO
// module before/after rematerialization.
//
// Returns whether any instructions were rematerialized. If memory use is
// already below the given limit then no instructions are rematerialized and
// false is returned.
@ -55,8 +65,8 @@ class HloRematerialization {
// code generation.
static StatusOr<bool> RematerializeAndSchedule(
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
HloModule* hlo_module,
SequentialHloOrdering::HloModuleSequence* sequence);
HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence,
RematerializationSizes* sizes = nullptr);
protected:
HloRematerialization(const ShapeSizeFunction& size_function)
@ -69,7 +79,7 @@ class HloRematerialization {
// contains the memory-minimizing order in which to emit the HLO instructions.
StatusOr<bool> Run(HloModule* module,
SequentialHloOrdering::HloModuleSequence* sequence,
int64 memory_limit);
int64 memory_limit, RematerializationSizes* sizes);
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the

View File

@ -159,12 +159,6 @@ void HloValue::AddPosition(HloInstruction* instruction,
for (const HloPosition& position : positions_) {
DCHECK_NE(position, new_position);
}
// The shape of the new position must match existing positions.
if (!positions_.empty()) {
CHECK(
ShapeUtil::Compatible(positions_.front().shape(), new_position.shape()))
<< "front: " << positions_.front() << " new: " << new_position;
}
positions_.push_back(std::move(new_position));

View File

@ -225,6 +225,9 @@ class HloValueSet {
// already exist in the set.
bool AddValue(const HloValue* value);
// Clear all values from the set.
void Clear() { values_.clear(); }
// Return the unique HLO value in the set. CHECKs if the set does not contain
// exactly one value.
const HloValue& GetUniqueValue() const {

View File

@ -32,13 +32,11 @@ class ShapeVerifier : public DfsHloVisitor {
const std::function<int64(const Shape&)>& shape_size_fn)
: shape_size_fn_(shape_size_fn) {}
Status HandleElementwiseUnary(HloInstruction* hlo,
HloOpcode opcode) override {
Status HandleElementwiseUnary(HloInstruction* hlo) override {
return CheckUnaryShape(hlo);
}
Status HandleElementwiseBinary(HloInstruction* hlo,
HloOpcode opcode) override {
Status HandleElementwiseBinary(HloInstruction* hlo) override {
return CheckBinaryShape(hlo);
}
@ -282,6 +280,14 @@ class ShapeVerifier : public DfsHloVisitor {
const std::function<int64(const Shape&)> shape_size_fn_;
};
string ComputationsToString(
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
return tensorflow::str_util::Join(
computations, ",", [](string* s, const HloComputation* computation) {
s->append(computation->name());
});
}
} // namespace
StatusOr<bool> HloVerifier::Run(HloModule* module) {
@ -292,6 +298,17 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
for (const auto& instruction : computation->instructions()) {
TF_RET_CHECK(instruction->parent() == computation.get());
if (instruction->opcode() == HloOpcode::kFusion) {
TF_RET_CHECK(
ContainersEqual(instruction->called_computations(),
{instruction->fused_instructions_computation()}))
<< "Fusion HLO calls computations other than the "
"fused_instructions_computation: "
<< instruction->ToString()
<< " instruction->fused_instructions_computation(): "
<< instruction->fused_instructions_computation()->ToString()
<< " instruction->called_computations(): "
<< ComputationsToString(instruction->called_computations());
for (const auto& fused : instruction->fused_instructions()) {
TF_RET_CHECK(fused->parent() ==
instruction->fused_instructions_computation())

View File

@ -122,7 +122,9 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
continue;
}
if (instruction->opcode() == HloOpcode::kFusion) {
if (instruction->opcode() == HloOpcode::kFusion &&
(instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
instruction->fusion_kind() == HloInstruction::FusionKind::kInput)) {
// Insert the reduce-precision operation inside the fusion computation,
// after the corresponding parameter instruction.
TF_ASSIGN_OR_RETURN(
@ -171,7 +173,9 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
continue;
}
if (instruction->opcode() == HloOpcode::kFusion) {
if (instruction->opcode() == HloOpcode::kFusion &&
(instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
instruction->fusion_kind() == HloInstruction::FusionKind::kOutput)) {
// Insert the reduce-precision operation as the last operation inside
// the fusion computation.
HloInstruction* fusion_root = instruction->fused_expression_root();

View File

@ -215,6 +215,7 @@ cc_library(
],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:test",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <cstdlib>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {

View File

@ -111,6 +111,11 @@ cc_binary(
deps = [
":replay_computation_library",
"//tensorflow/compiler/plugin/executor:plugin_lib",
# TODO: This dependency is a workaround for linking error with clang.
# Without it, linker complains about missing symbols from
# 'xla_device_launch_op'. This dependency should be propagated from
# plugin_lib instead, but no targets other than this break without it.
"//tensorflow/compiler/jit",
],
)

View File

@ -144,7 +144,7 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args,
int main(int argc, char** argv) {
// Flags
string fake_infeed_shape;
xla::string fake_infeed_shape;
bool use_fake_data = false;
const std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("use_fake_data", &use_fake_data,

View File

@ -28,6 +28,7 @@ py_library(
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/fused_conv:fused_conv_py",
"//tensorflow/contrib/gan",
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/hooks",
@ -72,6 +73,7 @@ py_library(
"//tensorflow/contrib/staging",
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
"//tensorflow/contrib/stateless",
"//tensorflow/contrib/summary:summary_ops",
"//tensorflow/contrib/tensor_forest:init_py",
"//tensorflow/contrib/tensorboard",
"//tensorflow/contrib/testing:testing_py",

View File

@ -31,6 +31,7 @@ from tensorflow.contrib import deprecated
from tensorflow.contrib import distributions
from tensorflow.contrib import factorization
from tensorflow.contrib import framework
from tensorflow.contrib import gan
from tensorflow.contrib import graph_editor
from tensorflow.contrib import grid_rnn
from tensorflow.contrib import image

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2
@ -26,18 +29,21 @@ from tensorflow.contrib.learn.python.learn import export_strategy
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader as saved_model_loader
from tensorflow.python.saved_model import tag_constants
def make_custom_export_strategy(name, convert_fn, feature_columns,
def make_custom_export_strategy(name,
convert_fn,
feature_columns,
export_input_fn):
"""Makes custom exporter of GTFlow tree format.
Args:
name: A string, for the name of the export strategy.
convert_fn: A function that converts the tree proto to desired format and
saves it to the desired location.
saves it to the desired location. Can be None to skip conversion.
feature_columns: A list of feature columns.
export_input_fn: A function that takes no arguments and returns an
`InputFnOps`.
@ -68,9 +74,22 @@ def make_custom_export_strategy(name, convert_fn, feature_columns,
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
dtec.ParseFromString(dfec_str)
# Export the result in the same folder as the saved model.
convert_fn(dtec, sorted_feature_names, len(dense_floats),
len(sparse_float_indices), len(sparse_int_indices),
result_dir, eval_result)
if convert_fn:
convert_fn(dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices),
len(sparse_int_indices), result_dir, eval_result)
feature_importances = _get_feature_importances(
dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices), len(sparse_int_indices))
sorted_by_importance = sorted(
feature_importances.items(), key=lambda x: -x[1])
assets_dir = os.path.join(result_dir, "assets.extra")
gfile.MakeDirs(assets_dir)
with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
"w") as f:
f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
return result_dir
return export_strategy.ExportStrategy(name, export_fn)
@ -157,3 +176,41 @@ def convert_to_universal_format(dtec, sorted_feature_names,
node.left_child_id.value = split.left_id
node.right_child_id.value = split.right_id
return model_and_features
def _get_feature_importances(dtec, feature_names, num_dense_floats,
num_sparse_float, num_sparse_int):
"""Export the feature importance per feature column."""
del num_sparse_int # Unused.
sums = collections.defaultdict(lambda: 0)
for tree_idx in range(len(dtec.trees)):
tree = dtec.trees[tree_idx]
for tree_node in tree.nodes:
node_type = tree_node.WhichOneof("node")
if node_type == "dense_float_binary_split":
split = tree_node.dense_float_binary_split
split_column = feature_names[split.feature_column]
elif node_type == "sparse_float_binary_split_default_left":
split = tree_node.sparse_float_binary_split_default_left.split
split_column = feature_names[split.feature_column + num_dense_floats]
elif node_type == "sparse_float_binary_split_default_right":
split = tree_node.sparse_float_binary_split_default_right.split
split_column = feature_names[split.feature_column + num_dense_floats]
elif node_type == "categorical_id_binary_split":
split = tree_node.categorical_id_binary_split
split_column = feature_names[split.feature_column + num_dense_floats +
num_sparse_float]
elif node_type == "categorical_id_set_membership_binary_split":
split = tree_node.categorical_id_set_membership_binary_split
split_column = feature_names[split.feature_column + num_dense_floats +
num_sparse_float]
elif node_type == "leaf":
assert tree_node.node_metadata.gain == 0
continue
else:
raise ValueError("Unexpected split type %s", node_type)
# Apply shrinkage factor. It is important since it is not always uniform
# across different trees.
sums[split_column] += (
tree_node.node_metadata.gain * dtec.tree_weights[tree_idx])
return dict(sums)

View File

@ -27,7 +27,7 @@ from tensorflow.python.platform import googletest
class ConvertModelTest(test_util.TensorFlowTestCase):
def testConvertModel(self):
def _make_trees(self):
dtec_str = """
trees {
nodes {
@ -108,8 +108,12 @@ class ConvertModelTest(test_util.TensorFlowTestCase):
"""
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(dtec_str, dtec)
# The feature columns in the order they were added.
feature_columns = ["feature_b", "feature_a", "feature_d"]
return dtec, feature_columns
def testConvertModel(self):
dtec, feature_columns = self._make_trees()
# The feature columns in the order they were added.
out = custom_export_strategy.convert_to_universal_format(
dtec, feature_columns, 1, 1,
1)
@ -273,6 +277,16 @@ class ConvertModelTest(test_util.TensorFlowTestCase):
}"""
self.assertProtoEquals(expected_tree, out)
def testFeatureImportance(self):
dtec, feature_columns = self._make_trees()
feature_importances = custom_export_strategy._get_feature_importances(
dtec, feature_columns, 1, 1, 1)
self.assertItemsEqual(["feature_b", "feature_a", "feature_d"],
feature_importances.keys())
self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4)
self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4)
self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4)
if __name__ == "__main__":
googletest.main()

View File

@ -61,11 +61,19 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
Raises:
ValueError: If learner_config is not valid.
"""
head = head_lib.multi_class_head(
n_classes=n_classes,
weight_column_name=weight_column_name,
enable_centered_bias=False)
if learner_config.num_classes == 0:
learner_config.num_classes = n_classes
elif learner_config.num_classes != n_classes:
raise ValueError("n_classes (%d) doesn't match learner_config (%d)." %
(learner_config.num_classes, n_classes))
super(GradientBoostedDecisionTreeClassifier, self).__init__(
model_fn=model.model_builder,
params={
@ -129,6 +137,10 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
label_dimension=label_dimension,
weight_column_name=weight_column_name,
enable_centered_bias=False)
if label_dimension == 1:
learner_config.num_classes = 2
else:
learner_config.num_classes = label_dimension
super(GradientBoostedDecisionTreeRegressor, self).__init__(
model_fn=model.model_builder,
params={

View File

@ -92,6 +92,7 @@ def model_builder(features, labels, mode, params, config):
examples_per_layer=examples_per_layer,
learner_config=learner_config,
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
features=features)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)

View File

@ -74,7 +74,7 @@ class TreeEnsembleStampTokenOp : public OpKernel {
decision_tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_ensemble_resource));
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
@ -95,7 +95,7 @@ class TreeEnsembleSerializeOp : public OpKernel {
decision_tree_ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_ensemble_resource));
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),

View File

@ -143,7 +143,7 @@ class GradientTreesPredictionOp : public OpKernel {
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
if (use_locking_) {
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
DoCompute(context, decision_tree_ensemble_resource);
} else {
DoCompute(context, decision_tree_ensemble_resource);
@ -334,7 +334,7 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
// Release the reference to the resource once we're done using it.
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
if (use_locking_) {
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
DoCompute(context, decision_tree_ensemble_resource);
} else {
DoCompute(context, decision_tree_ensemble_resource);

View File

@ -656,7 +656,8 @@ class GrowTreeEnsembleOp : public OpKernel {
CHECK(split->split_info.split_node().node_case() != TreeNode::NODE_NOT_SET);
CHECK(tree_config->nodes(node_id).node_case() == TreeNode::kLeaf)
<< "Unexpected node type to split "
<< tree_config->nodes(node_id).node_case();
<< tree_config->nodes(node_id).node_case() << " for node_id " << node_id
<< ". Tree config: " << tree_config->DebugString();
// Add left leaf.
int32 left_id = tree_config->nodes_size();
@ -767,7 +768,7 @@ class TreeEnsembleStatsOp : public OpKernel {
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&decision_tree_ensemble_resource));
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
// Get the stamp token.
const Tensor* stamp_token_t;

View File

@ -42,6 +42,7 @@ class BiasFeatureColumnHandlerTest : public ::testing::Test {
example_partitions_({0, 0, 1, 3}) {
// Set L2 regularization.
learner_config_.mutable_regularization()->set_l2(2.0f);
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
// Create handler.
handler_.reset(new BiasFeatureColumnHandler(kClassId, kSlotId, kBatchSize));

View File

@ -51,7 +51,7 @@ class CategoricalFeatureColumnHandlerTest : public ::testing::Test {
values_(test::AsTensor<int64>({1, 2, 2, 0}, {4})) {
// Set L2 regularization.
learner_config_.mutable_regularization()->set_l2(2.0f);
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
// Create handler.
handler_.reset(new CategoricalFeatureColumnHandler(
kClassId, kSlotId, kBatchSize, kFeatureColumn, indices_.matrix<int64>(),

View File

@ -51,7 +51,7 @@ class DenseQuantizedFeatureColumnHandlerTest : public ::testing::Test {
dense_quantized_values_(test::AsTensor<int32>({1, 1, 0, 1}, {4})) {
// Set L2 regularization.
learner_config_.mutable_regularization()->set_l2(2.0f);
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
// Create handler.
handler_.reset(new DenseQuantizedFeatureColumnHandler(
kClassId, kSlotId, kBatchSize, kFeatureColumn,

View File

@ -53,7 +53,7 @@ class SparseQuantizedFeatureColumnHandlerTest : public ::testing::Test {
sparse_quantized_values_(test::AsTensor<int32>({1, 0, 1}, {3})) {
// Set L2 regularization.
learner_config_.mutable_regularization()->set_l2(2.0f);
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
// Create handler.
handler_.reset(new SparseQuantizedFeatureColumnHandler(
kClassId, kSlotId, kBatchSize, kFeatureColumn,

View File

@ -30,6 +30,7 @@ const double kDelta = 1e-5;
TEST(NodeStatsTest, AlmostZero) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
NodeStats node_stats(learner_config, GradientStats(1e-8f, 1e-8f));
EXPECT_EQ(0, node_stats.weight_contribution[0]);
EXPECT_EQ(0, node_stats.gain);
@ -37,6 +38,7 @@ TEST(NodeStatsTest, AlmostZero) {
TEST(NodeStatsTest, LessThanMinWeightConstraint) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_constraints()->set_min_node_weight(3.2f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
EXPECT_EQ(0, node_stats.weight_contribution[0]);
@ -45,6 +47,7 @@ TEST(NodeStatsTest, LessThanMinWeightConstraint) {
TEST(NodeStatsTest, L1RegSquashed) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l1(10.0f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
EXPECT_EQ(0, node_stats.weight_contribution[0]);
@ -53,6 +56,7 @@ TEST(NodeStatsTest, L1RegSquashed) {
TEST(NodeStatsTest, L1RegPos) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l1(5.0f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
const float expected_clipped_grad = 7.32f - 5.0f;
@ -66,6 +70,7 @@ TEST(NodeStatsTest, L1RegPos) {
TEST(NodeStatsTest, L1RegNeg) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l1(5.0f);
NodeStats node_stats(learner_config, GradientStats(-7.32f, 1.63f));
const float expected_clipped_grad = -7.32f + 5.0f;
@ -79,6 +84,7 @@ TEST(NodeStatsTest, L1RegNeg) {
TEST(NodeStatsTest, L2Reg) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l2(8.0f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
const float expected_denom = 1.63f + 8.0f;
@ -91,6 +97,7 @@ TEST(NodeStatsTest, L2Reg) {
TEST(NodeStatsTest, L1L2Reg) {
LearnerConfig learner_config;
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
learner_config.mutable_regularization()->set_l1(5.0f);
learner_config.mutable_regularization()->set_l2(8.0f);
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));

View File

@ -15,6 +15,7 @@
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
#include <cstring>
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
@ -34,10 +35,27 @@ class WeightedQuantilesSummary {
struct SummaryEntry {
SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
const WeightType& max)
: value(v), weight(w), min_rank(min), max_rank(max) {}
const WeightType& max) {
// Explicitely initialize all of memory (including padding from memory
// alignment) to allow the struct to be msan-resistant "plain old data".
//
// POD = http://en.cppreference.com/w/cpp/concept/PODType
memset(this, 0, sizeof(*this));
SummaryEntry() : value(0), weight(0), min_rank(0), max_rank(0) {}
value = v;
weight = w;
min_rank = min;
max_rank = max;
}
SummaryEntry() {
memset(this, 0, sizeof(*this));
value = 0;
weight = 0;
min_rank = 0;
max_rank = 0;
}
bool operator==(const SummaryEntry& other) const {
return value == other.value && weight == other.weight &&

View File

@ -17,7 +17,7 @@ message TreeRegularizationConfig {
// Tree constraints config.
message TreeConstraintsConfig {
// Maximum depth of the trees.
// Maximum depth of the trees. The default value is 6 if not specified.
uint32 max_tree_depth = 1;
// Min hessian weight per node.
@ -86,20 +86,22 @@ message LearningRateDropoutDrivenConfig {
message LearnerConfig {
enum PruningMode {
PRE_PRUNE = 0;
POST_PRUNE = 1;
PRUNING_MODE_UNSPECIFIED = 0;
PRE_PRUNE = 1;
POST_PRUNE = 2;
}
enum GrowingMode {
WHOLE_TREE = 0;
// Layer by layer is only supported by the batch learner.
LAYER_BY_LAYER = 1;
GROWING_MODE_UNSPECIFIED = 0;
WHOLE_TREE = 1;
LAYER_BY_LAYER = 2;
}
enum MultiClassStrategy {
TREE_PER_CLASS = 0;
FULL_HESSIAN = 1;
DIAGONAL_HESSIAN = 2;
MULTI_CLASS_STRATEGY_UNSPECIFIED = 0;
TREE_PER_CLASS = 1;
FULL_HESSIAN = 2;
DIAGONAL_HESSIAN = 3;
}
// Number of classes.
@ -118,16 +120,18 @@ message LearnerConfig {
// Constraints.
TreeConstraintsConfig constraints = 5;
// Pruning.
// Pruning. POST_PRUNE is the default pruning mode.
PruningMode pruning_mode = 8;
// Growing Mode.
// Growing Mode. LAYER_BY_LAYER is the default growing mode.
GrowingMode growing_mode = 9;
// Learning rate.
// Learning rate. By default we use fixed learning rate of 0.1.
LearningRateConfig learning_rate_tuner = 6;
// Multi-class strategy.
// Multi-class strategy. By default we use TREE_PER_CLASS for binary
// classification and linear regression. For other cases, we use
// DIAGONAL_HESSIAN as the default.
MultiClassStrategy multi_class_strategy = 10;
// If you want to average the ensembles (for regularization), provide the

View File

@ -344,6 +344,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# Prepare learner config.
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE
result, result_no_dropout, dropout_info = (
prediction_ops.gradient_trees_prediction(

View File

@ -261,6 +261,7 @@ class GradientBoostedDecisionTreeModel(object):
examples_per_layer,
learner_config,
features,
logits_dimension,
feature_columns=None):
"""Construct a new GradientBoostedDecisionTreeModel function.
@ -273,8 +274,8 @@ class GradientBoostedDecisionTreeModel(object):
a tree layer. It can also be a function that computes the number of
examples based on the depth of the layer that's being built.
learner_config: A learner config.
print split, sorted_feature_names[split.feature_column]
features: `dict` of `Tensor` objects.
logits_dimension: An int, the dimension of logits.
feature_columns: A list of feature columns.
Raises:
@ -289,11 +290,39 @@ class GradientBoostedDecisionTreeModel(object):
if learner_config.num_classes < 2:
raise ValueError("Number of classes must be >=2")
self._logits_dimension = logits_dimension
self._is_chief = is_chief
self._num_ps_replicas = num_ps_replicas
self._ensemble_handle = ensemble_handle
self._center_bias = center_bias
self._examples_per_layer = examples_per_layer
# Fill in the defaults.
if (learner_config.multi_class_strategy ==
learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED):
if logits_dimension == 1:
learner_config.multi_class_strategy = (
learner_pb2.LearnerConfig.TREE_PER_CLASS)
else:
learner_config.multi_class_strategy = (
learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
if (learner_config.growing_mode ==
learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED):
learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
if (learner_config.pruning_mode ==
learner_pb2.LearnerConfig.PRUNING_MODE_UNSPECIFIED):
learner_config.pruning_mode = learner_pb2.LearnerConfig.POST_PRUNE
if learner_config.constraints.max_tree_depth == 0:
# Use 6 as the default maximum depth.
learner_config.constraints.max_tree_depth = 6
tuner = learner_config.learning_rate_tuner.WhichOneof("tuner")
if not tuner:
learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
@ -378,75 +407,81 @@ class GradientBoostedDecisionTreeModel(object):
local_stamp), _refresh_local_ensemble_fn,
lambda: (control_flow_ops.no_op(), ensemble_stamp))
# Once updated, Use the the local model for prediction.
# Once updated, use the local model for prediction.
with ops.control_dependencies([refresh_local_ensemble]):
ensemble_stats = training_ops.tree_ensemble_stats(
local_ensemble_handle, ensemble_stamp)
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
# We don't need dropout info - we can always restore it based on the
# seed.
predictions, predictions_no_dropout, _ = (
prediction_ops.gradient_trees_prediction(
local_ensemble_handle,
seed,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
learner_config=self._learner_config_serialized,
apply_dropout=apply_dropout,
apply_averaging=apply_averaging,
use_locking=False,
center_bias=self._center_bias,
reduce_dim=self._reduce_dim))
partition_ids = prediction_ops.gradient_trees_partition_examples(
local_ensemble_handle,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
use_locking=False)
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
# Make sure ensemble stats run. This will check that the ensemble has
# the right stamp.
with ops.control_dependencies(ensemble_stats):
predictions, predictions_no_dropout, _ = (
prediction_ops.gradient_trees_prediction(
local_ensemble_handle,
seed,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
learner_config=self._learner_config_serialized,
apply_dropout=apply_dropout,
apply_averaging=apply_averaging,
use_locking=True,
center_bias=self._center_bias,
reduce_dim=self._reduce_dim))
partition_ids = prediction_ops.gradient_trees_partition_examples(
local_ensemble_handle,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
use_locking=True)
else:
with ops.device(self._ensemble_handle.device):
ensemble_stats = training_ops.tree_ensemble_stats(
self._ensemble_handle, ensemble_stamp)
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
# We don't need dropout info - we can always restore it based on the
# seed.
predictions, predictions_no_dropout, _ = (
prediction_ops.gradient_trees_prediction(
self._ensemble_handle,
seed,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
learner_config=self._learner_config_serialized,
apply_dropout=apply_dropout,
apply_averaging=apply_averaging,
use_locking=False,
center_bias=self._center_bias,
reduce_dim=self._reduce_dim))
partition_ids = prediction_ops.gradient_trees_partition_examples(
self._ensemble_handle,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
use_locking=False)
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
# Make sure ensemble stats run. This will check that the ensemble has
# the right stamp.
with ops.control_dependencies(ensemble_stats):
predictions, predictions_no_dropout, _ = (
prediction_ops.gradient_trees_prediction(
self._ensemble_handle,
seed,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
learner_config=self._learner_config_serialized,
apply_dropout=apply_dropout,
apply_averaging=apply_averaging,
use_locking=True,
center_bias=self._center_bias,
reduce_dim=self._reduce_dim))
partition_ids = prediction_ops.gradient_trees_partition_examples(
self._ensemble_handle,
self._dense_floats,
self._sparse_float_indices,
self._sparse_float_values,
self._sparse_float_shapes,
self._sparse_int_indices,
self._sparse_int_values,
self._sparse_int_shapes,
use_locking=True)
return _make_predictions_dict(ensemble_stamp, predictions,
predictions_no_dropout, partition_ids,

View File

@ -164,7 +164,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -268,7 +268,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=num_examples_fn,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -371,7 +371,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -442,7 +442,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -505,7 +505,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@ -588,7 +588,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=1, features=features)
# Create predict op.
mode = model_fn.ModeKeys.EVAL
@ -627,7 +627,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=5, features=features)
predictions = array_ops.constant(
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
@ -730,7 +730,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=5, features=features)
predictions = array_ops.constant(
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
@ -833,7 +833,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
features=features)
logits_dimension=5, features=features)
batch_size = 3
predictions = array_ops.constant(

View File

@ -33,6 +33,7 @@ option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for cont
option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF)
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions")
option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON)
if (NOT WIN32)
# Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option
@ -204,6 +205,12 @@ if(tensorflow_ENABLE_JEMALLOC_SUPPORT)
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES jemalloc)
include_directories(${jemalloc_INCLUDE_DIRS})
endif()
if(tensorflow_ENABLE_SNAPPY_SUPPORT)
include(snappy)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${snappy_STATIC_LIBRARIES})
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES snappy)
include_directories(${snappy_INCLUDE_DIR})
endif()
if(WIN32)
list(APPEND tensorflow_EXTERNAL_LIBRARIES wsock32 ws2_32 shlwapi)
endif()

View File

@ -17,7 +17,7 @@ include (ExternalProject)
set(boringssl_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src/boringssl/include)
#set(boringssl_EXTRA_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src)
set(boringssl_URL https://boringssl.googlesource.com/boringssl)
set(boringssl_TAG 17cf2cb1d226b0ba2401304242df7ddd3b6f1ff2)
set(boringssl_TAG ee7aa02)
set(boringssl_BUILD ${CMAKE_BINARY_DIR}/boringssl/src/boringssl-build)
#set(boringssl_LIBRARIES ${boringssl_BUILD}/obj/so/libboringssl.so)
set(boringssl_STATIC_LIBRARIES

View File

@ -14,8 +14,8 @@
# ==============================================================================
include (ExternalProject)
set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/69ceda618313df8e9cac6659d607b08949455d14.tar.gz)
set(cub_HASH SHA256=87e856522c283b8ea887c3b61d7d5b252d2dd74abac4f1d756d776e721223e82)
set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip)
set(cub_HASH SHA256=b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe)
set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive)

View File

@ -0,0 +1,50 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
include (ExternalProject)
set(snappy_URL https://github.com/google/snappy.git)
set(snappy_TAG "55924d11095df25ab25c405fadfe93d0a46f82eb")
set(snappy_BUILD ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy)
set(snappy_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy)
if(WIN32)
set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/$(Configuration)/snappy.lib)
else()
set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/libsnappy.a)
endif()
set(snappy_HEADERS
"${snappy_INCLUDE_DIR}/snappy.h"
)
ExternalProject_Add(snappy
PREFIX snappy
GIT_REPOSITORY ${snappy_URL}
GIT_TAG ${snappy_TAG}
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
BUILD_IN_SOURCE 1
INSTALL_COMMAND ""
LOG_DOWNLOAD ON
LOG_CONFIGURE ON
LOG_BUILD ON
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=Release
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
-DSNAPPY_BUILD_TESTS:BOOL=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
)
# actually enables snappy in the source code
add_definitions(-DSNAPPY)

View File

@ -18,6 +18,7 @@
set(tf_c_srcs
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/c_api_function.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"

View File

@ -315,6 +315,7 @@ add_python_module("tensorflow/contrib/framework/ops")
add_python_module("tensorflow/contrib/framework/python")
add_python_module("tensorflow/contrib/framework/python/framework")
add_python_module("tensorflow/contrib/framework/python/ops")
add_python_module("tensorflow/contrib/gan")
add_python_module("tensorflow/contrib/graph_editor")
add_python_module("tensorflow/contrib/graph_editor/examples")
add_python_module("tensorflow/contrib/graph_editor/tests")

View File

@ -240,6 +240,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows.
"${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename.
"${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker.
"${tensorflow_source_dir}/tensorflow/python/training/server_lib_test.py" # Test occasionally deadlocks.
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops
# Broken tensorboard test due to cmake issues.
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
@ -291,6 +293,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Failing with TF 1.3 (TODO)
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py"
# Test should only be run manually
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py"
)
endif()
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})

View File

@ -716,6 +716,482 @@ _cudnn_rnn_common_doc_string = """
"""
def _check_direction(direction):
if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
raise ValueError("Invalid direction: %s, expect %s or %s" %
(direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION))
def _check_rnn_mode(rnn_mode):
if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU):
raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" %
(rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH,
CUDNN_RNN_RELU))
def _get_seed(seed):
seed, seed2 = random_seed.get_seed(seed)
if seed is None and seed2 is None:
seed, seed2 = 0, 0
return seed, seed2
def _get_num_params(rnn_mode, num_layers, direction):
"""Return num params for given Cudnn config."""
if rnn_mode == CUDNN_LSTM:
num_params_per_layer = 8
elif rnn_mode == CUDNN_GRU:
num_params_per_layer = 6
elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH):
num_params_per_layer = 2
else:
raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode)
num_params = num_layers * num_params_per_layer
if direction != CUDNN_RNN_UNIDIRECTION:
num_params *= 2
return num_params
def _cudnn_rnn(inputs,
input_h,
input_c,
params,
is_training,
rnn_mode,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn RNN.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM.
A Tensor of the same shape as input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h, output_c
"""
_check_rnn_mode(rnn_mode)
_check_direction(direction)
seed, seed2 = random_seed.get_seed(seed)
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
input=inputs,
input_h=input_h,
input_c=input_c,
params=params,
is_training=is_training,
rnn_mode=rnn_mode,
input_mode=input_mode,
direction=direction,
dropout=dropout,
seed=seed,
seed2=seed2,
name=name)
return (outputs, output_h, output_c)
def cudnn_lstm(inputs,
input_h,
input_c,
params,
is_training,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn LSTM.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM.
A Tensor of the same shape as input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h, output_c
"""
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
input_mode, direction, dropout, seed, name)
def _cudnn_rnn_no_input_c(inputs,
input_h,
params,
is_training,
rnn_mode,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn RNN w/o input_c.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h
"""
input_c = array_ops.constant([], dtype=input_h.dtype)
outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params,
is_training, rnn_mode, input_mode,
direction, dropout, seed, name)
return outputs, output_h
def cudnn_gru(inputs,
input_h,
params,
is_training,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn GRU.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h
"""
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU,
input_mode, direction, dropout, seed, name)
def cudnn_rnn_relu(inputs,
input_h,
params,
is_training,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn RNN Relu.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h
"""
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
CUDNN_RNN_RELU, input_mode, direction, dropout,
seed, name)
def cudnn_rnn_tanh(inputs,
input_h,
params,
is_training,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
seed=0,
name=None):
"""Cudnn RNN Tanh.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
outputs, output_h
"""
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
CUDNN_RNN_TANH, input_mode, direction, dropout,
seed, name)
def cudnn_rnn_params_to_canonical(rnn_mode,
num_layers,
num_units,
input_size,
params,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0,
seed=0,
name=None):
"""Convert cudnn opaque params to canonical.
Args:
rnn_mode: a string specifies the mode, under which this RNN model runs.
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
num_layers: the number of layers for the RNN model.
num_units: the number of units within the RNN model.
input_size: the size of the input, it could be different from the
num_units.
params: opaque cudnn params var.
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
weights list and bias list
Raises:
ValueError: if rnn_mode or direction is invalid.
"""
_check_rnn_mode(rnn_mode)
_check_direction(direction)
num_params = _get_num_params(rnn_mode, num_layers, direction)
seed, seed2 = random_seed.get_seed(seed)
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
rnn_mode=rnn_mode,
num_layers=num_layers,
num_units=num_units,
input_size=input_size,
params=params,
input_mode=input_mode,
direction=direction,
dropout=dropout,
seed=seed,
seed2=seed2,
num_params=num_params,
name=name)
return weights, biases
def cudnn_rnn_canonical_to_params(rnn_mode,
num_layers,
num_units,
input_size,
weights,
biases,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0,
seed=0,
name=None):
"""Converts params from the canonical format to a specific format of cuDNN.
Args:
rnn_mode: a string specifies the mode, under which this RNN model runs.
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
num_layers: the number of layers for the RNN model.
num_units: the number of units within the RNN model.
input_size: the size of the input, it could be different from the
num_units.
weights: a Tensor for weight parameters.
biases: a Tensor for bias parameters.
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
an opaque Cudnn param.
Raises:
ValueError: if rnn_mode or direction is invalid.
"""
_check_rnn_mode(rnn_mode)
_check_direction(direction)
seed, seed2 = random_seed.get_seed(seed)
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
rnn_mode=rnn_mode,
num_layers=num_layers,
num_units=num_units,
input_size=input_size,
weights=weights,
biases=biases,
input_mode=input_mode,
direction=direction,
dropout=dropout,
seed=seed,
seed2=seed2,
name=name)
def cudnn_opaque_params_size(rnn_mode,
num_layers,
num_units,
input_size,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dtype=dtypes.float32,
dropout=0,
seed=0,
name=None):
"""Returns opaque params size for specific Cudnn config.
Args:
rnn_mode: a string specifies the mode, under which this RNN model runs.
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
num_layers: the number of layers for the RNN model.
num_units: the number of units within the RNN model.
input_size: the size of the input, it could be different from the
num_units.
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
onto RNN hidden state. (standard RNN behavior).
'skip_input' is only allowed when input_size == num_units;
'auto_select' implies 'skip_input' when input_size == num_units;
otherwise, it implies 'linear_input'.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dtype: one of tf.float32 or tf.float64.
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
for behavior.
name: name of the operation.
Returns:
a int, size of Cudnn opaque params.
Raises:
ValueError: if rnn_mode or direction is invalid.
"""
_check_rnn_mode(rnn_mode)
_check_direction(direction)
seed, seed2 = random_seed.get_seed(seed)
return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
rnn_mode=rnn_mode,
num_layers=num_layers,
num_units=num_units,
input_size=input_size,
T=dtype,
S=dtypes.int32,
dropout=dropout,
seed=seed,
seed2=seed2,
input_mode=input_mode,
direction=direction,
name=name)[0]
class _CudnnRNN(object):
"""Creates an RNN model using the underlying Cudnn implementation.
@ -761,9 +1237,6 @@ class _CudnnRNN(object):
Raises:
ValueError: if direction is invalid.
"""
if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
raise ValueError("Invalid direction: %s, expect %s or %s",
direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)
self._num_layers = num_layers
self._num_units = num_units
self._input_size = input_size
@ -772,10 +1245,7 @@ class _CudnnRNN(object):
self._direction = direction
self._dtype = dtype
self._dropout = dropout
# get graph and op seed.
self._seed, self._seed2 = random_seed.get_seed(seed)
if self._seed is None and self._seed2 is None:
self._seed, self._seed2 = 0, 0
self._seed = seed
@property
def input_mode(self):
@ -807,18 +1277,16 @@ class _CudnnRNN(object):
Returns:
The calculated parameter buffer size.
"""
return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
return cudnn_opaque_params_size(
rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
input_size=self._input_size,
T=self._dtype,
S=dtypes.int32,
dtype=self._dtype,
dropout=self._dropout,
seed=self._seed,
seed2=self._seed2,
rnn_mode=self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction)[0]
direction=self._direction)
def __call__(self, input_data, input_h, input_c, params, is_training=True):
"""Runs the forward step for the RNN model.
@ -837,22 +1305,17 @@ class _CudnnRNN(object):
output_h: the final state for h.
output_c: the final state for c. This is only relevant for LSTM.
"""
if self._rnn_mode != CUDNN_LSTM:
# For model that doesn't take input_c, replace with a dummy tensor.
input_c = array_ops.constant([], dtype=self._dtype)
output, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
input=input_data,
input_h=input_h,
input_c=input_c,
params=params,
rnn_mode=self._rnn_mode,
return _cudnn_rnn(
input_data,
input_h,
input_c,
params,
is_training,
self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction,
dropout=self._dropout,
seed=self._seed,
seed2=self._seed2,
is_training=is_training)
return (output, output_h, output_c)
seed=self._seed)
def params_to_canonical(self, params):
"""Converts params from a specific format of cuDNN to the canonical format.
@ -863,22 +1326,16 @@ class _CudnnRNN(object):
Returns:
A function for the specific-to-canonical conversion.
"""
num_params = self._num_layers * self._NUM_PARAMS_PER_LAYER
if self._direction != CUDNN_RNN_UNIDIRECTION:
num_params *= 2
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
return cudnn_rnn_params_to_canonical(
rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
input_size=self._input_size,
params=params,
dropout=self._dropout,
seed=self._seed,
seed2=self._seed2,
num_params=num_params,
rnn_mode=self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction)
return weights, biases
direction=self._direction,
dropout=self._dropout,
seed=self._seed)
def canonical_to_params(self, weights, biases):
"""Converts params from the canonical format to a specific format of cuDNN.
@ -890,18 +1347,17 @@ class _CudnnRNN(object):
Returns:
A function for the canonical-to-params-to-specific conversion..
"""
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
return cudnn_rnn_canonical_to_params(
rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
input_size=self._input_size,
weights=weights,
biases=biases,
dropout=self._dropout,
seed=self._seed,
seed2=self._seed2,
rnn_mode=self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction)
direction=self._direction,
dropout=self._dropout,
seed=self._seed)
class CudnnLSTM(_CudnnRNN):
@ -1036,9 +1492,16 @@ class _CudnnRNNNoInputC(_CudnnRNN):
output: the output sequuence.
output_h: the final state for h.
"""
output, output_h, _ = super(_CudnnRNNNoInputC, self).__call__(
input_data, input_h, None, params, is_training=is_training)
return (output, output_h)
return _cudnn_rnn_no_input_c(
input_data,
input_h,
params,
is_training,
self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction,
dropout=self._dropout,
seed=self._seed)
class CudnnGRU(_CudnnRNNNoInputC):

View File

@ -22,6 +22,7 @@
@@read_batch_features
@@rejection_resample
@@group_by_window
"""
from __future__ import absolute_import
@ -31,6 +32,7 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset
from tensorflow.contrib.data.python.ops.dataset_ops import group_by_window
from tensorflow.contrib.data.python.ops.dataset_ops import Iterator
from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features
from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample

View File

@ -37,7 +37,9 @@ class GroupByWindowTest(test.TestCase):
components = np.random.randint(100, size=(200,)).astype(np.int64)
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
.apply(
dataset_ops.group_by_window,
args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -61,8 +63,9 @@ class GroupByWindowTest(test.TestCase):
components = np.array(
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1)
.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
dataset_ops.group_by_window,
args=(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -81,8 +84,9 @@ class GroupByWindowTest(test.TestCase):
def testSmallGroups(self):
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components)
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
dataset_ops.Dataset.from_tensor_slices(components).apply(
dataset_ops.group_by_window,
args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -108,8 +112,9 @@ class GroupByWindowTest(test.TestCase):
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: (x, ops.convert_to_tensor([x * x])))
.group_by_window(lambda x, _: x % 2, reduce_func, 32))
.map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
dataset_ops.group_by_window,
args=(lambda x, _: x % 2, reduce_func, 32)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -124,17 +129,20 @@ class GroupByWindowTest(test.TestCase):
def reduce_func(key, window):
# Apply two different kinds of padding to the input: tight
# padding, and quantized (to a multiple of 10) padding.
return dataset_ops.Dataset.zip((window.padded_batch(
4,
padded_shapes=tensor_shape.TensorShape([None])), window.padded_batch(
return dataset_ops.Dataset.zip((
window.padded_batch(
4, padded_shapes=tensor_shape.TensorShape([None])),
window.padded_batch(
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),))
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
.group_by_window(
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
reduce_func, 4))
.apply(
dataset_ops.group_by_window,
args=
(lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
reduce_func, 4)))
init_op = iterator.initializer
get_next = iterator.get_next()
@ -151,10 +159,9 @@ class GroupByWindowTest(test.TestCase):
self.assertEqual(len(components), sum(counts))
# NOTE(mrry): These tests are based on the tests in
# bucket_ops_test.py. Currently, different batch sizes for each key
# are not supported, although this would be possible to add to
# `Dataset.group_by_window()`.
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
class BucketTest(test.TestCase):
def _dynamicPad(self, bucket, window, window_size):
@ -168,6 +175,7 @@ class BucketTest(test.TestCase):
tensor_shape.TensorShape([3])))))
def testSingleBucket(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
@ -175,9 +183,10 @@ class BucketTest(test.TestCase):
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
bucketed_dataset = input_dataset.group_by_window(
lambda x, y, z: 0, lambda k, bucket: self._dynamicPad(k, bucket, 32),
32)
bucketed_dataset = input_dataset.apply(
dataset_ops.group_by_window,
args=(lambda x, y, z: 0,
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
@ -201,6 +210,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
def testEvenOddBuckets(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
@ -208,9 +218,10 @@ class BucketTest(test.TestCase):
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
bucketed_dataset = input_dataset.group_by_window(
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
bucketed_dataset = input_dataset.apply(
dataset_ops.group_by_window,
args=(lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
@ -256,25 +267,31 @@ class BucketTest(test.TestCase):
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
def testEvenOddBucketsFilterOutAllOdd(self):
def _map_fn(v):
return {"x": v,
"y": array_ops.fill([v], v),
"z": array_ops.fill([3], string_ops.as_string(v))}
return {
"x": v,
"y": array_ops.fill([v], v),
"z": array_ops.fill([3], string_ops.as_string(v))
}
def _dynamic_pad_fn(bucket, window, _):
return dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensors(bucket), window.padded_batch(
32, {"x": tensor_shape.TensorShape([]),
"y": tensor_shape.TensorShape([None]),
"z": tensor_shape.TensorShape([3])})))
32, {
"x": tensor_shape.TensorShape([]),
"y": tensor_shape.TensorShape([None]),
"z": tensor_shape.TensorShape([3])
})))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
.filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
bucketed_dataset = input_dataset.group_by_window(
lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)
bucketed_dataset = input_dataset.apply(
dataset_ops.group_by_window,
args=(lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
@ -295,6 +312,40 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
def testDynamicWindowSize(self):
components = np.arange(100).astype(np.int64)
# Key fn: even/odd
# Reduce fn: batches of 5
# Window size fn: even=5, odd=10
def window_size_func(key):
window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
return window_sizes[key]
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
dataset_ops.group_by_window,
args=(lambda x: x % 2, lambda _, xs: xs.batch(20), None,
window_size_func))
iterator = dataset_ops.Iterator.from_dataset(dataset)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
while True:
result = sess.run(get_next)
is_even = all(x % 2 == 0 for x in result)
is_odd = all(x % 2 == 1 for x in result)
self.assertTrue(is_even or is_odd)
expected_batch_size = 5 if is_even else 10
self.assertEqual(expected_batch_size, result.shape[0])
batches += 1
self.assertEqual(batches, 15)
if __name__ == "__main__":
test.main()

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import os
import threading
from collections import namedtuple
import numpy as np
@ -481,6 +482,40 @@ class MapDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testMapNamedtuple(self, count=10):
# construct dataset of tuples
labels = dataset_ops.Dataset.range(count)
images = labels.map(lambda l: -l)
dataset_tuple = dataset_ops.Dataset.zip((labels, images))
# convert dataset of tuples to dataset of namedtuples
Example = namedtuple("Example", ["label", "image"])
dataset_namedtuple = dataset_tuple.map(Example)
def preprocess_tuple(label, image):
image = 2 * image
return label, image
def preprocess_namedtuple(example):
return example._replace(image=2 * example.image)
# preprocess both datasets
dataset_tuple = dataset_tuple.map(preprocess_tuple)
dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple)
next_tuple = dataset_tuple.make_one_shot_iterator().get_next()
next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next()
# make sure both datasets contain the same data
with self.test_session() as sess:
for i in range(count):
tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple])
self.assertEqual(tuple_, namedtuple_)
self.assertEqual(tuple_, (i, -2 * i))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_namedtuple)
def testUseStepContainerInMap(self):
row = np.arange(6)
iterator = (

View File

@ -1199,28 +1199,9 @@ class Dataset(object):
return DenseToSparseBatchDataset(self, batch_size, row_shape)
def group_by_window(self, key_func, reduce_func, window_size):
"""Performs a windowed "group-by" operation on this dataset.
This method maps each consecutive element in this dataset to a key
using `key_func` and groups the elements by key. It then applies
`reduce_func` to at most `window_size` elements matching the same
key. All execpt the final window for each key will contain
`window_size` elements; the final window may be smaller.
Args:
key_func: A function mapping a nested structure of tensors
(having shapes and types defined by `self.output_shapes` and
`self.output_types`) to a scalar `tf.int64` tensor.
reduce_func: A function mapping a key and a dataset of up to `batch_size`
consecutive elements matching that key to another dataset.
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements matching the same key to combine in a single
batch, which will be passed to `reduce_func`.
Returns:
A `Dataset`.
"""
return GroupByWindowDataset(self, key_func, reduce_func, window_size)
"""See group_by_window()."""
return self.apply(
group_by_window, args=(key_func, reduce_func, window_size))
def map(self,
map_func,
@ -1370,6 +1351,43 @@ class Dataset(object):
"""
return FilterDataset(self, predicate)
def apply(self, fn, args=(), kwargs={}): # pylint: disable=dangerous-default-value
"""Apply a function to this dataset.
`apply` enables chaining of custom `Dataset` transformations.
For example:
```
dataset.map(
lambda x: x**2
).apply(
group_by_window, args=(key_func, reduce_func, window_size)
).map(
lambda x: x**3
)
```
Args:
fn: A function that takes a `Dataset`, `args`, and `kwargs`, and
returns a `Dataset`.
args: A `tuple` or `list` of arguments to be passed to `fn`.
kwargs: A `dict` of keyword arguments to be passed to `fn`.
Returns:
The `Dataset` returned by `fn`.
"""
if not (isinstance(args, tuple) or isinstance(args, list)):
raise TypeError("args must be a tuple or list.")
if not isinstance(kwargs, dict):
raise TypeError("kwargs must be a dict.")
dataset = fn(self, *args, **kwargs)
if not isinstance(dataset, Dataset):
raise TypeError("fn must return a Dataset.")
return dataset
class TensorDataset(Dataset):
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
@ -1903,7 +1921,7 @@ class DenseToSparseBatchDataset(Dataset):
def _should_unpack_args(args):
"""Returns `True` if `args` should be `*args` when passed to a callable."""
return nest.is_sequence(args) and not isinstance(args, dict)
return type(args) is tuple # pylint: disable=unidiomatic-typecheck
class _ResourceDataset(Dataset):
@ -1927,71 +1945,6 @@ class _ResourceDataset(Dataset):
return self._output_types
class GroupByWindowDataset(Dataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size):
"""See `Dataset.group_by_window()` for details."""
super(GroupByWindowDataset, self).__init__()
self._input_dataset = input_dataset
self._window_size = window_size
@function.Defun(*nest.flatten(input_dataset.output_types))
def tf_key_func(*args):
"""A wrapper for Defun that facilitates shape inference."""
# Pass in shape information from the input_dataset.
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
arg.set_shape(shape)
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
if _should_unpack_args(nested_args):
ret = key_func(*nested_args)
else:
ret = key_func(nested_args)
ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
if ret.dtype != dtypes.int64:
raise ValueError("`key_func` must return a single tf.int64 tensor.")
return ret
self._key_func = tf_key_func
self._key_func.add_to_graph(ops.get_default_graph())
@function.Defun(dtypes.int64, dtypes.resource)
def tf_reduce_func(key, window_dataset_resource):
"""A wrapper for Defun that facilitates shape inference."""
key.set_shape([])
window_dataset = _ResourceDataset(window_dataset_resource,
input_dataset.output_types,
input_dataset.output_shapes)
output_dataset = reduce_func(key, window_dataset)
if not isinstance(output_dataset, Dataset):
raise TypeError("`reduce_func` must return a `Dataset` object.")
self._output_types = output_dataset.output_types
self._output_shapes = output_dataset.output_shapes
return output_dataset.make_dataset_resource()
self._reduce_func = tf_reduce_func
self._reduce_func.add_to_graph(ops.get_default_graph())
def make_dataset_resource(self):
return gen_dataset_ops.group_by_window_dataset(
self._input_dataset.make_dataset_resource(),
self._key_func.captured_inputs,
self._reduce_func.captured_inputs,
self._window_size,
key_func=self._key_func,
reduce_func=self._reduce_func,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
@property
def output_shapes(self):
return self._output_shapes
@property
def output_types(self):
return self._output_types
class MapDataset(Dataset):
"""A `Dataset` that maps a function over elements in its input."""
@ -2151,7 +2104,7 @@ class InterleaveDataset(Dataset):
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
if nest.is_sequence(nested_args):
if _should_unpack_args(nested_args):
dataset = map_func(*nested_args)
else:
dataset = map_func(nested_args)
@ -2460,7 +2413,7 @@ def rejection_resample(dataset,
shapes and types defined by `dataset.output_shapes` and
`dataset.output_types`) to a scalar `tf.int32` tensor. Values should
be in `[0, num_classes)`.
target_dist: A floating point type tensor, shaped `[num_classes].
target_dist: A floating point type tensor, shaped `[num_classes]`.
initial_dist: (Optional.) A floating point type tensor, shaped
`[num_classes]`. If not provided, the true class distribution is
estimated live in a streaming fashion.
@ -2660,3 +2613,149 @@ def _get_file_names(file_pattern, randomize_input):
if not randomize_input:
file_names = sorted(file_names)
return file_names
class GroupByWindowDataset(Dataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
super(GroupByWindowDataset, self).__init__()
self._input_dataset = input_dataset
self._make_key_func(key_func, input_dataset)
self._make_reduce_func(reduce_func, input_dataset)
self._make_window_size_func(window_size_func)
def _make_window_size_func(self, window_size_func):
"""Make wrapping Defun for window_size_func."""
@function.Defun(dtypes.int64)
def tf_window_size_func(key):
key.set_shape([])
window_size = ops.convert_to_tensor(
window_size_func(key), dtype=dtypes.int64)
if window_size.dtype != dtypes.int64:
raise ValueError(
"`window_size_func` must return a single tf.int64 tensor.")
return window_size
self._window_size_func = tf_window_size_func
self._window_size_func.add_to_graph(ops.get_default_graph())
def _make_key_func(self, key_func, input_dataset):
"""Make wrapping Defun for key_func."""
@function.Defun(*nest.flatten(input_dataset.output_types))
def tf_key_func(*args):
"""A wrapper for Defun that facilitates shape inference."""
# Pass in shape information from the input_dataset.
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
arg.set_shape(shape)
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
if _should_unpack_args(nested_args):
ret = key_func(*nested_args)
else:
ret = key_func(nested_args)
ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
if ret.dtype != dtypes.int64:
raise ValueError("`key_func` must return a single tf.int64 tensor.")
return ret
self._key_func = tf_key_func
self._key_func.add_to_graph(ops.get_default_graph())
def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping Defun for reduce_func."""
@function.Defun(dtypes.int64, dtypes.resource)
def tf_reduce_func(key, window_dataset_resource):
"""A wrapper for Defun that facilitates shape inference."""
key.set_shape([])
window_dataset = _ResourceDataset(window_dataset_resource,
input_dataset.output_types,
input_dataset.output_shapes)
output_dataset = reduce_func(key, window_dataset)
if not isinstance(output_dataset, Dataset):
raise TypeError("`reduce_func` must return a `Dataset` object.")
self._output_types = output_dataset.output_types
self._output_shapes = output_dataset.output_shapes
return output_dataset.make_dataset_resource()
self._reduce_func = tf_reduce_func
self._reduce_func.add_to_graph(ops.get_default_graph())
@property
def output_shapes(self):
return self._output_shapes
@property
def output_types(self):
return self._output_types
def make_dataset_resource(self):
return gen_dataset_ops.group_by_window_dataset(
self._input_dataset.make_dataset_resource(),
self._key_func.captured_inputs,
self._reduce_func.captured_inputs,
self._window_size_func.captured_inputs,
key_func=self._key_func,
reduce_func=self._reduce_func,
window_size_func=self._window_size_func,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
def group_by_window(dataset,
key_func,
reduce_func,
window_size=None,
window_size_func=None):
"""Performs a windowed "group-by" operation on this dataset.
This method maps each consecutive element in this dataset to a key
using `key_func` and groups the elements by key. It then applies
`reduce_func` to at most `window_size_func(key)` elements matching the same
key. All execpt the final window for each key will contain
`window_size_func(key)` elements; the final window may be smaller.
You may provide either a constant `window_size` or a window size determined by
the key through `window_size_func`.
Args:
dataset: A `Dataset`.
key_func: A function mapping a nested structure of tensors
(having shapes and types defined by `self.output_shapes` and
`self.output_types`) to a scalar `tf.int64` tensor.
reduce_func: A function mapping a key and a dataset of up to `batch_size`
consecutive elements matching that key to another dataset.
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements matching the same key to combine in a single
batch, which will be passed to `reduce_func`. Mutually exclusive with
`window_size_func`.
window_size_func: A function mapping a key to a `tf.int64` scalar
`tf.Tensor`, representing the number of consecutive elements matching
the same key to combine in a single batch, which will be passed to
`reduce_func`. Mutually exclusive with `window_size`.
Returns:
A `Dataset`.
Raises:
ValueError: if neither or both of {`window_size`, `window_size_func`} are
passed.
"""
if (window_size is not None and window_size_func or
not (window_size is not None or window_size_func)):
raise ValueError("Must pass either window_size or window_size_func.")
if window_size is not None:
def constant_window_func(unused_key):
return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
window_size_func = constant_window_func
assert window_size_func is not None
return GroupByWindowDataset(dataset, key_func, reduce_func, window_size_func)

View File

@ -341,7 +341,7 @@ cuda_py_test(
cuda_py_test(
name = "sample_stats_test",
size = "small",
size = "medium",
srcs = ["python/kernel_tests/sample_stats_test.py"],
additional_deps = [
":distributions_py",

View File

@ -150,7 +150,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
`N - 1` dimensions index into a batch of independent distributions and
the last dimension represents a vector of probabilities for each
class. Only one of `logits` or `probs` should be passed in.
dtype: The type of the event samples (default: int32).
dtype: The type of the event samples (default: float32).
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
@ -388,7 +388,7 @@ class RelaxedOneHotCategorical(
dimensions index into a batch of independent distributions and the last
dimension represents a vector of probabilities for each class. Only one
of `logits` or `probs` should be passed in.
dtype: The type of the event samples (default: int32).
dtype: The type of the event samples (default: float32).
validate_args: Unused in this distribution.
allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any

View File

@ -17,440 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_checkpoint_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
ops.NotDifferentiable("GenerateVocabRemapping")
ops.NotDifferentiable("LoadAndRemapMatrix")
from tensorflow.python.training import checkpoint_ops
def _load_and_remap_matrix(ckpt_path,
old_tensor_name,
new_row_vocab_offset,
num_rows_to_load,
new_col_vocab_size,
initializer,
old_row_vocab_file=None,
new_row_vocab_file=None,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=0,
num_col_oov_buckets=0,
max_rows_in_memory=-1):
"""Loads a 2-D (matrix) `Tensor` from checkpoint.
Generates 1D-remappings for rows and columns using the
`GenerateVocabRemapping` op, and initializes any anticipated values with the
provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a
matrix that loads existing values from the checkpoint, while filling out
"missing" values with the newly initialized values. See
contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped
functionality (LoadAndRemapMatrix). This wrapper can be used to perform only
row remapping or only col remapping. If only row remapping is desired,
{new,old}_col_vocab_file should be `None`, and vice versa for column
remapping.
NOTE: This only supports div-partitioning the vocabulary on the 1st dimension
(row axis) via `new_row_vocab_offset`.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
new_row_vocab_offset: A 0-indexed integer representing what line to
start reading at in the new row vocabulary. Used for partitioned
variables.
num_rows_to_load: Number of rows to load for the new vocabulary (note: to
support variable partitioning and partial loading, this does not need to
be the same as the number of entries in `new_row_vocab_file`).
new_col_vocab_size: Number of columns to load - should be the same as the
number of entries in `new_col_vocab_file`, since we don't support
partitioning along the column axis.
initializer: Callable initializer function that accepts a 1-D tensor as the
arg to specify the shape of the returned tensor. Used to initialize
missing values.
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old row vocabulary file. Can be None, which represents no
remapping on the row axis.
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new row vocabulary file. Can be None, which represents no remapping
on the row axis - in which case, `new_row_vocab_offset` and
`num_rows_to_load` work under the assumption that the new row vocab is the
same as the old row vocab.
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old column vocabulary file. Can be None, which represents no
remapping on the column axis.
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new column vocabulary file. Can be None, which represents no
remapping on the column axis - in which case, `new_col_vocab_size` works
under the assumption that the new col vocab is the same as the old col
vocab.
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
to append. Must be >= 0.
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
columns to append. Must be >= 0.
max_rows_in_memory: `int` specifying the maximum number of rows to load from
the checkpoint at once. If less than or equal to 0, the entire matrix will
be loaded into memory. Setting this arg trades increased disk reads for
lower memory usage.
Returns:
A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the
specified tensor in the checkpoint, and any missing or OOV values
initialized with the given `initializer`.
Raises:
ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0.
ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is
provided, while the other is not. Same for `old_col_vocab_file` and
`new_col_vocab_file`.
ValueError: If neither row vocabs or col vocabs are provided.
"""
if num_row_oov_buckets < 0:
raise ValueError("num_row_oov_buckets must be >= 0, but received %d" %
num_row_oov_buckets)
if num_col_oov_buckets < 0:
raise ValueError("num_col_oov_buckets must be >= 0, but received %d" %
num_col_oov_buckets)
if bool(old_row_vocab_file) != bool(new_row_vocab_file):
raise ValueError(
"old_row_vocab_file and new_row_vocab_file must both be specified or "
"left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'".
format(old_row_vocab_file, new_row_vocab_file))
if bool(old_col_vocab_file) != bool(new_col_vocab_file):
raise ValueError(
"old_col_vocab_file and new_col_vocab_file must both be specified or "
"left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'".
format(old_col_vocab_file, new_col_vocab_file))
remap_rows = new_row_vocab_file and old_row_vocab_file
remap_cols = new_col_vocab_file and old_col_vocab_file
if not (remap_rows or remap_cols):
raise ValueError(
"Must provide either row or column vocab files. If no remapping is "
"necessary, consider using `tf.contrib.framework.init_from_checkpoint` "
"instead.")
num_rows_present = num_rows_to_load
if remap_rows:
row_remapping, num_rows_present = (
gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access
new_vocab_file=new_row_vocab_file,
old_vocab_file=old_row_vocab_file,
new_vocab_offset=new_row_vocab_offset,
num_new_vocab=num_rows_to_load))
else:
# Even when the rows are not being reordered, we still need to generate a
# remapping to account for initializing partitioned Variables (when
# new_row_vocab_offset is non-zero).
row_remapping = math_ops.range(
new_row_vocab_offset,
new_row_vocab_offset + num_rows_to_load,
dtype=dtypes.int64)
col_remapping = []
num_cols_present = new_col_vocab_size
if remap_cols:
col_remapping, num_cols_present = (
gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access
new_vocab_file=new_col_vocab_file,
old_vocab_file=old_col_vocab_file,
new_vocab_offset=0, # Offset is unused for cols (no partitioning).
num_new_vocab=new_col_vocab_size))
init_vals = initializer([
num_rows_to_load * new_col_vocab_size -
num_rows_present * num_cols_present, 1
])
return_tensor = gen_checkpoint_ops._load_and_remap_matrix( # pylint: disable=protected-access
ckpt_path=ckpt_path,
old_tensor_name=old_tensor_name,
row_remapping=row_remapping,
col_remapping=col_remapping,
initializing_values=init_vals,
num_rows=num_rows_to_load,
num_cols=new_col_vocab_size,
max_rows_in_memory=max_rows_in_memory)
# Add OOV row(s) and column(s).
if num_row_oov_buckets > 0:
init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size])
init_row_oov_val = ops.convert_to_tensor(init_row_oov_val)
return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0)
if num_col_oov_buckets > 0:
# We need to add any row OOV to the new column shape.
init_col_oov_val = initializer(
[num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets])
init_col_oov_val = ops.convert_to_tensor(init_col_oov_val)
return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1)
return return_tensor
def load_and_remap_matrix_initializer(ckpt_path,
old_tensor_name,
new_row_vocab_size,
new_col_vocab_size,
old_row_vocab_file=None,
new_row_vocab_file=None,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=0,
num_col_oov_buckets=0,
initializer=None,
max_rows_in_memory=-1):
r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor.
The returned initializer loads a 2-D (matrix) `Tensor` with name
`old_tensor_name` from the checkpoint at `ckpt_path`. It will reorder the
rows/columns according to the specified vocab files and append additional
out-of-vocabulary rows/columns according to the number of OOV buckets.
The format of the file at the `{old,new}_{row,col}_vocab_file` path should be
a text file, with each line containing a single entity within the vocabulary.
Let the function `line_of(f, "x")` return the 0-indexed line number of the
entity "x" in file f, and the function `entity_at(f, i)` return the entity at
line i of file f. Then, row i of the new output matrix will be taken from row
`line_of(old_row_vocab_file, entity_at(new_row_vocab_file, i))` of the old
matrix. If any entity in `new_row_vocab_file` is not found in
`old_row_vocab_file`, that row is considered a "missing" row, and its values
will be initialized using the `initializer` arg. The same logic also applies
for the columns.
For example, assuming that:
* `old_row_vocab_file` contains "mercury\nvenus\nmars"
* `new_row_vocab_file` contains "venus\njupiter\nmercury"
* `old_col_vocab_file` contains "good\nbetter\nbest"
* `new_col_vocab_file` contains "good\nbest\nfantastic"
* `initializer` returns the natural numbers `[1, 2, 3, 4, ...]`
* `w(i, j)` represents the value from row i, column j of the old matrix
Then the new output matrix will look like:
`[[w(1, 0), w(1, 2), 1],
[2, 3, 4],
[w(0, 0), w(0, 2), 5]]`
If we further specify that:
* `num_row_oov_buckets` == 2
* `num_col_oov_buckets` == 1
Then the new output matrix will look like:
`[[w(1, 0), w(1, 2), 1, 12],
[2, 3, 4, 13],
[w(0, 0), w(0, 2), 5, 14],
[6, 7, 8, 15],
[9, 10, 11, 16]]`
If `{old,new}_row_vocab_file` are None, we assume that the old and new row
vocab files are the same, and no row remapping is done. If
`{old,new}_col_vocab_file` are None, we assume that the old and new column
vocab files are the same, and no column remapping is done.
The returned initializer only supports div-partitioning along the row axis. It
does not support partitioning along the column axis or mod-partitioning.
NOTE: When this is used to warm-start variables, client code should use
`tf.lookup.index_table_from_tensor()` like
contrib/layers/python/layers/feature_column.py does, as opposed to
`tf.feature_to_id()` - in order to ensure the underlying lookup tables are the
same.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
new_row_vocab_size: `int` specifying the number of entries in
`new_row_vocab_file`. If no row remapping is needed (no row vocab
provided), this should be equal to the number of rows to load from the old
matrix (which can theoretically be smaller than the number of rows in the
old matrix).
new_col_vocab_size: `int` specifying the number of entries in
`new_col_vocab_file`. If no column remapping is needed (no column vocab
provided), this should be equal to the number of columns in the old
matrix.
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old row vocabulary file. Can be None, which represents no
remapping on the row axis.
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new row vocabulary file. Can be None, which represents no remapping
on the row axis.
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old column vocabulary file. Can be None, which represents no
remapping on the column axis.
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new column vocabulary file. Can be None, which represents no
remapping on the column axis.
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
to append. Must be >= 0.
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
columns to append. Must be >= 0.
initializer: Initializer function to initialize missing values. Accepts a
1-D tensor as the arg to specify the shape of the returned tensor. If
`None`, defaults to using `zeros_initializer()`.
max_rows_in_memory: `int` specifying the maximum number of rows to load from
the checkpoint at once. If less than or equal to 0, the entire matrix will
be loaded into memory. Setting this arg trades increased disk reads for
lower memory usage.
Returns:
A variable initializer function that should be used to initialize a
(potentially partitioned) `Variable` whose complete shape is
`[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size +
num_col_oov_buckets]`.
Raises:
TypeError: If `initializer` is specified but not callable.
"""
if initializer is None:
# TODO(b/25671353): Consider using sqrt(6/(fan_in + fan_out)) instead, from
# Glorot and Bengio, 2010.
initializer = init_ops.zeros_initializer()
if not callable(initializer):
raise TypeError(
"initializer must be callable, instead of being {} of type {}.".format(
initializer, type(initializer)))
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
"""Variable initializer.
Args:
shape: Shape of `Tensor` to return. Should include OOV on both axes.
dtype: Must be float32.
partition_info: variable_scope._PartitionInfo.
Returns:
`Tensor` of shape `shape`.
Raises:
TypeError: If `dtype` is anything other than float32.
ValueError: For shape mismatch upon invocation.
"""
# Sanity checks.
if dtype != dtypes.float32:
raise TypeError(
"Currently, only float32 is supported. Received dtype: {}".format(
dtype))
if len(shape) != 2:
raise ValueError("Expected 2-dim shape, but received: {}".format(shape))
if shape[0] <= 0:
raise ValueError(
"Expected 1st dim of shape to be > 0, but received shape: {}".format(
shape))
if shape[1] != (new_col_vocab_size + num_col_oov_buckets):
raise ValueError(
"Expected 2nd dim of shape to be new_col_vocab_size ({}) + "
"num_col_oov_buckets ({}) = {}, but received shape: {}".format(
new_col_vocab_size, num_col_oov_buckets,
new_col_vocab_size + num_col_oov_buckets, shape))
offset = 0
if partition_info is not None:
offset = partition_info.single_offset(shape)
if offset + shape[0] > new_row_vocab_size + num_row_oov_buckets:
raise ValueError(
"Trying to initialize {} additional rows after {} rows have already "
"been initialized, which would exceed expected total row count of "
"new_row_vocab_size ({}) + num_row_oov_buckets ({}) = {}.".format(
shape[0], offset, new_row_vocab_size, num_row_oov_buckets,
new_row_vocab_size + num_row_oov_buckets))
row_oov_buckets_to_use = min(shape[0],
max(0, offset + shape[0] - new_row_vocab_size))
num_rows_to_load = shape[0] - row_oov_buckets_to_use
return _load_and_remap_matrix(
ckpt_path=ckpt_path,
old_tensor_name=old_tensor_name,
new_row_vocab_offset=offset,
num_rows_to_load=num_rows_to_load,
new_col_vocab_size=new_col_vocab_size,
initializer=initializer,
old_row_vocab_file=old_row_vocab_file,
new_row_vocab_file=new_row_vocab_file,
old_col_vocab_file=old_col_vocab_file,
new_col_vocab_file=new_col_vocab_file,
num_row_oov_buckets=row_oov_buckets_to_use,
num_col_oov_buckets=num_col_oov_buckets,
max_rows_in_memory=max_rows_in_memory)
return _initializer
def load_embedding_initializer(ckpt_path,
embedding_tensor_name,
new_vocab_size,
embedding_dim,
old_vocab_file,
new_vocab_file,
num_oov_buckets=0,
initializer=None,
max_rows_in_memory=-1):
"""Returns a variable initializer for loading pre-trained embeddings.
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
embedding weights and remapping according to the provided vocab files. See
docs for `load_and_remap_matrix_initializer()` for more details.
NOTE: Only for use with div-partitioned variables / vocabularies.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
embedding_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
new_vocab_size: Number of entries in the new vocab.
embedding_dim: `int` specifying the dimension of the embedding vectors from
the checkpoint. Must match the number of columns in the old embedding
matrix.
old_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old vocabulary file.
new_vocab_file: A scalar `Tensor` of type `string` containing the
path to the new vocabulary file.
num_oov_buckets: `int` specifying the number of out-of-vocabulary
buckets to use. Must be >= 0.
initializer: Initializer function that accepts a 1-D tensor as the arg to
specify the shape of the returned tensor. If `None`, defaults to using
`truncated_normal_initializer()`.
max_rows_in_memory: `int` specifying the maximum number of rows to load from
the checkpoint at once. If less than or equal to 0, the entire matrix will
be loaded into memory. Setting this arg trades increased disk reads for
lower memory usage.
Returns:
A variable initializer function.
"""
if initializer is None:
# TODO(b/25671353): This should be kept in sync with the stddev used by
# feature_column.py's _EmbeddingColumn.
initializer = init_ops.truncated_normal_initializer(
stddev=1.0 / math.sqrt(embedding_dim))
return load_and_remap_matrix_initializer(
ckpt_path=ckpt_path,
old_tensor_name=embedding_tensor_name,
new_row_vocab_size=new_vocab_size,
new_col_vocab_size=embedding_dim,
old_row_vocab_file=old_vocab_file,
new_row_vocab_file=new_vocab_file,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=num_oov_buckets,
num_col_oov_buckets=0,
initializer=initializer,
max_rows_in_memory=max_rows_in_memory)
# pylint: disable=protected-access,line-too-long
load_and_remap_matrix_initializer = checkpoint_ops._load_and_remap_matrix_initializer
# pylint: enable=line-too-long
load_embedding_initializer = checkpoint_ops._load_embedding_initializer
# pylint: enable=protected-access
def load_linear_multiclass_bias_initializer(ckpt_path,

View File

@ -21,7 +21,6 @@ import os
import numpy as np
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework.python.ops import checkpoint_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -38,250 +37,6 @@ FLAGS = flags.FLAGS
_TESTDATA_PATH = 'contrib/framework/testdata'
class LoadAndRemapWrappersTest(test.TestCase):
"""Tests for the functionality of the Python wrappers."""
def setUp(self):
self.bundle_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint')
self.new_feature_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint_vocab.txt')
self.old_feature_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH),
'bundle_checkpoint_vocab_with_oov.txt')
self.new_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
self.old_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
self.init_val = 42
def _init_val_initializer(shape, dtype=None, partition_info=None):
del dtype, partition_info # Unused by this unit-testing initializer.
return array_ops.tile(
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
self.initializer = _init_val_initializer
def test_load_and_remap_matrix(self):
"""Tests the end-to-end loading / remapping of weights."""
# _load_and_remap_matrix() is the generalized wrapper that takes in row and
# column vocabulary files, calls the relevant remappings, and returns the
# weight matrix. Take this example to be linear multi-class by providing
# both row and column vocabularies.
remapped_matrix = checkpoint_ops._load_and_remap_matrix(
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_rows_to_load=4,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_offset=1,
initializer=self.initializer,
num_row_oov_buckets=1,
num_col_oov_buckets=1)
# [4 in vocab + 1 oov features, 4 in vocab + 1 oov classes]. The offset
# means we read
expected_remapped_matrix = np.concatenate(
[
np.reshape([18, 34, 50, self.init_val, self.init_val], [5, 1]),
np.reshape([16, 32, 48, self.init_val, self.init_val], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1]),
np.reshape([17, 33, 49, self.init_val, self.init_val], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1])
],
axis=1)
with self.test_session():
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_output_layer_weight_initializer_linear(self):
"""Tests for the output layer initializer in the linear multi-class case."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_row_oov_buckets=1,
num_col_oov_buckets=1,
initializer=self.initializer))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, self.init_val, self.init_val], [6, 1]),
np.reshape([0, 16, 32, 48, self.init_val, self.init_val], [6, 1]),
np.reshape([self.init_val] * 6, [6, 1]),
np.reshape([1, 17, 33, 49, self.init_val, self.init_val], [6, 1]),
np.reshape([self.init_val] * 6, [6, 1])
],
axis=1)
# The new weight matrix is of size
# [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV]. Use a
# partitioned variable to confirm that the offset logic works.
remapped_matrix = variable_scope.get_variable(
name='linear/obtained_weight_matrix',
shape=[6, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_load_and_remap_output_layer_weight_initializer_dnn_output(self):
"""Tests for the output layer initializer in the DNN output case."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
num_col_oov_buckets=1,
initializer=self.initializer))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, 66], [5, 1]),
np.reshape([0, 16, 32, 48, 64], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1]),
np.reshape([1, 17, 33, 49, 65], [5, 1]),
np.reshape([self.init_val] * 5, [5, 1])
],
axis=1)
# The new weight matrix is of size
# [5-sized input layer, 4 class vocab + 1 class OOV].
remapped_matrix = variable_scope.get_variable(
name='dnn_output/obtained_weight_matrix',
shape=[5, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_initializer_with_oov_only_partition(self):
"""Tests for the output layer initializer where one partition is all OOV."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_row_oov_buckets=5,
num_col_oov_buckets=1,
initializer=self.initializer))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50] + [self.init_val] * 6, [10, 1]),
np.reshape([0, 16, 32, 48] + [self.init_val] * 6, [10, 1]),
np.reshape([self.init_val] * 10, [10, 1]),
np.reshape([1, 17, 33, 49] + [self.init_val] * 6, [10, 1]),
np.reshape([self.init_val] * 10, [10, 1]),
],
axis=1)
# The new weight matrix is of size
# [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV]. The
# second partition has only OOV.
remapped_matrix = variable_scope.get_variable(
name='linear_all_oov/obtained_weight_matrix',
shape=[10, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_load_and_remap_linear_multiclass_initializer_default_init(self):
"""Tests where the zeros_initializer default is used for linear."""
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
new_row_vocab_size=5,
new_col_vocab_file=self.new_class_vocab_file,
old_col_vocab_file=self.old_class_vocab_file,
new_col_vocab_size=4,
old_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
new_row_vocab_file=self.new_feature_vocab_file,
old_row_vocab_file=self.old_feature_vocab_file,
num_row_oov_buckets=1,
num_col_oov_buckets=1))
expected_remapped_matrix = np.concatenate(
[
np.reshape([2, 18, 34, 50, 0, 0], [6, 1]),
np.reshape([0, 16, 32, 48, 0, 0], [6, 1]),
np.reshape([0] * 6, [6, 1]),
np.reshape([1, 17, 33, 49, 0, 0], [6, 1]),
np.reshape([0] * 6, [6, 1])
],
axis=1)
remapped_matrix = variable_scope.get_variable(
name='linear_init_fallback/obtained_weight_matrix',
shape=[6, 5],
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
def test_load_embedding_initializer(self):
"""Tests for the load_embedding_initializer wrapper."""
embedding_loading_initializer = (
contrib_framework.load_embedding_initializer(
new_vocab_file=self.new_feature_vocab_file,
old_vocab_file=self.old_feature_vocab_file,
new_vocab_size=5,
embedding_dim=16,
embedding_tensor_name='some_scope/embeddings',
ckpt_path=[self.bundle_file],
num_oov_buckets=1,
initializer=self.initializer))
expected_remapped_embeddings = np.concatenate(
[
np.reshape(range(64), [4, 16]),
np.reshape([self.init_val] * 32, [2, 16]),
],
axis=0)
# The new weight matrix is of size
# [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
# last vocab row (2nd last row) is newly initialized (wasn't found in
# previous vocab) and the actual last row is OOV and also newly initialized.
# Use a partitioned variable to confirm that the offset logic works.
remapped_embeddings = variable_scope.get_variable(
name='embedding/obtained_embedding_matrix',
shape=[6, 16],
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
class LoadMulticlassBiasTest(test.TestCase):
"""Tests for the load_linear_multiclass_bias_initializer functionality."""

View File

@ -0,0 +1,27 @@
package(default_visibility = ["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "gan",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
deps = [
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,4 @@
This directory contains the TFGAN project.
This file will have more details as code is added.

Some files were not shown because too many files have changed in this diff Show More