Merge pull request #5248 from zheng-xq/branch_137452948

Branch 137452948
This commit is contained in:
zheng-xq 2016-10-27 21:38:58 -07:00 committed by GitHub
commit 3737ac321e
419 changed files with 17858 additions and 5780 deletions

View File

@ -129,8 +129,6 @@ filegroup(
"//tensorflow/contrib/tensorboard:all_files",
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/tfprof/tools/tfprof:all_files",
"//tensorflow/contrib/tfprof/tools/tfprof/internal:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",
"//tensorflow/core:all_files",
@ -188,6 +186,8 @@ filegroup(
"//tensorflow/tools/proto_text:all_files",
"//tensorflow/tools/quantization:all_files",
"//tensorflow/tools/test:all_files",
"//tensorflow/tools/tfprof:all_files",
"//tensorflow/tools/tfprof/internal:all_files",
"//tensorflow/user_ops:all_files",
"//third_party/hadoop:all_files",
],

View File

@ -430,6 +430,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",

View File

@ -34,6 +34,7 @@ cc_library(
":constants",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core/util/tensor_bundle:naming",
@ -63,7 +64,6 @@ tf_cc_test(
filegroup(
name = "saved_model_half_plus_two",
srcs = glob([
"testdata/half_plus_two/**",
"testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_sharded/**",
]),

View File

@ -30,6 +30,9 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
// SavedModel text format proto filename.
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
// SavedModel legacy init op key.
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
// Directory in which to save the SavedModel variables.
constexpr char kSavedModelVariablesDirectory[] = "variables";

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
@ -83,10 +84,32 @@ Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
return (*session)->Create(meta_graph_def.graph_def());
}
Status Restore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
Session* session) {
Tensor CreateStringTensor(const string& value) {
Tensor tensor(DT_STRING, TensorShape({}));
tensor.scalar<string>()() = value;
return tensor;
}
void AddAssetsTensorsToInputs(const StringPiece export_dir,
const std::vector<AssetFileDef>& asset_file_defs,
std::vector<std::pair<string, Tensor>>* inputs) {
if (asset_file_defs.empty()) {
return;
}
for (auto& asset_file_def : asset_file_defs) {
Tensor assets_file_path_tensor = CreateStringTensor(io::JoinPath(
export_dir, kSavedModelAssetsDirectory, asset_file_def.filename()));
inputs->push_back(
{asset_file_def.tensor_info().name(), assets_file_path_tensor});
}
}
Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
const std::vector<AssetFileDef>& asset_file_defs,
Session* session) {
LOG(INFO) << "Restoring SavedModel bundle.";
// Find path to variables to be restored in export directory.
const string variables_directory =
io::JoinPath(export_dir, kSavedModelVariablesDirectory);
@ -109,11 +132,54 @@ Status Restore(const RunOptions& run_options, const string& export_dir,
std::vector<std::pair<string, Tensor>> inputs = {
{variable_filename_const_op_name.ToString(), variables_path_tensor}};
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
nullptr /* outputs */, &run_metadata);
}
Status RunLegacyInitOp(const RunOptions& run_options, const string& export_dir,
const MetaGraphDef& meta_graph_def,
const std::vector<AssetFileDef>& asset_file_defs,
Session* session) {
LOG(INFO) << "Running LegacyInitOp on SavedModel bundle.";
const auto& collection_def_map = meta_graph_def.collection_def();
const auto init_op_it = collection_def_map.find(kSavedModelLegacyInitOpKey);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(strings::StrCat(
"Expected exactly one serving init op in : ", export_dir));
}
std::vector<std::pair<string, Tensor>> inputs;
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
const StringPiece legacy_init_op_name =
init_op_it->second.node_list().value(0);
return session->Run(run_options, inputs, {},
{legacy_init_op_name.ToString()}, nullptr /* outputs */,
&run_metadata);
}
return Status::OK();
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
Status LoadSavedModelInternal(const SessionOptions& session_options,
const RunOptions& run_options,
const string& export_dir,
@ -134,12 +200,19 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
bundle->meta_graph_def, session_options, &bundle->session));
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(
Restore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
bundle->session.get()));
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
TF_RETURN_IF_ERROR(
RunRestore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
asset_file_defs, bundle->session.get()));
// TODO(sukritiramesh): Add support for a single main op to run upon load,
// which will supersede the legacy_init_op and separate RunRestore.
TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir,
bundle->meta_graph_def, asset_file_defs,
bundle->session.get()));
return Status::OK();
}

View File

@ -29,7 +29,6 @@ limitations under the License.
namespace tensorflow {
namespace {
constexpr char kTestDataPb[] = "cc/saved_model/testdata/half_plus_two";
constexpr char kTestDataPbTxt[] = "cc/saved_model/testdata/half_plus_two_pbtxt";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two_sharded";
@ -45,12 +44,26 @@ class LoaderTest : public ::testing::Test {
return example.SerializeAsString();
}
void ValidateAssets(const string& export_dir,
const SavedModelBundle& bundle) {
const string asset_directory =
io::JoinPath(export_dir, kSavedModelAssetsDirectory);
const string asset_filename = "foo.txt";
const string asset_filepath = io::JoinPath(asset_directory, asset_filename);
EXPECT_TRUE(Env::Default()->FileExists(asset_filepath));
std::vector<Tensor> path_outputs;
TF_ASSERT_OK(
bundle.session->Run({}, {"filename_tensor:0"}, {}, &path_outputs));
ASSERT_EQ(1, path_outputs.size());
test::ExpectTensorEqual<string>(
test::AsTensor<string>({"foo.txt"}, TensorShape({})), path_outputs[0]);
}
void CheckSavedModelBundle(const string& export_dir,
const SavedModelBundle& bundle) {
const string asset_path =
io::JoinPath(export_dir, kSavedModelAssetsDirectory, "foo.txt");
EXPECT_TRUE(Env::Default()->FileExists(asset_path));
ValidateAssets(export_dir, bundle);
// Retrieve the regression signature from meta graph def.
const auto signature_def_map = bundle.meta_graph_def.signature_def();
const auto signature_def = signature_def_map.at(kRegressMethodName);
@ -151,18 +164,6 @@ TEST_F(LoaderTest, PbtxtFormat) {
CheckSavedModelBundle(export_dir, bundle);
}
TEST_F(LoaderTest, SingleShardVariables) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPb);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(export_dir, bundle);
}
TEST_F(LoaderTest, InvalidExportPath) {
SavedModelBundle bundle;
RunOptions run_options;

View File

@ -1 +0,0 @@
asset-file-contents

View File

@ -102,6 +102,24 @@ meta_graphs {
type: "type"
}
}
op {
name: "MergeV2Checkpoints"
input_arg {
name: "checkpoint_prefixes"
type: DT_STRING
}
input_arg {
name: "destination_prefix"
type: DT_STRING
}
attr {
name: "delete_old_dirs"
type: "bool"
default_value {
b: true
}
}
}
op {
name: "Mul"
input_arg {
@ -140,6 +158,35 @@ meta_graphs {
op {
name: "NoOp"
}
op {
name: "Pack"
input_arg {
name: "values"
type_attr: "T"
number_attr: "N"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "N"
type: "int"
has_minimum: true
minimum: 1
}
attr {
name: "T"
type: "type"
}
attr {
name: "axis"
type: "int"
default_value {
i: 0
}
}
}
op {
name: "ParseExample"
input_arg {
@ -267,9 +314,9 @@ meta_graphs {
}
}
op {
name: "SaveSlices"
name: "SaveV2"
input_arg {
name: "filename"
name: "prefix"
type: DT_STRING
}
input_arg {
@ -277,15 +324,15 @@ meta_graphs {
type: DT_STRING
}
input_arg {
name: "shapes_and_slices"
name: "shape_and_slices"
type: DT_STRING
}
input_arg {
name: "data"
type_list_attr: "T"
name: "tensors"
type_list_attr: "dtypes"
}
attr {
name: "T"
name: "dtypes"
type: "list(type)"
has_minimum: true
minimum: 1
@ -311,19 +358,29 @@ meta_graphs {
}
}
op {
name: "ShardedFilespec"
name: "StringJoin"
input_arg {
name: "basename"
name: "inputs"
type: DT_STRING
}
input_arg {
name: "num_shards"
type: DT_INT32
number_attr: "N"
}
output_arg {
name: "filename"
name: "output"
type: DT_STRING
}
attr {
name: "N"
type: "int"
has_minimum: true
minimum: 1
}
attr {
name: "separator"
type: "string"
default_value {
s: ""
}
}
}
op {
name: "Variable"
@ -899,6 +956,244 @@ meta_graphs {
}
}
}
node {
name: "Const"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "/tmp/original/export/assets/foo.txt"
}
}
}
}
node {
name: "filename_tensor/initial_value"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "foo.txt"
}
}
}
}
node {
name: "filename_tensor"
op: "Variable"
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "shape"
value {
shape {
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
}
node {
name: "filename_tensor/Assign"
op: "Assign"
input: "filename_tensor"
input: "filename_tensor/initial_value"
attr {
key: "T"
value {
type: DT_STRING
}
}
attr {
key: "_class"
value {
list {
s: "loc:@filename_tensor"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "filename_tensor/read"
op: "Identity"
input: "filename_tensor"
attr {
key: "T"
value {
type: DT_STRING
}
}
attr {
key: "_class"
value {
list {
s: "loc:@filename_tensor"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
node {
name: "Assign/value"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "foo.txt"
}
}
}
}
node {
name: "Assign"
op: "Assign"
input: "filename_tensor"
input: "Assign/value"
attr {
key: "T"
value {
type: DT_STRING
}
}
attr {
key: "_class"
value {
list {
s: "loc:@filename_tensor"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "use_locking"
value {
b: false
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "Identity"
op: "Identity"
@ -931,6 +1226,11 @@ meta_graphs {
input: "^a/Assign"
input: "^b/Assign"
}
node {
name: "group_deps"
op: "NoOp"
input: "^Assign"
}
node {
name: "save/Const"
op: "Const"
@ -961,6 +1261,63 @@ meta_graphs {
}
}
}
node {
name: "save/StringJoin/inputs_1"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "_temp_ff2bd25218b646ea9ed224eecdce5e79/part"
}
}
}
}
node {
name: "save/StringJoin"
op: "StringJoin"
input: "save/Const"
input: "save/StringJoin/inputs_1"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "separator"
value {
s: ""
}
}
}
node {
name: "save/num_shards"
op: "Const"
@ -1024,7 +1381,7 @@ meta_graphs {
node {
name: "save/ShardedFilename"
op: "ShardedFilename"
input: "save/Const"
input: "save/StringJoin"
input: "save/ShardedFilename/shard"
input: "save/num_shards"
attr {
@ -1038,7 +1395,7 @@ meta_graphs {
}
}
node {
name: "save/save/tensor_names"
name: "save/SaveV2/tensor_names"
op: "Const"
attr {
key: "_output_shapes"
@ -1075,7 +1432,7 @@ meta_graphs {
}
}
node {
name: "save/save/shapes_and_slices"
name: "save/SaveV2/shape_and_slices"
op: "Const"
attr {
key: "_output_shapes"
@ -1112,15 +1469,15 @@ meta_graphs {
}
}
node {
name: "save/save"
op: "SaveSlices"
name: "save/SaveV2"
op: "SaveV2"
input: "save/ShardedFilename"
input: "save/save/tensor_names"
input: "save/save/shapes_and_slices"
input: "save/SaveV2/tensor_names"
input: "save/SaveV2/shape_and_slices"
input: "a"
input: "b"
attr {
key: "T"
key: "dtypes"
value {
list {
type: DT_FLOAT
@ -1133,7 +1490,7 @@ meta_graphs {
name: "save/control_dependency"
op: "Identity"
input: "save/ShardedFilename"
input: "^save/save"
input: "^save/SaveV2"
attr {
key: "T"
value {
@ -1159,11 +1516,65 @@ meta_graphs {
}
}
node {
name: "save/ShardedFilespec"
op: "ShardedFilespec"
input: "save/Const"
input: "save/num_shards"
name: "save/MergeV2Checkpoints/checkpoint_prefixes"
op: "Pack"
input: "save/ShardedFilename"
input: "^save/control_dependency"
attr {
key: "N"
value {
i: 1
}
}
attr {
key: "T"
value {
type: DT_STRING
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "axis"
value {
i: 0
}
}
}
node {
name: "save/MergeV2Checkpoints"
op: "MergeV2Checkpoints"
input: "save/MergeV2Checkpoints/checkpoint_prefixes"
input: "save/Const"
attr {
key: "delete_old_dirs"
value {
b: true
}
}
}
node {
name: "save/Identity"
op: "Identity"
input: "save/Const"
input: "^save/control_dependency"
input: "^save/MergeV2Checkpoints"
attr {
key: "T"
value {
type: DT_STRING
}
}
attr {
key: "_output_shapes"
value {
@ -1467,12 +1878,39 @@ meta_graphs {
}
saver_def {
filename_tensor_name: "save/Const:0"
save_tensor_name: "save/ShardedFilespec:0"
save_tensor_name: "save/Identity:0"
restore_op_name: "save/restore_all"
max_to_keep: 5
sharded: true
keep_checkpoint_every_n_hours: 10000.0
version: V1
version: V2
}
collection_def {
key: "asset_filepaths"
value {
node_list {
value: "Const:0"
}
}
}
collection_def {
key: "legacy_init_op"
value {
node_list {
value: "group_deps"
}
}
}
collection_def {
key: "saved_model_assets"
value {
any_list {
value {
type_url: "type.googleapis.com/tensorflow.AssetFileDef"
value: "\n\t\n\007Const:0\022\007foo.txt"
}
}
}
}
collection_def {
key: "trainable_variables"

View File

@ -54,7 +54,8 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
}
QueueRunner::~QueueRunner() {
should_stop_ = true;
// Cannot run Stop() here because the session might already be closed or
// destroyed.
Join();
}
@ -72,6 +73,15 @@ Status QueueRunner::Start(Session* sess) {
return Status::OK();
}
Status QueueRunner::Stop(Session* sess) {
should_stop_ = true;
if (cancel_op_name_.empty()) {
return Status::OK();
} else {
return sess->Run({}, {}, {cancel_op_name_}, nullptr);
}
}
Status QueueRunner::Join() {
thread_pool_.reset();
started_ = false;
@ -80,9 +90,8 @@ Status QueueRunner::Join() {
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
bool decremented = false;
while (!should_stop_) {
std::vector<Tensor> outputs;
auto status = sess->Run({}, {}, {enqueue_op}, &outputs);
while (!should_stop_.load()) {
auto status = sess->Run({}, {}, {enqueue_op}, nullptr);
if (status.ok()) {
continue;
} else if (queue_closed_exception_types_.count(
@ -94,19 +103,25 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
// If all enqueue ops have finished, run the close op.
if (runs_ == 0 && !close_op_name_.empty()) {
std::vector<Tensor> outputs;
auto s = sess->Run({}, {}, {close_op_name_}, &outputs);
if (!s.ok()) {
status_ = status;
auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
if (!s.ok() && status_.ok() &&
queue_closed_exception_types_.count(static_cast<int>(s.code())) ==
0) {
status_ = s;
}
}
} else {
mutex_lock l(mu_);
should_stop_ = true;
// Only record the first failure status.
if (status_.ok()) {
status_ = status;
{
mutex_lock l(mu_);
should_stop_ = true;
// Only record the first failure status.
if (status_.ok()) {
status_ = status;
}
}
// Stop the queue runner immediately to propagate the error to
// subsequent queues.
Stop(sess);
}
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
@ -49,6 +50,9 @@ class QueueRunner {
// Starts the queue runner with the given session.
Status Start(Session* sess);
// Requests to stop and runs the cancel op.
Status Stop(Session* sess);
// Joins all the threads. Returns okay if all threads run successfully;
// otherwise returns the first captured failure status.
Status Join();
@ -60,14 +64,14 @@ class QueueRunner {
string queue_name_;
std::vector<string> enqueue_op_names_;
string close_op_name_;
// The cancel op is not being called currently.
string cancel_op_name_;
// code::Code casted to int to avoid a hash function.
std::unordered_set<int> queue_closed_exception_types_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
bool should_stop_;
std::atomic<bool> should_stop_;
std::atomic<bool> started_;
condition_variable wait_to_close_;
mutex mu_;
// TODO(yuefengz): implement c++ coordinator.
int runs_ = 0;

View File

@ -14,8 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/cc/training/queue_runner.h"
#include <string>
#include <vector>
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -23,39 +25,42 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace {
using ::tensorflow::DataType;
using ::tensorflow::error::Code;
using ::tensorflow::GraphDef;
using ::tensorflow::ops::Assign;
using ::tensorflow::ops::Const;
using ::tensorflow::ops::CountUpTo;
using ::tensorflow::ops::FIFOQueue;
using ::tensorflow::ops::InputList;
using ::tensorflow::ops::QueueClose;
using ::tensorflow::ops::QueueDequeue;
using ::tensorflow::ops::QueueEnqueue;
using ::tensorflow::ops::Square;
using ::tensorflow::ops::Variable;
using ::tensorflow::QueueRunner;
using ::tensorflow::QueueRunnerDef;
using ::tensorflow::Scope;
using ::tensorflow::Session;
using ::tensorflow::SessionOptions;
using ::tensorflow::Tensor;
using ::tensorflow::TensorShape;
using error::Code;
using ops::Assign;
using ops::Const;
using ops::CountUpTo;
using ops::FIFOQueue;
using ops::QueueClose;
using ops::QueueDequeue;
using ops::QueueEnqueue;
using ops::Square;
using ops::Variable;
constexpr char kAssignOpName[] = "assign";
constexpr char kCancelOp0[] = "cancel0";
constexpr char kCancelOp1[] = "cancel1";
constexpr char kCloseOp0[] = "close0";
constexpr char kCloseOp1[] = "close1";
constexpr char kCountUpToOpName[] = "count";
constexpr char kDequeueOp0[] = "dequeue0";
constexpr char kDequeueOp1[] = "dequeue1";
constexpr char kEnqueueOp0[] = "enqueue0";
constexpr char kEnqueueOp1[] = "enqueue1";
constexpr char kIllegalOpName1[] = "would fail";
constexpr char kIllegalOpName2[] = "fail again";
constexpr char kQueueName[] = "unit_test";
constexpr char kQueueName0[] = "q0";
constexpr char kQueueName1[] = "q1";
constexpr char kSquareOpName[] = "square";
constexpr char kVarOpName[] = "var";
@ -75,7 +80,7 @@ GraphDef BuildSimpleGraph() {
QueueRunnerDef BuildQueueRunnerDef(
const std::string& queue_name, const std::vector<std::string>& enqueue_ops,
const std::string& close_op,
const std::string& close_op, const std::string& cancel_op,
const std::vector<Code>& queue_closed_error_codes) {
QueueRunnerDef queue_runner_def;
*queue_runner_def.mutable_queue_name() = kQueueName;
@ -83,6 +88,7 @@ QueueRunnerDef BuildQueueRunnerDef(
*queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
}
*queue_runner_def.mutable_close_op_name() = close_op;
*queue_runner_def.mutable_cancel_op_name() = cancel_op;
for (const auto& error_code : queue_closed_error_codes) {
*queue_runner_def.mutable_queue_closed_exception_types()->Add() =
error_code;
@ -96,8 +102,7 @@ std::unique_ptr<Session> BuildSessionAndInitVariable(
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
std::vector<Tensor> nothing;
TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, &nothing));
TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr));
return session;
}
@ -106,7 +111,7 @@ TEST(QueueRunnerTest, BasicTest) {
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, {});
kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {});
QueueRunner qr(queue_runner_def);
qr.Start(session.get());
@ -123,7 +128,7 @@ TEST(QueueRunnerTest, QueueClosedCode) {
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def =
BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName,
BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, "",
{Code::OUT_OF_RANGE, Code::CANCELLED});
QueueRunner qr(queue_runner_def);
@ -141,60 +146,167 @@ TEST(QueueRunnerDef, CatchErrorInJoin) {
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, {});
kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
QueueRunner qr(queue_runner_def);
qr.Start(session.get());
EXPECT_EQ(qr.Join().code(), Code::NOT_FOUND);
}
TEST(QueueRunnerTest, RealEnqueueDequeue) {
GraphDef BuildDoubleQueueGraph() {
Scope root = Scope::NewRootScope();
auto q0 = FIFOQueue(root.WithOpName("q0"), {DataType::DT_INT32});
auto q0 = FIFOQueue(root.WithOpName(kQueueName0), {DataType::DT_INT32});
auto ten = Const(root, 10);
auto enqueue0 = QueueEnqueue(root.WithOpName("enqueue0"), q0, {ten});
auto close0 = QueueClose(root.WithOpName("close0"), q0);
auto q1 = FIFOQueue(root.WithOpName("q1"), {DataType::DT_INT32});
auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {ten});
auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
QueueClose::CancelPendingEnqueues(true));
auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32});
auto dequeue0 =
QueueDequeue(root.WithOpName("dequeue0"), q0, {DataType::DT_INT32});
auto enqueue1 = QueueEnqueue(root.WithOpName("enqueue1"), q1, {dequeue0[0]});
QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
auto dequeue1 =
QueueDequeue(root.WithOpName("dequeue1"), q1, {DataType::DT_INT32});
auto close1 = QueueClose(root.WithOpName("close1"), q1);
QueueDequeue(root.WithOpName(kDequeueOp1), q1, {DataType::DT_INT32});
auto close1 = QueueClose(root.WithOpName(kCloseOp1), q1);
auto cancel1 = QueueClose(root.WithOpName(kCancelOp1), q1,
QueueClose::CancelPendingEnqueues(true));
GraphDef graph_def;
TF_EXPECT_OK(root.ToGraphDef(&graph_def));
return graph_def;
}
TEST(QueueRunnerTest, RealEnqueueDequeue) {
auto graph_def = BuildDoubleQueueGraph();
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
QueueRunnerDef queue_runner_def =
BuildQueueRunnerDef(kQueueName, {"enqueue1"}, "close1", {});
BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {});
QueueRunner qr;
qr.Init(queue_runner_def);
TF_CHECK_OK(qr.Start(session.get()));
std::vector<Tensor> outputs;
TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs));
TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs));
TF_EXPECT_OK(session->Run({}, {}, {"close0"}, &outputs));
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
// Closing queue 0 would also close the queue runner.
TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr));
TF_EXPECT_OK(qr.Join());
std::vector<Tensor> dq1;
TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq1));
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
std::vector<Tensor> dq2;
TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq2));
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq2));
EXPECT_EQ(*dq2[0].scalar<int>().data(), 10);
EXPECT_EQ(session->Run({}, {"dequeue1"}, {}, &dq1).code(),
EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
Code::OUT_OF_RANGE);
}
void JoinThread(QueueRunner* queue_runner, bool* join_succeeded,
Notification* join_done) {
EXPECT_EQ(queue_runner->Join().code(), Code::CANCELLED);
*join_succeeded = true;
join_done->Notify();
}
TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
auto graph_def = BuildDoubleQueueGraph();
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
QueueRunner qr;
qr.Init(queue_runner_def);
TF_CHECK_OK(qr.Start(session.get()));
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
std::vector<Tensor> dq1;
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
// The expected behavior is the QueueRunner::Join() call is blocked until
// Session::Close() is called.
bool join_succeeded = false;
Notification join_done;
Env::Default()->SchedClosure(
std::bind(&JoinThread, &qr, &join_succeeded, &join_done));
Env::Default()->SleepForMicroseconds(10000000);
EXPECT_EQ(join_succeeded, false);
// Closing the session is required to cancel pending enqueue nodes.
TF_EXPECT_OK(session->Close());
join_done.WaitForNotification();
EXPECT_EQ(join_succeeded, true);
}
TEST(QueueRunnerTest, Stop) {
auto graph_def = BuildDoubleQueueGraph();
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
QueueRunner qr;
qr.Init(queue_runner_def);
TF_CHECK_OK(qr.Start(session.get()));
TF_EXPECT_OK(qr.Stop(session.get()));
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
Code::OUT_OF_RANGE);
// qr is already stopped
TF_EXPECT_OK(qr.Join());
}
TEST(QueueRunnerTest, StopTwoQueues) {
auto graph_def = BuildDoubleQueueGraph();
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
QueueRunnerDef queue_runner0 =
BuildQueueRunnerDef(kQueueName0, {kEnqueueOp0}, kCloseOp0, kCancelOp0,
{Code::OUT_OF_RANGE, Code::CANCELLED});
QueueRunnerDef queue_runner1 =
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
{Code::OUT_OF_RANGE, Code::CANCELLED});
QueueRunner qr0;
qr0.Init(queue_runner0);
TF_CHECK_OK(qr0.Start(session.get()));
QueueRunner qr1;
qr1.Init(queue_runner1);
TF_CHECK_OK(qr1.Start(session.get()));
std::vector<Tensor> dq;
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
TF_EXPECT_OK(qr0.Stop(session.get()));
TF_EXPECT_OK(qr1.Stop(session.get()));
TF_EXPECT_OK(qr0.Join());
TF_EXPECT_OK(qr1.Join());
}
TEST(QueueRunnerTest, EmptyEnqueueOps) {
QueueRunnerDef queue_runner_def =
BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, {});
BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
QueueRunner qr;
EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::INVALID_ARGUMENT);
@ -203,8 +315,8 @@ TEST(QueueRunnerTest, EmptyEnqueueOps) {
TEST(QueueRunnerTest, InitAfterStart) {
GraphDef graph_def = BuildSimpleGraph();
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def =
BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kCountUpToOpName, {});
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName, {kCountUpToOpName}, kCountUpToOpName, "", {});
QueueRunner qr;
TF_EXPECT_OK(qr.Init(queue_runner_def));
@ -213,3 +325,4 @@ TEST(QueueRunnerTest, InitAfterStart) {
}
} // namespace
} // namespace tensorflow

View File

@ -1,7 +1,7 @@
include (ExternalProject)
set(gemmlowp_URL http://github.com/google/gemmlowp/archive/c0bacf11fb509a2cbe15a97362a2df067ffd57a2.tar.gz)
set(gemmlowp_HASH SHA256=dc64a38f9927db18748d9024987c9b102115e25bc2be4b76aa8e422b8f83d882)
set(gemmlowp_URL http://github.com/google/gemmlowp/archive/a6f29d8ac48d63293f845f2253eccbf86bc28321.tar.gz)
set(gemmlowp_HASH SHA256=75d40ea8e68b0d1644f052fffe8f14a410b2a73d40ccb859a95c0578d194ec26)
set(gemmlowp_BUILD ${CMAKE_BINARY_DIR}/gemmlowp/src/gemmlowp)
set(gemmlowp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/gemmlowp/src/gemmlowp)

View File

@ -0,0 +1,14 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

View File

@ -89,6 +89,7 @@ if(WIN32)
"${tensorflow_source_dir}/tensorflow/core/kernels/fact_op.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/immutable_constant_op.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/immutable_constant_op.h"
"${tensorflow_source_dir}/tensorflow/core/kernels/meta_support.*"
"${tensorflow_source_dir}/tensorflow/core/kernels/sparse_matmul_op.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/sparse_matmul_op.h"
"${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.h"

View File

@ -13,7 +13,10 @@ add_executable(${proto_text}
$<TARGET_OBJECTS:tf_core_lib>
)
target_link_libraries(${proto_text} PUBLIC ${tensorflow_EXTERNAL_LIBRARIES})
target_link_libraries(${proto_text} PUBLIC
${tensorflow_EXTERNAL_LIBRARIES}
tf_protos_cc
)
add_dependencies(${proto_text}
tf_core_lib

View File

@ -36,7 +36,7 @@ cuda_py_tests(
cuda_py_tests(
name = "operator_pd_cholesky_test",
size = "small",
size = "medium",
srcs = ["python/kernel_tests/operator_pd_cholesky_test.py"],
additional_deps = [
":distributions_py",
@ -60,7 +60,7 @@ cuda_py_tests(
cuda_py_tests(
name = "operator_pd_full_test",
size = "small",
size = "medium",
srcs = ["python/kernel_tests/operator_pd_full_test.py"],
additional_deps = [
":distributions_py",
@ -72,7 +72,7 @@ cuda_py_tests(
cuda_py_tests(
name = "operator_pd_identity_test",
size = "small",
size = "medium",
srcs = ["python/kernel_tests/operator_pd_identity_test.py"],
additional_deps = [
":distributions_py",

View File

@ -614,6 +614,67 @@ class SigmoidCenteredBijectorTest(tf.test.TestCase):
atol=0., rtol=1e-7)
class CholeskyOuterProductBijectorTest(tf.test.TestCase):
"""Tests the correctness of the Y = X * X^T transformation."""
def testBijectorMatrix(self):
with self.test_session():
bijector = bijectors.CholeskyOuterProduct(event_ndims=2,
validate_args=True)
self.assertEqual("cholesky_outer_product", bijector.name)
x = [[[1., 0],
[2, 1]],
[[math.sqrt(2.), 0],
[math.sqrt(8.), 1]]]
y = np.matmul(x, np.transpose(x, axes=(0, 2, 1)))
# Fairly easy to compute differentials since we have 2x2.
dx_dy = [[[2.*1, 0, 0],
[2, 1, 0],
[0, 2*2, 2*1]],
[[2*math.sqrt(2.), 0, 0],
[math.sqrt(8.), math.sqrt(2.), 0],
[0, 2*math.sqrt(8.), 2*1]]]
ildj = -np.sum(
np.log(np.asarray(dx_dy).diagonal(offset=0, axis1=1, axis2=2)),
axis=1)
self.assertAllEqual((2, 2, 2), bijector.forward(x).get_shape())
self.assertAllEqual((2, 2, 2), bijector.inverse(y).get_shape())
self.assertAllClose(y, bijector.forward(x).eval())
self.assertAllClose(x, bijector.inverse(y).eval())
self.assertAllClose(ildj,
bijector.inverse_log_det_jacobian(y).eval(),
atol=0., rtol=1e-7)
self.assertAllClose(-bijector.inverse_log_det_jacobian(y).eval(),
bijector.forward_log_det_jacobian(x).eval(),
atol=0., rtol=1e-7)
def testBijectorScalar(self):
with self.test_session():
bijector = bijectors.CholeskyOuterProduct(event_ndims=0,
validate_args=True)
self.assertEqual("cholesky_outer_product", bijector.name)
x = [[[1., 5],
[2, 1]],
[[math.sqrt(2.), 3],
[math.sqrt(8.), 1]]]
y = np.square(x)
ildj = -math.log(2.) - np.log(x)
self.assertAllClose(y, bijector.forward(x).eval())
self.assertAllClose(x, bijector.inverse(y).eval())
self.assertAllClose(ildj,
bijector.inverse_log_det_jacobian(y).eval(),
atol=0., rtol=1e-7)
self.assertAllClose(-bijector.inverse_log_det_jacobian(y).eval(),
bijector.forward_log_det_jacobian(x).eval(),
atol=0., rtol=1e-7)
def testScalarCongruency(self):
with self.test_session():
bijector = bijectors.CholeskyOuterProduct(event_ndims=0,
validate_args=True)
assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
class ChainBijectorTest(tf.test.TestCase):
"""Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation."""

View File

@ -41,11 +41,34 @@ class DistributionTest(tf.test.TestCase):
for cls in classes:
for sample_shape in sample_shapes:
param_shapes = cls.param_shapes(sample_shape)
print(param_shapes)
params = dict([(name, tf.random_normal(shape))
for name, shape in param_shapes.items()])
dist = cls(**params)
self.assertAllEqual(sample_shape, tf.shape(dist.sample()).eval())
dist_copy = dist.copy()
self.assertAllEqual(sample_shape,
tf.shape(dist_copy.sample()).eval())
self.assertEqual(dist.parameters, dist_copy.parameters)
def testCopyExtraArgs(self):
with self.test_session():
# Note: we cannot easily test all distributions since each requires
# different initialization arguments. We therefore spot test a few.
normal = dists.Normal(mu=1., sigma=2., validate_args=True)
self.assertEqual(normal.parameters, normal.copy().parameters)
wishart = dists.WishartFull(df=2, scale=[[1., 2], [2, 5]],
validate_args=True)
self.assertEqual(wishart.parameters, wishart.copy().parameters)
def testCopyOverride(self):
with self.test_session():
normal = dists.Normal(mu=1., sigma=2., validate_args=True)
normal_copy = normal.copy(validate_args=False)
base_params = normal.parameters.copy()
copy_params = normal.copy(validate_args=False).parameters.copy()
self.assertNotEqual(base_params.pop("validate_args"),
copy_params.pop("validate_args"))
self.assertEqual(base_params, copy_params)
if __name__ == '__main__':

View File

@ -14,7 +14,7 @@
# ==============================================================================
r"""Bijector Ops.
An API for reversible (bijective) transformations of random variables.
An API for invertible, differentiable transformations of random variables.
## Background
@ -31,6 +31,7 @@ To apply a `Bijector`, use `distributions.TransformedDistribution`.
@@Bijector
@@Chain
@@CholeskyOuterProduct
@@Exp
@@Identity
@@Inline
@ -46,7 +47,9 @@ from __future__ import division
from __future__ import print_function
import abc
import collections
import contextlib
import math
import re
import numpy as np
import six
@ -58,18 +61,112 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
__all__ = [
"Bijector",
"Chain",
"CholeskyOuterProduct",
"Exp",
"Identity",
"Inline",
"Invert",
"ScaleAndShift",
"SigmoidCentered",
"SoftmaxCentered",
"Softplus",
]
class _Mapping(collections.namedtuple("_Mapping",
["x", "y", "ildj", "condition_kwargs"])):
"""Helper class to make it easier to manage caching in `Bijector`."""
def __new__(cls, x=None, y=None, ildj=None, condition_kwargs=None):
"""Custom __new__ so namedtuple items have defaults.
Args:
x: `Tensor`. Forward.
y: `Tensor`. Inverse.
ildj: `Tensor`. Inverse log det Jacobian.
condition_kwargs: Python dictionary. Extra args supplied to
forward/inverse/etc functions.
Returns:
mapping: New instance of _Mapping.
"""
return super(_Mapping, cls).__new__(cls, x, y, ildj, condition_kwargs)
@property
def x_key(self):
"""Returns key used for caching Y=g(X)."""
return (self.x,) + self._deep_tuple(tuple(sorted(
self.condition_kwargs.items())))
@property
def y_key(self):
"""Returns key used for caching X=g^{-1}(Y)."""
return (self.y,) + self._deep_tuple(tuple(sorted(
self.condition_kwargs.items())))
def merge(self, x=None, y=None, ildj=None,
condition_kwargs=None, mapping=None):
"""Returns new _Mapping with args merged with self.
Args:
x: `Tensor`. Forward.
y: `Tensor`. Inverse.
ildj: `Tensor`. Inverse log det Jacobian.
condition_kwargs: Python dictionary. Extra args supplied to
forward/inverse/etc functions.
mapping: Instance of _Mapping to merge. Can only be specified if no other
arg is specified.
Returns:
mapping: New instance of `_Mapping` which has inputs merged with self.
Raises:
ValueError: if mapping and any other arg is not `None`.
"""
if mapping is None:
mapping = _Mapping(x=x, y=y, ildj=ildj,
condition_kwargs=condition_kwargs)
elif not all([arg is None for arg in [x, y, ildj, condition_kwargs]]):
raise ValueError("Cannot specify mapping and individual args.")
return _Mapping(
x=self._merge(self.x, mapping.x),
y=self._merge(self.y, mapping.y),
ildj=self._merge(self.ildj, mapping.ildj),
condition_kwargs=self._merge(self.condition_kwargs,
mapping.condition_kwargs))
def _merge(self, old, new):
"""Helper to merge which handles merging one value."""
if old is None:
return new
elif new is not None and old != new:
raise ValueError("Incompatible values: %s != %s" % (old, new))
return old
def _deep_tuple(self, x):
"""Converts lists of lists to tuples of tuples."""
return (tuple(map(self._deep_tuple, x))
if isinstance(x, (list, tuple)) else x)
@six.add_metaclass(abc.ABCMeta)
class Bijector(object):
"""Interface for transforming a `Distribution` via `TransformedDistribution`.
"""Interface for transforming a `Distribution` sample.
A `Bijector` implements a bijective, differentiable function by transforming
an input `Tensor`. The output `Tensor` shape is constrained by the input
`sample`, `batch`, and `event` shape. A `Bijector` is characterized by three
A `Bijector` implements a
[diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism), i.e., a
bijective, differentiable function. A `Bijector` is used by
`TransformedDistribution` but can be generally used for transforming a
`Distribution` generated `Tensor`. A `Bijector` is characterized by three
operations:
1. Forward Evaluation
@ -210,7 +307,8 @@ class Bijector(object):
- The inverse `log o det o Jacobian` can be implemented as the negative of the
forward `log o det o Jacobian`. This is useful if the `inverse` is
implemented as a cache or the inverse Jacobian is computationally more
expensive. The following demonstrates the suggested implementation.
expensive (e.g., `CholeskyOuterProduct` `Bijector`). The following
demonstrates the suggested implementation.
```python
def _inverse_and_log_det_jacobian(self, y):
@ -300,6 +398,11 @@ class Bijector(object):
self._is_constant_jacobian = is_constant_jacobian
self._validate_args = validate_args
self._dtype = dtype
self._from_y = {}
self._from_x = {}
# Using abbreviation ildj for "inverse log det Jacobian."
# This variable is not `None` iff is_constant_jacobian is `True`.
self._constant_ildj = None
if name:
self._name = name
else:
@ -368,7 +471,12 @@ class Bijector(object):
with self._name_scope(name, [x]):
x = ops.convert_to_tensor(x, name="x")
self._maybe_assert_dtype(x)
return self._forward(x, **condition_kwargs)
mapping = self._lookup(x=x, condition_kwargs=condition_kwargs)
if mapping.y is not None:
return mapping.y
mapping = mapping.merge(y=self._forward(x, **condition_kwargs))
self._cache(mapping)
return mapping.y
def _inverse(self, y):
raise NotImplementedError("inverse is not implemented")
@ -393,16 +501,28 @@ class Bijector(object):
with self._name_scope(name, [y]):
y = ops.convert_to_tensor(y, name="y")
self._maybe_assert_dtype(y)
mapping = self._lookup(y=y, condition_kwargs=condition_kwargs)
if mapping.x is not None:
return mapping.x
ildj = None
try:
return self._inverse(y, **condition_kwargs)
x = self._inverse(y, **condition_kwargs)
except NotImplementedError as original_error:
# Since _inverse was not implemented, try to see if it's implemented
# by the _inverse_and_inverse_log_det_jacobian member.
try:
return self._inverse_and_inverse_log_det_jacobian(
y, **condition_kwargs)[0]
x, ildj = self._inverse_and_inverse_log_det_jacobian(
y, **condition_kwargs)
if self._constant_ildj is not None:
ildj = self._constant_ildj # Use the "global" result.
elif self.is_constant_jacobian:
self._constant_ildj = ildj
except NotImplementedError:
raise original_error
x = x if mapping.x is None else mapping.x
mapping = mapping.merge(x=x, ildj=ildj)
self._cache(mapping)
return mapping.x
def _inverse_log_det_jacobian(self, y):
raise NotImplementedError("inverse_log_det_jacobian is not implemented.")
@ -430,18 +550,32 @@ class Bijector(object):
`_inverse_and_inverse_log_det_jacobian` are implemented.
"""
with self._name_scope(name, [y]):
if self._constant_ildj is not None:
return self._constant_ildj
y = ops.convert_to_tensor(y, name="y")
self._maybe_assert_dtype(y)
mapping = self._lookup(y=y, condition_kwargs=condition_kwargs)
if mapping.ildj is not None:
return mapping.ildj
try:
return self._inverse_log_det_jacobian(y, **condition_kwargs)
x = mapping.x
ildj = self._inverse_log_det_jacobian(y, **condition_kwargs)
except NotImplementedError as original_error:
# Since _inverse_log_det_jacobian was not implemented, try to see if
# it's implemented by the _inverse_and_inverse_log_det_jacobian member.
try:
return self._inverse_and_inverse_log_det_jacobian(
y, **condition_kwargs)[1]
x, ildj = self._inverse_and_inverse_log_det_jacobian(
y, **condition_kwargs)
if mapping.x is not None:
x = mapping.x
except NotImplementedError:
raise original_error
if self.is_constant_jacobian:
self._constant_ildj = ildj
x = x if mapping.x is None else mapping.x
mapping = mapping.merge(x=x, ildj=ildj)
self._cache(mapping)
return mapping.ildj
def _inverse_and_inverse_log_det_jacobian(self, y):
raise NotImplementedError(
@ -473,18 +607,30 @@ class Bijector(object):
with self._name_scope(name, [y]):
y = ops.convert_to_tensor(y, name="y")
self._maybe_assert_dtype(y)
mapping = self._lookup(y=y, condition_kwargs=condition_kwargs)
if mapping.x is not None and mapping.ildj is not None:
return mapping.x, mapping.ildj
try:
return self._inverse_and_inverse_log_det_jacobian(
x, ildj = self._inverse_and_inverse_log_det_jacobian(
y, **condition_kwargs)
except NotImplementedError as original_error:
# Since _inverse_and_inverse_log_det_jacobian was not implemented, try
# to see if we can separately use _inverse and
# _inverse_log_det_jacobian members.
try:
return (self._inverse(y, **condition_kwargs),
self._inverse_log_det_jacobian(y, **condition_kwargs))
x = self._inverse(y, **condition_kwargs)
if self._constant_ildj is None:
ildj = self._inverse_log_det_jacobian(y, **condition_kwargs)
except NotImplementedError:
raise original_error
if self._constant_ildj is not None:
ildj = self._constant_ildj # Ignore any ildj we may/not have.
elif self.is_constant_jacobian:
self._constant_ildj = ildj
x = x if mapping.x is None else mapping.x
mapping = mapping.merge(x=x, ildj=ildj)
self._cache(mapping)
return mapping.x, mapping.ildj
def _forward_log_det_jacobian(self, x):
raise NotImplementedError(
@ -509,16 +655,29 @@ class Bijector(object):
nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented.
"""
with self._name_scope(name, [x]):
if self._constant_ildj is not None:
# Need "-1. *" to avoid invalid-unary-operand-type linter warning.
return -1. * self._constant_ildj
x = ops.convert_to_tensor(x, name="x")
self._maybe_assert_dtype(x)
mapping = self._lookup(x=x, condition_kwargs=condition_kwargs)
if mapping.ildj is not None:
return -mapping.ildj
y = None
try:
return self._forward_log_det_jacobian(x, **condition_kwargs)
ildj = -self._forward_log_det_jacobian(x, **condition_kwargs)
except NotImplementedError as original_error:
try:
y = self.inverse(x, **condition_kwargs)
return -self.inverse_log_det_jacobian(y, **condition_kwargs)
y = self.inverse(x, **condition_kwargs) if y is None else y
ildj = self.inverse_log_det_jacobian(y, **condition_kwargs)
except NotImplementedError:
raise original_error
if self.is_constant_jacobian:
self._constant_ildj = ildj
y = y if mapping.y is None else mapping.y
mapping = mapping.merge(y=y, ildj=ildj)
self._cache(mapping)
return -mapping.ildj
@contextlib.contextmanager
def _name_scope(self, name=None, values=None):
@ -534,6 +693,31 @@ class Bijector(object):
raise TypeError("Input had dtype %s but expected %s." %
(self.dtype, x.dtype))
def _cache(self, mapping):
"""Helper which stores mapping info in forward/inverse dicts."""
if self._constant_ildj is not None:
# Fold in ildj if known constant Jacobian.
mapping = mapping.merge(ildj=self._constant_ildj)
# Merging from lookup is an added check that we're not overwriting anything
# which is not None.
mapping = mapping.merge(mapping=self._lookup(
mapping.x, mapping.y, mapping.condition_kwargs))
if mapping.x is None or mapping.y is None:
ValueError("Caching expects both (x,y) to be known, i.e., not None.")
self._from_x[mapping.x_key] = mapping
self._from_y[mapping.y_key] = mapping
def _lookup(self, x=None, y=None, condition_kwargs=None):
"""Helper which retrieves mapping info from forward/inverse dicts."""
mapping = _Mapping(x=x, y=y, condition_kwargs=condition_kwargs)
# Since _cache requires both x,y to be set, we only need to do one cache
# lookup since the mapping is always in both or neither.
if mapping.x is not None:
return self._from_x.get(mapping.x_key, mapping)
if mapping.y is not None:
return self._from_y.get(mapping.y_key, mapping)
return mapping
class Inline(Bijector):
# pylint: disable=line-too-long
@ -547,7 +731,7 @@ class Inline(Bijector):
inverse_fn=tf.log,
inverse_log_det_jacobian_fn=(
lambda y: -tf.reduce_sum(tf.log(y), reduction_indices=-1)),
name="Exp")
name="exp")
```
The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`.
@ -573,8 +757,8 @@ class Inline(Bijector):
log o det o jacobian of the forward transformation.
is_constant_jacobian: `Boolean` indicating that the Jacobian is constant
for all input arguments.
validate_args: `Boolean` indicated whether arguments should be checked for
correctness.
validate_args: `Boolean` indicating whether arguments should be checked
for correctness.
name: `String`, name given to ops managed by this object.
"""
super(Inline, self).__init__(
@ -643,8 +827,8 @@ class Invert(Bijector):
Args:
bijector: Bijector instance.
validate_args: `Boolean` indicated whether arguments should be checked for
correctness.
validate_args: `Boolean` indicating whether arguments should be checked
for correctness.
name: `String`, name given to ops managed by this object.
"""
@ -713,8 +897,8 @@ class Chain(Bijector):
Args:
bijectors: Python list of bijector instances. An empty list makes this
bijector equivalent to the `Identity` bijector.
validate_args: `Boolean` indicated whether arguments should be checked for
correctness.
validate_args: `Boolean` indicating whether arguments should be checked
for correctness.
name: `String`, name given to ops managed by this object. Default: E.g.,
`Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`.
@ -794,12 +978,9 @@ class Identity(Bijector):
def __init__(self, validate_args=False, name="identity"):
super(Identity, self).__init__(
batch_ndims=0,
event_ndims=0,
is_constant_jacobian=True,
validate_args=validate_args,
name=name)
self._is_constant_jacobian = True
def _forward(self, x):
return x
@ -841,8 +1022,8 @@ class Exp(Bijector):
Args:
event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
associated with a particular draw from the distribution.
validate_args: `Boolean` indicated whether arguments should be checked for
correctness.
validate_args: `Boolean` indicating whether arguments should be checked
for correctness.
name: `String` name given to ops managed by this object.
"""
@ -923,8 +1104,8 @@ class ScaleAndShift(Bijector):
scale: `Tensor` used to scale input, i.e., `Y = g(X) = scale * X + shift`.
event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
associated with a particular draw from the distribution.
validate_args: `Boolean` indicated whether arguments should be checked for
correctness.
validate_args: `Boolean` indicating whether arguments should be checked
for correctness.
name: `String` name given to ops managed by this object.
"""
@ -1271,3 +1452,150 @@ class SigmoidCentered(SoftmaxCentered):
def __init__(self, validate_args=False, name="sigmoid_centered"):
super(SigmoidCentered, self).__init__(
validate_args=validate_args, name=name)
class CholeskyOuterProduct(Bijector):
# pylint: disable=line-too-long
"""Bijector which computes Y = g(X) = X X^T where X is a lower-triangular, positive-diagonal matrix.
`event_ndims` must be 0 or 2, i.e., scalar or matrix.
Note: the upper-triangular part of X is ignored (whether or not its zero).
Examples:
```python
bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]])
# Result: [[1, 1], [1, 5]], i.e., x x^T
bijector.SoftmaxCentered(event_ndims=2).inverse(y=[[1., 1], [1, 5]])
# Result: [[1, 0], [2, 1]], i.e., chol(y).
```
"""
# pylint: enable=line-too-long
def __init__(self, event_ndims=2, validate_args=False,
name="cholesky_outer_product"):
"""Instantiates the `CholeskyOuterProduct` bijector.
Args:
event_ndims: `constant` `int32` scalar `Tensor` indicating the number of
dimensions associated with a particular draw from the distribution. Must
be 0 or 2.
validate_args: `Boolean` indicating whether arguments should be checked
for correctness.
name: `String` name given to ops managed by this object.
Raises:
ValueError: if event_ndims is neither 0 or 2.
"""
self._parameters = {}
self._name = name
with self._name_scope("init", values=[event_ndims]):
event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
event_ndims = tensor_util.constant_value(event_ndims)
if event_ndims is None or event_ndims not in [0, 2]:
raise ValueError("`event_ndims` must be a TF constant which is 0 or 2")
self._static_event_ndims = event_ndims
super(CholeskyOuterProduct, self).__init__(
validate_args=validate_args,
name=name)
def _forward(self, x):
if self._static_event_ndims == 0:
return math_ops.square(x)
if self.validate_args:
is_matrix = check_ops.assert_rank_at_least(x, 2)
shape = array_ops.shape(x)
is_square = check_ops.assert_equal(shape[-2], shape[-1])
x = control_flow_ops.with_dependencies([is_matrix, is_square], x)
# For safety, explicitly zero-out the upper triangular part.
x = array_ops.matrix_band_part(x, -1, 0)
return math_ops.batch_matmul(x, x, adj_y=True)
def _inverse_and_inverse_log_det_jacobian(self, y):
x = (math_ops.sqrt(y) if self._static_event_ndims == 0
else linalg_ops.cholesky(y))
return x, -self._forward_log_det_jacobian(x)
def _forward_log_det_jacobian(self, x):
# Let Y be a symmetric, positive definite matrix and write:
# Y = X X^T
# where X is lower-triangular.
#
# Observe that,
# dY[i,j]/dX[a,b]
# = d/dX[a,b] { X[i,:] X[j,:] }
# = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] }
#
# To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is
# symmetric and X is lower-triangular, we need vectors of dimension:
# d = p (p + 1) / 2
# where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e.,
# k = { i (i + 1) / 2 + j i>=j
# { undef i<j
# and assume zero-based indexes. When k is undef, the element is dropped.
# Example:
# j k
# 0 1 2 3 /
# 0 [ 0 . . . ]
# i 1 [ 1 2 . . ]
# 2 [ 3 4 5 . ]
# 3 [ 6 7 8 9 ]
# Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With
# slight abuse: k(i,j)=undef means the element is dropped.)
#
# We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are
# defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b.
# In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since:
# (1) j<=i<a thus i,j!=a.
# (2) i=a>j thus i,j!=a.
#
# Since the Jacobian is lower-triangular, we need only compute the product
# of diagonal elements:
# d vec[Y] / d vec[X] @[k(i,j), k(i,j)]
# = X[j,j] + I[i=j] X[i,j]
# = 2 X[j,j].
# Since there is a 2 X[j,j] term for every lower-triangular element of X we
# conclude:
# |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
if self._static_event_ndims == 0:
if self.validate_args:
is_positive = check_ops.assert_positive(
x, message="All elements must be positive.")
x = control_flow_ops.with_dependencies([is_positive], x)
return math.log(2.) + math_ops.log(x)
diag = array_ops.matrix_diag_part(x)
if self.validate_args:
is_matrix = check_ops.assert_rank_at_least(
x, 2, message="Input must be a (batch of) matrix.")
shape = array_ops.shape(x)
is_square = check_ops.assert_equal(
shape[-2], shape[-1],
message="Input must be a (batch of) square matrix.")
# Assuming lower-triangular means we only need check diag>0.
is_positive_definite = check_ops.assert_positive(
diag, message="Input must be positive definite.")
x = control_flow_ops.with_dependencies(
[is_matrix, is_square, is_positive_definite], x)
# Create a column vector equal to: [p, p-1, ..., 2, 1]^T.
if x.get_shape().ndims is None or x.get_shape()[-1].value is None:
p = array_ops.shape(x)[-1]
else:
p = x.get_shape()[-1].value
exponents = array_ops.expand_dims(
math_ops.linspace(math_ops.cast(p, dtype=x.dtype), 1., p),
dim=1)
sum_weighted_log_diag = array_ops.squeeze(
math_ops.batch_matmul(math_ops.log(diag), exponents),
squeeze_dims=-1)
fldj = p * math.log(2.) + sum_weighted_log_diag
if x.get_shape().ndims is not None:
fldj.set_shape(x.get_shape()[:-2])
return fldj

View File

@ -327,12 +327,13 @@ class Distribution(_BaseDistribution):
for i, t in enumerate(graph_parents):
if t is None or not contrib_framework.is_tensor(t):
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
parameters = parameters or {}
self._dtype = dtype
self._is_continuous = is_continuous
self._is_reparameterized = is_reparameterized
self._allow_nan_stats = allow_nan_stats
self._validate_args = validate_args
self._parameters = parameters or {}
self._parameters = parameters
self._graph_parents = graph_parents
self._name = name or type(self).__name__
@ -434,6 +435,27 @@ class Distribution(_BaseDistribution):
"""Python boolean indicated possibly expensive checks are enabled."""
return self._validate_args
def copy(self, **override_parameters_kwargs):
"""Creates a deep copy of the distribution.
Note: the copy distribution may continue to depend on the original
intialization arguments.
Args:
**override_parameters_kwargs: String/value dictionary of initialization
arguments to override with new values.
Returns:
distribution: A new instance of `type(self)` intitialized from the union
of self.parameters and override_parameters_kwargs, i.e.,
`dict(self.parameters, **override_parameters_kwargs)`.
"""
parameters = dict(self.parameters, **override_parameters_kwargs)
# Python3 leaks "__class__" into `locals()` so we remove if present.
# TODO(b/32376812): Remove this pop.
parameters.pop("__class__", None)
return type(self)(**parameters)
def _batch_shape(self):
raise NotImplementedError("batch_shape is not implemented")

View File

@ -19,7 +19,6 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution as distributions
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
@ -160,7 +159,6 @@ class TransformedDistribution(distributions.Distribution):
name = name or bijector.name + distribution.name
self._distribution = distribution
self._bijector = bijector
self._inverse_cache = {}
super(TransformedDistribution, self).__init__(
dtype=self._distribution.dtype,
is_continuous=self._distribution.is_continuous,
@ -202,9 +200,7 @@ class TransformedDistribution(distributions.Distribution):
**distribution_kwargs)
# Recall that a bijector is named for its forward transform, i.e.,
# `Y = g(X)`,
y = self.bijector.forward(x, **bijector_kwargs)
self._inverse_cache[y] = x
return y
return self.bijector.forward(x, **bijector_kwargs)
@distribution_util.AppendDocstring(
"""Implements `(log o p o g^{-1})(y) + (log o det o J o g^{-1})(y)`,
@ -216,11 +212,9 @@ class TransformedDistribution(distributions.Distribution):
def _log_prob(self, y, bijector_kwargs=None, distribution_kwargs=None):
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
inverse_log_det_jacobian = self.bijector.inverse_log_det_jacobian(
x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian(
y, **bijector_kwargs)
return (self.distribution.log_prob(x, **distribution_kwargs) +
inverse_log_det_jacobian)
return ildj + self.distribution.log_prob(x, **distribution_kwargs)
@distribution_util.AppendDocstring(
"""Implements `p(g^{-1}(y)) det|J(g^{-1}(y))|`, where `g^{-1}` is the
@ -232,18 +226,16 @@ class TransformedDistribution(distributions.Distribution):
def _prob(self, y, bijector_kwargs=None, distribution_kwargs=None):
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
inverse_det_jacobian = math_ops.exp(self.bijector.inverse_log_det_jacobian(
y, **bijector_kwargs))
return (self.distribution.prob(x, **distribution_kwargs) *
inverse_det_jacobian)
x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian(
y, **bijector_kwargs)
return math_ops.exp(ildj) * self.distribution.prob(x, **distribution_kwargs)
@distribution_util.AppendDocstring(
condition_kwargs_dict=_condition_kwargs_dict)
def _log_cdf(self, y, bijector_kwargs=None, distribution_kwargs=None):
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
x = self.bijector.inverse(y, **bijector_kwargs)
return self.distribution.log_cdf(x, distribution_kwargs)
@distribution_util.AppendDocstring(
@ -251,7 +243,7 @@ class TransformedDistribution(distributions.Distribution):
def _cdf(self, y, bijector_kwargs=None, distribution_kwargs=None):
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
x = self.bijector.inverse(y, **bijector_kwargs)
return self.distribution.cdf(x, **distribution_kwargs)
@distribution_util.AppendDocstring(
@ -260,7 +252,7 @@ class TransformedDistribution(distributions.Distribution):
bijector_kwargs=None, distribution_kwargs=None):
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
x = self.bijector.inverse(y, **bijector_kwargs)
return self.distribution.log_survival_function(x, **distribution_kwargs)
@distribution_util.AppendDocstring(
@ -269,13 +261,5 @@ class TransformedDistribution(distributions.Distribution):
bijector_kwargs=None, distribution_kwargs=None):
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
x = self.bijector.inverse(y, **bijector_kwargs)
return self.distribution.survival_function(x, **distribution_kwargs)
def _inverse_possibly_from_cache(self, y, bijector_kwargs):
"""Return `self._inverse(y)`, possibly using cached value."""
y = ops.convert_to_tensor(y, name="y")
if y in self._inverse_cache:
return self._inverse_cache[y]
else:
return self.bijector.inverse(y, **bijector_kwargs)

View File

@ -327,6 +327,6 @@ if __name__ == '__main__':
default=True,
help='Use fake input data.'
)
FLAGS = parser.parse_args()
FLAGS, unparsed = parser.parse_known_args()
tf.test.main()

View File

@ -243,6 +243,7 @@ class KMeansClustering(estimator.Estimator,
).training_graph()
incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
self._loss = tf.reduce_sum(losses)
tf.scalar_summary('loss/raw', self._loss)
training_op = with_dependencies([training_op, incr_step], self._loss)
return training_op, self._loss

View File

@ -24,16 +24,20 @@ from tensorflow.contrib import framework as contrib_framework
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as vars_
from tensorflow.python.training import moving_averages
from tensorflow.python.training import optimizer as optimizer_
from tensorflow.python.training import training as train
OPTIMIZER_CLS_NAMES = {
"Adagrad": train.AdagradOptimizer,
"Adam": train.AdamOptimizer,
@ -104,7 +108,11 @@ def optimize_loss(loss,
gradient_multipliers: dict of variables or variable names to floats.
If present, gradients for specified
variables will be multiplied by given constant.
clip_gradients: float or `None`, clips gradients by this value.
clip_gradients: float, callable or `None`. If float, is provided, a global
clipping is applied to prevent the norm of the gradient to exceed this
value. Alternatively, a callable can be provided e.g.: adaptive_clipping.
This callable takes a `list` of `(gradients, variables)` `tuple`s and
returns the same thing with the gradients modified.
learning_rate_decay_fn: function, takes `learning_rate` and `global_step`
`Tensor`s, returns `Tensor`.
Can be used to implement any learning rate decay
@ -132,6 +140,7 @@ def optimize_loss(loss,
* `global_step` is an invalid type or shape.
* `learning_rate` is an invalid type or value.
* `optimizer` is wrong type.
* `clip_gradients' is not float or callable.
* `learning_rate` and `learning_rate_decay_fn` are supplied, but no
`global_step` is available.
"""
@ -224,9 +233,18 @@ def optimize_loss(loss,
if gradient_multipliers is not None:
gradients = _multiply_gradients(gradients, gradient_multipliers)
if "gradient_norm" in summaries:
logging_ops.scalar_summary("global_norm/gradient_norm",
clip_ops.global_norm(zip(*gradients)[0]))
# Optionally clip gradients by global norm.
if clip_gradients is not None:
if isinstance(clip_gradients, float):
gradients = _clip_gradients_by_norm(gradients, clip_gradients)
elif callable(clip_gradients):
gradients = clip_gradients(gradients)
elif clip_gradients is not None:
raise ValueError(
"Unknown type %s for clip_gradients" % type(clip_gradients))
# Add scalar summary for loss.
if "loss" in summaries:
@ -241,11 +259,15 @@ def optimize_loss(loss,
if grad_values is not None:
if "gradients" in summaries:
logging_ops.histogram_summary(variable.name + "/gradients",
logging_ops.histogram_summary("gradients/" + variable.name,
grad_values)
if "gradient_norm" in summaries:
logging_ops.histogram_summary(variable.name + "/gradient_norm",
clip_ops.global_norm([grad_values]))
logging_ops.scalar_summary("gradient_norm/" + variable.name,
clip_ops.global_norm([grad_values]))
if clip_gradients is not None and "gradient_norm" in summaries:
logging_ops.scalar_summary("global_norm/clipped_gradient_norm",
clip_ops.global_norm(zip(*gradients)[0]))
# Create gradient updates.
grad_updates = opt.apply_gradients(gradients,
@ -266,6 +288,101 @@ def _clip_gradients_by_norm(grads_and_vars, clip_gradients):
return list(zip(clipped_gradients, variables))
def _adaptive_max_norm(norm, std_factor, decay, global_step, epsilon, name):
"""Find max_norm given norm and previous average."""
with vs.variable_scope(name, "AdaptiveMaxNorm", [norm]):
log_norm = math_ops.log(norm + epsilon)
def moving_average(name, value, decay):
moving_average_variable = vs.get_variable(
name, shape=value.get_shape(), dtype=value.dtype,
initializer=init_ops.zeros_initializer, trainable=False)
return moving_averages.assign_moving_average(
moving_average_variable, value, decay)
# quicker adaptation at the beginning
if global_step is not None:
n = math_ops.to_float(global_step)
decay = math_ops.minimum(decay, n / (n + 1.))
# update averages
mean = moving_average("mean", log_norm, decay)
sq_mean = moving_average("sq_mean", math_ops.square(log_norm), decay)
variance = sq_mean - math_ops.square(mean)
std = math_ops.sqrt(math_ops.maximum(epsilon, variance))
max_norms = math_ops.exp(mean + std_factor*std)
return max_norms, mean
def adaptive_clipping_fn(std_factor=2.,
decay=0.95,
static_max_norm=None,
global_step=None,
report_summary=False,
epsilon=1e-8,
name=None):
"""Adapt the clipping value using statistics on the norms.
Implement adaptive gradient as presented in section 3.2.1 of
https://arxiv.org/abs/1412.1602.
Keeps a moving average of the mean and std of the log(norm) of the gradient.
if the norm exceeds `exp(mean + std_factor*std)`, all gradients are rescaled
such that the global norm becomes `exp(mean)`.
Args:
std_factor: Python scaler (or tensor).
`max_norm = exp(mean + std_factor*std)`
decay: The smoothing factor of the moving averages.
static_max_norm: If provided, will threshold the norm to this value as an
extra safety.
global_step: Optional global_step. If provided, `decay = decay*n/(n+1)`.
This provides a quicker adaptation of the mean for the first steps.
report_summary: If `True`, will add histogram summaries of the `max_norm`.
epsilon: Small value chosen to avoid zero variance.
name: The name for this operation is used to scope operations and summaries.
Returns:
A function for applying gradient clipping.
"""
def gradient_clipping(grads_and_vars):
"""Internal function for adaptive clipping."""
grads, variables = zip(*grads_and_vars)
norm = clip_ops.global_norm(grads)
max_norm, log_mean = _adaptive_max_norm(
norm, std_factor, decay, global_step, epsilon, name)
# reports the max gradient norm for debugging
if report_summary:
logging_ops.scalar_summary(
"global_norm/adaptive_max_gradient_norm", max_norm)
# factor will be 1. if norm is smaller than max_norm
factor = math_ops.select(norm < max_norm,
array_ops.ones_like(norm),
math_ops.exp(log_mean) / norm)
if static_max_norm is not None:
factor = math_ops.minimum(static_max_norm / norm, factor)
# apply factor
clipped_grads = []
for grad in grads:
if grad is None:
clipped_grads.append(None)
elif isinstance(grad, ops.IndexedSlices):
clipped_grads.append(ops.IndexedSlices(
grad.values * factor, grad.indices, grad.dense_shape))
else:
clipped_grads.append(grad * factor)
return list(zip(clipped_grads, variables))
return gradient_clipping
def _add_scaled_noise_to_gradients(grads_and_vars, gradient_noise_scale):
"""Adds scaled noise from a 0-mean normal distribution to gradients."""
gradients, variables = zip(*grads_and_vars)

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
@ -179,6 +180,26 @@ class OptimizersTest(tf.test.TestCase):
self.assertAlmostEqual(var_value, 9.98999, 4)
self.assertEqual(global_step_value, 1)
def testAdaptiveGradientClip(self):
with self.test_session() as session:
x, var, loss, global_step = _setup_model()
clip_gradients = tf.contrib.layers.adaptive_clipping_fn()
train = tf.contrib.layers.optimize_loss(loss,
global_step,
learning_rate=0.1,
optimizer="SGD",
clip_gradients=clip_gradients)
tf.initialize_all_variables().run()
session.run(train, feed_dict={x: 5})
var_value, global_step_value = session.run([var, global_step])
self.assertAlmostEqual(var_value, 9.8916, 4)
self.assertEqual(global_step_value, 1)
var_count = 0
for var in tf.all_variables():
if var.name.startswith("OptimizeLoss/AdaptiveMaxNorm"):
var_count += 1
self.assertEqual(2, var_count)
def testGradientMultiply(self):
with self.test_session() as session:
x, var, loss, global_step = _setup_model()
@ -332,5 +353,70 @@ class OptimizersTest(tf.test.TestCase):
self.assertEqual(update_var_value, 20)
self.assertEqual(global_step_value, 1)
class AdaptiveClipping(tf.test.TestCase):
def testAverages(self):
with self.test_session() as session:
scale = 2.
grad = tf.ones([3, 4]) * scale
log_norm = np.log(np.sqrt(scale**2 * grad.get_shape().num_elements()))
grads_and_vars = [(grad, grad)]
grads_and_vars = tf.contrib.layers.adaptive_clipping_fn(
decay=0.5)(grads_and_vars)
var_dict = {}
for var in tf.all_variables():
if var.name.startswith("AdaptiveMaxNorm"):
var_dict[var.name.split(":")[0]] = var
self.assertEqual(2, len(var_dict))
moving_mean = var_dict["AdaptiveMaxNorm/mean"]
moving_sq_mean = var_dict["AdaptiveMaxNorm/sq_mean"]
tf.initialize_all_variables().run()
mean, sq_mean = session.run([moving_mean, moving_sq_mean])
self.assertEqual([0], mean)
self.assertEqual([0], sq_mean)
for i in range(20):
mean, sq_mean, _ = session.run(
[moving_mean, moving_sq_mean, grads_and_vars[0][0]])
if i == 0:
self.assertLess(mean, 0.9 * log_norm)
self.assertLess(sq_mean, 0.9 * log_norm**2)
self.assertAlmostEqual(float(mean), log_norm, places=4)
self.assertAlmostEqual(float(sq_mean), log_norm**2, places=4)
def testClip(self):
with self.test_session() as session:
spike = 1000.
multiplier = tf.placeholder(tf.float32, [], "multiplier")
step = tf.placeholder(tf.int32, [], "step")
grad = tf.ones([3, 4]) * multiplier
grads_and_vars = [(grad, grad)]
grads_and_vars = tf.contrib.layers.adaptive_clipping_fn(
decay=0.9, global_step=step)(grads_and_vars)
tf.initialize_all_variables().run()
def run(scale, i):
return session.run(grads_and_vars[0][0],
feed_dict={multiplier: scale, step: i})
for i in range(20):
scale = [1., -2.][i % 2]
clipped_grad = run(scale, i)
if i > 3:
self.assertAllClose(np.ones(clipped_grad.shape)*scale, clipped_grad)
# assert that the spike will have low influence.
clipped_grad = run(spike, 20)
self.assertTrue((clipped_grad < 25.).all())
# assert that a repeated spike will converge to this new value.
for i in range(10):
clipped_grad = run(spike, i + 21)
self.assertAllClose(np.ones(clipped_grad.shape)*spike, clipped_grad)
if __name__ == "__main__":
tf.test.main()

View File

@ -35,6 +35,6 @@ from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassi
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossMonitor
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
from tensorflow.contrib.learn.python.learn.estimators.svm import SVM

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.session_bundle import exporter
@ -27,6 +28,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@deprecated('2016-11-30', 'Please write an appropriate function for use with'
' your estimator.')
def classification_signature_fn(examples, unused_features, predictions):
"""Creates classification signature from given examples and predictions.
@ -61,6 +64,7 @@ class Classifier(estimator.Estimator):
CLASS_OUTPUT = 'classes'
PROBABILITY_OUTPUT = 'probabilities'
@deprecated('2016-11-30', 'Please use Estimator directly.')
def __init__(self, model_fn, n_classes, model_dir=None, config=None,
params=None, feature_engineering_fn=None):
"""Constructor for Classifier.

View File

@ -309,7 +309,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
inputs=rnn_outputs,
num_outputs=self._target_column.num_label_columns,
activation_fn=None,
trainable=False)
trainable=True)
return activations, final_state
@abc.abstractmethod

View File

@ -429,7 +429,7 @@ class SingleValueRNNEstimatorTest(tf.test.TestCase):
cell_type = 'basic_rnn'
cell_size = 8
optimizer_type = 'Momentum'
learning_rate = 0.5
learning_rate = 0.1
momentum = 0.9
loss_threshold = 0.1

View File

@ -36,6 +36,7 @@ from tensorflow.contrib import layers
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework import get_graph_from_inputs
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
from tensorflow.contrib.learn.python.learn import evaluable
@ -88,8 +89,11 @@ class ModelFnOps(
collections.namedtuple('ModelFnOps', ['predictions', 'loss', 'training_op',
'default_metrics', 'signature_fn'])):
def __new__(cls, predictions, loss, training_op, default_metrics,
signature_fn, mode):
def __new__(cls, mode, predictions=None, loss=None, training_op=None,
default_metrics=None, signature_fn=None):
# Assert all ops are from the same graph.
get_graph_from_inputs((predictions, loss, training_op))
# Validate training_op.
if training_op is None:
if mode == ModeKeys.TRAIN:
@ -1042,13 +1046,16 @@ class Estimator(BaseEstimator):
if isinstance(model_fn_results, ModelFnOps):
return model_fn_results
else:
# Here model_fn_ops should be a tuple with 3 elements.
if len(model_fn_results) != 3:
raise ValueError('Unrecognized value returned by model_fn, '
'please return ModelFnOps.')
return ModelFnOps(model_fn_results[0], model_fn_results[1],
model_fn_results[2], None, None, mode)
# Here model_fn_ops should be a tuple with 3 elements.
if len(model_fn_results) != 3:
raise ValueError('Unrecognized value returned by model_fn, '
'please return ModelFnOps.')
return ModelFnOps(
mode=mode,
predictions=model_fn_results[0],
loss=model_fn_results[1],
training_op=model_fn_results[2])
def _get_train_ops(self, features, targets):
"""Method that builds model graph and returns trainer ops.

View File

@ -229,20 +229,30 @@ class _Head(object):
else:
train_op = control_flow_ops.group(*additional_train_op)
return estimator.ModelFnOps(None, loss, train_op,
self._default_metric(),
self._create_signature_fn(), mode)
return estimator.ModelFnOps(
mode=estimator.ModeKeys.TRAIN,
loss=loss,
training_op=train_op,
default_metrics=self._default_metric(),
signature_fn=self._create_signature_fn())
if mode == estimator.ModeKeys.INFER:
predictions = self._infer_op(logits, logits_input)
return estimator.ModelFnOps(predictions, None, None,
self._default_metric(),
self._create_signature_fn(), mode)
return estimator.ModelFnOps(
mode=estimator.ModeKeys.INFER,
predictions=self._infer_op(logits, logits_input),
default_metrics=self._default_metric(),
signature_fn=self._create_signature_fn())
if mode == estimator.ModeKeys.EVAL:
predictions, loss = self._eval_op(features, target, logits, logits_input)
return estimator.ModelFnOps(predictions, loss, None,
self._default_metric(),
self._create_signature_fn(), mode)
raise ValueError("mode=%s unrecognized" % str(mode))
return estimator.ModelFnOps(
mode=estimator.ModeKeys.EVAL,
predictions=predictions,
loss=loss,
default_metrics=self._default_metric(),
signature_fn=self._create_signature_fn())
raise ValueError("mode=%s unrecognized." % str(mode))
@abc.abstractmethod
def _training_loss(self, features, target, logits=None, logits_input=None,

View File

@ -17,25 +17,28 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.learn.python.learn import monitors as mon
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.data import data_ops
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook
KEYS_NAME = 'keys'
LOSS_NAME = 'rf_training_loss'
def _assert_float32(tensors):
@ -56,58 +59,124 @@ def _assert_float32(tensors):
raise TypeError('Expected dtype=float32, %s.' % tensor)
class TensorForestLossMonitor(mon.EveryN):
"""Terminates training when training loss stops decreasing."""
class TensorForestLossHook(session_run_hook.SessionRunHook):
"""Monitor to request stop when loss stops decreasing."""
def __init__(self,
early_stopping_rounds,
every_n_steps):
super(TensorForestLossMonitor, self).__init__(every_n_steps=every_n_steps)
def __init__(self, early_stopping_rounds):
self.early_stopping_rounds = early_stopping_rounds
self.min_loss = None
self.min_loss_step = 0
self.last_step = -1
# self.steps records the number of steps for which the loss has been
# non-decreasing
self.steps = 0
def step_begin(self, step):
super(TensorForestLossMonitor, self).step_begin(step)
return [self._loss_op_name]
def before_run(self, run_context):
return session_run_hook.SessionRunArgs(
{'global_step': contrib_framework.get_global_step(),
'current_loss': run_context.session.graph.get_operation_by_name(
LOSS_NAME).outputs[0]})
def set_estimator(self, est):
"""This function gets called in the same graph as _get_train_ops."""
super(TensorForestLossMonitor, self).set_estimator(est)
self._loss_op_name = est.training_loss.name
def after_run(self, run_context, run_values):
current_loss = run_values.results['current_loss']
current_step = run_values.results['global_step']
self.steps += 1
# Gaurd against the global step going backwards, which might happen
# if we recover from something.
if self.last_step == -1 or self.last_step > current_step:
logging.info('TensorForestLossHook resetting last_step.')
self.last_step = current_step
self.steps = 0
return
def every_n_step_end(self, step, outputs):
super(TensorForestLossMonitor, self).every_n_step_end(step, outputs)
current_loss = outputs[self._loss_op_name]
if self.min_loss is None or current_loss < self.min_loss:
self.min_loss = current_loss
self.min_loss_step = step
return step - self.min_loss_step >= self.early_stopping_rounds
self.steps = 0
if self.steps > self.early_stopping_rounds:
logging.info('TensorForestLossHook requesting stop.')
run_context.request_stop()
class TensorForestEstimator(estimator.BaseEstimator):
def get_model_fn(params, graph_builder_class, device_assigner,
weights_name=None, keys_name=None):
"""Return a model function given a way to construct a graph builder."""
def _model_fn(features, targets):
"""Function that returns predictions, training loss, and training op."""
weights = None
keys = None
if weights_name and weights_name in features:
weights = features.pop(weights_name)
if keys_name and keys_name in features:
keys = features.pop(keys_name)
processed_features, spec = data_ops.ParseDataTensorOrDict(features)
_assert_float32(processed_features)
if targets is not None:
targets = data_ops.ParseLabelTensorOrDict(targets)
_assert_float32(targets)
graph_builder = graph_builder_class(params, device_assigner=device_assigner)
inference = {eval_metrics.INFERENCE_PROB_NAME:
graph_builder.inference_graph(processed_features,
data_spec=spec)}
if not params.regression:
inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax(
inference[eval_metrics.INFERENCE_PROB_NAME], 1)
if keys:
inference[KEYS_NAME] = keys
# targets might be None if we're doing prediction (which brings up the
# question of why we force everything to adhere to a single model_fn).
training_loss = None
training_graph = None
if targets is not None:
training_loss = graph_builder.training_loss(processed_features, targets,
data_spec=spec,
name=LOSS_NAME)
training_graph = control_flow_ops.group(
graph_builder.training_graph(
processed_features, targets, data_spec=spec,
input_weights=weights),
state_ops.assign_add(contrib_framework.get_global_step(), 1))
# Put weights back in
if weights is not None:
features[weights_name] = weights
return (inference, training_loss, training_graph)
return _model_fn
class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable):
"""An estimator that can train and evaluate a random forest."""
def __init__(self, params, device_assigner=None, model_dir=None,
graph_builder_class=tensor_forest.RandomForestGraphs,
master='', accuracy_metric=None,
tf_random_seed=None, config=None,
feature_engineering_fn=None):
config=None, weights_name=None, keys_name=None,
feature_engineering_fn=None, early_stopping_rounds=100):
self.params = params.fill()
self.accuracy_metric = (accuracy_metric or
('r2' if self.params.regression else 'accuracy'))
self.data_feeder = None
self.device_assigner = (
device_assigner or tensor_forest.RandomForestDeviceAssigner())
self.graph_builder_class = graph_builder_class
self.training_args = {}
self.construction_args = {}
self._feature_engineering_fn = (
feature_engineering_fn or
(lambda features, targets: (features, targets)))
self.early_stopping_rounds = early_stopping_rounds
self._estimator = estimator.Estimator(
model_fn=get_model_fn(params, graph_builder_class, device_assigner,
weights_name=weights_name, keys_name=keys_name),
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
super(TensorForestEstimator, self).__init__(model_dir=model_dir,
config=config)
def evaluate(
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
steps=None, metrics=None, name=None):
"""See evaluable.Evaluable."""
return self._estimator.evaluate(
input_fn=input_fn, x=x, y=y, feed_fn=feed_fn,
batch_size=batch_size, steps=steps,
metrics=metrics, name=name)
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
monitors=None, max_steps=None):
"""See trainable.Trainable."""
if not monitors:
monitors = [TensorForestLossHook(self.early_stopping_rounds)]
self._estimator.fit(input_fn=input_fn, x=x, y=y,
batch_size=batch_size, steps=steps, monitors=monitors,
max_steps=max_steps)
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
@ -135,13 +204,14 @@ class TensorForestEstimator(estimator.BaseEstimator):
Raises:
ValueError: If both or neither of x and input_fn were given.
"""
results = super(TensorForestEstimator, self).predict(
results = self._estimator.predict(
x=x, input_fn=input_fn, batch_size=batch_size, outputs=outputs,
as_iterable=as_iterable)
if as_iterable:
return (r['probabilities'] for r in results)
return (x[eval_metrics.INFERENCE_PROB_NAME] for x in results)
else:
return results['probabilities']
return results[eval_metrics.INFERENCE_PROB_NAME]
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
@ -168,16 +238,16 @@ class TensorForestEstimator(estimator.BaseEstimator):
Numpy array of predicted classes or regression values (or an iterable of
predictions if as_iterable is True).
"""
probabilities = self.predict_proba(
results = self._estimator.predict(
x=x, input_fn=input_fn, batch_size=batch_size, outputs=outputs,
as_iterable=as_iterable)
if self.params.regression:
return probabilities
predict_name = (eval_metrics.INFERENCE_PROB_NAME if self.params.regression
else eval_metrics.INFERENCE_PRED_NAME)
if as_iterable:
return (x[predict_name] for x in results)
else:
if as_iterable:
return (np.argmax(p, axis=0) for p in probabilities)
else:
return np.argmax(probabilities, axis=1)
return results[predict_name]
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
@ -186,100 +256,40 @@ class TensorForestEstimator(estimator.BaseEstimator):
self, x=None, input_fn=None, axis=None, batch_size=None, outputs=None,
as_iterable=True):
"""Same as predict but also returns the example keys."""
results = super(TensorForestEstimator, self).predict(
results = self._estimator.predict(
x=x, input_fn=input_fn, batch_size=batch_size, outputs=outputs,
as_iterable=as_iterable)
if self.params.regression:
if as_iterable:
return ((r['probabilities'], r.get('keys', None)) for r in results)
else:
return results['probabilities'], results.get('keys', None)
predict_name = (eval_metrics.INFERENCE_PROB_NAME if self.params.regression
else eval_metrics.INFERENCE_PRED_NAME)
if as_iterable:
return ((x[predict_name], x.get(KEYS_NAME, None)) for x in results)
else:
if as_iterable:
return ((np.argmax(r['probabilities'], axis=0),
r.get('keys', None)) for r in results)
else:
return np.argmax(results['probabilities'], axis=1), results.get('keys',
None)
def _get_train_ops(self, features, targets):
"""Method that builds model graph and returns trainer ops.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
targets: `Tensor` or `dict` of `Tensor` objects.
Returns:
Tuple of train `Operation` and loss `Tensor`.
"""
features, _, weights, spec = data_ops.ParseDataTensorOrDict(features)
labels = data_ops.ParseLabelTensorOrDict(targets)
features, labels = self._feature_engineering_fn(features, labels)
_assert_float32(features)
_assert_float32(labels)
if weights is not None:
if 'input_weights' in self.training_args:
logging.warning('Replacing input_weights in training_args.')
self.training_args['input_weights'] = weights
graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner,
**self.construction_args)
epoch = None
if self.data_feeder:
epoch = self.data_feeder.make_epoch_variable()
train = control_flow_ops.group(
graph_builder.training_graph(
features, labels, data_spec=spec, epoch=epoch,
**self.training_args),
state_ops.assign_add(contrib_framework.get_global_step(), 1))
self.training_loss = graph_builder.training_loss(features, targets)
return train, self.training_loss
def _get_predict_ops(self, features):
graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner, training=False,
**self.construction_args)
features, keys, _, spec = data_ops.ParseDataTensorOrDict(features)
features, _ = self._feature_engineering_fn(features, None)
_assert_float32(features)
output_dict = {
'probabilities': graph_builder.inference_graph(features,
data_spec=spec)}
if keys is not None:
output_dict['keys'] = keys
return output_dict
def _get_eval_ops(self, features, targets, metrics):
features, _, _, spec = data_ops.ParseDataTensorOrDict(features)
labels = data_ops.ParseLabelTensorOrDict(targets)
features, labels = self._feature_engineering_fn(features, labels)
_assert_float32(features)
_assert_float32(labels)
graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner, training=False,
**self.construction_args)
probabilities = graph_builder.inference_graph(features, data_spec=spec)
# One-hot the labels.
if not self.params.regression:
labels = math_ops.to_int64(array_ops.one_hot(math_ops.to_int64(
array_ops.squeeze(labels)), self.params.num_classes, 1, 0))
if metrics is None:
metrics = {self.accuracy_metric:
eval_metrics.get_metric(self.accuracy_metric)}
result = {}
for name, metric in six.iteritems(metrics):
result[name] = metric(probabilities, labels)
return results[predict_name], results.get(KEYS_NAME, None)
def export(self,
export_dir,
input_fn,
signature_fn=None,
default_batch_size=1):
"""See BaseEstimator.export."""
# Reset model function with basic device assigner.
# Servo doesn't support distributed inference
# but it will try to respect device assignments if they're there.
# pylint: disable=protected-access
orig_model_fn = self._estimator._model_fn
self._estimator._model_fn = get_model_fn(
self.params, self.graph_builder_class,
tensor_forest.RandomForestDeviceAssigner())
result = self._estimator.export(
export_dir=export_dir,
use_deprecated_input_fn=True,
signature_fn=(signature_fn or
(export.regression_signature_fn
if self.params.regression else
export.classification_signature_fn_with_prob)),
default_batch_size=default_batch_size,
prediction_key=eval_metrics.INFERENCE_PROB_NAME)
self._estimator._model_fn = orig_model_fn
# pylint: enable=protected-access
return result

View File

@ -28,14 +28,30 @@ class TensorForestTrainerTests(tf.test.TestCase):
def testClassification(self):
"""Tests multi-class classification using matrix data as input."""
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
num_trees=3, max_nodes=1000, num_classes=3, num_features=4)
classifier = tf.contrib.learn.TensorForestEstimator(hparams)
num_trees=3, max_nodes=1000, num_classes=3, num_features=4,
split_after_samples=20)
classifier = tf.contrib.learn.TensorForestEstimator(hparams.fill())
iris = tf.contrib.learn.datasets.load_iris()
data = iris.data.astype(np.float32)
target = iris.target.astype(np.float32)
monitors = [tf.contrib.learn.TensorForestLossMonitor(10, 10)]
classifier.fit(x=data, y=target, steps=100, batch_size=50)
classifier.evaluate(x=data, y=target, steps=10)
def testClassificationTrainingLoss(self):
"""Tests multi-class classification using matrix data as input."""
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
num_trees=3, max_nodes=1000, num_classes=3, num_features=4)
classifier = tf.contrib.learn.TensorForestEstimator(
hparams, graph_builder_class=(
tf.contrib.tensor_forest.python.tensor_forest.TrainingLossForest))
iris = tf.contrib.learn.datasets.load_iris()
data = iris.data.astype(np.float32)
target = iris.target.astype(np.float32)
monitors = [tf.contrib.learn.TensorForestLossHook(10)]
classifier.fit(x=data, y=target, steps=100, monitors=monitors)
classifier.evaluate(x=data, y=target, steps=10)
@ -44,16 +60,15 @@ class TensorForestTrainerTests(tf.test.TestCase):
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
num_trees=3, max_nodes=1000, num_classes=1, num_features=13,
regression=True)
regression=True, split_after_samples=20)
regressor = tf.contrib.learn.TensorForestEstimator(hparams)
regressor = tf.contrib.learn.TensorForestEstimator(hparams.fill())
boston = tf.contrib.learn.datasets.load_boston()
data = boston.data.astype(np.float32)
target = boston.target.astype(np.float32)
monitors = [tf.contrib.learn.TensorForestLossMonitor(10, 10)]
regressor.fit(x=data, y=target, steps=100, monitors=monitors)
regressor.fit(x=data, y=target, steps=100, batch_size=50)
regressor.evaluate(x=data, y=target, steps=10)

View File

@ -627,7 +627,7 @@ def _eval_results_to_str(eval_results):
def _write_summary_results(output_dir, eval_results, current_global_step):
"""Writes eval results into summary file in given dir."""
logging.info('Saving evaluation summary for %d step: %s', current_global_step,
logging.info('Saving evaluation summary for step %d: %s', current_global_step,
_eval_results_to_str(eval_results))
summary_writer = get_summary_writer(output_dir)
summary = summary_pb2.Summary()

View File

@ -253,6 +253,18 @@ def _get_shared_file_name_queue(file_names, shuffle, num_epochs, name):
def _get_file_names(file_pattern, randomize_input):
"""Parse list of file names from pattern, optionally shuffled.
Args:
file_pattern: File glob pattern, or list of strings.
randomize_input: Whether to shuffle the order of file names.
Returns:
List of file names matching `file_pattern`.
Raises:
ValueError: If `file_pattern` is empty, or pattern matches no files.
"""
if isinstance(file_pattern, list):
file_names = file_pattern
if not file_names:
@ -304,6 +316,36 @@ def _read_keyed_batch_examples_helper(file_pattern,
parse_fn=None,
setup_shared_queue=False,
name=None):
"""Adds operations to read, queue, batch `Example` protos.
Args:
file_pattern: List of files or pattern of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int or scalar `Tensor` specifying the batch size to use.
reader: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor).
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If `None`, cycles through the dataset forever.
NOTE - If specified, creates a variable that must be initialized, so call
`tf.initialize_all_variables()` as shown in the tests.
queue_capacity: Capacity for input queue.
num_threads: The number of threads enqueuing examples.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once
parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done.
setup_shared_queue: Whether to set up a shared queue for file names.
name: Name of resulting op.
Returns:
Returns tuple of:
- `Tensor` of string keys.
- String `Tensor` of batched `Example` proto.
Raises:
ValueError: for invalid inputs.
"""
# Retrieve files to read.
file_names = _get_file_names(file_pattern, randomize_input)
@ -348,10 +390,10 @@ def _read_keyed_batch_examples_helper(file_pattern,
enqueue_many = read_batch_size > 1
if num_epochs is not None:
allow_smaller_final_batch = True
else:
if num_epochs is None:
allow_smaller_final_batch = False
else:
allow_smaller_final_batch = True
# Setup batching queue given list of read example tensors.
if randomize_input:
@ -505,7 +547,6 @@ def _read_keyed_batch_features_shared_queue(file_pattern,
Adding multiple queue runners for the parsed example queue helps maintain
a full queue when the subsequent computations overall are cheaper than
parsing.
parser_num_threads: (Deprecated) The number of threads to parse examples.
parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done.
name: Name of resulting op.

View File

@ -121,7 +121,8 @@ class GraphIOTest(tf.test.TestCase):
batch_size = 17
queue_capacity = 1234
name = "my_batch"
features = {"feature": tf.FixedLenFeature(shape=[0], dtype=tf.float32)}
shape = (0,)
features = {"feature": tf.FixedLenFeature(shape=shape, dtype=tf.float32)}
with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
features = tf.contrib.learn.io.read_batch_record_features(
@ -132,8 +133,11 @@ class GraphIOTest(tf.test.TestCase):
queue_capacity=queue_capacity,
reader_num_threads=2,
name=name)
self.assertEqual("%s/fifo_queue_1_Dequeue:0" % name,
features["feature"].name)
self.assertTrue(
"feature" in features, "'feature' missing from %s." % features.keys())
feature = features["feature"]
self.assertEqual("%s/fifo_queue_1_Dequeue:0" % name, feature.name)
self.assertAllEqual((batch_size,) + shape, feature.get_shape().as_list())
file_name_queue_name = "%s/file_name_queue" % name
file_names_name = "%s/input" % file_name_queue_name
example_queue_name = "%s/fifo_queue" % name
@ -161,6 +165,7 @@ class GraphIOTest(tf.test.TestCase):
reader=tf.TFRecordReader, randomize_input=True,
num_epochs=1,
queue_capacity=queue_capacity, name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
self.assertEqual("%s:1" % name, inputs.name)
file_name_queue_name = "%s/file_name_queue" % name
file_name_queue_limit_name = (
@ -190,6 +195,7 @@ class GraphIOTest(tf.test.TestCase):
_VALID_FILE_PATTERN, batch_size,
reader=tf.TFRecordReader, randomize_input=True,
queue_capacity=queue_capacity, name=name)
self.assertAllEqual((batch_size,), inputs.get_shape().as_list())
self.assertEqual("%s:1" % name, inputs.name)
file_name_queue_name = "%s/file_name_queue" % name
file_names_name = "%s/input" % file_name_queue_name
@ -234,6 +240,7 @@ class GraphIOTest(tf.test.TestCase):
filename, batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
@ -280,10 +287,13 @@ class GraphIOTest(tf.test.TestCase):
features = {"sequence": tf.FixedLenFeature([], tf.string)}
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
_, result = tf.contrib.learn.read_keyed_batch_features(
keys, result = tf.contrib.learn.read_keyed_batch_features(
filename, batch_size, features, tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
num_enqueue_threads=2, parse_fn=tf.decode_json_example, name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertEqual(1, len(result))
self.assertAllEqual((None,), result["sequence"].get_shape().as_list())
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
@ -319,6 +329,7 @@ class GraphIOTest(tf.test.TestCase):
filenames, batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
@ -354,7 +365,7 @@ class GraphIOTest(tf.test.TestCase):
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
_, inputs = _read_keyed_batch_examples_shared_queue(
keys, inputs = _read_keyed_batch_examples_shared_queue(
filenames,
batch_size,
reader=tf.TextLineReader,
@ -362,6 +373,8 @@ class GraphIOTest(tf.test.TestCase):
num_epochs=1,
queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
@ -418,7 +431,7 @@ class GraphIOTest(tf.test.TestCase):
with tf.Graph().as_default() as g1, tf.Session(
server.target, graph=g1) as session:
_, inputs = _read_keyed_batch_examples_shared_queue(
keys, inputs = _read_keyed_batch_examples_shared_queue(
filenames,
batch_size,
reader=tf.TextLineReader,
@ -426,6 +439,8 @@ class GraphIOTest(tf.test.TestCase):
num_epochs=1,
queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.initialize_local_variables())
# Run the three queues once manually.
@ -443,7 +458,7 @@ class GraphIOTest(tf.test.TestCase):
with tf.Graph().as_default() as g2, tf.Session(
server.target, graph=g2) as session:
_, inputs = _read_keyed_batch_examples_shared_queue(
keys, inputs = _read_keyed_batch_examples_shared_queue(
filenames,
batch_size,
reader=tf.TextLineReader,
@ -451,6 +466,8 @@ class GraphIOTest(tf.test.TestCase):
num_epochs=1,
queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
# Run the worker and the example queue.
self._run_queue(worker_file_name_queue_name, session)
@ -473,6 +490,7 @@ class GraphIOTest(tf.test.TestCase):
[filename], batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
read_batch_size=10, name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
@ -499,6 +517,8 @@ class GraphIOTest(tf.test.TestCase):
filename, batch_size,
reader=tf.TextLineReader, randomize_input=False,
num_epochs=1, queue_capacity=queue_capacity, name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
@ -537,6 +557,9 @@ class GraphIOTest(tf.test.TestCase):
reader=tf.TextLineReader, randomize_input=False,
num_epochs=1, queue_capacity=queue_capacity,
parse_fn=parse_fn, name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertEqual(1, len(inputs))
self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list())
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()

View File

@ -24,6 +24,7 @@ from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.session_bundle import exporter
from tensorflow.contrib.session_bundle import gc
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -53,7 +54,7 @@ def _get_saver():
else:
saver = None
if saver is None and variables.all_variables():
saver = tf_saver.Saver()
saver = tf_saver.Saver(write_version=saver_pb2.SaverDef.V1)
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
return saver

View File

@ -21,6 +21,7 @@ tensorflow/core/kernels/strided_slice_op_inst_4.cc
tensorflow/core/kernels/strided_slice_op_inst_3.cc
tensorflow/core/kernels/strided_slice_op_inst_2.cc
tensorflow/core/kernels/strided_slice_op_inst_1.cc
tensorflow/core/kernels/strided_slice_op_inst_0.cc
tensorflow/core/kernels/strided_slice_op.cc
tensorflow/core/kernels/stack_ops.cc
tensorflow/core/kernels/split_op.cc
@ -142,6 +143,7 @@ tensorflow/core/kernels/avgpooling_op.cc
tensorflow/core/kernels/argmax_op.cc
tensorflow/core/kernels/aggregate_ops.cc
tensorflow/core/kernels/dequantize_op.cc
tensorflow/core/kernels/meta_support.cc
tensorflow/core/kernels/quantization_utils.cc
tensorflow/core/kernels/quantize_down_and_shrink_range.cc
tensorflow/core/kernels/quantize_op.cc
@ -153,6 +155,7 @@ tensorflow/core/kernels/quantized_conv_ops.cc
tensorflow/core/kernels/quantized_matmul_op.cc
tensorflow/core/kernels/quantized_pooling_ops.cc
tensorflow/core/kernels/quantized_reshape_op.cc
tensorflow/core/kernels/requantization_range_op.cc
tensorflow/core/kernels/requantize.cc
tensorflow/core/ops/training_ops.cc
tensorflow/core/ops/string_ops.cc

View File

@ -95,11 +95,6 @@ Certain metrics, such as streaming_mean or streaming_accuracy, can be weighted
via a `weights` argument. The `weights` tensor must be the same size as the
labels and predictions tensors and results in a weighted average of the metric.
Other metrics, such as streaming_recall, streaming_precision, and streaming_auc,
are not well defined with regard to weighted samples. However, a binary
`ignore_mask` argument can be used to ignore certain values at graph executation
time.
## Metric `Ops`
@@streaming_accuracy

View File

@ -23,7 +23,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_args
from tensorflow.contrib.framework import tensor_util
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
@ -41,40 +40,6 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
IGNORE_MASK_DATE = '2016-10-19'
IGNORE_MASK_INSTRUCTIONS = (
'`ignore_mask` is being deprecated. Instead use `weights` with values 0.0 '
'and 1.0 to mask values. For example, `weights=tf.logical_not(mask)`.')
def _mask_weights(mask=None, weights=None):
"""Mask a given set of weights.
Elements are included when the corresponding `mask` element is `False`, and
excluded otherwise.
Args:
mask: An optional, `bool` `Tensor`.
weights: An optional `Tensor` whose shape matches `mask` if `mask` is not
`None`.
Returns:
Masked weights if `mask` and `weights` are not `None`, weights equivalent to
`mask` if `weights` is `None`, and otherwise `weights`.
Raises:
ValueError: If `weights` and `mask` are not `None` and have mismatched
shapes.
"""
if mask is not None:
check_ops.assert_type(mask, dtypes.bool)
if weights is None:
weights = array_ops.ones_like(mask, dtype=dtypes.float32)
weights = math_ops.cast(math_ops.logical_not(mask), weights.dtype) * weights
return weights
def _safe_div(numerator, denominator, name):
"""Divides two values, returning 0 if the denominator is <= 0.
@ -516,8 +481,7 @@ def streaming_accuracy(predictions, labels, weights=None,
updates_collections, name or 'accuracy')
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
def streaming_precision(predictions, labels, weights=None,
metrics_collections=None, updates_collections=None,
name=None):
"""Computes the precision of the predictions with respect to the labels.
@ -534,14 +498,11 @@ def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
`weights`.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Alternatively, if `ignore_mask` is not `None`, then mask values where
`ignore_mask` is `True`.
Args:
predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
labels: The ground truth values, a `bool` `Tensor` whose dimensions must
match `predictions`.
ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
metrics_collections: An optional list of collections that `precision` should
be added to.
@ -558,9 +519,8 @@ def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
`ignore_mask` is not `None` and its shape doesn't match `predictions`, or
if `weights` is not `None` and its shape doesn't match `predictions`, or
if either `metrics_collections` or `updates_collections` are not a list or
`weights` is not `None` and its shape doesn't match `predictions`, or if
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(
@ -570,7 +530,6 @@ def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
weights = _mask_weights(ignore_mask, weights)
true_positives, true_positives_update_op = _streaming_true_positives(
predictions, labels, weights, metrics_collections=None,
updates_collections=None, name=None)
@ -599,8 +558,7 @@ def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
return precision, update_op
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
def streaming_recall(predictions, labels, ignore_mask=None, weights=None,
def streaming_recall(predictions, labels, weights=None,
metrics_collections=None, updates_collections=None,
name=None):
"""Computes the recall of the predictions with respect to the labels.
@ -615,14 +573,11 @@ def streaming_recall(predictions, labels, ignore_mask=None, weights=None,
weights each prediction by the corresponding value in `weights`.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Alternatively, if `ignore_mask` is not `None`, then mask values where
`ignore_mask` is `True`.
Args:
predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
labels: The ground truth values, a `bool` `Tensor` whose dimensions must
match `predictions`.
ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
metrics_collections: An optional list of collections that `recall` should
be added to.
@ -639,9 +594,8 @@ def streaming_recall(predictions, labels, ignore_mask=None, weights=None,
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
`ignore_mask` is not `None` and its shape doesn't match `predictions`, or
if `weights` is not `None` and its shape doesn't match `predictions`, or
if either `metrics_collections` or `updates_collections` are not a list or
`weights` is not `None` and its shape doesn't match `predictions`, or if
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(name, 'recall', [predictions, labels]):
@ -649,7 +603,6 @@ def streaming_recall(predictions, labels, ignore_mask=None, weights=None,
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
weights = _mask_weights(ignore_mask, weights)
true_positives, true_positives_update_op = _streaming_true_positives(
predictions, labels, weights, metrics_collections=None,
updates_collections=None, name=None)
@ -1235,10 +1188,9 @@ def _at_k_name(name, k=None, class_id=None):
@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, '
'and reshape labels from [batch_size] to [batch_size, 1].')
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
def streaming_recall_at_k(predictions, labels, k, ignore_mask=None,
weights=None, metrics_collections=None,
updates_collections=None, name=None):
def streaming_recall_at_k(predictions, labels, k, weights=None,
metrics_collections=None, updates_collections=None,
name=None):
"""Computes the recall@k of the predictions with respect to dense labels.
The `streaming_recall_at_k` function creates two local variables, `total` and
@ -1255,15 +1207,12 @@ def streaming_recall_at_k(predictions, labels, k, ignore_mask=None,
increments `count` with the reduced sum of `weights`.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Alternatively, if `ignore_mask` is not `None`, then mask values where
`ignore_mask` is `True`.
Args:
predictions: A floating point tensor of dimension [batch_size, num_classes]
labels: A tensor of dimension [batch_size] whose type is in `int32`,
`int64`.
k: The number of top elements to look at for computing recall.
ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
metrics_collections: An optional list of collections that `recall_at_k`
should be added to.
@ -1279,26 +1228,23 @@ def streaming_recall_at_k(predictions, labels, k, ignore_mask=None,
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
`ignore_mask` is not `None` and its shape doesn't match `predictions`, or
if `weights` is not `None` and its shape doesn't match `predictions`, or
if either `metrics_collections` or `updates_collections` are not a list or
`weights` is not `None` and its shape doesn't match `predictions`, or if
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
return streaming_mean(in_top_k,
_mask_weights(ignore_mask, weights),
weights,
metrics_collections,
updates_collections,
name or _at_k_name('recall', k))
# TODO(ptucker): Validate range of values in labels?
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
def streaming_sparse_recall_at_k(predictions,
labels,
k,
class_id=None,
ignore_mask=None,
weights=None,
metrics_collections=None,
updates_collections=None,
@ -1328,8 +1274,6 @@ def streaming_sparse_recall_at_k(predictions,
`false_negative_at_<k>` using these values.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Alternatively, if `ignore_mask` is not `None`, then mask values where
`ignore_mask` is `True`.
Args:
predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
@ -1347,8 +1291,6 @@ def streaming_sparse_recall_at_k(predictions,
class_id: Integer class ID for which we want binary metrics. This should be
in range [0, num_classes), where num_classes is the last dimension of
`predictions`. If class_id is outside this range, the method returns NAN.
ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
the the first [D1, ... DN] dimensions of `predictions` and `labels`.
weights: An optional `Tensor` whose shape is broadcastable to the the first
[D1, ... DN] dimensions of `predictions` and `labels`.
metrics_collections: An optional list of collections that values should
@ -1365,16 +1307,14 @@ def streaming_sparse_recall_at_k(predictions,
`recall`.
Raises:
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
`predictions`, or if `weights` is not `None` and its shape doesn't match
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
ValueError: If `weights` is not `None` and its shape doesn't match
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
default_name = _at_k_name('recall', k, class_id=class_id)
with ops.name_scope(name, default_name, (predictions, labels)) as scope:
_, top_k_idx = nn.top_k(predictions, k)
top_k_idx = math_ops.to_int64(top_k_idx)
weights = _mask_weights(ignore_mask, weights)
tp, tp_update = _streaming_sparse_true_positive_at_k(
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
weights=weights)
@ -1396,7 +1336,6 @@ def _streaming_sparse_precision_at_k(top_k_idx,
labels,
k=None,
class_id=None,
ignore_mask=None,
weights=None,
metrics_collections=None,
updates_collections=None,
@ -1423,8 +1362,6 @@ def _streaming_sparse_precision_at_k(top_k_idx,
in range [0, num_classes), where num_classes is the last dimension of
`predictions`. If `class_id` is outside this range, the method returns
NAN.
ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
the the first [D1, ... DN] dimensions of `predictions` and `labels`.
weights: An optional `Tensor` whose shape is broadcastable to the the first
[D1, ... DN] dimensions of `predictions` and `labels`.
metrics_collections: An optional list of collections that values should
@ -1441,13 +1378,11 @@ def _streaming_sparse_precision_at_k(top_k_idx,
`precision`.
Raises:
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
`predictions`, or if `weights` is not `None` and its shape doesn't match
ValueError: If `weights` is not `None` and its shape doesn't match
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
top_k_idx = math_ops.to_int64(top_k_idx)
weights = _mask_weights(ignore_mask, weights)
tp, tp_update = _streaming_sparse_true_positive_at_k(
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
weights=weights)
@ -1466,12 +1401,10 @@ def _streaming_sparse_precision_at_k(top_k_idx,
# TODO(ptucker): Validate range of values in labels?
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
def streaming_sparse_precision_at_k(predictions,
labels,
k,
class_id=None,
ignore_mask=None,
weights=None,
metrics_collections=None,
updates_collections=None,
@ -1502,8 +1435,6 @@ def streaming_sparse_precision_at_k(predictions,
`false_positive_at_<k>` using these values.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Alternatively, if `ignore_mask` is not `None`, then mask values where
`ignore_mask` is `True`.
Args:
predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
@ -1522,8 +1453,6 @@ def streaming_sparse_precision_at_k(predictions,
in range [0, num_classes], where num_classes is the last dimension of
`predictions`. If `class_id` is outside this range, the method returns
NAN.
ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
the the first [D1, ... DN] dimensions of `predictions` and `labels`.
weights: An optional `Tensor` whose shape is broadcastable to the the first
[D1, ... DN] dimensions of `predictions` and `labels`.
metrics_collections: An optional list of collections that values should
@ -1540,21 +1469,19 @@ def streaming_sparse_precision_at_k(predictions,
`precision`.
Raises:
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
`predictions`, or if `weights` is not `None` and its shape doesn't match
ValueError: If `weights` is not `None` and its shape doesn't match
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
default_name = _at_k_name('precision', k, class_id=class_id)
with ops.name_scope(name, default_name,
(predictions, labels, ignore_mask, weights)) as scope:
(predictions, labels, weights)) as scope:
_, top_k_idx = nn.top_k(predictions, k)
return _streaming_sparse_precision_at_k(
top_k_idx=top_k_idx,
labels=labels,
k=k,
class_id=class_id,
ignore_mask=ignore_mask,
weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections,
@ -1562,11 +1489,9 @@ def streaming_sparse_precision_at_k(predictions,
# TODO(ptucker): Validate range of values in labels?
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
def streaming_sparse_precision_at_top_k(top_k_predictions,
labels,
class_id=None,
ignore_mask=None,
weights=None,
metrics_collections=None,
updates_collections=None,
@ -1595,8 +1520,6 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
`false_positive_at_k` using these values.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Alternatively, if `ignore_mask` is not `None`, then mask values where
`ignore_mask` is `True`.
Args:
top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
@ -1614,8 +1537,6 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
in range [0, num_classes), where num_classes is the last dimension of
`predictions`. If `class_id` is outside this range, the method returns
NAN.
ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
the the first [D1, ... DN] dimensions of `predictions` and `labels`.
weights: An optional `Tensor` whose shape is broadcastable to the the first
[D1, ... DN] dimensions of `predictions` and `labels`.
metrics_collections: An optional list of collections that values should
@ -1632,8 +1553,7 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
`precision`.
Raises:
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
`predictions`, or if `weights` is not `None` and its shape doesn't match
ValueError: If `weights` is not `None` and its shape doesn't match
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
ValueError: If `top_k_predictions` has rank < 2.
@ -1641,7 +1561,7 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
default_name = _at_k_name('precision', class_id=class_id)
with ops.name_scope(
name, default_name,
(top_k_predictions, labels, ignore_mask, weights)) as scope:
(top_k_predictions, labels, weights)) as scope:
rank = array_ops.rank(top_k_predictions)
check_rank_op = control_flow_ops.Assert(
math_ops.greater_equal(rank, 2),
@ -1651,7 +1571,6 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
top_k_idx=top_k_predictions,
labels=labels,
class_id=class_id,
ignore_mask=ignore_mask,
weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections,
@ -2760,8 +2679,7 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
return mean_distance, update_op
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
def streaming_percentage_less(values, threshold, ignore_mask=None, weights=None,
def streaming_percentage_less(values, threshold, weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@ -2778,13 +2696,10 @@ def streaming_percentage_less(values, threshold, ignore_mask=None, weights=None,
`percentage`.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Alternatively, if `ignore_mask` is not `None`, then mask values where
`ignore_mask` is `True`.
Args:
values: A numeric `Tensor` of arbitrary size.
threshold: A scalar threshold.
ignore_mask: An optional, `bool` `Tensor` whose shape matches `values`.
weights: An optional `Tensor` whose shape is broadcastable to `values`.
metrics_collections: An optional list of collections that the metric
value variable should be added to.
@ -2799,23 +2714,21 @@ def streaming_percentage_less(values, threshold, ignore_mask=None, weights=None,
appropriately.
Raises:
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
`values`, or if `weights` is not `None` and its shape doesn't match
`values`, or if either `metrics_collections` or `updates_collections` are
not a list or tuple.
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
return streaming_mean(is_below_threshold, _mask_weights(ignore_mask, weights),
return streaming_mean(is_below_threshold,
weights,
metrics_collections,
updates_collections,
name or 'percentage_below_threshold')
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
def streaming_mean_iou(predictions,
labels,
num_classes,
ignore_mask=None,
weights=None,
metrics_collections=None,
updates_collections=None,
@ -2834,8 +2747,6 @@ def streaming_mean_iou(predictions,
`update_op` operation that updates these variables and returns the `mean_iou`.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Alternatively, if `ignore_mask` is not `None`, then mask values where
`ignore_mask` is `True`.
Args:
predictions: A tensor of prediction results for semantic labels, whose
@ -2846,7 +2757,6 @@ def streaming_mean_iou(predictions,
num_classes: The possible number of labels the prediction task can
have. This value must be provided, since a confusion matrix of
dimension = [num_classes, num_classes] will be allocated.
ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
metrics_collections: An optional list of collections that `mean_iou`
should be added to.
@ -2860,9 +2770,8 @@ def streaming_mean_iou(predictions,
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
`ignore_mask` is not `None` and its shape doesn't match `predictions`, or
if `weights` is not `None` and its shape doesn't match `predictions`, or
if either `metrics_collections` or `updates_collections` are not a list or
`weights` is not `None` and its shape doesn't match `predictions`, or if
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(name, 'mean_iou', [predictions, labels]):
@ -2888,7 +2797,6 @@ def streaming_mean_iou(predictions,
if labels_rank > 1:
labels = array_ops.reshape(labels, [-1])
weights = _mask_weights(ignore_mask, weights)
if weights is not None:
weights_rank = weights.get_shape().ndims
if weights_rank > 1:

View File

@ -671,18 +671,6 @@ class StreamingPrecisionTest(tf.test.TestCase):
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, precision.eval())
def testMasked(self):
predictions = tf.constant([1, 0, 1, 0, 1], shape=(1, 5))
labels = tf.constant([0, 1, 1, 0, 1], shape=(1, 5))
mask = tf.constant([False, False, False, False, True], shape=(1, 5))
precision, update_op = metrics.streaming_precision(
predictions, labels, ignore_mask=mask)
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, precision.eval())
def testWeighted1d(self):
predictions = tf.constant([[1, 0, 1, 0], [1, 0, 1, 0]])
labels = tf.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
@ -838,18 +826,6 @@ class StreamingRecallTest(tf.test.TestCase):
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, recall.eval())
def testMasked(self):
predictions = tf.constant([1, 0, 1, 0, 1], shape=(1, 5))
labels = tf.constant([0, 1, 1, 0, 1], shape=(1, 5))
mask = tf.constant([False, False, False, False, True], shape=(1, 5))
recall, update_op = metrics.streaming_recall(
predictions, labels, ignore_mask=mask)
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, recall.eval())
def testWeighted1d(self):
predictions = tf.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
labels = tf.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
@ -1737,15 +1713,13 @@ class StreamingRecallAtKTest(tf.test.TestCase):
dtype=tf.float32)
labels = tf.constant(
self._np_labels, shape=(self._batch_size,), dtype=tf.int64)
weights = tf.constant([0, 1, 1, 1], shape=(self._batch_size,),
weights = tf.constant([0, 1, 0, 1], shape=(self._batch_size,),
dtype=tf.float32)
mask = tf.constant([False, False, True, False], shape=(self._batch_size,),
dtype=tf.bool)
recall, update_op = metrics.streaming_recall_at_k(
predictions, labels, k=2, ignore_mask=mask, weights=weights)
predictions, labels, k=2, weights=weights)
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, tf.reshape(labels, (self._batch_size, 1)), k=2,
ignore_mask=mask, weights=weights)
weights=weights)
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
@ -1763,16 +1737,13 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
k,
expected,
class_id=None,
ignore_mask=None,
weights=None):
with tf.Graph().as_default() as g, self.test_session(g):
if ignore_mask is not None:
ignore_mask = tf.constant(ignore_mask, tf.bool)
if weights is not None:
weights = tf.constant(weights, tf.float32)
metric, update = metrics.streaming_sparse_precision_at_k(
predictions=tf.constant(predictions, tf.float32), labels=labels,
k=k, class_id=class_id, ignore_mask=ignore_mask, weights=weights)
k=k, class_id=class_id, weights=weights)
# Fails without initialized vars.
self.assertRaises(tf.OpError, metric.eval)
@ -1792,17 +1763,13 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
labels,
expected,
class_id=None,
ignore_mask=None,
weights=None):
with tf.Graph().as_default() as g, self.test_session(g):
if ignore_mask is not None:
ignore_mask = tf.constant(ignore_mask, tf.bool)
if weights is not None:
weights = tf.constant(weights, tf.float32)
metric, update = metrics.streaming_sparse_precision_at_top_k(
top_k_predictions=tf.constant(top_k_predictions, tf.int32),
labels=labels, class_id=class_id, ignore_mask=ignore_mask,
weights=weights)
labels=labels, class_id=class_id, weights=weights)
# Fails without initialized vars.
self.assertRaises(tf.OpError, metric.eval)
@ -1821,11 +1788,8 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
predictions,
labels,
k,
expected,
ignore_mask=None):
expected):
with tf.Graph().as_default() as g, self.test_session(g):
if ignore_mask is not None:
ignore_mask = tf.constant(ignore_mask, tf.bool)
predictions = tf.constant(predictions, tf.float32)
metric = metric_ops.sparse_average_precision_at_k(
predictions, labels, k)
@ -2305,11 +2269,9 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
top_k_predictions, labels, expected=NAN, class_id=class_id,
weights=[[0, 0], [0, 0]])
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=NAN, ignore_mask=[[False], [True]],
weights=[[0], [1]])
predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
self._test_streaming_sparse_precision_at_top_k(
top_k_predictions, labels, expected=NAN,
ignore_mask=[[False], [True]], weights=[[0], [1]])
top_k_predictions, labels, expected=NAN, weights=[[0], [0]])
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
self._test_streaming_sparse_precision_at_top_k(
@ -2342,34 +2304,34 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
# Class 2: 2 predictions, both correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
ignore_mask=[[False], [False]], weights=[[1], [0]])
weights=[[1], [0]])
self._test_streaming_sparse_precision_at_top_k(
top_k_predictions, labels, expected=2.0 / 2.0, class_id=2,
ignore_mask=[[False], [False]], weights=[[1], [0]])
weights=[[1], [0]])
# Class 2: 2 predictions, both correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
ignore_mask=[[False], [False]], weights=[[0], [1]])
weights=[[0], [1]])
self._test_streaming_sparse_precision_at_top_k(
top_k_predictions, labels, expected=2.0 / 2.0, class_id=2,
ignore_mask=[[False], [False]], weights=[[0], [1]])
weights=[[0], [1]])
# Class 7: 1 incorrect prediction.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=0.0 / 1.0, class_id=7,
ignore_mask=[[False], [True]], weights=[[1], [1]])
weights=[[1], [0]])
self._test_streaming_sparse_precision_at_top_k(
top_k_predictions, labels, expected=0.0 / 1.0, class_id=7,
ignore_mask=[[False], [True]], weights=[[1], [1]])
weights=[[1], [0]])
# Class 7: 1 correct prediction.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=1.0 / 1.0, class_id=7,
ignore_mask=[[True], [False]], weights=[[1], [1]])
weights=[[0], [1]])
self._test_streaming_sparse_precision_at_top_k(
top_k_predictions, labels, expected=1.0 / 1.0, class_id=7,
ignore_mask=[[True], [False]], weights=[[1], [1]])
weights=[[0], [1]])
# Class 7: no predictions.
self._test_streaming_sparse_precision_at_k(
@ -2409,17 +2371,13 @@ class StreamingSparseRecallTest(tf.test.TestCase):
k,
expected,
class_id=None,
ignore_mask=None,
weights=None):
with tf.Graph().as_default() as g, self.test_session(g):
if ignore_mask is not None:
ignore_mask = tf.constant(ignore_mask, tf.bool)
if weights is not None:
weights = tf.constant(weights, tf.float32)
metric, update = metrics.streaming_sparse_recall_at_k(
predictions=tf.constant(predictions, tf.float32),
labels=labels, k=k, class_id=class_id, ignore_mask=ignore_mask,
weights=weights)
labels=labels, k=k, class_id=class_id, weights=weights)
# Fails without initialized vars.
self.assertRaises(tf.OpError, metric.eval)
@ -2740,8 +2698,7 @@ class StreamingSparseRecallTest(tf.test.TestCase):
predictions, labels, k=5, expected=NAN, class_id=class_id,
weights=[[0, 0], [0, 0]])
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, ignore_mask=[[False], [True]],
weights=[[0], [1]])
predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
@ -2764,22 +2721,22 @@ class StreamingSparseRecallTest(tf.test.TestCase):
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
ignore_mask=[[False], [False]], weights=[[1], [0]])
weights=[[1], [0]])
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
ignore_mask=[[False], [False]], weights=[[0], [1]])
weights=[[0], [1]])
# Class 7: 1 label, correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=1.0 / 1.0, class_id=7,
ignore_mask=[[True], [False]], weights=[[1], [1]])
weights=[[0], [1]])
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0 / 1.0, class_id=7,
ignore_mask=[[False], [True]], weights=[[1], [1]])
weights=[[1], [0]])
# Class 7: 2 labels, 1 correct.
self._test_streaming_sparse_recall_at_k(
@ -3660,16 +3617,14 @@ class PcntBelowThreshTest(tf.test.TestCase):
def testSomePresentOneUpdate(self):
with self.test_session() as sess:
values = tf.constant([2, 4, 6, 8], shape=(1, 4), dtype=tf.float32)
mask = tf.constant([False, True, False, False], shape=(1, 4),
dtype=tf.bool)
weights = tf.constant([1, 1, 0, 1], shape=(1, 4), dtype=tf.float32)
weights = tf.constant([1, 0, 0, 1], shape=(1, 4), dtype=tf.float32)
pcnt0, update_op0 = metrics.streaming_percentage_less(
values, 100, ignore_mask=mask, weights=weights, name='high')
values, 100, weights=weights, name='high')
pcnt1, update_op1 = metrics.streaming_percentage_less(
values, 7, ignore_mask=mask, weights=weights, name='medium')
values, 7, weights=weights, name='medium')
pcnt2, update_op2 = metrics.streaming_percentage_less(
values, 1, ignore_mask=mask, weights=weights, name='low')
values, 1, weights=weights, name='low')
sess.run(tf.initialize_local_variables())
self.assertListEqual([1.0, 0.5, 0.0],
@ -3712,22 +3667,6 @@ class StreamingMeanIOUTest(tf.test.TestCase):
metrics.streaming_mean_iou(
predictions, labels, num_classes=2)
def testLabelsAndIgnoreMaskOfDifferentSizeRaisesValueError(self):
predictions = tf.ones([10])
labels = tf.ones([10])
ignore_mask = tf.cast(tf.ones([9]), tf.bool)
with self.assertRaises(ValueError):
metrics.streaming_mean_iou(
predictions, labels, num_classes=2, ignore_mask=ignore_mask)
def testIgnoreMaskIsNotBooleanRaisesTypeError(self):
predictions = tf.ones([10])
labels = tf.ones([10])
ignore_mask = tf.ones([10])
with self.assertRaises(TypeError):
metrics.streaming_mean_iou(
predictions, labels, num_classes=2, ignore_mask=ignore_mask)
def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self):
predictions = tf.ones([10])
labels = tf.ones([10])
@ -3810,29 +3749,18 @@ class StreamingMeanIOUTest(tf.test.TestCase):
_enqueue_vector(sess, labels_queue, [1])
labels = labels_queue.dequeue()
# Create the queue that populates the ignore_masks.
ignore_masks_queue = tf.FIFOQueue(6, dtypes=tf.bool, shapes=(1, 1))
_enqueue_vector(sess, ignore_masks_queue, [False])
_enqueue_vector(sess, ignore_masks_queue, [False])
_enqueue_vector(sess, ignore_masks_queue, [False])
_enqueue_vector(sess, ignore_masks_queue, [True])
_enqueue_vector(sess, ignore_masks_queue, [False])
_enqueue_vector(sess, ignore_masks_queue, [False])
ignore_mask = ignore_masks_queue.dequeue()
# Create the queue that populates the weights.
weights_queue = tf.FIFOQueue(6, dtypes=tf.float32, shapes=(1, 1))
_enqueue_vector(sess, weights_queue, [1.0])
_enqueue_vector(sess, weights_queue, [1.0])
_enqueue_vector(sess, weights_queue, [1.0])
_enqueue_vector(sess, weights_queue, [1.0])
_enqueue_vector(sess, weights_queue, [0.0])
_enqueue_vector(sess, weights_queue, [1.0])
_enqueue_vector(sess, weights_queue, [0.0])
weights = weights_queue.dequeue()
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes, ignore_mask=ignore_mask,
weights=weights)
predictions, labels, num_classes, weights=weights)
sess.run(tf.initialize_local_variables())
for _ in range(6):
@ -3920,13 +3848,12 @@ class StreamingMeanIOUTest(tf.test.TestCase):
labels = tf.concat(0, [tf.constant(0, shape=[3]),
tf.constant(1, shape=[7])])
num_classes = 2
mask = tf.concat(0, [tf.constant(False, shape=[9]),
tf.constant(True, shape=[1])])
weights = tf.concat(0, [tf.constant(0, shape=[1]),
tf.constant(1, shape=[9])])
tf.constant(1, shape=[8]),
tf.constant(0, shape=[1])])
with self.test_session() as sess:
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes, ignore_mask=mask, weights=weights)
predictions, labels, num_classes, weights=weights)
sess.run(tf.initialize_local_variables())
self.assertAllEqual([[2, 2], [0, 4]], update_op.eval())
desired_miou = np.mean([2./4., 4./6.])

View File

@ -100,7 +100,7 @@ class ExternalOptimizerInterface(object):
accumulated_dims[1:])]
def minimize(self, session=None, feed_dict=None, fetches=None,
step_callback=None, loss_callback=None, grad_callback=None):
step_callback=None, loss_callback=None):
"""Minimize a scalar `Tensor`.
Variables subject to optimization are updated in-place at the end of
@ -113,14 +113,13 @@ class ExternalOptimizerInterface(object):
Args:
session: A `Session` instance.
feed_dict: A feed dict to be passed to calls to `session.run`.
fetches: A list of `Tensor`s to fetch and supply to `loss_callback` and
`grad_callback` as positional arguments.
fetches: A list of `Tensor`s to fetch and supply to `loss_callback`
as positional arguments.
step_callback: A function to be called at each optimization step;
arguments are the current values of all optimization variables
flattened into a single vector.
loss_callback: A function to be called every time the loss and gradients
are computed, with evaluated fetches supplied as positional arguments.
grad_callback: Deprecated.
"""
session = session or ops.get_default_session()
feed_dict = feed_dict or {}
@ -128,9 +127,6 @@ class ExternalOptimizerInterface(object):
loss_callback = loss_callback or (lambda *fetches: None)
step_callback = step_callback or (lambda xk: None)
# TODO(chapelle): Remove grad_callback (b/30590858)
if grad_callback:
logging.warn('grad_callback is deprecated. Please use loss_callback.')
# Construct loss function and associated gradient.
loss_grad_func = self._make_eval_func(

View File

@ -62,7 +62,8 @@ from tensorflow.python.training import saver
class MovingAverageOptimizer(optimizer.Optimizer):
"""Optimizer wrapper that maintains a moving average of parameters."""
def __init__(self, opt, average_decay=0.9999, sequential_update=True):
def __init__(self, opt, average_decay=0.9999, num_updates=None,
sequential_update=True):
"""Construct a new MovingAverageOptimizer.
Args:
@ -70,6 +71,8 @@ class MovingAverageOptimizer(optimizer.Optimizer):
average_decay: Float. Decay to use to maintain the moving averages
of trained variables.
See tf.train.ExponentialMovingAverage for details.
num_updates: Optional count of number of updates applied to variables.
See tf.train.ExponentialMovingAverage for details.
sequential_update: Bool. If False, will compute the moving average at the
same time as the model is updated, potentially doing
benign data races.
@ -77,7 +80,8 @@ class MovingAverageOptimizer(optimizer.Optimizer):
updates.
"""
self._optimizer = opt
self._ema = moving_averages.ExponentialMovingAverage(average_decay)
self._ema = moving_averages.ExponentialMovingAverage(
average_decay, num_updates=num_updates)
self._variable_map = None
self._sequential_update = sequential_update

View File

@ -181,6 +181,24 @@ tf_gen_op_libs(
op_lib_names = ["lstm_ops"],
)
tf_kernel_library(
name = "gru_ops_kernels",
srcs = [
"kernels/blas_gemm.cc",
"kernels/blas_gemm.h",
],
gpu_srcs = [
"kernels/blas_gemm.h",
],
prefix = "kernels/gru_ops",
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:eigen_helpers",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "lstm_ops_kernels",
srcs = [

View File

@ -37,7 +37,6 @@ perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
namespace functor {
template <typename T>
void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
perftools::gputools::Stream* stream,
bool transa, bool transb, uint64 m,
uint64 n, uint64 k, T alpha, const T* a,
int lda, const T* b, int ldb, T beta, T* c,
@ -52,7 +51,8 @@ void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
auto c_ptr = AsDeviceMemory(c);
bool blas_launch_status =
stream
ctx->op_device_context()
->stream()
->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr,
lda, b_ptr, ldb, beta, &c_ptr, ldc)
.ok();

View File

@ -21,22 +21,15 @@ limitations under the License.
#include "tensorflow/core/kernels/eigen_activations.h"
#include "tensorflow/core/platform/types.h"
namespace perftools {
namespace gputools {
class Stream;
} // end namespace gputools
} // end namespace perftools
namespace tensorflow {
class OpKernelContext;
namespace functor {
template <typename T>
struct TensorCuBlasGemm {
void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
bool transa, bool transb, uint64 m, uint64 n, uint64 k,
T alpha, const T* a, int lda, const T* b, int ldb, T beta,
T* c, int ldc);
void operator()(OpKernelContext* ctx, bool transa, bool transb, uint64 m,
uint64 n, uint64 k, T alpha, const T* a, int lda, const T* b,
int ldb, T beta, T* c, int ldc);
};
template <typename Device, typename T, bool USE_CUBLAS>
@ -44,16 +37,15 @@ struct TensorBlasGemm;
template <typename Device, typename T>
struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, bool transa, bool transb, T alpha,
typename TTypes<T>::ConstMatrix a,
static void compute(OpKernelContext* ctx, const Device& d, bool transa,
bool transb, T alpha, typename TTypes<T>::ConstMatrix a,
typename TTypes<T>::ConstMatrix b, T beta,
typename TTypes<T>::Matrix c) {
int64 m = c.dimensions()[0];
int64 n = c.dimensions()[1];
int64 k = transa ? a.dimensions()[0] : a.dimensions()[1];
TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(),
TensorCuBlasGemm<T>()(ctx, transb, transa, n, m, k, alpha, b.data(),
transb ? k : n, a.data(), transa ? m : k, beta,
c.data(), n);
}
@ -61,9 +53,8 @@ struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
template <typename Device, typename T>
struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> {
static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, bool transa, bool transb, T alpha,
typename TTypes<T>::ConstMatrix a,
static void compute(OpKernelContext* ctx, const Device& d, bool transa,
bool transb, T alpha, typename TTypes<T>::ConstMatrix a,
typename TTypes<T>::ConstMatrix b, T beta,
typename TTypes<T>::Matrix c) {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;

View File

@ -15,10 +15,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/rnn/kernels/gru_ops.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -151,14 +147,9 @@ class GRUCellBlockOp : public OpKernel {
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::GRUBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_u_bar_tensor.matrix<T>(),
r_tensor->matrix<T>(), u_tensor->matrix<T>(), c_tensor->matrix<T>(),
@ -362,14 +353,10 @@ class GRUBlockCellGradOp : public OpKernel {
&d_x_component_2_h_prevr));
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::GRUBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_tensor->matrix<T>(),
u_tensor->matrix<T>(), c_tensor->matrix<T>(), d_h_tensor->matrix<T>(),
@ -400,8 +387,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void GRUBlockCellFprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, typename TTypes<T>::ConstMatrix x, \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w_ru, \
typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \
@ -430,9 +417,9 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void GRUBlockCellBprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix h, typename TTypes<T>::ConstMatrix w_ru, \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix h, \
typename TTypes<T>::ConstMatrix w_ru, \
typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \
typename TTypes<T>::ConstVec b_c, typename TTypes<T>::ConstMatrix r, \
typename TTypes<T>::ConstMatrix u, typename TTypes<T>::ConstMatrix c, \

View File

@ -21,12 +21,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
namespace perftools {
namespace gputools {
class Stream;
} // end namespace gputools
} // end namespace perftools
namespace tensorflow {
class OpKernelContext;
@ -77,18 +71,15 @@ struct GRUBlockCellFprop : public GRUCell {
const int cell_size)
: GRUCell(batch_size, input_size, cell_size) {}
void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix h_prev,
typename TTypes<T>::ConstMatrix w_ru,
typename TTypes<T>::ConstMatrix w_c,
typename TTypes<T>::ConstVec b_ru,
typename TTypes<T>::ConstVec b_c,
typename TTypes<T>::Matrix r_u_bar,
typename TTypes<T>::Matrix r, typename TTypes<T>::Matrix u,
typename TTypes<T>::Matrix c, typename TTypes<T>::Matrix h,
typename TTypes<T>::Matrix x_h_prev,
typename TTypes<T>::Matrix x_h_prevr) {
void operator()(
OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix h_prev,
typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
typename TTypes<T>::Matrix r_u_bar, typename TTypes<T>::Matrix r,
typename TTypes<T>::Matrix u, typename TTypes<T>::Matrix c,
typename TTypes<T>::Matrix h, typename TTypes<T>::Matrix x_h_prev,
typename TTypes<T>::Matrix x_h_prevr) {
// Concat x_h_prev = [x, h_prev].
x_h_prev.slice(x_offsets(), x_extends()).device(d) = x;
x_h_prev.slice(h_offsets(), h_extends()).device(d) = h_prev;
@ -96,9 +87,8 @@ struct GRUBlockCellFprop : public GRUCell {
// r_u_bar = x_h_prev * w_ru + b_ru
typename TTypes<T>::ConstMatrix const_x_h_prev(x_h_prev.data(),
x_h_prev.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, stream, d, false, false,
T(1), const_x_h_prev, w_ru,
T(0), r_u_bar);
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, d, false, false, T(1), const_x_h_prev, w_ru, T(0), r_u_bar);
// Creating a bias matrix for adding by broadcasting 'b_ru'
Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
@ -117,7 +107,7 @@ struct GRUBlockCellFprop : public GRUCell {
typename TTypes<T>::ConstMatrix const_x_h_prevr(x_h_prevr.data(),
x_h_prevr.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c);
ctx, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c);
Eigen::array<Eigen::DenseIndex, 2> b_c_shape({1, b_c.dimensions()[0]});
c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape));
@ -135,8 +125,7 @@ struct GRUBlockCellBprop : public GRUCell {
: GRUCell(batch_size, input_size, cell_size) {}
void operator()(
OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, typename TTypes<T>::ConstMatrix x,
OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix h_prev,
typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
@ -159,9 +148,9 @@ struct GRUBlockCellBprop : public GRUCell {
// [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T
typename TTypes<T>::ConstMatrix const_d_c_bar(d_c_bar.data(),
d_c_bar.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, stream, d, false, true,
T(1), const_d_c_bar, w_c,
T(0), d_x_comp2_and_h_prevr);
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, d, false, true, T(1),
const_d_c_bar, w_c, T(0),
d_x_comp2_and_h_prevr);
d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends());
d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r);
@ -175,7 +164,7 @@ struct GRUBlockCellBprop : public GRUCell {
typename TTypes<T>::ConstMatrix const_d_r_bar_u_bar(
d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0),
ctx, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0),
d_x_comp1_and_h_prev_comp1);
// d_x = d_x_comp1 + d_x_comp2

View File

@ -34,10 +34,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@ -164,14 +160,10 @@ class LSTMBlockCellOp : public OpKernel {
&icfo_tensor));
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, forget_bias_, cell_clip_, use_peephole_,
ctx, device, forget_bias_, cell_clip_, use_peephole_,
x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
@ -196,22 +188,21 @@ REGISTER_KERNEL(float);
#if GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, const T forget_bias, const T cell_clip, \
bool use_peephole, typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \
\
#define DECLARE_GPU_SPEC(T) \
template <> \
void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, const T forget_bias, \
const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \
\
extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
DECLARE_GPU_SPEC(float);
@ -445,10 +436,6 @@ class LSTMBlockCellGradOp : public OpKernel {
&di_tensor));
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>());
functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>());
@ -456,7 +443,7 @@ class LSTMBlockCellGradOp : public OpKernel {
functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, use_peephole_, x_tensor->matrix<T>(),
ctx, device, use_peephole_, x_tensor->matrix<T>(),
cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(),
@ -486,8 +473,7 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void LSTMBlockCellBprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, bool use_peephole, \
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
@ -769,10 +755,6 @@ class BlockLSTMOp : public OpKernel {
&icfo_tensor));
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
SliceHelper<Device, T> slicer(ctx);
@ -794,7 +776,7 @@ class BlockLSTMOp : public OpKernel {
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, forget_bias_, cell_clip_, use_peephole_,
ctx, device, forget_bias_, cell_clip_, use_peephole_,
x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
@ -1020,10 +1002,6 @@ class BlockLSTMGradOp : public OpKernel {
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<float>());
functor::TensorZero<Device, T>()(device,
@ -1073,7 +1051,7 @@ class BlockLSTMGradOp : public OpKernel {
Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, use_peephole_, x_tensor.matrix<T>(),
ctx, device, use_peephole_, x_tensor.matrix<T>(),
cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(),
@ -1134,8 +1112,7 @@ namespace functor {
\
template <> \
void BlockLSTMBprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, bool use_peephole, \
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \

View File

@ -22,12 +22,6 @@ limitations under the License.
#include "tensorflow/core/kernels/eigen_activations.h"
#include "tensorflow/core/platform/types.h"
namespace perftools {
namespace gputools {
class Stream;
} // end namespace gputools
} // end namespace perftools
namespace tensorflow {
class OpKernelContext;
@ -153,29 +147,26 @@ struct LSTMBlockCellFprop : public LSTMBlockCell {
const int cell_size)
: LSTMBlockCell(batch_size, input_size, cell_size) {}
void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, const T forget_bias, const T cell_clip,
bool use_peephole, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev,
typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci,
typename TTypes<T>::ConstVec wcf,
typename TTypes<T>::ConstVec wco,
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,
typename TTypes<T>::Matrix icfo,
typename TTypes<T>::Matrix h) {
void operator()(
OpKernelContext* ctx, const Device& d, const T forget_bias,
const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
typename TTypes<T>::Matrix h) {
// Concat xh = [x, h].
xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev;
// states1 = xh * w + b
typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, false, false, T(1), const_xh, w, T(0), icfo);
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, d, false, false, T(1),
const_xh, w, T(0), icfo);
Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
@ -239,8 +230,8 @@ struct LSTMBlockCellBprop : public LSTMBlockCell {
: LSTMBlockCell(batch_size, input_size, cell_size) {}
void operator()(
OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x,
OpKernelContext* ctx, const Device& d, bool use_peephole,
typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
@ -305,8 +296,8 @@ struct BlockLSTMBprop : public LSTMBlockCell {
: LSTMBlockCell(batch_size, input_size, cell_size) {}
void operator()(
OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x,
OpKernelContext* ctx, const Device& d, bool use_peephole,
typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
@ -364,7 +355,7 @@ struct BlockLSTMBprop : public LSTMBlockCell {
typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(),
dicfo.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, false, true, T(1), const_dicfo, w, T(0), xh_grad);
ctx, d, false, true, T(1), const_dicfo, w, T(0), xh_grad);
// xh.
xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
@ -377,7 +368,7 @@ struct BlockLSTMBprop : public LSTMBlockCell {
// w_grad.
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad);
ctx, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad);
// b_grad.
b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0}));

View File

@ -1005,7 +1005,7 @@ _linear = rnn_cell._linear
class AttentionCellWrapper(rnn_cell.RNNCell):
"""Basic attention cell wrapper.
Implementation based on https://arxiv.org/pdf/1601.06733.pdf.
Implementation based on https://arxiv.org/abs/1409.0473.
"""
def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,

View File

@ -51,7 +51,7 @@ from tensorflow.contrib.slim.python.slim.data import parallel_reader
class DatasetDataProvider(data_provider.DataProvider):
def __init__(self, dataset, num_readers=1, shuffle=True, num_epochs=None,
common_queue_capacity=256, common_queue_min=128):
common_queue_capacity=256, common_queue_min=128, seed=None):
"""Creates a DatasetDataProvider.
Args:
@ -64,6 +64,7 @@ class DatasetDataProvider(data_provider.DataProvider):
common_queue_capacity: The capacity of the common queue.
common_queue_min: The minimum number of elements in the common queue after
a dequeue.
seed: The seed to use if shuffling.
"""
_, data = parallel_reader.parallel_read(
dataset.data_sources,
@ -72,7 +73,8 @@ class DatasetDataProvider(data_provider.DataProvider):
num_readers=num_readers,
shuffle=shuffle,
capacity=common_queue_capacity,
min_after_dequeue=common_queue_min)
min_after_dequeue=common_queue_min,
seed=seed)
items = dataset.decoder.list_items()
tensors = dataset.decoder.decode(data, items)

View File

@ -170,7 +170,8 @@ def parallel_read(data_sources,
shuffle=True,
dtypes=None,
capacity=256,
min_after_dequeue=128):
min_after_dequeue=128,
seed=None):
"""Reads multiple records in parallel from data_sources using n readers.
It uses a ParallelReader to read from multiple files in parallel using
@ -199,6 +200,7 @@ def parallel_read(data_sources,
capacity: integer, capacity of the common_queue.
min_after_dequeue: integer, minimum number of records in the common_queue
after dequeue. Needed for a good shuffle.
seed: A seed for RandomShuffleQueue.
Returns:
key, value: a tuple of keys and values from the data_source.
@ -212,7 +214,8 @@ def parallel_read(data_sources,
common_queue = data_flow_ops.RandomShuffleQueue(
capacity=capacity,
min_after_dequeue=min_after_dequeue,
dtypes=dtypes)
dtypes=dtypes,
seed=seed)
else:
common_queue = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=dtypes)

View File

@ -471,7 +471,14 @@ def create_train_op(
'LossTensor is inf or nan')
# Ensure the train_tensor computes grad_updates.
return control_flow_ops.with_dependencies([grad_updates], total_loss)
train_op = control_flow_ops.with_dependencies([grad_updates], total_loss)
# Add the operation used for training to the 'train_op' collection
train_ops = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
if train_op not in train_ops:
train_ops.append(train_op)
return train_op
def _wait_for_step(sess, global_step, step):

View File

@ -301,6 +301,22 @@ class CreateTrainOpTest(tf.test.TestCase):
self.assertAllClose(mean, [0] * 4)
self.assertAllClose(variance, [1] * 4)
def testRecordTrainOpInCollection(self):
with tf.Graph().as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = LogisticClassifier(tf_inputs)
slim.losses.log_loss(tf_predictions, tf_labels)
total_loss = slim.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = slim.learning.create_train_op(total_loss, optimizer)
# Make sure the training op was recorded in the proper collection
self.assertTrue(train_op in tf.get_collection(tf.GraphKeys.TRAIN_OP))
class TrainTest(tf.test.TestCase):

View File

@ -23,43 +23,54 @@ from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
def _accuracy(probabilities, targets):
predictions = math_ops.argmax(probabilities, 1)
# undo one-hot
labels = math_ops.argmax(targets, 1)
return metric_ops.streaming_accuracy(predictions, labels)
INFERENCE_PROB_NAME = 'inference'
INFERENCE_PRED_NAME = 'predictions'
def _r2(probabilities, targets):
def _accuracy(predictions, targets, weights=None):
return metric_ops.streaming_accuracy(predictions, targets, weights=weights)
def _r2(probabilities, targets, weights=None):
if targets.get_shape().ndims == 1:
targets = array_ops.expand_dims(targets, -1)
targets = math_ops.to_float(targets)
y_mean = math_ops.reduce_mean(targets, 0)
squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0)
squares_residuals = math_ops.reduce_sum(math_ops.square(
targets - probabilities), 0)
score = 1 - math_ops.reduce_sum(squares_residuals / squares_total)
return metric_ops.streaming_mean(score)
return metric_ops.streaming_mean(score, weights=weights)
def _sigmoid_entropy(probabilities, targets):
def _squeeze_and_onehot(targets, depth):
targets = array_ops.squeeze(targets, squeeze_dims=[1])
return array_ops.one_hot(math_ops.to_int32(targets), depth)
def _sigmoid_entropy(probabilities, targets, weights=None):
return metric_ops.streaming_mean(losses.sigmoid_cross_entropy(
probabilities, targets))
probabilities, _squeeze_and_onehot(targets,
array_ops.shape(probabilities)[1])),
weights=weights)
def _softmax_entropy(probabilities, targets):
return metric_ops.streaming_mean(losses.softmax_cross_entropy(
probabilities, targets))
def _softmax_entropy(probabilities, targets, weights=None):
return metric_ops.streaming_mean(losses.sparse_softmax_cross_entropy(
probabilities, math_ops.to_int32(targets)),
weights=weights)
def _predictions(probabilities, unused_targets):
return math_ops.argmax(probabilities, 1)
def _predictions(predictions, unused_targets, **unused_kwargs):
return predictions
def _log_loss(probabilities, targets):
# targets doesn't have a shape coming in, log_loss isn't too happy about it.
targets = array_ops.reshape(targets, array_ops.shape(probabilities))
return metric_ops.streaming_mean(losses.log_loss(probabilities, targets))
def _class_log_loss(probabilities, targets, weights=None):
return metric_ops.streaming_mean(
losses.log_loss(probabilities,
_squeeze_and_onehot(targets,
array_ops.shape(probabilities)[1])),
weights=weights)
_EVAL_METRICS = {'sigmoid_entropy': _sigmoid_entropy,
@ -67,9 +78,21 @@ _EVAL_METRICS = {'sigmoid_entropy': _sigmoid_entropy,
'accuracy': _accuracy,
'r2': _r2,
'predictions': _predictions,
'log_loss': _log_loss}
'classification_log_loss': _class_log_loss}
_PREDICTION_KEYS = {'sigmoid_entropy': INFERENCE_PROB_NAME,
'softmax_entropy': INFERENCE_PROB_NAME,
'accuracy': INFERENCE_PRED_NAME,
'r2': INFERENCE_PROB_NAME,
'predictions': INFERENCE_PRED_NAME,
'classification_log_loss': INFERENCE_PROB_NAME}
def get_metric(metric_name):
"""Given a metric name, return the corresponding metric function."""
return _EVAL_METRICS[metric_name]
def get_prediction_key(metric_name):
return _PREDICTION_KEYS[metric_name]

View File

@ -17,10 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import threading
from tensorflow.contrib.learn.python.learn.learn_io import graph_io
from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.python.framework import common_shapes
@ -35,8 +33,6 @@ from tensorflow.python.platform import tf_logging as logging
DATA_OPS_FILE = '_data_ops.so'
EXAMPLE_WEIGHT_NAME = '__weight__'
_data_ops = None
_ops_lock = threading.Lock()
@ -69,68 +65,28 @@ def Load():
def _ParseSparse(data):
"""Concat sparse tensors together.
A common use of sparse tensors is to treat strings as a sparse bit vector
with a large number of features representing the presence of all possible
values. Here we convert these strings to integer indices in a sparse bit
tensor. In order to pack each incoming feature into a single sparse tensor,
we add an offset to the converted indices to indicate that they came from
different features in the source data.
Args:
data: A dict of name -> Tensor.
Returns:
A single sparse tensor with float values and a 1-D input spec Tensor.
A single sparse tensor and a 1-D input spec Tensor.
Raises:
NotImplementedError: Combining dense and sparse tensors is not yet
NotImplementedError: Combining dense and sparse tensors is not
supported.
ValueError: If data contains non-string Tensors.
"""
convert_ops = Load()
# Sparse tensor indices have 63 bits to use for information. We use the
# minimum number of these (MSBs) for the offset, and pack the rest with the
# actual data.
num_features = len(data)
offset_bits = int(math.ceil(math.log(num_features, 2)))
# We condense data to 26 bits, see sparse_values_to_indices.cc
offset_increment = int(math.pow(2, 26 - offset_bits))
offset = 0
sparse_tensors = []
keys = None
weights = None
for k in sorted(data.keys()):
if k == graph_io.KEY_FEATURE_NAME:
keys = data[k]
elif k == EXAMPLE_WEIGHT_NAME:
weights = data[k]
elif isinstance(data[k], ops.SparseTensor):
# TODO(gilberth): Support mixed string/float sparse tensors.
# We currently only support string (categorical) data if we're using
# sparse tensors.
if data[k].dtype != dtypes.string:
raise ValueError('Only sparse tensors of type string are supported.')
sparse_indices = data[k].indices
sparse_values = data[k].values
new_shape = array_ops.concat(
0, [array_ops.slice(data[k].shape, [0], [1]), [offset_increment]])
if not isinstance(data[k], ops.SparseTensor):
raise NotImplementedError(
'Features should be either all sparse or all dense. Use a '
'feature engineering function to convert some of them.')
new_indices, new_values = convert_ops.sparse_values_to_indices(
sparse_indices,
sparse_values,
offset, offset_bits=offset_bits)
sparse_tensors.append(ops.SparseTensor(indices=new_indices,
values=new_values,
shape=new_shape))
else:
# Convert dense to sparse.
raise NotImplementedError('Dense to sparse conversion not implemented.')
return (sparse_ops.sparse_concat(1, sparse_tensors), keys, weights,
[constants.DATA_CATEGORICAL])
data_spec = [
constants.DATA_CATEGORICAL if data[data.keys()[0]].dtype == dtypes.string
else constants.DATA_FLOAT
]
return sparse_ops.sparse_concat(1, data.values()), data_spec
def _ParseDense(data):
@ -143,22 +99,20 @@ def _ParseDense(data):
A tuple of (single dense float Tensor, keys tensor (if exists), data spec).
"""
convert_ops = Load()
data_spec = [constants.DATA_CATEGORICAL if data[k].dtype == dtypes.string else
constants.DATA_FLOAT for k in sorted(data.keys())]
data_spec = [constants.DATA_CATEGORICAL if (data[k].dtype == dtypes.string or
data[k].dtype == dtypes.int32 or
data[k].dtype == dtypes.int64)
else constants.DATA_FLOAT for k in sorted(data.keys())]
data_spec = [constants.DATA_FLOAT] + data_spec
keys = None
weights = None
features = []
for k in sorted(data.keys()):
if k == graph_io.KEY_FEATURE_NAME:
keys = data[k]
elif k == EXAMPLE_WEIGHT_NAME:
weights = data[k]
if data[k].dtype == dtypes.string:
features.append(convert_ops.string_to_float(data[k]))
elif data[k].dtype == dtypes.int64 or data[k].dtype == dtypes.int32:
features.append(math_ops.to_float(data[k]))
else:
features.append(
convert_ops.string_to_float(data[k]) if data[k].dtype == dtypes.string
else data[k])
return array_ops.concat(1, features), keys, weights, data_spec
features.append(data[k])
return array_ops.concat(1, features), data_spec
def ParseDataTensorOrDict(data):
@ -187,8 +141,7 @@ def ParseDataTensorOrDict(data):
else:
return _ParseDense(data)
else:
return (data, None, None,
[constants.DATA_FLOAT] * data.get_shape().as_list()[1])
return (data, [constants.DATA_FLOAT] * data.get_shape().as_list()[1])
def ParseLabelTensorOrDict(labels):

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import math
import random
import sys
from tensorflow.contrib.losses.python.losses import loss_ops
from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
from tensorflow.contrib.tensor_forest.python.ops import training_ops
@ -429,8 +431,9 @@ class RandomForestGraphs(object):
return math_ops.reduce_mean(math_ops.to_float(array_ops.pack(sizes)))
# pylint: disable=unused-argument
def training_loss(self, features, labels):
return math_ops.neg(self.average_size())
def training_loss(self, features, labels, data_spec=None,
name='training_loss'):
return math_ops.neg(self.average_size(), name=name)
# pylint: disable=unused-argument
def validation_loss(self, features, labels):
@ -456,6 +459,63 @@ class RandomForestGraphs(object):
return ForestStats(tree_stats, self.params)
def one_hot_wrapper(num_classes, loss_fn):
"""Some loss functions take one-hot labels."""
def _loss(probs, targets):
one_hot_labels = array_ops.one_hot(
math_ops.to_int32(targets), num_classes,
on_value=1., off_value=0., dtype=dtypes.float32)
return loss_fn(probs, one_hot_labels)
return _loss
class TrainingLossForest(RandomForestGraphs):
"""Random Forest that uses training loss as the termination criteria."""
def __init__(self, params, loss_fn=None, **kwargs):
"""Initialize.
Args:
params: Like RandomForestGraphs, a ForestHParams object.
loss_fn: A function that takes probabilities and targets and returns
a loss for each example.
**kwargs: Keyword args to pass to superclass (RandomForestGraphs).
"""
self.loss_fn = loss_fn or one_hot_wrapper(params.num_classes,
loss_ops.log_loss)
self._loss = None
super(TrainingLossForest, self).__init__(params, **kwargs)
def _get_loss(self, features, labels, data_spec=None):
"""Constructs, caches, and returns the inference-based loss."""
if self._loss is not None:
return self._loss
def _average_loss():
probs = self.inference_graph(features, data_spec=data_spec)
return math_ops.reduce_sum(self.loss_fn(
probs, labels)) / math_ops.to_float(
array_ops.shape(features)[0])
self._loss = control_flow_ops.cond(
self.average_size() > 0, _average_loss,
lambda: constant_op.constant(sys.maxsize, dtype=dtypes.float32))
return self._loss
def training_graph(self, input_data, input_labels, data_spec=None,
**kwargs):
loss = self._get_loss(input_data, input_labels, data_spec=data_spec)
with ops.control_dependencies([loss.op]):
return super(TrainingLossForest, self).training_graph(
input_data, input_labels, **kwargs)
def training_loss(self, features, labels, data_spec=None,
name='training_loss'):
return array_ops.identity(
self._get_loss(features, labels, data_spec=data_spec), name=name)
class RandomTreeGraphs(object):
"""Builds TF graphs for random tree training and inference."""

View File

@ -12,6 +12,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/contrib/tfprof/python/tools/tfprof:model_analyzer",
"//tensorflow/contrib/tfprof/python/tools/tfprof:tfprof_logger",
],
)

View File

@ -20,434 +20,9 @@ and measures system performance.
4. Explore model based on name scope or graph structure.
5. Selectively grouping/filtering/accounting/ordering ops.
### Interfaces
tfprof can be used as CommandLine Interface (CLI) and Python API.
CLI locates in tensorflow/tools/tfprof.
Python API locates in tensorflow/contrib/tfprof.
Tutorial locates in tensorflow/tools/tfprof/README.md
[CLI Tutorials](#cli-tutorials):
It supports interactive mode for exploration and single-shot mode for
scripts. Outputs can be dumped to files or printed in terminal.
Python API Tutorials: Python API is not released yet.
## CLI Tutorials
Tutorials are based on a 32 layers ResNet.
TODO(xpan): Provide graph.pbtxt, model.ckpt, tfprof_log and run_meta download.
### Examples
1) Start `tfprof` command line tool
```shell
# Build the tool.
bazel build -c opt tensorflow/contrib/tfprof/...
# Help information, including detail 'option' instructions.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof help
#
# The following commands will start tfprof interactive mode.
#
# Profile model shapes and parameters only.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path=/graph.pbtxt
#
# Additionally profile checkpoint statistics and values.
# Use '-account_type_regexes _checkpoint_variables' to select
# checkpoint tensors.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path=graph.pbtxt \
--checkpoint_path=model.ckpt
#
# Additionally profile ops requested memory and timing.
# See CLI Input Files section on generating run_meta file.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path=graph.pbtxt \
--run_meta_path=run_meta \
--checkpoint_path=model.ckpt
#
# tfprof_log is used to define customized op types and float ops.
# Use tfprof_logger.write_op_log() to create tfprof_log.
# See 11) in Examples section on generating tfprof_log file.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path=graph.pbtxt \
--run_meta_path=run_meta \
--op_log_path=tfprof_log \
--checkpoint_path=model.ckpt
```
Note that `graph.pbtxt` is an ASCII text format.
2) Press enter to show the default options
```shell
tfprof>
tfprof>
-max_depth 4
-min_bytes 0
-min_micros 0
-min_params 0
-min_float_ops 0
-device_regexes .*
-order_by name
-account_type_regexes Variable
-start_name_regexes .*
-trim_name_regexes
-show_name_regexes .*
-hide_name_regexes IsVariableInitialized_[0-9]+,save\/.*,^zeros[0-9_]*
-account_displayed_op_only false
# supported select fileds. Availability depends on --[run_meta|checkpoint|op_log]_path.
# [bytes|micros|params|float_ops|num_hidden_ops|tensor_value|device|op_types]
-select params
-viz false
-dump_to_file
```
3) I want to see the `BatchNorm`'s gamma value in checkpoint.
```shell
# Requires --graph_path, --checkpoint_path.
tfprof> scope -show_name_regexes unit_1_0.*gamma -select tensor_value -max_depth 5
_TFProfRoot ()
unit_1_0/shared_activation/init_bn/gamma ()
[1.80 2.10 2.06 1.91 2.26 1.86 1.81 1.37 1.78 1.85 1.96 1.54 2.04 2.34 2.22 1.99 ],
unit_1_0/sub2/bn2/gamma ()
[1.57 1.83 1.30 1.25 1.59 1.14 1.26 0.82 1.19 1.10 1.48 1.01 0.82 1.23 1.21 1.14 ],
```
4) I want to see my checkpoint tensors shape and number of parameters.
```shell
# Requires --graph_path, --checkpoint_path.
# Increase -max_depth to see all tensors.
tfprof> scope -account_type_regexes _checkpoint_variables -select params -max_depth 4
_TFProfRoot (--/930.58k params)
global_step (0/0 params)
init/init_conv/DW (3x3x3x16, 432/864 params)
pool_logit/DW (64x10, 640/1.28k params)
pool_logit/DW/Momentum (64x10, 640/640 params)
pool_logit/biases (10, 10/20 params)
pool_logit/biases/Momentum (10, 10/10 params)
unit_last/final_bn/beta (64, 64/128 params)
unit_last/final_bn/gamma (64, 64/128 params)
unit_last/final_bn/moving_mean (64, 64/64 params)
unit_last/final_bn/moving_variance (64, 64/64 params)
```
5) I defined an op named cost to calculate the loss. I want to know what ops
it depends on take a long time to run. Hint: Use the graph command to explore
graph dependencies.
```shell
# Requires --graph_path, --run_meta_path.
tfprof> graph -start_name_regexes cost.* -max_depth 100 -min_micros 10000 -select micros -account_type_regexes .*
_TFProfRoot (0us/3.61sec)
init/init_conv/Conv2D (11.75ms/3.10sec)
random_shuffle_queue_DequeueMany (3.09sec/3.09sec)
unit_1_0/sub2/conv2/Conv2D (74.14ms/3.19sec)
unit_1_3/sub2/conv2/Conv2D (60.75ms/3.34sec)
unit_2_4/sub2/conv2/Conv2D (73.58ms/3.54sec)
unit_3_3/sub2/conv2/Conv2D (10.26ms/3.60sec)
```
6) I want to know the expensive operations during the back propagation.
Hint: tensorflow prepend gradient to your defined name scopes. Use the scope
command to explore based on name scope hierarchies.
```shell
# Requires --graph_path, --run_meta_path.
tfprof> scope -start_name_regexes gradient.* -max_depth 100 -min_micros 20000 -select micros -account_type_regexes .*
_TFProfRoot (0us/2.29sec)
gradients/unit_1_0/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (54.96ms/54.96ms)
gradients/unit_1_0/sub2/conv2/Conv2D_grad/Conv2DBackpropFilter (83.63ms/83.63ms)
gradients/unit_1_1/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (99.25ms/99.25ms)
gradients/unit_1_2/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (95.40ms/95.40ms)
gradients/unit_1_2/sub2/conv2/Conv2D_grad/Conv2DBackpropFilter (99.83ms/99.83ms)
gradients/unit_1_3/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (95.39ms/95.39ms)
...
```
7) Show the number of float operations in the model.
Note: float operations calculation depends on
1) op.RegisterStatistics. If an op doesnt
have RegisterStatistics defined, its float operations cannot be counted.
2) fully defined shape is also necessary in order to calculate flops.
float operations number is provided by tensorflow::tfprof::OpLog logged from
Python API.
```shell
# Requires --graph_path, --op_log_path.
tfprof> scope -min_float_ops 1 -max_depth 10 -select float_ops -account_type_regexes .*
_TFProfRoot (0/17.63b flops)
gradients/pool_logit/xw_plus_b/MatMul_grad/MatMul (163.84k/163.84k flops)
gradients/pool_logit/xw_plus_b/MatMul_grad/MatMul_1 (163.84k/163.84k flops)
init/init_conv/Conv2D (113.25m/113.25m flops)
pool_logit/xw_plus_b (1.28k/165.12k flops)
pool_logit/xw_plus_b/MatMul (163.84k/163.84k flops)
unit_1_0/sub1/conv1/Conv2D (603.98m/603.98m flops)
unit_1_0/sub2/conv2/Conv2D (603.98m/603.98m flops)
unit_1_1/sub1/conv1/Conv2D (603.98m/603.98m flops)
unit_1_1/sub2/conv2/Conv2D (603.98m/603.98m flops)
...
```
8) Show the number of parameters of all `tf.trainable_variables()` in the model.
```shell
# Requires --graph_path --op_log_path.
# store option for future commands.
tfprof> set -account_type_regexes _trainable_variables
tfprof> scope -max_depth 4 -select params
_TFProfRoot (--/464.15k params)
init/init_conv/DW (3x3x3x16, 432/432 params)
pool_logit/DW (64x10, 640/640 params)
pool_logit/biases (10, 10/10 params)
unit_last/final_bn/beta (64, 64/64 params)
unit_last/final_bn/gamma (64, 64/64 params)
```
Where does “_trainable_variables” come from? It is from the OpLog file
generated by write_op_log() Python API. write_op_log() help users create some
common op types implicitly. Users can define their own op types and log it
through the write_op_log() API.
9) What if Im lazy and dont want to define op type? I have given my ops
well-defined names in my models code. And want to use names to select a group
of ops. Lets try it!
```shell
tfprof> set -account_type_regexes .*
tfprof> scope -show_name_regexes unit_2_1.*DW -max_depth 100 -account_displayed_op_only
_TFProfRoot (0/18.43k params)
unit_2_1/sub1/conv1/DW (3x3x32x32, 9.22k/9.22k params)
unit_2_1/sub2/conv2/DW (3x3x32x32, 9.22k/9.22k params)
```
The above command allows you to filter ops that match specific names.
`-account_displayed_op_only` asks tfprof to only account ops displayed
in terminal. Otherwise, tfprof accounts all ops matched by
`-account_type_regexes` recursively even if they are hidden due to some
options such as -max_depth.
10) TensorFlow has built-in op types. For example, built-in op type `Variable`
seems to include `Variable's` created by your model. However, be careful when
depending on it because TensorFlow creates extra `Variable` ops implicitly and
the implicitly created ops can have the same prefix as the `Variable's` you
defined.
In the following example, extra `Variables` are created and “/Momentum” is
appended to their names. This might cause you “model capacity” calculation
to get wrong.
```shell
tfprof> scope -account_type_regexes Variable -max_depth 4 -select params
_TFProfRoot (--/930.58k params)
global_step (1/1 params)
init/init_conv/DW (3x3x3x16, 432/864 params)
pool_logit/DW (64x10, 640/1.28k params)
pool_logit/DW/Momentum (64x10, 640/640 params)
pool_logit/biases (10, 10/20 params)
pool_logit/biases/Momentum (10, 10/10 params)
unit_last/final_bn/beta (64, 64/128 params)
unit_last/final_bn/gamma (64, 64/128 params)
unit_last/final_bn/moving_mean (64, 64/64 params)
unit_last/final_bn/moving_variance (64, 64/64 params)
```
11) A example of defining extra op type for ops using `OpLog`
First, in Python code, create an `OpLog` proto and add op type
information to it:
```python
op_log = tfprof_log_pb2.OpLog()
entry = op_log.log_entries.add()
entry.name = 'pool_logit/DW'
entry.types.append('pool_logit')
entry = op_log.log_entries.add()
entry.name = 'pool_logit/biases'
# Alternatively:
# var = tf.get_variable(xxx)
# entry.name = var.op.name
entry.types.append('pool_logit')
```
Second, call write_op_log to write the OpLog proto.
```python
tf.tfprof.tfprof_logger.write_op_log(sess.graph, /tmp/my_op_log_dir, op_log)
```
Third, when starting the tfprof tool, specify
"--op_log_path /tmp/my_op_log_dir/op_log"
```shell
tfprof> scope -account_type_regexes pool_logit -max_depth 4 -select params
_TFProfRoot (--/650 params)
pool_logit/DW (64x10, 640/640 params)
pool_logit/biases (10, 10/10 params)
```
Note that when you call
`tf.tfprof.tfprof_logger.write_op_log(...)`, the tool adds all `Variables`
inside `tf.trainable_variables()` to `_trainable_variables`.
12) Run tfprof in one-shot mode and dump result to file.
```shell
# Printed to stdout if --dump_to_file is not set.
tfprof scope --graph_path /cns/ij-d/home/xpan/tfprof/graph.pbtxt \
--max_depth 3 \
--dump_to_file "/tmp/dump"
Reading Files...
Parsing GraphDef...
Preparing Views...
cat /tmp/dump
_TFProfRoot (--/930.58k params)
global_step (0/0 params)
pool_logit/DW (64x10, 640/1.28k params)
pool_logit/biases (10, 10/20 params)
```
13) Analyze how balanced Variable are on parameter servers.
In this tutorial, I'm going to use a seq2seq model, which are split
on several gpus at workers and several parameter servers.
In tfprof, 'device' is an op_type. For example, if op1 and op2 are placed on
gpu0. They share an op_type called 'gpu0'.
```shell
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path ~/tfprof/textsum/graph.pbtxt \
--run_meta_path ~/tfprof/textsum/run_meta
# Looks like ps task 1 is holding twice more parameters than task 0.
tfprof> scope -select device,params -account_type_regexes .*ps.*task:0.* -max_depth 1
_TFProfRoot (--/25.81m params)
tfprof> scope -select device,params -account_type_regexes .*ps.*task:1.* -max_depth 1
_TFProfRoot (--/58.84m params)
```
### CLI Input Files
tfprof command line inference (CLI) loads dumped files from a tensorflow model.
Convert them into in-memory data structures. To use it, users need to specify
the locations of the dumped files. The following are the dumped files loaded
by tfprof:
<b>--graph_path:</b> GraphDef text file (required). Used to build in-memory
representation of the model. For example, graph.pbtxt written by tf.Supervisor
is a candidate. If you are not using tf.Supervisor, you can easily get GraphDef
using tf.Graph.as_graph_def() or other API.
<b>--run_meta_path:</b> tensorflow::RunMetadata.
Used to get the memory and time consumption of
each op of the model. Users need to enable it. For example, the following code
snippet writes a RunMetadata file:
```python
run_options = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
# Once a while, call it the get the RunMeta.
_ = self._sess.run(..., options=run_options, run_metadata=run_metadata)
with gfile.Open(os.path.join(output_dir, "run_meta"), "w") as f:
f.write(run_metadata.SerializeToString())
```
<b>--op_log_path:</b>
tensorflow::tfprof::OpLog. A proto used to provide extra op information
for ops. By giving a group of ops a type name, users can easily aggregate the
statistics for those ops without accidently missing or including extra ops.
tfprof exposes the following Python API to add op information and logging.
```python
tf.contrib.tfprof.tfprof_logger.write_op_log(graph, log_dir, op_log=None)
```
<b>--checkpoint_path:</b>
TensorFlow checkpoint. It defines _checkpoint_variable op type. It also
provides checkpointed tensors' values.
## Design
### In-memory representation
<b>Scope:</b> This representation organizes ops based on name scope hierarchy,
similar to filesystem hierarchy. Hence, it is essentially a tree data structure.
For example op1 with name “name1/name2” is a child of op2 with name “name1”.
<b>Graph:</b> The representation organizes ops based on op inputs. Hence it is
a graph structure. The graph is a “directed acyclic graph” (hopefully), with
direction from “output to input”. The direction is design this way so that users
can trace from “result” to its “sources”.
### Command line options
tfprofs major goals are to measure system performance and quicly analyze
model architectures. Hence, its commands and options should allow users to achieve
these 2 goals easily.
<b>graph:</b> It is expected that users will mostly use graph representation to
debug system performance. Hence, tfprof supports graph command, which pulls the
graph in-memory representation described above.
<b>scope:</b> It is expected that some users might want to explore their model
statistics using the name scope information they defined in the Python codes.
Hence, tfprof supports “scope” command, which pulls the tree in-memory
representation.
<b>set:</b> It is used to store the options so that user doesnt need to
re-type the same option again and again in the follow up command line. Note that
tfprof has traditional terminals history and auto-complete support.
<b>help:</b> print help information.
<b>Options:</b> Run “tfprof help” to get detailed explanations.
```python
"-max_depth",
"-min_bytes",
"-min_micros",
"-min_params",
"-min_float_ops",
"-order_by",
"-account_type_regexes",
"-start_name_regexes",
"-trim_name_regexes",
"-show_name_regexes",
"-hide_name_regexes",
"-account_displayed_op_only",
"-select",
"-viz", # Only supported for graph command.
"-dump_to_file",
```
A key design is that stats are aggregated from descendants up to ancestors.
`-account_type_regexes` is used to decide which ops stat is accounted. It makes
decision based on op type. Usually set it to `.*` if no extra type information
is added to the ops using OpLog. Intuitively, only accounted ops are displayed.
`-min/max` and `-show/hide/trim/start` options are only used the optionally
displayed or hide ops based on ops name and stats. However, they dont prevent
tfprof from accounting stats of hidden ops. Hence, the stat of a op can be
aggregated by its parent even if it is hidden. `-account_displayed_op_only` is
an option to break this rule. When it is set, only displayed ops are accounted.
Regexes are all comma-separated, for example `-show_name_regexes`
`regex1.*,regex2.*`. It is designed this way because it is convenient and comma
is not expected to show up in op names.
`-order_by` is used to order displayed ops. Displayed ops at the same hierarchy
(notice the indent printed) are sorted according to order_by.
## Future Work
* Load SummaryWriter event logs so that it can show the latest summary value.
* Better sorting and aggregation of outputs. Easier comprehension.
* Currently, shape information is based on `graph.pbtxt`. When the shape
information is incomplete, tfprof ignores it. See if it can use `RunMetadata`
and `Checkpoint` to complete shape information.
Enjoy!

View File

@ -17,5 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
from tensorflow.python.util.all_util import make_all

View File

@ -3,14 +3,36 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
py_library(
name = "model_analyzer",
srcs = ["model_analyzer.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/tfprof/python/tools/tfprof:pywrap_tensorflow_print_model_analysis_lib",
"//tensorflow/contrib/tfprof/python/tools/tfprof:tfprof_logger",
"//tensorflow/tools/tfprof:protos_all_py",
],
)
py_test(
name = "model_analyzer_test",
srcs = ["model_analyzer_test.py"],
srcs_version = "PY2AND3",
deps = [
":model_analyzer",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "tfprof_logger",
srcs = ["tfprof_logger.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/tools/tfprof:protos_all_py",
],
)
@ -20,7 +42,34 @@ tf_py_test(
additional_deps = [
":tfprof_logger",
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py",
"//tensorflow/tools/tfprof:protos_all_py",
],
)
tf_py_wrap_cc(
name = "pywrap_tensorflow_print_model_analysis_lib",
srcs = ["pywrap_tensorflow_print_model_analysis.i"],
swig_includes = [
"//tensorflow/python:lib/core/strings.i",
"//tensorflow/python:platform/base.i",
],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/tools/tfprof/internal:print_model_analysis_hdr",
"//util/python:python_headers",
],
)
py_test(
name = "print_model_analysis_test",
srcs = ["print_model_analysis_test.py"],
srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow_print_model_analysis_lib",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/tools/tfprof:protos_all_py",
],
)

View File

@ -0,0 +1,188 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model Analyzer.
Analyze model, including shape, params, time, memory, structure, etc.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tfprof.python.tools.tfprof import pywrap_tensorflow_print_model_analysis_lib as print_mdl
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
from tensorflow.tools.tfprof import tfprof_options_pb2
from tensorflow.tools.tfprof import tfprof_output_pb2
# pylint: disable=bad-whitespace
# pylint: disable=bad-continuation
# 2 example tfprof_options for print_model_analysis API.
#
# Show the parameter statistics of trainable variables.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS = {
'max_depth': 10000,
'min_bytes': 0,
'min_micros': 0,
'min_params': 0,
'min_float_ops': 0,
'device_regexes': ['.*'],
'order_by': 'name',
'account_type_regexes': [tfprof_logger.TRAINABLE_VARIABLES],
'start_name_regexes': ['.*'],
'trim_name_regexes': [],
'show_name_regexes': ['.*'],
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['params'],
'viz': False,
'dump_to_file': ''
}
# Show the number float operations.
FLOAT_OPS_OPTIONS = {
'max_depth': 10000,
'min_bytes': 0,
'min_micros': 0,
'min_params': 0,
'min_float_ops': 1,
'device_regexes': ['.*'],
'order_by': 'float_ops',
'account_type_regexes': ['.*'],
'start_name_regexes': ['.*'],
'trim_name_regexes': [],
'show_name_regexes': ['.*'],
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['float_ops'],
'viz': False,
'dump_to_file': ''
}
# Show number of parameters on parameter server 0.
# It is recommended to provide`run_meta` argument
# to have complete device placement info.
PRINT_PARAMS_ON_DEVICE = {
'max_depth': 1,
'min_bytes': 0,
'min_micros': 0,
'min_params': 0,
'min_float_ops': 0,
'device_regexes': ['.*'],
'order_by': 'name',
'account_type_regexes': ['.*ps.*task:0.*'],
'start_name_regexes': ['.*'],
'trim_name_regexes': [],
'show_name_regexes': ['.*'],
'hide_name_regexes': [],
'account_displayed_op_only': False,
'select': ['device', 'params'],
'viz': False,
'dump_to_file': ''
}
# Show the timing stats and memory demands.
PRINT_ALL_TIMING_MEMORY = {
'max_depth': 10000,
'min_bytes': 1, # Only >=1
'min_micros': 1, # Only >=1
'min_params': 0,
'min_float_ops': 0,
'device_regexes': ['.*'],
'order_by': 'name',
'account_type_regexes': ['.*'],
'start_name_regexes': ['.*'],
'trim_name_regexes': [],
'show_name_regexes': ['.*'],
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['micros', 'bytes'],
'viz': False,
'dump_to_file': ''
}
# pylint: enable=bad-whitespace
# pylint: enable=bad-continuation
def print_model_analysis(graph,
run_meta=None,
op_log=None,
tfprof_cmd='scope',
tfprof_options=TRAINABLE_VARS_PARAMS_STAT_OPTIONS):
"""Print model statistics.
Prints the model statistics to stdout. Also returns the results
in a TFProfNode proto. See go/tfprof or run tfprof tool:
'bazel run third_party/tensorflow/tools/tfprof help'
Examples:
Show the parameter/shape statistics of tf.trainable_variables().
print_model_analysis(sess.graph).
Show number of float ops. Only ops with RegisterStatistics defined
are counted.
show_float_op_opts = model_analyzer.FLOAT_OPS_OPTIONS
print_model_analysis(sess.graph, tfprof_options=show_float_op_opts)
Args:
graph: tf.Graph.
run_meta: tensorflow::RunMetadata proto. When provided, also shows valid
timing and memory information when 'select' option contains
'micros' and 'bytes'.
op_log: tensorflow::tfprof::OpLog proto. users can use this proto to
group together ops and use a op_type to select the group.
tfprof_cmd: string. Either 'scope' or 'graph'. 'scope' view organize
ops using their name scopes. 'graph' view organize ops using
their graph inputs.
tfprof_options: See 'tfprof help' for details.
Returns:
TFProfNode proto. Side effect: a formatted output to stdout.
"""
# pylint: disable=protected-access
op_log = tfprof_logger._merge_default_with_oplog(graph, op_log, run_meta)
# pylint: enable=protected-access
opts = tfprof_options_pb2.OptionsProto()
opts.max_depth = tfprof_options['max_depth']
opts.min_bytes = tfprof_options['min_bytes']
opts.min_micros = tfprof_options['min_micros']
opts.min_params = tfprof_options['min_params']
opts.min_float_ops = tfprof_options['min_float_ops']
for p in tfprof_options['device_regexes']:
opts.device_regexes.append(p)
opts.order_by = tfprof_options['order_by']
for p in tfprof_options['account_type_regexes']:
opts.account_type_regexes.append(p)
for p in tfprof_options['start_name_regexes']:
opts.start_name_regexes.append(p)
for p in tfprof_options['trim_name_regexes']:
opts.trim_name_regexes.append(p)
for p in tfprof_options['show_name_regexes']:
opts.show_name_regexes.append(p)
for p in tfprof_options['hide_name_regexes']:
opts.hide_name_regexes.append(p)
opts.account_displayed_op_only = tfprof_options['account_displayed_op_only']
for p in tfprof_options['select']:
opts.select.append(p)
opts.viz = tfprof_options['viz']
opts.dump_to_file = tfprof_options['dump_to_file']
run_meta_str = run_meta.SerializeToString() if run_meta else b''
op_log_str = op_log.SerializeToString() if op_log else b''
tfprof_node = tfprof_output_pb2.TFProfNode()
tfprof_node.ParseFromString(
print_mdl.PrintModelAnalysis(
graph.as_graph_def().SerializeToString(), run_meta_str, op_log_str,
tfprof_cmd.encode('utf-8'), opts.SerializeToString()))
return tfprof_node

View File

@ -0,0 +1,84 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
class PrintModelAnalysisTest(tf.test.TestCase):
def _BuildSmallModel(self):
image = tf.zeros([2, 6, 6, 3])
kernel = tf.get_variable(
'DW', [3, 3, 3, 6],
tf.float32,
initializer=tf.random_normal_initializer(stddev=0.001))
x = tf.nn.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
kernel = tf.get_variable(
'DW2', [2, 2, 6, 12],
tf.float32,
initializer=tf.random_normal_initializer(stddev=0.001))
x = tf.nn.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
return x
def testDumpToFile(self):
opts = tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
opts['dump_to_file'] = os.path.join(tf.test.get_temp_dir(), 'dump')
with tf.Session() as sess, tf.device('/cpu:0'):
_ = self._BuildSmallModel()
tf.contrib.tfprof.model_analyzer.print_model_analysis(
sess.graph, tfprof_options=opts)
with tf.gfile.Open(opts['dump_to_file'], 'r') as f:
self.assertEqual(u'_TFProfRoot (--/450 params)\n'
' DW (3x3x3x6, 162/162 params)\n'
' DW2 (2x2x6x12, 288/288 params)\n',
f.read().decode('utf-8'))
def testSelectEverything(self):
opts = tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
opts['dump_to_file'] = os.path.join(tf.test.get_temp_dir(), 'dump')
opts['account_type_regexes'] = ['.*']
opts['select'] = [
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', 'op_types'
]
with tf.Session() as sess, tf.device('/cpu:0'):
x = self._BuildSmallModel()
sess.run(tf.initialize_all_variables())
run_meta = tf.RunMetadata()
_ = sess.run(x,
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_meta)
tf.contrib.tfprof.model_analyzer.print_model_analysis(
sess.graph, run_meta, tfprof_options=opts)
with tf.gfile.Open(opts['dump_to_file'], 'r') as f:
# pylint: disable=line-too-long
self.assertEqual(
'_TFProfRoot (0/450 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Variable|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Variable|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n init (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n',
f.read().decode('utf-8'))
# pylint: enable=line-too-long
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,238 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""print_model_analysis test."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.contrib.tfprof.python.tools.tfprof import pywrap_tensorflow_print_model_analysis_lib as print_mdl
from tensorflow.tools.tfprof import tfprof_options_pb2
from tensorflow.tools.tfprof import tfprof_output_pb2
# pylint: disable=bad-whitespace
# pylint: disable=bad-continuation
TEST_OPTIONS = {
'max_depth': 10000,
'min_bytes': 0,
'min_micros': 0,
'min_params': 0,
'min_float_ops': 0,
'device_regexes': ['.*'],
'order_by': 'name',
'account_type_regexes': ['.*'],
'start_name_regexes': ['.*'],
'trim_name_regexes': [],
'show_name_regexes': ['.*'],
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['params'],
'viz': False
}
# pylint: enable=bad-whitespace
# pylint: enable=bad-continuation
class PrintModelAnalysisTest(tf.test.TestCase):
def _BuildSmallModel(self):
image = tf.zeros([2, 6, 6, 3])
kernel = tf.get_variable(
'DW', [6, 6, 3, 6],
tf.float32,
initializer=tf.random_normal_initializer(stddev=0.001))
x = tf.nn.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
return x
def testPrintModelAnalysis(self):
opts = tfprof_options_pb2.OptionsProto()
opts.max_depth = TEST_OPTIONS['max_depth']
opts.min_bytes = TEST_OPTIONS['min_bytes']
opts.min_micros = TEST_OPTIONS['min_micros']
opts.min_params = TEST_OPTIONS['min_params']
opts.min_float_ops = TEST_OPTIONS['min_float_ops']
for p in TEST_OPTIONS['device_regexes']:
opts.device_regexes.append(p)
opts.order_by = TEST_OPTIONS['order_by']
for p in TEST_OPTIONS['account_type_regexes']:
opts.account_type_regexes.append(p)
for p in TEST_OPTIONS['start_name_regexes']:
opts.start_name_regexes.append(p)
for p in TEST_OPTIONS['trim_name_regexes']:
opts.trim_name_regexes.append(p)
for p in TEST_OPTIONS['show_name_regexes']:
opts.show_name_regexes.append(p)
for p in TEST_OPTIONS['hide_name_regexes']:
opts.hide_name_regexes.append(p)
opts.account_displayed_op_only = TEST_OPTIONS['account_displayed_op_only']
for p in TEST_OPTIONS['select']:
opts.select.append(p)
opts.viz = TEST_OPTIONS['viz']
with tf.Session() as sess, tf.device('/cpu:0'):
_ = self._BuildSmallModel()
tfprof_pb = tfprof_output_pb2.TFProfNode()
tfprof_pb.ParseFromString(
print_mdl.PrintModelAnalysis(sess.graph.as_graph_def(
).SerializeToString(), b'', b'', b'scope', opts.SerializeToString()))
expected_pb = tfprof_output_pb2.TFProfNode()
text_format.Merge(r"""name: "_TFProfRoot"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 648
children {
name: "Conv2D"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
children {
name: "DW"
exec_micros: 0
requested_bytes: 0
parameters: 648
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 648
device: "/device:CPU:0"
children {
name: "DW/Assign"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
children {
name: "DW/Initializer"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
children {
name: "DW/Initializer/random_normal"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
children {
name: "DW/Initializer/random_normal/RandomStandardNormal"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
children {
name: "DW/Initializer/random_normal/mean"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
children {
name: "DW/Initializer/random_normal/mul"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
children {
name: "DW/Initializer/random_normal/shape"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
children {
name: "DW/Initializer/random_normal/stddev"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
float_ops: 0
total_float_ops: 0
}
float_ops: 0
total_float_ops: 0
}
children {
name: "DW/read"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
float_ops: 0
total_float_ops: 0
}
children {
name: "zeros"
exec_micros: 0
requested_bytes: 0
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
float_ops: 0
total_float_ops: 0""", expected_pb)
self.assertEqual(expected_pb, tfprof_pb)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,43 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
#include "tensorflow/core/framework/types.h"
%}
%typemap(typecheck) const string & = char *;
%typemap(in) const string& (string temp) {
if (!_PyObjAs<string>($input, &temp)) return NULL;
$1 = &temp;
}
%typemap(out) const string& {
$result = PyString_FromStringAndSize($1->data(), $1->size());
}
%apply const string & {string &};
%apply const string & {string *};
%ignoreall
%unignore tensorflow;
%unignore tensorflow::tfprof;
%unignore tensorflow::tfprof::PrintModelAnalysis;
%include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
%unignoreall

View File

@ -24,8 +24,8 @@ import os
import sys
import tensorflow as tf
from tensorflow.contrib.tfprof.tools.tfprof import tfprof_log_pb2
from tensorflow.python.framework import ops
from tensorflow.tools.tfprof import tfprof_log_pb2
TRAINABLE_VARIABLES = '_trainable_variables'
REGISTERED_FLOP_STATS = 'flops'
@ -85,7 +85,7 @@ def _get_logged_ops(graph, run_meta=None):
if node.name not in logged_ops:
entry = tfprof_log_pb2.OpLogEntry()
entry.name = node.name
entry.float_ops = stats.value
entry.float_ops = int(stats.value)
logged_ops[entry.name] = entry
for v in graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):

View File

@ -32,8 +32,9 @@ like to store state in the forward direction across segments of an example.
To resample data with replacement on a per-example basis, use
['rejection_sample'](#rejection_sample) or
['resample_at_rate'](#resample_at_rate). For `rejection_sample`, provide
a boolean Tensor describing whether to accept or reject. For `resample_at_rate`,
providing the desired rate for each example. If you wish to specify relative
a boolean Tensor describing whether to accept or reject. Resulting batch sizes
are always the same. For `resample_at_rate`, provide the desired rate for each
example. Resulting batch sizes may vary. If you wish to specify relative
rates, rather than absolute ones, use ['weighted_resample'](#weighted_resample)
(which also returns the actual resampling rate used for each output example).

View File

@ -16,8 +16,10 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/immutable_constant_op.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@ -45,13 +47,27 @@ class NodeConverter {
const DataType tensor_data_type = tensor_proto.dtype();
const TensorShapeProto tensor_shape = tensor_proto.tensor_shape();
// Check that the tensor type is POD, only these types are supported for
// memmapping.
// DataType enum is explicitly converted to int to avoid errors with passing
// enum type are a parameter type to std::unordered_set.
static std::unordered_set<int> supported_types{
#define TYPE_FOR_SET(type) static_cast<int>(DataTypeToEnum<type>::value),
TF_CALL_POD_TYPES(TYPE_FOR_SET)
#undef ADD_TYPE
};
if (supported_types.count(static_cast<int>(tensor_data_type)) == 0) {
return Status::OK();
}
// Create Tensor from value and write it in memmapped format.
Tensor parsed(tensor_proto.dtype());
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
tensor_proto.DebugString());
}
if (parsed.TotalBytes() < min_conversion_size_bytes) {
if (parsed.TotalBytes() < static_cast<size_t>(min_conversion_size_bytes)) {
return Status::OK();
}

View File

@ -26,6 +26,15 @@ limitations under the License.
namespace tensorflow {
namespace {
bool GraphHasImmutableConstNodes(const GraphDef& graph_def) {
for (const auto& node : graph_def.node()) {
if (node.op() == "ImmutableConst") {
return true;
}
}
return false;
}
TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
const string dir = testing::TmpDir();
const string filename_pb = io::JoinPath(dir, "graphdef.pb");
@ -69,6 +78,7 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
TF_ASSERT_OK(ReadBinaryProto(
&memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&loaded_graph_def));
ASSERT_TRUE(GraphHasImmutableConstNodes(loaded_graph_def));
TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph";
std::vector<Tensor> outputs;
@ -79,5 +89,48 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f * kTensorHeight);
}
TEST(ConvertGraphdefMemmappedFormatTest, NotSupportedTypesConvert) {
// Create a graph with strings.
const string dir = testing::TmpDir();
const string filename_pb = io::JoinPath(dir, "string_graphdef.pb");
constexpr int kTensorWidth = 4000;
constexpr int kTensorHeight = 100;
const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight});
Tensor test_tensor1(DT_STRING, kTestTensorShape);
test::FillFn<string>(&test_tensor1, [](int) -> string { return "ABC"; });
Tensor test_tensor2(DT_STRING, kTestTensorShape);
test::FillFn<string>(&test_tensor2, [](int) -> string { return "XYZ"; });
auto root = Scope::NewRootScope().ExitOnError();
ops::Output m = ops::Add(root, test_tensor1, test_tensor2);
const string result_name = m.node()->name();
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
string graph_def_serialized;
graph_def.SerializeToString(&graph_def_serialized);
TF_ASSERT_OK(
WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized));
const string filename_mmap = io::JoinPath(dir, "string_graphdef.mmap");
TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 1000));
// Create and initialize MemmappedEnv from the converted file.
MemmappedEnv memmapped_env(Env::Default());
TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap));
// Load the graph and run calculations.
SessionOptions session_options;
session_options.env = &memmapped_env;
std::unique_ptr<Session> session(NewSession(session_options));
ASSERT_TRUE(session != nullptr) << "Failed to create session";
GraphDef loaded_graph_def;
TF_ASSERT_OK(ReadBinaryProto(
&memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&loaded_graph_def));
ASSERT_FALSE(GraphHasImmutableConstNodes(loaded_graph_def));
}
} // namespace
} // namespace tensorflow

View File

@ -164,6 +164,8 @@ cc_library(
"lib/core/threadpool.h",
"lib/gtl/array_slice.h",
"lib/gtl/cleanup.h",
"lib/gtl/flatmap.h",
"lib/gtl/flatset.h",
"lib/gtl/inlined_vector.h",
"lib/gtl/priority_queue_util.h",
"lib/hash/crc32c.h",
@ -178,7 +180,6 @@ cc_library(
"lib/io/table.h",
"lib/io/table_builder.h",
"lib/io/table_options.h",
"lib/jpeg/jpeg_mem.h",
"lib/math/math_util.h",
"lib/monitoring/collected_metrics.h",
"lib/monitoring/collection_registry.h",
@ -220,6 +221,13 @@ cc_library(
],
)
cc_library(
name = "jpeg",
hdrs = ["lib/jpeg/jpeg_mem.h"],
visibility = ["//visibility:public"],
deps = [":jpeg_internal"],
)
# Test support library needed for all tests
# This is currently public, but may be made internal in the
# future. Try to avoid depending on it.
@ -521,6 +529,7 @@ cc_library(
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:ctc_ops",
"//tensorflow/core/kernels:data_flow",
"//tensorflow/core/kernels:fake_quant_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:image",
"//tensorflow/core/kernels:io",
@ -970,6 +979,7 @@ cc_library(
],
exclude = [
"**/*test*",
"lib/jpeg/**/*",
"platform/**/cuda.h",
"platform/**/stream_executor.h",
"platform/load_library.cc",
@ -986,6 +996,7 @@ cc_library(
],
exclude = [
"**/*test*",
"lib/jpeg/**/*",
"platform/**/cuda.h",
"platform/**/stream_executor.h",
],
@ -1019,7 +1030,6 @@ cc_library(
"lib/io/zlib_compression_options.h",
"lib/io/zlib_inputstream.h",
"lib/io/zlib_outputbuffer.h",
"lib/jpeg/jpeg_handle.h",
"lib/png/png_io.h",
"lib/random/random.h",
"lib/random/random_distributions.h",
@ -1048,6 +1058,26 @@ cc_library(
],
)
cc_library(
name = "jpeg_internal",
srcs = glob(
[
"lib/jpeg/*h",
"lib/jpeg/*.cc",
],
exclude = [
"**/*test*",
],
),
hdrs = ["lib/jpeg/jpeg_handle.h"],
copts = tf_copts(),
linkopts = ["-ldl"],
deps = [
":lib",
"//tensorflow/core/platform/default/build_config:jpeg",
],
)
proto_text_hdrs_and_srcs = tf_generate_proto_text_sources(
name = "proto_text_srcs_all",
srcs = tf_proto_text_protos_relative(),
@ -1149,83 +1179,6 @@ cc_header_only_library(
],
)
filegroup(
name = "framework_headers",
srcs = [
"framework/allocator.h",
"framework/attr_value_util.h",
"framework/bfloat16.h",
"framework/cancellation.h",
"framework/control_flow.h",
"framework/device_base.h",
"framework/function.h",
"framework/kernel_def_builder.h",
"framework/node_def_util.h",
"framework/numeric_types.h",
"framework/op.h",
"framework/op_def_builder.h",
"framework/op_def_util.h",
"framework/op_kernel.h",
"framework/partial_tensor_shape.h",
"framework/register_types.h",
"framework/rendezvous.h",
"framework/selective_registration.h",
"framework/session_state.h",
"framework/shape_inference.h",
"framework/tensor.h",
"framework/tensor_reference.h",
"framework/tensor_shape.h",
"framework/tensor_types.h",
"framework/tracking_allocator.h",
"framework/type_traits.h",
"framework/types.h",
"framework/unique_tensor_references.h",
"lib/core/errors.h",
"lib/core/notification.h",
"lib/core/refcount.h",
"lib/core/status.h",
"lib/core/stringpiece.h",
"lib/core/threadpool.h",
"lib/gtl/array_slice.h",
"lib/gtl/array_slice_internal.h",
"lib/gtl/inlined_vector.h",
"lib/gtl/manual_constructor.h",
"lib/hash/hash.h",
"lib/strings/numbers.h",
"lib/strings/str_util.h",
"lib/strings/strcat.h",
"platform/cpu_info.h",
"platform/default/dynamic_annotations.h",
"platform/default/integral_types.h",
"platform/default/logging.h",
"platform/default/mutex.h",
"platform/default/notification.h",
"platform/default/protobuf.h",
"platform/default/thread_annotations.h",
"platform/dynamic_annotations.h",
"platform/env.h",
"platform/file_statistics.h",
"platform/file_system.h",
"platform/fingerprint.h",
"platform/logging.h",
"platform/macros.h",
"platform/mem.h",
"platform/mutex.h",
"platform/net.h",
"platform/notification.h",
"platform/platform.h",
"platform/prefetch.h",
"platform/protobuf.h",
"platform/strong_hash.h",
"platform/thread_annotations.h",
"platform/types.h",
"public/session.h",
"public/session_options.h",
"public/version.h",
"util/device_name_utils.h",
],
)
tf_cuda_library(
name = "stream_executor",
srcs = tf_additional_stream_executor_srcs(),
@ -1316,7 +1269,7 @@ cc_library(
"platform/regexp.h",
],
visibility = [
"//tensorflow/contrib/tfprof:__subpackages__",
"//tensorflow/tools/tfprof:__subpackages__",
],
deps = [":lib_internal"],
)
@ -1326,11 +1279,13 @@ tf_cuda_library(
srcs = ["common_runtime/direct_session.cc"],
hdrs = ["common_runtime/direct_session.h"],
copts = tf_copts(),
cuda_deps = [
":gpu_tracer",
],
linkstatic = 1,
deps = [
":core_cpu_internal",
":framework",
":gpu_tracer",
":lib",
":lib_internal",
":proto_text",
@ -1496,6 +1451,8 @@ tf_cc_tests(
"lib/gtl/array_slice_test.cc",
"lib/gtl/cleanup_test.cc",
"lib/gtl/edit_distance_test.cc",
"lib/gtl/flatmap_test.cc",
"lib/gtl/flatset_test.cc",
"lib/gtl/inlined_vector_test.cc",
"lib/gtl/int_type_test.cc",
"lib/gtl/iterator_range_test.cc",
@ -1582,6 +1539,8 @@ cc_test(
srcs = ["lib/jpeg/jpeg_mem_unittest.cc"],
data = glob(["lib/jpeg/testdata/*.jpg"]),
deps = [
":jpeg",
":jpeg_internal",
":lib",
":lib_internal",
":test",

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/simple_placer.h"
@ -57,6 +56,10 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
namespace {
@ -453,12 +456,14 @@ Status DirectSession::Run(const RunOptions& run_options,
args.stats_collector = run_state.collector.get();
}
#if GOOGLE_CUDA
std::unique_ptr<GPUTracer> tracer;
if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
tracer.reset(CreateGPUTracer());
// tracer will be NULL on non-GPU platforms.
if (tracer) tracer->Start();
}
#endif // GOOGLE_CUDA
for (const auto& item : executors_and_keys->items) {
item.executor->RunAsync(args, barrier->Get());
@ -468,10 +473,12 @@ Status DirectSession::Run(const RunOptions& run_options,
? run_options.timeout_in_ms()
: operation_timeout_in_ms_);
#if GOOGLE_CUDA
if (tracer) {
tracer->Stop();
tracer->Collect(args.stats_collector);
}
#endif // GOOGLE_CUDA
{
mutex_lock l(run_state.mu_);
@ -840,10 +847,11 @@ Status DirectSession::GetOrCreateExecutors(
std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
std::sort(tn_sorted.begin(), tn_sorted.end());
const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->",
str_util::Join(outputs_sorted, ","), "/",
str_util::Join(tn_sorted, ","), "/",
run_state_args->is_partial_run);
const string key = strings::StrCat(
str_util::Join(inputs_sorted, ","), "->",
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
"/", run_state_args->is_partial_run, "/",
SummarizeDebugTensorWatches(run_state_args->debug_tensor_watches));
// Set the handle.
run_state_args->handle =
@ -938,7 +946,7 @@ Status DirectSession::GetOrCreateExecutors(
partition_graph = iter->second.release();
optimizer.Optimize(lib, options_.env, device, &partition_graph);
// EXPERIMENTAL: tfdb inserts debug nodes (i.e., probes) to the graph
// EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph
if (!run_state_args->debug_tensor_watches.empty()) {
TF_RETURN_IF_ERROR(
DebugNodeInserter::InsertNodes(run_state_args->debug_tensor_watches,

View File

@ -291,7 +291,7 @@ class DirectSession : public Session {
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
// EXPERIMENTAL: debugger (tfdb) related
// EXPERIMENTAL: debugger (tfdbg) related
friend class DebugGateway;
};

View File

@ -222,7 +222,7 @@ typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
class ExecutorImpl : public Executor {
public:
ExecutorImpl(const LocalExecutorParams& p, const Graph* g)
: params_(p), graph_(g), initial_pending_counts_(graph_->num_node_ids()) {
: params_(p), graph_(g) {
CHECK(p.create_kernel != nullptr);
CHECK(p.delete_kernel != nullptr);
}
@ -231,6 +231,7 @@ class ExecutorImpl : public Executor {
for (int i = 0; i < graph_->num_node_ids(); i++) {
params_.delete_kernel(nodes_[i].kernel);
}
delete[] frame_local_ids_;
delete[] nodes_;
delete graph_;
}
@ -256,13 +257,39 @@ class ExecutorImpl : public Executor {
private:
friend class ExecutorState;
static void InitializePending(const Graph* graph, PendingCounts* counts);
struct ControlFlowInfo {
std::unordered_map<string, int> frame_name_to_size;
std::vector<string> frame_names;
};
struct FrameInfo {
// The total number of inputs to a frame.
int input_count;
// The total number of input tensors of a frame.
// == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame.
int total_inputs;
// Each frame has its own PendingCounts only for the nodes in the frame.
PendingCounts* pending_counts; // Owned
// The nodes in a frame. Used only for debugging.
std::vector<const Node*>* nodes; // Owned
~FrameInfo() {
delete pending_counts;
delete nodes;
}
};
static Status BuildControlFlowInfo(const Graph* graph,
ControlFlowInfo* cf_info);
void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info);
// Owned.
LocalExecutorParams params_;
const Graph* graph_;
NodeItem* nodes_ = nullptr; // array of size "graph_.num_node_ids()"
int total_input_tensors_ = 0; // == sum(nodes_[*].num_inputs())
int total_output_tensors_ = 0; // == sum(nodes_[*].num_outputs())
// A cached value of params_
@ -271,14 +298,17 @@ class ExecutorImpl : public Executor {
// Root nodes (with no in edges) that should form the initial ready queue
std::vector<const Node*> root_nodes_;
PendingCounts initial_pending_counts_;
// The number of inputs for each frame in this graph. This is static
// information of the graph.
std::unordered_map<string, int> frame_input_count_;
std::vector<AllocatorAttributes> output_attrs_;
// Mapping from frame name to static information about the frame.
// TODO(yuanbyu): We could cache it along with the graph so to avoid
// the overhead of constructing it for each executor instance.
std::unordered_map<string, FrameInfo> frame_info_;
// Mapping from a node's id to its index in the PendingCounts of the
// frame the node belongs to.
int* frame_local_ids_ = nullptr; // Owned
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
};
@ -287,23 +317,31 @@ Status ExecutorImpl::Initialize() {
delete[] nodes_;
nodes_ = new NodeItem[num_nodes];
Status s;
total_input_tensors_ = 0;
total_output_tensors_ = 0;
InitializePending(graph_, &initial_pending_counts_);
// Build the information about frames in this subgraph.
ControlFlowInfo cf_info;
BuildControlFlowInfo(graph_, &cf_info);
// Cache this value so we make this virtual function call once, rather
// that O(# steps * # nodes per step) times.
device_record_tensor_accesses_ =
params_.device->RequiresRecordingAccessedTensors();
for (auto& it : cf_info.frame_name_to_size) {
frame_info_[it.first].nodes = new std::vector<const Node*>;
}
frame_local_ids_ = new int[num_nodes];
std::unordered_map<string, int> frame_count;
// Preprocess every node in the graph to create an instance of op
// kernel for each node;
// kernel for each node.
for (const Node* n : graph_->nodes()) {
const int id = n->id();
const string& frame_name = cf_info.frame_names[id];
FrameInfo& frame_info = frame_info_[frame_name];
// See if this node is a root node, and if so, add to root_nodes_
// See if this node is a root node, and if so, add to root_nodes_.
const int num_in_edges = n->in_edges().size();
if (num_in_edges == 0) {
root_nodes_.push_back(n);
@ -321,18 +359,18 @@ Status ExecutorImpl::Initialize() {
item->inlined_output_type[i] = n->output_type(i);
}
item->input_start = total_input_tensors_;
total_input_tensors_ += n->num_inputs();
item->input_start = frame_info.total_inputs;
frame_info.total_inputs += n->num_inputs();
item->output_attr_start = total_output_tensors_;
total_output_tensors_ += n->num_outputs();
s = params_.create_kernel(n->def(), &item->kernel);
Status s = params_.create_kernel(n->def(), &item->kernel);
if (!s.ok()) {
item->kernel = nullptr;
s = AttachDef(s, n->def());
LOG(ERROR) << "Executor failed to create kernel. " << s;
break;
return s;
}
CHECK(item->kernel);
item->kernel_is_expensive = item->kernel->IsExpensive();
@ -340,14 +378,18 @@ Status ExecutorImpl::Initialize() {
item->is_merge = IsMerge(n);
// Initialize static information about the frames in the graph.
frame_local_ids_[id] = frame_count[frame_name]++;
frame_info.nodes->push_back(n);
if (IsEnter(n)) {
string frame_name;
s = GetNodeAttr(n->def(), "frame_name", &frame_name);
if (!s.ok()) return s;
++frame_input_count_[frame_name];
string enter_name;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "frame_name", &enter_name));
++frame_info_[enter_name].input_count;
}
}
if (!s.ok()) return s;
// Initialize PendingCounts only after frame_local_ids_ is initialized.
InitializePending(graph_, cf_info);
return SetAllocAttrs();
}
@ -533,12 +575,13 @@ class ExecutorState {
typedef gtl::InlinedVector<Entry, 4> EntryVector;
struct IterationState {
explicit IterationState(const ExecutorImpl* impl)
: input_tensors(new Entry[impl->total_input_tensors_]),
explicit IterationState(const PendingCounts* pending_counts,
int total_input_tensors)
: input_tensors(new Entry[total_input_tensors]),
outstanding_ops(0),
outstanding_frame_count(0),
counts_(impl->graph_->num_node_ids()) {
counts_.InitializeFrom(impl->initial_pending_counts_);
counts_(pending_counts->num_nodes()) {
counts_.InitializeFrom(*pending_counts);
}
// The state of an iteration.
@ -668,9 +711,23 @@ class ExecutorState {
// will only "execute" the dead exits of the final iteration.
std::vector<const Node*> dead_exits GUARDED_BY(mu);
// Static information specific to this frame.
PendingCounts* pending_counts = nullptr;
int total_input_tensors = 0;
std::vector<const Node*>* nodes = nullptr;
// Lock ordering: ExecutorState.mu_ < mu.
mutex mu;
void InitializeFrameInfo(const string& enter_name) {
auto it_frame_info = executor->frame_info_.find(enter_name);
DCHECK(it_frame_info != executor->frame_info_.end());
pending_counts = it_frame_info->second.pending_counts;
total_input_tensors = it_frame_info->second.total_inputs;
num_pending_inputs = it_frame_info->second.input_count;
nodes = it_frame_info->second.nodes;
}
inline IterationState* GetIteration(int64 iter)
EXCLUSIVE_LOCKS_REQUIRED(mu) {
int index = iter % iterations.size();
@ -889,13 +946,12 @@ class ExecutorState {
inline void MaybeMarkCompleted(FrameState* frame, int64 iter, int64 id);
// Provide debugging output about an outstanding node in the executor.
void DumpCompletedNodeState(const int node_id, const Entry* input_vector);
void DumpPendingNodeState(const int node_id, const Entry* input_vector,
bool show_nodes_with_no_ready_inputs);
void DumpActiveNodeState(const int node_id, const Entry* input_vector);
// Provide debugging output about an outstanding iteration in the executor.
void DumpIterationState(IterationState* iteration);
void DumpIterationState(const FrameState* frame, IterationState* iteration);
// Provide debugging output of the state of the executor.
void DumpState();
@ -932,16 +988,16 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
num_outstanding_ops_(0) {
// We start the entire execution in iteration 0 of the root frame
// so let us create the root frame and the state for iteration 0.
// Initialize the frame.
// We assume root_frame_->frame_name.empty().
root_frame_ = new FrameState(impl_, 1);
root_frame_->frame_name = "_root"; // assume to be unique
root_frame_->frame_id = 0; // must be 0
// Initialize the first iteration.
root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
IterationState* iter_state = new IterationState(impl);
root_frame_->iterations[0] = iter_state;
root_frame_->InitializeFrameInfo(root_frame_->frame_name);
// Initialize iteration 0.
root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
root_frame_->iterations[0] = new IterationState(
root_frame_->pending_counts, root_frame_->total_input_tensors);
if (vlog_) VLOG(2) << "Create frame: " << root_frame_->frame_name;
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
}
@ -949,21 +1005,88 @@ ExecutorState::~ExecutorState() {
for (auto name_frame : outstanding_frames_) {
delete name_frame.second;
}
for (auto it : device_context_map_) {
it->Unref();
}
delete slice_reader_cache_;
}
Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
ControlFlowInfo* cf_info) {
const int num_nodes = g->num_node_ids();
cf_info->frame_names.resize(num_nodes);
std::vector<Node*> parent_nodes;
parent_nodes.resize(num_nodes);
std::vector<bool> visited;
visited.resize(num_nodes);
string frame_name;
std::deque<Node*> ready;
// Initialize with the root nodes.
for (Node* n : g->nodes()) {
if (n->in_edges().empty()) {
visited[n->id()] = true;
++cf_info->frame_name_to_size[frame_name];
ready.push_back(n);
}
}
while (!ready.empty()) {
Node* curr_node = ready.front();
int curr_id = curr_node->id();
ready.pop_front();
Node* parent = nullptr;
if (IsEnter(curr_node)) {
// Enter a child frame.
TF_RETURN_IF_ERROR(
GetNodeAttr(curr_node->def(), "frame_name", &frame_name));
parent = curr_node;
} else if (IsExit(curr_node)) {
// Exit to the parent frame.
parent = parent_nodes[curr_id];
frame_name = cf_info->frame_names[parent->id()];
parent = parent_nodes[parent->id()];
} else {
parent = parent_nodes[curr_id];
frame_name = cf_info->frame_names[curr_id];
}
for (const Edge* out_edge : curr_node->out_edges()) {
Node* out = out_edge->dst();
int out_id = out->id();
// Add to ready queue if not visited.
bool is_visited = visited[out_id];
if (!is_visited) {
ready.push_back(out);
visited[out_id] = true;
// Process the node 'out'.
cf_info->frame_names[out_id] = frame_name;
parent_nodes[out_id] = parent;
++cf_info->frame_name_to_size[frame_name];
}
}
}
return Status::OK();
}
void ExecutorImpl::InitializePending(const Graph* graph,
PendingCounts* counts) {
for (int id = 0; id < graph->num_node_ids(); id++) {
counts->set_initial_count(id, 0, 0); // Make sure everything is initialized
const ControlFlowInfo& cf_info) {
for (auto& it : cf_info.frame_name_to_size) {
PendingCounts* counts = new PendingCounts(it.second);
frame_info_[it.first].pending_counts = counts;
// Make sure everything is initialized
for (int id = 0; id < it.second; id++) {
counts->set_initial_count(id, 0, 0);
}
}
for (const Node* n : graph->nodes()) {
const int id = n->id();
const int pending_id = frame_local_ids_[id];
const int num_in_edges = n->in_edges().size();
int initial_count;
if (IsMerge(n)) {
@ -980,7 +1103,9 @@ void ExecutorImpl::InitializePending(const Graph* graph,
} else {
initial_count = num_in_edges;
}
counts->set_initial_count(id, initial_count, num_in_edges);
const string& name = cf_info.frame_names[id];
PendingCounts* counts = frame_info_[name].pending_counts;
counts->set_initial_count(pending_id, initial_count, num_in_edges);
}
}
@ -1104,8 +1229,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
// TODO(misard) Replace with a finer-grain enabling flag once we
// add better optional debugging support.
if (vlog_ && VLOG_IS_ON(1)) {
int pending_id = impl_->frame_local_ids_[id];
mutex_lock l(input_frame->mu);
input_frame->GetIteration(input_iter)->mark_started(id);
input_frame->GetIteration(input_iter)->mark_started(pending_id);
}
// Set the device_context for this node id, if it exists.
@ -1637,12 +1763,13 @@ void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
}
inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter,
int64 id) {
int64 node_id) {
// TODO(misard) Replace with a finer-grain enabling flag once we
// add better optional debugging support.
if (vlog_ && VLOG_IS_ON(1)) {
int pending_id = impl_->frame_local_ids_[node_id];
mutex_lock l(frame->mu);
frame->GetIteration(iter)->mark_completed(id);
frame->GetIteration(iter)->mark_completed(pending_id);
}
}
@ -1656,18 +1783,6 @@ const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) {
}
}
void ExecutorState::DumpCompletedNodeState(const int node_id,
const Entry* input_vector) {
const NodeItem& node_item = impl_->nodes_[node_id];
const Node& node = *node_item.node;
LOG(WARNING) << " Completed Node: " << node.DebugString();
const int input_base = node_item.input_start;
for (int i = 0; i < node.num_inputs(); ++i) {
const Entry& input = input_vector[input_base + i];
CHECK(!GetTensorValueForDump(input)->IsInitialized());
}
}
void ExecutorState::DumpPendingNodeState(
const int node_id, const Entry* input_vector,
const bool show_nodes_with_no_ready_inputs) {
@ -1723,23 +1838,30 @@ void ExecutorState::DumpActiveNodeState(const int node_id,
}
}
void ExecutorState::DumpIterationState(IterationState* iteration) {
void ExecutorState::DumpIterationState(const FrameState* frame,
IterationState* iteration) {
const std::vector<const Node*>* nodes = frame->nodes;
// Dump any waiting nodes that are holding on to tensors.
for (int i = 0; i < impl_->graph_->num_node_ids(); ++i) {
if (iteration->node_state(i) == PendingCounts::PENDING_NOTREADY ||
iteration->node_state(i) == PendingCounts::PENDING_READY) {
DumpPendingNodeState(i, iteration->input_tensors, false);
for (const Node* node : *nodes) {
int node_id = node->id();
int pending_id = impl_->frame_local_ids_[node_id];
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
DumpPendingNodeState(node_id, iteration->input_tensors, false);
}
}
// Then the active nodes.
for (int i = 0; i < impl_->graph_->num_node_ids(); ++i) {
if (iteration->node_state(i) == PendingCounts::STARTED) {
DumpActiveNodeState(i, iteration->input_tensors);
for (const Node* node : *nodes) {
int node_id = node->id();
int pending_id = impl_->frame_local_ids_[node_id];
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
DumpActiveNodeState(pending_id, iteration->input_tensors);
}
}
// Show all input tensors in use.
int total_input_tensors = frame->total_input_tensors;
size_t total_bytes = 0;
for (int i = 0; i < impl_->total_input_tensors_; ++i) {
for (int i = 0; i < total_input_tensors; ++i) {
const Entry& input = iteration->input_tensors[i];
const Tensor* tensor = GetTensorValueForDump(input);
if (tensor->IsInitialized()) {
@ -1764,7 +1886,7 @@ void ExecutorState::DumpState() {
mutex_lock frame_lock(frame_state->mu);
for (IterationState* iteration : frame_state->iterations) {
LOG(WARNING) << " Iteration:";
DumpIterationState(iteration);
DumpIterationState(frame_state, iteration);
}
}
dumped_on_error_ = true;
@ -1819,16 +1941,13 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
temp->frame_id = Hash64(child_name);
temp->parent_frame = frame;
temp->parent_iter = iter;
temp->InitializeFrameInfo(enter_name);
// 'iterations' is a fixed-length circular buffer.
temp->iterations.resize(temp->max_parallel_iterations + 1);
// Initialize the first iteration.
IterationState* iter_state = new IterationState(impl_);
temp->iterations[0] = iter_state;
auto frame_pending = impl_->frame_input_count_.find(enter_name);
DCHECK(frame_pending != impl_->frame_input_count_.end());
temp->num_pending_inputs = frame_pending->second;
// Initialize iteration 0.
temp->iterations[0] =
new IterationState(temp->pending_counts, temp->total_input_tensors);
{
mutex_lock executor_lock(mu_);
@ -1851,33 +1970,40 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
FrameState* parent_frame = frame->parent_frame;
int64 parent_iter = frame->parent_iter;
if (parent_frame != nullptr) {
const int* pending_ids = impl_->frame_local_ids_;
mutex_lock paranet_frame_lock(parent_frame->mu);
// Propagate all the dead exits to the parent frame.
for (const Node* node : frame->dead_exits) {
auto parent_iter_state = parent_frame->GetIteration(parent_iter);
for (const Edge* e : node->out_edges()) {
const Node* dst_node = e->dst();
const int dst_id = dst_node->id();
const int dst_pending_id = pending_ids[dst_node->id()];
// TODO(yuanbyu): We don't need this if we require the subgraph
// given to an executor not to contain a sink node.
if (dst_node->IsSink()) continue;
bool dst_dead = true;
bool dst_ready = false;
// We know this is a dead input to dst.
if (IsMerge(dst_node)) {
if (e->IsControlEdge()) {
parent_iter_state->decrement_pending(dst_id, 2);
int count = parent_iter_state->pending(dst_id);
dst_dead = (parent_iter_state->dead_count(dst_id) ==
dst_node->num_inputs());
parent_iter_state->decrement_pending(dst_pending_id, 2);
int count = parent_iter_state->pending(dst_pending_id);
int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
dst_dead = (dead_cnt == dst_node->num_inputs());
dst_ready = (count == 0) || ((count == 1) && dst_dead);
} else {
parent_iter_state->increment_dead_count(dst_id);
const int dead_cnt = parent_iter_state->dead_count(dst_id);
parent_iter_state->increment_dead_count(dst_pending_id);
const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
dst_dead = (dead_cnt == dst_node->num_inputs());
dst_ready = (parent_iter_state->pending(dst_id) == 1) && dst_dead;
dst_ready =
(parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
}
} else {
parent_iter_state->increment_dead_count(dst_id);
dst_ready = (parent_iter_state->decrement_pending(dst_id, 1) == 0);
parent_iter_state->increment_dead_count(dst_pending_id);
dst_ready =
(parent_iter_state->decrement_pending(dst_pending_id, 1) == 0);
}
if (dst_ready) {
ready->push_back(
@ -1923,12 +2049,18 @@ void ExecutorState::FrameState::ActivateNodes(const Node* node,
const EntryVector& outputs,
TaggedNodeSeq* ready) {
const NodeItem* nodes = executor->nodes_;
const int* pending_ids = executor->frame_local_ids_;
IterationState* iter_state = GetIteration(iter);
for (const Edge* e : node->out_edges()) {
const Node* dst_node = e->dst();
const int dst_id = dst_node->id();
const int dst_pending_id = pending_ids[dst_id];
const int src_slot = e->src_output();
// TODO(yuanbyu): We don't need this if we require the subgraph
// given to an executor not to contain a sink node.
if (dst_node->IsSink()) continue;
bool dst_dead = false;
bool dst_ready = false;
// True iff this input for dst is needed. We only set this input for
@ -1940,15 +2072,16 @@ void ExecutorState::FrameState::ActivateNodes(const Node* node,
// a) a live data input becomes available or b) all data inputs are dead.
// For Merge, pending's LSB is set iff a live data input has arrived.
if (e->IsControlEdge()) {
iter_state->decrement_pending(dst_id, 2);
int count = iter_state->pending(dst_id);
dst_dead = (iter_state->dead_count(dst_id) == dst_node->num_inputs());
iter_state->decrement_pending(dst_pending_id, 2);
int count = iter_state->pending(dst_pending_id);
int dead_cnt = iter_state->dead_count(dst_pending_id);
dst_dead = (dead_cnt == dst_node->num_inputs());
dst_ready = (count == 0) || ((count == 1) && dst_dead);
} else {
if (outputs[src_slot].has_value) {
// This is a live data input.
int count = iter_state->pending(dst_id);
iter_state->mark_live(dst_id);
int count = iter_state->pending(dst_pending_id);
iter_state->mark_live(dst_pending_id);
// Only the first live edge sets the input and (potentially)
// triggers execution. The low bit of count is set if and
// only if no live input has been used yet (mark_live clears
@ -1962,10 +2095,10 @@ void ExecutorState::FrameState::ActivateNodes(const Node* node,
// a dead enter. We need this to handle properly a while loop on
// the untaken branch of a conditional.
// TODO(yuanbyu): This is a bit hacky, but a good solution for now.
iter_state->increment_dead_count(dst_id);
const int dead_cnt = iter_state->dead_count(dst_id);
iter_state->increment_dead_count(dst_pending_id);
const int dead_cnt = iter_state->dead_count(dst_pending_id);
dst_dead = (dead_cnt == dst_node->num_inputs()) || IsEnter(node);
dst_ready = (iter_state->pending(dst_id) == 1) && dst_dead;
dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
dst_need_input = false;
}
}
@ -1974,10 +2107,10 @@ void ExecutorState::FrameState::ActivateNodes(const Node* node,
// for all inputs to come in even if we know the node is dead. This
// ensures that all input tensors get cleaned up.
if (is_dead || (!e->IsControlEdge() && !outputs[src_slot].has_value)) {
iter_state->increment_dead_count(dst_id);
iter_state->increment_dead_count(dst_pending_id);
}
dst_dead = iter_state->dead_count(dst_id) > 0;
dst_ready = (iter_state->decrement_pending(dst_id, 1) == 0);
dst_dead = iter_state->dead_count(dst_pending_id) > 0;
dst_ready = (iter_state->decrement_pending(dst_pending_id, 1) == 0);
}
if (dst_need_input) {
@ -2052,7 +2185,8 @@ void ExecutorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
int64 next_iter = iteration_count;
// Initialize the next iteration.
IterationState* iter_state = new IterationState(executor);
IterationState* iter_state =
new IterationState(pending_counts, total_input_tensors);
SetIteration(next_iter, iter_state);
num_outstanding_iterations++;
dead_exits.clear();

View File

@ -44,11 +44,7 @@ static const char* const kRetOp = "_Retval";
static const char* const kGradientOp = "SymbolicGradient";
static const char* const kNodeLabel = "Func";
static const char* const kFuncAttr = "f";
// kNoinlineAttr must start with an "_" to avoid collisions with
// user-specified attrs.
static const char* const kNoinlineAttr = "_noinline";
// Old graphs use no "_".
static const char* const kOldNoinlineAttr = "noinline";
static const char* const kNoInlineAttr = "_noinline";
// Represents the index-th output of a node.
struct Endpoint {
@ -168,6 +164,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
Device* device() override { return device_; }
Env* env() override { return env_; }
int graph_def_version() override { return graph_def_version_; }
string DebugString(Handle h) override;
@ -290,6 +287,34 @@ const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
return func_graphs_[h];
}
namespace {
struct CustomCreatorSingleton {
mutex mu;
CustomKernelCreator custom_creator = nullptr;
void Set(CustomKernelCreator cb) {
mutex_lock l(mu);
custom_creator = cb;
}
CustomKernelCreator Get() {
mutex_lock l(mu);
return custom_creator;
}
};
CustomCreatorSingleton* GetCustomCreatorSingleton() {
static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
return ccs;
}
} // end namespace
void RegisterCustomKernelCreator(CustomKernelCreator cb) {
GetCustomCreatorSingleton()->Set(cb);
}
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
OpKernel** kernel) {
if (lib_def_->Find(ndef.op()) == nullptr) {
@ -318,8 +343,23 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
output_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY);
}
// Constructs a CallOp kernel for running the instantiated function.
// If a custom kernel creator is given, try that.
CustomKernelCreator custom_creator = GetCustomCreatorSingleton()->Get();
Status s;
if (custom_creator) {
std::unique_ptr<OpKernel> ret;
s = custom_creator(this, ndef, &ret);
if (s.ok()) {
*kernel = ret.release();
return s;
} else {
VLOG(2) << "Custom creator error: " << s;
// Falls through.
s = Status::OK();
}
}
// Constructs a CallOp kernel for running the instantiated function.
auto device_type = DeviceType(device_->attributes().device_type());
OpKernelConstruction construction(
device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
@ -327,7 +367,7 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
fbody->ret_types, output_memory_types, graph_def_version_, &s);
*kernel = new CallOp(handle, &construction);
if (!s.ok()) {
delete kernel;
delete *kernel;
}
return s;
}
@ -887,15 +927,11 @@ static void InlineFunctionBody(Graph* g, Node* caller,
}
// Given a node's NodeDef, returns false iff the node explicitly
// specified _noinline. This gives ExpandInlineFunctions a heuristic to
// decide whether to inline the function.
// `old` is true for GraphDef versions older than 12, when the
// `noinline` attr was renamed to `_noinline` to avoid conflicts with
// user-specified attrs.
bool ShouldInline(const NodeDef& ndef, bool old) {
// specified _noinline. This gives ExpandInlineFunctions a heuristic
// to decide whether to inline the function.
bool ShouldInline(const NodeDef& ndef) {
bool noinline = false;
const char* const attr = old ? kOldNoinlineAttr : kNoinlineAttr;
if (GetNodeAttr(ndef, attr, &noinline).ok()) {
if (GetNodeAttr(ndef, kNoInlineAttr, &noinline).ok()) {
// If the node specifies attribute '_noinline', returns accordingly.
return !noinline;
}
@ -914,7 +950,8 @@ bool ShouldInline(const NodeDef& ndef, bool old) {
// continue and the runtime will error out.
return false;
}
s = GetNodeAttr(AttrSlice(&forward_func_attrs->attr()), attr, &noinline);
s = GetNodeAttr(AttrSlice(&forward_func_attrs->attr()), kNoInlineAttr,
&noinline);
if (!s.ok()) {
// The forward function doesn't specify '_noinline' attr, we should
// be free to decide.
@ -926,11 +963,9 @@ bool ShouldInline(const NodeDef& ndef, bool old) {
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
std::vector<std::pair<Node*, const FunctionBody*>> candidates;
// Identify old graphs before the 'noinline' attr was renamed '_noinline'.
const bool old_inline_attr = graph->versions().producer() < 12;
for (Node* node : graph->nodes()) {
VLOG(3) << "Expanding " << node->DebugString();
if (!ShouldInline(node->def(), old_inline_attr)) {
if (!ShouldInline(node->def())) {
VLOG(3) << "noinline: " << node->DebugString();
continue;
}

View File

@ -123,6 +123,18 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
// TODO(zhifengc): Asks math expert to say the comment again.
FunctionBody* SymbolicGradient(const FunctionBody& f);
// Registers a customizable kernel creator for a function call.
//
// If 'cb()' returns a non-OK, we still fall back to an executor-based
// interpreter op kernel to execute a function. If 'cb()' returns OK,
// takes ownership of the returned OpKernel.
//
// TODO(zhifengc/phawkins): b/32379046
typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
std::unique_ptr<OpKernel>*)>
CustomKernelCreator;
void RegisterCustomKernelCreator(CustomKernelCreator cb);
} // end namespace tensorflow
#endif // TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_

View File

@ -71,6 +71,7 @@ class PendingCounts {
}
}
inline int num_nodes() const { return num_nodes_; }
NodeState node_state(int id) {
if (IsLarge(id)) {
return NodeStateLarge(id);
@ -185,12 +186,7 @@ class PendingCounts {
// use one byte to hold both the pending and dead count for a node
// where these together can fit in one byte, and we use a hash table
// to handle the rare node ids that need larger counts than this.
// TODO(yuanbyu): We current use O(# of nodes in partition) space
// even for nested iterations where only a small fraction of the
// nodes are involved. This is not efficient if the subgraph for
// the frame is only a small subset of the partition. We should make
// the vector size to be only the size of the frame subgraph.
// Each frame in this subgraph has its own PendingCounts.
// We use 3 bits each for dead_count and pending.
static const int kMaxCountForPackedCounts = 7;

View File

@ -27,6 +27,10 @@ limitations under the License.
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
ShapeRefiner::ShapeRefiner(const OpRegistryInterface* ops)
: ops_registry_(ops) {}
@ -37,7 +41,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
// from 'input's InferenceContext, and store into a vector
// indexed by 'node's input.
std::vector<Node*> input_nodes(node->num_inputs());
std::vector<shape_inference::ShapeHandle> input_shapes(node->num_inputs());
std::vector<ShapeHandle> input_shapes(node->num_inputs());
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
@ -49,7 +53,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
node->name(), "' was not previously added to ShapeRefiner.");
}
shape_inference::InferenceContext* c = it->second;
InferenceContext* c = it->second;
DCHECK_GE(e->dst_input(), 0);
input_nodes[e->dst_input()] = input;
input_shapes[e->dst_input()] = c->output(e->src_output());
@ -68,11 +72,13 @@ Status ShapeRefiner::AddNode(const Node* node) {
std::vector<const Tensor*> input_tensors(node->num_inputs());
std::vector<Tensor> real_tensors(node->num_inputs());
std::vector<bool> attempted_materialization(node->num_inputs());
std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
std::vector<ShapeHandle> input_tensors_as_shapes;
// Create the inference context for this node with the existing input shapes.
std::unique_ptr<shape_inference::InferenceContext> c(
new shape_inference::InferenceContext(&node->def(), node->op_def(),
input_shapes, input_tensors));
std::unique_ptr<InferenceContext> c(
new InferenceContext(&node->def(), node->op_def(), input_shapes,
input_tensors, input_tensors_as_shapes));
if (!c->construction_status().ok()) {
return c->construction_status();
}
@ -101,63 +107,44 @@ Status ShapeRefiner::AddNode(const Node* node) {
// subgraph once.
for (int i = 0; i < c->num_inputs(); ++i) {
if (!c->requested_input_tensor(i)) {
continue;
}
// Check if we have not already filled in the requested input,
// and if not, try to materialize the tensors.
if (c->requested_input_tensor(i) && !attempted_materialization[i]) {
if (!attempted_materialization[i]) {
attempted_materialization[i] = true;
const Edge* input_edge;
TF_RETURN_IF_ERROR(node->input_edge(i, &input_edge));
bool is_constant_graph = false;
Graph subgraph(ops_registry_);
// We identify the possibly constant subgraph to evaluate by
// recursively iterating backwards through the inputs to 'node'
// until we either 1) find an already existing input to our subgraph
// (filled in `const_inputs`), 2) Discover our graph is not constant,
// or 3) Hit a root node.
std::vector<std::pair<string, Tensor>> const_inputs;
TF_RETURN_IF_ERROR(ExtractConstantSubgraph(
input_nodes[i], &subgraph, &is_constant_graph, &const_inputs));
if (is_constant_graph) {
const string output_tensor_name = strings::StrCat(
input_nodes[i]->name(), ":", input_edge->src_output());
std::vector<Tensor> outputs;
// NOTE; we should pass in a function library runtime if we want
// to support constant-expression evaluation on functions.
Status s = GraphRunner::Run(&subgraph, nullptr /* function_library */,
Env::Default(), const_inputs,
{output_tensor_name}, &outputs);
// If all kernels in the constant graph are not registered
// in the process, GraphRunner::Run may fail, in which case
// we cannot propagate constants, so this is best-effort.
if (s.ok()) {
real_tensors[i] = outputs[0];
input_tensors[i] = &real_tensors[i];
// We have more concrete information about a shape,
// so re-run shape inference.
rerun_shape_fn = true;
// We memoize (small) constants evaluated so far, so
// ExtractConstantSubgraph can avoid extracting the full
// subgraph. As we build up large graphs, this avoids
// repeated computation of the early parts of a constant
// graph.
if (outputs[0].TotalBytes() <= kMaxTensorSize) {
const_tensor_map_[output_tensor_name] = outputs[0];
}
}
Tensor result;
bool evaluated = false;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
if (evaluated) {
real_tensors[i] = result;
input_tensors[i] = &real_tensors[i];
// We have more concrete information about a shape,
// so re-run shape inference.
rerun_shape_fn = true;
}
}
if (c->requested_input_tensor_as_partial_shape(i) &&
!attempted_tensor_as_shape_conversion[i]) {
attempted_tensor_as_shape_conversion[i] = true;
if (i >= input_tensors_as_shapes.size()) {
input_tensors_as_shapes.resize(i + 1);
}
ShapeHandle s;
TF_RETURN_IF_ERROR(ConstantPartialShape(c.get(), node, i, &s));
input_tensors_as_shapes[i] = s;
rerun_shape_fn = true;
}
}
if (rerun_shape_fn) {
// We have more information about the shapes on this pass,
// so re-run shape inference.
c->set_input_tensors(input_tensors);
c->set_input_tensors_as_shapes(input_tensors_as_shapes);
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c.get()));
}
} while (rerun_shape_fn);
@ -169,7 +156,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
}
Status ShapeRefiner::SetShape(const Node* node, int output_port,
shape_inference::ShapeHandle shape) {
ShapeHandle shape) {
auto c = GetContext(node);
if (c == nullptr) {
return errors::Internal("Could not find context for ", node->name());
@ -182,7 +169,7 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
}
// Check compatibility, and merge the shapes.
shape_inference::ShapeHandle existing_shape = c->output(output_port);
ShapeHandle existing_shape = c->output(output_port);
TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
c->set_output(output_port, shape);
@ -196,6 +183,55 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
return Status::OK();
}
Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
int dst_idx, bool* evaluated,
Tensor* result) {
*evaluated = false;
const Edge* input_edge;
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
bool is_constant_graph = false;
Graph subgraph(ops_registry_);
// We identify the possibly constant subgraph to evaluate by
// recursively iterating backwards through the inputs to 'node'
// until we either 1) find an already existing input to our subgraph
// (filled in `const_inputs`), 2) Discover our graph is not constant,
// or 3) Hit a root node.
std::vector<std::pair<string, Tensor>> const_inputs;
TF_RETURN_IF_ERROR(ExtractConstantSubgraph(
input_edge->src(), &subgraph, &is_constant_graph, &const_inputs));
if (!is_constant_graph) {
return Status::OK();
}
const string output_tensor_name =
strings::StrCat(input_edge->src()->name(), ":", input_edge->src_output());
std::vector<Tensor> outputs;
// NOTE; we should pass in a function library runtime if we want
// to support constant-expression evaluation on functions.
Status s = GraphRunner::Run(&subgraph, nullptr /* function_library */,
Env::Default(), const_inputs,
{output_tensor_name}, &outputs);
// If all kernels in the constant graph are not registered
// in the process, GraphRunner::Run may fail, in which case
// we cannot propagate constants, so this is best-effort.
if (s.ok()) {
*result = outputs[0];
*evaluated = true;
// We memoize (small) constants evaluated so far, so
// ExtractConstantSubgraph can avoid extracting the full
// subgraph. As we build up large graphs, this avoids
// repeated computation of the early parts of a constant
// graph.
if (outputs[0].TotalBytes() <= kMaxTensorSize) {
const_tensor_map_[output_tensor_name] = outputs[0];
}
}
return Status::OK();
}
Status ShapeRefiner::ExtractConstantSubgraph(
Node* target_node, Graph* out_graph, bool* is_constant_graph,
std::vector<std::pair<string, Tensor>>* const_inputs) {
@ -308,4 +344,75 @@ Status ShapeRefiner::ExtractConstantSubgraph(
return Status::OK();
}
Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
const Node* node, int dst_idx,
ShapeHandle* result) {
const Edge* input_edge;
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
InferenceContext* src_context = GetContext(input_edge->src());
if (src_context == nullptr) return errors::Internal("Missing src context");
ShapeHandle src_shape = src_context->output(input_edge->src_output());
TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
const string& src_op = input_edge->src()->type_string();
if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
// Source tensor is a vector of length 0, so the shape it
// represents is as scalar.
*result = target_context->Scalar();
} else if (src_op == "Shape") {
*result = src_context->input(0);
} else if (src_op == "Pack") {
std::vector<DimensionHandle> dims;
// Pack is concatenating its input scalars to form the shape tensor vector.
for (int i = 0; i < src_context->num_inputs(); ++i) {
Tensor scalar;
bool evaluated = false;
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i,
&evaluated, &scalar));
if (evaluated) {
int64 size;
if (scalar.dtype() == DT_INT32) {
size = scalar.scalar<int32>()();
} else if (scalar.dtype() == DT_INT64) {
size = scalar.scalar<int64>()();
} else {
return errors::InvalidArgument("Pack input must be int32 or int64");
}
dims.push_back(size < 0 ? target_context->UnknownDim()
: target_context->MakeDim(size));
} else {
dims.push_back(target_context->UnknownDim());
}
}
*result = target_context->MakeShape(dims);
} else if (src_op == "Concat") {
*result = target_context->Scalar();
// Concat is concatenating its input shape vectors.
// input 0 is ignored as it is the concat dim and will always be 0.
for (int i = 1; i < src_context->num_inputs(); ++i) {
ShapeHandle sub_result;
TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
i, &sub_result));
if (!target_context->RankKnown(sub_result)) {
// Failed to evaluate. Treat the output as completely unknown.
// TODO(cwhipkey): we could rely on all inputs being the same size, so
// figure that size out and append the right number of unknown dims.
*result = target_context->UnknownShape();
return Status::OK();
}
TF_RETURN_IF_ERROR(
target_context->Concatenate(*result, sub_result, result));
}
} else {
Tensor t;
bool evaluated = false;
TF_RETURN_IF_ERROR(
EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
evaluated ? &t : nullptr, src_shape, result));
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -71,6 +71,34 @@ class ShapeRefiner {
Node* node, Graph* out_graph, bool* is_constant_graph,
std::vector<std::pair<string, Tensor>>* const_inputs) TF_MUST_USE_RESULT;
Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
bool* evaluated, Tensor* result);
// This function tries to materialize as much information about the 'node''s
// dst_idx input as a statically computable shape, and the result may be
// partially known, depending on what is statically inferable.
//
// This is called when node.input[dst_idx] is a tensor that is used to define
// the shape of some other tensor (e.g., the second argument to Reshape is a
// <shape> tensor, where each element of the shape tensor is a dimension of
// the target tensor). It returns in <result> a shape for that input.
//
// Unlike simply resolving node.input[dst_idx] to a constant and then
// converting that to a shape, this function can return a partial shape. This
// is useful for cases where the shape tensor is only partially defined, such
// as with calls for: reshape(x, shape(y)) where shape(y) is partially
// defined.
//
// The implementation has op implementations for ops commonly called on shape
// tensors, and the implementations are specialized to shape tensors (namely,
// the output is a vector).
//
// <target_context> is used when creating new DimensionHandle and ShapeHandle
// objects.
Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
const Node* node, int dst_idx,
shape_inference::ShapeHandle* result);
const OpRegistryInterface* ops_registry_ = nullptr;
// Stores a map from a node to its InferenceContext.

View File

@ -398,5 +398,347 @@ TEST(ShapeRefinerTest, ConstantValueVisitNodeTwice) {
EXPECT_EQ("[1,4,7]", ctx->DebugString(ctx->output(0)));
}
namespace {
Status TensorAsShapeShapeFn(shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0 /* input_idx */, &out));
c->set_output(0, out);
return Status::OK();
}
// Register ops used by the ConstantValueAsShape* tests.
REGISTER_OP("TensorAsShapeInt32")
.Input("a: int32")
.Output("o: int32")
.SetShapeFn(TensorAsShapeShapeFn);
REGISTER_OP("TensorAsShapeInt64")
.Input("a: int64")
.Output("o: int64")
.SetShapeFn(TensorAsShapeShapeFn);
REGISTER_OP("NonConstScalarInt32")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("NonConstScalarInt64")
.Output("o: int64")
.SetIsStateful() // prevents constant folding
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("WithEmptyVectorShape")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Vector(0));
return Status::OK();
});
REGISTER_OP("WithPartialShape")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(
0, c->MakeShape({1, shape_inference::InferenceContext::kUnknownDim, 3,
shape_inference::InferenceContext::kUnknownDim, 5}));
return Status::OK();
});
REGISTER_OP("WithPartialShape2")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(
0,
c->MakeShape({6, shape_inference::InferenceContext::kUnknownDim, 8}));
return Status::OK();
});
REGISTER_OP("WithUnknownShape")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->UnknownShape());
return Status::OK();
});
} // namespace
TEST(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) {
Scope root = Scope::NewRootScope();
Node* input;
TF_ASSERT_OK(
NodeBuilder("in", "WithEmptyVectorShape").Finalize(root.graph(), &input));
Node* result;
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
.Input(input)
.Finalize(root.graph(), &result));
ShapeRefiner m(OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input));
TF_ASSERT_OK(m.AddNode(result));
shape_inference::InferenceContext* ctx = m.GetContext(result);
EXPECT_EQ("[]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, ConstantValueAsShape_Shape) {
for (int pass = 0; pass < 2; ++pass) {
Scope root = Scope::NewRootScope();
Node* input;
TF_ASSERT_OK(
NodeBuilder("in", pass == 0 ? "WithPartialShape" : "WithUnknownShape")
.Finalize(root.graph(), &input));
auto shape = ops::Shape(root, ops::Output(input));
Node* result;
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
.Input(shape.node())
.Finalize(root.graph(), &result));
ShapeRefiner m(OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input));
TF_ASSERT_OK(m.AddNode(shape.node()));
TF_ASSERT_OK(m.AddNode(result));
shape_inference::InferenceContext* ctx = m.GetContext(result);
if (pass == 0) {
EXPECT_EQ("[1,?,3,?,5]", ctx->DebugString(ctx->output(0)));
} else {
EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
}
}
}
TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
Scope root = Scope::NewRootScope();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
.Finalize(root.graph(), &scalar_non_const));
ops::InputList inputs{
ops::Input(ops::Const<int32>(root, 10)),
ops::Input(ops::Const<int32>(root, 20)),
ops::Input(ops::Output(scalar_non_const)),
ops::Input(ops::Const<int32>(root, 40)),
};
auto pack = ops::Pack(root, inputs);
TF_ASSERT_OK(root.status());
Node* result;
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
.Input(pack.node())
.Finalize(root.graph(), &result));
ShapeRefiner m(OpRegistry::Global());
for (auto input : inputs) {
TF_ASSERT_OK(m.AddNode(input.node()));
}
TF_ASSERT_OK(m.AddNode(pack.node()));
TF_ASSERT_OK(m.AddNode(result));
shape_inference::InferenceContext* ctx = m.GetContext(result);
EXPECT_EQ("[10,20,?,40]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
Scope root = Scope::NewRootScope();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
.Finalize(root.graph(), &scalar_non_const));
ops::InputList inputs{
ops::Input(ops::Const<int64>(root, 10LL)),
ops::Input(ops::Const<int64>(root, 20LL)),
ops::Input(ops::Output(scalar_non_const)),
ops::Input(ops::Const<int64>(root, 1LL << 40)),
};
auto pack = ops::Pack(root, inputs);
TF_ASSERT_OK(root.status());
Node* result;
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
.Input(pack.node())
.Finalize(root.graph(), &result));
ShapeRefiner m(OpRegistry::Global());
for (const auto& input : inputs) {
TF_ASSERT_OK(m.AddNode(input.node()));
}
TF_ASSERT_OK(m.AddNode(pack.node()));
TF_ASSERT_OK(m.AddNode(result));
shape_inference::InferenceContext* ctx = m.GetContext(result);
EXPECT_EQ("[10,20,?,1099511627776]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) {
Scope root = Scope::NewRootScope();
ops::InputList inputs{
ops::Input(ops::Const<int64>(root, 10LL)),
ops::Input(ops::Const<int64>(root, -1LL)),
};
auto pack = ops::Pack(root, inputs);
TF_ASSERT_OK(root.status());
Node* result;
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
.Input(pack.node())
.Finalize(root.graph(), &result));
ShapeRefiner m(OpRegistry::Global());
for (const auto& input : inputs) {
TF_ASSERT_OK(m.AddNode(input.node()));
}
TF_ASSERT_OK(m.AddNode(pack.node()));
TF_ASSERT_OK(m.AddNode(result));
shape_inference::InferenceContext* ctx = m.GetContext(result);
EXPECT_EQ("[10,?]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
Scope root = Scope::NewRootScope();
// Inputs are length 2 vectors instead of scalars.
ops::InputList inputs{
ops::Input(ops::Const<int64>(root, {10LL, 20LL})),
ops::Input(ops::Const<int64>(root, {10LL, 21LL})),
};
auto pack = ops::Pack(root, inputs);
TF_ASSERT_OK(root.status());
Node* result;
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
.Input(pack.node())
.Finalize(root.graph(), &result));
ShapeRefiner m(OpRegistry::Global());
for (const auto& input : inputs) {
TF_ASSERT_OK(m.AddNode(input.node()));
}
TF_ASSERT_OK(m.AddNode(pack.node()));
EXPECT_TRUE(
StringPiece(m.AddNode(result).error_message()).contains("but is rank 2"));
}
TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) {
Scope root = Scope::NewRootScope();
Graph* g = root.graph();
Node* partial_1;
Node* partial_2;
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
auto const_input = ops::Const(root, {9, 10, 11});
ops::OutputList concat_inputs{
ops::Shape(root, ops::Output(partial_1)),
ops::Shape(root, ops::Output(partial_2)), const_input,
};
auto concat_dim = ops::Const(root, 0);
auto concat = ops::Concat(root, concat_dim, concat_inputs);
TF_ASSERT_OK(root.status());
Node* result;
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
.Input(concat.node())
.Finalize(g, &result));
ShapeRefiner m(OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(partial_1));
TF_ASSERT_OK(m.AddNode(partial_2));
for (const auto& o : concat_inputs) {
TF_ASSERT_OK(m.AddNode(o.node()));
}
TF_ASSERT_OK(m.AddNode(concat_dim.node()));
TF_ASSERT_OK(m.AddNode(concat.node()));
TF_ASSERT_OK(m.AddNode(result));
shape_inference::InferenceContext* ctx = m.GetContext(result);
EXPECT_EQ("[1,?,3,?,5,6,?,8,9,10,11]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
Scope root = Scope::NewRootScope();
Graph* g = root.graph();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
.Finalize(root.graph(), &scalar_non_const));
Node* partial_1;
Node* partial_2;
Node* unknown;
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
TF_ASSERT_OK(NodeBuilder("in", "WithUnknownShape").Finalize(g, &unknown));
ops::OutputList concat_inputs{
ops::Shape(root, ops::Output(partial_1)),
ops::Shape(root, ops::Output(partial_2)),
ops::Shape(root, ops::Output(unknown)),
};
auto concat_dim = ops::Const(root, 0);
auto concat = ops::Concat(root, concat_dim, concat_inputs);
TF_ASSERT_OK(root.status());
Node* result;
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
.Input(concat.node())
.Finalize(g, &result));
ShapeRefiner m(OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(partial_1));
TF_ASSERT_OK(m.AddNode(partial_2));
TF_ASSERT_OK(m.AddNode(unknown));
for (const auto& o : concat_inputs) {
TF_ASSERT_OK(m.AddNode(o.node()));
}
TF_ASSERT_OK(m.AddNode(concat_dim.node()));
TF_ASSERT_OK(m.AddNode(concat.node()));
TF_ASSERT_OK(m.AddNode(result));
shape_inference::InferenceContext* ctx = m.GetContext(result);
EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
Scope root = Scope::NewRootScope();
Graph* g = root.graph();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
.Finalize(root.graph(), &scalar_non_const));
Node* partial_1;
Node* partial_2;
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
auto const_input = ops::Const(root, {9, -2, 11});
ops::OutputList concat_inputs{
ops::Shape(root, ops::Output(partial_1)),
ops::Shape(root, ops::Output(partial_2)), //
const_input,
};
auto concat_dim = ops::Const(root, 0);
auto concat = ops::Concat(root, concat_dim, concat_inputs);
TF_ASSERT_OK(root.status());
Node* result;
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
.Input(concat.node())
.Finalize(g, &result));
ShapeRefiner m(OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(partial_1));
TF_ASSERT_OK(m.AddNode(partial_2));
for (const auto& o : concat_inputs) {
TF_ASSERT_OK(m.AddNode(o.node()));
}
TF_ASSERT_OK(m.AddNode(concat_dim.node()));
TF_ASSERT_OK(m.AddNode(concat.node()));
EXPECT_EQ("Invalid value in tensor used for shape: -2",
m.AddNode(result).error_message());
}
} // namespace
} // namespace tensorflow

View File

@ -274,16 +274,6 @@ Status SimpleGraphExecutionState::InitBaseGraph(
return Status::OK();
}
void SimpleGraphExecutionState::UpdateCostsFromStats(const StepStats& ss) {
mutex_lock l(mu_);
costs_.MergeFromStats(node_name_to_cost_id_map_, ss);
}
void SimpleGraphExecutionState::MergeCostsFromGlobal(CostModel* costs) {
mutex_lock l(mu_);
costs->MergeFromGlobal(costs_);
}
Status SimpleGraphExecutionState::BuildGraph(
const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
VLOG(1) << "BuildGraph";

View File

@ -133,22 +133,6 @@ class SimpleGraphExecutionState {
Status BuildGraph(const BuildGraphOptions& options,
std::unique_ptr<SimpleClientGraph>* out);
// Sums execution statistics in "ss" into the CostModel.
void UpdateCostsFromStats(const StepStats& ss);
Microseconds TimeEstimate(const Node* n) {
mutex_lock l(mu_); // could use reader lock
return costs_.TimeEstimate(n);
}
Bytes SizeEstimate(const Node* n, int output_slot) {
mutex_lock l(mu_); // could use reader lock
return costs_.SizeEstimate(n, output_slot);
}
// Merge the cost model maintained by this graph_execution_state to 'costs'.
void MergeCostsFromGlobal(CostModel* costs);
// The graph returned by BuildGraph may contain only the pruned
// graph, whereas some clients may want access to the full graph.
const Graph* full_graph() {

View File

@ -335,7 +335,9 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) {
}
TEST_F(SessionDebugMinusAXTest,
RunSimpleNetworkConcurrentlyWithDebugNodesInserted) {
RunSimpleNetworkConcurrentlyWithDifferentDebugTensorWatches) {
// Test concurrent Run() calls on a graph with different debug watches.
Initialize({3, 2, -1, 0});
std::unique_ptr<DirectSession> session(CreateSession());
ASSERT_TRUE(session != nullptr);
@ -351,33 +353,39 @@ TEST_F(SessionDebugMinusAXTest,
mutex mu;
DebugGateway debug_gateway(session.get());
std::vector<Tensor> debug_identity_tensor_vals;
std::unordered_map<string, Tensor> debug_identity_tensor_vals;
const string debug_identity = "DebugIdentity";
const string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
const string a_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
strings::StrCat(a_, ":", 0), 0, debug_identity);
const string x_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
strings::StrCat(x_, ":", 0), 0, debug_identity);
const string y_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
strings::StrCat(y_, ":", 0), 0, debug_identity);
Notification callbacks_done;
int comp_callback_count = 0;
int val_callback_count = 0;
debug_gateway.SetNodeCompletionCallback(
[&mu, &callbacks_done, &comp_callback_count, &debug_identity_node_name](
const string& node_name, const bool any_output) {
mutex_lock l(mu);
if (node_name == debug_identity_node_name) {
comp_callback_count++;
}
});
volatile int val_callback_count = 0;
debug_gateway.SetNodeValueCallback(
[this, &mu, &val_callback_count, &debug_identity_node_name,
[this, &mu, &val_callback_count, &a_debug_identity_node_name,
&x_debug_identity_node_name, &y_debug_identity_node_name,
&debug_identity_tensor_vals,
&callbacks_done](const string& node_name, const int output_slot,
const Tensor& tensor_value, const bool is_ref) {
mutex_lock l(mu);
if (node_name == debug_identity_node_name && output_slot == 0) {
if (node_name == a_debug_identity_node_name && output_slot == 0) {
debug_identity_tensor_vals["a"] = tensor_value;
val_callback_count++;
} else if (node_name == x_debug_identity_node_name &&
output_slot == 0) {
// output_slot == 0 carries the debug signal.
debug_identity_tensor_vals.push_back(tensor_value);
debug_identity_tensor_vals["x"] = tensor_value;
val_callback_count++;
} else if (node_name == y_debug_identity_node_name &&
output_slot == 0) {
debug_identity_tensor_vals["y"] = tensor_value;
val_callback_count++;
}
@ -389,19 +397,41 @@ TEST_F(SessionDebugMinusAXTest,
}
});
int run_counter = 0;
mutex run_lock;
// Function to be executed concurrently.
auto fn = [this, &session, output_names, target_nodes, &debug_identity]() {
// Create unique debug tensor watch options for each of the two concurrent
auto fn = [this, &run_lock, &run_counter, &session, output_names,
target_nodes, &debug_identity]() {
// Create unique debug tensor watch options for each of the concurrent
// run calls.
RunOptions run_opts;
run_opts.set_output_partition_graphs(true);
DebugTensorWatch* tensor_watch_opts =
run_opts.add_debug_tensor_watch_opts();
tensor_watch_opts->set_node_name(y_);
tensor_watch_opts->set_output_slot(0);
tensor_watch_opts->add_debug_ops(debug_identity);
{
// Let the concurrent runs watch different tensors.
mutex_lock l(run_lock);
if (run_counter == 0) {
// Let the 1st concurrent run watch a.
tensor_watch_opts->set_node_name(a_);
} else if (run_counter == 1) {
// Let the 2nd concurrent watch x.
tensor_watch_opts->set_node_name(x_);
} else if (run_counter == 2) {
// Let the 3rd concurrent watch y.
tensor_watch_opts->set_node_name(y_);
}
run_counter++;
}
// Run the graph.
RunMetadata run_metadata;
std::vector<std::pair<string, Tensor>> inputs;
@ -436,15 +466,26 @@ TEST_F(SessionDebugMinusAXTest,
{
mutex_lock l(mu);
ASSERT_EQ(kConcurrentRuns, comp_callback_count);
ASSERT_EQ(kConcurrentRuns, val_callback_count);
ASSERT_EQ(kConcurrentRuns, debug_identity_tensor_vals.size());
for (int i = 0; i < kConcurrentRuns; ++i) {
ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals[i].shape());
auto mat_identity = debug_identity_tensor_vals[i].matrix<float>();
ASSERT_EQ(5.0, mat_identity(0, 0));
ASSERT_EQ(-1.0, mat_identity(1, 0));
}
ASSERT_EQ(TensorShape({2, 2}), debug_identity_tensor_vals["a"].shape());
auto a_mat_identity = debug_identity_tensor_vals["a"].matrix<float>();
ASSERT_EQ(3.0, a_mat_identity(0, 0));
ASSERT_EQ(2.0, a_mat_identity(0, 1));
ASSERT_EQ(-1.0, a_mat_identity(1, 0));
ASSERT_EQ(0.0, a_mat_identity(1, 1));
ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals["x"].shape());
auto x_mat_identity = debug_identity_tensor_vals["x"].matrix<float>();
ASSERT_EQ(1.0, x_mat_identity(0, 0));
ASSERT_EQ(1.0, x_mat_identity(1, 0));
ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals["y"].shape());
auto y_mat_identity = debug_identity_tensor_vals["y"].matrix<float>();
ASSERT_EQ(5.0, y_mat_identity(0, 0));
ASSERT_EQ(-1.0, y_mat_identity(1, 0));
}
}
@ -499,25 +540,22 @@ TEST_F(SessionDebugOutputSlotWithoutOngoingEdgeTest,
Notification callbacks_done;
debug_gateway.SetNodeCompletionCallback(
[&mu, &callbacks_done](const string& node_name, const bool any_output) {
mutex_lock l(mu);
if (node_name == "_SINK" && !callbacks_done.HasBeenNotified()) {
callbacks_done.Notify();
}
});
std::vector<Tensor> debug_identity_tensor_vals;
debug_gateway.SetNodeValueCallback(
[this, &mu, &debug_identity_node_name, &debug_identity_tensor_vals](
const string& node_name, const int output_slot,
const Tensor& tensor_value, const bool is_ref) {
mutex_lock l(mu);
debug_gateway.SetNodeValueCallback([this, &mu, &callbacks_done,
&debug_identity_node_name,
&debug_identity_tensor_vals](
const string& node_name, const int output_slot,
const Tensor& tensor_value, const bool is_ref) {
mutex_lock l(mu);
if (node_name == debug_identity_node_name && output_slot == 0) {
debug_identity_tensor_vals.push_back(tensor_value);
}
});
if (node_name == debug_identity_node_name && output_slot == 0) {
debug_identity_tensor_vals.push_back(tensor_value);
if (!callbacks_done.HasBeenNotified()) {
callbacks_done.Notify();
}
}
});
// Add DebugIdentity watch on c:0, which does not have an outgoing edge.
RunOptions run_opts;

View File

@ -24,6 +24,30 @@ limitations under the License.
namespace tensorflow {
const string SummarizeDebugTensorWatches(
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches) {
std::ostringstream oss;
for (const DebugTensorWatch& watch : watches) {
string tensor_name =
strings::StrCat(watch.node_name(), ":", watch.output_slot());
oss << tensor_name << "|";
for (const string& debug_op : watch.debug_ops()) {
oss << debug_op << ",";
}
oss << "@";
for (const string& debug_url : watch.debug_urls()) {
oss << debug_url << ",";
}
oss << ";";
}
return oss.str();
}
// static
Status DebugNodeInserter::InsertNodes(
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches, Graph* graph,

View File

@ -27,6 +27,10 @@ limitations under the License.
namespace tensorflow {
// Returns a summary string for RepeatedPtrFields of DebugTensorWatches.
const string SummarizeDebugTensorWatches(
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches);
class DebugNodeInserter {
public:
// EXPERIMENTAL: Insert special debug ops (e.g., DebugIdentity) to graph for

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/log_memory.h"
@ -207,6 +208,11 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
if (!s.ok()) {
break;
}
unit->graph = subgraph;
unit->build_cost_model = graph_options.build_cost_model();
if (unit->build_cost_model > 0) {
skip_cost_models_ = false;
}
}
return s;
}
@ -319,6 +325,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done) {
// Lookup an item. Holds one ref while executing.
@ -348,7 +355,7 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
return;
}
StartParallelExecutors(handle, item, rendezvous, collector,
StartParallelExecutors(handle, item, rendezvous, collector, cost_graph,
cancellation_manager,
[this, item, rendezvous, done](const Status& s) {
done(s);
@ -360,6 +367,7 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
void GraphMgr::StartParallelExecutors(const string& handle, Item* item,
Rendezvous* rendezvous,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
StatusCallback done) {
const int num_units = item->units.size();
@ -367,7 +375,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, Item* item,
ResourceMgr* step_resource_manager = new ResourceMgr;
// NOTE: Transfer one ref of rendezvous and item.
ExecutorBarrier* barrier = new ExecutorBarrier(
num_units, rendezvous, [step_resource_manager, done](const Status& s) {
num_units, rendezvous, [this, item, collector, cost_graph,
step_resource_manager, done](const Status& s) {
BuildCostModel(item, collector, cost_graph);
done(s);
delete step_resource_manager;
});
@ -393,4 +403,24 @@ void GraphMgr::StartParallelExecutors(const string& handle, Item* item,
}
}
void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
CostGraphDef* cost_graph) {
if (collector && !skip_cost_models_) {
// Build the cost model
std::unordered_map<string, const Graph*> device_to_graph;
for (const auto& unit : item->units) {
if (unit.build_cost_model > 0) {
device_to_graph[unit.device->name()] = unit.graph;
}
}
collector->BuildCostModel(&cost_model_manager_, device_to_graph);
if (cost_graph != nullptr) {
for (const auto& unit : item->units) {
cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph);
}
}
}
}
} // end namespace tensorflow

View File

@ -19,9 +19,11 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/costmodel_manager.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -73,6 +75,7 @@ class GraphMgr {
typedef std::function<void(const Status&)> StatusCallback;
void ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts, StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done);
@ -89,9 +92,12 @@ class GraphMgr {
typedef GraphMgr ME;
struct ExecutionUnit {
Graph* graph = nullptr;
Device* device = nullptr;
Executor* root = nullptr;
FunctionLibraryRuntime* lib = nullptr;
// Build the cost model if this value is strictly positive.
int64 build_cost_model = 0;
};
struct Item : public core::RefCounted {
@ -117,6 +123,8 @@ class GraphMgr {
// Not owned.
const WorkerEnv* worker_env_;
CostModelManager cost_model_manager_;
// Owned.
mutex mu_;
int64 next_id_ GUARDED_BY(mu_) = 0;
@ -131,9 +139,17 @@ class GraphMgr {
void StartParallelExecutors(const string& handle, Item* item,
Rendezvous* rendezvous,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
StatusCallback done);
// Don't attempt to process cost models unless explicitely requested for at
// least one of the items.
bool skip_cost_models_ = true;
void BuildCostModel(Item* item, StepStatsCollector* collector,
CostGraphDef* cost_graph);
Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out);

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/scheduler.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
@ -58,6 +59,7 @@ struct PerStepState {
Microseconds end_micros = Microseconds(0);
std::vector<StepStats> step_stats; // per partition
StepStats rpc_stats; // for RPC layer
CostGraphDef cost_graph;
};
// MasterSession wraps SimpleClientGraph in a reference counted object.
@ -178,7 +180,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
// Post-processing of any runtime statistics gathered during execution.
void ProcessStats(const MasterEnv* env, int64 step_id, PerStepState* pss,
SimpleGraphExecutionState* execution_state,
ProfileHandler* ph, RunStepResponse* resp);
ProfileHandler* ph, const RunStepRequest& req,
RunStepResponse* resp);
void ProcessDeviceStats(ProfileHandler* ph,
const SimpleGraphExecutionState* execution_state,
const DeviceStepStats& ds, bool is_rpc);
@ -480,17 +483,6 @@ class RunManyGraphs {
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
};
int64 CostFrequency(int64 x) {
if (x < 10) {
return 1; // 100%
} else if (x < 100) {
return 10; // 10%
} else if (x < 1000) {
return 100; // 1%
} else {
return 1000; // 0.1%
}
}
Status MasterSession::ReffedClientGraph::RunPartitions(
const MasterEnv* env, int64 step_id, int64 execution_count,
@ -604,6 +596,12 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
if (pss->collect_timeline && calls.get(i)->resp.has_step_stats()) {
pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats());
}
if (pss->collect_costs && calls.get(i)->resp.has_cost_graph()) {
for (int j = 0; j < calls.get(i)->resp.cost_graph().node_size(); ++j) {
resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
calls.get(i)->resp.mutable_cost_graph()->mutable_node(j));
}
}
}
}
return status;
@ -679,7 +677,7 @@ void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
void MasterSession::ReffedClientGraph::ProcessStats(
const MasterEnv* env, int64 step_id, PerStepState* pss,
SimpleGraphExecutionState* execution_state, ProfileHandler* ph,
RunStepResponse* resp) {
const RunStepRequest& req, RunStepResponse* resp) {
if (!pss->collect_costs && !pss->collect_timeline) return;
// Out-of-band logging data is collected now, during post-processing.
@ -689,9 +687,6 @@ void MasterSession::ReffedClientGraph::ProcessStats(
}
for (size_t i = 0; i < partitions_.size(); ++i) {
const StepStats& ss = pss->step_stats[i];
if (pss->collect_costs) {
execution_state->UpdateCostsFromStats(ss);
}
if (ph) {
for (const auto& ds : ss.dev_stats()) {
ProcessDeviceStats(ph, execution_state, ds, false /*is_rpc*/);
@ -717,7 +712,7 @@ void MasterSession::ReffedClientGraph::ProcessStats(
stats_publisher_->PublishStatsProto(step_stats_proto);
// Copy the stats back, but only for on-demand profiling to avoid slowing
// down calls that trigger the automatic profiling.
if (session_opts_.config.graph_options().timeline_step() <= 0) {
if (req.options().trace_level() == RunOptions::FULL_TRACE) {
resp->mutable_metadata()->mutable_step_stats()->Swap(&step_stats_proto);
}
}
@ -1063,7 +1058,17 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
std::unique_ptr<ProfileHandler> ph;
pss.collect_timeline = req->options().trace_level() == RunOptions::FULL_TRACE;
pss.collect_costs = (0 == (count % CostFrequency(count)));
// Build the cost model every 'build_cost_model_every' steps after skipping an
// initial 'build_cost_model_after' steps.
const int64 build_cost_model_after =
session_opts_.config.graph_options().build_cost_model_after();
const int64 build_cost_model_every =
session_opts_.config.graph_options().build_cost_model();
pss.collect_costs =
build_cost_model_every > 0 &&
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
ph = rcg->GetProfileHandler(step_id, count, req->options());
if (ph) {
pss.collect_timeline = true;
@ -1078,7 +1083,7 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
// Schedule post-processing and cleanup to be done asynchronously.
rcg->Ref();
rcg->ProcessStats(env_, step_id, &pss, execution_state_.get(), ph.get(),
rcg->ProcessStats(env_, step_id, &pss, execution_state_.get(), ph.get(), *req,
resp);
rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) {
if (!s.ok()) {

View File

@ -329,7 +329,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
return;
}
StepStatsCollector* collector = nullptr;
if (call->request.exec_opts().record_timeline()) {
if (call->request.exec_opts().record_timeline() ||
call->request.exec_opts().record_costs()) {
collector = new StepStatsCollector(call->response.mutable_step_stats());
// TODO(mrry,pbar): GPU tracing for distributed steps.
}
@ -345,9 +346,10 @@ class GrpcWorkerService : public AsyncServiceInterface {
cancellation_manager_->RegisterCallback(token,
[cm]() { cm->StartCancel(); });
}
CostGraphDef* cost_graph = call->response.mutable_cost_graph();
env_->graph_mgr->ExecuteAsync(
call->request.graph_handle(), step_id, call->request.exec_opts(),
collector, cm, in,
collector, cost_graph, cm, in,
[this, step_id, call, cm, out, token, collector](Status s) {
if (s.ok()) {
env_->graph_mgr->RecvOutputs(step_id, out);

View File

@ -56,7 +56,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
.Input({{"data", 0, DT_FLOAT}})
.Finalize(&def));
InferenceContext c(&def, op_def, {S({}), S({10})}, {});
InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {});
TF_EXPECT_OK(NoOutputs(&c));
EXPECT_EQ(0, c.num_outputs());
}
@ -74,14 +74,14 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
{
InferenceContext c(&def, op_def, {S({})}, {});
InferenceContext c(&def, op_def, {S({})}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
}
{
InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {});
InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
@ -108,7 +108,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Finalize(&def));
{
InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {});
InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown inner dimension for one
InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {});
InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -126,7 +126,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Invalid rank.
InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {});
InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@ -136,7 +136,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown outer dimension
InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {});
InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -145,7 +145,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {});
InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@ -156,7 +156,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {});
InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@ -174,7 +174,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {});
InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -191,7 +191,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {});
InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -215,7 +215,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
{
InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {});
InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -224,7 +224,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Unknown ranks.
InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {});
InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_FALSE(c.RankKnown(output));
@ -232,7 +232,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Rank > 2
InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {});
InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
@ -245,7 +245,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {});
InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
@ -258,7 +258,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {});
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {},
{});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
@ -271,7 +272,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {});
InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[10,11,12]", c.DebugString(output));
@ -279,7 +280,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Input rank not high enough
InferenceContext c(&def, op_def, {S({3}), S({3})}, {});
InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@ -291,7 +292,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {});
InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
}
@ -310,7 +311,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
{
InferenceContext c(&def, op_def, {S({2, 10})}, {});
InferenceContext c(&def, op_def, {S({2, 10})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@ -318,7 +319,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Rank > 2
InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {});
InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@ -330,7 +331,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {});
InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@ -342,7 +343,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {});
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@ -354,7 +355,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({10, 11, 12})}, {});
InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@ -362,7 +363,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Input rank not high enough
InferenceContext c(&def, op_def, {S({3})}, {});
InferenceContext c(&def, op_def, {S({3})}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
@ -373,7 +374,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
InferenceContext c(&def, op_def, {S({2, 3})}, {});
InferenceContext c(&def, op_def, {S({2, 3})}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
}

View File

@ -400,6 +400,9 @@ class FunctionLibraryRuntime {
// Returns a debug string showing the definition of the function of
// 'handle'.
virtual string DebugString(Handle handle) = 0;
// Returns the graph version number.
virtual int graph_def_version() = 0;
};
// To register a gradient function for a builtin op, one should use

View File

@ -30,9 +30,10 @@ constexpr int64 InferenceContext::kUnknownDim;
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors)
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors);
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
for (const TensorShapeProto& p : input_shapes) {
ShapeHandle shape;
@ -48,9 +49,10 @@ InferenceContext::InferenceContext(
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors)
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors);
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
inputs_ = input_shapes;
PostInputInit();
@ -106,8 +108,10 @@ Status InferenceContext::output(StringPiece output_name,
}
void InferenceContext::PreInputInit(
const OpDef& op_def, const std::vector<const Tensor*>& input_tensors) {
const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes) {
input_tensors_ = input_tensors;
input_tensors_as_shapes_ = input_tensors_as_shapes;
construction_status_ =
NameRangesForNode(node_def_, op_def, &input_name_map_, &output_name_map_);
@ -139,6 +143,7 @@ void InferenceContext::PostInputInit() {
CHECK_LE(input_tensors_.size(), inputs_.size());
input_tensors_.resize(inputs_.size());
requested_input_tensor_.resize(inputs_.size());
requested_input_tensor_as_partial_shape_.resize(inputs_.size());
}
bool InferenceContext::FullyDefined(ShapeHandle s) {
@ -470,11 +475,24 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
ShapeHandle input_shape;
TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
const Tensor* t = input_tensor(input_idx);
if (input_idx < input_tensors_as_shapes_.size() &&
input_tensors_as_shapes_[input_idx].IsSet() &&
RankKnown(input_tensors_as_shapes_[input_idx])) {
*out = input_tensors_as_shapes_[input_idx];
return Status::OK();
}
requested_input_tensor_as_partial_shape_[input_idx] = true;
return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out);
}
Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
ShapeHandle tensor_shape,
ShapeHandle* out) {
if (t == nullptr) {
// Shape tensor is not known, but if the shape of the shape tensor is then
// the right number of unknown dims can be created.
DimensionHandle shape_dim = Dim(input_shape, 0);
DimensionHandle shape_dim = Dim(tensor_shape, 0);
if (!ValueKnown(shape_dim)) {
return ReturnUnknownShape(out);
}
@ -493,12 +511,24 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
if (t->dtype() == DataType::DT_INT32) {
auto flat_t = t->flat<int32>();
for (int i = 0; i < flat_t.size(); ++i) {
dims.push_back(MakeDim(flat_t(i)));
const int32 val = flat_t(i);
if (val < -1) {
return errors::InvalidArgument(
"Invalid value in tensor used for shape: ", val);
}
// -1 will become an unknown dim.
dims.push_back(MakeDim(val));
}
} else if (t->dtype() == DataType::DT_INT64) {
auto flat_t = t->flat<int64>();
for (int i = 0; i < flat_t.size(); ++i) {
dims.push_back(MakeDim(flat_t(i)));
const int64 val = flat_t(i);
if (val < -1) {
return errors::InvalidArgument(
"Invalid value in tensor used for shape: ", val);
}
// -1 will become an unknown dim.
dims.push_back(MakeDim(val));
}
} else {
*out = nullptr;
@ -558,24 +588,27 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
return Status::OK();
}
Status InferenceContext::Divide(DimensionHandle dividend, int64 divisor,
Status InferenceContext::Divide(DimensionHandle dividend,
DimensionOrConstant divisor,
bool evenly_divisible, DimensionHandle* out) {
if (divisor == 1) {
const int64 divisor_value = Value(divisor);
if (divisor_value == 1) {
*out = dividend;
} else if (!ValueKnown(dividend)) {
} else if (!ValueKnown(dividend) ||
(divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
*out = UnknownDim();
} else {
const int64 v = Value(dividend);
if (divisor <= 0) {
if (divisor_value <= 0) {
return errors::InvalidArgument("Divisor must be positive but is ",
divisor);
divisor_value);
}
if (evenly_divisible && (v % divisor) != 0) {
if (evenly_divisible && (v % divisor_value) != 0) {
return errors::InvalidArgument(
"Dimension size must be evenly divisible by ", divisor, " but is ",
v);
"Dimension size must be evenly divisible by ", divisor_value,
" but is ", v);
}
*out = MakeDim(v / divisor);
*out = MakeDim(v / divisor_value);
}
return Status::OK();
}

View File

@ -136,17 +136,33 @@ class InferenceContext {
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
//
// Elements of <input_tensors_as_shapes> are used for when a shape function
// makes a call to MakeShapeFromShapeTensor; in particular, when the
// input_tensors[i] is nullptr but the shape represented by it is partially
// known from analysis of the graph.
// <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
// Values of <input_tensors_as_shapes> do not need to outlive the context.
//
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors);
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes);
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
//
// Elements of <input_tensors_as_shapes> are used for when a shape function
// makes a call to MakeShapeFromShapeTensor; in particular, when the
// input_tensors[i] is nullptr but the shape represented by it is partially
// known from analysis of the graph.
// <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
// Values of <input_tensors_as_shapes> do not need to outlive the context.
//
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors);
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes);
~InferenceContext();
@ -180,10 +196,21 @@ class InferenceContext {
return requested_input_tensor_[idx];
}
// Returns true if MakeShapeFromInputTensor was called but the constant
// input_tensor was not present.
bool requested_input_tensor_as_partial_shape(int idx) const {
return requested_input_tensor_as_partial_shape_[idx];
}
void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
input_tensors_ = input_tensors;
}
void set_input_tensors_as_shapes(
const std::vector<ShapeHandle>& input_tensors_as_shapes) {
input_tensors_as_shapes_ = input_tensors_as_shapes;
}
void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
Status set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes);
@ -336,8 +363,8 @@ class InferenceContext {
// Returns in <out> the result of dividing <dividend> by <divisor>.
// Returns an error if <divisor> is not positive or if <evenly_divisible>
// and <divisor> does not evenly divide <dividend>.
Status Divide(DimensionHandle dividend, int64 divisor, bool evenly_divisible,
DimensionHandle* out);
Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
bool evenly_divisible, DimensionHandle* out);
// Returns in <out> the sum of <first> and <second>.
Status Add(DimensionHandle first, DimensionOrConstant second,
@ -408,6 +435,15 @@ class InferenceContext {
return Status::OK();
}
// Note that shape functions should usually call MakeShapeFromShapeTensor,
// as it does more analysis to provide partial shapes.
//
// Returns in <out> a new shape whose dimension sizes come from tensor <t>.
// The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL,
// then an unknown shape is returned.
Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
ShapeHandle* out);
private:
// Creates and stores shapes for use in InferenceContext.
class ShapeManager {
@ -443,7 +479,8 @@ class InferenceContext {
// Shared initialization across the two constructors. Remove
// once we get rid of one of them.
void PreInputInit(const OpDef& op_def,
const std::vector<const Tensor*>& input_tensors);
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes);
void PostInputInit();
DimensionHandle GetDimension(const DimensionOrConstant& d);
@ -463,11 +500,15 @@ class InferenceContext {
ShapeManager shape_manager_;
// inputs_ and outputs_ refer to values from `shape_manager_`.
// inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
// `shape_manager_`.
std::vector<ShapeHandle> inputs_;
std::vector<const Tensor*> input_tensors_;
std::vector<bool> requested_input_tensor_;
std::vector<ShapeHandle> outputs_;
// Can have fewer elements than inputs_.
std::vector<ShapeHandle> input_tensors_as_shapes_;
std::vector<bool> requested_input_tensor_as_partial_shape_;
const NodeDef& node_def_;
NameRangeMap input_name_map_;

View File

@ -71,7 +71,7 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
.Attr("N", 3)
.Input(FakeInput(DT_FLOAT))
.Finalize(&def);
InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {});
InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {});
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
@ -107,7 +107,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {});
InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {});
EXPECT_EQ(InferenceContext::kUnknownDim,
c.Value(InferenceContext::kUnknownDim));
EXPECT_EQ(1, c.Value(1));
@ -122,7 +122,7 @@ TEST_F(ShapeInferenceTest, Run) {
NodeDef def;
def.set_name("foo");
def.set_op("foo_op");
InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {});
InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {});
{
auto fn = [](InferenceContext* c) {
@ -154,7 +154,7 @@ TEST_F(ShapeInferenceTest, Run) {
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})},
{});
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@ -195,7 +195,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
TEST_F(ShapeInferenceTest, NumElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2),
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {});
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {});
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
@ -208,7 +208,7 @@ TEST_F(ShapeInferenceTest, NumElements) {
TEST_F(ShapeInferenceTest, WithRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {});
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -246,7 +246,7 @@ TEST_F(ShapeInferenceTest, WithRank) {
TEST_F(ShapeInferenceTest, WithRankAtMost) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {});
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -284,7 +284,7 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {});
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -322,7 +322,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
TEST_F(ShapeInferenceTest, WithValue) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {});
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@ -363,7 +363,7 @@ TEST_F(ShapeInferenceTest, WithValue) {
TEST_F(ShapeInferenceTest, MergeDim) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {});
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@ -412,7 +412,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
InferenceContext c(&def, MakeOpDef(7, 2),
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
Unknown(), S({1})},
{});
{}, {});
auto s_unknown = c.input(0);
auto s_1_2 = c.input(1);
@ -483,7 +483,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
{
Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
},
{});
{}, {});
auto s_unknown = c.input(0);
auto s_u_2 = c.input(1);
@ -536,7 +536,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
TEST_F(ShapeInferenceTest, Subshape) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()},
{});
{}, {});
ShapeHandle unknown = c.input(1);
ShapeHandle out;
@ -611,7 +611,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
TEST_F(ShapeInferenceTest, Concatenate) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2),
{S({1, -1, 3}), S({4, 5}), Unknown()}, {});
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -637,7 +637,7 @@ TEST_F(ShapeInferenceTest, Concatenate) {
TEST_F(ShapeInferenceTest, ReplaceDim) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {});
InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {});
auto in = c.input(0);
auto unknown = c.input(1);
@ -668,7 +668,7 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
TEST_F(ShapeInferenceTest, MakeShape) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {});
std::vector<DimensionHandle> dims;
auto in0 = c.input(0);
@ -693,7 +693,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
TEST_F(ShapeInferenceTest, UnknownShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto u0 = c.UnknownShape();
auto u1 = c.UnknownShape();
@ -705,7 +705,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
TEST_F(ShapeInferenceTest, Scalar) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto s0 = c.Scalar();
EXPECT_EQ("[]", c.DebugString(s0));
@ -716,7 +716,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
TEST_F(ShapeInferenceTest, Vector) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto s0 = c.Vector(1);
EXPECT_EQ("[1]", c.DebugString(s0));
@ -732,7 +732,7 @@ TEST_F(ShapeInferenceTest, Vector) {
TEST_F(ShapeInferenceTest, Matrix) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto s0 = c.Matrix(1, 2);
EXPECT_EQ("[1,2]", c.DebugString(s0));
@ -754,7 +754,7 @@ TEST_F(ShapeInferenceTest, Matrix) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
auto create = [&](Tensor* t) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t});
InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {});
ShapeHandle out;
Status s = c.MakeShapeFromShapeTensor(0, &out);
if (s.ok()) {
@ -774,6 +774,9 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
t = ::tensorflow::test::AsTensor<int64>({3, 2, 1});
EXPECT_EQ("[3,2,1]", create(&t));
t = ::tensorflow::test::AsTensor<int64>({3, -1, 1});
EXPECT_EQ("[3,?,1]", create(&t));
t = ::tensorflow::test::AsTensor<int64>({});
EXPECT_EQ("[]", create(&t));
@ -790,10 +793,20 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
EXPECT_TRUE(StringPiece(create(&t))
.contains("Input tensor must be rank 1, but was rank 2"));
// Test negative values for the dims.
t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
EXPECT_TRUE(StringPiece(create(&t))
.contains("Invalid value in tensor used for shape: -2"));
// Test negative values for the dims.
t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
EXPECT_TRUE(StringPiece(create(&t))
.contains("Invalid value in tensor used for shape: -2"));
// Test when the input shape is wrong.
{
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr});
InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {});
ShapeHandle out;
EXPECT_EQ("Shape must be rank 1 but is rank 2",
c.MakeShapeFromShapeTensor(0, &out).error_message());
@ -803,7 +816,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
TensorShapeProto proto;
// With a set unknown rank.
@ -839,7 +852,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
TEST_F(ShapeInferenceTest, MakeDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto d0 = c.MakeDim(1);
auto d1 = c.MakeDim(1);
@ -853,7 +866,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
TEST_F(ShapeInferenceTest, UnknownDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto d0 = c.UnknownDim();
auto d1 = c.UnknownDim();
@ -865,7 +878,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
@ -879,7 +892,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
{&t1, &t2});
{&t1, &t2}, {});
EXPECT_TRUE(c.input_tensor(0) == &t1);
EXPECT_TRUE(c.input_tensor(1) == &t2);
@ -890,7 +903,7 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
Tensor t1 = tensorflow::test::AsScalar<int32>(20);
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2});
InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {});
DimensionHandle d;
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
@ -921,7 +934,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
.ok());
std::vector<ShapeHandle> empty;
InferenceContext c(&def, op_reg_data.op_def, empty, {});
InferenceContext c(&def, op_reg_data.op_def, empty, {}, {});
string value;
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
EXPECT_EQ("bar", value);
@ -929,11 +942,14 @@ TEST_F(ShapeInferenceTest, GetAttr) {
TEST_F(ShapeInferenceTest, Divide) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1})}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
auto d_unknown = c.Dim(s, 1);
auto d_1 = c.Dim(s, 2);
auto d_2 = c.Dim(s, 3);
auto d_0 = c.Dim(s, 4);
bool evenly_divisible = true;
// Dividing unknown by non-1 gives new unknown.
@ -947,9 +963,15 @@ TEST_F(ShapeInferenceTest, Divide) {
EXPECT_TRUE(SameHandle(out, d_unknown));
EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok());
EXPECT_TRUE(SameHandle(out, d_6));
EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok());
EXPECT_TRUE(SameHandle(out, d_unknown));
EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok());
EXPECT_TRUE(SameHandle(out, d_6));
EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok());
EXPECT_EQ("3", c.DebugString(out));
EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
EXPECT_EQ("3", c.DebugString(out));
EXPECT_TRUE(
StringPiece(c.Divide(d_6, 5, evenly_divisible, &out).error_message())
@ -958,6 +980,9 @@ TEST_F(ShapeInferenceTest, Divide) {
EXPECT_TRUE(
StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message())
.contains("Divisor must be positive but is 0"));
EXPECT_TRUE(
StringPiece(c.Divide(d_6, d_0, evenly_divisible, &out).error_message())
.contains("Divisor must be positive but is 0"));
EXPECT_TRUE(
StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message())
@ -979,7 +1004,7 @@ TEST_F(ShapeInferenceTest, Divide) {
TEST_F(ShapeInferenceTest, Add) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1030,7 +1055,7 @@ TEST_F(ShapeInferenceTest, Add) {
TEST_F(ShapeInferenceTest, Subtract) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1079,7 +1104,7 @@ TEST_F(ShapeInferenceTest, Subtract) {
TEST_F(ShapeInferenceTest, Multiply) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1132,7 +1157,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
TEST_F(ShapeInferenceTest, FullyDefined) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
// No rank or missing dimension information should return false.
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
@ -1145,7 +1170,7 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
TEST_F(ShapeInferenceTest, Min) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@ -1193,7 +1218,7 @@ TEST_F(ShapeInferenceTest, Min) {
TEST_F(ShapeInferenceTest, Max) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@ -1231,7 +1256,7 @@ TEST_F(ShapeInferenceTest, Max) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
{});
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1243,7 +1268,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})},
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
{});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1256,7 +1281,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {});
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {},
{});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1269,7 +1295,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {});
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {},
{});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1282,7 +1309,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {});
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {},
{});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1295,7 +1323,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {});
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
{});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1307,7 +1336,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {});
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
{});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1319,7 +1349,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {});
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
{});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1331,7 +1362,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {});
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
{});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1343,7 +1375,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {});
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {},
{});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());

View File

@ -44,7 +44,8 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
}
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
in_shapes, op.input_tensors);
in_shapes, op.input_tensors,
{} /* input_tensors_as_shapes */);
TF_RETURN_IF_ERROR(c.construction_status());
if (op_reg_data->shape_inference_fn == nullptr) {
return errors::InvalidArgument(

View File

@ -243,6 +243,11 @@ void CostModel::RecordMaxMemorySize(const Node* node, int output_slot,
if (id < 0) return;
Ensure(id);
auto& current_max = max_mem_usage_[id].output_port_mem[output_slot];
// If the memory allocator doesn't track memory usage, let's infer a lower
// bound from the tensor shape and its data type.
if (bytes.value() < 0) {
bytes = MinTensorMemoryUsage(tensor_shape, dtype);
}
if (bytes.value() > current_max.value()) {
current_max = bytes.value();
max_mem_usage_[id].output_port_shape[output_slot] = tensor_shape;
@ -476,4 +481,18 @@ void CostModel::WriteSummaryToLog() const {
}
}
Bytes CostModel::MinTensorMemoryUsage(const TensorShapeProto& tensor_shape,
const DataType& dtype) {
if (tensor_shape.unknown_rank()) {
return Bytes(-1);
}
size_t num_coefficients = 1;
for (const TensorShapeProto::Dim& dim : tensor_shape.dim()) {
// If the dimension is unknown, it has to be at least 1
num_coefficients *= std::max<size_t>(dim.size(), 1);
}
return Bytes(num_coefficients * DataTypeSize(dtype));
}
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More