Merge branch 'master' into expose-symbols-osx
This commit is contained in:
commit
6d7ac549b6
configure.py
tensorflow
BUILD
c
BUILDc_api.ccc_api.hc_api_function.ccc_api_function_test.ccc_api_internal.hc_api_test.ccc_test_util.ccc_test_util.h
eager
version_script.ldscc
compiler
tests
tf2xla/kernels
xla
service
BUILD
cpu
dfs_hlo_visitor.ccdfs_hlo_visitor.hdfs_hlo_visitor_with_default.helemental_ir_emitter.ccgpu
hlo_alias_analysis.cchlo_alias_analysis.hhlo_alias_analysis_test.cchlo_buffer.cchlo_buffer.hhlo_cost_analysis.cchlo_cost_analysis.hhlo_dataflow_analysis.cchlo_dataflow_analysis.hhlo_dataflow_analysis_test.cchlo_graph_dumper.cchlo_instruction.cchlo_instruction.hhlo_instruction_test.cchlo_ordering_test.cchlo_pass_pipeline.cchlo_rematerialization.cchlo_rematerialization.hhlo_value.cchlo_value.hhlo_verifier.ccreduce_precision_insertion.cctests
tools
contrib
BUILD__init__.py
boosted_trees
estimator_batch
kernels
lib
learner/stochastic
handlers
bias-feature-column-handler_test.cccategorical-feature-column-handler_test.ccdense-quantized-feature-column-handler_test.ccsparse-quantized-feature-column-handler_test.cc
stats
quantiles
proto
python
cmake
cudnn_rnn/python/ops
data
distributions
framework/python/ops
gan
10
configure.py
10
configure.py
@ -685,10 +685,12 @@ def set_tf_cunn_version(environ_cp):
|
||||
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
|
||||
cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
|
||||
cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)',
|
||||
cudnn_path_from_ldconfig).group(1)
|
||||
if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
|
||||
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
|
||||
break
|
||||
cudnn_path_from_ldconfig)
|
||||
if cudnn_path_from_ldconfig:
|
||||
cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
|
||||
if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
|
||||
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
|
||||
break
|
||||
|
||||
# Reset and Retry
|
||||
print(
|
||||
|
@ -296,6 +296,7 @@ filegroup(
|
||||
"//tensorflow/contrib/ffmpeg/default:all_files",
|
||||
"//tensorflow/contrib/framework:all_files",
|
||||
"//tensorflow/contrib/fused_conv:all_files",
|
||||
"//tensorflow/contrib/gan:all_files",
|
||||
"//tensorflow/contrib/graph_editor:all_files",
|
||||
"//tensorflow/contrib/grid_rnn:all_files",
|
||||
"//tensorflow/contrib/hooks:all_files",
|
||||
@ -323,6 +324,7 @@ filegroup(
|
||||
"//tensorflow/contrib/nn:all_files",
|
||||
"//tensorflow/contrib/opt:all_files",
|
||||
"//tensorflow/contrib/predictor:all_files",
|
||||
"//tensorflow/contrib/receptive_field:all_files",
|
||||
"//tensorflow/contrib/reduce_slice_ops:all_files",
|
||||
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
|
||||
"//tensorflow/contrib/resampler:all_files",
|
||||
@ -342,6 +344,7 @@ filegroup(
|
||||
"//tensorflow/contrib/staging:all_files",
|
||||
"//tensorflow/contrib/stat_summarizer:all_files",
|
||||
"//tensorflow/contrib/stateless:all_files",
|
||||
"//tensorflow/contrib/summary:all_files",
|
||||
"//tensorflow/contrib/tensor_forest:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/kernels/v4:all_files",
|
||||
|
@ -45,8 +45,13 @@ tf_cuda_library(
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api",
|
||||
srcs = ["c_api.cc"],
|
||||
hdrs = ["c_api.h"],
|
||||
srcs = [
|
||||
"c_api.cc",
|
||||
"c_api_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"c_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
@ -157,6 +162,21 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "c_api_function_test",
|
||||
size = "small",
|
||||
srcs = ["c_api_function_test.cc"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "while_loop_test",
|
||||
size = "small",
|
||||
|
@ -165,22 +165,6 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
tensorflow::cpu_allocator()->DeallocateRaw(data);
|
||||
}
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
|
||||
TF_Buffer* out) {
|
||||
if (out->data != nullptr) {
|
||||
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
|
||||
}
|
||||
const auto proto_size = in.ByteSizeLong();
|
||||
void* buf = tensorflow::port::Malloc(proto_size);
|
||||
in.SerializeToArray(buf, proto_size);
|
||||
out->data = buf;
|
||||
out->length = proto_size;
|
||||
out->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TF_Tensor::~TF_Tensor() { buffer->Unref(); }
|
||||
@ -559,6 +543,27 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
dimvec.size(), base, size, DeleteArray, base);
|
||||
}
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::Message& in,
|
||||
TF_Buffer* out) {
|
||||
if (out->data != nullptr) {
|
||||
return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
|
||||
}
|
||||
const size_t proto_size = in.ByteSizeLong();
|
||||
void* buf = tensorflow::port::Malloc(proto_size);
|
||||
if (buf == nullptr) {
|
||||
return tensorflow::errors::ResourceExhausted(
|
||||
"Failed to allocate memory to serialize message of type '",
|
||||
in.GetTypeName(), "' and size ", proto_size);
|
||||
}
|
||||
in.SerializeToArray(buf, proto_size);
|
||||
out->data = buf;
|
||||
out->length = proto_size;
|
||||
out->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helpers for loading a TensorFlow plugin (a .so file).
|
||||
Status LoadLibrary(const char* library_filename, void** result,
|
||||
const void** buf, size_t* len);
|
||||
|
@ -357,6 +357,14 @@ typedef struct TF_Output {
|
||||
int index; // The index of the output within oper.
|
||||
} TF_Output;
|
||||
|
||||
// TF_Function is a grouping of operations with defined inputs and outputs.
|
||||
// Once created and added to graphs, functions can be invoked by creating an
|
||||
// operation whose operation type matches the function name.
|
||||
typedef struct TF_Function TF_Function;
|
||||
|
||||
// Function definition options. TODO(iga): Define and implement
|
||||
typedef struct TF_FunctionOptions TF_FunctionOptions;
|
||||
|
||||
// Sets the shape of the Tensor referenced by `output` in `graph` to
|
||||
// the shape described by `dims` and `num_dims`.
|
||||
//
|
||||
@ -914,6 +922,15 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef(
|
||||
TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options, TF_Status* status);
|
||||
|
||||
// Add `function` to graph `g`. Once `function` is added to `g`,
|
||||
// it can be called by creating an operation using the function's name.
|
||||
//
|
||||
// If successful, status is set to OK and function is added to g
|
||||
// Otherwise, status is set to the encountered error and g is unmodified
|
||||
TF_CAPI_EXPORT extern void TF_GraphAddFunction(TF_Graph* g,
|
||||
const TF_Function* function,
|
||||
TF_Status* status);
|
||||
|
||||
// Note: The following function may fail on very large protos in the future.
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper,
|
||||
@ -1001,6 +1018,105 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
|
||||
TF_Output* x, int nx, TF_Output* dx,
|
||||
TF_Status* status, TF_Output* dy);
|
||||
|
||||
// Create a TF_Function from a TF_Graph
|
||||
//
|
||||
// Params:
|
||||
// fn_body - the graph whose operations (or subset of whose operations) will be
|
||||
// converted to TF_Function.
|
||||
// fn_name - the name of the new TF_Function. Should match the operation
|
||||
// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct
|
||||
// from other operation names (at least those registered in graphs
|
||||
// where this function will be used).
|
||||
// TODO(iga): Allow null in here and have C API come up with
|
||||
// a unique name with high probability (similarly to
|
||||
// _create_hash_str in function.py)
|
||||
// num_opers - `num_opers` contains the number of elements in the `opers` array
|
||||
// or a special value of -1 meaning that no array is given.
|
||||
// The distinction between an empty array of operations and no
|
||||
// array of operations is necessary to distinguish the case of
|
||||
// creating a function with no body (e.g. identity or permutation)
|
||||
// and the case of creating a function whose body contains all
|
||||
// the nodes in the graph (except for the automatic skipping, see
|
||||
// below).
|
||||
// opers - Array of operations to become the body of the function or null.
|
||||
// - If no array is given (`num_opers` = -1), all the
|
||||
// operations in `fn_body` will become part of the function
|
||||
// except operations referenced in `inputs`. These operations
|
||||
// must have a single output (these operations are typically
|
||||
// placeholders created for the sole purpose of representing
|
||||
// an input. We can relax this constraint if there are
|
||||
// compelling use cases).
|
||||
// - If an array is given (`num_opers` >= 0), all operations
|
||||
// in it will become part of the function. In particular, no
|
||||
// automatic skipping of dummy input operations is performed.
|
||||
// ninputs - number of elements in `inputs` array
|
||||
// inputs - array of TF_Outputs that specify the inputs to the function.
|
||||
// If `ninputs` is zero (the function takes no inputs), `inputs`
|
||||
// can be null. The names used for function inputs are normalized
|
||||
// names of the operations (usually placeholders) pointed to by
|
||||
// `inputs`. These operation names should start with a letter.
|
||||
// Normalization will convert all letters to lowercase and
|
||||
// non-alphanumeric characters to '_' to make resulting names match
|
||||
// the "[a-z][a-z0-9_]*" pattern for operation argument names.
|
||||
// `inputs` cannot contain the same tensor twice.
|
||||
// noutputs - number of elements in `outputs` array
|
||||
// outputs - array of TF_Outputs that specify the outputs of the function.
|
||||
// If `noutputs` is zero (the function returns no outputs), `outputs`
|
||||
// can be null. `outputs` can contain the same tensor more than once.
|
||||
// output_names - The names of the function's outputs. `output_names` array
|
||||
// must either have the same length as `outputs`
|
||||
// (i.e. `noutputs`) or be null. In the former case,
|
||||
// the names should match the regular expression for ArgDef
|
||||
// names - "[a-z][a-z0-9_]*". In the latter case,
|
||||
// names for outputs will be generated automatically.
|
||||
// opts - various options for the function, e.g. XLA's inlining control.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
//
|
||||
// Note that when the same TF_Output is listed as both an input and an output,
|
||||
// the corresponding function's output will equal to this input,
|
||||
// instead of the original node's output.
|
||||
//
|
||||
// Callers must also satisfy the following constraints:
|
||||
// - `inputs` cannot refer to TF_Outputs within a control flow context. For
|
||||
// example, one cannot use the output of "switch" node as input.
|
||||
// - No TF_Output of a function (inside any of `inputs`, `outputs`, `fn_body`)
|
||||
// is allowed to have a reference type. Reference types are not exposed
|
||||
// through C API and are being deprecated.
|
||||
// - Every node in the function's body must have all of its inputs (including
|
||||
// control inputs). In other words, for every node in the body, each input
|
||||
// must be either listed in `inputs` or must come from another node in
|
||||
// the body. In particular, it is an error to have a control edge going from
|
||||
// a node outside of the body into a node in the body. This applies to control
|
||||
// edges going from nodes referenced in `inputs` to nodes in the body when
|
||||
// the former nodes are not in the body (automatically skipped or not
|
||||
// included in explicitly specified body).
|
||||
//
|
||||
// Returns:
|
||||
// On successful, a newly created TF_Function instance. It must be deleted by
|
||||
// calling TF_DeleteFunction.
|
||||
//
|
||||
// On failure, null.
|
||||
//
|
||||
// TODO(iga): Add input_names argument and get output_names working (they are
|
||||
// currently ignored)
|
||||
TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction(
|
||||
const TF_Graph* fn_body, const char* fn_name, int num_opers,
|
||||
const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
|
||||
int noutputs, const TF_Output* outputs, const char* const* output_names,
|
||||
const TF_FunctionOptions* opts, TF_Status* status);
|
||||
|
||||
// Write out a serialized representation of `func` (as a FunctionDef protocol
|
||||
// message) to `output_func_def` (allocated by TF_NewBuffer()).
|
||||
// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
|
||||
// is called.
|
||||
//
|
||||
// May fail on very large graphs in the future.
|
||||
TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func,
|
||||
TF_Buffer* output_func_def,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function*);
|
||||
|
||||
// TODO(josh11b): Register OpDef, available to all operations added
|
||||
// to this graph.
|
||||
|
||||
|
496
tensorflow/c/c_api_function.cc
Normal file
496
tensorflow/c/c_api_function.cc
Normal file
@ -0,0 +1,496 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Class that maintains a one-to-one original node name -> new node name
|
||||
// mapping. We normalize the names used as input and output arguments to match
|
||||
// regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
|
||||
// Once we rename them, we risk creating a name collision with the other
|
||||
// node names, so if necessary we add a suffix to make
|
||||
// names unique. If we have an input named "A" and a node in the function
|
||||
// body named "a", they will be renamed to "a" and "a_0".
|
||||
class NodeNameMapping {
|
||||
public:
|
||||
NodeNameMapping() = default;
|
||||
|
||||
// Normalize the input/output name and make it unique.
|
||||
string GetIOName(const string& name);
|
||||
|
||||
// Make the node name unique.
|
||||
string Uniquify(const string& name);
|
||||
|
||||
// Look up how a node name was previously normalized/uniquified.
|
||||
// Returns empty if name was never seen.
|
||||
string Lookup(const string& name) const;
|
||||
|
||||
private:
|
||||
string UniquifyHelper(const string& name) const;
|
||||
static string Normalize(string name);
|
||||
|
||||
// The normalized/uniquified names already used as
|
||||
// input names (in signature), output names (in signature), and node names
|
||||
// (in node_def).
|
||||
// This is a superset of values in name_mapping_.
|
||||
std::unordered_set<string> used_names_;
|
||||
// Mapping from original node name from the graph to the normalized
|
||||
// and uniqified version of it.
|
||||
std::unordered_map<string, string> name_mapping_;
|
||||
};
|
||||
|
||||
string NodeNameMapping::Normalize(string name) {
|
||||
// Convert letters to lowercase and non-alphanumeric characters to '_'.
|
||||
if (name.empty()) return "unknown";
|
||||
const int n = name.size();
|
||||
for (int i = 0; i < n; ++i) {
|
||||
char c = name[i];
|
||||
if (isalnum(c)) {
|
||||
if (isupper(c)) {
|
||||
name[i] = tolower(c);
|
||||
}
|
||||
} else {
|
||||
name[i] = '_';
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first letter and start with it.
|
||||
int i = 0;
|
||||
for (; i < n; ++i) {
|
||||
if (isalpha(name[i])) break;
|
||||
}
|
||||
|
||||
// Return "unknown" if none of the name's chars were letters.
|
||||
return i == n ? "unknown" : name.substr(i);
|
||||
}
|
||||
|
||||
string NodeNameMapping::UniquifyHelper(const string& name) const {
|
||||
// If the name hasn't been used yet, use it as-is.
|
||||
if (used_names_.find(name) == used_names_.end()) return name;
|
||||
// Add a suffix to name to make it unique.
|
||||
for (int i = 0;; ++i) {
|
||||
const string candidate = strings::StrCat(name, "_", i);
|
||||
if (used_names_.find(candidate) == used_names_.end()) return candidate;
|
||||
}
|
||||
}
|
||||
|
||||
string NodeNameMapping::GetIOName(const string& name) {
|
||||
const string& input_name = UniquifyHelper(Normalize(name));
|
||||
// Record that we used this name, but don't add it to name_mapping_
|
||||
// since this name is not for a node.
|
||||
used_names_.insert(input_name);
|
||||
return input_name;
|
||||
}
|
||||
|
||||
string NodeNameMapping::Uniquify(const string& name) {
|
||||
const string uniqued = UniquifyHelper(name);
|
||||
name_mapping_[name] = uniqued;
|
||||
used_names_.insert(uniqued);
|
||||
return uniqued;
|
||||
}
|
||||
|
||||
string NodeNameMapping::Lookup(const string& name) const {
|
||||
const auto iter = name_mapping_.find(name);
|
||||
if (iter == name_mapping_.end()) return string();
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
Status ValidateNoRefOutputs(const Node* node) {
|
||||
for (int i = 0; i < node->num_outputs(); ++i) {
|
||||
const DataType& dt = node->output_type(i);
|
||||
if (IsRefType(dt)) {
|
||||
return errors::InvalidArgument("Output ", i, " of node '", node->name(),
|
||||
"' has a reference "
|
||||
"type ",
|
||||
DataTypeString(dt));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FillFunctionBody(
|
||||
const string& fn_name, const NodeNameMapping& node_names,
|
||||
const std::vector<const Node*>& body_nodes,
|
||||
const std::unordered_map<string, string>& tensor_renaming,
|
||||
FunctionDef* fdef) {
|
||||
std::vector<const Edge*> in_edges;
|
||||
std::vector<const Edge*> control_edges;
|
||||
for (const Node* node : body_nodes) {
|
||||
NodeDef* node_def = fdef->add_node_def();
|
||||
// First, copy the node_def as is. We will patch it next.
|
||||
*node_def = node->def();
|
||||
if (!node->assigned_device_name().empty()) {
|
||||
node_def->set_device(node->assigned_device_name());
|
||||
}
|
||||
node_def->set_name(node_names.Lookup(node->name()));
|
||||
|
||||
// Input names must be set based on nested names in tensor_renaming.
|
||||
// Clear the flat input names we got from the original node_def
|
||||
// from the graph.
|
||||
node_def->clear_input();
|
||||
|
||||
// Collect regular and control inputs. Regular inputs are indexed
|
||||
// by the index at which they come into the `node`. Control inputs
|
||||
// don't follow any order.
|
||||
in_edges.clear();
|
||||
in_edges.resize(node->num_inputs(), nullptr);
|
||||
control_edges.clear();
|
||||
for (const Edge* edge : node->in_edges()) {
|
||||
if (edge->src()->IsSource()) continue;
|
||||
if (edge->IsControlEdge()) {
|
||||
control_edges.push_back(edge);
|
||||
} else {
|
||||
in_edges[edge->dst_input()] = edge;
|
||||
}
|
||||
}
|
||||
|
||||
// Add regular inputs.
|
||||
for (size_t i = 0; i < in_edges.size(); ++i) {
|
||||
const Edge* edge = in_edges[i];
|
||||
string original_input_name;
|
||||
if (edge == nullptr) {
|
||||
// A backedge might not appear as a regular Edge, but be only present
|
||||
// in the node_def. Such edges are referred to as requested_inputs().
|
||||
if (i >= node->requested_inputs().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Graph to be converted to function appears to be malformed. ",
|
||||
"Node ", node->name(), " is missing input edge ", i);
|
||||
}
|
||||
original_input_name =
|
||||
ParseTensorName(node->requested_inputs()[i]).ToString();
|
||||
} else {
|
||||
original_input_name =
|
||||
strings::StrCat(edge->src()->name(), ":", edge->src_output());
|
||||
}
|
||||
|
||||
const auto iter = tensor_renaming.find(original_input_name);
|
||||
if (iter == tensor_renaming.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Input ", i, ", '", original_input_name, "', of node '",
|
||||
node->name(), "' in function '", fn_name,
|
||||
"' is not available. You might need to include it in inputs "
|
||||
"or include its source node in the body");
|
||||
}
|
||||
node_def->add_input(iter->second);
|
||||
}
|
||||
|
||||
// Add control inputs.
|
||||
for (const Edge* edge : control_edges) {
|
||||
// Add this control input only if the src node is in the body.
|
||||
const string normalized = node_names.Lookup(edge->src()->name());
|
||||
// If we did not find a name for the source of control edge, this
|
||||
// source must be outside of the body. Raise an error.
|
||||
if (normalized.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"The source of control edge ", edge->DebugString(),
|
||||
" is not in the body. Encountered while creating function '",
|
||||
fn_name, "'");
|
||||
}
|
||||
node_def->add_input(strings::StrCat("^", normalized));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Graph to FunctionDef conversion. This code is closely modeled on the Python
|
||||
// code in third_party/tensorflow/python/framework/function.py.
|
||||
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
|
||||
const std::vector<const Node*>& body_nodes,
|
||||
const std::vector<OutputTensor>& inputs,
|
||||
const std::vector<OutputTensor>& outputs,
|
||||
const std::vector<string>& output_names,
|
||||
FunctionDef* fdef) {
|
||||
fdef->mutable_signature()->set_name(fn_name);
|
||||
|
||||
// Keep track of names we used and how we normalized them.
|
||||
NodeNameMapping node_names;
|
||||
|
||||
// Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
|
||||
// name we used in the function:
|
||||
// - For input tensors:
|
||||
// {flat_tensor_name -> normalized_name_of_src_node}
|
||||
// e.g. {In:3 -> in}
|
||||
// - For tensors produced by nodes in function's body:
|
||||
// {flat_tensor_name -> nested_tensor_name}
|
||||
// e.g. {Add:3 -> add_0:z:1}
|
||||
std::unordered_map<string, string> tensor_renaming;
|
||||
|
||||
// Fill inputs in function's signature.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const Node* node = inputs[i].node;
|
||||
int idx = inputs[i].index;
|
||||
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
|
||||
argdef->set_type(node->output_type(idx));
|
||||
const string& input_name = node_names.GetIOName(node->name());
|
||||
argdef->set_name(input_name);
|
||||
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
|
||||
}
|
||||
|
||||
// Fill outputs in function's signature.
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
const Node* node = outputs[i].node;
|
||||
int idx = outputs[i].index;
|
||||
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
|
||||
argdef->set_type(node->output_type(idx));
|
||||
argdef->set_name(node_names.GetIOName(node->name()));
|
||||
}
|
||||
|
||||
// Populate tensor_renaming and node_names.
|
||||
// Generate the new output names for every node in the function.
|
||||
// The NodeDefs in FunctionDefs use a different naming scheme for
|
||||
// their inputs than the NodeDefs in a graph (see the comment for
|
||||
// FunctionDef.node_def in function.proto). We do the
|
||||
// graph tensor name -> function tensor name conversion for every
|
||||
// possible input (i.e. every node's outputs) and store the result
|
||||
// in tensor_renaming.
|
||||
for (const Node* node : body_nodes) {
|
||||
// Make sure node_name does not collide with an input or output name.
|
||||
const string& node_name = node_names.Uniquify(node->name());
|
||||
// For each output_arg in the op_def, the output_ranges
|
||||
// map will have [start, end] range of indices that this arg produces
|
||||
// among all the output tensors of this op.
|
||||
NameRangeMap output_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
|
||||
for (const auto& output : output_ranges) {
|
||||
const string& output_name = output.first;
|
||||
int index_start = output.second.first;
|
||||
int index_end = output.second.second;
|
||||
for (int i = index_start; i < index_end; ++i) {
|
||||
const string& original_name = strings::StrCat(node->name(), ":", i);
|
||||
const string& new_name =
|
||||
strings::StrCat(node_name, ":", output_name, ":", i - index_start);
|
||||
// Record the mapping if this tensor is not already mapped.
|
||||
// Tensor can be already mapped if it is used as an input.
|
||||
if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
|
||||
tensor_renaming[original_name] = new_name;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
|
||||
|
||||
// Remap return values.
|
||||
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
|
||||
const string& ret_name = fdef->signature().output_arg(r).name();
|
||||
|
||||
// We convert this flat tensor name to the nested value
|
||||
// (e.g. `add:z:1`) that we stored in tensor_renaming.
|
||||
const string& return_value =
|
||||
strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
|
||||
const auto iter = tensor_renaming.find(return_value);
|
||||
if (iter == tensor_renaming.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"TF_Output ", return_value, " is neither in the function body ",
|
||||
"nor among function inputs. Encountered while creating function '",
|
||||
fn_name, "'");
|
||||
}
|
||||
(*fdef->mutable_ret())[ret_name] = iter->second;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
|
||||
// does various checks while doing so. `input_nodes` will contain the same
|
||||
// information as input_tensors just in a different structure to make
|
||||
// following processing easier. TODO(iga): Simplify this nested structure.
|
||||
Status ProcessInputs(
|
||||
const TF_Graph* fn_body, const char* fn_name, int ninputs,
|
||||
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
|
||||
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
input_tensors->reserve(ninputs);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
const Node& node = inputs[i].oper->node;
|
||||
int idx = inputs[i].index;
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
fn_body->graph.IsValidOutputTensor(&node, idx),
|
||||
"Encountered while processing input ", i, " into function '", fn_name,
|
||||
"'");
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node),
|
||||
"Encountered while processing input ", i,
|
||||
" into function '", fn_name, "'");
|
||||
|
||||
input_tensors->emplace_back(&node, idx);
|
||||
|
||||
const auto& iter = input_nodes->find(&node);
|
||||
if (iter == input_nodes->end()) {
|
||||
input_nodes->insert({&node, {idx}});
|
||||
} else {
|
||||
auto& indices = iter->second;
|
||||
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"TF_Output ", node.name(), ":", idx,
|
||||
" appears more than once in the input list");
|
||||
}
|
||||
indices.push_back(idx);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts `noutputs` and `outputs` into `outputs_tensors` and does various
|
||||
// checks while doing so.
|
||||
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
||||
int noutputs, const TF_Output* outputs,
|
||||
std::vector<OutputTensor>* output_tensors)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
output_tensors->reserve(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
const Node& node = outputs[i].oper->node;
|
||||
int idx = outputs[i].index;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
fn_body->graph.IsValidOutputTensor(&node, idx),
|
||||
"Encountered while processing output ", i, " from function '", fn_name,
|
||||
"'");
|
||||
output_tensors->emplace_back(&node, idx);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Populates `body_nodes` with the nodes that will become function's body.
|
||||
// Performs various checks.
|
||||
Status ComputeBodyNodes(
|
||||
const TF_Graph* fn_body, const char* fn_name, int num_opers,
|
||||
const TF_Operation* const* opers,
|
||||
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
|
||||
std::vector<const Node*>* body_nodes)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
if (num_opers == -1) {
|
||||
for (const Node* node : fn_body->graph.op_nodes()) {
|
||||
const auto& iter = input_nodes.find(node);
|
||||
if (iter == input_nodes.end()) {
|
||||
// This node is not referenced in inputs. Add it to the body.
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
|
||||
"Encountered while creating function '",
|
||||
fn_name, "'");
|
||||
body_nodes->push_back(node);
|
||||
} else {
|
||||
// This node is referenced in inputs. Currently, we place an
|
||||
// artificial restriction and require that when num_opers=-1, such
|
||||
// nodes must have a single output.
|
||||
if (node->num_outputs() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"When `num_opers` is set to -1, nodes referenced in `inputs` "
|
||||
"must have a single output. Node ",
|
||||
node->name(), " has ", node->num_outputs(),
|
||||
" outputs. Encountered while creating function '", fn_name, "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
body_nodes->reserve(num_opers);
|
||||
for (int i = 0; i < num_opers; ++i) {
|
||||
const Node* node = &opers[i]->node;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
|
||||
"Encountered while creating function '",
|
||||
fn_name, "'");
|
||||
body_nodes->push_back(node);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
using tensorflow::Node;
|
||||
using tensorflow::string;
|
||||
|
||||
TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
|
||||
int num_opers, const TF_Operation* const* opers,
|
||||
int ninputs, const TF_Output* inputs,
|
||||
int noutputs, const TF_Output* outputs,
|
||||
const char* const* output_names,
|
||||
const TF_FunctionOptions* opts,
|
||||
TF_Status* status) {
|
||||
tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
|
||||
|
||||
// Process inputs.
|
||||
std::vector<tensorflow::OutputTensor> input_tensors;
|
||||
std::unordered_map<const Node*, std::vector<int>> input_nodes;
|
||||
status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
|
||||
&input_tensors, &input_nodes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
// Process outputs.
|
||||
std::vector<tensorflow::OutputTensor> output_tensors;
|
||||
status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
|
||||
outputs, &output_tensors);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
// Process output names.
|
||||
std::vector<string> output_names_vec;
|
||||
if (output_names) {
|
||||
output_names_vec.reserve(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names_vec.push_back(string(output_names[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// Compute body nodes.
|
||||
std::vector<const Node*> body_nodes;
|
||||
status->status = tensorflow::ComputeBodyNodes(
|
||||
fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
// Do the actual function creation.
|
||||
TF_Function* tf_function = new TF_Function();
|
||||
status->status = tensorflow::GraphToFunctionDef(
|
||||
fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors,
|
||||
output_names_vec, tf_function->fdef_lib.add_function());
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteFunction(tf_function);
|
||||
return nullptr;
|
||||
}
|
||||
return tf_function;
|
||||
}
|
||||
|
||||
void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function,
|
||||
TF_Status* status) {
|
||||
tensorflow::mutex_lock l(g->mu);
|
||||
|
||||
// At the moment, we have only one function and no gradients in fdef_lib.
|
||||
// This makes the following operation atomic.
|
||||
// TODO(iga): Add an atomic version of AddFunctionLibrary when we support
|
||||
// gradients
|
||||
status->status = g->graph.AddFunctionLibrary(function->fdef_lib);
|
||||
}
|
||||
|
||||
void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
|
||||
TF_Status* status) {
|
||||
DCHECK_EQ(1, func->fdef_lib.function_size());
|
||||
status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def);
|
||||
}
|
||||
|
||||
void TF_DeleteFunction(TF_Function* function) { delete function; }
|
1039
tensorflow/c/c_api_function_test.cc
Normal file
1039
tensorflow/c/c_api_function_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -130,6 +130,11 @@ struct TF_DeviceList {
|
||||
std::vector<tensorflow::DeviceAttributes> response;
|
||||
};
|
||||
|
||||
struct TF_Function {
|
||||
// Currently contains a single function and no gradients
|
||||
tensorflow::FunctionDefLibrary fdef_lib;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
@ -141,7 +146,12 @@ class TensorCApi {
|
||||
}
|
||||
};
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_C_API_INTERNAL_H_
|
||||
|
@ -829,7 +829,7 @@ TEST(CAPI, ShapeInferenceError) {
|
||||
TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Operation* add = Add(vec2, vec3, graph, status);
|
||||
TF_Operation* add = AddNoCheck(vec2, vec3, graph, status);
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status));
|
||||
ASSERT_TRUE(add == nullptr);
|
||||
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using tensorflow::GraphDef;
|
||||
@ -36,6 +38,23 @@ TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
|
||||
return t;
|
||||
}
|
||||
|
||||
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
|
||||
const int32_t* values) {
|
||||
int64_t num_values = 1;
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
num_values *= dims[i];
|
||||
}
|
||||
TF_Tensor* t =
|
||||
TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
|
||||
memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
|
||||
return t;
|
||||
}
|
||||
|
||||
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
|
||||
int64_t dims = values.size();
|
||||
return Int32Tensor(&dims, 1, values.data());
|
||||
}
|
||||
|
||||
TF_Tensor* Int32Tensor(int32_t v) {
|
||||
const int num_bytes = sizeof(int32_t);
|
||||
int32_t* values = new int32_t[1];
|
||||
@ -44,19 +63,40 @@ TF_Tensor* Int32Tensor(int32_t v) {
|
||||
&Int32Deallocator, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
|
||||
// All the *Helper methods are used as a workaround for the restrictions that
|
||||
// one cannot call ASSERT_* methods in non-void-returning functions (when
|
||||
// exceptions are disabled during compilation)
|
||||
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
|
||||
TF_Operation** op) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
|
||||
TF_SetAttrType(desc, "dtype", TF_INT32);
|
||||
return TF_FinishOperation(desc, s);
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) {
|
||||
TF_Operation* op;
|
||||
PlaceholderHelper(graph, s, name, &op);
|
||||
return op;
|
||||
}
|
||||
|
||||
void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
|
||||
TF_Operation** op) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
|
||||
TF_SetAttrTensor(desc, "value", t, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
|
||||
TF_SetAttrTensor(desc, "value", t, s);
|
||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
|
||||
return TF_FinishOperation(desc, s);
|
||||
TF_Operation* op;
|
||||
ConstHelper(t, graph, s, name, &op);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
@ -65,11 +105,39 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
return Const(tensor.get(), graph, s, name);
|
||||
}
|
||||
|
||||
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name) {
|
||||
void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
|
||||
const char* name, TF_Operation** op, bool check) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
|
||||
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
|
||||
TF_AddInputList(desc, add_inputs, 2);
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
if (check) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name) {
|
||||
TF_Operation* op;
|
||||
AddHelper(l, r, graph, s, name, &op, true);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name) {
|
||||
TF_Operation* op;
|
||||
AddHelper(l, r, graph, s, name, &op, false);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
|
||||
TF_Graph* graph, TF_Operation* ctrl_op,
|
||||
TF_Status* s, const char* name) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
|
||||
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
|
||||
TF_AddInputList(desc, add_inputs, 2);
|
||||
TF_AddControlInput(desc, ctrl_op);
|
||||
return TF_FinishOperation(desc, s);
|
||||
}
|
||||
|
||||
@ -81,11 +149,20 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
|
||||
return TF_FinishOperation(desc, s);
|
||||
}
|
||||
|
||||
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
|
||||
void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s,
|
||||
TF_Operation** op) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
|
||||
TF_Output neg_input = {n, 0};
|
||||
TF_AddInput(desc, neg_input);
|
||||
return TF_FinishOperation(desc, s);
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
|
||||
TF_Operation* op;
|
||||
NegHelper(n, graph, s, &op);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
|
||||
@ -96,6 +173,32 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
|
||||
return TF_FinishOperation(desc, s);
|
||||
}
|
||||
|
||||
void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
|
||||
const char* name, TF_Operation** op) {
|
||||
TF_Operation* zero = ScalarConst(
|
||||
0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
|
||||
TF_AddInput(desc, {zero, 0});
|
||||
TF_AddInput(desc, {input, 0});
|
||||
TF_SetAttrInt(desc, "num_split", 3);
|
||||
TF_SetAttrType(desc, "T", TF_INT32);
|
||||
// Set device to CPU since there is no version of split for int32 on GPU
|
||||
// TODO(iga): Convert all these helpers and tests to use floats because
|
||||
// they are usually available on GPUs. After doing this, remove TF_SetDevice
|
||||
// call in c_api_function_test.cc
|
||||
TF_SetDevice(desc, "/cpu:0");
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
TF_Operation* op;
|
||||
Split3Helper(input, graph, s, name, &op);
|
||||
return op;
|
||||
}
|
||||
|
||||
bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
|
||||
if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
|
||||
return false;
|
||||
@ -196,6 +299,18 @@ bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Buffer* buffer = TF_NewBuffer();
|
||||
TF_FunctionToFunctionDef(func, buffer, s);
|
||||
bool ret = TF_GetCode(s) == TF_OK;
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
|
||||
TF_DeleteBuffer(buffer);
|
||||
TF_DeleteStatus(s);
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
|
||||
tensorflow::AttrValue* attr_value, TF_Status* s) {
|
||||
TF_Buffer* buffer = TF_NewBuffer();
|
||||
|
@ -33,6 +33,13 @@ typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
|
||||
// Create a tensor with values of type TF_INT8 provided by `values`.
|
||||
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
|
||||
|
||||
// Create a tensor with values of type TF_INT32 provided by `values`.
|
||||
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
|
||||
const int32_t* values);
|
||||
|
||||
// Create 1 dimensional tensor with values from `values`
|
||||
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values);
|
||||
|
||||
TF_Tensor* Int32Tensor(int32_t v);
|
||||
|
||||
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
|
||||
@ -47,6 +54,13 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
|
||||
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name = "add");
|
||||
|
||||
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name = "add");
|
||||
|
||||
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
|
||||
TF_Graph* graph, TF_Operation* ctrl_op,
|
||||
TF_Status* s, const char* name = "add");
|
||||
|
||||
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "add");
|
||||
|
||||
@ -54,6 +68,10 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s);
|
||||
|
||||
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);
|
||||
|
||||
// Split `input` along the first dimention into 3 tensors
|
||||
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
|
||||
const char* name = "split3");
|
||||
|
||||
bool IsPlaceholder(const tensorflow::NodeDef& node_def);
|
||||
|
||||
bool IsScalarConst(const tensorflow::NodeDef& node_def, int v);
|
||||
@ -66,6 +84,8 @@ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
|
||||
|
||||
bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def);
|
||||
|
||||
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def);
|
||||
|
||||
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
|
||||
tensorflow::AttrValue* attr_value, TF_Status* s);
|
||||
|
||||
|
@ -151,10 +151,11 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
return TF_SessionListDevices(ctx->session, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) {
|
||||
return new TFE_TensorHandle(
|
||||
tensorflow::TensorCApi::MakeTensor(t->dtype, t->shape, t->buffer),
|
||||
nullptr);
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return new TFE_TensorHandle(tensor, nullptr);
|
||||
}
|
||||
|
||||
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
|
||||
|
@ -20,6 +20,25 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
// Macro to control visibility of exported symbols in the shared library (.so,
|
||||
// .dylib, .dll).
|
||||
// This duplicates the TF_EXPORT macro definition in
|
||||
// tensorflow/core/platform/macros.h in order to keep this .h file independent
|
||||
// of any other includes.$a
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(COMPILER_MSVC)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // COMPILER_MSVC
|
||||
#endif // SWIG
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@ -30,11 +49,11 @@ extern "C" {
|
||||
// TODO(ashankar): Merge with TF_Session?
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
|
||||
extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
|
||||
TF_Status* status);
|
||||
extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
|
||||
extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// A handle to a tensor on a device.
|
||||
//
|
||||
@ -43,14 +62,15 @@ extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
||||
// placed in memory of different devices or remote address spaces.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t);
|
||||
extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||
extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
|
||||
extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
|
||||
extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index);
|
||||
extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h);
|
||||
extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||
TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
|
||||
TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
|
||||
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index);
|
||||
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h);
|
||||
TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
// Create a new TFE_TensorHandle with the same contents as 'h' but placed
|
||||
// in the memory of the device name 'device_name'.
|
||||
@ -58,10 +78,10 @@ extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
|
||||
// that shares the underlying buffer. Otherwise, it currently requires at least
|
||||
// one of the source or destination devices to be CPU (i.e., for the source or
|
||||
// destination tensor to be placed in host memory).
|
||||
extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
TFE_Context* ctx,
|
||||
const char* device_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Description of the TensorFlow op to execute.
|
||||
//
|
||||
@ -76,49 +96,49 @@ extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
// the additional sanity checks there seem unnecessary;
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status);
|
||||
extern void TFE_DeleteOp(TFE_Op* op);
|
||||
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
|
||||
|
||||
// TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context
|
||||
// parameter. Instead, the TFE_Context should be captured when creating the
|
||||
// TFE_Op.
|
||||
extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx,
|
||||
const char* device_name, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx,
|
||||
const char* device_name, TF_Status* status);
|
||||
|
||||
extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status);
|
||||
|
||||
extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name,
|
||||
const char* value);
|
||||
extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value);
|
||||
extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value);
|
||||
extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
|
||||
unsigned char value);
|
||||
extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
|
||||
TF_DataType value);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name,
|
||||
const char* value);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
|
||||
unsigned char value);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
|
||||
TF_DataType value);
|
||||
// If the number of dimensions is unknown, `num_dims` must be set to
|
||||
// -1 and `dims` can be null. If a dimension is unknown, the
|
||||
// corresponding entry in the `dims` array must be -1.
|
||||
extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* dims, const int num_dims,
|
||||
TF_Status* out_status);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* dims, const int num_dims,
|
||||
TF_Status* out_status);
|
||||
|
||||
extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
const char** value, int num_values);
|
||||
extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values);
|
||||
extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
const float* values, int num_values);
|
||||
extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values);
|
||||
extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values);
|
||||
extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
||||
const char** value, int num_values);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t* values, int num_values);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
||||
const float* values, int num_values);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
||||
const unsigned char* values, int num_values);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
||||
const TF_DataType* values, int num_values);
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, TF_Status* out_status);
|
||||
|
||||
// Execute the operation defined by 'op' and return handles to computed
|
||||
// tensors in 'retvals'.
|
||||
@ -128,14 +148,14 @@ extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
//
|
||||
// On return, 'num_retvals' will be set to the actual number of outputs
|
||||
// returned by the operation.
|
||||
extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
|
||||
int* num_retvals, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
|
||||
int* num_retvals, TF_Status* status);
|
||||
|
||||
// Add a function (serialized FunctionDef protocol buffer) to ctx so
|
||||
// that it can be invoked using TFE_Execute.
|
||||
extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
const char* serialized_function_def,
|
||||
size_t size, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
const char* serialized_function_def,
|
||||
size_t size, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
|
@ -34,7 +34,9 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
return th;
|
||||
}
|
||||
@ -383,7 +385,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle);
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status), TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
@ -2,6 +2,7 @@ VERS_1.0 {
|
||||
# Export symbols in c_api.h.
|
||||
global:
|
||||
*TF_*;
|
||||
*TFE_*;
|
||||
|
||||
# Hide everything else.
|
||||
local:
|
||||
|
@ -77,6 +77,10 @@ class SymbolicGradientBuilder {
|
||||
Status CallGradFunction(const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs);
|
||||
|
||||
// Returns a list mapping whether each node in the graph is reachable
|
||||
// from outputs_. Keyed by node id.
|
||||
std::vector<bool> GetReachableNodes();
|
||||
|
||||
const Scope& scope_;
|
||||
const ops::GradOpRegistry* registry_;
|
||||
@ -143,11 +147,36 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
|
||||
std::vector<bool> reachable_nodes(scope_.graph()->num_node_ids(), false);
|
||||
std::deque<Node*> queue;
|
||||
for (const Output& out : outputs_) {
|
||||
if (!reachable_nodes[out.node()->id()]) {
|
||||
queue.push_back(out.node());
|
||||
reachable_nodes[out.node()->id()] = true;
|
||||
}
|
||||
}
|
||||
|
||||
while (!queue.empty()) {
|
||||
Node* n = queue.front();
|
||||
queue.pop_front();
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
queue.push_back(e->src());
|
||||
reachable_nodes[e->src()->id()] = true;
|
||||
}
|
||||
}
|
||||
return reachable_nodes;
|
||||
}
|
||||
|
||||
Status SymbolicGradientBuilder::Initialize() {
|
||||
if (outputs_.size() != grad_inputs_.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Must specify a gradient input for each output.");
|
||||
}
|
||||
std::vector<bool> reachable_nodes = GetReachableNodes();
|
||||
// TODO(theflofly) Check that inputs_ are reachable from
|
||||
// outputs_ using reachable_nodes
|
||||
grad_outputs_->clear();
|
||||
grad_outputs_->resize(inputs_.size());
|
||||
// Populate `output_nodes_` from node ids in `outputs_`.
|
||||
@ -188,12 +217,15 @@ Status SymbolicGradientBuilder::Initialize() {
|
||||
if (output_nodes_.find(n->id()) == output_nodes_.end()) {
|
||||
// Internal node: continue BFS along connected outputs.
|
||||
for (const Edge* e : n->out_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
++num_expected_backprops;
|
||||
// If a node is not reachable from outputs_,
|
||||
// we don't expect it to receive a backpropagated gradient.
|
||||
// It will not be counted in num_expected_backprops.
|
||||
if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue;
|
||||
if (visited.find(e->dst()) == visited.end()) {
|
||||
queue.push_back(e->dst());
|
||||
visited.insert(e->dst());
|
||||
}
|
||||
++num_expected_backprops;
|
||||
}
|
||||
} else {
|
||||
// Output node: stop BFS and update `num_expected_backprops` for
|
||||
|
@ -364,6 +364,73 @@ TEST_F(GradientsTest, MultipleNodeOutputGrads) {
|
||||
test::AsTensor<int>({60, 61, 62, 63, 66, 66, 66, 67}, {4, 2}));
|
||||
}
|
||||
|
||||
TEST_F(GradientsTest, UnreachableEdgeGradOneOutput) {
|
||||
auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
|
||||
auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
|
||||
auto x_assign = Assign(scope_test_, x, x_const);
|
||||
|
||||
auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
|
||||
auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
|
||||
auto y_assign = Assign(scope_test_, y, y_const);
|
||||
|
||||
auto m1 = MatMul(scope_test_, x, y);
|
||||
|
||||
auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
|
||||
auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
|
||||
auto z_assign = Assign(scope_test_, z, z_const);
|
||||
|
||||
auto m2 = MatMul(scope_test_, y, z);
|
||||
|
||||
auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
|
||||
|
||||
std::vector<Output> grad_outputs;
|
||||
TF_ASSERT_OK(
|
||||
AddSymbolicGradients(scope_test_, {m1}, {y}, {dm1}, &grad_outputs));
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
|
||||
{grad_outputs[0]}, &outputs);
|
||||
// dz/dy = xT * dm1
|
||||
test::ExpectTensorNear<double>(
|
||||
outputs[0], test::AsTensor<double>({2.5, 3.5, 4.5}, {3, 1}), 1e-5);
|
||||
}
|
||||
|
||||
TEST_F(GradientsTest, UnreachableEdgeGradTwoOutputs) {
|
||||
auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
|
||||
auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
|
||||
auto x_assign = Assign(scope_test_, x, x_const);
|
||||
|
||||
auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
|
||||
auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
|
||||
auto y_assign = Assign(scope_test_, y, y_const);
|
||||
|
||||
auto m1 = MatMul(scope_test_, x, y);
|
||||
|
||||
auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
|
||||
auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
|
||||
auto z_assign = Assign(scope_test_, z, z_const);
|
||||
|
||||
auto m2 = MatMul(scope_test_, y, z);
|
||||
|
||||
auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
|
||||
auto dm2 =
|
||||
Const(scope_test_, {{0.5, 0.5, 0.5}, {0.6, 0.7, 0.8}, {0.6, 0.7, 0.9}});
|
||||
|
||||
std::vector<Output> grad_outputs;
|
||||
TF_ASSERT_OK(AddSymbolicGradients(scope_test_, {m1, m2}, {y}, {dm1, dm2},
|
||||
&grad_outputs));
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
|
||||
{grad_outputs[0]}, &outputs);
|
||||
|
||||
// the gradients from m1 and m2 will be summed to compute the gradient
|
||||
// w.r.t y
|
||||
// dz/dy = xT * dm1 + dm2 * zT
|
||||
test::ExpectTensorNear<double>(
|
||||
outputs[0], test::AsTensor<double>({17.5, 24.7, 26.8}, {3, 1}), 1e-5);
|
||||
}
|
||||
|
||||
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
|
||||
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
|
||||
// a single nodes output.
|
||||
|
@ -36,5 +36,19 @@ void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
|
||||
*out = outputs[0];
|
||||
}
|
||||
|
||||
void GetTensors(const Scope& scope, const std::vector<Output>& assign_vars,
|
||||
OutputList tensors, std::vector<Tensor>* out) {
|
||||
ClientSession session(scope);
|
||||
TF_CHECK_OK(session.Run(assign_vars, nullptr));
|
||||
TF_CHECK_OK(session.Run(tensors, out));
|
||||
}
|
||||
|
||||
void GetTensor(const Scope& scope, const std::vector<Output>& assign_vars,
|
||||
Output tensor, Tensor* out) {
|
||||
std::vector<Tensor> outputs;
|
||||
GetTensors(scope, assign_vars, {std::move(tensor)}, &outputs);
|
||||
*out = outputs[0];
|
||||
}
|
||||
|
||||
} // end namespace test
|
||||
} // end namespace tensorflow
|
||||
|
@ -26,9 +26,21 @@ namespace test {
|
||||
void GetTensors(const Scope& scope, OutputList tensors,
|
||||
std::vector<Tensor>* out);
|
||||
|
||||
// Computes the outputs listed in 'tensors', returns the tensors in 'out'.
|
||||
// assign_vars are extra outputs that should be run
|
||||
// e.g. to assign values to variables.
|
||||
void GetTensors(const Scope& scope, const std::vector<Output>& assign_vars,
|
||||
OutputList tensors, std::vector<Tensor>* out);
|
||||
|
||||
/// Computes the output 'tensor', returning the resulting tensor in 'out'.
|
||||
void GetTensor(const Scope& scope, Output tensor, Tensor* out);
|
||||
|
||||
// Computes the output 'tensor', returning the resulting tensor in 'out'.
|
||||
// assign_vars are extra outputs that should be run
|
||||
// e.g. to assign values to variables.
|
||||
void GetTensor(const Scope& scope, const std::vector<Output>& assign_vars,
|
||||
Output tensor, Tensor* out);
|
||||
|
||||
} // namespace test
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -687,6 +687,72 @@ Status MeanGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Mean", MeanGrad);
|
||||
|
||||
Status MinOrMaxGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
// The partial derivative for any input along a "reduced" dimension
|
||||
// is 1 when it is the min (or max) and 0 everywhere else. So the
|
||||
// gradient calculation is identical for both operators.
|
||||
//
|
||||
// There's a special case for propagating gradients when there are
|
||||
// multiple minima (or maxima) - we choose to divide the gradient
|
||||
// equally among all matching inputs.
|
||||
//
|
||||
// Please note this comment
|
||||
// https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
|
||||
// for details.
|
||||
|
||||
// Running example:
|
||||
// input: [[5, 5, 5],
|
||||
// [1, 2, -3]]
|
||||
// reduction_indices: [1]
|
||||
auto input = op.input(0);
|
||||
auto reduction_indices = op.input(1);
|
||||
|
||||
// [2, 3]
|
||||
auto input_shape = Shape(scope, input);
|
||||
|
||||
// [2, 1]
|
||||
auto output_shape_kept_dims =
|
||||
ReducedShapeHelper(scope, input_shape, reduction_indices);
|
||||
|
||||
// for op=min (say)
|
||||
// output = [5, -3]
|
||||
// y = [[5],
|
||||
// [-3]]
|
||||
auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
|
||||
|
||||
// reshape([g1, g2], [2, 1]) = [[g1],
|
||||
// [g2]]
|
||||
auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
|
||||
|
||||
// indicators = equal(y, input)
|
||||
// = equal([[5], [[5, 5, 5],
|
||||
// [-3]], [1, 2, -3]])
|
||||
// = [[1, 1, 1],
|
||||
// [0, 0, 1]]
|
||||
auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
|
||||
|
||||
// [[3],
|
||||
// [1]]
|
||||
auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
|
||||
output_shape_kept_dims);
|
||||
|
||||
// [[1/3, 1/3, 1/3],
|
||||
// [0, 0, 1]]
|
||||
auto scale = Div(scope, indicators, num_selected);
|
||||
|
||||
// [[g1/3, g1/3, g1/3],
|
||||
// [0, 0, g2]]
|
||||
grad_outputs->push_back(Mul(scope, scale, grad));
|
||||
|
||||
// Stop propagation along reduction_indices
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
|
||||
REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
|
||||
|
||||
// MatMulGrad helper function used to compute two MatMul operations
|
||||
// based on input matrix transposition combinations.
|
||||
Status MatMulGradHelper(const Scope& scope, const bool is_batch,
|
||||
|
@ -955,6 +955,55 @@ TEST_F(NaryGradTest, Mean) {
|
||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, Min) {
|
||||
TensorShape x_shape({2, 3});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto y = Min(scope_, x, {-1});
|
||||
// y's shape is the result of reducing x along axes -1 (= 1)
|
||||
TensorShape y_shape({2});
|
||||
Tensor x_init_value =
|
||||
test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, Max) {
|
||||
TensorShape x_shape({2, 3});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto y = Max(scope_, x, {-1});
|
||||
// y's shape is the result of reducing x along axes -1 (= 1)
|
||||
TensorShape y_shape({2});
|
||||
Tensor x_init_value =
|
||||
test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, MinMulti) {
|
||||
// Test gradient when there are multiple minima.
|
||||
// Note that we cannot directly use a test Tensor with multiple
|
||||
// minima, as the numeric estimator will calculate incorrect
|
||||
// gradients when perturbing each entry in the Tensor (which then
|
||||
// changes how many minima exist.)
|
||||
// Instead, we use a single input that broadcast-multiplies a larger
|
||||
// tensor with equal values, and apply reduce_min to the multiplied
|
||||
// result.
|
||||
TensorShape x_shape({1});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
|
||||
auto y = Min(scope_, all_same, {0});
|
||||
// y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
|
||||
TensorShape y_shape({1});
|
||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, MaxMulti) {
|
||||
TensorShape x_shape({1});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
|
||||
auto y = Max(scope_, all_same, {0});
|
||||
TensorShape y_shape({1});
|
||||
RunTest({x}, {x_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
TEST_F(NaryGradTest, AddN) {
|
||||
TensorShape shape({3, 2, 5});
|
||||
std::vector<Output> xs;
|
||||
|
@ -52,6 +52,12 @@ class BinaryOpsTest(XLATestCase):
|
||||
|
||||
def testFloatOps(self):
|
||||
for dtype in self.float_types:
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
|
||||
np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype),
|
||||
np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype),
|
||||
expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._real_div,
|
||||
np.array([3, 3, -1.5, -8, 44], dtype=dtype),
|
||||
@ -82,6 +88,12 @@ class BinaryOpsTest(XLATestCase):
|
||||
dtype(4),
|
||||
expected=np.array([[16], [81]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._reciprocal_grad,
|
||||
np.array([4, -3, -2, 1], dtype=dtype),
|
||||
np.array([5, -6, 7, -8], dtype=dtype),
|
||||
expected=np.array([-80, 54, -28, 8], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._sigmoid_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
@ -107,6 +119,13 @@ class BinaryOpsTest(XLATestCase):
|
||||
expected=np.array(
|
||||
[3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_nn_ops._softsign_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
np.array([5, 6, 7, 8], dtype=dtype),
|
||||
expected=np.array(
|
||||
[0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._tanh_grad,
|
||||
np.array([4, 3, 2, 1], dtype=dtype),
|
||||
|
@ -888,6 +888,16 @@ TEST_F(OpTest, Any) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, ApproximateEqual) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = RandomDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Asinh) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
@ -1662,11 +1672,9 @@ TEST_F(OpTest, GreaterEqual) {
|
||||
|
||||
TEST_F(OpTest, L2Loss) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
// TODO(b/31644876): scalars currently crash.
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss")
|
||||
.RandomInput(type, RandomDims(1))
|
||||
.Attr("T", type));
|
||||
DataType type = DT_FLOAT;
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("L2Loss").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -2165,6 +2173,15 @@ TEST_F(OpTest, Reciprocal) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, ReciprocalGrad) {
|
||||
Repeatedly([this]() {
|
||||
std::vector<int64> dims = RandomDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
TEST_F(OpTest, Relu) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
@ -2250,6 +2267,13 @@ TEST_F(OpTest, ReverseV2) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Rint) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Rint").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Round) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
@ -2402,6 +2426,23 @@ TEST_F(OpTest, SoftplusGrad) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Softsign) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Softsign").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, SoftsignGrad) {
|
||||
Repeatedly([this]() {
|
||||
std::vector<int64> dims = RandomDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftsignGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, SpaceToBatch) {
|
||||
Repeatedly([this]() {
|
||||
std::vector<int64> block_dims = RandomDims(4, 4, 0, 5);
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
@ -161,12 +163,17 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[-1.7, 1.2]], dtype=dtype),
|
||||
expected=np.array([[-2, 1]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_finite,
|
||||
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
|
||||
dtype=dtype),
|
||||
expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool))
|
||||
|
||||
# Tests for tf.nn ops.
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0))
|
||||
|
||||
# TODO(b/31644876): enable this test case when fixed.
|
||||
# self._assertOpOutputMatchesExpected(tf.nn.l2_loss, dtype(4), dtype(10))
|
||||
self._assertOpOutputMatchesExpected(nn_ops.l2_loss, dtype(4), dtype(8))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10))
|
||||
@ -198,6 +205,12 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
|
||||
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.rint,
|
||||
np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
|
||||
[0.5, 1.5, 2.5, 3.5]], dtype=dtype),
|
||||
expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]],
|
||||
dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.round,
|
||||
np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
|
||||
@ -301,6 +314,12 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[-2, 0, 8]], dtype=dtype),
|
||||
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softsign,
|
||||
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
|
||||
expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]],
|
||||
dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_finite,
|
||||
np.array(
|
||||
@ -335,6 +354,23 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[4, 3], [2, 1]], dtype=dtype),
|
||||
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
|
||||
|
||||
# TODO(phawkins): these tests fail unless fastmath optimizations
|
||||
# are disabled. Use more robust IsInf/IsNaN detection and enable these
|
||||
# tests.
|
||||
@unittest.skip("test case fails in fast-math mode")
|
||||
def testIsInfAndIsNan(self):
|
||||
for dtype in self.float_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_inf,
|
||||
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
|
||||
dtype=dtype),
|
||||
expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_nan,
|
||||
np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
|
||||
dtype=dtype),
|
||||
expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool))
|
||||
|
||||
def testLogicalOps(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.logical_not,
|
||||
|
@ -31,7 +31,6 @@ tf_kernel_library(
|
||||
"function_ops.cc",
|
||||
"gather_op.cc",
|
||||
"identity_op.cc",
|
||||
"is_finite_op.cc",
|
||||
"l2loss_op.cc",
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
|
@ -102,6 +102,7 @@ XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs))));
|
||||
XLA_MAKE_BINARY(
|
||||
RsqrtGrad,
|
||||
b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
|
||||
@ -140,6 +141,11 @@ XLA_MAKE_BINARY(SoftplusGrad,
|
||||
b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
|
||||
XlaHelpers::One(b, input_type(1)))));
|
||||
|
||||
// softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
|
||||
XLA_MAKE_BINARY(SoftsignGrad,
|
||||
b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)),
|
||||
b->Abs(rhs)))));
|
||||
|
||||
XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
|
||||
b->Mul(lhs, lhs))));
|
||||
|
||||
@ -147,5 +153,24 @@ XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions));
|
||||
|
||||
#undef XLA_MAKE_BINARY
|
||||
|
||||
class ApproximateEqualOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ApproximateEqualOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance_));
|
||||
}
|
||||
|
||||
// Computes the max of the scalar input x and 0.
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), tolerance_));
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
|
||||
private:
|
||||
float tolerance_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ApproximateEqual"), ApproximateEqualOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,43 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class IsFiniteOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit IsFiniteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle input = ctx->Input(0);
|
||||
ctx->SetOutput(0, ctx->builder()->IsFinite(input));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IsFiniteOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("IsFinite"), IsFiniteOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
@ -73,8 +73,12 @@ XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
|
||||
XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0))));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
|
||||
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
|
||||
XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
|
||||
XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x));
|
||||
XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x),
|
||||
XlaHelpers::FloatLiteral(
|
||||
b, input_type(0),
|
||||
std::numeric_limits<double>::infinity())));
|
||||
XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x));
|
||||
// Return 1/x
|
||||
XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x));
|
||||
XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x));
|
||||
@ -105,6 +109,12 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b,
|
||||
b->Add(round_val, one), round_val);
|
||||
}
|
||||
|
||||
XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x));
|
||||
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Rsqrt,
|
||||
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
|
||||
|
||||
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
|
||||
static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
|
||||
DataType dtype,
|
||||
@ -112,16 +122,19 @@ static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
|
||||
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
|
||||
return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
|
||||
}
|
||||
|
||||
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
|
||||
XLAJIT_MAKE_UNARY(Rsqrt,
|
||||
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
|
||||
XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
|
||||
|
||||
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
|
||||
XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
|
||||
XLAJIT_MAKE_UNARY(Sinh,
|
||||
b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
|
||||
XLAJIT_MAKE_UNARY(Softplus,
|
||||
b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0)))));
|
||||
// softsign(x) = x / (abs(x) + 1)
|
||||
XLAJIT_MAKE_UNARY(Softsign,
|
||||
b->Div(x,
|
||||
b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0)))));
|
||||
XLAJIT_MAKE_UNARY(Sqrt,
|
||||
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
|
||||
XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));
|
||||
|
@ -847,6 +847,7 @@ cc_test(
|
||||
srcs = ["hlo_ordering_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_ordering",
|
||||
":hlo_scheduling",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -241,7 +241,7 @@ Status Executor::Run() {
|
||||
completion_queue_.pop_front();
|
||||
break;
|
||||
}
|
||||
} while (1);
|
||||
} while (true);
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
|
||||
assignment_->GetUniqueTopLevelSlice(instruction));
|
||||
void* result_buffer =
|
||||
|
@ -24,16 +24,14 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo,
|
||||
HloOpcode opcode) {
|
||||
Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo) {
|
||||
return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s",
|
||||
HloOpcodeString(opcode).c_str());
|
||||
HloOpcodeString(hlo->opcode()).c_str());
|
||||
}
|
||||
|
||||
Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) {
|
||||
Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo) {
|
||||
return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s",
|
||||
HloOpcodeString(opcode).c_str());
|
||||
HloOpcodeString(hlo->opcode()).c_str());
|
||||
}
|
||||
|
||||
DfsHloVisitor::VisitState DfsHloVisitor::GetVisitState(
|
||||
|
@ -63,37 +63,37 @@ class DfsHloVisitor {
|
||||
// These routines are self-descriptive, see class comment for usage
|
||||
// information.
|
||||
|
||||
virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode);
|
||||
virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode);
|
||||
virtual Status HandleElementwiseUnary(HloInstruction* hlo);
|
||||
virtual Status HandleElementwiseBinary(HloInstruction* hlo);
|
||||
virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
||||
HloInstruction* arg, HloInstruction* max) = 0;
|
||||
virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred,
|
||||
HloInstruction* on_true,
|
||||
HloInstruction* on_false) = 0;
|
||||
virtual Status HandleMaximum(HloInstruction* maximum) {
|
||||
return HandleElementwiseBinary(maximum, HloOpcode::kMaximum);
|
||||
return HandleElementwiseBinary(maximum);
|
||||
}
|
||||
virtual Status HandleMinimum(HloInstruction* minimum) {
|
||||
return HandleElementwiseBinary(minimum, HloOpcode::kMinimum);
|
||||
return HandleElementwiseBinary(minimum);
|
||||
}
|
||||
virtual Status HandleConcatenate(
|
||||
HloInstruction* concatenate,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
|
||||
virtual Status HandleConvert(HloInstruction* convert) {
|
||||
return HandleElementwiseUnary(convert, HloOpcode::kConvert);
|
||||
return HandleElementwiseUnary(convert);
|
||||
}
|
||||
virtual Status HandleCopy(HloInstruction* copy) {
|
||||
return HandleElementwiseUnary(copy, HloOpcode::kCopy);
|
||||
return HandleElementwiseUnary(copy);
|
||||
}
|
||||
virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(multiply, HloOpcode::kMultiply);
|
||||
return HandleElementwiseBinary(multiply);
|
||||
}
|
||||
virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
|
||||
HloInstruction* rhs) = 0;
|
||||
virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(power, HloOpcode::kPower);
|
||||
return HandleElementwiseBinary(power);
|
||||
}
|
||||
virtual Status HandleConvolution(HloInstruction* convolution,
|
||||
HloInstruction* lhs, HloInstruction* rhs,
|
||||
@ -101,73 +101,72 @@ class DfsHloVisitor {
|
||||
virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0;
|
||||
virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
||||
HloInstruction* lhs, HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(compare, opcode);
|
||||
return HandleElementwiseBinary(compare);
|
||||
}
|
||||
virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(add, HloOpcode::kAdd);
|
||||
return HandleElementwiseBinary(add);
|
||||
}
|
||||
virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(divide, HloOpcode::kDivide);
|
||||
return HandleElementwiseBinary(divide);
|
||||
}
|
||||
virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(remainder, HloOpcode::kRemainder);
|
||||
return HandleElementwiseBinary(remainder);
|
||||
}
|
||||
virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(subtract, HloOpcode::kSubtract);
|
||||
return HandleElementwiseBinary(subtract);
|
||||
}
|
||||
virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(abs, HloOpcode::kAbs);
|
||||
return HandleElementwiseUnary(abs);
|
||||
}
|
||||
virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(sign, HloOpcode::kSign);
|
||||
return HandleElementwiseUnary(sign);
|
||||
}
|
||||
virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(negate, HloOpcode::kNegate);
|
||||
return HandleElementwiseUnary(negate);
|
||||
}
|
||||
virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(exp, HloOpcode::kExp);
|
||||
return HandleElementwiseUnary(exp);
|
||||
}
|
||||
virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(floor, HloOpcode::kFloor);
|
||||
return HandleElementwiseUnary(floor);
|
||||
}
|
||||
virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(ceil, HloOpcode::kCeil);
|
||||
return HandleElementwiseUnary(ceil);
|
||||
}
|
||||
virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(log, HloOpcode::kLog);
|
||||
return HandleElementwiseUnary(log);
|
||||
}
|
||||
virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(cos, HloOpcode::kCos);
|
||||
return HandleElementwiseUnary(cos);
|
||||
}
|
||||
virtual Status HandleSin(HloInstruction* sin, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(sin, HloOpcode::kSin);
|
||||
return HandleElementwiseUnary(sin);
|
||||
}
|
||||
virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(tanh, HloOpcode::kTanh);
|
||||
return HandleElementwiseUnary(tanh);
|
||||
}
|
||||
virtual Status HandleIsFinite(HloInstruction* is_finite,
|
||||
HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite);
|
||||
return HandleElementwiseUnary(is_finite);
|
||||
}
|
||||
virtual Status HandleLogicalAnd(HloInstruction* logical_and,
|
||||
HloInstruction* lhs, HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd);
|
||||
return HandleElementwiseBinary(logical_and);
|
||||
}
|
||||
virtual Status HandleLogicalNot(HloInstruction* logical_not,
|
||||
HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot);
|
||||
return HandleElementwiseUnary(logical_not);
|
||||
}
|
||||
virtual Status HandleLogicalOr(HloInstruction* logical_or,
|
||||
HloInstruction* lhs, HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr);
|
||||
return HandleElementwiseBinary(logical_or);
|
||||
}
|
||||
virtual Status HandleReducePrecision(HloInstruction* reduce_precision) {
|
||||
return HandleElementwiseUnary(reduce_precision,
|
||||
HloOpcode::kReducePrecision);
|
||||
return HandleElementwiseUnary(reduce_precision);
|
||||
}
|
||||
|
||||
virtual Status HandleInfeed(HloInstruction* infeed) = 0;
|
||||
|
@ -41,12 +41,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
|
||||
// Default action performed on HloInstruction.
|
||||
virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0;
|
||||
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override {
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override {
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
|
@ -709,7 +709,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
|
||||
} else {
|
||||
auto r = ir_builder_->CreateSub(q, p);
|
||||
auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
|
||||
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)},
|
||||
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)},
|
||||
{param_ir_type}, ir_builder_);
|
||||
auto in_block = ir_builder_->GetInsertBlock();
|
||||
|
||||
|
@ -334,7 +334,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
||||
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_);
|
||||
|
||||
IrArray::Index input_index(index.size());
|
||||
llvm::Value* in_bounds = ir_builder_->getInt1(1);
|
||||
llvm::Value* in_bounds = ir_builder_->getInt1(true);
|
||||
for (size_t i = 0; i < index.size(); ++i) {
|
||||
llvm::Value* stridden_index = ir_builder_->CreateNSWMul(
|
||||
index[i], ir_builder_->getInt64(window.dimensions(i).stride()));
|
||||
|
@ -389,7 +389,7 @@ StatusOr<string> CompileModuleToPtx(llvm::Module* module,
|
||||
|
||||
// Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA
|
||||
// again after the standard optimization passes [http://b/13329423].
|
||||
// TODO(jingyue): SROA may further expose more optimization opportunities, such
|
||||
// TODO(jingyue): SROA may further expose more optimization opportunities such
|
||||
// as more precise alias analysis and more function inlining (SROA may change
|
||||
// the inlining cost of a function). For now, running SROA already emits good
|
||||
// enough code for the evaluated benchmarks. We may want to run more
|
||||
|
@ -37,6 +37,230 @@ namespace xla {
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
// Data structure used to construct the alias analysis. Thrown away after alias
|
||||
// analysis is complete. This data structure keeps track of which sets of
|
||||
// HloValues must be in the same HloBuffer. This is maintained as a map from a
|
||||
// buffer identifier (BufferNumber) to set of HLoValues.
|
||||
//
|
||||
// Initially each value is its own buffer. In MergeAliasedBuffers, sets of
|
||||
// values which must share the same buffer are merged together. The end result
|
||||
// is a partitioning of all HloValues into sets where each set needs its own
|
||||
// HloBuffer. By performing this analysis without constructing HloBuffers on the
|
||||
// fly, we can after-the-fact construct a vector of contiguously numbered
|
||||
// HloBuffers after the buffer requirement has been determined.
|
||||
class BufferValueMap {
|
||||
public:
|
||||
// A unique identifier for a set of colocated values which must share the same
|
||||
// buffer. This is not necessarily the same as the HloBuffer::Id which will
|
||||
// ultimately contain the values. The reason is that HloBuffer::Id's are
|
||||
// contiguous, while BufferNumbers may not be. BufferNumbers may not be
|
||||
// dense because buffers may be created and destroyed during the analysis
|
||||
// construction process.
|
||||
using BufferNumber = int64;
|
||||
|
||||
explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
|
||||
: dataflow_(dataflow) {
|
||||
buffers_.reserve(dataflow_.values().size());
|
||||
value_to_buffer_number_.reserve(dataflow_.values().size());
|
||||
for (const HloValue* value : dataflow_.values()) {
|
||||
BufferNumber buffer_number = next_buffer_number_++;
|
||||
buffers_[buffer_number].insert(value);
|
||||
value_to_buffer_number_[value] = buffer_number;
|
||||
}
|
||||
}
|
||||
|
||||
// Merge together sets of HloValues which must be in the same HloBuffer
|
||||
// because of aliasing rules (eg, in-place kWhile instruction).
|
||||
void MergeAliasedBuffers() {
|
||||
for (const HloValue* value : dataflow_.values()) {
|
||||
VLOG(3) << "Merging colocated values, value: " << value->ToShortString();
|
||||
|
||||
// Gather the set of buffers with aliasing rules (eg, kWhile) which this
|
||||
// value must be contained in.
|
||||
std::vector<BufferNumber> aliased_buffers = ComputeAliasedBuffers(*value);
|
||||
|
||||
BufferNumber current_buffer = value_to_buffer_number_.at(value);
|
||||
if (aliased_buffers.empty()) {
|
||||
// The buffer containing 'value' aliases no other buffers. If the buffer
|
||||
// containing 'value' already only contains 'value', then no change is
|
||||
// necessary. If the buffer containing 'value' does contain other
|
||||
// values, then remove 'value' from the buffer and create a new buffer
|
||||
// containing only 'value'
|
||||
if (buffers_.at(current_buffer).size() == 1) {
|
||||
CHECK_EQ(*buffers_.at(current_buffer).begin(), value);
|
||||
} else {
|
||||
MoveValueToNewBuffer(*value);
|
||||
}
|
||||
} else {
|
||||
// If multiple buffers are aliased merge these buffers together into a
|
||||
// single buffer (arbitrarily chosen as the first buffer in the vector).
|
||||
if (aliased_buffers.size() > 1) {
|
||||
for (int64 i = 1; i < aliased_buffers.size(); ++i) {
|
||||
MergeBuffers(/*from=*/aliased_buffers[i],
|
||||
/*to=*/aliased_buffers[0]);
|
||||
}
|
||||
}
|
||||
BufferNumber new_buffer = aliased_buffers[0];
|
||||
if (current_buffer != new_buffer) {
|
||||
MoveValueToBuffer(*value, new_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute and return a sorted vector of all BufferNumbers. Can be used to
|
||||
// iterate through all buffers stabily.
|
||||
std::vector<BufferNumber> ComputeSortedBufferNumbers() const {
|
||||
std::vector<BufferNumber> buffer_numbers;
|
||||
for (const auto& pair : buffers_) {
|
||||
buffer_numbers.push_back(pair.first);
|
||||
}
|
||||
std::sort(buffer_numbers.begin(), buffer_numbers.end());
|
||||
return buffer_numbers;
|
||||
}
|
||||
|
||||
// Return a set of all the values in the given buffer.
|
||||
const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer(
|
||||
BufferNumber buffer_number) const {
|
||||
return buffers_.at(buffer_number);
|
||||
}
|
||||
|
||||
private:
|
||||
// Create a new buffer.
|
||||
void NewBuffer(const HloValue& value) {
|
||||
BufferNumber buffer_number = next_buffer_number_++;
|
||||
buffers_[buffer_number].insert(&value);
|
||||
value_to_buffer_number_[&value] = buffer_number;
|
||||
}
|
||||
|
||||
// Move the given value into a new buffer containing only the value.
|
||||
void MoveValueToNewBuffer(const HloValue& value) {
|
||||
BufferNumber new_buffer_number = next_buffer_number_++;
|
||||
buffers_[new_buffer_number];
|
||||
MoveValueToBuffer(value, new_buffer_number);
|
||||
}
|
||||
|
||||
// Move the given value into the given buffer.
|
||||
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
|
||||
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
|
||||
buffers_.at(old_buffer_number).erase(&value);
|
||||
if (buffers_.at(old_buffer_number).empty()) {
|
||||
buffers_.erase(old_buffer_number);
|
||||
}
|
||||
|
||||
buffers_.at(buffer_number).insert(&value);
|
||||
value_to_buffer_number_.at(&value) = buffer_number;
|
||||
}
|
||||
|
||||
// Merge the buffer 'from' into the buffer 'to'.
|
||||
void MergeBuffers(BufferNumber from, BufferNumber to) {
|
||||
auto& from_value_set = buffers_.at(from);
|
||||
buffers_.at(to).insert(from_value_set.begin(), from_value_set.end());
|
||||
// NOTE: using a union-find algorithm to hold the colocated values might be
|
||||
// faster.
|
||||
for (const HloValue* value : from_value_set) {
|
||||
value_to_buffer_number_.at(value) = to;
|
||||
}
|
||||
buffers_.erase(from);
|
||||
}
|
||||
|
||||
BufferNumber GetBufferForValue(const HloValue& value) {
|
||||
return value_to_buffer_number_.at(&value);
|
||||
}
|
||||
|
||||
// Compute and return a vector of buffers that the given value must be
|
||||
// contained in due to HLO aliasing rules.
|
||||
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
|
||||
// Value is init of a while (use is while).
|
||||
std::vector<BufferNumber> aliased_buffers;
|
||||
for (const HloUse& use : value.uses()) {
|
||||
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
// Determine the while value that this shares a buffer with.
|
||||
const HloValue& while_value =
|
||||
dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
|
||||
aliased_buffers.push_back(GetBufferForValue(while_value));
|
||||
VLOG(3) << " value is init value to a while; must share buffer with "
|
||||
"while value "
|
||||
<< while_value.ToShortString();
|
||||
}
|
||||
}
|
||||
|
||||
// Value is a parameter of a while body/condition.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
|
||||
const HloComputation* computation =
|
||||
value.defining_instruction()->parent();
|
||||
const CallGraphNode& call_graph_node =
|
||||
dataflow_.call_graph().GetNode(computation);
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
|
||||
// Call graph must have been flattened.
|
||||
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
|
||||
|
||||
const HloValue& while_value = dataflow_.GetUniqueValueAt(
|
||||
callsite.instruction(), value.defining_index());
|
||||
VLOG(3) << " value is parameter value of the body or condition of a "
|
||||
"while; must share buffer with while value "
|
||||
<< while_value.ToShortString();
|
||||
aliased_buffers.push_back(GetBufferForValue(while_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Value is the root of a while body.
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
const HloComputation* computation = position.instruction->parent();
|
||||
const CallGraphNode& call_graph_node =
|
||||
dataflow_.call_graph().GetNode(computation);
|
||||
if (position.instruction == computation->root_instruction()) {
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
|
||||
callsite.instruction()->while_body() == computation) {
|
||||
// Call graph must have been flattened.
|
||||
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
|
||||
|
||||
const HloValue& while_value = dataflow_.GetUniqueValueAt(
|
||||
callsite.instruction(), position.index);
|
||||
VLOG(3) << " value is root the body computation of a while; must "
|
||||
"share buffer with while value "
|
||||
<< while_value.ToShortString();
|
||||
aliased_buffers.push_back(GetBufferForValue(while_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Value is the output of the while instruction itself.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
|
||||
VLOG(3) << " value is output of a while instruction";
|
||||
aliased_buffers.push_back(GetBufferForValue(value));
|
||||
}
|
||||
|
||||
// Uniquify aliased buffers.
|
||||
std::sort(aliased_buffers.begin(), aliased_buffers.end());
|
||||
aliased_buffers.erase(
|
||||
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
|
||||
aliased_buffers.end());
|
||||
|
||||
return aliased_buffers;
|
||||
}
|
||||
|
||||
// Dataflow analysis used to construct the buffer map.
|
||||
const HloDataflowAnalysis& dataflow_;
|
||||
|
||||
// A map containing the set of values contained in each buffer.
|
||||
tensorflow::gtl::FlatMap<BufferNumber,
|
||||
tensorflow::gtl::FlatSet<const HloValue*>>
|
||||
buffers_;
|
||||
|
||||
// A map indicating which buffer each value is contained in.
|
||||
tensorflow::gtl::FlatMap<const HloValue*, BufferNumber>
|
||||
value_to_buffer_number_;
|
||||
|
||||
// The buffer number of the next buffer to be created.
|
||||
BufferNumber next_buffer_number_ = 0;
|
||||
};
|
||||
|
||||
HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {}
|
||||
|
||||
const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
|
||||
@ -99,10 +323,11 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
|
||||
}
|
||||
} else {
|
||||
// It's possible for multiple values at this index to have the same
|
||||
// HloBuffer. This does not result in non-distictness. To account for this
|
||||
// case, add all of the buffers at this index after checking whether each
|
||||
// buffer exists at an earlier index. This is a corner case, however, as
|
||||
// the number of values at an index is almost always one.
|
||||
// HloBuffer. This does not result in non-distictness. To account for
|
||||
// this case, add all of the buffers at this index after checking
|
||||
// whether each buffer exists at an earlier index. This is a corner
|
||||
// case, however, as the number of values at an index is almost always
|
||||
// one.
|
||||
std::vector<const HloBuffer*> buffers_at_this_index;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
const HloBuffer* buffer = &GetBufferContainingValue(*value);
|
||||
@ -118,15 +343,6 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
|
||||
return true;
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::InitializeBufferSets() {
|
||||
// Initially define a buffer for every HloValue in the module.
|
||||
for (const HloValue& value : dataflow_analysis_->values()) {
|
||||
HloBuffer& buffer = NewHloBuffer();
|
||||
buffer.AddValue(value);
|
||||
value_to_buffer_[&value] = &buffer;
|
||||
}
|
||||
}
|
||||
|
||||
Status HloAliasAnalysis::Verify() const {
|
||||
// Verify consistency between the value_to_buffer_ map and
|
||||
// HloBuffer::values().
|
||||
@ -137,9 +353,8 @@ Status HloAliasAnalysis::Verify() const {
|
||||
value) != buffer.values().end());
|
||||
}
|
||||
|
||||
for (const auto& pair : buffers_) {
|
||||
const HloBuffer::Id id = pair.first;
|
||||
const HloBuffer& buffer = pair.second;
|
||||
for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
|
||||
const HloBuffer& buffer = buffers_[id];
|
||||
TF_RET_CHECK(buffer.id() == id);
|
||||
|
||||
HloValue::Id last_value_id = -1;
|
||||
@ -152,116 +367,9 @@ Status HloAliasAnalysis::Verify() const {
|
||||
}
|
||||
}
|
||||
|
||||
if (!buffers_vector_.empty()) {
|
||||
// buffers_vector_ should be a vector of all HloBuffers sorted by id.
|
||||
std::vector<const HloBuffer*> buffers;
|
||||
for (const auto& id_buffer : buffers_) {
|
||||
buffers.push_back(&id_buffer.second);
|
||||
}
|
||||
std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan);
|
||||
TF_RET_CHECK(buffers_vector_ == buffers);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloAliasAnalysis::VerifyAgainstReference() const {
|
||||
TF_RETURN_IF_ERROR(Verify());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> reference,
|
||||
Run(module_));
|
||||
TF_RETURN_IF_ERROR(reference->Verify());
|
||||
|
||||
VLOG(2) << "This analysis:";
|
||||
XLA_VLOG_LINES(2, ToString());
|
||||
VLOG(2) << "Reference:";
|
||||
XLA_VLOG_LINES(2, reference->ToString());
|
||||
|
||||
// Create map from HloValue in the reference analysis to HloValue in this
|
||||
// analysis and vice versa.
|
||||
tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> reference_to_this;
|
||||
tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> this_to_reference;
|
||||
for (const HloValue& value : dataflow_analysis().values()) {
|
||||
const HloValue& reference_value =
|
||||
reference->dataflow_analysis().GetValueDefinedAt(
|
||||
value.defining_instruction(), value.defining_index());
|
||||
reference_to_this[&reference_value] = &value;
|
||||
this_to_reference[&value] = &reference_value;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(buffers_.size() == reference->buffers_.size())
|
||||
<< "Different number of buffers (" << buffers_.size()
|
||||
<< " != " << reference->buffers_.size() << ")";
|
||||
for (const auto& pair : reference->buffers_) {
|
||||
const HloBuffer& reference_buffer = pair.second;
|
||||
|
||||
// Find the corresponding buffer in the reference by taking the first value
|
||||
// in the buffer, finding the corresponding value in the reference, and then
|
||||
// finding the buffer holding that value.
|
||||
TF_RET_CHECK(!reference_buffer.values().empty());
|
||||
const HloValue* reference_value = reference_buffer.values()[0];
|
||||
const HloValue* value = reference_to_this.at(reference_value);
|
||||
const HloBuffer& buffer = GetBufferContainingValue(*value);
|
||||
|
||||
// The buffer and the reference should have the exact same values. To make
|
||||
// comparison easy, sort the values in the reference buffer identically to
|
||||
// the values in the non-reference buffer (ie, by the corresponding id of
|
||||
// the non-reference value).
|
||||
std::vector<const HloValue*> reference_values = reference_buffer.values();
|
||||
std::sort(reference_values.begin(), reference_values.end(),
|
||||
[&reference_to_this](const HloValue* a, const HloValue* b) {
|
||||
return reference_to_this.at(a)->id() <
|
||||
reference_to_this.at(b)->id();
|
||||
});
|
||||
TF_RET_CHECK(reference_values.size() == buffer.values().size());
|
||||
for (int i = 0; i < buffer.values().size(); ++i) {
|
||||
TF_RET_CHECK(*reference_values[i] == *buffer.values()[i])
|
||||
<< "Buffer:\n " << buffer
|
||||
<< "\ndoes not have the same values as reference buffer:\n "
|
||||
<< reference_buffer;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
HloBuffer& HloAliasAnalysis::NewHloBuffer() {
|
||||
HloBuffer::Id buffer_id = next_buffer_id_++;
|
||||
auto emplaced = buffers_.emplace(std::piecewise_construct,
|
||||
std::forward_as_tuple(buffer_id),
|
||||
std::forward_as_tuple(buffer_id));
|
||||
CHECK(emplaced.second);
|
||||
|
||||
buffers_vector_.clear();
|
||||
|
||||
return emplaced.first->second;
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::MoveValueToNewBuffer(const HloValue& value) {
|
||||
HloBuffer& new_buffer = NewHloBuffer();
|
||||
MoveValueToBuffer(value, &new_buffer);
|
||||
|
||||
VLOG(3) << "Moved value " << value.ToShortString() << " into new buffer "
|
||||
<< new_buffer.id();
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::MoveValueToBuffer(const HloValue& value,
|
||||
HloBuffer* buffer) {
|
||||
HloBuffer& old_buffer = GetBufferContainingValue(value);
|
||||
CHECK_NE(buffer, &old_buffer);
|
||||
VLOG(3) << "Moved value " << value.ToShortString() << " from buffer "
|
||||
<< old_buffer.id() << " into buffer " << buffer->id();
|
||||
old_buffer.RemoveValue(value);
|
||||
if (old_buffer.values().empty()) {
|
||||
VLOG(3) << "Buffer " << old_buffer.id() << " now empty. Removing.";
|
||||
buffers_.erase(old_buffer.id());
|
||||
buffers_vector_.clear();
|
||||
}
|
||||
|
||||
buffer->AddValue(value);
|
||||
value_to_buffer_[&value] = buffer;
|
||||
}
|
||||
|
||||
string HloAliasAnalysis::ToString() const {
|
||||
string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
|
||||
StrAppend(&out, " Buffers at each position:\n");
|
||||
@ -290,10 +398,10 @@ string HloAliasAnalysis::ToString() const {
|
||||
}
|
||||
|
||||
StrAppend(&out, " Buffers:\n");
|
||||
for (const HloBuffer* buffer : buffers()) {
|
||||
StrAppend(&out, " ", buffer->ToString(), "\n");
|
||||
for (const HloBuffer& buffer : buffers()) {
|
||||
StrAppend(&out, " ", buffer.ToString(), "\n");
|
||||
StrAppend(&out, " positions:\n");
|
||||
for (const HloPosition& position : buffer->ComputePositions()) {
|
||||
for (const HloPosition& position : buffer.ComputePositions()) {
|
||||
StrAppend(&out, " ", position.ToString(), "\n");
|
||||
}
|
||||
}
|
||||
@ -301,217 +409,6 @@ string HloAliasAnalysis::ToString() const {
|
||||
return out;
|
||||
}
|
||||
|
||||
const std::vector<const HloBuffer*>& HloAliasAnalysis::buffers() const {
|
||||
if (buffers_vector_.empty()) {
|
||||
// Lazily construct vector of buffers.
|
||||
buffers_vector_.reserve(buffers_.size());
|
||||
for (auto& pair : buffers_) {
|
||||
buffers_vector_.push_back(&pair.second);
|
||||
}
|
||||
std::sort(buffers_vector_.begin(), buffers_vector_.end(),
|
||||
HloBuffer::IdLessThan);
|
||||
} else {
|
||||
CHECK_EQ(buffers_vector_.size(), buffers_.size());
|
||||
for (const HloBuffer* buffer : buffers_vector_) {
|
||||
DCHECK(ContainsKey(buffers_, buffer->id()));
|
||||
DCHECK(&GetBuffer(buffer->id()) == buffer);
|
||||
}
|
||||
}
|
||||
return buffers_vector_;
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::UpdateAtInstructions(
|
||||
tensorflow::gtl::ArraySlice<const HloInstruction*> instructions) {
|
||||
VLOG(4) << "Updated HLO module:";
|
||||
XLA_VLOG_LINES(4, module_->ToString());
|
||||
|
||||
VLOG(3) << "Before update:";
|
||||
XLA_VLOG_LINES(3, ToString());
|
||||
|
||||
std::vector<const HloValue*> values_to_update;
|
||||
for (const HloInstruction* instruction : instructions) {
|
||||
for (auto& pair : dataflow_analysis().GetInstructionValueSet(instruction)) {
|
||||
for (const HloValue* value : pair.second.values()) {
|
||||
values_to_update.push_back(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
UpdateBuffersForValues(values_to_update);
|
||||
|
||||
VLOG(3) << "After update:";
|
||||
XLA_VLOG_LINES(3, ToString());
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::UpdateAfterChangingOperand(HloInstruction* instruction,
|
||||
HloInstruction* old_operand,
|
||||
HloInstruction* new_operand) {
|
||||
VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
|
||||
<< old_operand->name() << " => " << new_operand->name() << ")";
|
||||
|
||||
dataflow_analysis_->UpdateAfterChangingOperand(instruction, old_operand,
|
||||
new_operand);
|
||||
TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
|
||||
|
||||
VLOG(4) << "Updated dataflow:";
|
||||
XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
|
||||
|
||||
UpdateAtInstructions({instruction, old_operand, new_operand});
|
||||
}
|
||||
|
||||
void HloAliasAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
|
||||
HloInstruction* new_root) {
|
||||
VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
|
||||
<< new_root->name() << ")";
|
||||
|
||||
dataflow_analysis_->UpdateAfterChangingRoot(old_root, new_root);
|
||||
TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
|
||||
|
||||
VLOG(4) << "Updated dataflow:";
|
||||
XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
|
||||
|
||||
UpdateAtInstructions({old_root, new_root});
|
||||
}
|
||||
|
||||
std::vector<HloBuffer*> HloAliasAnalysis::ComputeAliasedBuffers(
|
||||
const HloValue& value) {
|
||||
std::vector<HloBuffer*> aliased_buffers;
|
||||
|
||||
// Value is init of a while (use is while).
|
||||
for (const HloUse& use : value.uses()) {
|
||||
VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
// Determine the while value that this shares a buffer with.
|
||||
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
|
||||
use.instruction, use.operand_index);
|
||||
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
|
||||
VLOG(3) << " value is init value to a while; must share buffer with "
|
||||
"while value "
|
||||
<< while_value.ToShortString();
|
||||
}
|
||||
}
|
||||
|
||||
// Value is a parameter of a while body/condition.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
|
||||
const HloComputation* computation = value.defining_instruction()->parent();
|
||||
const CallGraphNode& call_graph_node =
|
||||
dataflow_analysis().call_graph().GetNode(computation);
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
|
||||
// Call graph must have been flattened.
|
||||
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
|
||||
|
||||
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
|
||||
callsite.instruction(), value.defining_index());
|
||||
VLOG(3) << " value is parameter value of the body or condition of a "
|
||||
"while; must share buffer with while value "
|
||||
<< while_value.ToShortString();
|
||||
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Value is the root of a while body.
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
const HloComputation* computation = position.instruction->parent();
|
||||
const CallGraphNode& call_graph_node =
|
||||
dataflow_analysis().call_graph().GetNode(computation);
|
||||
if (position.instruction == computation->root_instruction()) {
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
|
||||
callsite.instruction()->while_body() == computation) {
|
||||
// Call graph must have been flattened.
|
||||
CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
|
||||
|
||||
// If the value appears in the root of a while body, then
|
||||
// necessarily the value is defined in the body as well.
|
||||
CHECK_EQ(value.defining_instruction()->parent(), computation);
|
||||
|
||||
const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
|
||||
callsite.instruction(), position.index);
|
||||
VLOG(3) << " value is root the body computation of a while; must "
|
||||
"share buffer with while value "
|
||||
<< while_value.ToShortString();
|
||||
aliased_buffers.push_back(&GetBufferContainingValue(while_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Value is in the while instruction itself.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
|
||||
VLOG(3) << " value is output of a while instruction";
|
||||
aliased_buffers.push_back(&GetUniqueBufferAt(value.defining_instruction(),
|
||||
value.defining_index()));
|
||||
}
|
||||
|
||||
// Uniquify aliased buffers.
|
||||
std::sort(aliased_buffers.begin(), aliased_buffers.end(),
|
||||
HloBuffer::IdLessThan);
|
||||
aliased_buffers.erase(
|
||||
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
|
||||
aliased_buffers.end());
|
||||
|
||||
return aliased_buffers;
|
||||
}
|
||||
|
||||
// This method recomputes the HloBuffer for each of the given HloValues. The
|
||||
// method does not necessarily update the HloBuffer of values which share a
|
||||
// buffer with the given values, but are not explicitly passed in
|
||||
// 'values'. Therefore, the caller must pass in all values which may require an
|
||||
// update according to the kind of HLO graph change which occurred: operand
|
||||
// changed (UpdateAfterChangingOperand), or root of computation changed
|
||||
// (UpdateAfterChangingRoot).
|
||||
void HloAliasAnalysis::UpdateBuffersForValues(
|
||||
tensorflow::gtl::ArraySlice<const HloValue*> values) {
|
||||
for (const HloValue* value : values) {
|
||||
VLOG(3) << "Updating buffer for value: " << value->ToShortString();
|
||||
|
||||
// Gather the set of buffer with aliasing rules (eg, kWhile) which this
|
||||
// value must be contained in due.
|
||||
std::vector<HloBuffer*> aliased_buffers = ComputeAliasedBuffers(*value);
|
||||
|
||||
HloBuffer& current_buffer = GetBufferContainingValue(*value);
|
||||
if (aliased_buffers.empty()) {
|
||||
// The buffer containing 'value' aliases no other buffers. If the buffer
|
||||
// containing 'value' already only contains 'value', then no change is
|
||||
// necessary. If the buffer containing 'value' does contain other values,
|
||||
// then remove 'value' from the buffer and create a new buffer containing
|
||||
// only 'value'
|
||||
if (current_buffer.values().size() == 1) {
|
||||
CHECK_EQ(current_buffer.values()[0], value);
|
||||
} else {
|
||||
MoveValueToNewBuffer(*value);
|
||||
}
|
||||
} else {
|
||||
// If multiple buffers are aliased merge these buffers together into a
|
||||
// single buffer (arbitrarily chosen as the first buffer in the vector).
|
||||
if (aliased_buffers.size() > 1) {
|
||||
for (int64 i = 1; i < aliased_buffers.size(); ++i) {
|
||||
// Make copy of values vector because MoveValueToBuffer invalidates
|
||||
// the values iterator. The could be done more efficiently by moving
|
||||
// all values and once.
|
||||
std::vector<const HloValue*> values = aliased_buffers[i]->values();
|
||||
for (const HloValue* value : values) {
|
||||
MoveValueToBuffer(*value, aliased_buffers[0]);
|
||||
}
|
||||
}
|
||||
aliased_buffers.resize(1);
|
||||
}
|
||||
|
||||
CHECK_EQ(aliased_buffers.size(), 1);
|
||||
HloBuffer* new_buffer = aliased_buffers[0];
|
||||
|
||||
if (¤t_buffer != new_buffer) {
|
||||
MoveValueToBuffer(*value, new_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(4) << "Analysis after update:";
|
||||
XLA_VLOG_LINES(4, ToString());
|
||||
}
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
HloModule* module) {
|
||||
@ -524,18 +421,28 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
HloDataflowAnalysis::Run(module, /*ssa_form=*/true,
|
||||
/*bitcast_defines_value=*/false));
|
||||
|
||||
alias_analysis->InitializeBufferSets();
|
||||
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
|
||||
buffer_map.MergeAliasedBuffers();
|
||||
|
||||
VLOG(3) << "After initialization:";
|
||||
XLA_VLOG_LINES(3, alias_analysis->ToString());
|
||||
|
||||
std::vector<const HloValue*> all_values;
|
||||
for (const HloValue& value : alias_analysis->dataflow_analysis().values()) {
|
||||
all_values.push_back(&value);
|
||||
// Create a vector of HloBuffers, one for each set of values in the
|
||||
// BufferValueMap. Create the HloBuffers as a vector of contiguously numbered
|
||||
// buffers.
|
||||
std::vector<BufferValueMap::BufferNumber> sorted_buffer_numbers =
|
||||
buffer_map.ComputeSortedBufferNumbers();
|
||||
alias_analysis->buffers_.reserve(sorted_buffer_numbers.size());
|
||||
HloBuffer::Id next_id = 0;
|
||||
for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) {
|
||||
auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
|
||||
std::vector<const HloValue*> sorted_values(value_set.begin(),
|
||||
value_set.end());
|
||||
std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan);
|
||||
alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
|
||||
for (const HloValue* value : sorted_values) {
|
||||
alias_analysis->value_to_buffer_[value] =
|
||||
&alias_analysis->buffers_.back();
|
||||
}
|
||||
}
|
||||
|
||||
alias_analysis->UpdateBuffersForValues(all_values);
|
||||
|
||||
TF_DCHECK_OK(alias_analysis->Verify());
|
||||
|
||||
XLA_VLOG_LINES(1, alias_analysis->ToString());
|
||||
|
@ -74,7 +74,7 @@ class HloAliasAnalysis {
|
||||
// Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This
|
||||
// vector is lazily computed. Mutating operations on HloAliasAnalysis may
|
||||
// invalidate the underlying vector requiring recomputation.
|
||||
const std::vector<const HloBuffer*>& buffers() const;
|
||||
const std::vector<HloBuffer>& buffers() const { return buffers_; }
|
||||
|
||||
// Returns the underlying dataflow analysis used by this alias analysis.
|
||||
const HloDataflowAnalysis& dataflow_analysis() const {
|
||||
@ -90,50 +90,13 @@ class HloAliasAnalysis {
|
||||
// output of the given instruction.
|
||||
bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const;
|
||||
|
||||
// Updates the analysis after the operands of 'instruction' have changed or if
|
||||
// 'instruction' has been made the root of a computation. Analysis update is
|
||||
// not possible if instructions have been added or removed from the graph.
|
||||
void UpdateAfterChangingOperand(HloInstruction* instruction,
|
||||
HloInstruction* old_operand,
|
||||
HloInstruction* new_operand);
|
||||
void UpdateAfterChangingRoot(HloInstruction* old_root,
|
||||
HloInstruction* new_root);
|
||||
|
||||
// Compare the dataflow analysis against a clean recomputation of the
|
||||
// analysis. Returns an error status if there is a mismatch. Useful for
|
||||
// verifying the correctness after updates to the analysis.
|
||||
Status VerifyAgainstReference() const;
|
||||
|
||||
protected:
|
||||
HloAliasAnalysis(HloModule* module);
|
||||
|
||||
// Create a new empty HloBuffer.
|
||||
HloBuffer& NewHloBuffer();
|
||||
|
||||
// Move the given value to the given buffer. The value is removed from it's
|
||||
// current buffer.
|
||||
void MoveValueToBuffer(const HloValue& value, HloBuffer* buffer);
|
||||
|
||||
// Move the given value to a newly created buffer. The value is removed from
|
||||
// it's current buffer.
|
||||
void MoveValueToNewBuffer(const HloValue& value);
|
||||
|
||||
// Construct the initial set of buffer sets where an HloBuffer is created for
|
||||
// each HloValue in the module.
|
||||
void InitializeBufferSets();
|
||||
|
||||
// Compute and return the buffers with aliasing rules (eg, kWhile) which the
|
||||
// given value must be contained in.
|
||||
std::vector<HloBuffer*> ComputeAliasedBuffers(const HloValue& value);
|
||||
|
||||
// Recompute the HloBuffers for the given values.
|
||||
void UpdateBuffersForValues(
|
||||
tensorflow::gtl::ArraySlice<const HloValue*> values);
|
||||
|
||||
// Recompute the HloBuffers for all the values which appear in the output of
|
||||
// the given instructions.
|
||||
void UpdateAtInstructions(
|
||||
tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
|
||||
explicit HloAliasAnalysis(HloModule* module);
|
||||
|
||||
// Verify various invariants of the alias analysis.
|
||||
Status Verify() const;
|
||||
@ -143,20 +106,12 @@ class HloAliasAnalysis {
|
||||
// The underlying dataflow analysis used by this alias analysis.
|
||||
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
||||
|
||||
// The map of all HloBuffers in the module. We pass around pointers to the
|
||||
// mapped HloBuffers, so the underlying container must keep them valid despite
|
||||
// mutations touching other map entries.
|
||||
std::unordered_map<HloBuffer::Id, HloBuffer> buffers_;
|
||||
|
||||
// A map indicating which buffer a value is contained in.
|
||||
tensorflow::gtl::FlatMap<const HloValue*, HloBuffer*> value_to_buffer_;
|
||||
|
||||
// A lazily constructed vector containing all HloBuffers sorted by
|
||||
// HloBuffer::Id.
|
||||
mutable std::vector<const HloBuffer*> buffers_vector_;
|
||||
|
||||
// The Id to use for the next HloBuffer.
|
||||
int64 next_buffer_id_ = 0;
|
||||
std::vector<HloBuffer> buffers_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -87,14 +87,13 @@ class HloAliasAnalysisTest : public HloTestBase {
|
||||
// constructed.
|
||||
bool AnyValuesInSameBufferInterfere() {
|
||||
DependencyHloOrdering ordering(module_.get());
|
||||
for (const HloBuffer* buffer : analysis_->buffers()) {
|
||||
for (const HloValue* value_a : buffer->values()) {
|
||||
for (const HloValue* value_b : buffer->values()) {
|
||||
for (const HloBuffer& buffer : analysis_->buffers()) {
|
||||
for (const HloValue* value_a : buffer.values()) {
|
||||
for (const HloValue* value_b : buffer.values()) {
|
||||
if (*value_a != *value_b &&
|
||||
analysis_->dataflow_analysis().MayInterfere(*value_a, *value_b,
|
||||
ordering)) {
|
||||
ordering.MayInterfere(*value_a, *value_b)) {
|
||||
VLOG(1) << *value_a << " interferes with " << *value_b
|
||||
<< " in buffer: " << *buffer;
|
||||
<< " in buffer: " << buffer;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -384,10 +383,7 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
|
||||
|
||||
EXPECT_THAT(
|
||||
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
|
||||
UnorderedElementsAre(GetValueDefinedAt(xla_while, /*index=*/{0}),
|
||||
GetValueDefinedAt(body_param, /*index=*/{0}),
|
||||
GetValueDefinedAt(cond_param, /*index=*/{0}),
|
||||
GetValueDefinedAt(constant1)));
|
||||
UnorderedElementsAre(GetValueDefinedAt(constant1)));
|
||||
EXPECT_THAT(
|
||||
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
|
||||
UnorderedElementsAre(GetValueDefinedAt(constant2),
|
||||
@ -631,9 +627,9 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
|
||||
// HloBuffers.
|
||||
EXPECT_THAT(
|
||||
analysis.buffers(),
|
||||
UnorderedElementsAre(&analysis.GetUniqueBufferAt(constant1),
|
||||
&analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
|
||||
&analysis.GetUniqueBufferAt(cond_constant)));
|
||||
UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
|
||||
analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
|
||||
analysis.GetUniqueBufferAt(cond_constant)));
|
||||
|
||||
// The tuple elements of the while and the three constant inputs should all be
|
||||
// smooshed into the same buffer.
|
||||
@ -820,127 +816,5 @@ TEST_F(HloAliasAnalysisTest, Bitcast) {
|
||||
analysis.GetUniqueBufferAt(bitcast));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, UpdateAnalysisForWhile) {
|
||||
// Test updating alias analysis after modifying a module with an array shaped
|
||||
// while:
|
||||
//
|
||||
// body(F32[] %param):
|
||||
// %negate = Negate(%param)
|
||||
//
|
||||
// condition(F32[] %param):
|
||||
// return Constant(false)
|
||||
//
|
||||
// entry:
|
||||
// %constant = Constant(1.0)
|
||||
// %exp = Exp(%constant)
|
||||
// return While(%exp, body, condition)
|
||||
//
|
||||
auto body_builder = HloComputation::Builder("body");
|
||||
auto body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
scalar_shape_, HloOpcode::kNegate, body_param));
|
||||
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
// Condition computation trivially returns a constant "false".
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
auto cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloComputation* condition =
|
||||
module_->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
|
||||
auto xla_while = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
// Sanity check some alias information.
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(body_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(cond_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(negate));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(xla_while));
|
||||
|
||||
// Set the body root to the body_param. Previously it was Negate(body_param).
|
||||
body->set_root_instruction(body_param);
|
||||
|
||||
// Prior to updating, verify that the analysis is no longer valid.
|
||||
Status verify_status = analysis.VerifyAgainstReference();
|
||||
EXPECT_FALSE(verify_status.ok());
|
||||
|
||||
analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
|
||||
/*new_root*/ body_param);
|
||||
|
||||
// Analysis should be valid after the update.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
// The exponential should now pass through the body transparently.
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(body_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(cond_param));
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(negate));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
|
||||
analysis.GetUniqueBufferAt(xla_while));
|
||||
|
||||
// Now replace the operand of the while with %constant (was %exp).
|
||||
TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
|
||||
analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
|
||||
/*new_operand=*/constant);
|
||||
|
||||
// Analysis should be valid after the update.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(body_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(cond_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(xla_while));
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(exp));
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(negate));
|
||||
|
||||
// And finally make the negate the root of the body again.
|
||||
body->set_root_instruction(negate);
|
||||
analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
|
||||
/*new_root*/ negate);
|
||||
|
||||
// Analysis should be valid after the update.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
|
||||
analysis.GetUniqueBufferAt(body_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
|
||||
analysis.GetUniqueBufferAt(cond_param));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
|
||||
analysis.GetUniqueBufferAt(xla_while));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
|
||||
analysis.GetUniqueBufferAt(negate));
|
||||
|
||||
auto value_of = [&analysis](const HloInstruction* instruction) {
|
||||
return &analysis.dataflow_analysis().GetValueDefinedAt(instruction);
|
||||
};
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(negate).values(),
|
||||
UnorderedElementsAre(value_of(body_param), value_of(cond_param),
|
||||
value_of(negate), value_of(constant),
|
||||
value_of(xla_while)));
|
||||
}
|
||||
|
||||
// Test update tuple element.
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -36,22 +36,6 @@ namespace xla {
|
||||
using ::tensorflow::str_util::Join;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
void HloBuffer::AddValue(const HloValue& value) {
|
||||
values_.push_back(&value);
|
||||
// Sort vector and remove duplicates.
|
||||
std::sort(values_.begin(), values_.end(), HloValue::IdLessThan);
|
||||
values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
|
||||
values_.end());
|
||||
}
|
||||
|
||||
void HloBuffer::RemoveValue(const HloValue& value) {
|
||||
// The values are sorted, so finding the value could be done in log(n) time
|
||||
// with a binary search.
|
||||
auto it = std::find(values_.begin(), values_.end(), &value);
|
||||
CHECK(it != values_.end());
|
||||
values_.erase(it);
|
||||
}
|
||||
|
||||
bool HloBuffer::operator==(const HloBuffer& other) const {
|
||||
bool equal = id() == other.id();
|
||||
if (equal) {
|
||||
|
@ -84,22 +84,15 @@ class HloBuffer {
|
||||
return a->id() == b->id();
|
||||
}
|
||||
|
||||
HloBuffer(Id id) : id_(id) {}
|
||||
HloBuffer(Id id, tensorflow::gtl::ArraySlice<const HloValue*> values)
|
||||
: id_(id), values_(values.begin(), values.end()) {}
|
||||
|
||||
// Return the unique identifier for this HloBuffer.
|
||||
Id id() const { return id_; }
|
||||
|
||||
// Add a value to the set of values held by this buffer. Also adds the
|
||||
// HloPositions of the value to the positions vector of the buffer. If the
|
||||
// buffer already contains this value, then this method is a nop.
|
||||
void AddValue(const HloValue& value);
|
||||
void RemoveValue(const HloValue& value);
|
||||
|
||||
// Return all values contained in this buffer.
|
||||
const std::vector<const HloValue*>& values() const { return values_; }
|
||||
|
||||
std::vector<HloPosition> ComputePositions() const;
|
||||
|
||||
// Return the unique HLO value in the buffer. CHECK fails if the buffer does
|
||||
// not contain exactly one value.
|
||||
const HloValue& GetUniqueValue() const {
|
||||
@ -107,6 +100,8 @@ class HloBuffer {
|
||||
return *values_[0];
|
||||
}
|
||||
|
||||
std::vector<HloPosition> ComputePositions() const;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const HloBuffer& other) const;
|
||||
@ -118,7 +113,7 @@ class HloBuffer {
|
||||
|
||||
// The set of values contained in this buffer. Vector contains no duplicates
|
||||
// and is sorted stably by HloValue::Id.
|
||||
std::vector<const HloValue*> values_;
|
||||
const std::vector<const HloValue*> values_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
|
||||
|
@ -118,13 +118,11 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) {
|
||||
}
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo,
|
||||
HloOpcode opcode) {
|
||||
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo) {
|
||||
return HandleElementwiseOp(hlo);
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) {
|
||||
Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo) {
|
||||
return HandleElementwiseOp(hlo);
|
||||
}
|
||||
|
||||
|
@ -49,9 +49,8 @@ class HloCostAnalysis : public DfsHloVisitor {
|
||||
using ShapeSizeFunction = std::function<int64(const Shape&)>;
|
||||
explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
|
||||
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override;
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override;
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo) override;
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo) override;
|
||||
Status HandleConstant(HloInstruction* constant,
|
||||
const Literal& literal) override;
|
||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element,
|
||||
|
@ -67,6 +67,22 @@ HloValue& HloDataflowAnalysis::GetValueDefinedAt(
|
||||
return GetUniqueValueAt(instruction, index);
|
||||
}
|
||||
|
||||
HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
|
||||
const ShapeIndex& index,
|
||||
bool is_phi) {
|
||||
const int64 value_id = next_value_id_++;
|
||||
auto emplaced = values_.emplace(
|
||||
std::piecewise_construct, std::forward_as_tuple(value_id),
|
||||
std::forward_as_tuple(value_id, instruction, index, is_phi));
|
||||
CHECK(emplaced.second);
|
||||
|
||||
return &emplaced.first->second;
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
|
||||
values_.erase(value_id);
|
||||
}
|
||||
|
||||
string HloDataflowAnalysis::ToString() const {
|
||||
string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
|
||||
StrAppend(&out, " Instruction value sets:\n");
|
||||
@ -99,22 +115,98 @@ string HloDataflowAnalysis::ToString() const {
|
||||
}
|
||||
}
|
||||
StrAppend(&out, " HloValues:\n");
|
||||
for (const HloValue& value : values()) {
|
||||
StrAppend(&out, value.ToString(/*indent=*/4));
|
||||
}
|
||||
StrAppend(&out, " Phi resolutions:\n");
|
||||
for (const HloValue& value : values()) {
|
||||
if (value.is_phi()) {
|
||||
const HloValue* resolved_value = ResolvePhi(value);
|
||||
StrAppend(&out, " ", value.ToShortString(), " => ",
|
||||
resolved_value == nullptr ? "UNKNOWN"
|
||||
: resolved_value->ToShortString(),
|
||||
"\n");
|
||||
}
|
||||
for (const HloValue* value : values()) {
|
||||
StrAppend(&out, value->ToString(/*indent=*/4));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::Phi(
|
||||
HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
|
||||
CHECK(ssa_form_);
|
||||
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (auto& pair : GetInstructionValueSet(instruction)) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
HloValueSet& value_set = pair.second;
|
||||
|
||||
// Positions with phi values should never have more than one value in the
|
||||
// value set.
|
||||
CHECK_LE(value_set.values().size(), 1);
|
||||
const HloValue* current_value =
|
||||
value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
|
||||
|
||||
// Construct a vector of unique value IDs of the inputs.
|
||||
std::vector<HloValue::Id> input_value_ids;
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
for (const HloValue* value : input->element(index).values()) {
|
||||
input_value_ids.push_back(value->id());
|
||||
}
|
||||
}
|
||||
std::sort(input_value_ids.begin(), input_value_ids.end());
|
||||
input_value_ids.erase(
|
||||
std::unique(input_value_ids.begin(), input_value_ids.end()),
|
||||
input_value_ids.end());
|
||||
|
||||
// Remove the existing phi value (if it exists). The phi can be its own
|
||||
// input, for example, in while body parameters where the body passes
|
||||
// through the parameter value.
|
||||
bool current_value_defined_here =
|
||||
(current_value != nullptr &&
|
||||
current_value->defining_instruction() == instruction &&
|
||||
current_value->defining_index() == index);
|
||||
if (current_value_defined_here) {
|
||||
CHECK(current_value->is_phi());
|
||||
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
|
||||
current_value->id());
|
||||
if (it != input_value_ids.end()) {
|
||||
input_value_ids.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
if (input_value_ids.empty()) {
|
||||
// A value set which has at least one element should never have its value
|
||||
// set reduced to zero elements. During dataflow value sets only can go
|
||||
// from empty to non-empty, not the reverse.
|
||||
CHECK_EQ(value_set.values().size(), 0)
|
||||
<< "Instruction " << instruction->name() << " at index " << index
|
||||
<< " previously had non-empty value set. Value set: " << value_set;
|
||||
} else if (input_value_ids.size() == 1) {
|
||||
// Only a single value reaches this point. There should be no phi, and
|
||||
// this value set should contain this single value.
|
||||
const HloValue& new_value = GetValue(input_value_ids[0]);
|
||||
if (current_value == nullptr) {
|
||||
value_set.Clear();
|
||||
value_set.AddValue(&new_value);
|
||||
changed = true;
|
||||
} else if (current_value != &new_value) {
|
||||
if (current_value_defined_here) {
|
||||
// Remove the existing phi.
|
||||
DeleteHloValue(current_value->id());
|
||||
}
|
||||
value_set.Clear();
|
||||
value_set.AddValue(&new_value);
|
||||
changed = true;
|
||||
}
|
||||
} else {
|
||||
// Multiple distinct values reach this point. A phi value is
|
||||
// necessary.
|
||||
CHECK_GT(input_value_ids.size(), 1);
|
||||
if (current_value == nullptr || !current_value->is_phi()) {
|
||||
value_set.Clear();
|
||||
value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
|
||||
return values_.at(value_id);
|
||||
}
|
||||
@ -142,129 +234,6 @@ HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
|
||||
return GetValueSet(position.instruction, position.index);
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateAfterChangingOperand(
|
||||
HloInstruction* instruction, HloInstruction* old_operand,
|
||||
HloInstruction* new_operand) {
|
||||
CHECK(std::find(instruction->operands().begin(),
|
||||
instruction->operands().end(),
|
||||
new_operand) != instruction->operands().end());
|
||||
VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
|
||||
<< old_operand->name() << " => " << new_operand->name() << ")";
|
||||
|
||||
std::vector<HloInstruction*> to_update = {instruction};
|
||||
|
||||
// If the instruction calls any computations then add the parameters of called
|
||||
// computation to capture any changes to the dataflow into the subcomputation
|
||||
// introduced by the new operand.
|
||||
for (HloComputation* computation : instruction->called_computations()) {
|
||||
to_update.insert(to_update.end(),
|
||||
computation->parameter_instructions().begin(),
|
||||
computation->parameter_instructions().end());
|
||||
}
|
||||
|
||||
UpdateInstructionsAndPropagate(to_update);
|
||||
|
||||
// The uses of the values in the old and new operand may have changed. Uses of
|
||||
// other HloValues are updated in UpdateInstructionsAndPropagate.
|
||||
for (auto& pair : GetInstructionValueSet(old_operand)) {
|
||||
for (const HloValue* value : pair.second.values()) {
|
||||
GetValue(value->id()).RecomputeUses();
|
||||
}
|
||||
}
|
||||
for (auto& pair : GetInstructionValueSet(new_operand)) {
|
||||
for (const HloValue* value : pair.second.values()) {
|
||||
GetValue(value->id()).RecomputeUses();
|
||||
}
|
||||
}
|
||||
|
||||
TF_DCHECK_OK(VerifyAgainstReference());
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
|
||||
HloInstruction* new_root) {
|
||||
VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
|
||||
<< new_root->name() << ")";
|
||||
|
||||
CHECK_EQ(new_root, new_root->parent()->root_instruction());
|
||||
CHECK_EQ(new_root->parent(), old_root->parent());
|
||||
|
||||
std::vector<HloInstruction*> to_update = {old_root, new_root};
|
||||
|
||||
const CallGraphNode& call_graph_node =
|
||||
call_graph_->GetNode(new_root->parent());
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
|
||||
to_update.push_back(callsite.instruction());
|
||||
} else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
|
||||
// Add the while itself, and the body and condition parameters.
|
||||
to_update.push_back(callsite.instruction());
|
||||
to_update.push_back(
|
||||
callsite.instruction()->while_body()->parameter_instruction(0));
|
||||
to_update.push_back(
|
||||
callsite.instruction()->while_condition()->parameter_instruction(0));
|
||||
}
|
||||
}
|
||||
|
||||
UpdateInstructionsAndPropagate(to_update);
|
||||
|
||||
TF_DCHECK_OK(VerifyAgainstReference());
|
||||
}
|
||||
|
||||
const HloValue* HloDataflowAnalysis::ResolvePhi(const HloValue& phi) const {
|
||||
CHECK(phi.is_phi());
|
||||
|
||||
tensorflow::gtl::FlatSet<const HloValue*> visited;
|
||||
std::queue<const HloValue*> worklist;
|
||||
auto add_to_worklist = [&worklist, &visited](const HloValue* v) {
|
||||
if (visited.insert(v).second) {
|
||||
// 'v' was not previously in visited.
|
||||
worklist.push(v);
|
||||
}
|
||||
};
|
||||
add_to_worklist(&phi);
|
||||
|
||||
const HloValue* resolved_value = nullptr;
|
||||
while (!worklist.empty()) {
|
||||
const HloValue* value = worklist.front();
|
||||
worklist.pop();
|
||||
|
||||
if (!value->is_phi()) {
|
||||
if (resolved_value == nullptr) {
|
||||
resolved_value = value;
|
||||
} else if (resolved_value != value) {
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
for (const HloValue* input : phi_inputs_.at(value)) {
|
||||
add_to_worklist(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
return resolved_value;
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdatePhiInputs(
|
||||
const HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
|
||||
CHECK(ssa_form_);
|
||||
for (auto& pair : GetInstructionValueSet(instruction)) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
const HloValue& phi_value = GetUniqueValueAt(instruction, index);
|
||||
auto& phi_inputs = phi_inputs_.at(&phi_value);
|
||||
phi_inputs.clear();
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
for (const HloValue* value : input->element(index).values()) {
|
||||
// The number of phi inputs is typically 2, and virtually always very
|
||||
// small.
|
||||
if (std::find(phi_inputs.begin(), phi_inputs.end(), value) ==
|
||||
phi_inputs.end()) {
|
||||
phi_inputs.push_back(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
|
||||
CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
|
||||
const InstructionValueSet& operand_set =
|
||||
@ -380,8 +349,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
|
||||
}
|
||||
|
||||
if (ssa_form_ && called_from_while) {
|
||||
UpdatePhiInputs(parameter, inputs);
|
||||
return false;
|
||||
return Phi(parameter, inputs);
|
||||
} else {
|
||||
return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
|
||||
}
|
||||
@ -439,8 +407,7 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
|
||||
&GetInstructionValueSet(xla_while->while_body()->root_instruction()),
|
||||
&GetInstructionValueSet(xla_while->operand(0))};
|
||||
if (ssa_form_) {
|
||||
UpdatePhiInputs(xla_while, inputs);
|
||||
return false;
|
||||
return Phi(xla_while, inputs);
|
||||
} else {
|
||||
return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
|
||||
}
|
||||
@ -487,38 +454,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
|
||||
VLOG(3) << "Worklist top: " << instruction->name();
|
||||
VLOG(3) << ToString();
|
||||
|
||||
// The updating of the instruction value set below in
|
||||
// UpdateInstructionValueSet does not update HloValue::positions(). To
|
||||
// perform the positions() update remove all positions in 'instruction' from
|
||||
// the HloValues in 'instruction's value set prior to the update, then after
|
||||
// the update add the new positions back in. There is likely a more
|
||||
// efficient way of doing this.
|
||||
for (auto& pair : GetInstructionValueSet(instruction)) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
HloValueSet& value_set = pair.second;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
if (value->defining_instruction() != instruction) {
|
||||
// Use GetValue for a non-const HloValue reference.
|
||||
GetValue(value->id()).RemovePosition(instruction, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = UpdateInstructionValueSet(instruction);
|
||||
|
||||
// Add the positions back in.
|
||||
for (auto& pair : GetInstructionValueSet(instruction)) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
HloValueSet& value_set = pair.second;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
if (value->defining_instruction() != instruction) {
|
||||
// Use GetValue for a non-const HloValue reference.
|
||||
GetValue(value->id()).AddPosition(instruction, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!changed) {
|
||||
if (!UpdateInstructionValueSet(instruction)) {
|
||||
// No change to the instruction's value set.
|
||||
VLOG(4) << "No change.";
|
||||
continue;
|
||||
@ -531,12 +467,16 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
worklist.push(user);
|
||||
|
||||
// If user calls a computation, then the respective parameter(s) of the
|
||||
// computation need to be updated.
|
||||
// If user sequentially calls a computation, then the respective
|
||||
// parameter(s) of the computation need to be updated.
|
||||
for (HloComputation* called_computation : user->called_computations()) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
worklist.push(
|
||||
called_computation->parameter_instruction(operand_number));
|
||||
const CallGraphNode& call_graph_node =
|
||||
call_graph_->GetNode(called_computation);
|
||||
if (call_graph_node.context() == CallContext::kSequential) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
worklist.push(
|
||||
called_computation->parameter_instruction(operand_number));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -574,25 +514,10 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
|
||||
}
|
||||
|
||||
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
// Gather the values to create before creating them. This is done because we
|
||||
// want to allocate the vector of values only once so references to elements
|
||||
// are stable.
|
||||
struct ValueToCreate {
|
||||
HloInstruction* instruction;
|
||||
ShapeIndex index;
|
||||
bool is_phi;
|
||||
};
|
||||
std::vector<ValueToCreate> values_to_create;
|
||||
|
||||
for (const std::unique_ptr<HloComputation>& computation :
|
||||
module_->computations()) {
|
||||
const CallGraphNode& call_graph_node =
|
||||
call_graph_->GetNode(computation.get());
|
||||
bool called_from_while = std::any_of(
|
||||
call_graph_node.caller_callsites().begin(),
|
||||
call_graph_node.caller_callsites().end(), [](const CallSite& cs) {
|
||||
return cs.instruction()->opcode() == HloOpcode::kWhile;
|
||||
});
|
||||
|
||||
for (const std::unique_ptr<HloInstruction>& instruction :
|
||||
computation->instructions()) {
|
||||
@ -603,20 +528,22 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
|
||||
// Lambda to set the value set to define all values in the output of the
|
||||
// instruction.
|
||||
auto define_all_values = [this, &instruction,
|
||||
&values_to_create](bool is_phi = false) {
|
||||
auto define_all_values = [this, &instruction](bool is_phi = false) {
|
||||
for (auto& pair : GetInstructionValueSet(instruction.get())) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
values_to_create.push_back({instruction.get(), index, is_phi});
|
||||
HloValue* value =
|
||||
NewHloValue(instruction.get(), index, /*is_phi=*/false);
|
||||
GetValueSet(instruction.get(), index).AddValue(value);
|
||||
}
|
||||
};
|
||||
|
||||
// Lambda to set the value set to define only the top-level buffer in the
|
||||
// output of the instruction. Any other values flow from the operands of
|
||||
// the instruction (or from cross-computation dataflow).
|
||||
auto define_top_level_only = [this, &instruction, &values_to_create]() {
|
||||
values_to_create.push_back(
|
||||
{instruction.get(), /*index=*/{}, /*is_phi=*/false});
|
||||
auto define_top_level_only = [this, &instruction]() {
|
||||
HloValue* value =
|
||||
NewHloValue(instruction.get(), /*index=*/{}, /*is_phi=*/false);
|
||||
GetValueSet(instruction.get(), /*index=*/{}).AddValue(value);
|
||||
};
|
||||
|
||||
switch (instruction->opcode()) {
|
||||
@ -626,10 +553,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kWhile:
|
||||
if (ssa_form_) {
|
||||
define_all_values(/*is_phi=*/true);
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
// These instructions define no values. The values in their output
|
||||
@ -654,10 +577,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
// values in their output. Otherwise the values of the parameter
|
||||
// come from the caller (eg, operands to the kCall instruction).
|
||||
define_all_values();
|
||||
} else if (call_graph_node.context() == CallContext::kSequential &&
|
||||
called_from_while && ssa_form_) {
|
||||
// Parameters of while bodies and conditions are phis.
|
||||
define_all_values(/*is_phi=*/true);
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kCopy:
|
||||
@ -674,164 +593,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
}
|
||||
}
|
||||
|
||||
// Reserve the vector ahead of time so references to elements are stable.
|
||||
values_.reserve(values_to_create.size());
|
||||
for (int64 i = 0; i < values_to_create.size(); ++i) {
|
||||
const ValueToCreate& to_create = values_to_create[i];
|
||||
values_.emplace_back(/*id=*/i, to_create.instruction, to_create.index,
|
||||
to_create.is_phi);
|
||||
const HloValue& value = values_.back();
|
||||
GetValueSet(to_create.instruction, to_create.index).AddValue(&value);
|
||||
if (value.is_phi()) {
|
||||
phi_inputs_[&value] = {};
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const {
|
||||
// If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
|
||||
// is live into the module.
|
||||
if (b.defining_instruction()->parent() == module_->entry_computation() &&
|
||||
b.defining_instruction()->opcode() == HloOpcode::kParameter) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Phi values require special handling. Because XLA does not have a phi
|
||||
// instruction, the definition instruction of the phis values are
|
||||
// placeholders: either the subcomputation parameter (body or condition) or
|
||||
// the while instruction. However, the program point where these values are
|
||||
// logically defined does not necessarily coincide exactly with program point
|
||||
// of these place-holder instructions. So we explicitly define the following
|
||||
// order for phi values:
|
||||
//
|
||||
// body/condition parameter phi:
|
||||
// Defined before all values defined in its computation excepting other
|
||||
// phis.
|
||||
//
|
||||
// while phi:
|
||||
// defined after all values defined in the condition or body.
|
||||
//
|
||||
auto is_body_or_condition_phi = [](const HloValue& v) {
|
||||
return v.is_phi() &&
|
||||
v.defining_instruction()->opcode() == HloOpcode::kParameter;
|
||||
};
|
||||
if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
|
||||
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
|
||||
a.defining_instruction()->parent())) {
|
||||
return true;
|
||||
}
|
||||
if (is_body_or_condition_phi(b) &&
|
||||
call_graph_->InstructionIsNestedIn(a.defining_instruction(),
|
||||
b.defining_instruction()->parent())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If 'b' is a while phi and 'a' is in the body or condition, then 'a'
|
||||
// executes before 'b'.
|
||||
if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
|
||||
(call_graph_->InstructionIsNestedIn(
|
||||
a.defining_instruction(), b.defining_instruction()->while_body()) ||
|
||||
call_graph_->InstructionIsNestedIn(
|
||||
a.defining_instruction(),
|
||||
b.defining_instruction()->while_condition()))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return ordering.ExecutesBefore(a.defining_instruction(),
|
||||
b.defining_instruction());
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UseIsBeforeValueDefinition(
|
||||
const HloUse& use, const HloValue& value,
|
||||
const HloOrdering& ordering) const {
|
||||
if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the use is at the instruction where the value is defined, then the use
|
||||
// is before the def if the instruction allows buffer sharing (in place
|
||||
// computation).
|
||||
if (use.instruction == value.defining_instruction() &&
|
||||
CanShareOperandBufferWithUser(
|
||||
use.instruction->mutable_operand(use.operand_number),
|
||||
use.operand_index, value.defining_instruction(),
|
||||
value.defining_index())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// The use at a while is an input to a phi, and logically occurs before values
|
||||
// are defined in the body or condition computations.
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
const HloInstruction* xla_while = use.instruction;
|
||||
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
xla_while->while_body()) ||
|
||||
call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
xla_while->while_condition())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Similarly if the value is defined at a while, it logically occurs after any
|
||||
// uses in the body or condition computations.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
|
||||
CHECK(ssa_form_);
|
||||
const HloInstruction* xla_while = value.defining_instruction();
|
||||
if (call_graph_->InstructionIsNestedIn(use.instruction,
|
||||
xla_while->while_body()) ||
|
||||
call_graph_->InstructionIsNestedIn(use.instruction,
|
||||
xla_while->while_condition())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::LiveRangeStrictlyBefore(
|
||||
const HloValue& a, const HloValue& b, const HloOrdering& ordering) const {
|
||||
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
|
||||
<< ", b = " << b.ToShortString() << ")";
|
||||
if (!IsDefinedBefore(a, b, ordering)) {
|
||||
VLOG(4) << "a not defined before b";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Live-out values from the module can never have ranges strictly before any
|
||||
// other value.
|
||||
if (a.live_out_of_module()) {
|
||||
VLOG(4) << "a is live out of module";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Live-out values of computations can never have ranges strictly before any
|
||||
// other value in the computation (including values nested in
|
||||
// subcomputations).
|
||||
if (a.live_out_of_computation() &&
|
||||
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
|
||||
a.defining_instruction()->parent())) {
|
||||
VLOG(4) << "a is live out of computation containing b";
|
||||
return false;
|
||||
}
|
||||
|
||||
// All uses of 'a' must be before 'b' is defined.
|
||||
for (const HloUse& use : a.uses()) {
|
||||
if (!UseIsBeforeValueDefinition(use, b, ordering)) {
|
||||
VLOG(4) << "use of a (" << use << ") not before b is defined";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const {
|
||||
// Buffers without disjoint liveness may interfere.
|
||||
return !LiveRangeStrictlyBefore(a, b, ordering) &&
|
||||
!LiveRangeStrictlyBefore(b, a, ordering);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
HloModule* module, bool ssa_form, bool bitcast_defines_value) {
|
||||
@ -855,6 +619,33 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
}
|
||||
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
|
||||
|
||||
// Add in positions to all values.
|
||||
for (const std::unique_ptr<HloComputation>& computation :
|
||||
module->computations()) {
|
||||
for (const std::unique_ptr<HloInstruction>& instruction :
|
||||
computation->instructions()) {
|
||||
for (const auto& pair :
|
||||
dataflow_analysis->GetInstructionValueSet(instruction.get())) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
const HloValueSet& value_set = pair.second;
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
if (value->defining_instruction() != instruction.get()) {
|
||||
dataflow_analysis->GetValue(value->id())
|
||||
.AddPosition(instruction.get(), index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct vector of values.
|
||||
dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
|
||||
for (auto& pair : dataflow_analysis->values_) {
|
||||
dataflow_analysis->values_vector_.push_back(&pair.second);
|
||||
}
|
||||
std::sort(dataflow_analysis->values_vector_.begin(),
|
||||
dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
|
||||
|
||||
TF_DCHECK_OK(dataflow_analysis->Verify());
|
||||
|
||||
XLA_VLOG_LINES(1, dataflow_analysis->ToString());
|
||||
@ -865,14 +656,14 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
Status HloDataflowAnalysis::Verify() const {
|
||||
// Verify each HloValue appears in the value sets that the value's positions()
|
||||
// indicate.
|
||||
for (const HloValue& value : values()) {
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
for (const HloValue* value : values()) {
|
||||
for (const HloPosition& position : value->positions()) {
|
||||
const HloValueSet& value_set = GetValueSet(position);
|
||||
TF_RET_CHECK(std::find(value_set.values().begin(),
|
||||
value_set.values().end(),
|
||||
&value) != value_set.values().end())
|
||||
value) != value_set.values().end())
|
||||
<< "Value set at position " << position << " does not contain value "
|
||||
<< value.ToShortString();
|
||||
<< value->ToShortString();
|
||||
}
|
||||
}
|
||||
|
||||
@ -898,75 +689,4 @@ Status HloDataflowAnalysis::Verify() const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloDataflowAnalysis::VerifyAgainstReference() const {
|
||||
TF_RETURN_IF_ERROR(Verify());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> reference,
|
||||
Run(module_, ssa_form_, bitcast_defines_value_));
|
||||
TF_RETURN_IF_ERROR(reference->Verify());
|
||||
|
||||
VLOG(2) << "This analysis:";
|
||||
XLA_VLOG_LINES(2, ToString());
|
||||
VLOG(2) << "Reference:";
|
||||
XLA_VLOG_LINES(2, reference->ToString());
|
||||
|
||||
// Verify value sets in each position are identical.
|
||||
for (const auto& computation : module_->computations()) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
for (const auto& pair : GetInstructionValueSet(instruction.get())) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
const HloValueSet& value_set = pair.second;
|
||||
const HloValueSet& reference_value_set =
|
||||
reference->GetValueSet(instruction.get(), index);
|
||||
|
||||
auto value_in_set = [](const HloValue& v, const HloValueSet& vset) {
|
||||
return std::find_if(vset.values().begin(), vset.values().end(),
|
||||
[&v](const HloValue* w) { return *w == v; }) !=
|
||||
vset.values().end();
|
||||
};
|
||||
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
TF_RET_CHECK(value_in_set(*value, reference_value_set))
|
||||
<< "Value " << value->ToShortString()
|
||||
<< " does not exist in reference";
|
||||
}
|
||||
for (const HloValue* reference_value : reference_value_set.values()) {
|
||||
TF_RET_CHECK(value_in_set(*reference_value, value_set))
|
||||
<< "Value " << reference_value->ToShortString()
|
||||
<< " only exists in reference";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all phis resolve identically and uses are identical.
|
||||
for (const HloValue& value : values()) {
|
||||
const HloValue& reference_value = reference->GetValueDefinedAt(
|
||||
value.defining_instruction(), value.defining_index());
|
||||
TF_RET_CHECK(value.is_phi() == reference_value.is_phi());
|
||||
if (value.is_phi()) {
|
||||
const HloValue* resolved_value = ResolvePhi(value);
|
||||
const HloValue* reference_resolved_value =
|
||||
reference->ResolvePhi(reference_value);
|
||||
if (resolved_value == nullptr) {
|
||||
TF_RET_CHECK(reference_resolved_value == nullptr);
|
||||
} else {
|
||||
TF_RET_CHECK(reference_resolved_value != nullptr);
|
||||
TF_RET_CHECK(*reference_resolved_value == *resolved_value);
|
||||
}
|
||||
}
|
||||
|
||||
for (const HloUse& use : value.uses()) {
|
||||
TF_RET_CHECK(std::find(reference_value.uses().begin(),
|
||||
reference_value.uses().end(),
|
||||
use) != reference_value.uses().end());
|
||||
}
|
||||
for (const HloUse& reference_use : reference_value.uses()) {
|
||||
TF_RET_CHECK(std::find(value.uses().begin(), value.uses().end(),
|
||||
reference_use) != value.uses().end());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -88,10 +88,10 @@ class HloDataflowAnalysis {
|
||||
// given position.
|
||||
const HloValueSet& GetValueSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const;
|
||||
HloValueSet& GetValueSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {});
|
||||
const HloValueSet& GetValueSet(const HloPosition& position) const;
|
||||
HloValueSet& GetValueSet(const HloPosition& position);
|
||||
HloValueSet& GetValueSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {});
|
||||
|
||||
// Return the unique value in the HloValueSet at the given instruction and
|
||||
// shape index. CHECKs if the value set does not contain a exactly one value.
|
||||
@ -108,49 +108,11 @@ class HloDataflowAnalysis {
|
||||
const HloValue& GetValue(HloValue::Id value_id) const;
|
||||
HloValue& GetValue(HloValue::Id value_id);
|
||||
|
||||
// Returns whether the given values interfere assuming the given HLO
|
||||
// ordering. Two values interfere if they may both be simultaneously live.
|
||||
bool MayInterfere(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const;
|
||||
|
||||
// Overload which takes HloValue:Ids.
|
||||
bool MayInterfere(HloValue::Id a, HloValue::Id b,
|
||||
const HloOrdering& ordering) const {
|
||||
return MayInterfere(GetValue(a), GetValue(b), ordering);
|
||||
}
|
||||
|
||||
// Return the total number of HloValues.
|
||||
int64 value_count() const { return values_.size(); }
|
||||
|
||||
// Return a vector of all HloValues.
|
||||
const std::vector<HloValue>& values() const { return values_; }
|
||||
|
||||
// Updates the dataflow after the changing an operand of
|
||||
// 'instruction'. Dataflow update is not possible if instructions have been
|
||||
// added or removed from the graph.
|
||||
void UpdateAfterChangingOperand(HloInstruction* instruction,
|
||||
HloInstruction* old_operand,
|
||||
HloInstruction* new_operand);
|
||||
|
||||
// Updates the dataflow after the changing the root of a computation from
|
||||
// 'old_root' to 'new_root'.
|
||||
void UpdateAfterChangingRoot(HloInstruction* old_root,
|
||||
HloInstruction* new_root);
|
||||
|
||||
// Returns the non-phi HloValue that is the unique (transitive) input to the
|
||||
// given phi. If no such HloValue exists (there are multiple inputs to the
|
||||
// phi) then nullptr is returned. This is computed by all walking the inputs
|
||||
// of the given phi value until non-phi HloValue(s) are encountered.
|
||||
const HloValue* ResolvePhi(const HloValue& phi) const;
|
||||
const HloValue* ResolvePhi(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const {
|
||||
return ResolvePhi(GetValueDefinedAt(instruction, index));
|
||||
}
|
||||
|
||||
// Compare the dataflow analysis against a clean recomputation of the
|
||||
// analysis. Returns an error status if there is a mismatch. Useful for
|
||||
// verifying the correctness after updates to the analysis.
|
||||
Status VerifyAgainstReference() const;
|
||||
// Return a vector of all HloValues stabily sorted by HloValue::Id.
|
||||
const std::vector<const HloValue*>& values() const { return values_vector_; }
|
||||
|
||||
// Return the call graph used for computing the dataflow.
|
||||
const CallGraph& call_graph() const { return *call_graph_; }
|
||||
@ -161,6 +123,13 @@ class HloDataflowAnalysis {
|
||||
HloDataflowAnalysis(HloModule* module, bool ssa_form,
|
||||
bool bitcast_defines_value = false);
|
||||
|
||||
// Returns a new HloValue defined at the given instruction and shape index.
|
||||
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
|
||||
bool is_phi = false);
|
||||
|
||||
// Delete the HloValue with the given ID.
|
||||
void DeleteHloValue(HloValue::Id value_id);
|
||||
|
||||
// Constructs and initializes the InstructionValueSets of all instructions to
|
||||
// contain exactly the HloValues defined by each instruction. These values can
|
||||
// then propagated throughout the HLO graph by calling
|
||||
@ -187,10 +156,11 @@ class HloDataflowAnalysis {
|
||||
void UpdateInstructionsAndPropagate(
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
|
||||
|
||||
// Sets the inputs of the given phi to given value(s).
|
||||
void UpdatePhiInputs(
|
||||
const HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
|
||||
// Return the result of the SSA Phi function applied to the given inputs at
|
||||
// the given instruction. If skip_top_level is true, then the top level of the
|
||||
// value set of 'instruction' is not modified.
|
||||
bool Phi(HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
|
||||
|
||||
// Updates the positions of the HloValues in the output of the given
|
||||
// instruction. This should be called after the instruction value set of
|
||||
@ -203,20 +173,6 @@ class HloDataflowAnalysis {
|
||||
HloInstruction* instruction, const InstructionValueSet& new_value_set,
|
||||
const InstructionValueSet* prev_value_set = nullptr);
|
||||
|
||||
// Returns true if the live range of the given value 'a' is strictly before
|
||||
// the live range of value 'b' using the given HLO ordering.
|
||||
bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const;
|
||||
|
||||
// Returns whether the value 'a' is defined before the value 'b' under the
|
||||
// given ordering.
|
||||
bool IsDefinedBefore(const HloValue& a, const HloValue& b,
|
||||
const HloOrdering& ordering) const;
|
||||
|
||||
// Returns whether the given use is before the given value definition.
|
||||
bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value,
|
||||
const HloOrdering& ordering) const;
|
||||
|
||||
// Verify various invariants of the dataflow analysis.
|
||||
Status Verify() const;
|
||||
|
||||
@ -226,19 +182,19 @@ class HloDataflowAnalysis {
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph_;
|
||||
|
||||
// Array of all values in the module. This is allocated once at analysis
|
||||
// construction time so HloValue references are stable. Updates to the
|
||||
// analysis via UpdateAfterChangingOperand and UpdateAfterChangingRoot do not
|
||||
// result in the creation or destruction of any HloValues.
|
||||
std::vector<HloValue> values_;
|
||||
|
||||
// Map hold the inputs to each phi value in the module. Used by ResolvePhi.
|
||||
tensorflow::gtl::FlatMap<const HloValue*,
|
||||
tensorflow::gtl::InlinedVector<const HloValue*, 2>>
|
||||
phi_inputs_;
|
||||
// The map of all HloValues in the module. We pass around pointers to the
|
||||
// mapped HloValues, so the underlying container must keep them valid despite
|
||||
// mutations touching other map entries.
|
||||
std::unordered_map<HloValue::Id, HloValue> values_;
|
||||
|
||||
// A map from instruction to InstructionValueSet.
|
||||
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
|
||||
|
||||
// A vector containing all HloValues sorted by HloValue::Id.
|
||||
std::vector<const HloValue*> values_vector_;
|
||||
|
||||
// The Id to use for the next HloValue.
|
||||
HloValue::Id next_value_id_ = 0;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -44,8 +43,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
|
||||
|
||||
// Run dataflow analysis on the member module. For convenience returns a
|
||||
// reference to the generated analysis stored in analysis_.
|
||||
HloDataflowAnalysis& RunAnalysis(bool ssa_form,
|
||||
bool bitcast_defines_value = false) {
|
||||
const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
|
||||
bool bitcast_defines_value = false) {
|
||||
analysis_ =
|
||||
HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
|
||||
.ConsumeValueOrDie();
|
||||
@ -71,8 +70,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
|
||||
const HloInstruction* b) {
|
||||
EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
|
||||
EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
|
||||
return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a),
|
||||
analysis_->GetValueDefinedAt(b), ordering);
|
||||
return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
|
||||
analysis_->GetValueDefinedAt(b));
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> module_;
|
||||
@ -499,37 +498,26 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
|
||||
|
||||
if (ssa_form) {
|
||||
// While instruction should define phi values. The value at index {0} is a
|
||||
// degenerate phi with a single input 'constant1'.
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{0}),
|
||||
&analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{0}),
|
||||
&analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{0}),
|
||||
&analysis.GetValueDefinedAt(constant1));
|
||||
// Element 0 of the tuple passed through the body so no phi value is
|
||||
// defined.
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
|
||||
|
||||
// Element 1 of the tuple should be a phi value.
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{1}), nullptr);
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{1}), nullptr);
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{1}), nullptr);
|
||||
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{xla_while, 0, {0}}));
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}}));
|
||||
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
|
||||
.live_out_of_module());
|
||||
// Constant1 passes through the body and out of the module.
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
|
||||
.live_out_of_module());
|
||||
|
||||
@ -613,20 +601,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
|
||||
bool ssa_form = GetParam();
|
||||
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
if (ssa_form) {
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while2).live_out_of_module());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
} else {
|
||||
// Element 0 is passed through all the while instructions and out of the
|
||||
// module.
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
}
|
||||
// Element 0 is passed through all the while instructions and out of the
|
||||
// module..
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
|
||||
analysis.GetValueDefinedAt(constant1));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
|
||||
@ -705,13 +688,18 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
|
||||
bool ssa_form = GetParam();
|
||||
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
|
||||
if (ssa_form) {
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
|
||||
|
||||
// Element 0 of the nested while is %negate.
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
|
||||
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
|
||||
// Element 1 is a phi value (join of %add and %constant2).
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
|
||||
@ -724,8 +712,6 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(add),
|
||||
analysis.GetValueDefinedAt(constant2)));
|
||||
@ -1496,256 +1482,6 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
|
||||
EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, UpdateAnalysisForWhile) {
|
||||
// Test updating dataflow after modifying a module with an array shaped while:
|
||||
//
|
||||
// body(F32[] %param):
|
||||
// %negate = Negate(%param)
|
||||
//
|
||||
// condition(F32[] %param):
|
||||
// return Constant(false)
|
||||
//
|
||||
// entry:
|
||||
// %constant = Constant(1.0)
|
||||
// %exp = Exp(%constant)
|
||||
// return While(%exp, body, condition)
|
||||
//
|
||||
auto body_builder = HloComputation::Builder("body");
|
||||
auto body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
scalar_shape_, HloOpcode::kNegate, body_param));
|
||||
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
// Condition computation trivially returns a constant "false".
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
auto cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloComputation* condition =
|
||||
module_->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
|
||||
auto xla_while = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
bool ssa_form = GetParam();
|
||||
HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
// Sanity check the initial dataflow analysis before transforming the HLO
|
||||
// graph.
|
||||
if (ssa_form) {
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param).is_phi());
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
|
||||
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(body_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(cond_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(xla_while),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
}
|
||||
|
||||
// Set the body root to the body_param. Previously it was Negate(body_param).
|
||||
body->set_root_instruction(body_param);
|
||||
|
||||
// Prior to updating, verify that the dataflow analysis is no longer valid.
|
||||
Status verify_status = analysis.VerifyAgainstReference();
|
||||
EXPECT_FALSE(verify_status.ok());
|
||||
|
||||
analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
|
||||
/*new_root=*/body_param);
|
||||
|
||||
// Analysis should be valid after the update.
|
||||
TF_EXPECT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
if (ssa_form) {
|
||||
// The phis should now be resolvable as 'exp' is passed through the body
|
||||
// transparently.
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param),
|
||||
&analysis.GetValueDefinedAt(exp));
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param),
|
||||
&analysis.GetValueDefinedAt(exp));
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while), &analysis.GetValueDefinedAt(exp));
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(body_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
|
||||
EXPECT_THAT(HloValuesAt(cond_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
|
||||
EXPECT_THAT(HloValuesAt(xla_while),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
}
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
|
||||
|
||||
// Now replace the operand of the while with %constant (was %exp).
|
||||
TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
|
||||
analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
|
||||
/*new_operand=*/constant);
|
||||
|
||||
// Verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
if (ssa_form) {
|
||||
// The phis now resolve to 'constant'.
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param),
|
||||
&analysis.GetValueDefinedAt(constant));
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param),
|
||||
&analysis.GetValueDefinedAt(constant));
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while),
|
||||
&analysis.GetValueDefinedAt(constant));
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(body_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
|
||||
EXPECT_THAT(HloValuesAt(cond_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
|
||||
EXPECT_THAT(HloValuesAt(xla_while),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
|
||||
}
|
||||
|
||||
// And finally make the negate the root of the body again.
|
||||
body->set_root_instruction(negate);
|
||||
analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
|
||||
/*new_root=*/negate);
|
||||
|
||||
// Verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
if (ssa_form) {
|
||||
// Phis should no longer be resolvable.
|
||||
EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
|
||||
EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
|
||||
EXPECT_EQ(analysis.ResolvePhi(xla_while), nullptr);
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(body_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(cond_param),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
EXPECT_THAT(HloValuesAt(xla_while),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
|
||||
analysis.GetValueDefinedAt(negate)));
|
||||
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
|
||||
}
|
||||
|
||||
// After the updates, verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, UpdateOfATupleSelect) {
|
||||
// Test changing the operands of kSelects of a tuple value and updating the
|
||||
// dataflow.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto pred = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
auto a = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto b = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
|
||||
auto c = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
|
||||
auto d = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
|
||||
auto tuple_a = builder.AddInstruction(HloInstruction::CreateTuple({a}));
|
||||
auto tuple_b = builder.AddInstruction(HloInstruction::CreateTuple({b}));
|
||||
auto tuple_c = builder.AddInstruction(HloInstruction::CreateTuple({c}));
|
||||
auto tuple_d = builder.AddInstruction(HloInstruction::CreateTuple({d}));
|
||||
const Shape tuple_shape = tuple_a->shape();
|
||||
auto select_aa = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_a));
|
||||
auto select_ab = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_b));
|
||||
auto select_cd = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple_shape, HloOpcode::kSelect, pred, tuple_c, tuple_d));
|
||||
auto select_abcd = builder.AddInstruction(HloInstruction::CreateTernary(
|
||||
tuple_shape, HloOpcode::kSelect, pred, select_ab, select_cd));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
bool ssa_form = GetParam();
|
||||
HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
// Sanity check dataflow before changing the graph and updating.
|
||||
EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a)));
|
||||
EXPECT_THAT(HloValuesAt(select_ab, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(b)));
|
||||
EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(c),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(b),
|
||||
analysis.GetValueDefinedAt(c),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(c).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
|
||||
|
||||
// Set the rhs of 'select_aa' to be 'd'.
|
||||
TF_ASSERT_OK(select_aa->ReplaceOperandWith(2, tuple_d));
|
||||
analysis.UpdateAfterChangingOperand(select_aa, /*old_operand=*/tuple_a,
|
||||
/*new_operand=*/tuple_d);
|
||||
|
||||
// Verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
|
||||
// Set the lhs of 'select_cd' to be 'a'.
|
||||
TF_ASSERT_OK(select_cd->ReplaceOperandWith(1, tuple_a));
|
||||
analysis.UpdateAfterChangingOperand(select_cd, /*old_operand=*/tuple_c,
|
||||
/*new_operand=*/tuple_a);
|
||||
|
||||
// Verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
|
||||
EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(a),
|
||||
analysis.GetValueDefinedAt(b),
|
||||
analysis.GetValueDefinedAt(d)));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(c).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
|
||||
|
||||
// After the updates, verify that the dataflow is correct.
|
||||
TF_ASSERT_OK(analysis.VerifyAgainstReference());
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
|
||||
HloDataflowAnalysisTest,
|
||||
::testing::Values(false, true));
|
||||
|
@ -561,13 +561,21 @@ tooltip = " ";
|
||||
}
|
||||
|
||||
string comp_body = DumpComputation(subcomp);
|
||||
string computation =
|
||||
Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
|
||||
|
||||
// Add an edge from the subcomputation to its parent node. If subcomp
|
||||
// belongs to a fusion node, it's drawn in place of the fusion instruction, so
|
||||
// there's no need to link those.
|
||||
if (parent_instr->opcode() != HloOpcode::kFusion) {
|
||||
if (parent_instr->opcode() == HloOpcode::kFusion) {
|
||||
// Dump any nested fusion nodes.
|
||||
for (const auto& subcomp_instr : subcomp->instructions()) {
|
||||
if (subcomp_instr->opcode() == HloOpcode::kFusion) {
|
||||
StrAppend(
|
||||
&comp_body,
|
||||
DumpSubcomputation(subcomp_instr->fused_instructions_computation(),
|
||||
subcomp_instr.get()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Add an edge from the subcomputation to its parent node. If subcomp
|
||||
// belongs to a fusion node, it's drawn in place of the fusion instruction,
|
||||
// so there's no need to link those.
|
||||
edge_ids_.insert(
|
||||
{{subcomp->root_instruction(), parent_instr}, next_edge_id_++});
|
||||
const char* edge_fmt =
|
||||
@ -578,6 +586,9 @@ tooltip = " ";
|
||||
subcomp->name(), parent_instr->name()));
|
||||
}
|
||||
|
||||
string computation =
|
||||
Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
|
||||
|
||||
return computation;
|
||||
}
|
||||
|
||||
|
@ -793,13 +793,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
}
|
||||
}
|
||||
|
||||
for (HloComputation* computation :
|
||||
instruction_to_fuse->called_computations()) {
|
||||
if (std::find(called_computations_.begin(), called_computations_.end(),
|
||||
computation) == called_computations_.end()) {
|
||||
called_computations_.push_back(computation);
|
||||
}
|
||||
}
|
||||
VLOG(2) << "New clone:\n" << clone->ToString();
|
||||
return clone;
|
||||
}
|
||||
@ -1669,6 +1662,21 @@ string HloInstruction::ExtendedOpcodeStr() const {
|
||||
|
||||
string HloInstruction::ToString(bool compact_operands,
|
||||
bool include_metadata) const {
|
||||
string result =
|
||||
StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
|
||||
ExtendedOpcodeStr(), "(", OperandsToString(compact_operands), ")");
|
||||
for (const string& extra : ExtraAttributesToString()) {
|
||||
StrAppend(&result, ", ", extra);
|
||||
}
|
||||
if (include_metadata &&
|
||||
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
|
||||
!metadata_.source_file().empty())) {
|
||||
StrAppend(&result, " # metadata=", metadata_.ShortDebugString());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
string HloInstruction::OperandsToString(bool compact) const {
|
||||
string operands;
|
||||
if (opcode() == HloOpcode::kConstant) {
|
||||
// For constants, show the actual value in place of an empty operand list.
|
||||
@ -1697,12 +1705,12 @@ string HloInstruction::ToString(bool compact_operands,
|
||||
} else {
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
|
||||
const int64 kMaxOperandsToShowIfCompact = 4;
|
||||
if (compact_operands && slice.size() > kMaxOperandsToShowIfCompact) {
|
||||
if (compact && slice.size() > kMaxOperandsToShowIfCompact) {
|
||||
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
|
||||
}
|
||||
operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
|
||||
*out += ShapeUtil::HumanStringWithLayout(operand->shape());
|
||||
if (!compact_operands) {
|
||||
if (!compact) {
|
||||
StrAppend(out, " ", operand->name());
|
||||
}
|
||||
});
|
||||
@ -1711,15 +1719,19 @@ string HloInstruction::ToString(bool compact_operands,
|
||||
StrAppend(&operands, ", ...(+", remaining, ")");
|
||||
}
|
||||
}
|
||||
string extra;
|
||||
return operands;
|
||||
}
|
||||
|
||||
std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
||||
std::vector<string> extra;
|
||||
if (CanHaveDimensionsField()) {
|
||||
StrAppend(&extra, ", dimensions={", Join(dimensions(), ","), "}");
|
||||
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
|
||||
}
|
||||
if (window_ != nullptr) {
|
||||
StrAppend(&extra, ", ", window_util::ToString(*window_));
|
||||
extra.push_back(window_util::ToString(*window_));
|
||||
}
|
||||
if (padding_config_ != nullptr) {
|
||||
StrAppend(&extra, ", padding=", padding_config_->ShortDebugString());
|
||||
extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
|
||||
}
|
||||
if (!slice_starts_.empty() && !slice_limits_.empty()) {
|
||||
std::vector<string> bounds;
|
||||
@ -1728,45 +1740,38 @@ string HloInstruction::ToString(bool compact_operands,
|
||||
bounds.push_back(
|
||||
StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]"));
|
||||
}
|
||||
StrAppend(&extra, ", slice={", Join(bounds, ", "), "}");
|
||||
extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
|
||||
}
|
||||
|
||||
if (convolution_dimension_numbers_ != nullptr) {
|
||||
StrAppend(&extra, ", ", ConvolutionDimensionNumbersToString());
|
||||
extra.push_back(ConvolutionDimensionNumbersToString());
|
||||
}
|
||||
|
||||
if (opcode() == HloOpcode::kWhile) {
|
||||
StrAppend(&extra, ", condition=", while_condition()->name());
|
||||
StrAppend(&extra, ", body=", while_body()->name());
|
||||
extra.push_back(StrCat("condition=", while_condition()->name()));
|
||||
extra.push_back(StrCat("body=", while_body()->name()));
|
||||
} else if (opcode() == HloOpcode::kSelectAndScatter) {
|
||||
StrAppend(&extra, ", select=", select()->name());
|
||||
StrAppend(&extra, ", scatter=", scatter()->name());
|
||||
extra.push_back(StrCat("select=", select()->name()));
|
||||
extra.push_back(StrCat("scatter=", scatter()->name()));
|
||||
} else if (!called_computations().empty()) {
|
||||
StrAppend(&extra, ", calls=",
|
||||
Join(called_computations(), ", ",
|
||||
[](string* out, const HloComputation* computation) {
|
||||
StrAppend(out, computation->name());
|
||||
}));
|
||||
extra.push_back(StrCat(
|
||||
"calls=", Join(called_computations(), ", ",
|
||||
[](string* out, const HloComputation* computation) {
|
||||
StrAppend(out, computation->name());
|
||||
})));
|
||||
}
|
||||
|
||||
if (opcode() == HloOpcode::kGetTupleElement) {
|
||||
StrAppend(&extra, ", index=", tuple_index());
|
||||
extra.push_back(StrCat("index=", tuple_index()));
|
||||
}
|
||||
if (!control_successors_.empty()) {
|
||||
StrAppend(
|
||||
&extra, ", control-successors=",
|
||||
extra.push_back(StrCat(
|
||||
"control-successors=",
|
||||
Join(control_successors_, ", ", [](string* out, HloInstruction* succ) {
|
||||
StrAppend(out, succ->name());
|
||||
}));
|
||||
})));
|
||||
}
|
||||
if (include_metadata &&
|
||||
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
|
||||
!metadata_.source_file().empty())) {
|
||||
StrAppend(&extra, " # metadata=", metadata_.ShortDebugString());
|
||||
}
|
||||
|
||||
return StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
|
||||
ExtendedOpcodeStr(), "(", operands, ")", extra);
|
||||
return extra;
|
||||
}
|
||||
|
||||
string HloInstruction::ToShortString() const {
|
||||
|
@ -548,6 +548,14 @@ class HloInstruction {
|
||||
string ToString(bool compact_operands = false,
|
||||
bool include_metadata = true) const;
|
||||
|
||||
// Components of the ToString() representation:
|
||||
|
||||
// Returns a string representation of the operand list.
|
||||
string OperandsToString(bool compact) const;
|
||||
|
||||
// Returns string representation of op-specific attributes.
|
||||
std::vector<string> ExtraAttributesToString() const;
|
||||
|
||||
string ToStringNoMetadata() const { return ToString(false, false); }
|
||||
|
||||
// As ToString, but returns a shorter string.
|
||||
@ -797,8 +805,7 @@ class HloInstruction {
|
||||
const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
|
||||
|
||||
// Returns the computations this instruction calls (if any). This includes
|
||||
// computations called by fused instructions inside of a fusion instruction.
|
||||
// Returns the computations this instruction directly calls (if any).
|
||||
const std::vector<HloComputation*>& called_computations() const {
|
||||
return called_computations_;
|
||||
}
|
||||
|
@ -758,16 +758,13 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||
auto* fusion = computation->CreateFusionInstruction(
|
||||
{map_3_y}, HloInstruction::FusionKind::kLoop);
|
||||
auto* fused_computation = fusion->fused_instructions_computation();
|
||||
EXPECT_THAT(fusion->called_computations(),
|
||||
ElementsAre(fused_computation, computation_y));
|
||||
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
|
||||
|
||||
fusion->FuseInstruction(map_2_x);
|
||||
EXPECT_THAT(fusion->called_computations(),
|
||||
ElementsAre(fused_computation, computation_y, computation_x));
|
||||
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
|
||||
|
||||
fusion->FuseInstruction(map_1_x);
|
||||
EXPECT_THAT(fusion->called_computations(),
|
||||
ElementsAre(fused_computation, computation_y, computation_x));
|
||||
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, ComplexFusionOp) {
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
|
||||
@ -218,6 +219,94 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
|
||||
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
|
||||
}
|
||||
|
||||
TEST_F(HloOrderingTest, ValuesInWhileComputations) {
|
||||
// Tests the ordering of values (defined by dataflow analysis) in the body and
|
||||
// condition of a while instruction. HLO code:
|
||||
//
|
||||
// body(F32[]) %param):
|
||||
// %negate = Negate(%param)
|
||||
//
|
||||
// condition(F32[] %param):
|
||||
// %convert = Convert<PRED>(%param)
|
||||
//
|
||||
// entry:
|
||||
// %constant = Constant(1.0)
|
||||
// %while = While(%constant, body, condition)
|
||||
// %add = Add(%constant, %while)
|
||||
//
|
||||
auto module = CreateNewModule();
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
|
||||
|
||||
auto body_builder = HloComputation::Builder("body");
|
||||
auto body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
|
||||
auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
scalar_shape, HloOpcode::kNegate, body_param));
|
||||
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
auto cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
|
||||
auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
|
||||
HloComputation* condition =
|
||||
module->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
|
||||
auto xla_while = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, constant, xla_while));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true));
|
||||
DependencyHloOrdering ordering(module.get());
|
||||
|
||||
// Init value is defined before the while, but live range is not before the
|
||||
// while because of the use of the init value in the add.
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_FALSE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
|
||||
// Any value defined in the body or condition is defined before the while, and
|
||||
// has a live range strictly before the while.
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
|
||||
// The live range of the while should be before the add.
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
|
||||
dataflow->GetValueDefinedAt(add)));
|
||||
ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
|
||||
|
||||
const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
|
||||
EXPECT_EQ(while_use.instruction, add);
|
||||
EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
|
||||
while_use, dataflow->GetValueDefinedAt(add)));
|
||||
EXPECT_TRUE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while),
|
||||
dataflow->GetValueDefinedAt(add)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -54,11 +55,18 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
<< tensorflow::str_util::Join(disabled_passes, ", ");
|
||||
}
|
||||
|
||||
auto run_invariant_checkers = [this, module]() -> Status {
|
||||
auto run_invariant_checkers = [this,
|
||||
module](const string& message) -> Status {
|
||||
for (auto& invariant_checker : invariant_checkers_) {
|
||||
VLOG(1) << " Invariant checker " << invariant_checker->name();
|
||||
TF_ASSIGN_OR_RETURN(bool changed, invariant_checker->Run(module));
|
||||
TF_RET_CHECK(!changed) << "invariant checkers must not change the graph";
|
||||
StatusOr<bool> changed_status = invariant_checker->Run(module);
|
||||
if (!changed_status.ok()) {
|
||||
return Status(changed_status.status().code(),
|
||||
StrCat(changed_status.status().error_message(),
|
||||
"\n\nFailed ", message));
|
||||
}
|
||||
TF_RET_CHECK(!changed_status.ValueOrDie())
|
||||
<< "invariant checkers must not change the graph";
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
@ -66,6 +74,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
string prefix = name().ToString() + ": pipeline start";
|
||||
bool changed = false;
|
||||
string message;
|
||||
TF_RETURN_IF_ERROR(
|
||||
run_invariant_checkers(StrCat("before running pipeline: ", name())));
|
||||
for (auto& pass : passes_) {
|
||||
if (disabled_passes.count(pass->name().ToString()) > 0) {
|
||||
VLOG(1) << " Skipping HLO pass " << pass->name()
|
||||
@ -80,14 +90,14 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
StrAppend(&message, prefix, ", before ", pass->name());
|
||||
DumpModule(*module, message);
|
||||
|
||||
TF_RETURN_IF_ERROR(run_invariant_checkers());
|
||||
TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
|
||||
TF_RETURN_IF_ERROR(
|
||||
run_invariant_checkers(StrCat("after running pass: ", pass->name())));
|
||||
|
||||
changed |= changed_this_pass;
|
||||
prefix.clear();
|
||||
StrAppend(&prefix, name(), ": after ", pass->name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(run_invariant_checkers());
|
||||
DumpModule(*module, prefix + ", pipeline end");
|
||||
return changed;
|
||||
}
|
||||
|
@ -1202,7 +1202,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
|
||||
StatusOr<bool> HloRematerialization::Run(
|
||||
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
|
||||
int64 memory_limit_bytes) {
|
||||
int64 memory_limit_bytes, RematerializationSizes* sizes) {
|
||||
// The sequence is constructed entirely by this method.
|
||||
TF_RET_CHECK(sequence->empty());
|
||||
|
||||
@ -1248,7 +1248,8 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
sequence->at(node.computation())));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
},
|
||||
/*visit_unreachable_nodes=*/false));
|
||||
|
||||
// The peak memory usage of the module equals the peak memory use of the entry
|
||||
// computation plus the output size of the computation. This is because the
|
||||
@ -1318,13 +1319,20 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
<< HumanReadableNumBytes(reduced_peak_memory) << " ("
|
||||
<< reduced_peak_memory << " bytes)";
|
||||
|
||||
if (sizes != nullptr) {
|
||||
sizes->before_bytes = before_peak_memory;
|
||||
sizes->after_bytes = current_peak_memory;
|
||||
}
|
||||
|
||||
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
|
||||
|
||||
if (current_peak_memory > memory_limit_bytes) {
|
||||
LOG(WARNING) << "Can't reduce memory use below "
|
||||
<< HumanReadableNumBytes(memory_limit_bytes)
|
||||
<< " by rematerialization (only reduced to "
|
||||
<< HumanReadableNumBytes(current_peak_memory) << ")";
|
||||
LOG(WARNING) << tensorflow::strings::Printf(
|
||||
"Can't reduce memory use below %s (%lld bytes) by rematerialization; "
|
||||
"only reduced to %s (%lld bytes)",
|
||||
HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes,
|
||||
HumanReadableNumBytes(current_peak_memory).c_str(),
|
||||
current_peak_memory);
|
||||
}
|
||||
|
||||
return changed;
|
||||
@ -1333,9 +1341,10 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
|
||||
const HloRematerialization::ShapeSizeFunction& size_function,
|
||||
int64 memory_limit_bytes, HloModule* hlo_module,
|
||||
SequentialHloOrdering::HloModuleSequence* sequence) {
|
||||
SequentialHloOrdering::HloModuleSequence* sequence,
|
||||
RematerializationSizes* sizes) {
|
||||
HloRematerialization remat(size_function);
|
||||
return remat.Run(hlo_module, sequence, memory_limit_bytes);
|
||||
return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -28,6 +28,13 @@ class HloRematerialization {
|
||||
public:
|
||||
using ShapeSizeFunction = std::function<int64(const Shape&)>;
|
||||
|
||||
// Helper struct that communicates the before / after sizes for the
|
||||
// rematerialization process.
|
||||
struct RematerializationSizes {
|
||||
int64 before_bytes;
|
||||
int64 after_bytes;
|
||||
};
|
||||
|
||||
// Rematerialize HLO instructions in the given module to reduce peak memory
|
||||
// use below memory_limit_bytes where memory use is defined as the total size
|
||||
// of all live HLO instruction values. Parameters and constants are included
|
||||
@ -46,6 +53,9 @@ class HloRematerialization {
|
||||
// rematerialization. This is the order in which HLO instructions should
|
||||
// be emitted to minimize memory use.
|
||||
//
|
||||
// sizes: Optional outparam that indicates the peak memory usage of the HLO
|
||||
// module before/after rematerialization.
|
||||
//
|
||||
// Returns whether any instructions were rematerialized. If memory use is
|
||||
// already below the given limit then no instructions are rematerialized and
|
||||
// false is returned.
|
||||
@ -55,8 +65,8 @@ class HloRematerialization {
|
||||
// code generation.
|
||||
static StatusOr<bool> RematerializeAndSchedule(
|
||||
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
|
||||
HloModule* hlo_module,
|
||||
SequentialHloOrdering::HloModuleSequence* sequence);
|
||||
HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence,
|
||||
RematerializationSizes* sizes = nullptr);
|
||||
|
||||
protected:
|
||||
HloRematerialization(const ShapeSizeFunction& size_function)
|
||||
@ -69,7 +79,7 @@ class HloRematerialization {
|
||||
// contains the memory-minimizing order in which to emit the HLO instructions.
|
||||
StatusOr<bool> Run(HloModule* module,
|
||||
SequentialHloOrdering::HloModuleSequence* sequence,
|
||||
int64 memory_limit);
|
||||
int64 memory_limit, RematerializationSizes* sizes);
|
||||
|
||||
// Rematerializes instructions within the given computation. 'order' is the
|
||||
// order in which the computation's instructions will be emitted in the
|
||||
|
@ -159,12 +159,6 @@ void HloValue::AddPosition(HloInstruction* instruction,
|
||||
for (const HloPosition& position : positions_) {
|
||||
DCHECK_NE(position, new_position);
|
||||
}
|
||||
// The shape of the new position must match existing positions.
|
||||
if (!positions_.empty()) {
|
||||
CHECK(
|
||||
ShapeUtil::Compatible(positions_.front().shape(), new_position.shape()))
|
||||
<< "front: " << positions_.front() << " new: " << new_position;
|
||||
}
|
||||
|
||||
positions_.push_back(std::move(new_position));
|
||||
|
||||
|
@ -225,6 +225,9 @@ class HloValueSet {
|
||||
// already exist in the set.
|
||||
bool AddValue(const HloValue* value);
|
||||
|
||||
// Clear all values from the set.
|
||||
void Clear() { values_.clear(); }
|
||||
|
||||
// Return the unique HLO value in the set. CHECKs if the set does not contain
|
||||
// exactly one value.
|
||||
const HloValue& GetUniqueValue() const {
|
||||
|
@ -32,13 +32,11 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
const std::function<int64(const Shape&)>& shape_size_fn)
|
||||
: shape_size_fn_(shape_size_fn) {}
|
||||
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override {
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo) override {
|
||||
return CheckUnaryShape(hlo);
|
||||
}
|
||||
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo,
|
||||
HloOpcode opcode) override {
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo) override {
|
||||
return CheckBinaryShape(hlo);
|
||||
}
|
||||
|
||||
@ -282,6 +280,14 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
const std::function<int64(const Shape&)> shape_size_fn_;
|
||||
};
|
||||
|
||||
string ComputationsToString(
|
||||
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
|
||||
return tensorflow::str_util::Join(
|
||||
computations, ",", [](string* s, const HloComputation* computation) {
|
||||
s->append(computation->name());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
@ -292,6 +298,17 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
TF_RET_CHECK(instruction->parent() == computation.get());
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
TF_RET_CHECK(
|
||||
ContainersEqual(instruction->called_computations(),
|
||||
{instruction->fused_instructions_computation()}))
|
||||
<< "Fusion HLO calls computations other than the "
|
||||
"fused_instructions_computation: "
|
||||
<< instruction->ToString()
|
||||
<< " instruction->fused_instructions_computation(): "
|
||||
<< instruction->fused_instructions_computation()->ToString()
|
||||
<< " instruction->called_computations(): "
|
||||
<< ComputationsToString(instruction->called_computations());
|
||||
|
||||
for (const auto& fused : instruction->fused_instructions()) {
|
||||
TF_RET_CHECK(fused->parent() ==
|
||||
instruction->fused_instructions_computation())
|
||||
|
@ -122,7 +122,9 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
if (instruction->opcode() == HloOpcode::kFusion &&
|
||||
(instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
|
||||
instruction->fusion_kind() == HloInstruction::FusionKind::kInput)) {
|
||||
// Insert the reduce-precision operation inside the fusion computation,
|
||||
// after the corresponding parameter instruction.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -171,7 +173,9 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
if (instruction->opcode() == HloOpcode::kFusion &&
|
||||
(instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
|
||||
instruction->fusion_kind() == HloInstruction::FusionKind::kOutput)) {
|
||||
// Insert the reduce-precision operation as the last operation inside
|
||||
// the fusion computation.
|
||||
HloInstruction* fusion_root = instruction->fused_expression_root();
|
||||
|
@ -215,6 +215,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
@ -111,6 +111,11 @@ cc_binary(
|
||||
deps = [
|
||||
":replay_computation_library",
|
||||
"//tensorflow/compiler/plugin/executor:plugin_lib",
|
||||
# TODO: This dependency is a workaround for linking error with clang.
|
||||
# Without it, linker complains about missing symbols from
|
||||
# 'xla_device_launch_op'. This dependency should be propagated from
|
||||
# plugin_lib instead, but no targets other than this break without it.
|
||||
"//tensorflow/compiler/jit",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -144,7 +144,7 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args,
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// Flags
|
||||
string fake_infeed_shape;
|
||||
xla::string fake_infeed_shape;
|
||||
bool use_fake_data = false;
|
||||
const std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("use_fake_data", &use_fake_data,
|
||||
|
@ -28,6 +28,7 @@ py_library(
|
||||
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/fused_conv:fused_conv_py",
|
||||
"//tensorflow/contrib/gan",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||
"//tensorflow/contrib/hooks",
|
||||
@ -72,6 +73,7 @@ py_library(
|
||||
"//tensorflow/contrib/staging",
|
||||
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
|
||||
"//tensorflow/contrib/stateless",
|
||||
"//tensorflow/contrib/summary:summary_ops",
|
||||
"//tensorflow/contrib/tensor_forest:init_py",
|
||||
"//tensorflow/contrib/tensorboard",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.contrib import deprecated
|
||||
from tensorflow.contrib import distributions
|
||||
from tensorflow.contrib import factorization
|
||||
from tensorflow.contrib import framework
|
||||
from tensorflow.contrib import gan
|
||||
from tensorflow.contrib import graph_editor
|
||||
from tensorflow.contrib import grid_rnn
|
||||
from tensorflow.contrib import image
|
||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
|
||||
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
|
||||
from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2
|
||||
@ -26,18 +29,21 @@ from tensorflow.contrib.learn.python.learn import export_strategy
|
||||
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.saved_model import loader as saved_model_loader
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
|
||||
|
||||
def make_custom_export_strategy(name, convert_fn, feature_columns,
|
||||
def make_custom_export_strategy(name,
|
||||
convert_fn,
|
||||
feature_columns,
|
||||
export_input_fn):
|
||||
"""Makes custom exporter of GTFlow tree format.
|
||||
|
||||
Args:
|
||||
name: A string, for the name of the export strategy.
|
||||
convert_fn: A function that converts the tree proto to desired format and
|
||||
saves it to the desired location.
|
||||
saves it to the desired location. Can be None to skip conversion.
|
||||
feature_columns: A list of feature columns.
|
||||
export_input_fn: A function that takes no arguments and returns an
|
||||
`InputFnOps`.
|
||||
@ -68,9 +74,22 @@ def make_custom_export_strategy(name, convert_fn, feature_columns,
|
||||
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||
dtec.ParseFromString(dfec_str)
|
||||
# Export the result in the same folder as the saved model.
|
||||
convert_fn(dtec, sorted_feature_names, len(dense_floats),
|
||||
len(sparse_float_indices), len(sparse_int_indices),
|
||||
result_dir, eval_result)
|
||||
if convert_fn:
|
||||
convert_fn(dtec, sorted_feature_names,
|
||||
len(dense_floats),
|
||||
len(sparse_float_indices),
|
||||
len(sparse_int_indices), result_dir, eval_result)
|
||||
feature_importances = _get_feature_importances(
|
||||
dtec, sorted_feature_names,
|
||||
len(dense_floats),
|
||||
len(sparse_float_indices), len(sparse_int_indices))
|
||||
sorted_by_importance = sorted(
|
||||
feature_importances.items(), key=lambda x: -x[1])
|
||||
assets_dir = os.path.join(result_dir, "assets.extra")
|
||||
gfile.MakeDirs(assets_dir)
|
||||
with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
|
||||
"w") as f:
|
||||
f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
|
||||
return result_dir
|
||||
return export_strategy.ExportStrategy(name, export_fn)
|
||||
|
||||
@ -157,3 +176,41 @@ def convert_to_universal_format(dtec, sorted_feature_names,
|
||||
node.left_child_id.value = split.left_id
|
||||
node.right_child_id.value = split.right_id
|
||||
return model_and_features
|
||||
|
||||
|
||||
def _get_feature_importances(dtec, feature_names, num_dense_floats,
|
||||
num_sparse_float, num_sparse_int):
|
||||
"""Export the feature importance per feature column."""
|
||||
del num_sparse_int # Unused.
|
||||
sums = collections.defaultdict(lambda: 0)
|
||||
for tree_idx in range(len(dtec.trees)):
|
||||
tree = dtec.trees[tree_idx]
|
||||
for tree_node in tree.nodes:
|
||||
node_type = tree_node.WhichOneof("node")
|
||||
if node_type == "dense_float_binary_split":
|
||||
split = tree_node.dense_float_binary_split
|
||||
split_column = feature_names[split.feature_column]
|
||||
elif node_type == "sparse_float_binary_split_default_left":
|
||||
split = tree_node.sparse_float_binary_split_default_left.split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats]
|
||||
elif node_type == "sparse_float_binary_split_default_right":
|
||||
split = tree_node.sparse_float_binary_split_default_right.split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats]
|
||||
elif node_type == "categorical_id_binary_split":
|
||||
split = tree_node.categorical_id_binary_split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats +
|
||||
num_sparse_float]
|
||||
elif node_type == "categorical_id_set_membership_binary_split":
|
||||
split = tree_node.categorical_id_set_membership_binary_split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats +
|
||||
num_sparse_float]
|
||||
elif node_type == "leaf":
|
||||
assert tree_node.node_metadata.gain == 0
|
||||
continue
|
||||
else:
|
||||
raise ValueError("Unexpected split type %s", node_type)
|
||||
# Apply shrinkage factor. It is important since it is not always uniform
|
||||
# across different trees.
|
||||
sums[split_column] += (
|
||||
tree_node.node_metadata.gain * dtec.tree_weights[tree_idx])
|
||||
return dict(sums)
|
||||
|
@ -27,7 +27,7 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
class ConvertModelTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testConvertModel(self):
|
||||
def _make_trees(self):
|
||||
dtec_str = """
|
||||
trees {
|
||||
nodes {
|
||||
@ -108,8 +108,12 @@ class ConvertModelTest(test_util.TensorFlowTestCase):
|
||||
"""
|
||||
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||
text_format.Merge(dtec_str, dtec)
|
||||
# The feature columns in the order they were added.
|
||||
feature_columns = ["feature_b", "feature_a", "feature_d"]
|
||||
return dtec, feature_columns
|
||||
|
||||
def testConvertModel(self):
|
||||
dtec, feature_columns = self._make_trees()
|
||||
# The feature columns in the order they were added.
|
||||
out = custom_export_strategy.convert_to_universal_format(
|
||||
dtec, feature_columns, 1, 1,
|
||||
1)
|
||||
@ -273,6 +277,16 @@ class ConvertModelTest(test_util.TensorFlowTestCase):
|
||||
}"""
|
||||
self.assertProtoEquals(expected_tree, out)
|
||||
|
||||
def testFeatureImportance(self):
|
||||
dtec, feature_columns = self._make_trees()
|
||||
feature_importances = custom_export_strategy._get_feature_importances(
|
||||
dtec, feature_columns, 1, 1, 1)
|
||||
self.assertItemsEqual(["feature_b", "feature_a", "feature_d"],
|
||||
feature_importances.keys())
|
||||
self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4)
|
||||
self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4)
|
||||
self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -61,11 +61,19 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
|
||||
Raises:
|
||||
ValueError: If learner_config is not valid.
|
||||
"""
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False)
|
||||
if learner_config.num_classes == 0:
|
||||
learner_config.num_classes = n_classes
|
||||
elif learner_config.num_classes != n_classes:
|
||||
raise ValueError("n_classes (%d) doesn't match learner_config (%d)." %
|
||||
(learner_config.num_classes, n_classes))
|
||||
super(GradientBoostedDecisionTreeClassifier, self).__init__(
|
||||
model_fn=model.model_builder,
|
||||
params={
|
||||
@ -129,6 +137,10 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
|
||||
label_dimension=label_dimension,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False)
|
||||
if label_dimension == 1:
|
||||
learner_config.num_classes = 2
|
||||
else:
|
||||
learner_config.num_classes = label_dimension
|
||||
super(GradientBoostedDecisionTreeRegressor, self).__init__(
|
||||
model_fn=model.model_builder,
|
||||
params={
|
||||
|
@ -92,6 +92,7 @@ def model_builder(features, labels, mode, params, config):
|
||||
examples_per_layer=examples_per_layer,
|
||||
learner_config=learner_config,
|
||||
feature_columns=feature_columns,
|
||||
logits_dimension=head.logits_dimension,
|
||||
features=features)
|
||||
with ops.name_scope("gbdt", "gbdt_optimizer"):
|
||||
predictions_dict = gbdt_model.predict(mode)
|
||||
|
@ -74,7 +74,7 @@ class TreeEnsembleStampTokenOp : public OpKernel {
|
||||
decision_tree_ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_ensemble_resource));
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
Tensor* output_stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||
@ -95,7 +95,7 @@ class TreeEnsembleSerializeOp : public OpKernel {
|
||||
decision_tree_ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_ensemble_resource));
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
Tensor* output_stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||
|
@ -143,7 +143,7 @@ class GradientTreesPredictionOp : public OpKernel {
|
||||
// Release the reference to the resource once we're done using it.
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
if (use_locking_) {
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
DoCompute(context, decision_tree_ensemble_resource);
|
||||
} else {
|
||||
DoCompute(context, decision_tree_ensemble_resource);
|
||||
@ -334,7 +334,7 @@ class GradientTreesPartitionExamplesOp : public OpKernel {
|
||||
// Release the reference to the resource once we're done using it.
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
if (use_locking_) {
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
DoCompute(context, decision_tree_ensemble_resource);
|
||||
} else {
|
||||
DoCompute(context, decision_tree_ensemble_resource);
|
||||
|
@ -656,7 +656,8 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
CHECK(split->split_info.split_node().node_case() != TreeNode::NODE_NOT_SET);
|
||||
CHECK(tree_config->nodes(node_id).node_case() == TreeNode::kLeaf)
|
||||
<< "Unexpected node type to split "
|
||||
<< tree_config->nodes(node_id).node_case();
|
||||
<< tree_config->nodes(node_id).node_case() << " for node_id " << node_id
|
||||
<< ". Tree config: " << tree_config->DebugString();
|
||||
|
||||
// Add left leaf.
|
||||
int32 left_id = tree_config->nodes_size();
|
||||
@ -767,7 +768,7 @@ class TreeEnsembleStatsOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&decision_tree_ensemble_resource));
|
||||
core::ScopedUnref unref_me(decision_tree_ensemble_resource);
|
||||
mutex_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
|
@ -42,6 +42,7 @@ class BiasFeatureColumnHandlerTest : public ::testing::Test {
|
||||
example_partitions_({0, 0, 1, 3}) {
|
||||
// Set L2 regularization.
|
||||
learner_config_.mutable_regularization()->set_l2(2.0f);
|
||||
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
|
||||
// Create handler.
|
||||
handler_.reset(new BiasFeatureColumnHandler(kClassId, kSlotId, kBatchSize));
|
||||
|
@ -51,7 +51,7 @@ class CategoricalFeatureColumnHandlerTest : public ::testing::Test {
|
||||
values_(test::AsTensor<int64>({1, 2, 2, 0}, {4})) {
|
||||
// Set L2 regularization.
|
||||
learner_config_.mutable_regularization()->set_l2(2.0f);
|
||||
|
||||
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
// Create handler.
|
||||
handler_.reset(new CategoricalFeatureColumnHandler(
|
||||
kClassId, kSlotId, kBatchSize, kFeatureColumn, indices_.matrix<int64>(),
|
||||
|
@ -51,7 +51,7 @@ class DenseQuantizedFeatureColumnHandlerTest : public ::testing::Test {
|
||||
dense_quantized_values_(test::AsTensor<int32>({1, 1, 0, 1}, {4})) {
|
||||
// Set L2 regularization.
|
||||
learner_config_.mutable_regularization()->set_l2(2.0f);
|
||||
|
||||
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
// Create handler.
|
||||
handler_.reset(new DenseQuantizedFeatureColumnHandler(
|
||||
kClassId, kSlotId, kBatchSize, kFeatureColumn,
|
||||
|
@ -53,7 +53,7 @@ class SparseQuantizedFeatureColumnHandlerTest : public ::testing::Test {
|
||||
sparse_quantized_values_(test::AsTensor<int32>({1, 0, 1}, {3})) {
|
||||
// Set L2 regularization.
|
||||
learner_config_.mutable_regularization()->set_l2(2.0f);
|
||||
|
||||
learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
// Create handler.
|
||||
handler_.reset(new SparseQuantizedFeatureColumnHandler(
|
||||
kClassId, kSlotId, kBatchSize, kFeatureColumn,
|
||||
|
@ -30,6 +30,7 @@ const double kDelta = 1e-5;
|
||||
|
||||
TEST(NodeStatsTest, AlmostZero) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
NodeStats node_stats(learner_config, GradientStats(1e-8f, 1e-8f));
|
||||
EXPECT_EQ(0, node_stats.weight_contribution[0]);
|
||||
EXPECT_EQ(0, node_stats.gain);
|
||||
@ -37,6 +38,7 @@ TEST(NodeStatsTest, AlmostZero) {
|
||||
|
||||
TEST(NodeStatsTest, LessThanMinWeightConstraint) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_constraints()->set_min_node_weight(3.2f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
EXPECT_EQ(0, node_stats.weight_contribution[0]);
|
||||
@ -45,6 +47,7 @@ TEST(NodeStatsTest, LessThanMinWeightConstraint) {
|
||||
|
||||
TEST(NodeStatsTest, L1RegSquashed) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l1(10.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
EXPECT_EQ(0, node_stats.weight_contribution[0]);
|
||||
@ -53,6 +56,7 @@ TEST(NodeStatsTest, L1RegSquashed) {
|
||||
|
||||
TEST(NodeStatsTest, L1RegPos) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l1(5.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
const float expected_clipped_grad = 7.32f - 5.0f;
|
||||
@ -66,6 +70,7 @@ TEST(NodeStatsTest, L1RegPos) {
|
||||
|
||||
TEST(NodeStatsTest, L1RegNeg) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l1(5.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(-7.32f, 1.63f));
|
||||
const float expected_clipped_grad = -7.32f + 5.0f;
|
||||
@ -79,6 +84,7 @@ TEST(NodeStatsTest, L1RegNeg) {
|
||||
|
||||
TEST(NodeStatsTest, L2Reg) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l2(8.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
const float expected_denom = 1.63f + 8.0f;
|
||||
@ -91,6 +97,7 @@ TEST(NodeStatsTest, L2Reg) {
|
||||
|
||||
TEST(NodeStatsTest, L1L2Reg) {
|
||||
LearnerConfig learner_config;
|
||||
learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
|
||||
learner_config.mutable_regularization()->set_l1(5.0f);
|
||||
learner_config.mutable_regularization()->set_l2(8.0f);
|
||||
NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
|
||||
|
@ -15,6 +15,7 @@
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
|
||||
@ -34,10 +35,27 @@ class WeightedQuantilesSummary {
|
||||
|
||||
struct SummaryEntry {
|
||||
SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
|
||||
const WeightType& max)
|
||||
: value(v), weight(w), min_rank(min), max_rank(max) {}
|
||||
const WeightType& max) {
|
||||
// Explicitely initialize all of memory (including padding from memory
|
||||
// alignment) to allow the struct to be msan-resistant "plain old data".
|
||||
//
|
||||
// POD = http://en.cppreference.com/w/cpp/concept/PODType
|
||||
memset(this, 0, sizeof(*this));
|
||||
|
||||
SummaryEntry() : value(0), weight(0), min_rank(0), max_rank(0) {}
|
||||
value = v;
|
||||
weight = w;
|
||||
min_rank = min;
|
||||
max_rank = max;
|
||||
}
|
||||
|
||||
SummaryEntry() {
|
||||
memset(this, 0, sizeof(*this));
|
||||
|
||||
value = 0;
|
||||
weight = 0;
|
||||
min_rank = 0;
|
||||
max_rank = 0;
|
||||
}
|
||||
|
||||
bool operator==(const SummaryEntry& other) const {
|
||||
return value == other.value && weight == other.weight &&
|
||||
|
@ -17,7 +17,7 @@ message TreeRegularizationConfig {
|
||||
|
||||
// Tree constraints config.
|
||||
message TreeConstraintsConfig {
|
||||
// Maximum depth of the trees.
|
||||
// Maximum depth of the trees. The default value is 6 if not specified.
|
||||
uint32 max_tree_depth = 1;
|
||||
|
||||
// Min hessian weight per node.
|
||||
@ -86,20 +86,22 @@ message LearningRateDropoutDrivenConfig {
|
||||
|
||||
message LearnerConfig {
|
||||
enum PruningMode {
|
||||
PRE_PRUNE = 0;
|
||||
POST_PRUNE = 1;
|
||||
PRUNING_MODE_UNSPECIFIED = 0;
|
||||
PRE_PRUNE = 1;
|
||||
POST_PRUNE = 2;
|
||||
}
|
||||
|
||||
enum GrowingMode {
|
||||
WHOLE_TREE = 0;
|
||||
// Layer by layer is only supported by the batch learner.
|
||||
LAYER_BY_LAYER = 1;
|
||||
GROWING_MODE_UNSPECIFIED = 0;
|
||||
WHOLE_TREE = 1;
|
||||
LAYER_BY_LAYER = 2;
|
||||
}
|
||||
|
||||
enum MultiClassStrategy {
|
||||
TREE_PER_CLASS = 0;
|
||||
FULL_HESSIAN = 1;
|
||||
DIAGONAL_HESSIAN = 2;
|
||||
MULTI_CLASS_STRATEGY_UNSPECIFIED = 0;
|
||||
TREE_PER_CLASS = 1;
|
||||
FULL_HESSIAN = 2;
|
||||
DIAGONAL_HESSIAN = 3;
|
||||
}
|
||||
|
||||
// Number of classes.
|
||||
@ -118,16 +120,18 @@ message LearnerConfig {
|
||||
// Constraints.
|
||||
TreeConstraintsConfig constraints = 5;
|
||||
|
||||
// Pruning.
|
||||
// Pruning. POST_PRUNE is the default pruning mode.
|
||||
PruningMode pruning_mode = 8;
|
||||
|
||||
// Growing Mode.
|
||||
// Growing Mode. LAYER_BY_LAYER is the default growing mode.
|
||||
GrowingMode growing_mode = 9;
|
||||
|
||||
// Learning rate.
|
||||
// Learning rate. By default we use fixed learning rate of 0.1.
|
||||
LearningRateConfig learning_rate_tuner = 6;
|
||||
|
||||
// Multi-class strategy.
|
||||
// Multi-class strategy. By default we use TREE_PER_CLASS for binary
|
||||
// classification and linear regression. For other cases, we use
|
||||
// DIAGONAL_HESSIAN as the default.
|
||||
MultiClassStrategy multi_class_strategy = 10;
|
||||
|
||||
// If you want to average the ensembles (for regularization), provide the
|
||||
|
@ -344,6 +344,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
|
||||
# Prepare learner config.
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.num_classes = 2
|
||||
learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE
|
||||
|
||||
result, result_no_dropout, dropout_info = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
|
@ -261,6 +261,7 @@ class GradientBoostedDecisionTreeModel(object):
|
||||
examples_per_layer,
|
||||
learner_config,
|
||||
features,
|
||||
logits_dimension,
|
||||
feature_columns=None):
|
||||
"""Construct a new GradientBoostedDecisionTreeModel function.
|
||||
|
||||
@ -273,8 +274,8 @@ class GradientBoostedDecisionTreeModel(object):
|
||||
a tree layer. It can also be a function that computes the number of
|
||||
examples based on the depth of the layer that's being built.
|
||||
learner_config: A learner config.
|
||||
print split, sorted_feature_names[split.feature_column]
|
||||
features: `dict` of `Tensor` objects.
|
||||
logits_dimension: An int, the dimension of logits.
|
||||
feature_columns: A list of feature columns.
|
||||
|
||||
Raises:
|
||||
@ -289,11 +290,39 @@ class GradientBoostedDecisionTreeModel(object):
|
||||
if learner_config.num_classes < 2:
|
||||
raise ValueError("Number of classes must be >=2")
|
||||
|
||||
self._logits_dimension = logits_dimension
|
||||
self._is_chief = is_chief
|
||||
self._num_ps_replicas = num_ps_replicas
|
||||
self._ensemble_handle = ensemble_handle
|
||||
self._center_bias = center_bias
|
||||
self._examples_per_layer = examples_per_layer
|
||||
|
||||
# Fill in the defaults.
|
||||
if (learner_config.multi_class_strategy ==
|
||||
learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED):
|
||||
if logits_dimension == 1:
|
||||
learner_config.multi_class_strategy = (
|
||||
learner_pb2.LearnerConfig.TREE_PER_CLASS)
|
||||
else:
|
||||
learner_config.multi_class_strategy = (
|
||||
learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
|
||||
|
||||
if (learner_config.growing_mode ==
|
||||
learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED):
|
||||
learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
|
||||
|
||||
if (learner_config.pruning_mode ==
|
||||
learner_pb2.LearnerConfig.PRUNING_MODE_UNSPECIFIED):
|
||||
learner_config.pruning_mode = learner_pb2.LearnerConfig.POST_PRUNE
|
||||
|
||||
if learner_config.constraints.max_tree_depth == 0:
|
||||
# Use 6 as the default maximum depth.
|
||||
learner_config.constraints.max_tree_depth = 6
|
||||
|
||||
tuner = learner_config.learning_rate_tuner.WhichOneof("tuner")
|
||||
if not tuner:
|
||||
learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
|
||||
|
||||
self._learner_config = learner_config
|
||||
self._feature_columns = feature_columns
|
||||
self._learner_config_serialized = learner_config.SerializeToString()
|
||||
@ -378,75 +407,81 @@ class GradientBoostedDecisionTreeModel(object):
|
||||
local_stamp), _refresh_local_ensemble_fn,
|
||||
lambda: (control_flow_ops.no_op(), ensemble_stamp))
|
||||
|
||||
# Once updated, Use the the local model for prediction.
|
||||
# Once updated, use the local model for prediction.
|
||||
with ops.control_dependencies([refresh_local_ensemble]):
|
||||
ensemble_stats = training_ops.tree_ensemble_stats(
|
||||
local_ensemble_handle, ensemble_stamp)
|
||||
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
|
||||
# We don't need dropout info - we can always restore it based on the
|
||||
# seed.
|
||||
predictions, predictions_no_dropout, _ = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
local_ensemble_handle,
|
||||
seed,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
learner_config=self._learner_config_serialized,
|
||||
apply_dropout=apply_dropout,
|
||||
apply_averaging=apply_averaging,
|
||||
use_locking=False,
|
||||
center_bias=self._center_bias,
|
||||
reduce_dim=self._reduce_dim))
|
||||
partition_ids = prediction_ops.gradient_trees_partition_examples(
|
||||
local_ensemble_handle,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
use_locking=False)
|
||||
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
|
||||
# Make sure ensemble stats run. This will check that the ensemble has
|
||||
# the right stamp.
|
||||
with ops.control_dependencies(ensemble_stats):
|
||||
predictions, predictions_no_dropout, _ = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
local_ensemble_handle,
|
||||
seed,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
learner_config=self._learner_config_serialized,
|
||||
apply_dropout=apply_dropout,
|
||||
apply_averaging=apply_averaging,
|
||||
use_locking=True,
|
||||
center_bias=self._center_bias,
|
||||
reduce_dim=self._reduce_dim))
|
||||
partition_ids = prediction_ops.gradient_trees_partition_examples(
|
||||
local_ensemble_handle,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
use_locking=True)
|
||||
|
||||
else:
|
||||
with ops.device(self._ensemble_handle.device):
|
||||
ensemble_stats = training_ops.tree_ensemble_stats(
|
||||
self._ensemble_handle, ensemble_stamp)
|
||||
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
|
||||
# We don't need dropout info - we can always restore it based on the
|
||||
# seed.
|
||||
predictions, predictions_no_dropout, _ = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
self._ensemble_handle,
|
||||
seed,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
learner_config=self._learner_config_serialized,
|
||||
apply_dropout=apply_dropout,
|
||||
apply_averaging=apply_averaging,
|
||||
use_locking=False,
|
||||
center_bias=self._center_bias,
|
||||
reduce_dim=self._reduce_dim))
|
||||
partition_ids = prediction_ops.gradient_trees_partition_examples(
|
||||
self._ensemble_handle,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
use_locking=False)
|
||||
apply_dropout, seed = _dropout_params(mode, ensemble_stats)
|
||||
# Make sure ensemble stats run. This will check that the ensemble has
|
||||
# the right stamp.
|
||||
with ops.control_dependencies(ensemble_stats):
|
||||
predictions, predictions_no_dropout, _ = (
|
||||
prediction_ops.gradient_trees_prediction(
|
||||
self._ensemble_handle,
|
||||
seed,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
learner_config=self._learner_config_serialized,
|
||||
apply_dropout=apply_dropout,
|
||||
apply_averaging=apply_averaging,
|
||||
use_locking=True,
|
||||
center_bias=self._center_bias,
|
||||
reduce_dim=self._reduce_dim))
|
||||
partition_ids = prediction_ops.gradient_trees_partition_examples(
|
||||
self._ensemble_handle,
|
||||
self._dense_floats,
|
||||
self._sparse_float_indices,
|
||||
self._sparse_float_values,
|
||||
self._sparse_float_shapes,
|
||||
self._sparse_int_indices,
|
||||
self._sparse_int_values,
|
||||
self._sparse_int_shapes,
|
||||
use_locking=True)
|
||||
|
||||
return _make_predictions_dict(ensemble_stamp, predictions,
|
||||
predictions_no_dropout, partition_ids,
|
||||
|
@ -164,7 +164,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -268,7 +268,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=num_examples_fn,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -371,7 +371,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -442,7 +442,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -505,7 +505,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
|
||||
@ -588,7 +588,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=1, features=features)
|
||||
|
||||
# Create predict op.
|
||||
mode = model_fn.ModeKeys.EVAL
|
||||
@ -627,7 +627,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=5, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
|
||||
@ -730,7 +730,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=5, features=features)
|
||||
|
||||
predictions = array_ops.constant(
|
||||
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
|
||||
@ -833,7 +833,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
ensemble_handle=ensemble_handle,
|
||||
examples_per_layer=1,
|
||||
learner_config=learner_config,
|
||||
features=features)
|
||||
logits_dimension=5, features=features)
|
||||
|
||||
batch_size = 3
|
||||
predictions = array_ops.constant(
|
||||
|
@ -33,6 +33,7 @@ option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for cont
|
||||
option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF)
|
||||
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
|
||||
option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions")
|
||||
option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON)
|
||||
|
||||
if (NOT WIN32)
|
||||
# Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option
|
||||
@ -204,6 +205,12 @@ if(tensorflow_ENABLE_JEMALLOC_SUPPORT)
|
||||
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES jemalloc)
|
||||
include_directories(${jemalloc_INCLUDE_DIRS})
|
||||
endif()
|
||||
if(tensorflow_ENABLE_SNAPPY_SUPPORT)
|
||||
include(snappy)
|
||||
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${snappy_STATIC_LIBRARIES})
|
||||
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES snappy)
|
||||
include_directories(${snappy_INCLUDE_DIR})
|
||||
endif()
|
||||
if(WIN32)
|
||||
list(APPEND tensorflow_EXTERNAL_LIBRARIES wsock32 ws2_32 shlwapi)
|
||||
endif()
|
||||
|
@ -17,7 +17,7 @@ include (ExternalProject)
|
||||
set(boringssl_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src/boringssl/include)
|
||||
#set(boringssl_EXTRA_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src)
|
||||
set(boringssl_URL https://boringssl.googlesource.com/boringssl)
|
||||
set(boringssl_TAG 17cf2cb1d226b0ba2401304242df7ddd3b6f1ff2)
|
||||
set(boringssl_TAG ee7aa02)
|
||||
set(boringssl_BUILD ${CMAKE_BINARY_DIR}/boringssl/src/boringssl-build)
|
||||
#set(boringssl_LIBRARIES ${boringssl_BUILD}/obj/so/libboringssl.so)
|
||||
set(boringssl_STATIC_LIBRARIES
|
||||
|
4
tensorflow/contrib/cmake/external/cub.cmake
vendored
4
tensorflow/contrib/cmake/external/cub.cmake
vendored
@ -14,8 +14,8 @@
|
||||
# ==============================================================================
|
||||
include (ExternalProject)
|
||||
|
||||
set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/69ceda618313df8e9cac6659d607b08949455d14.tar.gz)
|
||||
set(cub_HASH SHA256=87e856522c283b8ea887c3b61d7d5b252d2dd74abac4f1d756d776e721223e82)
|
||||
set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip)
|
||||
set(cub_HASH SHA256=b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe)
|
||||
set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
|
||||
set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
|
||||
set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive)
|
||||
|
50
tensorflow/contrib/cmake/external/snappy.cmake
vendored
Normal file
50
tensorflow/contrib/cmake/external/snappy.cmake
vendored
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
include (ExternalProject)
|
||||
|
||||
set(snappy_URL https://github.com/google/snappy.git)
|
||||
set(snappy_TAG "55924d11095df25ab25c405fadfe93d0a46f82eb")
|
||||
set(snappy_BUILD ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy)
|
||||
set(snappy_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy)
|
||||
|
||||
if(WIN32)
|
||||
set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/$(Configuration)/snappy.lib)
|
||||
else()
|
||||
set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/libsnappy.a)
|
||||
endif()
|
||||
|
||||
set(snappy_HEADERS
|
||||
"${snappy_INCLUDE_DIR}/snappy.h"
|
||||
)
|
||||
|
||||
ExternalProject_Add(snappy
|
||||
PREFIX snappy
|
||||
GIT_REPOSITORY ${snappy_URL}
|
||||
GIT_TAG ${snappy_TAG}
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
BUILD_IN_SOURCE 1
|
||||
INSTALL_COMMAND ""
|
||||
LOG_DOWNLOAD ON
|
||||
LOG_CONFIGURE ON
|
||||
LOG_BUILD ON
|
||||
CMAKE_CACHE_ARGS
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||
-DSNAPPY_BUILD_TESTS:BOOL=OFF
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
)
|
||||
|
||||
# actually enables snappy in the source code
|
||||
add_definitions(-DSNAPPY)
|
@ -18,6 +18,7 @@
|
||||
set(tf_c_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
|
||||
"${tensorflow_source_dir}/tensorflow/c/c_api_function.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
|
||||
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"
|
||||
|
@ -315,6 +315,7 @@ add_python_module("tensorflow/contrib/framework/ops")
|
||||
add_python_module("tensorflow/contrib/framework/python")
|
||||
add_python_module("tensorflow/contrib/framework/python/framework")
|
||||
add_python_module("tensorflow/contrib/framework/python/ops")
|
||||
add_python_module("tensorflow/contrib/gan")
|
||||
add_python_module("tensorflow/contrib/graph_editor")
|
||||
add_python_module("tensorflow/contrib/graph_editor/examples")
|
||||
add_python_module("tensorflow/contrib/graph_editor/tests")
|
||||
|
@ -240,6 +240,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/server_lib_test.py" # Test occasionally deadlocks.
|
||||
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops
|
||||
# Broken tensorboard test due to cmake issues.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
||||
@ -291,6 +293,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
# Failing with TF 1.3 (TODO)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py"
|
||||
# Test should only be run manually
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py"
|
||||
)
|
||||
endif()
|
||||
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
|
||||
|
@ -716,6 +716,482 @@ _cudnn_rnn_common_doc_string = """
|
||||
"""
|
||||
|
||||
|
||||
def _check_direction(direction):
|
||||
if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
|
||||
raise ValueError("Invalid direction: %s, expect %s or %s" %
|
||||
(direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION))
|
||||
|
||||
|
||||
def _check_rnn_mode(rnn_mode):
|
||||
if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU):
|
||||
raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" %
|
||||
(rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH,
|
||||
CUDNN_RNN_RELU))
|
||||
|
||||
|
||||
def _get_seed(seed):
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
if seed is None and seed2 is None:
|
||||
seed, seed2 = 0, 0
|
||||
return seed, seed2
|
||||
|
||||
|
||||
def _get_num_params(rnn_mode, num_layers, direction):
|
||||
"""Return num params for given Cudnn config."""
|
||||
if rnn_mode == CUDNN_LSTM:
|
||||
num_params_per_layer = 8
|
||||
elif rnn_mode == CUDNN_GRU:
|
||||
num_params_per_layer = 6
|
||||
elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH):
|
||||
num_params_per_layer = 2
|
||||
else:
|
||||
raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode)
|
||||
num_params = num_layers * num_params_per_layer
|
||||
if direction != CUDNN_RNN_UNIDIRECTION:
|
||||
num_params *= 2
|
||||
return num_params
|
||||
|
||||
|
||||
def _cudnn_rnn(inputs,
|
||||
input_h,
|
||||
input_c,
|
||||
params,
|
||||
is_training,
|
||||
rnn_mode,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn RNN.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
||||
A Tensor of the same shape as input_h.
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h, output_c
|
||||
"""
|
||||
_check_rnn_mode(rnn_mode)
|
||||
_check_direction(direction)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
||||
input=inputs,
|
||||
input_h=input_h,
|
||||
input_c=input_c,
|
||||
params=params,
|
||||
is_training=is_training,
|
||||
rnn_mode=rnn_mode,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
name=name)
|
||||
return (outputs, output_h, output_c)
|
||||
|
||||
|
||||
def cudnn_lstm(inputs,
|
||||
input_h,
|
||||
input_c,
|
||||
params,
|
||||
is_training,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn LSTM.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
||||
A Tensor of the same shape as input_h.
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h, output_c
|
||||
"""
|
||||
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
|
||||
input_mode, direction, dropout, seed, name)
|
||||
|
||||
|
||||
def _cudnn_rnn_no_input_c(inputs,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
rnn_mode,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn RNN w/o input_c.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h
|
||||
"""
|
||||
input_c = array_ops.constant([], dtype=input_h.dtype)
|
||||
outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params,
|
||||
is_training, rnn_mode, input_mode,
|
||||
direction, dropout, seed, name)
|
||||
return outputs, output_h
|
||||
|
||||
|
||||
def cudnn_gru(inputs,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn GRU.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h
|
||||
"""
|
||||
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU,
|
||||
input_mode, direction, dropout, seed, name)
|
||||
|
||||
|
||||
def cudnn_rnn_relu(inputs,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn RNN Relu.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h
|
||||
"""
|
||||
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
|
||||
CUDNN_RNN_RELU, input_mode, direction, dropout,
|
||||
seed, name)
|
||||
|
||||
|
||||
def cudnn_rnn_tanh(inputs,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Cudnn RNN Tanh.
|
||||
|
||||
Args:
|
||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
outputs, output_h
|
||||
"""
|
||||
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
|
||||
CUDNN_RNN_TANH, input_mode, direction, dropout,
|
||||
seed, name)
|
||||
|
||||
|
||||
def cudnn_rnn_params_to_canonical(rnn_mode,
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
params,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Convert cudnn opaque params to canonical.
|
||||
|
||||
Args:
|
||||
rnn_mode: a string specifies the mode, under which this RNN model runs.
|
||||
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
|
||||
num_layers: the number of layers for the RNN model.
|
||||
num_units: the number of units within the RNN model.
|
||||
input_size: the size of the input, it could be different from the
|
||||
num_units.
|
||||
params: opaque cudnn params var.
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
weights list and bias list
|
||||
Raises:
|
||||
ValueError: if rnn_mode or direction is invalid.
|
||||
"""
|
||||
|
||||
_check_rnn_mode(rnn_mode)
|
||||
_check_direction(direction)
|
||||
num_params = _get_num_params(rnn_mode, num_layers, direction)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
params=params,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
num_params=num_params,
|
||||
name=name)
|
||||
return weights, biases
|
||||
|
||||
|
||||
def cudnn_rnn_canonical_to_params(rnn_mode,
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
weights,
|
||||
biases,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Converts params from the canonical format to a specific format of cuDNN.
|
||||
|
||||
Args:
|
||||
rnn_mode: a string specifies the mode, under which this RNN model runs.
|
||||
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
|
||||
num_layers: the number of layers for the RNN model.
|
||||
num_units: the number of units within the RNN model.
|
||||
input_size: the size of the input, it could be different from the
|
||||
num_units.
|
||||
weights: a Tensor for weight parameters.
|
||||
biases: a Tensor for bias parameters.
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
an opaque Cudnn param.
|
||||
Raises:
|
||||
ValueError: if rnn_mode or direction is invalid.
|
||||
"""
|
||||
_check_rnn_mode(rnn_mode)
|
||||
_check_direction(direction)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
weights=weights,
|
||||
biases=biases,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
name=name)
|
||||
|
||||
|
||||
def cudnn_opaque_params_size(rnn_mode,
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dtype=dtypes.float32,
|
||||
dropout=0,
|
||||
seed=0,
|
||||
name=None):
|
||||
"""Returns opaque params size for specific Cudnn config.
|
||||
|
||||
Args:
|
||||
rnn_mode: a string specifies the mode, under which this RNN model runs.
|
||||
Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
|
||||
num_layers: the number of layers for the RNN model.
|
||||
num_units: the number of units within the RNN model.
|
||||
input_size: the size of the input, it could be different from the
|
||||
num_units.
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be
|
||||
'linear_input', 'skip_input' or 'auto_select'.
|
||||
'linear_input' (default) always applies a linear projection of input
|
||||
onto RNN hidden state. (standard RNN behavior).
|
||||
'skip_input' is only allowed when input_size == num_units;
|
||||
'auto_select' implies 'skip_input' when input_size == num_units;
|
||||
otherwise, it implies 'linear_input'.
|
||||
direction: the direction model that the model operates. Could be either
|
||||
'unidirectional' or 'bidirectional'
|
||||
dtype: one of tf.float32 or tf.float64.
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
|
||||
for behavior.
|
||||
name: name of the operation.
|
||||
Returns:
|
||||
a int, size of Cudnn opaque params.
|
||||
Raises:
|
||||
ValueError: if rnn_mode or direction is invalid.
|
||||
"""
|
||||
_check_rnn_mode(rnn_mode)
|
||||
_check_direction(direction)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
T=dtype,
|
||||
S=dtypes.int32,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
name=name)[0]
|
||||
|
||||
|
||||
class _CudnnRNN(object):
|
||||
"""Creates an RNN model using the underlying Cudnn implementation.
|
||||
|
||||
@ -761,9 +1237,6 @@ class _CudnnRNN(object):
|
||||
Raises:
|
||||
ValueError: if direction is invalid.
|
||||
"""
|
||||
if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
|
||||
raise ValueError("Invalid direction: %s, expect %s or %s",
|
||||
direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)
|
||||
self._num_layers = num_layers
|
||||
self._num_units = num_units
|
||||
self._input_size = input_size
|
||||
@ -772,10 +1245,7 @@ class _CudnnRNN(object):
|
||||
self._direction = direction
|
||||
self._dtype = dtype
|
||||
self._dropout = dropout
|
||||
# get graph and op seed.
|
||||
self._seed, self._seed2 = random_seed.get_seed(seed)
|
||||
if self._seed is None and self._seed2 is None:
|
||||
self._seed, self._seed2 = 0, 0
|
||||
self._seed = seed
|
||||
|
||||
@property
|
||||
def input_mode(self):
|
||||
@ -807,18 +1277,16 @@ class _CudnnRNN(object):
|
||||
Returns:
|
||||
The calculated parameter buffer size.
|
||||
"""
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
|
||||
return cudnn_opaque_params_size(
|
||||
rnn_mode=self._rnn_mode,
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
T=self._dtype,
|
||||
S=dtypes.int32,
|
||||
dtype=self._dtype,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)[0]
|
||||
direction=self._direction)
|
||||
|
||||
def __call__(self, input_data, input_h, input_c, params, is_training=True):
|
||||
"""Runs the forward step for the RNN model.
|
||||
@ -837,22 +1305,17 @@ class _CudnnRNN(object):
|
||||
output_h: the final state for h.
|
||||
output_c: the final state for c. This is only relevant for LSTM.
|
||||
"""
|
||||
if self._rnn_mode != CUDNN_LSTM:
|
||||
# For model that doesn't take input_c, replace with a dummy tensor.
|
||||
input_c = array_ops.constant([], dtype=self._dtype)
|
||||
output, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
||||
input=input_data,
|
||||
input_h=input_h,
|
||||
input_c=input_c,
|
||||
params=params,
|
||||
rnn_mode=self._rnn_mode,
|
||||
return _cudnn_rnn(
|
||||
input_data,
|
||||
input_h,
|
||||
input_c,
|
||||
params,
|
||||
is_training,
|
||||
self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
is_training=is_training)
|
||||
return (output, output_h, output_c)
|
||||
seed=self._seed)
|
||||
|
||||
def params_to_canonical(self, params):
|
||||
"""Converts params from a specific format of cuDNN to the canonical format.
|
||||
@ -863,22 +1326,16 @@ class _CudnnRNN(object):
|
||||
Returns:
|
||||
A function for the specific-to-canonical conversion.
|
||||
"""
|
||||
num_params = self._num_layers * self._NUM_PARAMS_PER_LAYER
|
||||
if self._direction != CUDNN_RNN_UNIDIRECTION:
|
||||
num_params *= 2
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||
return cudnn_rnn_params_to_canonical(
|
||||
rnn_mode=self._rnn_mode,
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
params=params,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
num_params=num_params,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)
|
||||
return weights, biases
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed)
|
||||
|
||||
def canonical_to_params(self, weights, biases):
|
||||
"""Converts params from the canonical format to a specific format of cuDNN.
|
||||
@ -890,18 +1347,17 @@ class _CudnnRNN(object):
|
||||
Returns:
|
||||
A function for the canonical-to-params-to-specific conversion..
|
||||
"""
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||
return cudnn_rnn_canonical_to_params(
|
||||
rnn_mode=self._rnn_mode,
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
weights=weights,
|
||||
biases=biases,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed,
|
||||
seed2=self._seed2,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed)
|
||||
|
||||
|
||||
class CudnnLSTM(_CudnnRNN):
|
||||
@ -1036,9 +1492,16 @@ class _CudnnRNNNoInputC(_CudnnRNN):
|
||||
output: the output sequuence.
|
||||
output_h: the final state for h.
|
||||
"""
|
||||
output, output_h, _ = super(_CudnnRNNNoInputC, self).__call__(
|
||||
input_data, input_h, None, params, is_training=is_training)
|
||||
return (output, output_h)
|
||||
return _cudnn_rnn_no_input_c(
|
||||
input_data,
|
||||
input_h,
|
||||
params,
|
||||
is_training,
|
||||
self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed)
|
||||
|
||||
|
||||
class CudnnGRU(_CudnnRNNNoInputC):
|
||||
|
@ -22,6 +22,7 @@
|
||||
|
||||
@@read_batch_features
|
||||
@@rejection_resample
|
||||
@@group_by_window
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -31,6 +32,7 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import group_by_window
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import Iterator
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample
|
||||
|
@ -37,7 +37,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
|
||||
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
|
||||
.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -61,8 +63,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
components = np.array(
|
||||
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1)
|
||||
.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -81,8 +84,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
def testSmallGroups(self):
|
||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
|
||||
dataset_ops.Dataset.from_tensor_slices(components).apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -108,8 +112,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: (x, ops.convert_to_tensor([x * x])))
|
||||
.group_by_window(lambda x, _: x % 2, reduce_func, 32))
|
||||
.map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x, _: x % 2, reduce_func, 32)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -124,17 +129,20 @@ class GroupByWindowTest(test.TestCase):
|
||||
def reduce_func(key, window):
|
||||
# Apply two different kinds of padding to the input: tight
|
||||
# padding, and quantized (to a multiple of 10) padding.
|
||||
return dataset_ops.Dataset.zip((window.padded_batch(
|
||||
4,
|
||||
padded_shapes=tensor_shape.TensorShape([None])), window.padded_batch(
|
||||
return dataset_ops.Dataset.zip((
|
||||
window.padded_batch(
|
||||
4, padded_shapes=tensor_shape.TensorShape([None])),
|
||||
window.padded_batch(
|
||||
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
|
||||
.group_by_window(
|
||||
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
|
||||
reduce_func, 4))
|
||||
.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=
|
||||
(lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
|
||||
reduce_func, 4)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -151,10 +159,9 @@ class GroupByWindowTest(test.TestCase):
|
||||
self.assertEqual(len(components), sum(counts))
|
||||
|
||||
|
||||
# NOTE(mrry): These tests are based on the tests in
|
||||
# bucket_ops_test.py. Currently, different batch sizes for each key
|
||||
# are not supported, although this would be possible to add to
|
||||
# `Dataset.group_by_window()`.
|
||||
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
|
||||
# Currently, they use a constant batch size, though should be made to use a
|
||||
# different batch size per key.
|
||||
class BucketTest(test.TestCase):
|
||||
|
||||
def _dynamicPad(self, bucket, window, window_size):
|
||||
@ -168,6 +175,7 @@ class BucketTest(test.TestCase):
|
||||
tensor_shape.TensorShape([3])))))
|
||||
|
||||
def testSingleBucket(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
@ -175,9 +183,10 @@ class BucketTest(test.TestCase):
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda x, y, z: 0, lambda k, bucket: self._dynamicPad(k, bucket, 32),
|
||||
32)
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x, y, z: 0,
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
@ -201,6 +210,7 @@ class BucketTest(test.TestCase):
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
|
||||
|
||||
def testEvenOddBuckets(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
@ -208,9 +218,10 @@ class BucketTest(test.TestCase):
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
@ -256,25 +267,31 @@ class BucketTest(test.TestCase):
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
|
||||
|
||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return {"x": v,
|
||||
"y": array_ops.fill([v], v),
|
||||
"z": array_ops.fill([3], string_ops.as_string(v))}
|
||||
return {
|
||||
"x": v,
|
||||
"y": array_ops.fill([v], v),
|
||||
"z": array_ops.fill([3], string_ops.as_string(v))
|
||||
}
|
||||
|
||||
def _dynamic_pad_fn(bucket, window, _):
|
||||
return dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensors(bucket), window.padded_batch(
|
||||
32, {"x": tensor_shape.TensorShape([]),
|
||||
"y": tensor_shape.TensorShape([None]),
|
||||
"z": tensor_shape.TensorShape([3])})))
|
||||
32, {
|
||||
"x": tensor_shape.TensorShape([]),
|
||||
"y": tensor_shape.TensorShape([None]),
|
||||
"z": tensor_shape.TensorShape([3])
|
||||
})))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
|
||||
.filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
|
||||
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
|
||||
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
@ -295,6 +312,40 @@ class BucketTest(test.TestCase):
|
||||
self.assertAllEqual(
|
||||
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
|
||||
|
||||
def testDynamicWindowSize(self):
|
||||
components = np.arange(100).astype(np.int64)
|
||||
|
||||
# Key fn: even/odd
|
||||
# Reduce fn: batches of 5
|
||||
# Window size fn: even=5, odd=10
|
||||
|
||||
def window_size_func(key):
|
||||
window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
|
||||
return window_sizes[key]
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
|
||||
dataset_ops.group_by_window,
|
||||
args=(lambda x: x % 2, lambda _, xs: xs.batch(20), None,
|
||||
window_size_func))
|
||||
iterator = dataset_ops.Iterator.from_dataset(dataset)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
batches = 0
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
is_even = all(x % 2 == 0 for x in result)
|
||||
is_odd = all(x % 2 == 1 for x in result)
|
||||
self.assertTrue(is_even or is_odd)
|
||||
expected_batch_size = 5 if is_even else 10
|
||||
self.assertEqual(expected_batch_size, result.shape[0])
|
||||
batches += 1
|
||||
|
||||
self.assertEqual(batches, 15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -481,6 +482,40 @@ class MapDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testMapNamedtuple(self, count=10):
|
||||
# construct dataset of tuples
|
||||
labels = dataset_ops.Dataset.range(count)
|
||||
images = labels.map(lambda l: -l)
|
||||
dataset_tuple = dataset_ops.Dataset.zip((labels, images))
|
||||
|
||||
# convert dataset of tuples to dataset of namedtuples
|
||||
Example = namedtuple("Example", ["label", "image"])
|
||||
dataset_namedtuple = dataset_tuple.map(Example)
|
||||
|
||||
def preprocess_tuple(label, image):
|
||||
image = 2 * image
|
||||
return label, image
|
||||
|
||||
def preprocess_namedtuple(example):
|
||||
return example._replace(image=2 * example.image)
|
||||
|
||||
# preprocess both datasets
|
||||
dataset_tuple = dataset_tuple.map(preprocess_tuple)
|
||||
dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple)
|
||||
|
||||
next_tuple = dataset_tuple.make_one_shot_iterator().get_next()
|
||||
next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next()
|
||||
|
||||
# make sure both datasets contain the same data
|
||||
with self.test_session() as sess:
|
||||
for i in range(count):
|
||||
tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple])
|
||||
self.assertEqual(tuple_, namedtuple_)
|
||||
self.assertEqual(tuple_, (i, -2 * i))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_namedtuple)
|
||||
|
||||
def testUseStepContainerInMap(self):
|
||||
row = np.arange(6)
|
||||
iterator = (
|
||||
|
@ -1199,28 +1199,9 @@ class Dataset(object):
|
||||
return DenseToSparseBatchDataset(self, batch_size, row_shape)
|
||||
|
||||
def group_by_window(self, key_func, reduce_func, window_size):
|
||||
"""Performs a windowed "group-by" operation on this dataset.
|
||||
|
||||
This method maps each consecutive element in this dataset to a key
|
||||
using `key_func` and groups the elements by key. It then applies
|
||||
`reduce_func` to at most `window_size` elements matching the same
|
||||
key. All execpt the final window for each key will contain
|
||||
`window_size` elements; the final window may be smaller.
|
||||
|
||||
Args:
|
||||
key_func: A function mapping a nested structure of tensors
|
||||
(having shapes and types defined by `self.output_shapes` and
|
||||
`self.output_types`) to a scalar `tf.int64` tensor.
|
||||
reduce_func: A function mapping a key and a dataset of up to `batch_size`
|
||||
consecutive elements matching that key to another dataset.
|
||||
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
||||
consecutive elements matching the same key to combine in a single
|
||||
batch, which will be passed to `reduce_func`.
|
||||
|
||||
Returns:
|
||||
A `Dataset`.
|
||||
"""
|
||||
return GroupByWindowDataset(self, key_func, reduce_func, window_size)
|
||||
"""See group_by_window()."""
|
||||
return self.apply(
|
||||
group_by_window, args=(key_func, reduce_func, window_size))
|
||||
|
||||
def map(self,
|
||||
map_func,
|
||||
@ -1370,6 +1351,43 @@ class Dataset(object):
|
||||
"""
|
||||
return FilterDataset(self, predicate)
|
||||
|
||||
def apply(self, fn, args=(), kwargs={}): # pylint: disable=dangerous-default-value
|
||||
"""Apply a function to this dataset.
|
||||
|
||||
`apply` enables chaining of custom `Dataset` transformations.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
dataset.map(
|
||||
lambda x: x**2
|
||||
).apply(
|
||||
group_by_window, args=(key_func, reduce_func, window_size)
|
||||
).map(
|
||||
lambda x: x**3
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
fn: A function that takes a `Dataset`, `args`, and `kwargs`, and
|
||||
returns a `Dataset`.
|
||||
args: A `tuple` or `list` of arguments to be passed to `fn`.
|
||||
kwargs: A `dict` of keyword arguments to be passed to `fn`.
|
||||
|
||||
Returns:
|
||||
The `Dataset` returned by `fn`.
|
||||
"""
|
||||
if not (isinstance(args, tuple) or isinstance(args, list)):
|
||||
raise TypeError("args must be a tuple or list.")
|
||||
if not isinstance(kwargs, dict):
|
||||
raise TypeError("kwargs must be a dict.")
|
||||
|
||||
dataset = fn(self, *args, **kwargs)
|
||||
|
||||
if not isinstance(dataset, Dataset):
|
||||
raise TypeError("fn must return a Dataset.")
|
||||
return dataset
|
||||
|
||||
|
||||
class TensorDataset(Dataset):
|
||||
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
|
||||
@ -1903,7 +1921,7 @@ class DenseToSparseBatchDataset(Dataset):
|
||||
|
||||
def _should_unpack_args(args):
|
||||
"""Returns `True` if `args` should be `*args` when passed to a callable."""
|
||||
return nest.is_sequence(args) and not isinstance(args, dict)
|
||||
return type(args) is tuple # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
|
||||
class _ResourceDataset(Dataset):
|
||||
@ -1927,71 +1945,6 @@ class _ResourceDataset(Dataset):
|
||||
return self._output_types
|
||||
|
||||
|
||||
class GroupByWindowDataset(Dataset):
|
||||
"""A `Dataset` that groups its input and performs a windowed reduction."""
|
||||
|
||||
def __init__(self, input_dataset, key_func, reduce_func, window_size):
|
||||
"""See `Dataset.group_by_window()` for details."""
|
||||
super(GroupByWindowDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
self._window_size = window_size
|
||||
|
||||
@function.Defun(*nest.flatten(input_dataset.output_types))
|
||||
def tf_key_func(*args):
|
||||
"""A wrapper for Defun that facilitates shape inference."""
|
||||
# Pass in shape information from the input_dataset.
|
||||
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
|
||||
arg.set_shape(shape)
|
||||
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
|
||||
if _should_unpack_args(nested_args):
|
||||
ret = key_func(*nested_args)
|
||||
else:
|
||||
ret = key_func(nested_args)
|
||||
ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
|
||||
if ret.dtype != dtypes.int64:
|
||||
raise ValueError("`key_func` must return a single tf.int64 tensor.")
|
||||
return ret
|
||||
|
||||
self._key_func = tf_key_func
|
||||
self._key_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
@function.Defun(dtypes.int64, dtypes.resource)
|
||||
def tf_reduce_func(key, window_dataset_resource):
|
||||
"""A wrapper for Defun that facilitates shape inference."""
|
||||
key.set_shape([])
|
||||
window_dataset = _ResourceDataset(window_dataset_resource,
|
||||
input_dataset.output_types,
|
||||
input_dataset.output_shapes)
|
||||
output_dataset = reduce_func(key, window_dataset)
|
||||
if not isinstance(output_dataset, Dataset):
|
||||
raise TypeError("`reduce_func` must return a `Dataset` object.")
|
||||
self._output_types = output_dataset.output_types
|
||||
self._output_shapes = output_dataset.output_shapes
|
||||
return output_dataset.make_dataset_resource()
|
||||
|
||||
self._reduce_func = tf_reduce_func
|
||||
self._reduce_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
def make_dataset_resource(self):
|
||||
return gen_dataset_ops.group_by_window_dataset(
|
||||
self._input_dataset.make_dataset_resource(),
|
||||
self._key_func.captured_inputs,
|
||||
self._reduce_func.captured_inputs,
|
||||
self._window_size,
|
||||
key_func=self._key_func,
|
||||
reduce_func=self._reduce_func,
|
||||
output_types=nest.flatten(self.output_types),
|
||||
output_shapes=nest.flatten(self.output_shapes))
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._output_types
|
||||
|
||||
|
||||
class MapDataset(Dataset):
|
||||
"""A `Dataset` that maps a function over elements in its input."""
|
||||
|
||||
@ -2151,7 +2104,7 @@ class InterleaveDataset(Dataset):
|
||||
|
||||
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
|
||||
|
||||
if nest.is_sequence(nested_args):
|
||||
if _should_unpack_args(nested_args):
|
||||
dataset = map_func(*nested_args)
|
||||
else:
|
||||
dataset = map_func(nested_args)
|
||||
@ -2460,7 +2413,7 @@ def rejection_resample(dataset,
|
||||
shapes and types defined by `dataset.output_shapes` and
|
||||
`dataset.output_types`) to a scalar `tf.int32` tensor. Values should
|
||||
be in `[0, num_classes)`.
|
||||
target_dist: A floating point type tensor, shaped `[num_classes].
|
||||
target_dist: A floating point type tensor, shaped `[num_classes]`.
|
||||
initial_dist: (Optional.) A floating point type tensor, shaped
|
||||
`[num_classes]`. If not provided, the true class distribution is
|
||||
estimated live in a streaming fashion.
|
||||
@ -2660,3 +2613,149 @@ def _get_file_names(file_pattern, randomize_input):
|
||||
if not randomize_input:
|
||||
file_names = sorted(file_names)
|
||||
return file_names
|
||||
|
||||
|
||||
class GroupByWindowDataset(Dataset):
|
||||
"""A `Dataset` that groups its input and performs a windowed reduction."""
|
||||
|
||||
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
|
||||
"""See `group_by_window()` for details."""
|
||||
super(GroupByWindowDataset, self).__init__()
|
||||
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._make_key_func(key_func, input_dataset)
|
||||
self._make_reduce_func(reduce_func, input_dataset)
|
||||
self._make_window_size_func(window_size_func)
|
||||
|
||||
def _make_window_size_func(self, window_size_func):
|
||||
"""Make wrapping Defun for window_size_func."""
|
||||
|
||||
@function.Defun(dtypes.int64)
|
||||
def tf_window_size_func(key):
|
||||
key.set_shape([])
|
||||
window_size = ops.convert_to_tensor(
|
||||
window_size_func(key), dtype=dtypes.int64)
|
||||
if window_size.dtype != dtypes.int64:
|
||||
raise ValueError(
|
||||
"`window_size_func` must return a single tf.int64 tensor.")
|
||||
return window_size
|
||||
|
||||
self._window_size_func = tf_window_size_func
|
||||
self._window_size_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
def _make_key_func(self, key_func, input_dataset):
|
||||
"""Make wrapping Defun for key_func."""
|
||||
|
||||
@function.Defun(*nest.flatten(input_dataset.output_types))
|
||||
def tf_key_func(*args):
|
||||
"""A wrapper for Defun that facilitates shape inference."""
|
||||
# Pass in shape information from the input_dataset.
|
||||
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
|
||||
arg.set_shape(shape)
|
||||
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
|
||||
if _should_unpack_args(nested_args):
|
||||
ret = key_func(*nested_args)
|
||||
else:
|
||||
ret = key_func(nested_args)
|
||||
ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
|
||||
if ret.dtype != dtypes.int64:
|
||||
raise ValueError("`key_func` must return a single tf.int64 tensor.")
|
||||
return ret
|
||||
|
||||
self._key_func = tf_key_func
|
||||
self._key_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
def _make_reduce_func(self, reduce_func, input_dataset):
|
||||
"""Make wrapping Defun for reduce_func."""
|
||||
|
||||
@function.Defun(dtypes.int64, dtypes.resource)
|
||||
def tf_reduce_func(key, window_dataset_resource):
|
||||
"""A wrapper for Defun that facilitates shape inference."""
|
||||
key.set_shape([])
|
||||
window_dataset = _ResourceDataset(window_dataset_resource,
|
||||
input_dataset.output_types,
|
||||
input_dataset.output_shapes)
|
||||
output_dataset = reduce_func(key, window_dataset)
|
||||
if not isinstance(output_dataset, Dataset):
|
||||
raise TypeError("`reduce_func` must return a `Dataset` object.")
|
||||
self._output_types = output_dataset.output_types
|
||||
self._output_shapes = output_dataset.output_shapes
|
||||
return output_dataset.make_dataset_resource()
|
||||
|
||||
self._reduce_func = tf_reduce_func
|
||||
self._reduce_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._output_types
|
||||
|
||||
def make_dataset_resource(self):
|
||||
return gen_dataset_ops.group_by_window_dataset(
|
||||
self._input_dataset.make_dataset_resource(),
|
||||
self._key_func.captured_inputs,
|
||||
self._reduce_func.captured_inputs,
|
||||
self._window_size_func.captured_inputs,
|
||||
key_func=self._key_func,
|
||||
reduce_func=self._reduce_func,
|
||||
window_size_func=self._window_size_func,
|
||||
output_types=nest.flatten(self.output_types),
|
||||
output_shapes=nest.flatten(self.output_shapes))
|
||||
|
||||
|
||||
def group_by_window(dataset,
|
||||
key_func,
|
||||
reduce_func,
|
||||
window_size=None,
|
||||
window_size_func=None):
|
||||
"""Performs a windowed "group-by" operation on this dataset.
|
||||
|
||||
This method maps each consecutive element in this dataset to a key
|
||||
using `key_func` and groups the elements by key. It then applies
|
||||
`reduce_func` to at most `window_size_func(key)` elements matching the same
|
||||
key. All execpt the final window for each key will contain
|
||||
`window_size_func(key)` elements; the final window may be smaller.
|
||||
|
||||
You may provide either a constant `window_size` or a window size determined by
|
||||
the key through `window_size_func`.
|
||||
|
||||
Args:
|
||||
dataset: A `Dataset`.
|
||||
key_func: A function mapping a nested structure of tensors
|
||||
(having shapes and types defined by `self.output_shapes` and
|
||||
`self.output_types`) to a scalar `tf.int64` tensor.
|
||||
reduce_func: A function mapping a key and a dataset of up to `batch_size`
|
||||
consecutive elements matching that key to another dataset.
|
||||
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
||||
consecutive elements matching the same key to combine in a single
|
||||
batch, which will be passed to `reduce_func`. Mutually exclusive with
|
||||
`window_size_func`.
|
||||
window_size_func: A function mapping a key to a `tf.int64` scalar
|
||||
`tf.Tensor`, representing the number of consecutive elements matching
|
||||
the same key to combine in a single batch, which will be passed to
|
||||
`reduce_func`. Mutually exclusive with `window_size`.
|
||||
|
||||
Returns:
|
||||
A `Dataset`.
|
||||
|
||||
Raises:
|
||||
ValueError: if neither or both of {`window_size`, `window_size_func`} are
|
||||
passed.
|
||||
"""
|
||||
if (window_size is not None and window_size_func or
|
||||
not (window_size is not None or window_size_func)):
|
||||
raise ValueError("Must pass either window_size or window_size_func.")
|
||||
|
||||
if window_size is not None:
|
||||
|
||||
def constant_window_func(unused_key):
|
||||
return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
|
||||
|
||||
window_size_func = constant_window_func
|
||||
|
||||
assert window_size_func is not None
|
||||
return GroupByWindowDataset(dataset, key_func, reduce_func, window_size_func)
|
||||
|
@ -341,7 +341,7 @@ cuda_py_test(
|
||||
|
||||
cuda_py_test(
|
||||
name = "sample_stats_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/sample_stats_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
|
@ -150,7 +150,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
|
||||
`N - 1` dimensions index into a batch of independent distributions and
|
||||
the last dimension represents a vector of probabilities for each
|
||||
class. Only one of `logits` or `probs` should be passed in.
|
||||
dtype: The type of the event samples (default: int32).
|
||||
dtype: The type of the event samples (default: float32).
|
||||
validate_args: Python `bool`, default `False`. When `True` distribution
|
||||
parameters are checked for validity despite possibly degrading runtime
|
||||
performance. When `False` invalid inputs may silently render incorrect
|
||||
@ -388,7 +388,7 @@ class RelaxedOneHotCategorical(
|
||||
dimensions index into a batch of independent distributions and the last
|
||||
dimension represents a vector of probabilities for each class. Only one
|
||||
of `logits` or `probs` should be passed in.
|
||||
dtype: The type of the event samples (default: int32).
|
||||
dtype: The type of the event samples (default: float32).
|
||||
validate_args: Unused in this distribution.
|
||||
allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
|
||||
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||
|
@ -17,440 +17,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_checkpoint_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
ops.NotDifferentiable("GenerateVocabRemapping")
|
||||
ops.NotDifferentiable("LoadAndRemapMatrix")
|
||||
from tensorflow.python.training import checkpoint_ops
|
||||
|
||||
|
||||
def _load_and_remap_matrix(ckpt_path,
|
||||
old_tensor_name,
|
||||
new_row_vocab_offset,
|
||||
num_rows_to_load,
|
||||
new_col_vocab_size,
|
||||
initializer,
|
||||
old_row_vocab_file=None,
|
||||
new_row_vocab_file=None,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=0,
|
||||
num_col_oov_buckets=0,
|
||||
max_rows_in_memory=-1):
|
||||
"""Loads a 2-D (matrix) `Tensor` from checkpoint.
|
||||
|
||||
Generates 1D-remappings for rows and columns using the
|
||||
`GenerateVocabRemapping` op, and initializes any anticipated values with the
|
||||
provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a
|
||||
matrix that loads existing values from the checkpoint, while filling out
|
||||
"missing" values with the newly initialized values. See
|
||||
contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped
|
||||
functionality (LoadAndRemapMatrix). This wrapper can be used to perform only
|
||||
row remapping or only col remapping. If only row remapping is desired,
|
||||
{new,old}_col_vocab_file should be `None`, and vice versa for column
|
||||
remapping.
|
||||
|
||||
NOTE: This only supports div-partitioning the vocabulary on the 1st dimension
|
||||
(row axis) via `new_row_vocab_offset`.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
new_row_vocab_offset: A 0-indexed integer representing what line to
|
||||
start reading at in the new row vocabulary. Used for partitioned
|
||||
variables.
|
||||
num_rows_to_load: Number of rows to load for the new vocabulary (note: to
|
||||
support variable partitioning and partial loading, this does not need to
|
||||
be the same as the number of entries in `new_row_vocab_file`).
|
||||
new_col_vocab_size: Number of columns to load - should be the same as the
|
||||
number of entries in `new_col_vocab_file`, since we don't support
|
||||
partitioning along the column axis.
|
||||
initializer: Callable initializer function that accepts a 1-D tensor as the
|
||||
arg to specify the shape of the returned tensor. Used to initialize
|
||||
missing values.
|
||||
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old row vocabulary file. Can be None, which represents no
|
||||
remapping on the row axis.
|
||||
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new row vocabulary file. Can be None, which represents no remapping
|
||||
on the row axis - in which case, `new_row_vocab_offset` and
|
||||
`num_rows_to_load` work under the assumption that the new row vocab is the
|
||||
same as the old row vocab.
|
||||
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis - in which case, `new_col_vocab_size` works
|
||||
under the assumption that the new col vocab is the same as the old col
|
||||
vocab.
|
||||
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
|
||||
to append. Must be >= 0.
|
||||
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
columns to append. Must be >= 0.
|
||||
max_rows_in_memory: `int` specifying the maximum number of rows to load from
|
||||
the checkpoint at once. If less than or equal to 0, the entire matrix will
|
||||
be loaded into memory. Setting this arg trades increased disk reads for
|
||||
lower memory usage.
|
||||
|
||||
Returns:
|
||||
A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
|
||||
new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the
|
||||
specified tensor in the checkpoint, and any missing or OOV values
|
||||
initialized with the given `initializer`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0.
|
||||
ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is
|
||||
provided, while the other is not. Same for `old_col_vocab_file` and
|
||||
`new_col_vocab_file`.
|
||||
ValueError: If neither row vocabs or col vocabs are provided.
|
||||
"""
|
||||
if num_row_oov_buckets < 0:
|
||||
raise ValueError("num_row_oov_buckets must be >= 0, but received %d" %
|
||||
num_row_oov_buckets)
|
||||
if num_col_oov_buckets < 0:
|
||||
raise ValueError("num_col_oov_buckets must be >= 0, but received %d" %
|
||||
num_col_oov_buckets)
|
||||
|
||||
if bool(old_row_vocab_file) != bool(new_row_vocab_file):
|
||||
raise ValueError(
|
||||
"old_row_vocab_file and new_row_vocab_file must both be specified or "
|
||||
"left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'".
|
||||
format(old_row_vocab_file, new_row_vocab_file))
|
||||
if bool(old_col_vocab_file) != bool(new_col_vocab_file):
|
||||
raise ValueError(
|
||||
"old_col_vocab_file and new_col_vocab_file must both be specified or "
|
||||
"left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'".
|
||||
format(old_col_vocab_file, new_col_vocab_file))
|
||||
|
||||
remap_rows = new_row_vocab_file and old_row_vocab_file
|
||||
remap_cols = new_col_vocab_file and old_col_vocab_file
|
||||
if not (remap_rows or remap_cols):
|
||||
raise ValueError(
|
||||
"Must provide either row or column vocab files. If no remapping is "
|
||||
"necessary, consider using `tf.contrib.framework.init_from_checkpoint` "
|
||||
"instead.")
|
||||
|
||||
num_rows_present = num_rows_to_load
|
||||
if remap_rows:
|
||||
row_remapping, num_rows_present = (
|
||||
gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access
|
||||
new_vocab_file=new_row_vocab_file,
|
||||
old_vocab_file=old_row_vocab_file,
|
||||
new_vocab_offset=new_row_vocab_offset,
|
||||
num_new_vocab=num_rows_to_load))
|
||||
else:
|
||||
# Even when the rows are not being reordered, we still need to generate a
|
||||
# remapping to account for initializing partitioned Variables (when
|
||||
# new_row_vocab_offset is non-zero).
|
||||
row_remapping = math_ops.range(
|
||||
new_row_vocab_offset,
|
||||
new_row_vocab_offset + num_rows_to_load,
|
||||
dtype=dtypes.int64)
|
||||
|
||||
col_remapping = []
|
||||
num_cols_present = new_col_vocab_size
|
||||
if remap_cols:
|
||||
col_remapping, num_cols_present = (
|
||||
gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access
|
||||
new_vocab_file=new_col_vocab_file,
|
||||
old_vocab_file=old_col_vocab_file,
|
||||
new_vocab_offset=0, # Offset is unused for cols (no partitioning).
|
||||
num_new_vocab=new_col_vocab_size))
|
||||
|
||||
init_vals = initializer([
|
||||
num_rows_to_load * new_col_vocab_size -
|
||||
num_rows_present * num_cols_present, 1
|
||||
])
|
||||
return_tensor = gen_checkpoint_ops._load_and_remap_matrix( # pylint: disable=protected-access
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=old_tensor_name,
|
||||
row_remapping=row_remapping,
|
||||
col_remapping=col_remapping,
|
||||
initializing_values=init_vals,
|
||||
num_rows=num_rows_to_load,
|
||||
num_cols=new_col_vocab_size,
|
||||
max_rows_in_memory=max_rows_in_memory)
|
||||
|
||||
# Add OOV row(s) and column(s).
|
||||
if num_row_oov_buckets > 0:
|
||||
init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size])
|
||||
init_row_oov_val = ops.convert_to_tensor(init_row_oov_val)
|
||||
return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0)
|
||||
if num_col_oov_buckets > 0:
|
||||
# We need to add any row OOV to the new column shape.
|
||||
init_col_oov_val = initializer(
|
||||
[num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets])
|
||||
init_col_oov_val = ops.convert_to_tensor(init_col_oov_val)
|
||||
return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1)
|
||||
|
||||
return return_tensor
|
||||
|
||||
|
||||
def load_and_remap_matrix_initializer(ckpt_path,
|
||||
old_tensor_name,
|
||||
new_row_vocab_size,
|
||||
new_col_vocab_size,
|
||||
old_row_vocab_file=None,
|
||||
new_row_vocab_file=None,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=0,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=None,
|
||||
max_rows_in_memory=-1):
|
||||
r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor.
|
||||
|
||||
The returned initializer loads a 2-D (matrix) `Tensor` with name
|
||||
`old_tensor_name` from the checkpoint at `ckpt_path`. It will reorder the
|
||||
rows/columns according to the specified vocab files and append additional
|
||||
out-of-vocabulary rows/columns according to the number of OOV buckets.
|
||||
|
||||
The format of the file at the `{old,new}_{row,col}_vocab_file` path should be
|
||||
a text file, with each line containing a single entity within the vocabulary.
|
||||
Let the function `line_of(f, "x")` return the 0-indexed line number of the
|
||||
entity "x" in file f, and the function `entity_at(f, i)` return the entity at
|
||||
line i of file f. Then, row i of the new output matrix will be taken from row
|
||||
`line_of(old_row_vocab_file, entity_at(new_row_vocab_file, i))` of the old
|
||||
matrix. If any entity in `new_row_vocab_file` is not found in
|
||||
`old_row_vocab_file`, that row is considered a "missing" row, and its values
|
||||
will be initialized using the `initializer` arg. The same logic also applies
|
||||
for the columns.
|
||||
|
||||
For example, assuming that:
|
||||
|
||||
* `old_row_vocab_file` contains "mercury\nvenus\nmars"
|
||||
* `new_row_vocab_file` contains "venus\njupiter\nmercury"
|
||||
* `old_col_vocab_file` contains "good\nbetter\nbest"
|
||||
* `new_col_vocab_file` contains "good\nbest\nfantastic"
|
||||
* `initializer` returns the natural numbers `[1, 2, 3, 4, ...]`
|
||||
* `w(i, j)` represents the value from row i, column j of the old matrix
|
||||
|
||||
Then the new output matrix will look like:
|
||||
|
||||
`[[w(1, 0), w(1, 2), 1],
|
||||
[2, 3, 4],
|
||||
[w(0, 0), w(0, 2), 5]]`
|
||||
|
||||
If we further specify that:
|
||||
|
||||
* `num_row_oov_buckets` == 2
|
||||
* `num_col_oov_buckets` == 1
|
||||
|
||||
Then the new output matrix will look like:
|
||||
|
||||
`[[w(1, 0), w(1, 2), 1, 12],
|
||||
[2, 3, 4, 13],
|
||||
[w(0, 0), w(0, 2), 5, 14],
|
||||
[6, 7, 8, 15],
|
||||
[9, 10, 11, 16]]`
|
||||
|
||||
If `{old,new}_row_vocab_file` are None, we assume that the old and new row
|
||||
vocab files are the same, and no row remapping is done. If
|
||||
`{old,new}_col_vocab_file` are None, we assume that the old and new column
|
||||
vocab files are the same, and no column remapping is done.
|
||||
|
||||
The returned initializer only supports div-partitioning along the row axis. It
|
||||
does not support partitioning along the column axis or mod-partitioning.
|
||||
|
||||
NOTE: When this is used to warm-start variables, client code should use
|
||||
`tf.lookup.index_table_from_tensor()` like
|
||||
contrib/layers/python/layers/feature_column.py does, as opposed to
|
||||
`tf.feature_to_id()` - in order to ensure the underlying lookup tables are the
|
||||
same.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
new_row_vocab_size: `int` specifying the number of entries in
|
||||
`new_row_vocab_file`. If no row remapping is needed (no row vocab
|
||||
provided), this should be equal to the number of rows to load from the old
|
||||
matrix (which can theoretically be smaller than the number of rows in the
|
||||
old matrix).
|
||||
new_col_vocab_size: `int` specifying the number of entries in
|
||||
`new_col_vocab_file`. If no column remapping is needed (no column vocab
|
||||
provided), this should be equal to the number of columns in the old
|
||||
matrix.
|
||||
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old row vocabulary file. Can be None, which represents no
|
||||
remapping on the row axis.
|
||||
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new row vocabulary file. Can be None, which represents no remapping
|
||||
on the row axis.
|
||||
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
|
||||
to append. Must be >= 0.
|
||||
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
columns to append. Must be >= 0.
|
||||
initializer: Initializer function to initialize missing values. Accepts a
|
||||
1-D tensor as the arg to specify the shape of the returned tensor. If
|
||||
`None`, defaults to using `zeros_initializer()`.
|
||||
max_rows_in_memory: `int` specifying the maximum number of rows to load from
|
||||
the checkpoint at once. If less than or equal to 0, the entire matrix will
|
||||
be loaded into memory. Setting this arg trades increased disk reads for
|
||||
lower memory usage.
|
||||
|
||||
Returns:
|
||||
A variable initializer function that should be used to initialize a
|
||||
(potentially partitioned) `Variable` whose complete shape is
|
||||
`[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size +
|
||||
num_col_oov_buckets]`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `initializer` is specified but not callable.
|
||||
"""
|
||||
if initializer is None:
|
||||
# TODO(b/25671353): Consider using sqrt(6/(fan_in + fan_out)) instead, from
|
||||
# Glorot and Bengio, 2010.
|
||||
initializer = init_ops.zeros_initializer()
|
||||
|
||||
if not callable(initializer):
|
||||
raise TypeError(
|
||||
"initializer must be callable, instead of being {} of type {}.".format(
|
||||
initializer, type(initializer)))
|
||||
|
||||
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
|
||||
"""Variable initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of `Tensor` to return. Should include OOV on both axes.
|
||||
dtype: Must be float32.
|
||||
partition_info: variable_scope._PartitionInfo.
|
||||
|
||||
Returns:
|
||||
`Tensor` of shape `shape`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `dtype` is anything other than float32.
|
||||
ValueError: For shape mismatch upon invocation.
|
||||
"""
|
||||
# Sanity checks.
|
||||
if dtype != dtypes.float32:
|
||||
raise TypeError(
|
||||
"Currently, only float32 is supported. Received dtype: {}".format(
|
||||
dtype))
|
||||
if len(shape) != 2:
|
||||
raise ValueError("Expected 2-dim shape, but received: {}".format(shape))
|
||||
if shape[0] <= 0:
|
||||
raise ValueError(
|
||||
"Expected 1st dim of shape to be > 0, but received shape: {}".format(
|
||||
shape))
|
||||
if shape[1] != (new_col_vocab_size + num_col_oov_buckets):
|
||||
raise ValueError(
|
||||
"Expected 2nd dim of shape to be new_col_vocab_size ({}) + "
|
||||
"num_col_oov_buckets ({}) = {}, but received shape: {}".format(
|
||||
new_col_vocab_size, num_col_oov_buckets,
|
||||
new_col_vocab_size + num_col_oov_buckets, shape))
|
||||
|
||||
offset = 0
|
||||
if partition_info is not None:
|
||||
offset = partition_info.single_offset(shape)
|
||||
|
||||
if offset + shape[0] > new_row_vocab_size + num_row_oov_buckets:
|
||||
raise ValueError(
|
||||
"Trying to initialize {} additional rows after {} rows have already "
|
||||
"been initialized, which would exceed expected total row count of "
|
||||
"new_row_vocab_size ({}) + num_row_oov_buckets ({}) = {}.".format(
|
||||
shape[0], offset, new_row_vocab_size, num_row_oov_buckets,
|
||||
new_row_vocab_size + num_row_oov_buckets))
|
||||
|
||||
row_oov_buckets_to_use = min(shape[0],
|
||||
max(0, offset + shape[0] - new_row_vocab_size))
|
||||
num_rows_to_load = shape[0] - row_oov_buckets_to_use
|
||||
|
||||
return _load_and_remap_matrix(
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=old_tensor_name,
|
||||
new_row_vocab_offset=offset,
|
||||
num_rows_to_load=num_rows_to_load,
|
||||
new_col_vocab_size=new_col_vocab_size,
|
||||
initializer=initializer,
|
||||
old_row_vocab_file=old_row_vocab_file,
|
||||
new_row_vocab_file=new_row_vocab_file,
|
||||
old_col_vocab_file=old_col_vocab_file,
|
||||
new_col_vocab_file=new_col_vocab_file,
|
||||
num_row_oov_buckets=row_oov_buckets_to_use,
|
||||
num_col_oov_buckets=num_col_oov_buckets,
|
||||
max_rows_in_memory=max_rows_in_memory)
|
||||
|
||||
return _initializer
|
||||
|
||||
|
||||
def load_embedding_initializer(ckpt_path,
|
||||
embedding_tensor_name,
|
||||
new_vocab_size,
|
||||
embedding_dim,
|
||||
old_vocab_file,
|
||||
new_vocab_file,
|
||||
num_oov_buckets=0,
|
||||
initializer=None,
|
||||
max_rows_in_memory=-1):
|
||||
"""Returns a variable initializer for loading pre-trained embeddings.
|
||||
|
||||
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
|
||||
embedding weights and remapping according to the provided vocab files. See
|
||||
docs for `load_and_remap_matrix_initializer()` for more details.
|
||||
|
||||
NOTE: Only for use with div-partitioned variables / vocabularies.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
embedding_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
new_vocab_size: Number of entries in the new vocab.
|
||||
embedding_dim: `int` specifying the dimension of the embedding vectors from
|
||||
the checkpoint. Must match the number of columns in the old embedding
|
||||
matrix.
|
||||
old_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old vocabulary file.
|
||||
new_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the new vocabulary file.
|
||||
num_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
buckets to use. Must be >= 0.
|
||||
initializer: Initializer function that accepts a 1-D tensor as the arg to
|
||||
specify the shape of the returned tensor. If `None`, defaults to using
|
||||
`truncated_normal_initializer()`.
|
||||
max_rows_in_memory: `int` specifying the maximum number of rows to load from
|
||||
the checkpoint at once. If less than or equal to 0, the entire matrix will
|
||||
be loaded into memory. Setting this arg trades increased disk reads for
|
||||
lower memory usage.
|
||||
|
||||
Returns:
|
||||
A variable initializer function.
|
||||
"""
|
||||
if initializer is None:
|
||||
# TODO(b/25671353): This should be kept in sync with the stddev used by
|
||||
# feature_column.py's _EmbeddingColumn.
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
stddev=1.0 / math.sqrt(embedding_dim))
|
||||
|
||||
return load_and_remap_matrix_initializer(
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=embedding_tensor_name,
|
||||
new_row_vocab_size=new_vocab_size,
|
||||
new_col_vocab_size=embedding_dim,
|
||||
old_row_vocab_file=old_vocab_file,
|
||||
new_row_vocab_file=new_vocab_file,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=num_oov_buckets,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=initializer,
|
||||
max_rows_in_memory=max_rows_in_memory)
|
||||
# pylint: disable=protected-access,line-too-long
|
||||
load_and_remap_matrix_initializer = checkpoint_ops._load_and_remap_matrix_initializer
|
||||
# pylint: enable=line-too-long
|
||||
load_embedding_initializer = checkpoint_ops._load_embedding_initializer
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def load_linear_multiclass_bias_initializer(ckpt_path,
|
||||
|
@ -21,7 +21,6 @@ import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.contrib.framework.python.ops import checkpoint_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -38,250 +37,6 @@ FLAGS = flags.FLAGS
|
||||
_TESTDATA_PATH = 'contrib/framework/testdata'
|
||||
|
||||
|
||||
class LoadAndRemapWrappersTest(test.TestCase):
|
||||
"""Tests for the functionality of the Python wrappers."""
|
||||
|
||||
def setUp(self):
|
||||
self.bundle_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint')
|
||||
self.new_feature_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint_vocab.txt')
|
||||
self.old_feature_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH),
|
||||
'bundle_checkpoint_vocab_with_oov.txt')
|
||||
self.new_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
|
||||
self.old_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
|
||||
self.init_val = 42
|
||||
|
||||
def _init_val_initializer(shape, dtype=None, partition_info=None):
|
||||
del dtype, partition_info # Unused by this unit-testing initializer.
|
||||
return array_ops.tile(
|
||||
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
|
||||
|
||||
self.initializer = _init_val_initializer
|
||||
|
||||
def test_load_and_remap_matrix(self):
|
||||
"""Tests the end-to-end loading / remapping of weights."""
|
||||
# _load_and_remap_matrix() is the generalized wrapper that takes in row and
|
||||
# column vocabulary files, calls the relevant remappings, and returns the
|
||||
# weight matrix. Take this example to be linear multi-class by providing
|
||||
# both row and column vocabularies.
|
||||
remapped_matrix = checkpoint_ops._load_and_remap_matrix(
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_rows_to_load=4,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_offset=1,
|
||||
initializer=self.initializer,
|
||||
num_row_oov_buckets=1,
|
||||
num_col_oov_buckets=1)
|
||||
|
||||
# [4 in vocab + 1 oov features, 4 in vocab + 1 oov classes]. The offset
|
||||
# means we read
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([18, 34, 50, self.init_val, self.init_val], [5, 1]),
|
||||
np.reshape([16, 32, 48, self.init_val, self.init_val], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1]),
|
||||
np.reshape([17, 33, 49, self.init_val, self.init_val], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
with self.test_session():
|
||||
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
|
||||
|
||||
def test_load_and_remap_output_layer_weight_initializer_linear(self):
|
||||
"""Tests for the output layer initializer in the linear multi-class case."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_row_oov_buckets=1,
|
||||
num_col_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50, self.init_val, self.init_val], [6, 1]),
|
||||
np.reshape([0, 16, 32, 48, self.init_val, self.init_val], [6, 1]),
|
||||
np.reshape([self.init_val] * 6, [6, 1]),
|
||||
np.reshape([1, 17, 33, 49, self.init_val, self.init_val], [6, 1]),
|
||||
np.reshape([self.init_val] * 6, [6, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV]. Use a
|
||||
# partitioned variable to confirm that the offset logic works.
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='linear/obtained_weight_matrix',
|
||||
shape=[6, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_and_remap_output_layer_weight_initializer_dnn_output(self):
|
||||
"""Tests for the output layer initializer in the DNN output case."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
num_col_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50, 66], [5, 1]),
|
||||
np.reshape([0, 16, 32, 48, 64], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1]),
|
||||
np.reshape([1, 17, 33, 49, 65], [5, 1]),
|
||||
np.reshape([self.init_val] * 5, [5, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5-sized input layer, 4 class vocab + 1 class OOV].
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='dnn_output/obtained_weight_matrix',
|
||||
shape=[5, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_initializer_with_oov_only_partition(self):
|
||||
"""Tests for the output layer initializer where one partition is all OOV."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_row_oov_buckets=5,
|
||||
num_col_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50] + [self.init_val] * 6, [10, 1]),
|
||||
np.reshape([0, 16, 32, 48] + [self.init_val] * 6, [10, 1]),
|
||||
np.reshape([self.init_val] * 10, [10, 1]),
|
||||
np.reshape([1, 17, 33, 49] + [self.init_val] * 6, [10, 1]),
|
||||
np.reshape([self.init_val] * 10, [10, 1]),
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV]. The
|
||||
# second partition has only OOV.
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='linear_all_oov/obtained_weight_matrix',
|
||||
shape=[10, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_and_remap_linear_multiclass_initializer_default_init(self):
|
||||
"""Tests where the zeros_initializer default is used for linear."""
|
||||
loading_initializer = (contrib_framework.load_and_remap_matrix_initializer(
|
||||
new_row_vocab_size=5,
|
||||
new_col_vocab_file=self.new_class_vocab_file,
|
||||
old_col_vocab_file=self.old_class_vocab_file,
|
||||
new_col_vocab_size=4,
|
||||
old_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
new_row_vocab_file=self.new_feature_vocab_file,
|
||||
old_row_vocab_file=self.old_feature_vocab_file,
|
||||
num_row_oov_buckets=1,
|
||||
num_col_oov_buckets=1))
|
||||
|
||||
expected_remapped_matrix = np.concatenate(
|
||||
[
|
||||
np.reshape([2, 18, 34, 50, 0, 0], [6, 1]),
|
||||
np.reshape([0, 16, 32, 48, 0, 0], [6, 1]),
|
||||
np.reshape([0] * 6, [6, 1]),
|
||||
np.reshape([1, 17, 33, 49, 0, 0], [6, 1]),
|
||||
np.reshape([0] * 6, [6, 1])
|
||||
],
|
||||
axis=1)
|
||||
|
||||
remapped_matrix = variable_scope.get_variable(
|
||||
name='linear_init_fallback/obtained_weight_matrix',
|
||||
shape=[6, 5],
|
||||
initializer=loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_matrix,
|
||||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_embedding_initializer(self):
|
||||
"""Tests for the load_embedding_initializer wrapper."""
|
||||
embedding_loading_initializer = (
|
||||
contrib_framework.load_embedding_initializer(
|
||||
new_vocab_file=self.new_feature_vocab_file,
|
||||
old_vocab_file=self.old_feature_vocab_file,
|
||||
new_vocab_size=5,
|
||||
embedding_dim=16,
|
||||
embedding_tensor_name='some_scope/embeddings',
|
||||
ckpt_path=[self.bundle_file],
|
||||
num_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_embeddings = np.concatenate(
|
||||
[
|
||||
np.reshape(range(64), [4, 16]),
|
||||
np.reshape([self.init_val] * 32, [2, 16]),
|
||||
],
|
||||
axis=0)
|
||||
|
||||
# The new weight matrix is of size
|
||||
# [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
|
||||
# last vocab row (2nd last row) is newly initialized (wasn't found in
|
||||
# previous vocab) and the actual last row is OOV and also newly initialized.
|
||||
# Use a partitioned variable to confirm that the offset logic works.
|
||||
remapped_embeddings = variable_scope.get_variable(
|
||||
name='embedding/obtained_embedding_matrix',
|
||||
shape=[6, 16],
|
||||
initializer=embedding_loading_initializer,
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_embeddings,
|
||||
remapped_embeddings.as_tensor().eval())
|
||||
|
||||
|
||||
class LoadMulticlassBiasTest(test.TestCase):
|
||||
"""Tests for the load_linear_multiclass_bias_initializer functionality."""
|
||||
|
||||
|
27
tensorflow/contrib/gan/BUILD
Normal file
27
tensorflow/contrib/gan/BUILD
Normal file
@ -0,0 +1,27 @@
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "gan",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
4
tensorflow/contrib/gan/README.md
Normal file
4
tensorflow/contrib/gan/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
This directory contains the TFGAN project.
|
||||
|
||||
This file will have more details as code is added.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user