Merge pull request #5248 from zheng-xq/branch_137452948
Branch 137452948
This commit is contained in:
commit
3737ac321e
@ -129,8 +129,6 @@ filegroup(
|
||||
"//tensorflow/contrib/tensorboard:all_files",
|
||||
"//tensorflow/contrib/testing:all_files",
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof/internal:all_files",
|
||||
"//tensorflow/contrib/training:all_files",
|
||||
"//tensorflow/contrib/util:all_files",
|
||||
"//tensorflow/core:all_files",
|
||||
@ -188,6 +186,8 @@ filegroup(
|
||||
"//tensorflow/tools/proto_text:all_files",
|
||||
"//tensorflow/tools/quantization:all_files",
|
||||
"//tensorflow/tools/test:all_files",
|
||||
"//tensorflow/tools/tfprof:all_files",
|
||||
"//tensorflow/tools/tfprof/internal:all_files",
|
||||
"//tensorflow/user_ops:all_files",
|
||||
"//third_party/hadoop:all_files",
|
||||
],
|
||||
|
@ -430,6 +430,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -34,6 +34,7 @@ cc_library(
|
||||
":constants",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/util/tensor_bundle:naming",
|
||||
@ -63,7 +64,6 @@ tf_cc_test(
|
||||
filegroup(
|
||||
name = "saved_model_half_plus_two",
|
||||
srcs = glob([
|
||||
"testdata/half_plus_two/**",
|
||||
"testdata/half_plus_two_pbtxt/**",
|
||||
"testdata/half_plus_two_sharded/**",
|
||||
]),
|
||||
|
@ -30,6 +30,9 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
|
||||
// SavedModel text format proto filename.
|
||||
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
|
||||
|
||||
// SavedModel legacy init op key.
|
||||
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
|
||||
|
||||
// Directory in which to save the SavedModel variables.
|
||||
constexpr char kSavedModelVariablesDirectory[] = "variables";
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
@ -83,10 +84,32 @@ Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||
return (*session)->Create(meta_graph_def.graph_def());
|
||||
}
|
||||
|
||||
Status Restore(const RunOptions& run_options, const string& export_dir,
|
||||
const StringPiece restore_op_name,
|
||||
const StringPiece variable_filename_const_op_name,
|
||||
Session* session) {
|
||||
Tensor CreateStringTensor(const string& value) {
|
||||
Tensor tensor(DT_STRING, TensorShape({}));
|
||||
tensor.scalar<string>()() = value;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void AddAssetsTensorsToInputs(const StringPiece export_dir,
|
||||
const std::vector<AssetFileDef>& asset_file_defs,
|
||||
std::vector<std::pair<string, Tensor>>* inputs) {
|
||||
if (asset_file_defs.empty()) {
|
||||
return;
|
||||
}
|
||||
for (auto& asset_file_def : asset_file_defs) {
|
||||
Tensor assets_file_path_tensor = CreateStringTensor(io::JoinPath(
|
||||
export_dir, kSavedModelAssetsDirectory, asset_file_def.filename()));
|
||||
inputs->push_back(
|
||||
{asset_file_def.tensor_info().name(), assets_file_path_tensor});
|
||||
}
|
||||
}
|
||||
|
||||
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
const StringPiece restore_op_name,
|
||||
const StringPiece variable_filename_const_op_name,
|
||||
const std::vector<AssetFileDef>& asset_file_defs,
|
||||
Session* session) {
|
||||
LOG(INFO) << "Restoring SavedModel bundle.";
|
||||
// Find path to variables to be restored in export directory.
|
||||
const string variables_directory =
|
||||
io::JoinPath(export_dir, kSavedModelVariablesDirectory);
|
||||
@ -109,11 +132,54 @@ Status Restore(const RunOptions& run_options, const string& export_dir,
|
||||
std::vector<std::pair<string, Tensor>> inputs = {
|
||||
{variable_filename_const_op_name.ToString(), variables_path_tensor}};
|
||||
|
||||
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
|
||||
|
||||
RunMetadata run_metadata;
|
||||
return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
|
||||
nullptr /* outputs */, &run_metadata);
|
||||
}
|
||||
|
||||
Status RunLegacyInitOp(const RunOptions& run_options, const string& export_dir,
|
||||
const MetaGraphDef& meta_graph_def,
|
||||
const std::vector<AssetFileDef>& asset_file_defs,
|
||||
Session* session) {
|
||||
LOG(INFO) << "Running LegacyInitOp on SavedModel bundle.";
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto init_op_it = collection_def_map.find(kSavedModelLegacyInitOpKey);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(strings::StrCat(
|
||||
"Expected exactly one serving init op in : ", export_dir));
|
||||
}
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
|
||||
RunMetadata run_metadata;
|
||||
const StringPiece legacy_init_op_name =
|
||||
init_op_it->second.node_list().value(0);
|
||||
return session->Run(run_options, inputs, {},
|
||||
{legacy_init_op_name.ToString()}, nullptr /* outputs */,
|
||||
&run_metadata);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFileDef asset_file_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
||||
asset_file_defs->push_back(asset_file_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
const RunOptions& run_options,
|
||||
const string& export_dir,
|
||||
@ -134,12 +200,19 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
|
||||
bundle->meta_graph_def, session_options, &bundle->session));
|
||||
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Restore(run_options, export_dir,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
||||
bundle->session.get()));
|
||||
|
||||
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunRestore(run_options, export_dir,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
||||
asset_file_defs, bundle->session.get()));
|
||||
// TODO(sukritiramesh): Add support for a single main op to run upon load,
|
||||
// which will supersede the legacy_init_op and separate RunRestore.
|
||||
TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir,
|
||||
bundle->meta_graph_def, asset_file_defs,
|
||||
bundle->session.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestDataPb[] = "cc/saved_model/testdata/half_plus_two";
|
||||
constexpr char kTestDataPbTxt[] = "cc/saved_model/testdata/half_plus_two_pbtxt";
|
||||
constexpr char kTestDataSharded[] =
|
||||
"cc/saved_model/testdata/half_plus_two_sharded";
|
||||
@ -45,12 +44,26 @@ class LoaderTest : public ::testing::Test {
|
||||
return example.SerializeAsString();
|
||||
}
|
||||
|
||||
void ValidateAssets(const string& export_dir,
|
||||
const SavedModelBundle& bundle) {
|
||||
const string asset_directory =
|
||||
io::JoinPath(export_dir, kSavedModelAssetsDirectory);
|
||||
const string asset_filename = "foo.txt";
|
||||
const string asset_filepath = io::JoinPath(asset_directory, asset_filename);
|
||||
EXPECT_TRUE(Env::Default()->FileExists(asset_filepath));
|
||||
|
||||
std::vector<Tensor> path_outputs;
|
||||
TF_ASSERT_OK(
|
||||
bundle.session->Run({}, {"filename_tensor:0"}, {}, &path_outputs));
|
||||
ASSERT_EQ(1, path_outputs.size());
|
||||
|
||||
test::ExpectTensorEqual<string>(
|
||||
test::AsTensor<string>({"foo.txt"}, TensorShape({})), path_outputs[0]);
|
||||
}
|
||||
|
||||
void CheckSavedModelBundle(const string& export_dir,
|
||||
const SavedModelBundle& bundle) {
|
||||
const string asset_path =
|
||||
io::JoinPath(export_dir, kSavedModelAssetsDirectory, "foo.txt");
|
||||
EXPECT_TRUE(Env::Default()->FileExists(asset_path));
|
||||
|
||||
ValidateAssets(export_dir, bundle);
|
||||
// Retrieve the regression signature from meta graph def.
|
||||
const auto signature_def_map = bundle.meta_graph_def.signature_def();
|
||||
const auto signature_def = signature_def_map.at(kRegressMethodName);
|
||||
@ -151,18 +164,6 @@ TEST_F(LoaderTest, PbtxtFormat) {
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, SingleShardVariables) {
|
||||
SavedModelBundle bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPb);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, InvalidExportPath) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
|
@ -1 +0,0 @@
|
||||
asset-file-contents
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -102,6 +102,24 @@ meta_graphs {
|
||||
type: "type"
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MergeV2Checkpoints"
|
||||
input_arg {
|
||||
name: "checkpoint_prefixes"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "destination_prefix"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "delete_old_dirs"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Mul"
|
||||
input_arg {
|
||||
@ -140,6 +158,35 @@ meta_graphs {
|
||||
op {
|
||||
name: "NoOp"
|
||||
}
|
||||
op {
|
||||
name: "Pack"
|
||||
input_arg {
|
||||
name: "values"
|
||||
type_attr: "T"
|
||||
number_attr: "N"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "axis"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ParseExample"
|
||||
input_arg {
|
||||
@ -267,9 +314,9 @@ meta_graphs {
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "SaveSlices"
|
||||
name: "SaveV2"
|
||||
input_arg {
|
||||
name: "filename"
|
||||
name: "prefix"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
@ -277,15 +324,15 @@ meta_graphs {
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "shapes_and_slices"
|
||||
name: "shape_and_slices"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "data"
|
||||
type_list_attr: "T"
|
||||
name: "tensors"
|
||||
type_list_attr: "dtypes"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
name: "dtypes"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
@ -311,19 +358,29 @@ meta_graphs {
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ShardedFilespec"
|
||||
name: "StringJoin"
|
||||
input_arg {
|
||||
name: "basename"
|
||||
name: "inputs"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "num_shards"
|
||||
type: DT_INT32
|
||||
number_attr: "N"
|
||||
}
|
||||
output_arg {
|
||||
name: "filename"
|
||||
name: "output"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "separator"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Variable"
|
||||
@ -899,6 +956,244 @@ meta_graphs {
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
}
|
||||
string_val: "/tmp/original/export/assets/foo.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "filename_tensor/initial_value"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
}
|
||||
string_val: "foo.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "filename_tensor"
|
||||
op: "Variable"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "container"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shared_name"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "filename_tensor/Assign"
|
||||
op: "Assign"
|
||||
input: "filename_tensor"
|
||||
input: "filename_tensor/initial_value"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@filename_tensor"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_locking"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "validate_shape"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "filename_tensor/read"
|
||||
op: "Identity"
|
||||
input: "filename_tensor"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@filename_tensor"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Assign/value"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
}
|
||||
string_val: "foo.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Assign"
|
||||
op: "Assign"
|
||||
input: "filename_tensor"
|
||||
input: "Assign/value"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@filename_tensor"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_locking"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "validate_shape"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Identity"
|
||||
op: "Identity"
|
||||
@ -931,6 +1226,11 @@ meta_graphs {
|
||||
input: "^a/Assign"
|
||||
input: "^b/Assign"
|
||||
}
|
||||
node {
|
||||
name: "group_deps"
|
||||
op: "NoOp"
|
||||
input: "^Assign"
|
||||
}
|
||||
node {
|
||||
name: "save/Const"
|
||||
op: "Const"
|
||||
@ -961,6 +1261,63 @@ meta_graphs {
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/StringJoin/inputs_1"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_STRING
|
||||
tensor_shape {
|
||||
}
|
||||
string_val: "_temp_ff2bd25218b646ea9ed224eecdce5e79/part"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/StringJoin"
|
||||
op: "StringJoin"
|
||||
input: "save/Const"
|
||||
input: "save/StringJoin/inputs_1"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "separator"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/num_shards"
|
||||
op: "Const"
|
||||
@ -1024,7 +1381,7 @@ meta_graphs {
|
||||
node {
|
||||
name: "save/ShardedFilename"
|
||||
op: "ShardedFilename"
|
||||
input: "save/Const"
|
||||
input: "save/StringJoin"
|
||||
input: "save/ShardedFilename/shard"
|
||||
input: "save/num_shards"
|
||||
attr {
|
||||
@ -1038,7 +1395,7 @@ meta_graphs {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/save/tensor_names"
|
||||
name: "save/SaveV2/tensor_names"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
@ -1075,7 +1432,7 @@ meta_graphs {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/save/shapes_and_slices"
|
||||
name: "save/SaveV2/shape_and_slices"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
@ -1112,15 +1469,15 @@ meta_graphs {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/save"
|
||||
op: "SaveSlices"
|
||||
name: "save/SaveV2"
|
||||
op: "SaveV2"
|
||||
input: "save/ShardedFilename"
|
||||
input: "save/save/tensor_names"
|
||||
input: "save/save/shapes_and_slices"
|
||||
input: "save/SaveV2/tensor_names"
|
||||
input: "save/SaveV2/shape_and_slices"
|
||||
input: "a"
|
||||
input: "b"
|
||||
attr {
|
||||
key: "T"
|
||||
key: "dtypes"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
@ -1133,7 +1490,7 @@ meta_graphs {
|
||||
name: "save/control_dependency"
|
||||
op: "Identity"
|
||||
input: "save/ShardedFilename"
|
||||
input: "^save/save"
|
||||
input: "^save/SaveV2"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
@ -1159,11 +1516,65 @@ meta_graphs {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/ShardedFilespec"
|
||||
op: "ShardedFilespec"
|
||||
input: "save/Const"
|
||||
input: "save/num_shards"
|
||||
name: "save/MergeV2Checkpoints/checkpoint_prefixes"
|
||||
op: "Pack"
|
||||
input: "save/ShardedFilename"
|
||||
input: "^save/control_dependency"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "axis"
|
||||
value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/MergeV2Checkpoints"
|
||||
op: "MergeV2Checkpoints"
|
||||
input: "save/MergeV2Checkpoints/checkpoint_prefixes"
|
||||
input: "save/Const"
|
||||
attr {
|
||||
key: "delete_old_dirs"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "save/Identity"
|
||||
op: "Identity"
|
||||
input: "save/Const"
|
||||
input: "^save/control_dependency"
|
||||
input: "^save/MergeV2Checkpoints"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
@ -1467,12 +1878,39 @@ meta_graphs {
|
||||
}
|
||||
saver_def {
|
||||
filename_tensor_name: "save/Const:0"
|
||||
save_tensor_name: "save/ShardedFilespec:0"
|
||||
save_tensor_name: "save/Identity:0"
|
||||
restore_op_name: "save/restore_all"
|
||||
max_to_keep: 5
|
||||
sharded: true
|
||||
keep_checkpoint_every_n_hours: 10000.0
|
||||
version: V1
|
||||
version: V2
|
||||
}
|
||||
collection_def {
|
||||
key: "asset_filepaths"
|
||||
value {
|
||||
node_list {
|
||||
value: "Const:0"
|
||||
}
|
||||
}
|
||||
}
|
||||
collection_def {
|
||||
key: "legacy_init_op"
|
||||
value {
|
||||
node_list {
|
||||
value: "group_deps"
|
||||
}
|
||||
}
|
||||
}
|
||||
collection_def {
|
||||
key: "saved_model_assets"
|
||||
value {
|
||||
any_list {
|
||||
value {
|
||||
type_url: "type.googleapis.com/tensorflow.AssetFileDef"
|
||||
value: "\n\t\n\007Const:0\022\007foo.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
collection_def {
|
||||
key: "trainable_variables"
|
||||
|
Binary file not shown.
@ -54,7 +54,8 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
|
||||
}
|
||||
|
||||
QueueRunner::~QueueRunner() {
|
||||
should_stop_ = true;
|
||||
// Cannot run Stop() here because the session might already be closed or
|
||||
// destroyed.
|
||||
Join();
|
||||
}
|
||||
|
||||
@ -72,6 +73,15 @@ Status QueueRunner::Start(Session* sess) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status QueueRunner::Stop(Session* sess) {
|
||||
should_stop_ = true;
|
||||
if (cancel_op_name_.empty()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return sess->Run({}, {}, {cancel_op_name_}, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
Status QueueRunner::Join() {
|
||||
thread_pool_.reset();
|
||||
started_ = false;
|
||||
@ -80,9 +90,8 @@ Status QueueRunner::Join() {
|
||||
|
||||
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
|
||||
bool decremented = false;
|
||||
while (!should_stop_) {
|
||||
std::vector<Tensor> outputs;
|
||||
auto status = sess->Run({}, {}, {enqueue_op}, &outputs);
|
||||
while (!should_stop_.load()) {
|
||||
auto status = sess->Run({}, {}, {enqueue_op}, nullptr);
|
||||
if (status.ok()) {
|
||||
continue;
|
||||
} else if (queue_closed_exception_types_.count(
|
||||
@ -94,19 +103,25 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
|
||||
|
||||
// If all enqueue ops have finished, run the close op.
|
||||
if (runs_ == 0 && !close_op_name_.empty()) {
|
||||
std::vector<Tensor> outputs;
|
||||
auto s = sess->Run({}, {}, {close_op_name_}, &outputs);
|
||||
if (!s.ok()) {
|
||||
status_ = status;
|
||||
auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
|
||||
if (!s.ok() && status_.ok() &&
|
||||
queue_closed_exception_types_.count(static_cast<int>(s.code())) ==
|
||||
0) {
|
||||
status_ = s;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mutex_lock l(mu_);
|
||||
should_stop_ = true;
|
||||
// Only record the first failure status.
|
||||
if (status_.ok()) {
|
||||
status_ = status;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
should_stop_ = true;
|
||||
// Only record the first failure status.
|
||||
if (status_.ok()) {
|
||||
status_ = status;
|
||||
}
|
||||
}
|
||||
// Stop the queue runner immediately to propagate the error to
|
||||
// subsequent queues.
|
||||
Stop(sess);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
@ -49,6 +50,9 @@ class QueueRunner {
|
||||
// Starts the queue runner with the given session.
|
||||
Status Start(Session* sess);
|
||||
|
||||
// Requests to stop and runs the cancel op.
|
||||
Status Stop(Session* sess);
|
||||
|
||||
// Joins all the threads. Returns okay if all threads run successfully;
|
||||
// otherwise returns the first captured failure status.
|
||||
Status Join();
|
||||
@ -60,14 +64,14 @@ class QueueRunner {
|
||||
string queue_name_;
|
||||
std::vector<string> enqueue_op_names_;
|
||||
string close_op_name_;
|
||||
// The cancel op is not being called currently.
|
||||
string cancel_op_name_;
|
||||
// code::Code casted to int to avoid a hash function.
|
||||
std::unordered_set<int> queue_closed_exception_types_;
|
||||
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
bool should_stop_;
|
||||
std::atomic<bool> should_stop_;
|
||||
std::atomic<bool> started_;
|
||||
condition_variable wait_to_close_;
|
||||
mutex mu_;
|
||||
// TODO(yuefengz): implement c++ coordinator.
|
||||
int runs_ = 0;
|
||||
|
@ -14,8 +14,10 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/training/queue_runner.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -23,39 +25,42 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/queue_runner.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::DataType;
|
||||
using ::tensorflow::error::Code;
|
||||
using ::tensorflow::GraphDef;
|
||||
using ::tensorflow::ops::Assign;
|
||||
using ::tensorflow::ops::Const;
|
||||
using ::tensorflow::ops::CountUpTo;
|
||||
using ::tensorflow::ops::FIFOQueue;
|
||||
using ::tensorflow::ops::InputList;
|
||||
using ::tensorflow::ops::QueueClose;
|
||||
using ::tensorflow::ops::QueueDequeue;
|
||||
using ::tensorflow::ops::QueueEnqueue;
|
||||
using ::tensorflow::ops::Square;
|
||||
using ::tensorflow::ops::Variable;
|
||||
using ::tensorflow::QueueRunner;
|
||||
using ::tensorflow::QueueRunnerDef;
|
||||
using ::tensorflow::Scope;
|
||||
using ::tensorflow::Session;
|
||||
using ::tensorflow::SessionOptions;
|
||||
using ::tensorflow::Tensor;
|
||||
using ::tensorflow::TensorShape;
|
||||
using error::Code;
|
||||
using ops::Assign;
|
||||
using ops::Const;
|
||||
using ops::CountUpTo;
|
||||
using ops::FIFOQueue;
|
||||
using ops::QueueClose;
|
||||
using ops::QueueDequeue;
|
||||
using ops::QueueEnqueue;
|
||||
using ops::Square;
|
||||
using ops::Variable;
|
||||
|
||||
constexpr char kAssignOpName[] = "assign";
|
||||
constexpr char kCancelOp0[] = "cancel0";
|
||||
constexpr char kCancelOp1[] = "cancel1";
|
||||
constexpr char kCloseOp0[] = "close0";
|
||||
constexpr char kCloseOp1[] = "close1";
|
||||
constexpr char kCountUpToOpName[] = "count";
|
||||
constexpr char kDequeueOp0[] = "dequeue0";
|
||||
constexpr char kDequeueOp1[] = "dequeue1";
|
||||
constexpr char kEnqueueOp0[] = "enqueue0";
|
||||
constexpr char kEnqueueOp1[] = "enqueue1";
|
||||
constexpr char kIllegalOpName1[] = "would fail";
|
||||
constexpr char kIllegalOpName2[] = "fail again";
|
||||
constexpr char kQueueName[] = "unit_test";
|
||||
constexpr char kQueueName0[] = "q0";
|
||||
constexpr char kQueueName1[] = "q1";
|
||||
constexpr char kSquareOpName[] = "square";
|
||||
constexpr char kVarOpName[] = "var";
|
||||
|
||||
@ -75,7 +80,7 @@ GraphDef BuildSimpleGraph() {
|
||||
|
||||
QueueRunnerDef BuildQueueRunnerDef(
|
||||
const std::string& queue_name, const std::vector<std::string>& enqueue_ops,
|
||||
const std::string& close_op,
|
||||
const std::string& close_op, const std::string& cancel_op,
|
||||
const std::vector<Code>& queue_closed_error_codes) {
|
||||
QueueRunnerDef queue_runner_def;
|
||||
*queue_runner_def.mutable_queue_name() = kQueueName;
|
||||
@ -83,6 +88,7 @@ QueueRunnerDef BuildQueueRunnerDef(
|
||||
*queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
|
||||
}
|
||||
*queue_runner_def.mutable_close_op_name() = close_op;
|
||||
*queue_runner_def.mutable_cancel_op_name() = cancel_op;
|
||||
for (const auto& error_code : queue_closed_error_codes) {
|
||||
*queue_runner_def.mutable_queue_closed_exception_types()->Add() =
|
||||
error_code;
|
||||
@ -96,8 +102,7 @@ std::unique_ptr<Session> BuildSessionAndInitVariable(
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
TF_CHECK_OK(session->Create(graph_def));
|
||||
|
||||
std::vector<Tensor> nothing;
|
||||
TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, ¬hing));
|
||||
TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr));
|
||||
return session;
|
||||
}
|
||||
|
||||
@ -106,7 +111,7 @@ TEST(QueueRunnerTest, BasicTest) {
|
||||
auto session = BuildSessionAndInitVariable(graph_def);
|
||||
|
||||
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
||||
kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, {});
|
||||
kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {});
|
||||
|
||||
QueueRunner qr(queue_runner_def);
|
||||
qr.Start(session.get());
|
||||
@ -123,7 +128,7 @@ TEST(QueueRunnerTest, QueueClosedCode) {
|
||||
auto session = BuildSessionAndInitVariable(graph_def);
|
||||
|
||||
QueueRunnerDef queue_runner_def =
|
||||
BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName,
|
||||
BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, "",
|
||||
{Code::OUT_OF_RANGE, Code::CANCELLED});
|
||||
|
||||
QueueRunner qr(queue_runner_def);
|
||||
@ -141,60 +146,167 @@ TEST(QueueRunnerDef, CatchErrorInJoin) {
|
||||
auto session = BuildSessionAndInitVariable(graph_def);
|
||||
|
||||
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
||||
kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, {});
|
||||
kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
|
||||
|
||||
QueueRunner qr(queue_runner_def);
|
||||
qr.Start(session.get());
|
||||
EXPECT_EQ(qr.Join().code(), Code::NOT_FOUND);
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, RealEnqueueDequeue) {
|
||||
GraphDef BuildDoubleQueueGraph() {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto q0 = FIFOQueue(root.WithOpName("q0"), {DataType::DT_INT32});
|
||||
auto q0 = FIFOQueue(root.WithOpName(kQueueName0), {DataType::DT_INT32});
|
||||
auto ten = Const(root, 10);
|
||||
auto enqueue0 = QueueEnqueue(root.WithOpName("enqueue0"), q0, {ten});
|
||||
auto close0 = QueueClose(root.WithOpName("close0"), q0);
|
||||
auto q1 = FIFOQueue(root.WithOpName("q1"), {DataType::DT_INT32});
|
||||
auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {ten});
|
||||
auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
|
||||
auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
|
||||
QueueClose::CancelPendingEnqueues(true));
|
||||
auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32});
|
||||
auto dequeue0 =
|
||||
QueueDequeue(root.WithOpName("dequeue0"), q0, {DataType::DT_INT32});
|
||||
auto enqueue1 = QueueEnqueue(root.WithOpName("enqueue1"), q1, {dequeue0[0]});
|
||||
QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
|
||||
auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
|
||||
auto dequeue1 =
|
||||
QueueDequeue(root.WithOpName("dequeue1"), q1, {DataType::DT_INT32});
|
||||
auto close1 = QueueClose(root.WithOpName("close1"), q1);
|
||||
QueueDequeue(root.WithOpName(kDequeueOp1), q1, {DataType::DT_INT32});
|
||||
auto close1 = QueueClose(root.WithOpName(kCloseOp1), q1);
|
||||
auto cancel1 = QueueClose(root.WithOpName(kCancelOp1), q1,
|
||||
QueueClose::CancelPendingEnqueues(true));
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_EXPECT_OK(root.ToGraphDef(&graph_def));
|
||||
return graph_def;
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, RealEnqueueDequeue) {
|
||||
auto graph_def = BuildDoubleQueueGraph();
|
||||
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
TF_CHECK_OK(session->Create(graph_def));
|
||||
|
||||
QueueRunnerDef queue_runner_def =
|
||||
BuildQueueRunnerDef(kQueueName, {"enqueue1"}, "close1", {});
|
||||
BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {});
|
||||
QueueRunner qr;
|
||||
qr.Init(queue_runner_def);
|
||||
TF_CHECK_OK(qr.Start(session.get()));
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs));
|
||||
TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs));
|
||||
TF_EXPECT_OK(session->Run({}, {}, {"close0"}, &outputs));
|
||||
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
|
||||
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
|
||||
// Closing queue 0 would also close the queue runner.
|
||||
TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr));
|
||||
|
||||
TF_EXPECT_OK(qr.Join());
|
||||
std::vector<Tensor> dq1;
|
||||
TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq1));
|
||||
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
|
||||
EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
|
||||
std::vector<Tensor> dq2;
|
||||
TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq2));
|
||||
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq2));
|
||||
EXPECT_EQ(*dq2[0].scalar<int>().data(), 10);
|
||||
|
||||
EXPECT_EQ(session->Run({}, {"dequeue1"}, {}, &dq1).code(),
|
||||
EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
|
||||
Code::OUT_OF_RANGE);
|
||||
}
|
||||
|
||||
void JoinThread(QueueRunner* queue_runner, bool* join_succeeded,
|
||||
Notification* join_done) {
|
||||
EXPECT_EQ(queue_runner->Join().code(), Code::CANCELLED);
|
||||
*join_succeeded = true;
|
||||
join_done->Notify();
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
|
||||
auto graph_def = BuildDoubleQueueGraph();
|
||||
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
TF_CHECK_OK(session->Create(graph_def));
|
||||
|
||||
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
||||
kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
|
||||
QueueRunner qr;
|
||||
qr.Init(queue_runner_def);
|
||||
TF_CHECK_OK(qr.Start(session.get()));
|
||||
|
||||
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
|
||||
|
||||
std::vector<Tensor> dq1;
|
||||
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
|
||||
EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
|
||||
|
||||
// The expected behavior is the QueueRunner::Join() call is blocked until
|
||||
// Session::Close() is called.
|
||||
bool join_succeeded = false;
|
||||
Notification join_done;
|
||||
Env::Default()->SchedClosure(
|
||||
std::bind(&JoinThread, &qr, &join_succeeded, &join_done));
|
||||
|
||||
Env::Default()->SleepForMicroseconds(10000000);
|
||||
EXPECT_EQ(join_succeeded, false);
|
||||
|
||||
// Closing the session is required to cancel pending enqueue nodes.
|
||||
TF_EXPECT_OK(session->Close());
|
||||
|
||||
join_done.WaitForNotification();
|
||||
EXPECT_EQ(join_succeeded, true);
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, Stop) {
|
||||
auto graph_def = BuildDoubleQueueGraph();
|
||||
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
TF_CHECK_OK(session->Create(graph_def));
|
||||
|
||||
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
||||
kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
|
||||
QueueRunner qr;
|
||||
qr.Init(queue_runner_def);
|
||||
TF_CHECK_OK(qr.Start(session.get()));
|
||||
|
||||
TF_EXPECT_OK(qr.Stop(session.get()));
|
||||
|
||||
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
|
||||
|
||||
EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
|
||||
Code::OUT_OF_RANGE);
|
||||
|
||||
// qr is already stopped
|
||||
TF_EXPECT_OK(qr.Join());
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, StopTwoQueues) {
|
||||
auto graph_def = BuildDoubleQueueGraph();
|
||||
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
TF_CHECK_OK(session->Create(graph_def));
|
||||
|
||||
QueueRunnerDef queue_runner0 =
|
||||
BuildQueueRunnerDef(kQueueName0, {kEnqueueOp0}, kCloseOp0, kCancelOp0,
|
||||
{Code::OUT_OF_RANGE, Code::CANCELLED});
|
||||
QueueRunnerDef queue_runner1 =
|
||||
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
|
||||
{Code::OUT_OF_RANGE, Code::CANCELLED});
|
||||
QueueRunner qr0;
|
||||
qr0.Init(queue_runner0);
|
||||
TF_CHECK_OK(qr0.Start(session.get()));
|
||||
QueueRunner qr1;
|
||||
qr1.Init(queue_runner1);
|
||||
TF_CHECK_OK(qr1.Start(session.get()));
|
||||
|
||||
std::vector<Tensor> dq;
|
||||
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
|
||||
EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
|
||||
|
||||
TF_EXPECT_OK(qr0.Stop(session.get()));
|
||||
TF_EXPECT_OK(qr1.Stop(session.get()));
|
||||
|
||||
TF_EXPECT_OK(qr0.Join());
|
||||
TF_EXPECT_OK(qr1.Join());
|
||||
}
|
||||
|
||||
TEST(QueueRunnerTest, EmptyEnqueueOps) {
|
||||
QueueRunnerDef queue_runner_def =
|
||||
BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, {});
|
||||
BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
|
||||
|
||||
QueueRunner qr;
|
||||
EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::INVALID_ARGUMENT);
|
||||
@ -203,8 +315,8 @@ TEST(QueueRunnerTest, EmptyEnqueueOps) {
|
||||
TEST(QueueRunnerTest, InitAfterStart) {
|
||||
GraphDef graph_def = BuildSimpleGraph();
|
||||
auto session = BuildSessionAndInitVariable(graph_def);
|
||||
QueueRunnerDef queue_runner_def =
|
||||
BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kCountUpToOpName, {});
|
||||
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
|
||||
kQueueName, {kCountUpToOpName}, kCountUpToOpName, "", {});
|
||||
|
||||
QueueRunner qr;
|
||||
TF_EXPECT_OK(qr.Init(queue_runner_def));
|
||||
@ -213,3 +325,4 @@ TEST(QueueRunnerTest, InitAfterStart) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,7 +1,7 @@
|
||||
include (ExternalProject)
|
||||
|
||||
set(gemmlowp_URL http://github.com/google/gemmlowp/archive/c0bacf11fb509a2cbe15a97362a2df067ffd57a2.tar.gz)
|
||||
set(gemmlowp_HASH SHA256=dc64a38f9927db18748d9024987c9b102115e25bc2be4b76aa8e422b8f83d882)
|
||||
set(gemmlowp_URL http://github.com/google/gemmlowp/archive/a6f29d8ac48d63293f845f2253eccbf86bc28321.tar.gz)
|
||||
set(gemmlowp_HASH SHA256=75d40ea8e68b0d1644f052fffe8f14a410b2a73d40ccb859a95c0578d194ec26)
|
||||
set(gemmlowp_BUILD ${CMAKE_BINARY_DIR}/gemmlowp/src/gemmlowp)
|
||||
set(gemmlowp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/gemmlowp/src/gemmlowp)
|
||||
|
||||
|
@ -0,0 +1,14 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
@ -89,6 +89,7 @@ if(WIN32)
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/fact_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/immutable_constant_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/immutable_constant_op.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/meta_support.*"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/sparse_matmul_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/sparse_matmul_op.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.h"
|
||||
|
@ -13,7 +13,10 @@ add_executable(${proto_text}
|
||||
$<TARGET_OBJECTS:tf_core_lib>
|
||||
)
|
||||
|
||||
target_link_libraries(${proto_text} PUBLIC ${tensorflow_EXTERNAL_LIBRARIES})
|
||||
target_link_libraries(${proto_text} PUBLIC
|
||||
${tensorflow_EXTERNAL_LIBRARIES}
|
||||
tf_protos_cc
|
||||
)
|
||||
|
||||
add_dependencies(${proto_text}
|
||||
tf_core_lib
|
||||
|
@ -36,7 +36,7 @@ cuda_py_tests(
|
||||
|
||||
cuda_py_tests(
|
||||
name = "operator_pd_cholesky_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/operator_pd_cholesky_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
@ -60,7 +60,7 @@ cuda_py_tests(
|
||||
|
||||
cuda_py_tests(
|
||||
name = "operator_pd_full_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/operator_pd_full_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
@ -72,7 +72,7 @@ cuda_py_tests(
|
||||
|
||||
cuda_py_tests(
|
||||
name = "operator_pd_identity_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["python/kernel_tests/operator_pd_identity_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
|
@ -614,6 +614,67 @@ class SigmoidCenteredBijectorTest(tf.test.TestCase):
|
||||
atol=0., rtol=1e-7)
|
||||
|
||||
|
||||
class CholeskyOuterProductBijectorTest(tf.test.TestCase):
|
||||
"""Tests the correctness of the Y = X * X^T transformation."""
|
||||
|
||||
def testBijectorMatrix(self):
|
||||
with self.test_session():
|
||||
bijector = bijectors.CholeskyOuterProduct(event_ndims=2,
|
||||
validate_args=True)
|
||||
self.assertEqual("cholesky_outer_product", bijector.name)
|
||||
x = [[[1., 0],
|
||||
[2, 1]],
|
||||
[[math.sqrt(2.), 0],
|
||||
[math.sqrt(8.), 1]]]
|
||||
y = np.matmul(x, np.transpose(x, axes=(0, 2, 1)))
|
||||
# Fairly easy to compute differentials since we have 2x2.
|
||||
dx_dy = [[[2.*1, 0, 0],
|
||||
[2, 1, 0],
|
||||
[0, 2*2, 2*1]],
|
||||
[[2*math.sqrt(2.), 0, 0],
|
||||
[math.sqrt(8.), math.sqrt(2.), 0],
|
||||
[0, 2*math.sqrt(8.), 2*1]]]
|
||||
ildj = -np.sum(
|
||||
np.log(np.asarray(dx_dy).diagonal(offset=0, axis1=1, axis2=2)),
|
||||
axis=1)
|
||||
self.assertAllEqual((2, 2, 2), bijector.forward(x).get_shape())
|
||||
self.assertAllEqual((2, 2, 2), bijector.inverse(y).get_shape())
|
||||
self.assertAllClose(y, bijector.forward(x).eval())
|
||||
self.assertAllClose(x, bijector.inverse(y).eval())
|
||||
self.assertAllClose(ildj,
|
||||
bijector.inverse_log_det_jacobian(y).eval(),
|
||||
atol=0., rtol=1e-7)
|
||||
self.assertAllClose(-bijector.inverse_log_det_jacobian(y).eval(),
|
||||
bijector.forward_log_det_jacobian(x).eval(),
|
||||
atol=0., rtol=1e-7)
|
||||
|
||||
def testBijectorScalar(self):
|
||||
with self.test_session():
|
||||
bijector = bijectors.CholeskyOuterProduct(event_ndims=0,
|
||||
validate_args=True)
|
||||
self.assertEqual("cholesky_outer_product", bijector.name)
|
||||
x = [[[1., 5],
|
||||
[2, 1]],
|
||||
[[math.sqrt(2.), 3],
|
||||
[math.sqrt(8.), 1]]]
|
||||
y = np.square(x)
|
||||
ildj = -math.log(2.) - np.log(x)
|
||||
self.assertAllClose(y, bijector.forward(x).eval())
|
||||
self.assertAllClose(x, bijector.inverse(y).eval())
|
||||
self.assertAllClose(ildj,
|
||||
bijector.inverse_log_det_jacobian(y).eval(),
|
||||
atol=0., rtol=1e-7)
|
||||
self.assertAllClose(-bijector.inverse_log_det_jacobian(y).eval(),
|
||||
bijector.forward_log_det_jacobian(x).eval(),
|
||||
atol=0., rtol=1e-7)
|
||||
|
||||
def testScalarCongruency(self):
|
||||
with self.test_session():
|
||||
bijector = bijectors.CholeskyOuterProduct(event_ndims=0,
|
||||
validate_args=True)
|
||||
assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
|
||||
|
||||
|
||||
class ChainBijectorTest(tf.test.TestCase):
|
||||
"""Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation."""
|
||||
|
||||
|
@ -41,11 +41,34 @@ class DistributionTest(tf.test.TestCase):
|
||||
for cls in classes:
|
||||
for sample_shape in sample_shapes:
|
||||
param_shapes = cls.param_shapes(sample_shape)
|
||||
print(param_shapes)
|
||||
params = dict([(name, tf.random_normal(shape))
|
||||
for name, shape in param_shapes.items()])
|
||||
dist = cls(**params)
|
||||
self.assertAllEqual(sample_shape, tf.shape(dist.sample()).eval())
|
||||
dist_copy = dist.copy()
|
||||
self.assertAllEqual(sample_shape,
|
||||
tf.shape(dist_copy.sample()).eval())
|
||||
self.assertEqual(dist.parameters, dist_copy.parameters)
|
||||
|
||||
def testCopyExtraArgs(self):
|
||||
with self.test_session():
|
||||
# Note: we cannot easily test all distributions since each requires
|
||||
# different initialization arguments. We therefore spot test a few.
|
||||
normal = dists.Normal(mu=1., sigma=2., validate_args=True)
|
||||
self.assertEqual(normal.parameters, normal.copy().parameters)
|
||||
wishart = dists.WishartFull(df=2, scale=[[1., 2], [2, 5]],
|
||||
validate_args=True)
|
||||
self.assertEqual(wishart.parameters, wishart.copy().parameters)
|
||||
|
||||
def testCopyOverride(self):
|
||||
with self.test_session():
|
||||
normal = dists.Normal(mu=1., sigma=2., validate_args=True)
|
||||
normal_copy = normal.copy(validate_args=False)
|
||||
base_params = normal.parameters.copy()
|
||||
copy_params = normal.copy(validate_args=False).parameters.copy()
|
||||
self.assertNotEqual(base_params.pop("validate_args"),
|
||||
copy_params.pop("validate_args"))
|
||||
self.assertEqual(base_params, copy_params)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -14,7 +14,7 @@
|
||||
# ==============================================================================
|
||||
r"""Bijector Ops.
|
||||
|
||||
An API for reversible (bijective) transformations of random variables.
|
||||
An API for invertible, differentiable transformations of random variables.
|
||||
|
||||
## Background
|
||||
|
||||
@ -31,6 +31,7 @@ To apply a `Bijector`, use `distributions.TransformedDistribution`.
|
||||
|
||||
@@Bijector
|
||||
@@Chain
|
||||
@@CholeskyOuterProduct
|
||||
@@Exp
|
||||
@@Identity
|
||||
@@Inline
|
||||
@ -46,7 +47,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import contextlib
|
||||
import math
|
||||
import re
|
||||
import numpy as np
|
||||
import six
|
||||
@ -58,18 +61,112 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
|
||||
__all__ = [
|
||||
"Bijector",
|
||||
"Chain",
|
||||
"CholeskyOuterProduct",
|
||||
"Exp",
|
||||
"Identity",
|
||||
"Inline",
|
||||
"Invert",
|
||||
"ScaleAndShift",
|
||||
"SigmoidCentered",
|
||||
"SoftmaxCentered",
|
||||
"Softplus",
|
||||
]
|
||||
|
||||
|
||||
class _Mapping(collections.namedtuple("_Mapping",
|
||||
["x", "y", "ildj", "condition_kwargs"])):
|
||||
"""Helper class to make it easier to manage caching in `Bijector`."""
|
||||
|
||||
def __new__(cls, x=None, y=None, ildj=None, condition_kwargs=None):
|
||||
"""Custom __new__ so namedtuple items have defaults.
|
||||
|
||||
Args:
|
||||
x: `Tensor`. Forward.
|
||||
y: `Tensor`. Inverse.
|
||||
ildj: `Tensor`. Inverse log det Jacobian.
|
||||
condition_kwargs: Python dictionary. Extra args supplied to
|
||||
forward/inverse/etc functions.
|
||||
|
||||
Returns:
|
||||
mapping: New instance of _Mapping.
|
||||
"""
|
||||
return super(_Mapping, cls).__new__(cls, x, y, ildj, condition_kwargs)
|
||||
|
||||
@property
|
||||
def x_key(self):
|
||||
"""Returns key used for caching Y=g(X)."""
|
||||
return (self.x,) + self._deep_tuple(tuple(sorted(
|
||||
self.condition_kwargs.items())))
|
||||
|
||||
@property
|
||||
def y_key(self):
|
||||
"""Returns key used for caching X=g^{-1}(Y)."""
|
||||
return (self.y,) + self._deep_tuple(tuple(sorted(
|
||||
self.condition_kwargs.items())))
|
||||
|
||||
def merge(self, x=None, y=None, ildj=None,
|
||||
condition_kwargs=None, mapping=None):
|
||||
"""Returns new _Mapping with args merged with self.
|
||||
|
||||
Args:
|
||||
x: `Tensor`. Forward.
|
||||
y: `Tensor`. Inverse.
|
||||
ildj: `Tensor`. Inverse log det Jacobian.
|
||||
condition_kwargs: Python dictionary. Extra args supplied to
|
||||
forward/inverse/etc functions.
|
||||
mapping: Instance of _Mapping to merge. Can only be specified if no other
|
||||
arg is specified.
|
||||
|
||||
Returns:
|
||||
mapping: New instance of `_Mapping` which has inputs merged with self.
|
||||
|
||||
Raises:
|
||||
ValueError: if mapping and any other arg is not `None`.
|
||||
"""
|
||||
if mapping is None:
|
||||
mapping = _Mapping(x=x, y=y, ildj=ildj,
|
||||
condition_kwargs=condition_kwargs)
|
||||
elif not all([arg is None for arg in [x, y, ildj, condition_kwargs]]):
|
||||
raise ValueError("Cannot specify mapping and individual args.")
|
||||
return _Mapping(
|
||||
x=self._merge(self.x, mapping.x),
|
||||
y=self._merge(self.y, mapping.y),
|
||||
ildj=self._merge(self.ildj, mapping.ildj),
|
||||
condition_kwargs=self._merge(self.condition_kwargs,
|
||||
mapping.condition_kwargs))
|
||||
|
||||
def _merge(self, old, new):
|
||||
"""Helper to merge which handles merging one value."""
|
||||
if old is None:
|
||||
return new
|
||||
elif new is not None and old != new:
|
||||
raise ValueError("Incompatible values: %s != %s" % (old, new))
|
||||
return old
|
||||
|
||||
def _deep_tuple(self, x):
|
||||
"""Converts lists of lists to tuples of tuples."""
|
||||
return (tuple(map(self._deep_tuple, x))
|
||||
if isinstance(x, (list, tuple)) else x)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Bijector(object):
|
||||
"""Interface for transforming a `Distribution` via `TransformedDistribution`.
|
||||
"""Interface for transforming a `Distribution` sample.
|
||||
|
||||
A `Bijector` implements a bijective, differentiable function by transforming
|
||||
an input `Tensor`. The output `Tensor` shape is constrained by the input
|
||||
`sample`, `batch`, and `event` shape. A `Bijector` is characterized by three
|
||||
A `Bijector` implements a
|
||||
[diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism), i.e., a
|
||||
bijective, differentiable function. A `Bijector` is used by
|
||||
`TransformedDistribution` but can be generally used for transforming a
|
||||
`Distribution` generated `Tensor`. A `Bijector` is characterized by three
|
||||
operations:
|
||||
|
||||
1. Forward Evaluation
|
||||
@ -210,7 +307,8 @@ class Bijector(object):
|
||||
- The inverse `log o det o Jacobian` can be implemented as the negative of the
|
||||
forward `log o det o Jacobian`. This is useful if the `inverse` is
|
||||
implemented as a cache or the inverse Jacobian is computationally more
|
||||
expensive. The following demonstrates the suggested implementation.
|
||||
expensive (e.g., `CholeskyOuterProduct` `Bijector`). The following
|
||||
demonstrates the suggested implementation.
|
||||
|
||||
```python
|
||||
def _inverse_and_log_det_jacobian(self, y):
|
||||
@ -300,6 +398,11 @@ class Bijector(object):
|
||||
self._is_constant_jacobian = is_constant_jacobian
|
||||
self._validate_args = validate_args
|
||||
self._dtype = dtype
|
||||
self._from_y = {}
|
||||
self._from_x = {}
|
||||
# Using abbreviation ildj for "inverse log det Jacobian."
|
||||
# This variable is not `None` iff is_constant_jacobian is `True`.
|
||||
self._constant_ildj = None
|
||||
if name:
|
||||
self._name = name
|
||||
else:
|
||||
@ -368,7 +471,12 @@ class Bijector(object):
|
||||
with self._name_scope(name, [x]):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._maybe_assert_dtype(x)
|
||||
return self._forward(x, **condition_kwargs)
|
||||
mapping = self._lookup(x=x, condition_kwargs=condition_kwargs)
|
||||
if mapping.y is not None:
|
||||
return mapping.y
|
||||
mapping = mapping.merge(y=self._forward(x, **condition_kwargs))
|
||||
self._cache(mapping)
|
||||
return mapping.y
|
||||
|
||||
def _inverse(self, y):
|
||||
raise NotImplementedError("inverse is not implemented")
|
||||
@ -393,16 +501,28 @@ class Bijector(object):
|
||||
with self._name_scope(name, [y]):
|
||||
y = ops.convert_to_tensor(y, name="y")
|
||||
self._maybe_assert_dtype(y)
|
||||
mapping = self._lookup(y=y, condition_kwargs=condition_kwargs)
|
||||
if mapping.x is not None:
|
||||
return mapping.x
|
||||
ildj = None
|
||||
try:
|
||||
return self._inverse(y, **condition_kwargs)
|
||||
x = self._inverse(y, **condition_kwargs)
|
||||
except NotImplementedError as original_error:
|
||||
# Since _inverse was not implemented, try to see if it's implemented
|
||||
# by the _inverse_and_inverse_log_det_jacobian member.
|
||||
try:
|
||||
return self._inverse_and_inverse_log_det_jacobian(
|
||||
y, **condition_kwargs)[0]
|
||||
x, ildj = self._inverse_and_inverse_log_det_jacobian(
|
||||
y, **condition_kwargs)
|
||||
if self._constant_ildj is not None:
|
||||
ildj = self._constant_ildj # Use the "global" result.
|
||||
elif self.is_constant_jacobian:
|
||||
self._constant_ildj = ildj
|
||||
except NotImplementedError:
|
||||
raise original_error
|
||||
x = x if mapping.x is None else mapping.x
|
||||
mapping = mapping.merge(x=x, ildj=ildj)
|
||||
self._cache(mapping)
|
||||
return mapping.x
|
||||
|
||||
def _inverse_log_det_jacobian(self, y):
|
||||
raise NotImplementedError("inverse_log_det_jacobian is not implemented.")
|
||||
@ -430,18 +550,32 @@ class Bijector(object):
|
||||
`_inverse_and_inverse_log_det_jacobian` are implemented.
|
||||
"""
|
||||
with self._name_scope(name, [y]):
|
||||
if self._constant_ildj is not None:
|
||||
return self._constant_ildj
|
||||
y = ops.convert_to_tensor(y, name="y")
|
||||
self._maybe_assert_dtype(y)
|
||||
mapping = self._lookup(y=y, condition_kwargs=condition_kwargs)
|
||||
if mapping.ildj is not None:
|
||||
return mapping.ildj
|
||||
try:
|
||||
return self._inverse_log_det_jacobian(y, **condition_kwargs)
|
||||
x = mapping.x
|
||||
ildj = self._inverse_log_det_jacobian(y, **condition_kwargs)
|
||||
except NotImplementedError as original_error:
|
||||
# Since _inverse_log_det_jacobian was not implemented, try to see if
|
||||
# it's implemented by the _inverse_and_inverse_log_det_jacobian member.
|
||||
try:
|
||||
return self._inverse_and_inverse_log_det_jacobian(
|
||||
y, **condition_kwargs)[1]
|
||||
x, ildj = self._inverse_and_inverse_log_det_jacobian(
|
||||
y, **condition_kwargs)
|
||||
if mapping.x is not None:
|
||||
x = mapping.x
|
||||
except NotImplementedError:
|
||||
raise original_error
|
||||
if self.is_constant_jacobian:
|
||||
self._constant_ildj = ildj
|
||||
x = x if mapping.x is None else mapping.x
|
||||
mapping = mapping.merge(x=x, ildj=ildj)
|
||||
self._cache(mapping)
|
||||
return mapping.ildj
|
||||
|
||||
def _inverse_and_inverse_log_det_jacobian(self, y):
|
||||
raise NotImplementedError(
|
||||
@ -473,18 +607,30 @@ class Bijector(object):
|
||||
with self._name_scope(name, [y]):
|
||||
y = ops.convert_to_tensor(y, name="y")
|
||||
self._maybe_assert_dtype(y)
|
||||
mapping = self._lookup(y=y, condition_kwargs=condition_kwargs)
|
||||
if mapping.x is not None and mapping.ildj is not None:
|
||||
return mapping.x, mapping.ildj
|
||||
try:
|
||||
return self._inverse_and_inverse_log_det_jacobian(
|
||||
x, ildj = self._inverse_and_inverse_log_det_jacobian(
|
||||
y, **condition_kwargs)
|
||||
except NotImplementedError as original_error:
|
||||
# Since _inverse_and_inverse_log_det_jacobian was not implemented, try
|
||||
# to see if we can separately use _inverse and
|
||||
# _inverse_log_det_jacobian members.
|
||||
try:
|
||||
return (self._inverse(y, **condition_kwargs),
|
||||
self._inverse_log_det_jacobian(y, **condition_kwargs))
|
||||
x = self._inverse(y, **condition_kwargs)
|
||||
if self._constant_ildj is None:
|
||||
ildj = self._inverse_log_det_jacobian(y, **condition_kwargs)
|
||||
except NotImplementedError:
|
||||
raise original_error
|
||||
if self._constant_ildj is not None:
|
||||
ildj = self._constant_ildj # Ignore any ildj we may/not have.
|
||||
elif self.is_constant_jacobian:
|
||||
self._constant_ildj = ildj
|
||||
x = x if mapping.x is None else mapping.x
|
||||
mapping = mapping.merge(x=x, ildj=ildj)
|
||||
self._cache(mapping)
|
||||
return mapping.x, mapping.ildj
|
||||
|
||||
def _forward_log_det_jacobian(self, x):
|
||||
raise NotImplementedError(
|
||||
@ -509,16 +655,29 @@ class Bijector(object):
|
||||
nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented.
|
||||
"""
|
||||
with self._name_scope(name, [x]):
|
||||
if self._constant_ildj is not None:
|
||||
# Need "-1. *" to avoid invalid-unary-operand-type linter warning.
|
||||
return -1. * self._constant_ildj
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._maybe_assert_dtype(x)
|
||||
mapping = self._lookup(x=x, condition_kwargs=condition_kwargs)
|
||||
if mapping.ildj is not None:
|
||||
return -mapping.ildj
|
||||
y = None
|
||||
try:
|
||||
return self._forward_log_det_jacobian(x, **condition_kwargs)
|
||||
ildj = -self._forward_log_det_jacobian(x, **condition_kwargs)
|
||||
except NotImplementedError as original_error:
|
||||
try:
|
||||
y = self.inverse(x, **condition_kwargs)
|
||||
return -self.inverse_log_det_jacobian(y, **condition_kwargs)
|
||||
y = self.inverse(x, **condition_kwargs) if y is None else y
|
||||
ildj = self.inverse_log_det_jacobian(y, **condition_kwargs)
|
||||
except NotImplementedError:
|
||||
raise original_error
|
||||
if self.is_constant_jacobian:
|
||||
self._constant_ildj = ildj
|
||||
y = y if mapping.y is None else mapping.y
|
||||
mapping = mapping.merge(y=y, ildj=ildj)
|
||||
self._cache(mapping)
|
||||
return -mapping.ildj
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _name_scope(self, name=None, values=None):
|
||||
@ -534,6 +693,31 @@ class Bijector(object):
|
||||
raise TypeError("Input had dtype %s but expected %s." %
|
||||
(self.dtype, x.dtype))
|
||||
|
||||
def _cache(self, mapping):
|
||||
"""Helper which stores mapping info in forward/inverse dicts."""
|
||||
if self._constant_ildj is not None:
|
||||
# Fold in ildj if known constant Jacobian.
|
||||
mapping = mapping.merge(ildj=self._constant_ildj)
|
||||
# Merging from lookup is an added check that we're not overwriting anything
|
||||
# which is not None.
|
||||
mapping = mapping.merge(mapping=self._lookup(
|
||||
mapping.x, mapping.y, mapping.condition_kwargs))
|
||||
if mapping.x is None or mapping.y is None:
|
||||
ValueError("Caching expects both (x,y) to be known, i.e., not None.")
|
||||
self._from_x[mapping.x_key] = mapping
|
||||
self._from_y[mapping.y_key] = mapping
|
||||
|
||||
def _lookup(self, x=None, y=None, condition_kwargs=None):
|
||||
"""Helper which retrieves mapping info from forward/inverse dicts."""
|
||||
mapping = _Mapping(x=x, y=y, condition_kwargs=condition_kwargs)
|
||||
# Since _cache requires both x,y to be set, we only need to do one cache
|
||||
# lookup since the mapping is always in both or neither.
|
||||
if mapping.x is not None:
|
||||
return self._from_x.get(mapping.x_key, mapping)
|
||||
if mapping.y is not None:
|
||||
return self._from_y.get(mapping.y_key, mapping)
|
||||
return mapping
|
||||
|
||||
|
||||
class Inline(Bijector):
|
||||
# pylint: disable=line-too-long
|
||||
@ -547,7 +731,7 @@ class Inline(Bijector):
|
||||
inverse_fn=tf.log,
|
||||
inverse_log_det_jacobian_fn=(
|
||||
lambda y: -tf.reduce_sum(tf.log(y), reduction_indices=-1)),
|
||||
name="Exp")
|
||||
name="exp")
|
||||
```
|
||||
|
||||
The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`.
|
||||
@ -573,8 +757,8 @@ class Inline(Bijector):
|
||||
log o det o jacobian of the forward transformation.
|
||||
is_constant_jacobian: `Boolean` indicating that the Jacobian is constant
|
||||
for all input arguments.
|
||||
validate_args: `Boolean` indicated whether arguments should be checked for
|
||||
correctness.
|
||||
validate_args: `Boolean` indicating whether arguments should be checked
|
||||
for correctness.
|
||||
name: `String`, name given to ops managed by this object.
|
||||
"""
|
||||
super(Inline, self).__init__(
|
||||
@ -643,8 +827,8 @@ class Invert(Bijector):
|
||||
|
||||
Args:
|
||||
bijector: Bijector instance.
|
||||
validate_args: `Boolean` indicated whether arguments should be checked for
|
||||
correctness.
|
||||
validate_args: `Boolean` indicating whether arguments should be checked
|
||||
for correctness.
|
||||
name: `String`, name given to ops managed by this object.
|
||||
"""
|
||||
|
||||
@ -713,8 +897,8 @@ class Chain(Bijector):
|
||||
Args:
|
||||
bijectors: Python list of bijector instances. An empty list makes this
|
||||
bijector equivalent to the `Identity` bijector.
|
||||
validate_args: `Boolean` indicated whether arguments should be checked for
|
||||
correctness.
|
||||
validate_args: `Boolean` indicating whether arguments should be checked
|
||||
for correctness.
|
||||
name: `String`, name given to ops managed by this object. Default: E.g.,
|
||||
`Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`.
|
||||
|
||||
@ -794,12 +978,9 @@ class Identity(Bijector):
|
||||
|
||||
def __init__(self, validate_args=False, name="identity"):
|
||||
super(Identity, self).__init__(
|
||||
batch_ndims=0,
|
||||
event_ndims=0,
|
||||
is_constant_jacobian=True,
|
||||
validate_args=validate_args,
|
||||
name=name)
|
||||
self._is_constant_jacobian = True
|
||||
|
||||
def _forward(self, x):
|
||||
return x
|
||||
@ -841,8 +1022,8 @@ class Exp(Bijector):
|
||||
Args:
|
||||
event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
|
||||
associated with a particular draw from the distribution.
|
||||
validate_args: `Boolean` indicated whether arguments should be checked for
|
||||
correctness.
|
||||
validate_args: `Boolean` indicating whether arguments should be checked
|
||||
for correctness.
|
||||
name: `String` name given to ops managed by this object.
|
||||
"""
|
||||
|
||||
@ -923,8 +1104,8 @@ class ScaleAndShift(Bijector):
|
||||
scale: `Tensor` used to scale input, i.e., `Y = g(X) = scale * X + shift`.
|
||||
event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
|
||||
associated with a particular draw from the distribution.
|
||||
validate_args: `Boolean` indicated whether arguments should be checked for
|
||||
correctness.
|
||||
validate_args: `Boolean` indicating whether arguments should be checked
|
||||
for correctness.
|
||||
name: `String` name given to ops managed by this object.
|
||||
"""
|
||||
|
||||
@ -1271,3 +1452,150 @@ class SigmoidCentered(SoftmaxCentered):
|
||||
def __init__(self, validate_args=False, name="sigmoid_centered"):
|
||||
super(SigmoidCentered, self).__init__(
|
||||
validate_args=validate_args, name=name)
|
||||
|
||||
|
||||
class CholeskyOuterProduct(Bijector):
|
||||
# pylint: disable=line-too-long
|
||||
"""Bijector which computes Y = g(X) = X X^T where X is a lower-triangular, positive-diagonal matrix.
|
||||
|
||||
`event_ndims` must be 0 or 2, i.e., scalar or matrix.
|
||||
|
||||
Note: the upper-triangular part of X is ignored (whether or not its zero).
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]])
|
||||
# Result: [[1, 1], [1, 5]], i.e., x x^T
|
||||
|
||||
bijector.SoftmaxCentered(event_ndims=2).inverse(y=[[1., 1], [1, 5]])
|
||||
# Result: [[1, 0], [2, 1]], i.e., chol(y).
|
||||
```
|
||||
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
def __init__(self, event_ndims=2, validate_args=False,
|
||||
name="cholesky_outer_product"):
|
||||
"""Instantiates the `CholeskyOuterProduct` bijector.
|
||||
|
||||
Args:
|
||||
event_ndims: `constant` `int32` scalar `Tensor` indicating the number of
|
||||
dimensions associated with a particular draw from the distribution. Must
|
||||
be 0 or 2.
|
||||
validate_args: `Boolean` indicating whether arguments should be checked
|
||||
for correctness.
|
||||
name: `String` name given to ops managed by this object.
|
||||
|
||||
Raises:
|
||||
ValueError: if event_ndims is neither 0 or 2.
|
||||
"""
|
||||
self._parameters = {}
|
||||
self._name = name
|
||||
with self._name_scope("init", values=[event_ndims]):
|
||||
event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
|
||||
event_ndims = tensor_util.constant_value(event_ndims)
|
||||
if event_ndims is None or event_ndims not in [0, 2]:
|
||||
raise ValueError("`event_ndims` must be a TF constant which is 0 or 2")
|
||||
self._static_event_ndims = event_ndims
|
||||
super(CholeskyOuterProduct, self).__init__(
|
||||
validate_args=validate_args,
|
||||
name=name)
|
||||
|
||||
def _forward(self, x):
|
||||
if self._static_event_ndims == 0:
|
||||
return math_ops.square(x)
|
||||
if self.validate_args:
|
||||
is_matrix = check_ops.assert_rank_at_least(x, 2)
|
||||
shape = array_ops.shape(x)
|
||||
is_square = check_ops.assert_equal(shape[-2], shape[-1])
|
||||
x = control_flow_ops.with_dependencies([is_matrix, is_square], x)
|
||||
# For safety, explicitly zero-out the upper triangular part.
|
||||
x = array_ops.matrix_band_part(x, -1, 0)
|
||||
return math_ops.batch_matmul(x, x, adj_y=True)
|
||||
|
||||
def _inverse_and_inverse_log_det_jacobian(self, y):
|
||||
x = (math_ops.sqrt(y) if self._static_event_ndims == 0
|
||||
else linalg_ops.cholesky(y))
|
||||
return x, -self._forward_log_det_jacobian(x)
|
||||
|
||||
def _forward_log_det_jacobian(self, x):
|
||||
# Let Y be a symmetric, positive definite matrix and write:
|
||||
# Y = X X^T
|
||||
# where X is lower-triangular.
|
||||
#
|
||||
# Observe that,
|
||||
# dY[i,j]/dX[a,b]
|
||||
# = d/dX[a,b] { X[i,:] X[j,:] }
|
||||
# = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] }
|
||||
#
|
||||
# To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is
|
||||
# symmetric and X is lower-triangular, we need vectors of dimension:
|
||||
# d = p (p + 1) / 2
|
||||
# where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e.,
|
||||
# k = { i (i + 1) / 2 + j i>=j
|
||||
# { undef i<j
|
||||
# and assume zero-based indexes. When k is undef, the element is dropped.
|
||||
# Example:
|
||||
# j k
|
||||
# 0 1 2 3 /
|
||||
# 0 [ 0 . . . ]
|
||||
# i 1 [ 1 2 . . ]
|
||||
# 2 [ 3 4 5 . ]
|
||||
# 3 [ 6 7 8 9 ]
|
||||
# Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With
|
||||
# slight abuse: k(i,j)=undef means the element is dropped.)
|
||||
#
|
||||
# We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are
|
||||
# defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b.
|
||||
# In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since:
|
||||
# (1) j<=i<a thus i,j!=a.
|
||||
# (2) i=a>j thus i,j!=a.
|
||||
#
|
||||
# Since the Jacobian is lower-triangular, we need only compute the product
|
||||
# of diagonal elements:
|
||||
# d vec[Y] / d vec[X] @[k(i,j), k(i,j)]
|
||||
# = X[j,j] + I[i=j] X[i,j]
|
||||
# = 2 X[j,j].
|
||||
# Since there is a 2 X[j,j] term for every lower-triangular element of X we
|
||||
# conclude:
|
||||
# |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
|
||||
if self._static_event_ndims == 0:
|
||||
if self.validate_args:
|
||||
is_positive = check_ops.assert_positive(
|
||||
x, message="All elements must be positive.")
|
||||
x = control_flow_ops.with_dependencies([is_positive], x)
|
||||
return math.log(2.) + math_ops.log(x)
|
||||
|
||||
diag = array_ops.matrix_diag_part(x)
|
||||
if self.validate_args:
|
||||
is_matrix = check_ops.assert_rank_at_least(
|
||||
x, 2, message="Input must be a (batch of) matrix.")
|
||||
shape = array_ops.shape(x)
|
||||
is_square = check_ops.assert_equal(
|
||||
shape[-2], shape[-1],
|
||||
message="Input must be a (batch of) square matrix.")
|
||||
# Assuming lower-triangular means we only need check diag>0.
|
||||
is_positive_definite = check_ops.assert_positive(
|
||||
diag, message="Input must be positive definite.")
|
||||
x = control_flow_ops.with_dependencies(
|
||||
[is_matrix, is_square, is_positive_definite], x)
|
||||
|
||||
# Create a column vector equal to: [p, p-1, ..., 2, 1]^T.
|
||||
if x.get_shape().ndims is None or x.get_shape()[-1].value is None:
|
||||
p = array_ops.shape(x)[-1]
|
||||
else:
|
||||
p = x.get_shape()[-1].value
|
||||
exponents = array_ops.expand_dims(
|
||||
math_ops.linspace(math_ops.cast(p, dtype=x.dtype), 1., p),
|
||||
dim=1)
|
||||
|
||||
sum_weighted_log_diag = array_ops.squeeze(
|
||||
math_ops.batch_matmul(math_ops.log(diag), exponents),
|
||||
squeeze_dims=-1)
|
||||
fldj = p * math.log(2.) + sum_weighted_log_diag
|
||||
|
||||
if x.get_shape().ndims is not None:
|
||||
fldj.set_shape(x.get_shape()[:-2])
|
||||
|
||||
return fldj
|
||||
|
@ -327,12 +327,13 @@ class Distribution(_BaseDistribution):
|
||||
for i, t in enumerate(graph_parents):
|
||||
if t is None or not contrib_framework.is_tensor(t):
|
||||
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
|
||||
parameters = parameters or {}
|
||||
self._dtype = dtype
|
||||
self._is_continuous = is_continuous
|
||||
self._is_reparameterized = is_reparameterized
|
||||
self._allow_nan_stats = allow_nan_stats
|
||||
self._validate_args = validate_args
|
||||
self._parameters = parameters or {}
|
||||
self._parameters = parameters
|
||||
self._graph_parents = graph_parents
|
||||
self._name = name or type(self).__name__
|
||||
|
||||
@ -434,6 +435,27 @@ class Distribution(_BaseDistribution):
|
||||
"""Python boolean indicated possibly expensive checks are enabled."""
|
||||
return self._validate_args
|
||||
|
||||
def copy(self, **override_parameters_kwargs):
|
||||
"""Creates a deep copy of the distribution.
|
||||
|
||||
Note: the copy distribution may continue to depend on the original
|
||||
intialization arguments.
|
||||
|
||||
Args:
|
||||
**override_parameters_kwargs: String/value dictionary of initialization
|
||||
arguments to override with new values.
|
||||
|
||||
Returns:
|
||||
distribution: A new instance of `type(self)` intitialized from the union
|
||||
of self.parameters and override_parameters_kwargs, i.e.,
|
||||
`dict(self.parameters, **override_parameters_kwargs)`.
|
||||
"""
|
||||
parameters = dict(self.parameters, **override_parameters_kwargs)
|
||||
# Python3 leaks "__class__" into `locals()` so we remove if present.
|
||||
# TODO(b/32376812): Remove this pop.
|
||||
parameters.pop("__class__", None)
|
||||
return type(self)(**parameters)
|
||||
|
||||
def _batch_shape(self):
|
||||
raise NotImplementedError("batch_shape is not implemented")
|
||||
|
||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution as distributions
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
@ -160,7 +159,6 @@ class TransformedDistribution(distributions.Distribution):
|
||||
name = name or bijector.name + distribution.name
|
||||
self._distribution = distribution
|
||||
self._bijector = bijector
|
||||
self._inverse_cache = {}
|
||||
super(TransformedDistribution, self).__init__(
|
||||
dtype=self._distribution.dtype,
|
||||
is_continuous=self._distribution.is_continuous,
|
||||
@ -202,9 +200,7 @@ class TransformedDistribution(distributions.Distribution):
|
||||
**distribution_kwargs)
|
||||
# Recall that a bijector is named for its forward transform, i.e.,
|
||||
# `Y = g(X)`,
|
||||
y = self.bijector.forward(x, **bijector_kwargs)
|
||||
self._inverse_cache[y] = x
|
||||
return y
|
||||
return self.bijector.forward(x, **bijector_kwargs)
|
||||
|
||||
@distribution_util.AppendDocstring(
|
||||
"""Implements `(log o p o g^{-1})(y) + (log o det o J o g^{-1})(y)`,
|
||||
@ -216,11 +212,9 @@ class TransformedDistribution(distributions.Distribution):
|
||||
def _log_prob(self, y, bijector_kwargs=None, distribution_kwargs=None):
|
||||
bijector_kwargs = bijector_kwargs or {}
|
||||
distribution_kwargs = distribution_kwargs or {}
|
||||
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
|
||||
inverse_log_det_jacobian = self.bijector.inverse_log_det_jacobian(
|
||||
x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian(
|
||||
y, **bijector_kwargs)
|
||||
return (self.distribution.log_prob(x, **distribution_kwargs) +
|
||||
inverse_log_det_jacobian)
|
||||
return ildj + self.distribution.log_prob(x, **distribution_kwargs)
|
||||
|
||||
@distribution_util.AppendDocstring(
|
||||
"""Implements `p(g^{-1}(y)) det|J(g^{-1}(y))|`, where `g^{-1}` is the
|
||||
@ -232,18 +226,16 @@ class TransformedDistribution(distributions.Distribution):
|
||||
def _prob(self, y, bijector_kwargs=None, distribution_kwargs=None):
|
||||
bijector_kwargs = bijector_kwargs or {}
|
||||
distribution_kwargs = distribution_kwargs or {}
|
||||
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
|
||||
inverse_det_jacobian = math_ops.exp(self.bijector.inverse_log_det_jacobian(
|
||||
y, **bijector_kwargs))
|
||||
return (self.distribution.prob(x, **distribution_kwargs) *
|
||||
inverse_det_jacobian)
|
||||
x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian(
|
||||
y, **bijector_kwargs)
|
||||
return math_ops.exp(ildj) * self.distribution.prob(x, **distribution_kwargs)
|
||||
|
||||
@distribution_util.AppendDocstring(
|
||||
condition_kwargs_dict=_condition_kwargs_dict)
|
||||
def _log_cdf(self, y, bijector_kwargs=None, distribution_kwargs=None):
|
||||
bijector_kwargs = bijector_kwargs or {}
|
||||
distribution_kwargs = distribution_kwargs or {}
|
||||
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
|
||||
x = self.bijector.inverse(y, **bijector_kwargs)
|
||||
return self.distribution.log_cdf(x, distribution_kwargs)
|
||||
|
||||
@distribution_util.AppendDocstring(
|
||||
@ -251,7 +243,7 @@ class TransformedDistribution(distributions.Distribution):
|
||||
def _cdf(self, y, bijector_kwargs=None, distribution_kwargs=None):
|
||||
bijector_kwargs = bijector_kwargs or {}
|
||||
distribution_kwargs = distribution_kwargs or {}
|
||||
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
|
||||
x = self.bijector.inverse(y, **bijector_kwargs)
|
||||
return self.distribution.cdf(x, **distribution_kwargs)
|
||||
|
||||
@distribution_util.AppendDocstring(
|
||||
@ -260,7 +252,7 @@ class TransformedDistribution(distributions.Distribution):
|
||||
bijector_kwargs=None, distribution_kwargs=None):
|
||||
bijector_kwargs = bijector_kwargs or {}
|
||||
distribution_kwargs = distribution_kwargs or {}
|
||||
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
|
||||
x = self.bijector.inverse(y, **bijector_kwargs)
|
||||
return self.distribution.log_survival_function(x, **distribution_kwargs)
|
||||
|
||||
@distribution_util.AppendDocstring(
|
||||
@ -269,13 +261,5 @@ class TransformedDistribution(distributions.Distribution):
|
||||
bijector_kwargs=None, distribution_kwargs=None):
|
||||
bijector_kwargs = bijector_kwargs or {}
|
||||
distribution_kwargs = distribution_kwargs or {}
|
||||
x = self._inverse_possibly_from_cache(y, bijector_kwargs)
|
||||
x = self.bijector.inverse(y, **bijector_kwargs)
|
||||
return self.distribution.survival_function(x, **distribution_kwargs)
|
||||
|
||||
def _inverse_possibly_from_cache(self, y, bijector_kwargs):
|
||||
"""Return `self._inverse(y)`, possibly using cached value."""
|
||||
y = ops.convert_to_tensor(y, name="y")
|
||||
if y in self._inverse_cache:
|
||||
return self._inverse_cache[y]
|
||||
else:
|
||||
return self.bijector.inverse(y, **bijector_kwargs)
|
||||
|
@ -327,6 +327,6 @@ if __name__ == '__main__':
|
||||
default=True,
|
||||
help='Use fake input data.'
|
||||
)
|
||||
FLAGS = parser.parse_args()
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
tf.test.main()
|
||||
|
@ -243,6 +243,7 @@ class KMeansClustering(estimator.Estimator,
|
||||
).training_graph()
|
||||
incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
|
||||
self._loss = tf.reduce_sum(losses)
|
||||
tf.scalar_summary('loss/raw', self._loss)
|
||||
training_op = with_dependencies([training_op, incr_step], self._loss)
|
||||
return training_op, self._loss
|
||||
|
||||
|
@ -24,16 +24,20 @@ from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops import variables as vars_
|
||||
from tensorflow.python.training import moving_averages
|
||||
from tensorflow.python.training import optimizer as optimizer_
|
||||
from tensorflow.python.training import training as train
|
||||
|
||||
|
||||
OPTIMIZER_CLS_NAMES = {
|
||||
"Adagrad": train.AdagradOptimizer,
|
||||
"Adam": train.AdamOptimizer,
|
||||
@ -104,7 +108,11 @@ def optimize_loss(loss,
|
||||
gradient_multipliers: dict of variables or variable names to floats.
|
||||
If present, gradients for specified
|
||||
variables will be multiplied by given constant.
|
||||
clip_gradients: float or `None`, clips gradients by this value.
|
||||
clip_gradients: float, callable or `None`. If float, is provided, a global
|
||||
clipping is applied to prevent the norm of the gradient to exceed this
|
||||
value. Alternatively, a callable can be provided e.g.: adaptive_clipping.
|
||||
This callable takes a `list` of `(gradients, variables)` `tuple`s and
|
||||
returns the same thing with the gradients modified.
|
||||
learning_rate_decay_fn: function, takes `learning_rate` and `global_step`
|
||||
`Tensor`s, returns `Tensor`.
|
||||
Can be used to implement any learning rate decay
|
||||
@ -132,6 +140,7 @@ def optimize_loss(loss,
|
||||
* `global_step` is an invalid type or shape.
|
||||
* `learning_rate` is an invalid type or value.
|
||||
* `optimizer` is wrong type.
|
||||
* `clip_gradients' is not float or callable.
|
||||
* `learning_rate` and `learning_rate_decay_fn` are supplied, but no
|
||||
`global_step` is available.
|
||||
"""
|
||||
@ -224,9 +233,18 @@ def optimize_loss(loss,
|
||||
if gradient_multipliers is not None:
|
||||
gradients = _multiply_gradients(gradients, gradient_multipliers)
|
||||
|
||||
if "gradient_norm" in summaries:
|
||||
logging_ops.scalar_summary("global_norm/gradient_norm",
|
||||
clip_ops.global_norm(zip(*gradients)[0]))
|
||||
|
||||
# Optionally clip gradients by global norm.
|
||||
if clip_gradients is not None:
|
||||
if isinstance(clip_gradients, float):
|
||||
gradients = _clip_gradients_by_norm(gradients, clip_gradients)
|
||||
elif callable(clip_gradients):
|
||||
gradients = clip_gradients(gradients)
|
||||
elif clip_gradients is not None:
|
||||
raise ValueError(
|
||||
"Unknown type %s for clip_gradients" % type(clip_gradients))
|
||||
|
||||
# Add scalar summary for loss.
|
||||
if "loss" in summaries:
|
||||
@ -241,11 +259,15 @@ def optimize_loss(loss,
|
||||
|
||||
if grad_values is not None:
|
||||
if "gradients" in summaries:
|
||||
logging_ops.histogram_summary(variable.name + "/gradients",
|
||||
logging_ops.histogram_summary("gradients/" + variable.name,
|
||||
grad_values)
|
||||
if "gradient_norm" in summaries:
|
||||
logging_ops.histogram_summary(variable.name + "/gradient_norm",
|
||||
clip_ops.global_norm([grad_values]))
|
||||
logging_ops.scalar_summary("gradient_norm/" + variable.name,
|
||||
clip_ops.global_norm([grad_values]))
|
||||
|
||||
if clip_gradients is not None and "gradient_norm" in summaries:
|
||||
logging_ops.scalar_summary("global_norm/clipped_gradient_norm",
|
||||
clip_ops.global_norm(zip(*gradients)[0]))
|
||||
|
||||
# Create gradient updates.
|
||||
grad_updates = opt.apply_gradients(gradients,
|
||||
@ -266,6 +288,101 @@ def _clip_gradients_by_norm(grads_and_vars, clip_gradients):
|
||||
return list(zip(clipped_gradients, variables))
|
||||
|
||||
|
||||
def _adaptive_max_norm(norm, std_factor, decay, global_step, epsilon, name):
|
||||
"""Find max_norm given norm and previous average."""
|
||||
with vs.variable_scope(name, "AdaptiveMaxNorm", [norm]):
|
||||
log_norm = math_ops.log(norm + epsilon)
|
||||
|
||||
def moving_average(name, value, decay):
|
||||
moving_average_variable = vs.get_variable(
|
||||
name, shape=value.get_shape(), dtype=value.dtype,
|
||||
initializer=init_ops.zeros_initializer, trainable=False)
|
||||
return moving_averages.assign_moving_average(
|
||||
moving_average_variable, value, decay)
|
||||
|
||||
# quicker adaptation at the beginning
|
||||
if global_step is not None:
|
||||
n = math_ops.to_float(global_step)
|
||||
decay = math_ops.minimum(decay, n / (n + 1.))
|
||||
|
||||
# update averages
|
||||
mean = moving_average("mean", log_norm, decay)
|
||||
sq_mean = moving_average("sq_mean", math_ops.square(log_norm), decay)
|
||||
|
||||
variance = sq_mean - math_ops.square(mean)
|
||||
std = math_ops.sqrt(math_ops.maximum(epsilon, variance))
|
||||
max_norms = math_ops.exp(mean + std_factor*std)
|
||||
return max_norms, mean
|
||||
|
||||
|
||||
def adaptive_clipping_fn(std_factor=2.,
|
||||
decay=0.95,
|
||||
static_max_norm=None,
|
||||
global_step=None,
|
||||
report_summary=False,
|
||||
epsilon=1e-8,
|
||||
name=None):
|
||||
"""Adapt the clipping value using statistics on the norms.
|
||||
|
||||
Implement adaptive gradient as presented in section 3.2.1 of
|
||||
https://arxiv.org/abs/1412.1602.
|
||||
|
||||
Keeps a moving average of the mean and std of the log(norm) of the gradient.
|
||||
if the norm exceeds `exp(mean + std_factor*std)`, all gradients are rescaled
|
||||
such that the global norm becomes `exp(mean)`.
|
||||
|
||||
Args:
|
||||
std_factor: Python scaler (or tensor).
|
||||
`max_norm = exp(mean + std_factor*std)`
|
||||
decay: The smoothing factor of the moving averages.
|
||||
static_max_norm: If provided, will threshold the norm to this value as an
|
||||
extra safety.
|
||||
global_step: Optional global_step. If provided, `decay = decay*n/(n+1)`.
|
||||
This provides a quicker adaptation of the mean for the first steps.
|
||||
report_summary: If `True`, will add histogram summaries of the `max_norm`.
|
||||
epsilon: Small value chosen to avoid zero variance.
|
||||
name: The name for this operation is used to scope operations and summaries.
|
||||
|
||||
Returns:
|
||||
A function for applying gradient clipping.
|
||||
"""
|
||||
def gradient_clipping(grads_and_vars):
|
||||
"""Internal function for adaptive clipping."""
|
||||
grads, variables = zip(*grads_and_vars)
|
||||
|
||||
norm = clip_ops.global_norm(grads)
|
||||
|
||||
max_norm, log_mean = _adaptive_max_norm(
|
||||
norm, std_factor, decay, global_step, epsilon, name)
|
||||
|
||||
# reports the max gradient norm for debugging
|
||||
if report_summary:
|
||||
logging_ops.scalar_summary(
|
||||
"global_norm/adaptive_max_gradient_norm", max_norm)
|
||||
|
||||
# factor will be 1. if norm is smaller than max_norm
|
||||
factor = math_ops.select(norm < max_norm,
|
||||
array_ops.ones_like(norm),
|
||||
math_ops.exp(log_mean) / norm)
|
||||
|
||||
if static_max_norm is not None:
|
||||
factor = math_ops.minimum(static_max_norm / norm, factor)
|
||||
|
||||
# apply factor
|
||||
clipped_grads = []
|
||||
for grad in grads:
|
||||
if grad is None:
|
||||
clipped_grads.append(None)
|
||||
elif isinstance(grad, ops.IndexedSlices):
|
||||
clipped_grads.append(ops.IndexedSlices(
|
||||
grad.values * factor, grad.indices, grad.dense_shape))
|
||||
else:
|
||||
clipped_grads.append(grad * factor)
|
||||
|
||||
return list(zip(clipped_grads, variables))
|
||||
return gradient_clipping
|
||||
|
||||
|
||||
def _add_scaled_noise_to_gradients(grads_and_vars, gradient_noise_scale):
|
||||
"""Adds scaled noise from a 0-mean normal distribution to gradients."""
|
||||
gradients, variables = zip(*grads_and_vars)
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@ -179,6 +180,26 @@ class OptimizersTest(tf.test.TestCase):
|
||||
self.assertAlmostEqual(var_value, 9.98999, 4)
|
||||
self.assertEqual(global_step_value, 1)
|
||||
|
||||
def testAdaptiveGradientClip(self):
|
||||
with self.test_session() as session:
|
||||
x, var, loss, global_step = _setup_model()
|
||||
clip_gradients = tf.contrib.layers.adaptive_clipping_fn()
|
||||
train = tf.contrib.layers.optimize_loss(loss,
|
||||
global_step,
|
||||
learning_rate=0.1,
|
||||
optimizer="SGD",
|
||||
clip_gradients=clip_gradients)
|
||||
tf.initialize_all_variables().run()
|
||||
session.run(train, feed_dict={x: 5})
|
||||
var_value, global_step_value = session.run([var, global_step])
|
||||
self.assertAlmostEqual(var_value, 9.8916, 4)
|
||||
self.assertEqual(global_step_value, 1)
|
||||
var_count = 0
|
||||
for var in tf.all_variables():
|
||||
if var.name.startswith("OptimizeLoss/AdaptiveMaxNorm"):
|
||||
var_count += 1
|
||||
self.assertEqual(2, var_count)
|
||||
|
||||
def testGradientMultiply(self):
|
||||
with self.test_session() as session:
|
||||
x, var, loss, global_step = _setup_model()
|
||||
@ -332,5 +353,70 @@ class OptimizersTest(tf.test.TestCase):
|
||||
self.assertEqual(update_var_value, 20)
|
||||
self.assertEqual(global_step_value, 1)
|
||||
|
||||
|
||||
class AdaptiveClipping(tf.test.TestCase):
|
||||
|
||||
def testAverages(self):
|
||||
with self.test_session() as session:
|
||||
scale = 2.
|
||||
grad = tf.ones([3, 4]) * scale
|
||||
log_norm = np.log(np.sqrt(scale**2 * grad.get_shape().num_elements()))
|
||||
grads_and_vars = [(grad, grad)]
|
||||
grads_and_vars = tf.contrib.layers.adaptive_clipping_fn(
|
||||
decay=0.5)(grads_and_vars)
|
||||
|
||||
var_dict = {}
|
||||
for var in tf.all_variables():
|
||||
if var.name.startswith("AdaptiveMaxNorm"):
|
||||
var_dict[var.name.split(":")[0]] = var
|
||||
self.assertEqual(2, len(var_dict))
|
||||
moving_mean = var_dict["AdaptiveMaxNorm/mean"]
|
||||
moving_sq_mean = var_dict["AdaptiveMaxNorm/sq_mean"]
|
||||
tf.initialize_all_variables().run()
|
||||
mean, sq_mean = session.run([moving_mean, moving_sq_mean])
|
||||
self.assertEqual([0], mean)
|
||||
self.assertEqual([0], sq_mean)
|
||||
for i in range(20):
|
||||
mean, sq_mean, _ = session.run(
|
||||
[moving_mean, moving_sq_mean, grads_and_vars[0][0]])
|
||||
if i == 0:
|
||||
self.assertLess(mean, 0.9 * log_norm)
|
||||
self.assertLess(sq_mean, 0.9 * log_norm**2)
|
||||
|
||||
self.assertAlmostEqual(float(mean), log_norm, places=4)
|
||||
self.assertAlmostEqual(float(sq_mean), log_norm**2, places=4)
|
||||
|
||||
def testClip(self):
|
||||
with self.test_session() as session:
|
||||
spike = 1000.
|
||||
multiplier = tf.placeholder(tf.float32, [], "multiplier")
|
||||
step = tf.placeholder(tf.int32, [], "step")
|
||||
|
||||
grad = tf.ones([3, 4]) * multiplier
|
||||
grads_and_vars = [(grad, grad)]
|
||||
grads_and_vars = tf.contrib.layers.adaptive_clipping_fn(
|
||||
decay=0.9, global_step=step)(grads_and_vars)
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
def run(scale, i):
|
||||
return session.run(grads_and_vars[0][0],
|
||||
feed_dict={multiplier: scale, step: i})
|
||||
|
||||
for i in range(20):
|
||||
scale = [1., -2.][i % 2]
|
||||
clipped_grad = run(scale, i)
|
||||
if i > 3:
|
||||
self.assertAllClose(np.ones(clipped_grad.shape)*scale, clipped_grad)
|
||||
|
||||
# assert that the spike will have low influence.
|
||||
clipped_grad = run(spike, 20)
|
||||
self.assertTrue((clipped_grad < 25.).all())
|
||||
|
||||
# assert that a repeated spike will converge to this new value.
|
||||
for i in range(10):
|
||||
clipped_grad = run(spike, i + 21)
|
||||
|
||||
self.assertAllClose(np.ones(clipped_grad.shape)*spike, clipped_grad)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -35,6 +35,6 @@ from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassi
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossMonitor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
|
||||
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
|
||||
from tensorflow.contrib.learn.python.learn.estimators.svm import SVM
|
||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib import metrics as metrics_lib
|
||||
from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework import deprecated_arg_values
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.session_bundle import exporter
|
||||
@ -27,6 +28,8 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
|
||||
@deprecated('2016-11-30', 'Please write an appropriate function for use with'
|
||||
' your estimator.')
|
||||
def classification_signature_fn(examples, unused_features, predictions):
|
||||
"""Creates classification signature from given examples and predictions.
|
||||
|
||||
@ -61,6 +64,7 @@ class Classifier(estimator.Estimator):
|
||||
CLASS_OUTPUT = 'classes'
|
||||
PROBABILITY_OUTPUT = 'probabilities'
|
||||
|
||||
@deprecated('2016-11-30', 'Please use Estimator directly.')
|
||||
def __init__(self, model_fn, n_classes, model_dir=None, config=None,
|
||||
params=None, feature_engineering_fn=None):
|
||||
"""Constructor for Classifier.
|
||||
|
@ -309,7 +309,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
|
||||
inputs=rnn_outputs,
|
||||
num_outputs=self._target_column.num_label_columns,
|
||||
activation_fn=None,
|
||||
trainable=False)
|
||||
trainable=True)
|
||||
return activations, final_state
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -429,7 +429,7 @@ class SingleValueRNNEstimatorTest(tf.test.TestCase):
|
||||
cell_type = 'basic_rnn'
|
||||
cell_size = 8
|
||||
optimizer_type = 'Momentum'
|
||||
learning_rate = 0.5
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
loss_threshold = 0.1
|
||||
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import metrics as metrics_lib
|
||||
from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework import deprecated_arg_values
|
||||
from tensorflow.contrib.framework import get_graph_from_inputs
|
||||
from tensorflow.contrib.framework import list_variables
|
||||
from tensorflow.contrib.framework import load_variable
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
@ -88,8 +89,11 @@ class ModelFnOps(
|
||||
collections.namedtuple('ModelFnOps', ['predictions', 'loss', 'training_op',
|
||||
'default_metrics', 'signature_fn'])):
|
||||
|
||||
def __new__(cls, predictions, loss, training_op, default_metrics,
|
||||
signature_fn, mode):
|
||||
def __new__(cls, mode, predictions=None, loss=None, training_op=None,
|
||||
default_metrics=None, signature_fn=None):
|
||||
# Assert all ops are from the same graph.
|
||||
get_graph_from_inputs((predictions, loss, training_op))
|
||||
|
||||
# Validate training_op.
|
||||
if training_op is None:
|
||||
if mode == ModeKeys.TRAIN:
|
||||
@ -1042,13 +1046,16 @@ class Estimator(BaseEstimator):
|
||||
|
||||
if isinstance(model_fn_results, ModelFnOps):
|
||||
return model_fn_results
|
||||
else:
|
||||
# Here model_fn_ops should be a tuple with 3 elements.
|
||||
if len(model_fn_results) != 3:
|
||||
raise ValueError('Unrecognized value returned by model_fn, '
|
||||
'please return ModelFnOps.')
|
||||
return ModelFnOps(model_fn_results[0], model_fn_results[1],
|
||||
model_fn_results[2], None, None, mode)
|
||||
|
||||
# Here model_fn_ops should be a tuple with 3 elements.
|
||||
if len(model_fn_results) != 3:
|
||||
raise ValueError('Unrecognized value returned by model_fn, '
|
||||
'please return ModelFnOps.')
|
||||
return ModelFnOps(
|
||||
mode=mode,
|
||||
predictions=model_fn_results[0],
|
||||
loss=model_fn_results[1],
|
||||
training_op=model_fn_results[2])
|
||||
|
||||
def _get_train_ops(self, features, targets):
|
||||
"""Method that builds model graph and returns trainer ops.
|
||||
|
@ -229,20 +229,30 @@ class _Head(object):
|
||||
else:
|
||||
train_op = control_flow_ops.group(*additional_train_op)
|
||||
|
||||
return estimator.ModelFnOps(None, loss, train_op,
|
||||
self._default_metric(),
|
||||
self._create_signature_fn(), mode)
|
||||
return estimator.ModelFnOps(
|
||||
mode=estimator.ModeKeys.TRAIN,
|
||||
loss=loss,
|
||||
training_op=train_op,
|
||||
default_metrics=self._default_metric(),
|
||||
signature_fn=self._create_signature_fn())
|
||||
|
||||
if mode == estimator.ModeKeys.INFER:
|
||||
predictions = self._infer_op(logits, logits_input)
|
||||
return estimator.ModelFnOps(predictions, None, None,
|
||||
self._default_metric(),
|
||||
self._create_signature_fn(), mode)
|
||||
return estimator.ModelFnOps(
|
||||
mode=estimator.ModeKeys.INFER,
|
||||
predictions=self._infer_op(logits, logits_input),
|
||||
default_metrics=self._default_metric(),
|
||||
signature_fn=self._create_signature_fn())
|
||||
|
||||
if mode == estimator.ModeKeys.EVAL:
|
||||
predictions, loss = self._eval_op(features, target, logits, logits_input)
|
||||
return estimator.ModelFnOps(predictions, loss, None,
|
||||
self._default_metric(),
|
||||
self._create_signature_fn(), mode)
|
||||
raise ValueError("mode=%s unrecognized" % str(mode))
|
||||
return estimator.ModelFnOps(
|
||||
mode=estimator.ModeKeys.EVAL,
|
||||
predictions=predictions,
|
||||
loss=loss,
|
||||
default_metrics=self._default_metric(),
|
||||
signature_fn=self._create_signature_fn())
|
||||
|
||||
raise ValueError("mode=%s unrecognized." % str(mode))
|
||||
|
||||
@abc.abstractmethod
|
||||
def _training_loss(self, features, target, logits=None, logits_input=None,
|
||||
|
@ -17,25 +17,28 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.contrib.framework import deprecated_arg_values
|
||||
from tensorflow.contrib.learn.python.learn import monitors as mon
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.utils import export
|
||||
|
||||
from tensorflow.contrib.tensor_forest.client import eval_metrics
|
||||
from tensorflow.contrib.tensor_forest.data import data_ops
|
||||
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import session_run_hook
|
||||
|
||||
|
||||
KEYS_NAME = 'keys'
|
||||
LOSS_NAME = 'rf_training_loss'
|
||||
|
||||
|
||||
def _assert_float32(tensors):
|
||||
@ -56,58 +59,124 @@ def _assert_float32(tensors):
|
||||
raise TypeError('Expected dtype=float32, %s.' % tensor)
|
||||
|
||||
|
||||
class TensorForestLossMonitor(mon.EveryN):
|
||||
"""Terminates training when training loss stops decreasing."""
|
||||
class TensorForestLossHook(session_run_hook.SessionRunHook):
|
||||
"""Monitor to request stop when loss stops decreasing."""
|
||||
|
||||
def __init__(self,
|
||||
early_stopping_rounds,
|
||||
every_n_steps):
|
||||
super(TensorForestLossMonitor, self).__init__(every_n_steps=every_n_steps)
|
||||
def __init__(self, early_stopping_rounds):
|
||||
self.early_stopping_rounds = early_stopping_rounds
|
||||
self.min_loss = None
|
||||
self.min_loss_step = 0
|
||||
self.last_step = -1
|
||||
# self.steps records the number of steps for which the loss has been
|
||||
# non-decreasing
|
||||
self.steps = 0
|
||||
|
||||
def step_begin(self, step):
|
||||
super(TensorForestLossMonitor, self).step_begin(step)
|
||||
return [self._loss_op_name]
|
||||
def before_run(self, run_context):
|
||||
return session_run_hook.SessionRunArgs(
|
||||
{'global_step': contrib_framework.get_global_step(),
|
||||
'current_loss': run_context.session.graph.get_operation_by_name(
|
||||
LOSS_NAME).outputs[0]})
|
||||
|
||||
def set_estimator(self, est):
|
||||
"""This function gets called in the same graph as _get_train_ops."""
|
||||
super(TensorForestLossMonitor, self).set_estimator(est)
|
||||
self._loss_op_name = est.training_loss.name
|
||||
def after_run(self, run_context, run_values):
|
||||
current_loss = run_values.results['current_loss']
|
||||
current_step = run_values.results['global_step']
|
||||
self.steps += 1
|
||||
# Gaurd against the global step going backwards, which might happen
|
||||
# if we recover from something.
|
||||
if self.last_step == -1 or self.last_step > current_step:
|
||||
logging.info('TensorForestLossHook resetting last_step.')
|
||||
self.last_step = current_step
|
||||
self.steps = 0
|
||||
return
|
||||
|
||||
def every_n_step_end(self, step, outputs):
|
||||
super(TensorForestLossMonitor, self).every_n_step_end(step, outputs)
|
||||
current_loss = outputs[self._loss_op_name]
|
||||
if self.min_loss is None or current_loss < self.min_loss:
|
||||
self.min_loss = current_loss
|
||||
self.min_loss_step = step
|
||||
return step - self.min_loss_step >= self.early_stopping_rounds
|
||||
self.steps = 0
|
||||
if self.steps > self.early_stopping_rounds:
|
||||
logging.info('TensorForestLossHook requesting stop.')
|
||||
run_context.request_stop()
|
||||
|
||||
|
||||
class TensorForestEstimator(estimator.BaseEstimator):
|
||||
def get_model_fn(params, graph_builder_class, device_assigner,
|
||||
weights_name=None, keys_name=None):
|
||||
"""Return a model function given a way to construct a graph builder."""
|
||||
def _model_fn(features, targets):
|
||||
"""Function that returns predictions, training loss, and training op."""
|
||||
weights = None
|
||||
keys = None
|
||||
if weights_name and weights_name in features:
|
||||
weights = features.pop(weights_name)
|
||||
if keys_name and keys_name in features:
|
||||
keys = features.pop(keys_name)
|
||||
processed_features, spec = data_ops.ParseDataTensorOrDict(features)
|
||||
_assert_float32(processed_features)
|
||||
if targets is not None:
|
||||
targets = data_ops.ParseLabelTensorOrDict(targets)
|
||||
_assert_float32(targets)
|
||||
|
||||
graph_builder = graph_builder_class(params, device_assigner=device_assigner)
|
||||
inference = {eval_metrics.INFERENCE_PROB_NAME:
|
||||
graph_builder.inference_graph(processed_features,
|
||||
data_spec=spec)}
|
||||
if not params.regression:
|
||||
inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax(
|
||||
inference[eval_metrics.INFERENCE_PROB_NAME], 1)
|
||||
if keys:
|
||||
inference[KEYS_NAME] = keys
|
||||
|
||||
# targets might be None if we're doing prediction (which brings up the
|
||||
# question of why we force everything to adhere to a single model_fn).
|
||||
training_loss = None
|
||||
training_graph = None
|
||||
if targets is not None:
|
||||
training_loss = graph_builder.training_loss(processed_features, targets,
|
||||
data_spec=spec,
|
||||
name=LOSS_NAME)
|
||||
training_graph = control_flow_ops.group(
|
||||
graph_builder.training_graph(
|
||||
processed_features, targets, data_spec=spec,
|
||||
input_weights=weights),
|
||||
state_ops.assign_add(contrib_framework.get_global_step(), 1))
|
||||
# Put weights back in
|
||||
if weights is not None:
|
||||
features[weights_name] = weights
|
||||
return (inference, training_loss, training_graph)
|
||||
return _model_fn
|
||||
|
||||
|
||||
class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable):
|
||||
"""An estimator that can train and evaluate a random forest."""
|
||||
|
||||
def __init__(self, params, device_assigner=None, model_dir=None,
|
||||
graph_builder_class=tensor_forest.RandomForestGraphs,
|
||||
master='', accuracy_metric=None,
|
||||
tf_random_seed=None, config=None,
|
||||
feature_engineering_fn=None):
|
||||
config=None, weights_name=None, keys_name=None,
|
||||
feature_engineering_fn=None, early_stopping_rounds=100):
|
||||
self.params = params.fill()
|
||||
self.accuracy_metric = (accuracy_metric or
|
||||
('r2' if self.params.regression else 'accuracy'))
|
||||
self.data_feeder = None
|
||||
self.device_assigner = (
|
||||
device_assigner or tensor_forest.RandomForestDeviceAssigner())
|
||||
self.graph_builder_class = graph_builder_class
|
||||
self.training_args = {}
|
||||
self.construction_args = {}
|
||||
self._feature_engineering_fn = (
|
||||
feature_engineering_fn or
|
||||
(lambda features, targets: (features, targets)))
|
||||
self.early_stopping_rounds = early_stopping_rounds
|
||||
self._estimator = estimator.Estimator(
|
||||
model_fn=get_model_fn(params, graph_builder_class, device_assigner,
|
||||
weights_name=weights_name, keys_name=keys_name),
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
super(TensorForestEstimator, self).__init__(model_dir=model_dir,
|
||||
config=config)
|
||||
def evaluate(
|
||||
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
|
||||
steps=None, metrics=None, name=None):
|
||||
"""See evaluable.Evaluable."""
|
||||
return self._estimator.evaluate(
|
||||
input_fn=input_fn, x=x, y=y, feed_fn=feed_fn,
|
||||
batch_size=batch_size, steps=steps,
|
||||
metrics=metrics, name=name)
|
||||
|
||||
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
|
||||
monitors=None, max_steps=None):
|
||||
"""See trainable.Trainable."""
|
||||
if not monitors:
|
||||
monitors = [TensorForestLossHook(self.early_stopping_rounds)]
|
||||
self._estimator.fit(input_fn=input_fn, x=x, y=y,
|
||||
batch_size=batch_size, steps=steps, monitors=monitors,
|
||||
max_steps=max_steps)
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -135,13 +204,14 @@ class TensorForestEstimator(estimator.BaseEstimator):
|
||||
Raises:
|
||||
ValueError: If both or neither of x and input_fn were given.
|
||||
"""
|
||||
results = super(TensorForestEstimator, self).predict(
|
||||
results = self._estimator.predict(
|
||||
x=x, input_fn=input_fn, batch_size=batch_size, outputs=outputs,
|
||||
as_iterable=as_iterable)
|
||||
|
||||
if as_iterable:
|
||||
return (r['probabilities'] for r in results)
|
||||
return (x[eval_metrics.INFERENCE_PROB_NAME] for x in results)
|
||||
else:
|
||||
return results['probabilities']
|
||||
return results[eval_metrics.INFERENCE_PROB_NAME]
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -168,16 +238,16 @@ class TensorForestEstimator(estimator.BaseEstimator):
|
||||
Numpy array of predicted classes or regression values (or an iterable of
|
||||
predictions if as_iterable is True).
|
||||
"""
|
||||
probabilities = self.predict_proba(
|
||||
results = self._estimator.predict(
|
||||
x=x, input_fn=input_fn, batch_size=batch_size, outputs=outputs,
|
||||
as_iterable=as_iterable)
|
||||
if self.params.regression:
|
||||
return probabilities
|
||||
|
||||
predict_name = (eval_metrics.INFERENCE_PROB_NAME if self.params.regression
|
||||
else eval_metrics.INFERENCE_PRED_NAME)
|
||||
if as_iterable:
|
||||
return (x[predict_name] for x in results)
|
||||
else:
|
||||
if as_iterable:
|
||||
return (np.argmax(p, axis=0) for p in probabilities)
|
||||
else:
|
||||
return np.argmax(probabilities, axis=1)
|
||||
return results[predict_name]
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
@ -186,100 +256,40 @@ class TensorForestEstimator(estimator.BaseEstimator):
|
||||
self, x=None, input_fn=None, axis=None, batch_size=None, outputs=None,
|
||||
as_iterable=True):
|
||||
"""Same as predict but also returns the example keys."""
|
||||
results = super(TensorForestEstimator, self).predict(
|
||||
results = self._estimator.predict(
|
||||
x=x, input_fn=input_fn, batch_size=batch_size, outputs=outputs,
|
||||
as_iterable=as_iterable)
|
||||
if self.params.regression:
|
||||
if as_iterable:
|
||||
return ((r['probabilities'], r.get('keys', None)) for r in results)
|
||||
else:
|
||||
return results['probabilities'], results.get('keys', None)
|
||||
|
||||
predict_name = (eval_metrics.INFERENCE_PROB_NAME if self.params.regression
|
||||
else eval_metrics.INFERENCE_PRED_NAME)
|
||||
if as_iterable:
|
||||
return ((x[predict_name], x.get(KEYS_NAME, None)) for x in results)
|
||||
else:
|
||||
if as_iterable:
|
||||
return ((np.argmax(r['probabilities'], axis=0),
|
||||
r.get('keys', None)) for r in results)
|
||||
|
||||
else:
|
||||
return np.argmax(results['probabilities'], axis=1), results.get('keys',
|
||||
None)
|
||||
|
||||
def _get_train_ops(self, features, targets):
|
||||
"""Method that builds model graph and returns trainer ops.
|
||||
|
||||
Args:
|
||||
features: `Tensor` or `dict` of `Tensor` objects.
|
||||
targets: `Tensor` or `dict` of `Tensor` objects.
|
||||
|
||||
Returns:
|
||||
Tuple of train `Operation` and loss `Tensor`.
|
||||
"""
|
||||
features, _, weights, spec = data_ops.ParseDataTensorOrDict(features)
|
||||
labels = data_ops.ParseLabelTensorOrDict(targets)
|
||||
features, labels = self._feature_engineering_fn(features, labels)
|
||||
_assert_float32(features)
|
||||
_assert_float32(labels)
|
||||
|
||||
if weights is not None:
|
||||
if 'input_weights' in self.training_args:
|
||||
logging.warning('Replacing input_weights in training_args.')
|
||||
self.training_args['input_weights'] = weights
|
||||
|
||||
graph_builder = self.graph_builder_class(
|
||||
self.params, device_assigner=self.device_assigner,
|
||||
**self.construction_args)
|
||||
|
||||
epoch = None
|
||||
if self.data_feeder:
|
||||
epoch = self.data_feeder.make_epoch_variable()
|
||||
|
||||
train = control_flow_ops.group(
|
||||
graph_builder.training_graph(
|
||||
features, labels, data_spec=spec, epoch=epoch,
|
||||
**self.training_args),
|
||||
state_ops.assign_add(contrib_framework.get_global_step(), 1))
|
||||
|
||||
self.training_loss = graph_builder.training_loss(features, targets)
|
||||
|
||||
return train, self.training_loss
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
graph_builder = self.graph_builder_class(
|
||||
self.params, device_assigner=self.device_assigner, training=False,
|
||||
**self.construction_args)
|
||||
features, keys, _, spec = data_ops.ParseDataTensorOrDict(features)
|
||||
features, _ = self._feature_engineering_fn(features, None)
|
||||
_assert_float32(features)
|
||||
output_dict = {
|
||||
'probabilities': graph_builder.inference_graph(features,
|
||||
data_spec=spec)}
|
||||
if keys is not None:
|
||||
output_dict['keys'] = keys
|
||||
return output_dict
|
||||
|
||||
def _get_eval_ops(self, features, targets, metrics):
|
||||
features, _, _, spec = data_ops.ParseDataTensorOrDict(features)
|
||||
labels = data_ops.ParseLabelTensorOrDict(targets)
|
||||
features, labels = self._feature_engineering_fn(features, labels)
|
||||
_assert_float32(features)
|
||||
_assert_float32(labels)
|
||||
|
||||
graph_builder = self.graph_builder_class(
|
||||
self.params, device_assigner=self.device_assigner, training=False,
|
||||
**self.construction_args)
|
||||
|
||||
probabilities = graph_builder.inference_graph(features, data_spec=spec)
|
||||
|
||||
# One-hot the labels.
|
||||
if not self.params.regression:
|
||||
labels = math_ops.to_int64(array_ops.one_hot(math_ops.to_int64(
|
||||
array_ops.squeeze(labels)), self.params.num_classes, 1, 0))
|
||||
|
||||
if metrics is None:
|
||||
metrics = {self.accuracy_metric:
|
||||
eval_metrics.get_metric(self.accuracy_metric)}
|
||||
|
||||
result = {}
|
||||
for name, metric in six.iteritems(metrics):
|
||||
result[name] = metric(probabilities, labels)
|
||||
return results[predict_name], results.get(KEYS_NAME, None)
|
||||
|
||||
def export(self,
|
||||
export_dir,
|
||||
input_fn,
|
||||
signature_fn=None,
|
||||
default_batch_size=1):
|
||||
"""See BaseEstimator.export."""
|
||||
# Reset model function with basic device assigner.
|
||||
# Servo doesn't support distributed inference
|
||||
# but it will try to respect device assignments if they're there.
|
||||
# pylint: disable=protected-access
|
||||
orig_model_fn = self._estimator._model_fn
|
||||
self._estimator._model_fn = get_model_fn(
|
||||
self.params, self.graph_builder_class,
|
||||
tensor_forest.RandomForestDeviceAssigner())
|
||||
result = self._estimator.export(
|
||||
export_dir=export_dir,
|
||||
use_deprecated_input_fn=True,
|
||||
signature_fn=(signature_fn or
|
||||
(export.regression_signature_fn
|
||||
if self.params.regression else
|
||||
export.classification_signature_fn_with_prob)),
|
||||
default_batch_size=default_batch_size,
|
||||
prediction_key=eval_metrics.INFERENCE_PROB_NAME)
|
||||
self._estimator._model_fn = orig_model_fn
|
||||
# pylint: enable=protected-access
|
||||
return result
|
||||
|
@ -28,14 +28,30 @@ class TensorForestTrainerTests(tf.test.TestCase):
|
||||
def testClassification(self):
|
||||
"""Tests multi-class classification using matrix data as input."""
|
||||
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
|
||||
num_trees=3, max_nodes=1000, num_classes=3, num_features=4)
|
||||
classifier = tf.contrib.learn.TensorForestEstimator(hparams)
|
||||
num_trees=3, max_nodes=1000, num_classes=3, num_features=4,
|
||||
split_after_samples=20)
|
||||
classifier = tf.contrib.learn.TensorForestEstimator(hparams.fill())
|
||||
|
||||
iris = tf.contrib.learn.datasets.load_iris()
|
||||
data = iris.data.astype(np.float32)
|
||||
target = iris.target.astype(np.float32)
|
||||
|
||||
monitors = [tf.contrib.learn.TensorForestLossMonitor(10, 10)]
|
||||
classifier.fit(x=data, y=target, steps=100, batch_size=50)
|
||||
classifier.evaluate(x=data, y=target, steps=10)
|
||||
|
||||
def testClassificationTrainingLoss(self):
|
||||
"""Tests multi-class classification using matrix data as input."""
|
||||
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
|
||||
num_trees=3, max_nodes=1000, num_classes=3, num_features=4)
|
||||
classifier = tf.contrib.learn.TensorForestEstimator(
|
||||
hparams, graph_builder_class=(
|
||||
tf.contrib.tensor_forest.python.tensor_forest.TrainingLossForest))
|
||||
|
||||
iris = tf.contrib.learn.datasets.load_iris()
|
||||
data = iris.data.astype(np.float32)
|
||||
target = iris.target.astype(np.float32)
|
||||
|
||||
monitors = [tf.contrib.learn.TensorForestLossHook(10)]
|
||||
classifier.fit(x=data, y=target, steps=100, monitors=monitors)
|
||||
classifier.evaluate(x=data, y=target, steps=10)
|
||||
|
||||
@ -44,16 +60,15 @@ class TensorForestTrainerTests(tf.test.TestCase):
|
||||
|
||||
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
|
||||
num_trees=3, max_nodes=1000, num_classes=1, num_features=13,
|
||||
regression=True)
|
||||
regression=True, split_after_samples=20)
|
||||
|
||||
regressor = tf.contrib.learn.TensorForestEstimator(hparams)
|
||||
regressor = tf.contrib.learn.TensorForestEstimator(hparams.fill())
|
||||
|
||||
boston = tf.contrib.learn.datasets.load_boston()
|
||||
data = boston.data.astype(np.float32)
|
||||
target = boston.target.astype(np.float32)
|
||||
|
||||
monitors = [tf.contrib.learn.TensorForestLossMonitor(10, 10)]
|
||||
regressor.fit(x=data, y=target, steps=100, monitors=monitors)
|
||||
regressor.fit(x=data, y=target, steps=100, batch_size=50)
|
||||
regressor.evaluate(x=data, y=target, steps=10)
|
||||
|
||||
|
||||
|
@ -627,7 +627,7 @@ def _eval_results_to_str(eval_results):
|
||||
|
||||
def _write_summary_results(output_dir, eval_results, current_global_step):
|
||||
"""Writes eval results into summary file in given dir."""
|
||||
logging.info('Saving evaluation summary for %d step: %s', current_global_step,
|
||||
logging.info('Saving evaluation summary for step %d: %s', current_global_step,
|
||||
_eval_results_to_str(eval_results))
|
||||
summary_writer = get_summary_writer(output_dir)
|
||||
summary = summary_pb2.Summary()
|
||||
|
@ -253,6 +253,18 @@ def _get_shared_file_name_queue(file_names, shuffle, num_epochs, name):
|
||||
|
||||
|
||||
def _get_file_names(file_pattern, randomize_input):
|
||||
"""Parse list of file names from pattern, optionally shuffled.
|
||||
|
||||
Args:
|
||||
file_pattern: File glob pattern, or list of strings.
|
||||
randomize_input: Whether to shuffle the order of file names.
|
||||
|
||||
Returns:
|
||||
List of file names matching `file_pattern`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `file_pattern` is empty, or pattern matches no files.
|
||||
"""
|
||||
if isinstance(file_pattern, list):
|
||||
file_names = file_pattern
|
||||
if not file_names:
|
||||
@ -304,6 +316,36 @@ def _read_keyed_batch_examples_helper(file_pattern,
|
||||
parse_fn=None,
|
||||
setup_shared_queue=False,
|
||||
name=None):
|
||||
"""Adds operations to read, queue, batch `Example` protos.
|
||||
|
||||
Args:
|
||||
file_pattern: List of files or pattern of file paths containing
|
||||
`Example` records. See `tf.gfile.Glob` for pattern rules.
|
||||
batch_size: An int or scalar `Tensor` specifying the batch size to use.
|
||||
reader: A function or class that returns an object with
|
||||
`read` method, (filename tensor) -> (example tensor).
|
||||
randomize_input: Whether the input should be randomized.
|
||||
num_epochs: Integer specifying the number of times to read through the
|
||||
dataset. If `None`, cycles through the dataset forever.
|
||||
NOTE - If specified, creates a variable that must be initialized, so call
|
||||
`tf.initialize_all_variables()` as shown in the tests.
|
||||
queue_capacity: Capacity for input queue.
|
||||
num_threads: The number of threads enqueuing examples.
|
||||
read_batch_size: An int or scalar `Tensor` specifying the number of
|
||||
records to read at once
|
||||
parse_fn: Parsing function, takes `Example` Tensor returns parsed
|
||||
representation. If `None`, no parsing is done.
|
||||
setup_shared_queue: Whether to set up a shared queue for file names.
|
||||
name: Name of resulting op.
|
||||
|
||||
Returns:
|
||||
Returns tuple of:
|
||||
- `Tensor` of string keys.
|
||||
- String `Tensor` of batched `Example` proto.
|
||||
|
||||
Raises:
|
||||
ValueError: for invalid inputs.
|
||||
"""
|
||||
# Retrieve files to read.
|
||||
file_names = _get_file_names(file_pattern, randomize_input)
|
||||
|
||||
@ -348,10 +390,10 @@ def _read_keyed_batch_examples_helper(file_pattern,
|
||||
|
||||
enqueue_many = read_batch_size > 1
|
||||
|
||||
if num_epochs is not None:
|
||||
allow_smaller_final_batch = True
|
||||
else:
|
||||
if num_epochs is None:
|
||||
allow_smaller_final_batch = False
|
||||
else:
|
||||
allow_smaller_final_batch = True
|
||||
|
||||
# Setup batching queue given list of read example tensors.
|
||||
if randomize_input:
|
||||
@ -505,7 +547,6 @@ def _read_keyed_batch_features_shared_queue(file_pattern,
|
||||
Adding multiple queue runners for the parsed example queue helps maintain
|
||||
a full queue when the subsequent computations overall are cheaper than
|
||||
parsing.
|
||||
parser_num_threads: (Deprecated) The number of threads to parse examples.
|
||||
parse_fn: Parsing function, takes `Example` Tensor returns parsed
|
||||
representation. If `None`, no parsing is done.
|
||||
name: Name of resulting op.
|
||||
|
@ -121,7 +121,8 @@ class GraphIOTest(tf.test.TestCase):
|
||||
batch_size = 17
|
||||
queue_capacity = 1234
|
||||
name = "my_batch"
|
||||
features = {"feature": tf.FixedLenFeature(shape=[0], dtype=tf.float32)}
|
||||
shape = (0,)
|
||||
features = {"feature": tf.FixedLenFeature(shape=shape, dtype=tf.float32)}
|
||||
|
||||
with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
|
||||
features = tf.contrib.learn.io.read_batch_record_features(
|
||||
@ -132,8 +133,11 @@ class GraphIOTest(tf.test.TestCase):
|
||||
queue_capacity=queue_capacity,
|
||||
reader_num_threads=2,
|
||||
name=name)
|
||||
self.assertEqual("%s/fifo_queue_1_Dequeue:0" % name,
|
||||
features["feature"].name)
|
||||
self.assertTrue(
|
||||
"feature" in features, "'feature' missing from %s." % features.keys())
|
||||
feature = features["feature"]
|
||||
self.assertEqual("%s/fifo_queue_1_Dequeue:0" % name, feature.name)
|
||||
self.assertAllEqual((batch_size,) + shape, feature.get_shape().as_list())
|
||||
file_name_queue_name = "%s/file_name_queue" % name
|
||||
file_names_name = "%s/input" % file_name_queue_name
|
||||
example_queue_name = "%s/fifo_queue" % name
|
||||
@ -161,6 +165,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
reader=tf.TFRecordReader, randomize_input=True,
|
||||
num_epochs=1,
|
||||
queue_capacity=queue_capacity, name=name)
|
||||
self.assertAllEqual((None,), inputs.get_shape().as_list())
|
||||
self.assertEqual("%s:1" % name, inputs.name)
|
||||
file_name_queue_name = "%s/file_name_queue" % name
|
||||
file_name_queue_limit_name = (
|
||||
@ -190,6 +195,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
_VALID_FILE_PATTERN, batch_size,
|
||||
reader=tf.TFRecordReader, randomize_input=True,
|
||||
queue_capacity=queue_capacity, name=name)
|
||||
self.assertAllEqual((batch_size,), inputs.get_shape().as_list())
|
||||
self.assertEqual("%s:1" % name, inputs.name)
|
||||
file_name_queue_name = "%s/file_name_queue" % name
|
||||
file_names_name = "%s/input" % file_name_queue_name
|
||||
@ -234,6 +240,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
filename, batch_size, reader=tf.TextLineReader,
|
||||
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
|
||||
name=name)
|
||||
self.assertAllEqual((None,), inputs.get_shape().as_list())
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
@ -280,10 +287,13 @@ class GraphIOTest(tf.test.TestCase):
|
||||
features = {"sequence": tf.FixedLenFeature([], tf.string)}
|
||||
|
||||
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
|
||||
_, result = tf.contrib.learn.read_keyed_batch_features(
|
||||
keys, result = tf.contrib.learn.read_keyed_batch_features(
|
||||
filename, batch_size, features, tf.TextLineReader,
|
||||
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
|
||||
num_enqueue_threads=2, parse_fn=tf.decode_json_example, name=name)
|
||||
self.assertAllEqual((None,), keys.get_shape().as_list())
|
||||
self.assertEqual(1, len(result))
|
||||
self.assertAllEqual((None,), result["sequence"].get_shape().as_list())
|
||||
session.run(tf.initialize_local_variables())
|
||||
coord = tf.train.Coordinator()
|
||||
threads = tf.train.start_queue_runners(session, coord=coord)
|
||||
@ -319,6 +329,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
filenames, batch_size, reader=tf.TextLineReader,
|
||||
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
|
||||
name=name)
|
||||
self.assertAllEqual((None,), inputs.get_shape().as_list())
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
@ -354,7 +365,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
name = "my_batch"
|
||||
|
||||
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
|
||||
_, inputs = _read_keyed_batch_examples_shared_queue(
|
||||
keys, inputs = _read_keyed_batch_examples_shared_queue(
|
||||
filenames,
|
||||
batch_size,
|
||||
reader=tf.TextLineReader,
|
||||
@ -362,6 +373,8 @@ class GraphIOTest(tf.test.TestCase):
|
||||
num_epochs=1,
|
||||
queue_capacity=queue_capacity,
|
||||
name=name)
|
||||
self.assertAllEqual((None,), keys.get_shape().as_list())
|
||||
self.assertAllEqual((None,), inputs.get_shape().as_list())
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
@ -418,7 +431,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
|
||||
with tf.Graph().as_default() as g1, tf.Session(
|
||||
server.target, graph=g1) as session:
|
||||
_, inputs = _read_keyed_batch_examples_shared_queue(
|
||||
keys, inputs = _read_keyed_batch_examples_shared_queue(
|
||||
filenames,
|
||||
batch_size,
|
||||
reader=tf.TextLineReader,
|
||||
@ -426,6 +439,8 @@ class GraphIOTest(tf.test.TestCase):
|
||||
num_epochs=1,
|
||||
queue_capacity=queue_capacity,
|
||||
name=name)
|
||||
self.assertAllEqual((None,), keys.get_shape().as_list())
|
||||
self.assertAllEqual((None,), inputs.get_shape().as_list())
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
# Run the three queues once manually.
|
||||
@ -443,7 +458,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
|
||||
with tf.Graph().as_default() as g2, tf.Session(
|
||||
server.target, graph=g2) as session:
|
||||
_, inputs = _read_keyed_batch_examples_shared_queue(
|
||||
keys, inputs = _read_keyed_batch_examples_shared_queue(
|
||||
filenames,
|
||||
batch_size,
|
||||
reader=tf.TextLineReader,
|
||||
@ -451,6 +466,8 @@ class GraphIOTest(tf.test.TestCase):
|
||||
num_epochs=1,
|
||||
queue_capacity=queue_capacity,
|
||||
name=name)
|
||||
self.assertAllEqual((None,), keys.get_shape().as_list())
|
||||
self.assertAllEqual((None,), inputs.get_shape().as_list())
|
||||
|
||||
# Run the worker and the example queue.
|
||||
self._run_queue(worker_file_name_queue_name, session)
|
||||
@ -473,6 +490,7 @@ class GraphIOTest(tf.test.TestCase):
|
||||
[filename], batch_size, reader=tf.TextLineReader,
|
||||
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
|
||||
read_batch_size=10, name=name)
|
||||
self.assertAllEqual((None,), inputs.get_shape().as_list())
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
@ -499,6 +517,8 @@ class GraphIOTest(tf.test.TestCase):
|
||||
filename, batch_size,
|
||||
reader=tf.TextLineReader, randomize_input=False,
|
||||
num_epochs=1, queue_capacity=queue_capacity, name=name)
|
||||
self.assertAllEqual((None,), keys.get_shape().as_list())
|
||||
self.assertAllEqual((None,), inputs.get_shape().as_list())
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
@ -537,6 +557,9 @@ class GraphIOTest(tf.test.TestCase):
|
||||
reader=tf.TextLineReader, randomize_input=False,
|
||||
num_epochs=1, queue_capacity=queue_capacity,
|
||||
parse_fn=parse_fn, name=name)
|
||||
self.assertAllEqual((None,), keys.get_shape().as_list())
|
||||
self.assertEqual(1, len(inputs))
|
||||
self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list())
|
||||
session.run(tf.initialize_local_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.contrib.framework import deprecated_arg_values
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.session_bundle import exporter
|
||||
from tensorflow.contrib.session_bundle import gc
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -53,7 +54,7 @@ def _get_saver():
|
||||
else:
|
||||
saver = None
|
||||
if saver is None and variables.all_variables():
|
||||
saver = tf_saver.Saver()
|
||||
saver = tf_saver.Saver(write_version=saver_pb2.SaverDef.V1)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
|
||||
return saver
|
||||
|
||||
|
@ -21,6 +21,7 @@ tensorflow/core/kernels/strided_slice_op_inst_4.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_3.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_2.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_1.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_0.cc
|
||||
tensorflow/core/kernels/strided_slice_op.cc
|
||||
tensorflow/core/kernels/stack_ops.cc
|
||||
tensorflow/core/kernels/split_op.cc
|
||||
@ -142,6 +143,7 @@ tensorflow/core/kernels/avgpooling_op.cc
|
||||
tensorflow/core/kernels/argmax_op.cc
|
||||
tensorflow/core/kernels/aggregate_ops.cc
|
||||
tensorflow/core/kernels/dequantize_op.cc
|
||||
tensorflow/core/kernels/meta_support.cc
|
||||
tensorflow/core/kernels/quantization_utils.cc
|
||||
tensorflow/core/kernels/quantize_down_and_shrink_range.cc
|
||||
tensorflow/core/kernels/quantize_op.cc
|
||||
@ -153,6 +155,7 @@ tensorflow/core/kernels/quantized_conv_ops.cc
|
||||
tensorflow/core/kernels/quantized_matmul_op.cc
|
||||
tensorflow/core/kernels/quantized_pooling_ops.cc
|
||||
tensorflow/core/kernels/quantized_reshape_op.cc
|
||||
tensorflow/core/kernels/requantization_range_op.cc
|
||||
tensorflow/core/kernels/requantize.cc
|
||||
tensorflow/core/ops/training_ops.cc
|
||||
tensorflow/core/ops/string_ops.cc
|
||||
|
@ -95,11 +95,6 @@ Certain metrics, such as streaming_mean or streaming_accuracy, can be weighted
|
||||
via a `weights` argument. The `weights` tensor must be the same size as the
|
||||
labels and predictions tensors and results in a weighted average of the metric.
|
||||
|
||||
Other metrics, such as streaming_recall, streaming_precision, and streaming_auc,
|
||||
are not well defined with regard to weighted samples. However, a binary
|
||||
`ignore_mask` argument can be used to ignore certain values at graph executation
|
||||
time.
|
||||
|
||||
## Metric `Ops`
|
||||
|
||||
@@streaming_accuracy
|
||||
|
@ -23,7 +23,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework import deprecated_args
|
||||
from tensorflow.contrib.framework import tensor_util
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
|
||||
@ -41,40 +40,6 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
IGNORE_MASK_DATE = '2016-10-19'
|
||||
IGNORE_MASK_INSTRUCTIONS = (
|
||||
'`ignore_mask` is being deprecated. Instead use `weights` with values 0.0 '
|
||||
'and 1.0 to mask values. For example, `weights=tf.logical_not(mask)`.')
|
||||
|
||||
|
||||
def _mask_weights(mask=None, weights=None):
|
||||
"""Mask a given set of weights.
|
||||
|
||||
Elements are included when the corresponding `mask` element is `False`, and
|
||||
excluded otherwise.
|
||||
|
||||
Args:
|
||||
mask: An optional, `bool` `Tensor`.
|
||||
weights: An optional `Tensor` whose shape matches `mask` if `mask` is not
|
||||
`None`.
|
||||
|
||||
Returns:
|
||||
Masked weights if `mask` and `weights` are not `None`, weights equivalent to
|
||||
`mask` if `weights` is `None`, and otherwise `weights`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `weights` and `mask` are not `None` and have mismatched
|
||||
shapes.
|
||||
"""
|
||||
if mask is not None:
|
||||
check_ops.assert_type(mask, dtypes.bool)
|
||||
if weights is None:
|
||||
weights = array_ops.ones_like(mask, dtype=dtypes.float32)
|
||||
weights = math_ops.cast(math_ops.logical_not(mask), weights.dtype) * weights
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def _safe_div(numerator, denominator, name):
|
||||
"""Divides two values, returning 0 if the denominator is <= 0.
|
||||
|
||||
@ -516,8 +481,7 @@ def streaming_accuracy(predictions, labels, weights=None,
|
||||
updates_collections, name or 'accuracy')
|
||||
|
||||
|
||||
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
|
||||
def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
|
||||
def streaming_precision(predictions, labels, weights=None,
|
||||
metrics_collections=None, updates_collections=None,
|
||||
name=None):
|
||||
"""Computes the precision of the predictions with respect to the labels.
|
||||
@ -534,14 +498,11 @@ def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
|
||||
`weights`.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
Alternatively, if `ignore_mask` is not `None`, then mask values where
|
||||
`ignore_mask` is `True`.
|
||||
|
||||
Args:
|
||||
predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
|
||||
labels: The ground truth values, a `bool` `Tensor` whose dimensions must
|
||||
match `predictions`.
|
||||
ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
|
||||
weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
|
||||
metrics_collections: An optional list of collections that `precision` should
|
||||
be added to.
|
||||
@ -558,9 +519,8 @@ def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
|
||||
|
||||
Raises:
|
||||
ValueError: If `predictions` and `labels` have mismatched shapes, or if
|
||||
`ignore_mask` is not `None` and its shape doesn't match `predictions`, or
|
||||
if `weights` is not `None` and its shape doesn't match `predictions`, or
|
||||
if either `metrics_collections` or `updates_collections` are not a list or
|
||||
`weights` is not `None` and its shape doesn't match `predictions`, or if
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
@ -570,7 +530,6 @@ def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
|
||||
weights = _mask_weights(ignore_mask, weights)
|
||||
true_positives, true_positives_update_op = _streaming_true_positives(
|
||||
predictions, labels, weights, metrics_collections=None,
|
||||
updates_collections=None, name=None)
|
||||
@ -599,8 +558,7 @@ def streaming_precision(predictions, labels, ignore_mask=None, weights=None,
|
||||
return precision, update_op
|
||||
|
||||
|
||||
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
|
||||
def streaming_recall(predictions, labels, ignore_mask=None, weights=None,
|
||||
def streaming_recall(predictions, labels, weights=None,
|
||||
metrics_collections=None, updates_collections=None,
|
||||
name=None):
|
||||
"""Computes the recall of the predictions with respect to the labels.
|
||||
@ -615,14 +573,11 @@ def streaming_recall(predictions, labels, ignore_mask=None, weights=None,
|
||||
weights each prediction by the corresponding value in `weights`.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
Alternatively, if `ignore_mask` is not `None`, then mask values where
|
||||
`ignore_mask` is `True`.
|
||||
|
||||
Args:
|
||||
predictions: The predicted values, a `bool` `Tensor` of arbitrary shape.
|
||||
labels: The ground truth values, a `bool` `Tensor` whose dimensions must
|
||||
match `predictions`.
|
||||
ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
|
||||
weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
|
||||
metrics_collections: An optional list of collections that `recall` should
|
||||
be added to.
|
||||
@ -639,9 +594,8 @@ def streaming_recall(predictions, labels, ignore_mask=None, weights=None,
|
||||
|
||||
Raises:
|
||||
ValueError: If `predictions` and `labels` have mismatched shapes, or if
|
||||
`ignore_mask` is not `None` and its shape doesn't match `predictions`, or
|
||||
if `weights` is not `None` and its shape doesn't match `predictions`, or
|
||||
if either `metrics_collections` or `updates_collections` are not a list or
|
||||
`weights` is not `None` and its shape doesn't match `predictions`, or if
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'recall', [predictions, labels]):
|
||||
@ -649,7 +603,6 @@ def streaming_recall(predictions, labels, ignore_mask=None, weights=None,
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
|
||||
weights = _mask_weights(ignore_mask, weights)
|
||||
true_positives, true_positives_update_op = _streaming_true_positives(
|
||||
predictions, labels, weights, metrics_collections=None,
|
||||
updates_collections=None, name=None)
|
||||
@ -1235,10 +1188,9 @@ def _at_k_name(name, k=None, class_id=None):
|
||||
|
||||
@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, '
|
||||
'and reshape labels from [batch_size] to [batch_size, 1].')
|
||||
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
|
||||
def streaming_recall_at_k(predictions, labels, k, ignore_mask=None,
|
||||
weights=None, metrics_collections=None,
|
||||
updates_collections=None, name=None):
|
||||
def streaming_recall_at_k(predictions, labels, k, weights=None,
|
||||
metrics_collections=None, updates_collections=None,
|
||||
name=None):
|
||||
"""Computes the recall@k of the predictions with respect to dense labels.
|
||||
|
||||
The `streaming_recall_at_k` function creates two local variables, `total` and
|
||||
@ -1255,15 +1207,12 @@ def streaming_recall_at_k(predictions, labels, k, ignore_mask=None,
|
||||
increments `count` with the reduced sum of `weights`.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
Alternatively, if `ignore_mask` is not `None`, then mask values where
|
||||
`ignore_mask` is `True`.
|
||||
|
||||
Args:
|
||||
predictions: A floating point tensor of dimension [batch_size, num_classes]
|
||||
labels: A tensor of dimension [batch_size] whose type is in `int32`,
|
||||
`int64`.
|
||||
k: The number of top elements to look at for computing recall.
|
||||
ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
|
||||
weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
|
||||
metrics_collections: An optional list of collections that `recall_at_k`
|
||||
should be added to.
|
||||
@ -1279,26 +1228,23 @@ def streaming_recall_at_k(predictions, labels, k, ignore_mask=None,
|
||||
|
||||
Raises:
|
||||
ValueError: If `predictions` and `labels` have mismatched shapes, or if
|
||||
`ignore_mask` is not `None` and its shape doesn't match `predictions`, or
|
||||
if `weights` is not `None` and its shape doesn't match `predictions`, or
|
||||
if either `metrics_collections` or `updates_collections` are not a list or
|
||||
`weights` is not `None` and its shape doesn't match `predictions`, or if
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k))
|
||||
return streaming_mean(in_top_k,
|
||||
_mask_weights(ignore_mask, weights),
|
||||
weights,
|
||||
metrics_collections,
|
||||
updates_collections,
|
||||
name or _at_k_name('recall', k))
|
||||
|
||||
|
||||
# TODO(ptucker): Validate range of values in labels?
|
||||
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
|
||||
def streaming_sparse_recall_at_k(predictions,
|
||||
labels,
|
||||
k,
|
||||
class_id=None,
|
||||
ignore_mask=None,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
@ -1328,8 +1274,6 @@ def streaming_sparse_recall_at_k(predictions,
|
||||
`false_negative_at_<k>` using these values.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
Alternatively, if `ignore_mask` is not `None`, then mask values where
|
||||
`ignore_mask` is `True`.
|
||||
|
||||
Args:
|
||||
predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
|
||||
@ -1347,8 +1291,6 @@ def streaming_sparse_recall_at_k(predictions,
|
||||
class_id: Integer class ID for which we want binary metrics. This should be
|
||||
in range [0, num_classes), where num_classes is the last dimension of
|
||||
`predictions`. If class_id is outside this range, the method returns NAN.
|
||||
ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
|
||||
the the first [D1, ... DN] dimensions of `predictions` and `labels`.
|
||||
weights: An optional `Tensor` whose shape is broadcastable to the the first
|
||||
[D1, ... DN] dimensions of `predictions` and `labels`.
|
||||
metrics_collections: An optional list of collections that values should
|
||||
@ -1365,16 +1307,14 @@ def streaming_sparse_recall_at_k(predictions,
|
||||
`recall`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
|
||||
`predictions`, or if `weights` is not `None` and its shape doesn't match
|
||||
`predictions`, or if either `metrics_collections` or `updates_collections`
|
||||
are not a list or tuple.
|
||||
ValueError: If `weights` is not `None` and its shape doesn't match
|
||||
`predictions`, or if either `metrics_collections` or `updates_collections`
|
||||
are not a list or tuple.
|
||||
"""
|
||||
default_name = _at_k_name('recall', k, class_id=class_id)
|
||||
with ops.name_scope(name, default_name, (predictions, labels)) as scope:
|
||||
_, top_k_idx = nn.top_k(predictions, k)
|
||||
top_k_idx = math_ops.to_int64(top_k_idx)
|
||||
weights = _mask_weights(ignore_mask, weights)
|
||||
tp, tp_update = _streaming_sparse_true_positive_at_k(
|
||||
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
|
||||
weights=weights)
|
||||
@ -1396,7 +1336,6 @@ def _streaming_sparse_precision_at_k(top_k_idx,
|
||||
labels,
|
||||
k=None,
|
||||
class_id=None,
|
||||
ignore_mask=None,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
@ -1423,8 +1362,6 @@ def _streaming_sparse_precision_at_k(top_k_idx,
|
||||
in range [0, num_classes), where num_classes is the last dimension of
|
||||
`predictions`. If `class_id` is outside this range, the method returns
|
||||
NAN.
|
||||
ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
|
||||
the the first [D1, ... DN] dimensions of `predictions` and `labels`.
|
||||
weights: An optional `Tensor` whose shape is broadcastable to the the first
|
||||
[D1, ... DN] dimensions of `predictions` and `labels`.
|
||||
metrics_collections: An optional list of collections that values should
|
||||
@ -1441,13 +1378,11 @@ def _streaming_sparse_precision_at_k(top_k_idx,
|
||||
`precision`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
|
||||
`predictions`, or if `weights` is not `None` and its shape doesn't match
|
||||
ValueError: If `weights` is not `None` and its shape doesn't match
|
||||
`predictions`, or if either `metrics_collections` or `updates_collections`
|
||||
are not a list or tuple.
|
||||
"""
|
||||
top_k_idx = math_ops.to_int64(top_k_idx)
|
||||
weights = _mask_weights(ignore_mask, weights)
|
||||
tp, tp_update = _streaming_sparse_true_positive_at_k(
|
||||
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
|
||||
weights=weights)
|
||||
@ -1466,12 +1401,10 @@ def _streaming_sparse_precision_at_k(top_k_idx,
|
||||
|
||||
|
||||
# TODO(ptucker): Validate range of values in labels?
|
||||
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
|
||||
def streaming_sparse_precision_at_k(predictions,
|
||||
labels,
|
||||
k,
|
||||
class_id=None,
|
||||
ignore_mask=None,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
@ -1502,8 +1435,6 @@ def streaming_sparse_precision_at_k(predictions,
|
||||
`false_positive_at_<k>` using these values.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
Alternatively, if `ignore_mask` is not `None`, then mask values where
|
||||
`ignore_mask` is `True`.
|
||||
|
||||
Args:
|
||||
predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
|
||||
@ -1522,8 +1453,6 @@ def streaming_sparse_precision_at_k(predictions,
|
||||
in range [0, num_classes], where num_classes is the last dimension of
|
||||
`predictions`. If `class_id` is outside this range, the method returns
|
||||
NAN.
|
||||
ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
|
||||
the the first [D1, ... DN] dimensions of `predictions` and `labels`.
|
||||
weights: An optional `Tensor` whose shape is broadcastable to the the first
|
||||
[D1, ... DN] dimensions of `predictions` and `labels`.
|
||||
metrics_collections: An optional list of collections that values should
|
||||
@ -1540,21 +1469,19 @@ def streaming_sparse_precision_at_k(predictions,
|
||||
`precision`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
|
||||
`predictions`, or if `weights` is not `None` and its shape doesn't match
|
||||
ValueError: If `weights` is not `None` and its shape doesn't match
|
||||
`predictions`, or if either `metrics_collections` or `updates_collections`
|
||||
are not a list or tuple.
|
||||
"""
|
||||
default_name = _at_k_name('precision', k, class_id=class_id)
|
||||
with ops.name_scope(name, default_name,
|
||||
(predictions, labels, ignore_mask, weights)) as scope:
|
||||
(predictions, labels, weights)) as scope:
|
||||
_, top_k_idx = nn.top_k(predictions, k)
|
||||
return _streaming_sparse_precision_at_k(
|
||||
top_k_idx=top_k_idx,
|
||||
labels=labels,
|
||||
k=k,
|
||||
class_id=class_id,
|
||||
ignore_mask=ignore_mask,
|
||||
weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections,
|
||||
@ -1562,11 +1489,9 @@ def streaming_sparse_precision_at_k(predictions,
|
||||
|
||||
|
||||
# TODO(ptucker): Validate range of values in labels?
|
||||
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
|
||||
def streaming_sparse_precision_at_top_k(top_k_predictions,
|
||||
labels,
|
||||
class_id=None,
|
||||
ignore_mask=None,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
@ -1595,8 +1520,6 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
|
||||
`false_positive_at_k` using these values.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
Alternatively, if `ignore_mask` is not `None`, then mask values where
|
||||
`ignore_mask` is `True`.
|
||||
|
||||
Args:
|
||||
top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
|
||||
@ -1614,8 +1537,6 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
|
||||
in range [0, num_classes), where num_classes is the last dimension of
|
||||
`predictions`. If `class_id` is outside this range, the method returns
|
||||
NAN.
|
||||
ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
|
||||
the the first [D1, ... DN] dimensions of `predictions` and `labels`.
|
||||
weights: An optional `Tensor` whose shape is broadcastable to the the first
|
||||
[D1, ... DN] dimensions of `predictions` and `labels`.
|
||||
metrics_collections: An optional list of collections that values should
|
||||
@ -1632,8 +1553,7 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
|
||||
`precision`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
|
||||
`predictions`, or if `weights` is not `None` and its shape doesn't match
|
||||
ValueError: If `weights` is not `None` and its shape doesn't match
|
||||
`predictions`, or if either `metrics_collections` or `updates_collections`
|
||||
are not a list or tuple.
|
||||
ValueError: If `top_k_predictions` has rank < 2.
|
||||
@ -1641,7 +1561,7 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
|
||||
default_name = _at_k_name('precision', class_id=class_id)
|
||||
with ops.name_scope(
|
||||
name, default_name,
|
||||
(top_k_predictions, labels, ignore_mask, weights)) as scope:
|
||||
(top_k_predictions, labels, weights)) as scope:
|
||||
rank = array_ops.rank(top_k_predictions)
|
||||
check_rank_op = control_flow_ops.Assert(
|
||||
math_ops.greater_equal(rank, 2),
|
||||
@ -1651,7 +1571,6 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
|
||||
top_k_idx=top_k_predictions,
|
||||
labels=labels,
|
||||
class_id=class_id,
|
||||
ignore_mask=ignore_mask,
|
||||
weights=weights,
|
||||
metrics_collections=metrics_collections,
|
||||
updates_collections=updates_collections,
|
||||
@ -2760,8 +2679,7 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
|
||||
return mean_distance, update_op
|
||||
|
||||
|
||||
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
|
||||
def streaming_percentage_less(values, threshold, ignore_mask=None, weights=None,
|
||||
def streaming_percentage_less(values, threshold, weights=None,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
name=None):
|
||||
@ -2778,13 +2696,10 @@ def streaming_percentage_less(values, threshold, ignore_mask=None, weights=None,
|
||||
`percentage`.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
Alternatively, if `ignore_mask` is not `None`, then mask values where
|
||||
`ignore_mask` is `True`.
|
||||
|
||||
Args:
|
||||
values: A numeric `Tensor` of arbitrary size.
|
||||
threshold: A scalar threshold.
|
||||
ignore_mask: An optional, `bool` `Tensor` whose shape matches `values`.
|
||||
weights: An optional `Tensor` whose shape is broadcastable to `values`.
|
||||
metrics_collections: An optional list of collections that the metric
|
||||
value variable should be added to.
|
||||
@ -2799,23 +2714,21 @@ def streaming_percentage_less(values, threshold, ignore_mask=None, weights=None,
|
||||
appropriately.
|
||||
|
||||
Raises:
|
||||
ValueError: If `ignore_mask` is not `None` and its shape doesn't match
|
||||
`values`, or if `weights` is not `None` and its shape doesn't match
|
||||
`values`, or if either `metrics_collections` or `updates_collections` are
|
||||
not a list or tuple.
|
||||
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
|
||||
or if either `metrics_collections` or `updates_collections` are not a list
|
||||
or tuple.
|
||||
"""
|
||||
is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
|
||||
return streaming_mean(is_below_threshold, _mask_weights(ignore_mask, weights),
|
||||
return streaming_mean(is_below_threshold,
|
||||
weights,
|
||||
metrics_collections,
|
||||
updates_collections,
|
||||
name or 'percentage_below_threshold')
|
||||
|
||||
|
||||
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
|
||||
def streaming_mean_iou(predictions,
|
||||
labels,
|
||||
num_classes,
|
||||
ignore_mask=None,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
@ -2834,8 +2747,6 @@ def streaming_mean_iou(predictions,
|
||||
`update_op` operation that updates these variables and returns the `mean_iou`.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
Alternatively, if `ignore_mask` is not `None`, then mask values where
|
||||
`ignore_mask` is `True`.
|
||||
|
||||
Args:
|
||||
predictions: A tensor of prediction results for semantic labels, whose
|
||||
@ -2846,7 +2757,6 @@ def streaming_mean_iou(predictions,
|
||||
num_classes: The possible number of labels the prediction task can
|
||||
have. This value must be provided, since a confusion matrix of
|
||||
dimension = [num_classes, num_classes] will be allocated.
|
||||
ignore_mask: An optional, `bool` `Tensor` whose shape matches `predictions`.
|
||||
weights: An optional `Tensor` whose shape is broadcastable to `predictions`.
|
||||
metrics_collections: An optional list of collections that `mean_iou`
|
||||
should be added to.
|
||||
@ -2860,9 +2770,8 @@ def streaming_mean_iou(predictions,
|
||||
|
||||
Raises:
|
||||
ValueError: If `predictions` and `labels` have mismatched shapes, or if
|
||||
`ignore_mask` is not `None` and its shape doesn't match `predictions`, or
|
||||
if `weights` is not `None` and its shape doesn't match `predictions`, or
|
||||
if either `metrics_collections` or `updates_collections` are not a list or
|
||||
`weights` is not `None` and its shape doesn't match `predictions`, or if
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'mean_iou', [predictions, labels]):
|
||||
@ -2888,7 +2797,6 @@ def streaming_mean_iou(predictions,
|
||||
if labels_rank > 1:
|
||||
labels = array_ops.reshape(labels, [-1])
|
||||
|
||||
weights = _mask_weights(ignore_mask, weights)
|
||||
if weights is not None:
|
||||
weights_rank = weights.get_shape().ndims
|
||||
if weights_rank > 1:
|
||||
|
@ -671,18 +671,6 @@ class StreamingPrecisionTest(tf.test.TestCase):
|
||||
self.assertAlmostEqual(0.5, update_op.eval())
|
||||
self.assertAlmostEqual(0.5, precision.eval())
|
||||
|
||||
def testMasked(self):
|
||||
predictions = tf.constant([1, 0, 1, 0, 1], shape=(1, 5))
|
||||
labels = tf.constant([0, 1, 1, 0, 1], shape=(1, 5))
|
||||
mask = tf.constant([False, False, False, False, True], shape=(1, 5))
|
||||
precision, update_op = metrics.streaming_precision(
|
||||
predictions, labels, ignore_mask=mask)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_local_variables())
|
||||
self.assertAlmostEqual(0.5, update_op.eval())
|
||||
self.assertAlmostEqual(0.5, precision.eval())
|
||||
|
||||
def testWeighted1d(self):
|
||||
predictions = tf.constant([[1, 0, 1, 0], [1, 0, 1, 0]])
|
||||
labels = tf.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
|
||||
@ -838,18 +826,6 @@ class StreamingRecallTest(tf.test.TestCase):
|
||||
self.assertAlmostEqual(0.5, update_op.eval())
|
||||
self.assertAlmostEqual(0.5, recall.eval())
|
||||
|
||||
def testMasked(self):
|
||||
predictions = tf.constant([1, 0, 1, 0, 1], shape=(1, 5))
|
||||
labels = tf.constant([0, 1, 1, 0, 1], shape=(1, 5))
|
||||
mask = tf.constant([False, False, False, False, True], shape=(1, 5))
|
||||
recall, update_op = metrics.streaming_recall(
|
||||
predictions, labels, ignore_mask=mask)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_local_variables())
|
||||
self.assertAlmostEqual(0.5, update_op.eval())
|
||||
self.assertAlmostEqual(0.5, recall.eval())
|
||||
|
||||
def testWeighted1d(self):
|
||||
predictions = tf.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
|
||||
labels = tf.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
|
||||
@ -1737,15 +1713,13 @@ class StreamingRecallAtKTest(tf.test.TestCase):
|
||||
dtype=tf.float32)
|
||||
labels = tf.constant(
|
||||
self._np_labels, shape=(self._batch_size,), dtype=tf.int64)
|
||||
weights = tf.constant([0, 1, 1, 1], shape=(self._batch_size,),
|
||||
weights = tf.constant([0, 1, 0, 1], shape=(self._batch_size,),
|
||||
dtype=tf.float32)
|
||||
mask = tf.constant([False, False, True, False], shape=(self._batch_size,),
|
||||
dtype=tf.bool)
|
||||
recall, update_op = metrics.streaming_recall_at_k(
|
||||
predictions, labels, k=2, ignore_mask=mask, weights=weights)
|
||||
predictions, labels, k=2, weights=weights)
|
||||
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
|
||||
predictions, tf.reshape(labels, (self._batch_size, 1)), k=2,
|
||||
ignore_mask=mask, weights=weights)
|
||||
weights=weights)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_local_variables())
|
||||
@ -1763,16 +1737,13 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
|
||||
k,
|
||||
expected,
|
||||
class_id=None,
|
||||
ignore_mask=None,
|
||||
weights=None):
|
||||
with tf.Graph().as_default() as g, self.test_session(g):
|
||||
if ignore_mask is not None:
|
||||
ignore_mask = tf.constant(ignore_mask, tf.bool)
|
||||
if weights is not None:
|
||||
weights = tf.constant(weights, tf.float32)
|
||||
metric, update = metrics.streaming_sparse_precision_at_k(
|
||||
predictions=tf.constant(predictions, tf.float32), labels=labels,
|
||||
k=k, class_id=class_id, ignore_mask=ignore_mask, weights=weights)
|
||||
k=k, class_id=class_id, weights=weights)
|
||||
|
||||
# Fails without initialized vars.
|
||||
self.assertRaises(tf.OpError, metric.eval)
|
||||
@ -1792,17 +1763,13 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
|
||||
labels,
|
||||
expected,
|
||||
class_id=None,
|
||||
ignore_mask=None,
|
||||
weights=None):
|
||||
with tf.Graph().as_default() as g, self.test_session(g):
|
||||
if ignore_mask is not None:
|
||||
ignore_mask = tf.constant(ignore_mask, tf.bool)
|
||||
if weights is not None:
|
||||
weights = tf.constant(weights, tf.float32)
|
||||
metric, update = metrics.streaming_sparse_precision_at_top_k(
|
||||
top_k_predictions=tf.constant(top_k_predictions, tf.int32),
|
||||
labels=labels, class_id=class_id, ignore_mask=ignore_mask,
|
||||
weights=weights)
|
||||
labels=labels, class_id=class_id, weights=weights)
|
||||
|
||||
# Fails without initialized vars.
|
||||
self.assertRaises(tf.OpError, metric.eval)
|
||||
@ -1821,11 +1788,8 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
|
||||
predictions,
|
||||
labels,
|
||||
k,
|
||||
expected,
|
||||
ignore_mask=None):
|
||||
expected):
|
||||
with tf.Graph().as_default() as g, self.test_session(g):
|
||||
if ignore_mask is not None:
|
||||
ignore_mask = tf.constant(ignore_mask, tf.bool)
|
||||
predictions = tf.constant(predictions, tf.float32)
|
||||
metric = metric_ops.sparse_average_precision_at_k(
|
||||
predictions, labels, k)
|
||||
@ -2305,11 +2269,9 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
|
||||
top_k_predictions, labels, expected=NAN, class_id=class_id,
|
||||
weights=[[0, 0], [0, 0]])
|
||||
self._test_streaming_sparse_precision_at_k(
|
||||
predictions, labels, k=5, expected=NAN, ignore_mask=[[False], [True]],
|
||||
weights=[[0], [1]])
|
||||
predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
|
||||
self._test_streaming_sparse_precision_at_top_k(
|
||||
top_k_predictions, labels, expected=NAN,
|
||||
ignore_mask=[[False], [True]], weights=[[0], [1]])
|
||||
top_k_predictions, labels, expected=NAN, weights=[[0], [0]])
|
||||
self._test_streaming_sparse_precision_at_k(
|
||||
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
|
||||
self._test_streaming_sparse_precision_at_top_k(
|
||||
@ -2342,34 +2304,34 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
|
||||
# Class 2: 2 predictions, both correct.
|
||||
self._test_streaming_sparse_precision_at_k(
|
||||
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
|
||||
ignore_mask=[[False], [False]], weights=[[1], [0]])
|
||||
weights=[[1], [0]])
|
||||
self._test_streaming_sparse_precision_at_top_k(
|
||||
top_k_predictions, labels, expected=2.0 / 2.0, class_id=2,
|
||||
ignore_mask=[[False], [False]], weights=[[1], [0]])
|
||||
weights=[[1], [0]])
|
||||
|
||||
# Class 2: 2 predictions, both correct.
|
||||
self._test_streaming_sparse_precision_at_k(
|
||||
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
|
||||
ignore_mask=[[False], [False]], weights=[[0], [1]])
|
||||
weights=[[0], [1]])
|
||||
self._test_streaming_sparse_precision_at_top_k(
|
||||
top_k_predictions, labels, expected=2.0 / 2.0, class_id=2,
|
||||
ignore_mask=[[False], [False]], weights=[[0], [1]])
|
||||
weights=[[0], [1]])
|
||||
|
||||
# Class 7: 1 incorrect prediction.
|
||||
self._test_streaming_sparse_precision_at_k(
|
||||
predictions, labels, k=5, expected=0.0 / 1.0, class_id=7,
|
||||
ignore_mask=[[False], [True]], weights=[[1], [1]])
|
||||
weights=[[1], [0]])
|
||||
self._test_streaming_sparse_precision_at_top_k(
|
||||
top_k_predictions, labels, expected=0.0 / 1.0, class_id=7,
|
||||
ignore_mask=[[False], [True]], weights=[[1], [1]])
|
||||
weights=[[1], [0]])
|
||||
|
||||
# Class 7: 1 correct prediction.
|
||||
self._test_streaming_sparse_precision_at_k(
|
||||
predictions, labels, k=5, expected=1.0 / 1.0, class_id=7,
|
||||
ignore_mask=[[True], [False]], weights=[[1], [1]])
|
||||
weights=[[0], [1]])
|
||||
self._test_streaming_sparse_precision_at_top_k(
|
||||
top_k_predictions, labels, expected=1.0 / 1.0, class_id=7,
|
||||
ignore_mask=[[True], [False]], weights=[[1], [1]])
|
||||
weights=[[0], [1]])
|
||||
|
||||
# Class 7: no predictions.
|
||||
self._test_streaming_sparse_precision_at_k(
|
||||
@ -2409,17 +2371,13 @@ class StreamingSparseRecallTest(tf.test.TestCase):
|
||||
k,
|
||||
expected,
|
||||
class_id=None,
|
||||
ignore_mask=None,
|
||||
weights=None):
|
||||
with tf.Graph().as_default() as g, self.test_session(g):
|
||||
if ignore_mask is not None:
|
||||
ignore_mask = tf.constant(ignore_mask, tf.bool)
|
||||
if weights is not None:
|
||||
weights = tf.constant(weights, tf.float32)
|
||||
metric, update = metrics.streaming_sparse_recall_at_k(
|
||||
predictions=tf.constant(predictions, tf.float32),
|
||||
labels=labels, k=k, class_id=class_id, ignore_mask=ignore_mask,
|
||||
weights=weights)
|
||||
labels=labels, k=k, class_id=class_id, weights=weights)
|
||||
|
||||
# Fails without initialized vars.
|
||||
self.assertRaises(tf.OpError, metric.eval)
|
||||
@ -2740,8 +2698,7 @@ class StreamingSparseRecallTest(tf.test.TestCase):
|
||||
predictions, labels, k=5, expected=NAN, class_id=class_id,
|
||||
weights=[[0, 0], [0, 0]])
|
||||
self._test_streaming_sparse_recall_at_k(
|
||||
predictions, labels, k=5, expected=NAN, ignore_mask=[[False], [True]],
|
||||
weights=[[0], [1]])
|
||||
predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
|
||||
self._test_streaming_sparse_recall_at_k(
|
||||
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
|
||||
|
||||
@ -2764,22 +2721,22 @@ class StreamingSparseRecallTest(tf.test.TestCase):
|
||||
# Class 2: 2 labels, both correct.
|
||||
self._test_streaming_sparse_recall_at_k(
|
||||
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
|
||||
ignore_mask=[[False], [False]], weights=[[1], [0]])
|
||||
weights=[[1], [0]])
|
||||
|
||||
# Class 2: 2 labels, both correct.
|
||||
self._test_streaming_sparse_recall_at_k(
|
||||
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
|
||||
ignore_mask=[[False], [False]], weights=[[0], [1]])
|
||||
weights=[[0], [1]])
|
||||
|
||||
# Class 7: 1 label, correct.
|
||||
self._test_streaming_sparse_recall_at_k(
|
||||
predictions, labels, k=5, expected=1.0 / 1.0, class_id=7,
|
||||
ignore_mask=[[True], [False]], weights=[[1], [1]])
|
||||
weights=[[0], [1]])
|
||||
|
||||
# Class 7: 1 label, incorrect.
|
||||
self._test_streaming_sparse_recall_at_k(
|
||||
predictions, labels, k=5, expected=0.0 / 1.0, class_id=7,
|
||||
ignore_mask=[[False], [True]], weights=[[1], [1]])
|
||||
weights=[[1], [0]])
|
||||
|
||||
# Class 7: 2 labels, 1 correct.
|
||||
self._test_streaming_sparse_recall_at_k(
|
||||
@ -3660,16 +3617,14 @@ class PcntBelowThreshTest(tf.test.TestCase):
|
||||
def testSomePresentOneUpdate(self):
|
||||
with self.test_session() as sess:
|
||||
values = tf.constant([2, 4, 6, 8], shape=(1, 4), dtype=tf.float32)
|
||||
mask = tf.constant([False, True, False, False], shape=(1, 4),
|
||||
dtype=tf.bool)
|
||||
weights = tf.constant([1, 1, 0, 1], shape=(1, 4), dtype=tf.float32)
|
||||
weights = tf.constant([1, 0, 0, 1], shape=(1, 4), dtype=tf.float32)
|
||||
|
||||
pcnt0, update_op0 = metrics.streaming_percentage_less(
|
||||
values, 100, ignore_mask=mask, weights=weights, name='high')
|
||||
values, 100, weights=weights, name='high')
|
||||
pcnt1, update_op1 = metrics.streaming_percentage_less(
|
||||
values, 7, ignore_mask=mask, weights=weights, name='medium')
|
||||
values, 7, weights=weights, name='medium')
|
||||
pcnt2, update_op2 = metrics.streaming_percentage_less(
|
||||
values, 1, ignore_mask=mask, weights=weights, name='low')
|
||||
values, 1, weights=weights, name='low')
|
||||
|
||||
sess.run(tf.initialize_local_variables())
|
||||
self.assertListEqual([1.0, 0.5, 0.0],
|
||||
@ -3712,22 +3667,6 @@ class StreamingMeanIOUTest(tf.test.TestCase):
|
||||
metrics.streaming_mean_iou(
|
||||
predictions, labels, num_classes=2)
|
||||
|
||||
def testLabelsAndIgnoreMaskOfDifferentSizeRaisesValueError(self):
|
||||
predictions = tf.ones([10])
|
||||
labels = tf.ones([10])
|
||||
ignore_mask = tf.cast(tf.ones([9]), tf.bool)
|
||||
with self.assertRaises(ValueError):
|
||||
metrics.streaming_mean_iou(
|
||||
predictions, labels, num_classes=2, ignore_mask=ignore_mask)
|
||||
|
||||
def testIgnoreMaskIsNotBooleanRaisesTypeError(self):
|
||||
predictions = tf.ones([10])
|
||||
labels = tf.ones([10])
|
||||
ignore_mask = tf.ones([10])
|
||||
with self.assertRaises(TypeError):
|
||||
metrics.streaming_mean_iou(
|
||||
predictions, labels, num_classes=2, ignore_mask=ignore_mask)
|
||||
|
||||
def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self):
|
||||
predictions = tf.ones([10])
|
||||
labels = tf.ones([10])
|
||||
@ -3810,29 +3749,18 @@ class StreamingMeanIOUTest(tf.test.TestCase):
|
||||
_enqueue_vector(sess, labels_queue, [1])
|
||||
labels = labels_queue.dequeue()
|
||||
|
||||
# Create the queue that populates the ignore_masks.
|
||||
ignore_masks_queue = tf.FIFOQueue(6, dtypes=tf.bool, shapes=(1, 1))
|
||||
_enqueue_vector(sess, ignore_masks_queue, [False])
|
||||
_enqueue_vector(sess, ignore_masks_queue, [False])
|
||||
_enqueue_vector(sess, ignore_masks_queue, [False])
|
||||
_enqueue_vector(sess, ignore_masks_queue, [True])
|
||||
_enqueue_vector(sess, ignore_masks_queue, [False])
|
||||
_enqueue_vector(sess, ignore_masks_queue, [False])
|
||||
ignore_mask = ignore_masks_queue.dequeue()
|
||||
|
||||
# Create the queue that populates the weights.
|
||||
weights_queue = tf.FIFOQueue(6, dtypes=tf.float32, shapes=(1, 1))
|
||||
_enqueue_vector(sess, weights_queue, [1.0])
|
||||
_enqueue_vector(sess, weights_queue, [1.0])
|
||||
_enqueue_vector(sess, weights_queue, [1.0])
|
||||
_enqueue_vector(sess, weights_queue, [1.0])
|
||||
_enqueue_vector(sess, weights_queue, [0.0])
|
||||
_enqueue_vector(sess, weights_queue, [1.0])
|
||||
_enqueue_vector(sess, weights_queue, [0.0])
|
||||
weights = weights_queue.dequeue()
|
||||
|
||||
miou, update_op = metrics.streaming_mean_iou(
|
||||
predictions, labels, num_classes, ignore_mask=ignore_mask,
|
||||
weights=weights)
|
||||
predictions, labels, num_classes, weights=weights)
|
||||
|
||||
sess.run(tf.initialize_local_variables())
|
||||
for _ in range(6):
|
||||
@ -3920,13 +3848,12 @@ class StreamingMeanIOUTest(tf.test.TestCase):
|
||||
labels = tf.concat(0, [tf.constant(0, shape=[3]),
|
||||
tf.constant(1, shape=[7])])
|
||||
num_classes = 2
|
||||
mask = tf.concat(0, [tf.constant(False, shape=[9]),
|
||||
tf.constant(True, shape=[1])])
|
||||
weights = tf.concat(0, [tf.constant(0, shape=[1]),
|
||||
tf.constant(1, shape=[9])])
|
||||
tf.constant(1, shape=[8]),
|
||||
tf.constant(0, shape=[1])])
|
||||
with self.test_session() as sess:
|
||||
miou, update_op = metrics.streaming_mean_iou(
|
||||
predictions, labels, num_classes, ignore_mask=mask, weights=weights)
|
||||
predictions, labels, num_classes, weights=weights)
|
||||
sess.run(tf.initialize_local_variables())
|
||||
self.assertAllEqual([[2, 2], [0, 4]], update_op.eval())
|
||||
desired_miou = np.mean([2./4., 4./6.])
|
||||
|
@ -100,7 +100,7 @@ class ExternalOptimizerInterface(object):
|
||||
accumulated_dims[1:])]
|
||||
|
||||
def minimize(self, session=None, feed_dict=None, fetches=None,
|
||||
step_callback=None, loss_callback=None, grad_callback=None):
|
||||
step_callback=None, loss_callback=None):
|
||||
"""Minimize a scalar `Tensor`.
|
||||
|
||||
Variables subject to optimization are updated in-place at the end of
|
||||
@ -113,14 +113,13 @@ class ExternalOptimizerInterface(object):
|
||||
Args:
|
||||
session: A `Session` instance.
|
||||
feed_dict: A feed dict to be passed to calls to `session.run`.
|
||||
fetches: A list of `Tensor`s to fetch and supply to `loss_callback` and
|
||||
`grad_callback` as positional arguments.
|
||||
fetches: A list of `Tensor`s to fetch and supply to `loss_callback`
|
||||
as positional arguments.
|
||||
step_callback: A function to be called at each optimization step;
|
||||
arguments are the current values of all optimization variables
|
||||
flattened into a single vector.
|
||||
loss_callback: A function to be called every time the loss and gradients
|
||||
are computed, with evaluated fetches supplied as positional arguments.
|
||||
grad_callback: Deprecated.
|
||||
"""
|
||||
session = session or ops.get_default_session()
|
||||
feed_dict = feed_dict or {}
|
||||
@ -128,9 +127,6 @@ class ExternalOptimizerInterface(object):
|
||||
|
||||
loss_callback = loss_callback or (lambda *fetches: None)
|
||||
step_callback = step_callback or (lambda xk: None)
|
||||
# TODO(chapelle): Remove grad_callback (b/30590858)
|
||||
if grad_callback:
|
||||
logging.warn('grad_callback is deprecated. Please use loss_callback.')
|
||||
|
||||
# Construct loss function and associated gradient.
|
||||
loss_grad_func = self._make_eval_func(
|
||||
|
@ -62,7 +62,8 @@ from tensorflow.python.training import saver
|
||||
class MovingAverageOptimizer(optimizer.Optimizer):
|
||||
"""Optimizer wrapper that maintains a moving average of parameters."""
|
||||
|
||||
def __init__(self, opt, average_decay=0.9999, sequential_update=True):
|
||||
def __init__(self, opt, average_decay=0.9999, num_updates=None,
|
||||
sequential_update=True):
|
||||
"""Construct a new MovingAverageOptimizer.
|
||||
|
||||
Args:
|
||||
@ -70,6 +71,8 @@ class MovingAverageOptimizer(optimizer.Optimizer):
|
||||
average_decay: Float. Decay to use to maintain the moving averages
|
||||
of trained variables.
|
||||
See tf.train.ExponentialMovingAverage for details.
|
||||
num_updates: Optional count of number of updates applied to variables.
|
||||
See tf.train.ExponentialMovingAverage for details.
|
||||
sequential_update: Bool. If False, will compute the moving average at the
|
||||
same time as the model is updated, potentially doing
|
||||
benign data races.
|
||||
@ -77,7 +80,8 @@ class MovingAverageOptimizer(optimizer.Optimizer):
|
||||
updates.
|
||||
"""
|
||||
self._optimizer = opt
|
||||
self._ema = moving_averages.ExponentialMovingAverage(average_decay)
|
||||
self._ema = moving_averages.ExponentialMovingAverage(
|
||||
average_decay, num_updates=num_updates)
|
||||
self._variable_map = None
|
||||
self._sequential_update = sequential_update
|
||||
|
||||
|
@ -181,6 +181,24 @@ tf_gen_op_libs(
|
||||
op_lib_names = ["lstm_ops"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "gru_ops_kernels",
|
||||
srcs = [
|
||||
"kernels/blas_gemm.cc",
|
||||
"kernels/blas_gemm.h",
|
||||
],
|
||||
gpu_srcs = [
|
||||
"kernels/blas_gemm.h",
|
||||
],
|
||||
prefix = "kernels/gru_ops",
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/kernels:eigen_helpers",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "lstm_ops_kernels",
|
||||
srcs = [
|
||||
|
@ -37,7 +37,6 @@ perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
|
||||
perftools::gputools::Stream* stream,
|
||||
bool transa, bool transb, uint64 m,
|
||||
uint64 n, uint64 k, T alpha, const T* a,
|
||||
int lda, const T* b, int ldb, T beta, T* c,
|
||||
@ -52,7 +51,8 @@ void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
|
||||
auto c_ptr = AsDeviceMemory(c);
|
||||
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
ctx->op_device_context()
|
||||
->stream()
|
||||
->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr,
|
||||
lda, b_ptr, ldb, beta, &c_ptr, ldc)
|
||||
.ok();
|
||||
|
@ -21,22 +21,15 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/eigen_activations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
class Stream;
|
||||
} // end namespace gputools
|
||||
} // end namespace perftools
|
||||
|
||||
namespace tensorflow {
|
||||
class OpKernelContext;
|
||||
namespace functor {
|
||||
|
||||
template <typename T>
|
||||
struct TensorCuBlasGemm {
|
||||
void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
|
||||
bool transa, bool transb, uint64 m, uint64 n, uint64 k,
|
||||
T alpha, const T* a, int lda, const T* b, int ldb, T beta,
|
||||
T* c, int ldc);
|
||||
void operator()(OpKernelContext* ctx, bool transa, bool transb, uint64 m,
|
||||
uint64 n, uint64 k, T alpha, const T* a, int lda, const T* b,
|
||||
int ldb, T beta, T* c, int ldc);
|
||||
};
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
@ -44,16 +37,15 @@ struct TensorBlasGemm;
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
|
||||
static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
|
||||
const Device& d, bool transa, bool transb, T alpha,
|
||||
typename TTypes<T>::ConstMatrix a,
|
||||
static void compute(OpKernelContext* ctx, const Device& d, bool transa,
|
||||
bool transb, T alpha, typename TTypes<T>::ConstMatrix a,
|
||||
typename TTypes<T>::ConstMatrix b, T beta,
|
||||
typename TTypes<T>::Matrix c) {
|
||||
int64 m = c.dimensions()[0];
|
||||
int64 n = c.dimensions()[1];
|
||||
int64 k = transa ? a.dimensions()[0] : a.dimensions()[1];
|
||||
|
||||
TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(),
|
||||
TensorCuBlasGemm<T>()(ctx, transb, transa, n, m, k, alpha, b.data(),
|
||||
transb ? k : n, a.data(), transa ? m : k, beta,
|
||||
c.data(), n);
|
||||
}
|
||||
@ -61,9 +53,8 @@ struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> {
|
||||
static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
|
||||
const Device& d, bool transa, bool transb, T alpha,
|
||||
typename TTypes<T>::ConstMatrix a,
|
||||
static void compute(OpKernelContext* ctx, const Device& d, bool transa,
|
||||
bool transb, T alpha, typename TTypes<T>::ConstMatrix a,
|
||||
typename TTypes<T>::ConstMatrix b, T beta,
|
||||
typename TTypes<T>::Matrix c) {
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
||||
|
@ -15,10 +15,6 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/contrib/rnn/kernels/gru_ops.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -151,14 +147,9 @@ class GRUCellBlockOp : public OpKernel {
|
||||
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
|
||||
perftools::gputools::Stream* stream =
|
||||
std::is_same<Device, GPUDevice>::value
|
||||
? ctx->op_device_context()->stream()
|
||||
: nullptr;
|
||||
|
||||
functor::GRUBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
ctx, stream, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
|
||||
ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
|
||||
w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
|
||||
b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_u_bar_tensor.matrix<T>(),
|
||||
r_tensor->matrix<T>(), u_tensor->matrix<T>(), c_tensor->matrix<T>(),
|
||||
@ -362,14 +353,10 @@ class GRUBlockCellGradOp : public OpKernel {
|
||||
&d_x_component_2_h_prevr));
|
||||
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
perftools::gputools::Stream* stream =
|
||||
std::is_same<Device, GPUDevice>::value
|
||||
? ctx->op_device_context()->stream()
|
||||
: nullptr;
|
||||
|
||||
functor::GRUBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
ctx, stream, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
|
||||
ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
|
||||
w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
|
||||
b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_tensor->matrix<T>(),
|
||||
u_tensor->matrix<T>(), c_tensor->matrix<T>(), d_h_tensor->matrix<T>(),
|
||||
@ -400,8 +387,8 @@ namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void GRUBlockCellFprop<GPUDevice, T, true>::operator()( \
|
||||
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
|
||||
const GPUDevice& d, typename TTypes<T>::ConstMatrix x, \
|
||||
OpKernelContext* ctx, const GPUDevice& d, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w_ru, \
|
||||
typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \
|
||||
@ -430,9 +417,9 @@ namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void GRUBlockCellBprop<GPUDevice, T, true>::operator()( \
|
||||
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
|
||||
const GPUDevice& d, typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix h, typename TTypes<T>::ConstMatrix w_ru, \
|
||||
OpKernelContext* ctx, const GPUDevice& d, \
|
||||
typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix h, \
|
||||
typename TTypes<T>::ConstMatrix w_ru, \
|
||||
typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \
|
||||
typename TTypes<T>::ConstVec b_c, typename TTypes<T>::ConstMatrix r, \
|
||||
typename TTypes<T>::ConstMatrix u, typename TTypes<T>::ConstMatrix c, \
|
||||
|
@ -21,12 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
class Stream;
|
||||
} // end namespace gputools
|
||||
} // end namespace perftools
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpKernelContext;
|
||||
@ -77,18 +71,15 @@ struct GRUBlockCellFprop : public GRUCell {
|
||||
const int cell_size)
|
||||
: GRUCell(batch_size, input_size, cell_size) {}
|
||||
|
||||
void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
|
||||
const Device& d, typename TTypes<T>::ConstMatrix x,
|
||||
typename TTypes<T>::ConstMatrix h_prev,
|
||||
typename TTypes<T>::ConstMatrix w_ru,
|
||||
typename TTypes<T>::ConstMatrix w_c,
|
||||
typename TTypes<T>::ConstVec b_ru,
|
||||
typename TTypes<T>::ConstVec b_c,
|
||||
typename TTypes<T>::Matrix r_u_bar,
|
||||
typename TTypes<T>::Matrix r, typename TTypes<T>::Matrix u,
|
||||
typename TTypes<T>::Matrix c, typename TTypes<T>::Matrix h,
|
||||
typename TTypes<T>::Matrix x_h_prev,
|
||||
typename TTypes<T>::Matrix x_h_prevr) {
|
||||
void operator()(
|
||||
OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
|
||||
typename TTypes<T>::ConstMatrix h_prev,
|
||||
typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
|
||||
typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
|
||||
typename TTypes<T>::Matrix r_u_bar, typename TTypes<T>::Matrix r,
|
||||
typename TTypes<T>::Matrix u, typename TTypes<T>::Matrix c,
|
||||
typename TTypes<T>::Matrix h, typename TTypes<T>::Matrix x_h_prev,
|
||||
typename TTypes<T>::Matrix x_h_prevr) {
|
||||
// Concat x_h_prev = [x, h_prev].
|
||||
x_h_prev.slice(x_offsets(), x_extends()).device(d) = x;
|
||||
x_h_prev.slice(h_offsets(), h_extends()).device(d) = h_prev;
|
||||
@ -96,9 +87,8 @@ struct GRUBlockCellFprop : public GRUCell {
|
||||
// r_u_bar = x_h_prev * w_ru + b_ru
|
||||
typename TTypes<T>::ConstMatrix const_x_h_prev(x_h_prev.data(),
|
||||
x_h_prev.dimensions());
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, stream, d, false, false,
|
||||
T(1), const_x_h_prev, w_ru,
|
||||
T(0), r_u_bar);
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
|
||||
ctx, d, false, false, T(1), const_x_h_prev, w_ru, T(0), r_u_bar);
|
||||
|
||||
// Creating a bias matrix for adding by broadcasting 'b_ru'
|
||||
Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
|
||||
@ -117,7 +107,7 @@ struct GRUBlockCellFprop : public GRUCell {
|
||||
typename TTypes<T>::ConstMatrix const_x_h_prevr(x_h_prevr.data(),
|
||||
x_h_prevr.dimensions());
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
|
||||
ctx, stream, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c);
|
||||
ctx, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c);
|
||||
|
||||
Eigen::array<Eigen::DenseIndex, 2> b_c_shape({1, b_c.dimensions()[0]});
|
||||
c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape));
|
||||
@ -135,8 +125,7 @@ struct GRUBlockCellBprop : public GRUCell {
|
||||
: GRUCell(batch_size, input_size, cell_size) {}
|
||||
|
||||
void operator()(
|
||||
OpKernelContext* ctx, perftools::gputools::Stream* stream,
|
||||
const Device& d, typename TTypes<T>::ConstMatrix x,
|
||||
OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
|
||||
typename TTypes<T>::ConstMatrix h_prev,
|
||||
typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
|
||||
typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
|
||||
@ -159,9 +148,9 @@ struct GRUBlockCellBprop : public GRUCell {
|
||||
// [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T
|
||||
typename TTypes<T>::ConstMatrix const_d_c_bar(d_c_bar.data(),
|
||||
d_c_bar.dimensions());
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, stream, d, false, true,
|
||||
T(1), const_d_c_bar, w_c,
|
||||
T(0), d_x_comp2_and_h_prevr);
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, d, false, true, T(1),
|
||||
const_d_c_bar, w_c, T(0),
|
||||
d_x_comp2_and_h_prevr);
|
||||
|
||||
d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends());
|
||||
d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r);
|
||||
@ -175,7 +164,7 @@ struct GRUBlockCellBprop : public GRUCell {
|
||||
typename TTypes<T>::ConstMatrix const_d_r_bar_u_bar(
|
||||
d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions());
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
|
||||
ctx, stream, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0),
|
||||
ctx, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0),
|
||||
d_x_comp1_and_h_prev_comp1);
|
||||
|
||||
// d_x = d_x_comp1 + d_x_comp2
|
||||
|
@ -34,10 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
@ -164,14 +160,10 @@ class LSTMBlockCellOp : public OpKernel {
|
||||
&icfo_tensor));
|
||||
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
perftools::gputools::Stream* stream =
|
||||
std::is_same<Device, GPUDevice>::value
|
||||
? ctx->op_device_context()->stream()
|
||||
: nullptr;
|
||||
|
||||
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
ctx, stream, device, forget_bias_, cell_clip_, use_peephole_,
|
||||
ctx, device, forget_bias_, cell_clip_, use_peephole_,
|
||||
x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
|
||||
h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
|
||||
wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
|
||||
@ -196,22 +188,21 @@ REGISTER_KERNEL(float);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
|
||||
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
|
||||
const GPUDevice& d, const T forget_bias, const T cell_clip, \
|
||||
bool use_peephole, typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
||||
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
||||
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||
typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \
|
||||
\
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, const T forget_bias, \
|
||||
const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
||||
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
||||
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||
typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \
|
||||
\
|
||||
extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
@ -445,10 +436,6 @@ class LSTMBlockCellGradOp : public OpKernel {
|
||||
&di_tensor));
|
||||
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
perftools::gputools::Stream* stream =
|
||||
std::is_same<Device, GPUDevice>::value
|
||||
? ctx->op_device_context()->stream()
|
||||
: nullptr;
|
||||
|
||||
functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>());
|
||||
functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>());
|
||||
@ -456,7 +443,7 @@ class LSTMBlockCellGradOp : public OpKernel {
|
||||
|
||||
functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
ctx, stream, device, use_peephole_, x_tensor->matrix<T>(),
|
||||
ctx, device, use_peephole_, x_tensor->matrix<T>(),
|
||||
cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
|
||||
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
|
||||
wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(),
|
||||
@ -486,8 +473,7 @@ namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void LSTMBlockCellBprop<GPUDevice, T, true>::operator()( \
|
||||
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
|
||||
const GPUDevice& d, bool use_peephole, \
|
||||
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
@ -769,10 +755,6 @@ class BlockLSTMOp : public OpKernel {
|
||||
&icfo_tensor));
|
||||
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
perftools::gputools::Stream* stream =
|
||||
std::is_same<Device, GPUDevice>::value
|
||||
? ctx->op_device_context()->stream()
|
||||
: nullptr;
|
||||
|
||||
const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
|
||||
SliceHelper<Device, T> slicer(ctx);
|
||||
@ -794,7 +776,7 @@ class BlockLSTMOp : public OpKernel {
|
||||
|
||||
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
ctx, stream, device, forget_bias_, cell_clip_, use_peephole_,
|
||||
ctx, device, forget_bias_, cell_clip_, use_peephole_,
|
||||
x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
|
||||
h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
|
||||
wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
|
||||
@ -1020,10 +1002,6 @@ class BlockLSTMGradOp : public OpKernel {
|
||||
|
||||
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
perftools::gputools::Stream* stream =
|
||||
std::is_same<Device, GPUDevice>::value
|
||||
? ctx->op_device_context()->stream()
|
||||
: nullptr;
|
||||
|
||||
functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<float>());
|
||||
functor::TensorZero<Device, T>()(device,
|
||||
@ -1073,7 +1051,7 @@ class BlockLSTMGradOp : public OpKernel {
|
||||
Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
|
||||
functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
ctx, stream, device, use_peephole_, x_tensor.matrix<T>(),
|
||||
ctx, device, use_peephole_, x_tensor.matrix<T>(),
|
||||
cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
|
||||
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
|
||||
wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(),
|
||||
@ -1134,8 +1112,7 @@ namespace functor {
|
||||
\
|
||||
template <> \
|
||||
void BlockLSTMBprop<GPUDevice, T, true>::operator()( \
|
||||
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
|
||||
const GPUDevice& d, bool use_peephole, \
|
||||
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
|
@ -22,12 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/eigen_activations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
class Stream;
|
||||
} // end namespace gputools
|
||||
} // end namespace perftools
|
||||
|
||||
namespace tensorflow {
|
||||
class OpKernelContext;
|
||||
|
||||
@ -153,29 +147,26 @@ struct LSTMBlockCellFprop : public LSTMBlockCell {
|
||||
const int cell_size)
|
||||
: LSTMBlockCell(batch_size, input_size, cell_size) {}
|
||||
|
||||
void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
|
||||
const Device& d, const T forget_bias, const T cell_clip,
|
||||
bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
||||
typename TTypes<T>::ConstMatrix cs_prev,
|
||||
typename TTypes<T>::ConstMatrix h_prev,
|
||||
typename TTypes<T>::ConstMatrix w,
|
||||
typename TTypes<T>::ConstVec wci,
|
||||
typename TTypes<T>::ConstVec wcf,
|
||||
typename TTypes<T>::ConstVec wco,
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,
|
||||
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,
|
||||
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,
|
||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,
|
||||
typename TTypes<T>::Matrix icfo,
|
||||
typename TTypes<T>::Matrix h) {
|
||||
void operator()(
|
||||
OpKernelContext* ctx, const Device& d, const T forget_bias,
|
||||
const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
||||
typename TTypes<T>::ConstMatrix cs_prev,
|
||||
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
|
||||
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
|
||||
typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
|
||||
typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
|
||||
typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
|
||||
typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
|
||||
typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
|
||||
typename TTypes<T>::Matrix h) {
|
||||
// Concat xh = [x, h].
|
||||
xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
|
||||
xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev;
|
||||
|
||||
// states1 = xh * w + b
|
||||
typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
|
||||
ctx, stream, d, false, false, T(1), const_xh, w, T(0), icfo);
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, d, false, false, T(1),
|
||||
const_xh, w, T(0), icfo);
|
||||
Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
|
||||
Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
|
||||
icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
|
||||
@ -239,8 +230,8 @@ struct LSTMBlockCellBprop : public LSTMBlockCell {
|
||||
: LSTMBlockCell(batch_size, input_size, cell_size) {}
|
||||
|
||||
void operator()(
|
||||
OpKernelContext* ctx, perftools::gputools::Stream* stream,
|
||||
const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
||||
OpKernelContext* ctx, const Device& d, bool use_peephole,
|
||||
typename TTypes<T>::ConstMatrix x,
|
||||
typename TTypes<T>::ConstMatrix cs_prev,
|
||||
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
|
||||
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
|
||||
@ -305,8 +296,8 @@ struct BlockLSTMBprop : public LSTMBlockCell {
|
||||
: LSTMBlockCell(batch_size, input_size, cell_size) {}
|
||||
|
||||
void operator()(
|
||||
OpKernelContext* ctx, perftools::gputools::Stream* stream,
|
||||
const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
||||
OpKernelContext* ctx, const Device& d, bool use_peephole,
|
||||
typename TTypes<T>::ConstMatrix x,
|
||||
typename TTypes<T>::ConstMatrix cs_prev,
|
||||
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
|
||||
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
|
||||
@ -364,7 +355,7 @@ struct BlockLSTMBprop : public LSTMBlockCell {
|
||||
typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(),
|
||||
dicfo.dimensions());
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
|
||||
ctx, stream, d, false, true, T(1), const_dicfo, w, T(0), xh_grad);
|
||||
ctx, d, false, true, T(1), const_dicfo, w, T(0), xh_grad);
|
||||
|
||||
// xh.
|
||||
xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
|
||||
@ -377,7 +368,7 @@ struct BlockLSTMBprop : public LSTMBlockCell {
|
||||
|
||||
// w_grad.
|
||||
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
|
||||
ctx, stream, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad);
|
||||
ctx, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad);
|
||||
|
||||
// b_grad.
|
||||
b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0}));
|
||||
|
@ -1005,7 +1005,7 @@ _linear = rnn_cell._linear
|
||||
class AttentionCellWrapper(rnn_cell.RNNCell):
|
||||
"""Basic attention cell wrapper.
|
||||
|
||||
Implementation based on https://arxiv.org/pdf/1601.06733.pdf.
|
||||
Implementation based on https://arxiv.org/abs/1409.0473.
|
||||
"""
|
||||
|
||||
def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
|
||||
|
@ -51,7 +51,7 @@ from tensorflow.contrib.slim.python.slim.data import parallel_reader
|
||||
class DatasetDataProvider(data_provider.DataProvider):
|
||||
|
||||
def __init__(self, dataset, num_readers=1, shuffle=True, num_epochs=None,
|
||||
common_queue_capacity=256, common_queue_min=128):
|
||||
common_queue_capacity=256, common_queue_min=128, seed=None):
|
||||
"""Creates a DatasetDataProvider.
|
||||
|
||||
Args:
|
||||
@ -64,6 +64,7 @@ class DatasetDataProvider(data_provider.DataProvider):
|
||||
common_queue_capacity: The capacity of the common queue.
|
||||
common_queue_min: The minimum number of elements in the common queue after
|
||||
a dequeue.
|
||||
seed: The seed to use if shuffling.
|
||||
"""
|
||||
_, data = parallel_reader.parallel_read(
|
||||
dataset.data_sources,
|
||||
@ -72,7 +73,8 @@ class DatasetDataProvider(data_provider.DataProvider):
|
||||
num_readers=num_readers,
|
||||
shuffle=shuffle,
|
||||
capacity=common_queue_capacity,
|
||||
min_after_dequeue=common_queue_min)
|
||||
min_after_dequeue=common_queue_min,
|
||||
seed=seed)
|
||||
|
||||
items = dataset.decoder.list_items()
|
||||
tensors = dataset.decoder.decode(data, items)
|
||||
|
@ -170,7 +170,8 @@ def parallel_read(data_sources,
|
||||
shuffle=True,
|
||||
dtypes=None,
|
||||
capacity=256,
|
||||
min_after_dequeue=128):
|
||||
min_after_dequeue=128,
|
||||
seed=None):
|
||||
"""Reads multiple records in parallel from data_sources using n readers.
|
||||
|
||||
It uses a ParallelReader to read from multiple files in parallel using
|
||||
@ -199,6 +200,7 @@ def parallel_read(data_sources,
|
||||
capacity: integer, capacity of the common_queue.
|
||||
min_after_dequeue: integer, minimum number of records in the common_queue
|
||||
after dequeue. Needed for a good shuffle.
|
||||
seed: A seed for RandomShuffleQueue.
|
||||
|
||||
Returns:
|
||||
key, value: a tuple of keys and values from the data_source.
|
||||
@ -212,7 +214,8 @@ def parallel_read(data_sources,
|
||||
common_queue = data_flow_ops.RandomShuffleQueue(
|
||||
capacity=capacity,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
dtypes=dtypes)
|
||||
dtypes=dtypes,
|
||||
seed=seed)
|
||||
else:
|
||||
common_queue = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=dtypes)
|
||||
|
||||
|
@ -471,7 +471,14 @@ def create_train_op(
|
||||
'LossTensor is inf or nan')
|
||||
|
||||
# Ensure the train_tensor computes grad_updates.
|
||||
return control_flow_ops.with_dependencies([grad_updates], total_loss)
|
||||
train_op = control_flow_ops.with_dependencies([grad_updates], total_loss)
|
||||
|
||||
# Add the operation used for training to the 'train_op' collection
|
||||
train_ops = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
||||
if train_op not in train_ops:
|
||||
train_ops.append(train_op)
|
||||
|
||||
return train_op
|
||||
|
||||
|
||||
def _wait_for_step(sess, global_step, step):
|
||||
|
@ -301,6 +301,22 @@ class CreateTrainOpTest(tf.test.TestCase):
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
def testRecordTrainOpInCollection(self):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = LogisticClassifier(tf_inputs)
|
||||
slim.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = slim.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
train_op = slim.learning.create_train_op(total_loss, optimizer)
|
||||
|
||||
# Make sure the training op was recorded in the proper collection
|
||||
self.assertTrue(train_op in tf.get_collection(tf.GraphKeys.TRAIN_OP))
|
||||
|
||||
|
||||
class TrainTest(tf.test.TestCase):
|
||||
|
||||
|
@ -23,43 +23,54 @@ from tensorflow.contrib.metrics.python.ops import metric_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def _accuracy(probabilities, targets):
|
||||
predictions = math_ops.argmax(probabilities, 1)
|
||||
# undo one-hot
|
||||
labels = math_ops.argmax(targets, 1)
|
||||
return metric_ops.streaming_accuracy(predictions, labels)
|
||||
INFERENCE_PROB_NAME = 'inference'
|
||||
INFERENCE_PRED_NAME = 'predictions'
|
||||
|
||||
|
||||
def _r2(probabilities, targets):
|
||||
def _accuracy(predictions, targets, weights=None):
|
||||
return metric_ops.streaming_accuracy(predictions, targets, weights=weights)
|
||||
|
||||
|
||||
def _r2(probabilities, targets, weights=None):
|
||||
if targets.get_shape().ndims == 1:
|
||||
targets = array_ops.expand_dims(targets, -1)
|
||||
targets = math_ops.to_float(targets)
|
||||
y_mean = math_ops.reduce_mean(targets, 0)
|
||||
squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0)
|
||||
squares_residuals = math_ops.reduce_sum(math_ops.square(
|
||||
targets - probabilities), 0)
|
||||
score = 1 - math_ops.reduce_sum(squares_residuals / squares_total)
|
||||
return metric_ops.streaming_mean(score)
|
||||
return metric_ops.streaming_mean(score, weights=weights)
|
||||
|
||||
|
||||
def _sigmoid_entropy(probabilities, targets):
|
||||
def _squeeze_and_onehot(targets, depth):
|
||||
targets = array_ops.squeeze(targets, squeeze_dims=[1])
|
||||
return array_ops.one_hot(math_ops.to_int32(targets), depth)
|
||||
|
||||
|
||||
def _sigmoid_entropy(probabilities, targets, weights=None):
|
||||
return metric_ops.streaming_mean(losses.sigmoid_cross_entropy(
|
||||
probabilities, targets))
|
||||
probabilities, _squeeze_and_onehot(targets,
|
||||
array_ops.shape(probabilities)[1])),
|
||||
weights=weights)
|
||||
|
||||
|
||||
def _softmax_entropy(probabilities, targets):
|
||||
return metric_ops.streaming_mean(losses.softmax_cross_entropy(
|
||||
probabilities, targets))
|
||||
def _softmax_entropy(probabilities, targets, weights=None):
|
||||
return metric_ops.streaming_mean(losses.sparse_softmax_cross_entropy(
|
||||
probabilities, math_ops.to_int32(targets)),
|
||||
weights=weights)
|
||||
|
||||
|
||||
def _predictions(probabilities, unused_targets):
|
||||
return math_ops.argmax(probabilities, 1)
|
||||
def _predictions(predictions, unused_targets, **unused_kwargs):
|
||||
return predictions
|
||||
|
||||
|
||||
def _log_loss(probabilities, targets):
|
||||
# targets doesn't have a shape coming in, log_loss isn't too happy about it.
|
||||
targets = array_ops.reshape(targets, array_ops.shape(probabilities))
|
||||
return metric_ops.streaming_mean(losses.log_loss(probabilities, targets))
|
||||
def _class_log_loss(probabilities, targets, weights=None):
|
||||
return metric_ops.streaming_mean(
|
||||
losses.log_loss(probabilities,
|
||||
_squeeze_and_onehot(targets,
|
||||
array_ops.shape(probabilities)[1])),
|
||||
weights=weights)
|
||||
|
||||
|
||||
_EVAL_METRICS = {'sigmoid_entropy': _sigmoid_entropy,
|
||||
@ -67,9 +78,21 @@ _EVAL_METRICS = {'sigmoid_entropy': _sigmoid_entropy,
|
||||
'accuracy': _accuracy,
|
||||
'r2': _r2,
|
||||
'predictions': _predictions,
|
||||
'log_loss': _log_loss}
|
||||
'classification_log_loss': _class_log_loss}
|
||||
|
||||
|
||||
_PREDICTION_KEYS = {'sigmoid_entropy': INFERENCE_PROB_NAME,
|
||||
'softmax_entropy': INFERENCE_PROB_NAME,
|
||||
'accuracy': INFERENCE_PRED_NAME,
|
||||
'r2': INFERENCE_PROB_NAME,
|
||||
'predictions': INFERENCE_PRED_NAME,
|
||||
'classification_log_loss': INFERENCE_PROB_NAME}
|
||||
|
||||
|
||||
def get_metric(metric_name):
|
||||
"""Given a metric name, return the corresponding metric function."""
|
||||
return _EVAL_METRICS[metric_name]
|
||||
|
||||
|
||||
def get_prediction_key(metric_name):
|
||||
return _PREDICTION_KEYS[metric_name]
|
||||
|
@ -17,10 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import threading
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.learn_io import graph_io
|
||||
from tensorflow.contrib.tensor_forest.python import constants
|
||||
|
||||
from tensorflow.python.framework import common_shapes
|
||||
@ -35,8 +33,6 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
DATA_OPS_FILE = '_data_ops.so'
|
||||
|
||||
EXAMPLE_WEIGHT_NAME = '__weight__'
|
||||
|
||||
_data_ops = None
|
||||
_ops_lock = threading.Lock()
|
||||
|
||||
@ -69,68 +65,28 @@ def Load():
|
||||
def _ParseSparse(data):
|
||||
"""Concat sparse tensors together.
|
||||
|
||||
A common use of sparse tensors is to treat strings as a sparse bit vector
|
||||
with a large number of features representing the presence of all possible
|
||||
values. Here we convert these strings to integer indices in a sparse bit
|
||||
tensor. In order to pack each incoming feature into a single sparse tensor,
|
||||
we add an offset to the converted indices to indicate that they came from
|
||||
different features in the source data.
|
||||
|
||||
Args:
|
||||
data: A dict of name -> Tensor.
|
||||
|
||||
Returns:
|
||||
A single sparse tensor with float values and a 1-D input spec Tensor.
|
||||
A single sparse tensor and a 1-D input spec Tensor.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Combining dense and sparse tensors is not yet
|
||||
NotImplementedError: Combining dense and sparse tensors is not
|
||||
supported.
|
||||
ValueError: If data contains non-string Tensors.
|
||||
"""
|
||||
convert_ops = Load()
|
||||
|
||||
# Sparse tensor indices have 63 bits to use for information. We use the
|
||||
# minimum number of these (MSBs) for the offset, and pack the rest with the
|
||||
# actual data.
|
||||
num_features = len(data)
|
||||
offset_bits = int(math.ceil(math.log(num_features, 2)))
|
||||
|
||||
# We condense data to 26 bits, see sparse_values_to_indices.cc
|
||||
offset_increment = int(math.pow(2, 26 - offset_bits))
|
||||
offset = 0
|
||||
|
||||
sparse_tensors = []
|
||||
keys = None
|
||||
weights = None
|
||||
for k in sorted(data.keys()):
|
||||
if k == graph_io.KEY_FEATURE_NAME:
|
||||
keys = data[k]
|
||||
elif k == EXAMPLE_WEIGHT_NAME:
|
||||
weights = data[k]
|
||||
elif isinstance(data[k], ops.SparseTensor):
|
||||
# TODO(gilberth): Support mixed string/float sparse tensors.
|
||||
# We currently only support string (categorical) data if we're using
|
||||
# sparse tensors.
|
||||
if data[k].dtype != dtypes.string:
|
||||
raise ValueError('Only sparse tensors of type string are supported.')
|
||||
sparse_indices = data[k].indices
|
||||
sparse_values = data[k].values
|
||||
new_shape = array_ops.concat(
|
||||
0, [array_ops.slice(data[k].shape, [0], [1]), [offset_increment]])
|
||||
if not isinstance(data[k], ops.SparseTensor):
|
||||
raise NotImplementedError(
|
||||
'Features should be either all sparse or all dense. Use a '
|
||||
'feature engineering function to convert some of them.')
|
||||
|
||||
new_indices, new_values = convert_ops.sparse_values_to_indices(
|
||||
sparse_indices,
|
||||
sparse_values,
|
||||
offset, offset_bits=offset_bits)
|
||||
sparse_tensors.append(ops.SparseTensor(indices=new_indices,
|
||||
values=new_values,
|
||||
shape=new_shape))
|
||||
else:
|
||||
# Convert dense to sparse.
|
||||
raise NotImplementedError('Dense to sparse conversion not implemented.')
|
||||
|
||||
return (sparse_ops.sparse_concat(1, sparse_tensors), keys, weights,
|
||||
[constants.DATA_CATEGORICAL])
|
||||
data_spec = [
|
||||
constants.DATA_CATEGORICAL if data[data.keys()[0]].dtype == dtypes.string
|
||||
else constants.DATA_FLOAT
|
||||
]
|
||||
return sparse_ops.sparse_concat(1, data.values()), data_spec
|
||||
|
||||
|
||||
def _ParseDense(data):
|
||||
@ -143,22 +99,20 @@ def _ParseDense(data):
|
||||
A tuple of (single dense float Tensor, keys tensor (if exists), data spec).
|
||||
"""
|
||||
convert_ops = Load()
|
||||
data_spec = [constants.DATA_CATEGORICAL if data[k].dtype == dtypes.string else
|
||||
constants.DATA_FLOAT for k in sorted(data.keys())]
|
||||
data_spec = [constants.DATA_CATEGORICAL if (data[k].dtype == dtypes.string or
|
||||
data[k].dtype == dtypes.int32 or
|
||||
data[k].dtype == dtypes.int64)
|
||||
else constants.DATA_FLOAT for k in sorted(data.keys())]
|
||||
data_spec = [constants.DATA_FLOAT] + data_spec
|
||||
keys = None
|
||||
weights = None
|
||||
features = []
|
||||
for k in sorted(data.keys()):
|
||||
if k == graph_io.KEY_FEATURE_NAME:
|
||||
keys = data[k]
|
||||
elif k == EXAMPLE_WEIGHT_NAME:
|
||||
weights = data[k]
|
||||
if data[k].dtype == dtypes.string:
|
||||
features.append(convert_ops.string_to_float(data[k]))
|
||||
elif data[k].dtype == dtypes.int64 or data[k].dtype == dtypes.int32:
|
||||
features.append(math_ops.to_float(data[k]))
|
||||
else:
|
||||
features.append(
|
||||
convert_ops.string_to_float(data[k]) if data[k].dtype == dtypes.string
|
||||
else data[k])
|
||||
return array_ops.concat(1, features), keys, weights, data_spec
|
||||
features.append(data[k])
|
||||
return array_ops.concat(1, features), data_spec
|
||||
|
||||
|
||||
def ParseDataTensorOrDict(data):
|
||||
@ -187,8 +141,7 @@ def ParseDataTensorOrDict(data):
|
||||
else:
|
||||
return _ParseDense(data)
|
||||
else:
|
||||
return (data, None, None,
|
||||
[constants.DATA_FLOAT] * data.get_shape().as_list()[1])
|
||||
return (data, [constants.DATA_FLOAT] * data.get_shape().as_list()[1])
|
||||
|
||||
|
||||
def ParseLabelTensorOrDict(labels):
|
||||
|
@ -19,7 +19,9 @@ from __future__ import print_function
|
||||
|
||||
import math
|
||||
import random
|
||||
import sys
|
||||
|
||||
from tensorflow.contrib.losses.python.losses import loss_ops
|
||||
from tensorflow.contrib.tensor_forest.python import constants
|
||||
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
|
||||
from tensorflow.contrib.tensor_forest.python.ops import training_ops
|
||||
@ -429,8 +431,9 @@ class RandomForestGraphs(object):
|
||||
return math_ops.reduce_mean(math_ops.to_float(array_ops.pack(sizes)))
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def training_loss(self, features, labels):
|
||||
return math_ops.neg(self.average_size())
|
||||
def training_loss(self, features, labels, data_spec=None,
|
||||
name='training_loss'):
|
||||
return math_ops.neg(self.average_size(), name=name)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def validation_loss(self, features, labels):
|
||||
@ -456,6 +459,63 @@ class RandomForestGraphs(object):
|
||||
return ForestStats(tree_stats, self.params)
|
||||
|
||||
|
||||
def one_hot_wrapper(num_classes, loss_fn):
|
||||
"""Some loss functions take one-hot labels."""
|
||||
def _loss(probs, targets):
|
||||
one_hot_labels = array_ops.one_hot(
|
||||
math_ops.to_int32(targets), num_classes,
|
||||
on_value=1., off_value=0., dtype=dtypes.float32)
|
||||
return loss_fn(probs, one_hot_labels)
|
||||
return _loss
|
||||
|
||||
|
||||
class TrainingLossForest(RandomForestGraphs):
|
||||
"""Random Forest that uses training loss as the termination criteria."""
|
||||
|
||||
def __init__(self, params, loss_fn=None, **kwargs):
|
||||
"""Initialize.
|
||||
|
||||
Args:
|
||||
params: Like RandomForestGraphs, a ForestHParams object.
|
||||
loss_fn: A function that takes probabilities and targets and returns
|
||||
a loss for each example.
|
||||
**kwargs: Keyword args to pass to superclass (RandomForestGraphs).
|
||||
"""
|
||||
self.loss_fn = loss_fn or one_hot_wrapper(params.num_classes,
|
||||
loss_ops.log_loss)
|
||||
self._loss = None
|
||||
super(TrainingLossForest, self).__init__(params, **kwargs)
|
||||
|
||||
def _get_loss(self, features, labels, data_spec=None):
|
||||
"""Constructs, caches, and returns the inference-based loss."""
|
||||
if self._loss is not None:
|
||||
return self._loss
|
||||
|
||||
def _average_loss():
|
||||
probs = self.inference_graph(features, data_spec=data_spec)
|
||||
return math_ops.reduce_sum(self.loss_fn(
|
||||
probs, labels)) / math_ops.to_float(
|
||||
array_ops.shape(features)[0])
|
||||
|
||||
self._loss = control_flow_ops.cond(
|
||||
self.average_size() > 0, _average_loss,
|
||||
lambda: constant_op.constant(sys.maxsize, dtype=dtypes.float32))
|
||||
|
||||
return self._loss
|
||||
|
||||
def training_graph(self, input_data, input_labels, data_spec=None,
|
||||
**kwargs):
|
||||
loss = self._get_loss(input_data, input_labels, data_spec=data_spec)
|
||||
with ops.control_dependencies([loss.op]):
|
||||
return super(TrainingLossForest, self).training_graph(
|
||||
input_data, input_labels, **kwargs)
|
||||
|
||||
def training_loss(self, features, labels, data_spec=None,
|
||||
name='training_loss'):
|
||||
return array_ops.identity(
|
||||
self._get_loss(features, labels, data_spec=data_spec), name=name)
|
||||
|
||||
|
||||
class RandomTreeGraphs(object):
|
||||
"""Builds TF graphs for random tree training and inference."""
|
||||
|
||||
|
@ -12,6 +12,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:model_analyzer",
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:tfprof_logger",
|
||||
],
|
||||
)
|
||||
|
@ -20,434 +20,9 @@ and measures system performance.
|
||||
4. Explore model based on name scope or graph structure.
|
||||
5. Selectively grouping/filtering/accounting/ordering ops.
|
||||
|
||||
### Interfaces
|
||||
tfprof can be used as CommandLine Interface (CLI) and Python API.
|
||||
CLI locates in tensorflow/tools/tfprof.
|
||||
Python API locates in tensorflow/contrib/tfprof.
|
||||
Tutorial locates in tensorflow/tools/tfprof/README.md
|
||||
|
||||
[CLI Tutorials](#cli-tutorials):
|
||||
It supports interactive mode for exploration and single-shot mode for
|
||||
scripts. Outputs can be dumped to files or printed in terminal.
|
||||
|
||||
Python API Tutorials: Python API is not released yet.
|
||||
|
||||
## CLI Tutorials
|
||||
|
||||
Tutorials are based on a 32 layers ResNet.
|
||||
TODO(xpan): Provide graph.pbtxt, model.ckpt, tfprof_log and run_meta download.
|
||||
|
||||
### Examples
|
||||
|
||||
1) Start `tfprof` command line tool
|
||||
|
||||
```shell
|
||||
# Build the tool.
|
||||
bazel build -c opt tensorflow/contrib/tfprof/...
|
||||
|
||||
# Help information, including detail 'option' instructions.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof help
|
||||
#
|
||||
# The following commands will start tfprof interactive mode.
|
||||
#
|
||||
# Profile model shapes and parameters only.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path=/graph.pbtxt
|
||||
#
|
||||
# Additionally profile checkpoint statistics and values.
|
||||
# Use '-account_type_regexes _checkpoint_variables' to select
|
||||
# checkpoint tensors.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path=graph.pbtxt \
|
||||
--checkpoint_path=model.ckpt
|
||||
#
|
||||
# Additionally profile ops requested memory and timing.
|
||||
# See CLI Input Files section on generating run_meta file.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path=graph.pbtxt \
|
||||
--run_meta_path=run_meta \
|
||||
--checkpoint_path=model.ckpt
|
||||
#
|
||||
# tfprof_log is used to define customized op types and float ops.
|
||||
# Use tfprof_logger.write_op_log() to create tfprof_log.
|
||||
# See 11) in Examples section on generating tfprof_log file.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path=graph.pbtxt \
|
||||
--run_meta_path=run_meta \
|
||||
--op_log_path=tfprof_log \
|
||||
--checkpoint_path=model.ckpt
|
||||
```
|
||||
Note that `graph.pbtxt` is an ASCII text format.
|
||||
|
||||
2) Press enter to show the default options
|
||||
|
||||
```shell
|
||||
tfprof>
|
||||
tfprof>
|
||||
-max_depth 4
|
||||
-min_bytes 0
|
||||
-min_micros 0
|
||||
-min_params 0
|
||||
-min_float_ops 0
|
||||
-device_regexes .*
|
||||
-order_by name
|
||||
-account_type_regexes Variable
|
||||
-start_name_regexes .*
|
||||
-trim_name_regexes
|
||||
-show_name_regexes .*
|
||||
-hide_name_regexes IsVariableInitialized_[0-9]+,save\/.*,^zeros[0-9_]*
|
||||
-account_displayed_op_only false
|
||||
# supported select fileds. Availability depends on --[run_meta|checkpoint|op_log]_path.
|
||||
# [bytes|micros|params|float_ops|num_hidden_ops|tensor_value|device|op_types]
|
||||
-select params
|
||||
-viz false
|
||||
-dump_to_file
|
||||
```
|
||||
|
||||
3) I want to see the `BatchNorm`'s gamma value in checkpoint.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --checkpoint_path.
|
||||
tfprof> scope -show_name_regexes unit_1_0.*gamma -select tensor_value -max_depth 5
|
||||
_TFProfRoot ()
|
||||
unit_1_0/shared_activation/init_bn/gamma ()
|
||||
[1.80 2.10 2.06 1.91 2.26 1.86 1.81 1.37 1.78 1.85 1.96 1.54 2.04 2.34 2.22 1.99 ],
|
||||
unit_1_0/sub2/bn2/gamma ()
|
||||
[1.57 1.83 1.30 1.25 1.59 1.14 1.26 0.82 1.19 1.10 1.48 1.01 0.82 1.23 1.21 1.14 ],
|
||||
```
|
||||
|
||||
4) I want to see my checkpoint tensors shape and number of parameters.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --checkpoint_path.
|
||||
# Increase -max_depth to see all tensors.
|
||||
tfprof> scope -account_type_regexes _checkpoint_variables -select params -max_depth 4
|
||||
_TFProfRoot (--/930.58k params)
|
||||
global_step (0/0 params)
|
||||
init/init_conv/DW (3x3x3x16, 432/864 params)
|
||||
pool_logit/DW (64x10, 640/1.28k params)
|
||||
pool_logit/DW/Momentum (64x10, 640/640 params)
|
||||
pool_logit/biases (10, 10/20 params)
|
||||
pool_logit/biases/Momentum (10, 10/10 params)
|
||||
unit_last/final_bn/beta (64, 64/128 params)
|
||||
unit_last/final_bn/gamma (64, 64/128 params)
|
||||
unit_last/final_bn/moving_mean (64, 64/64 params)
|
||||
unit_last/final_bn/moving_variance (64, 64/64 params)
|
||||
```
|
||||
|
||||
5) I defined an op named ‘cost’ to calculate the loss. I want to know what ops
|
||||
it depends on take a long time to run. Hint: Use the ‘graph’ command to explore
|
||||
graph dependencies.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --run_meta_path.
|
||||
tfprof> graph -start_name_regexes cost.* -max_depth 100 -min_micros 10000 -select micros -account_type_regexes .*
|
||||
_TFProfRoot (0us/3.61sec)
|
||||
init/init_conv/Conv2D (11.75ms/3.10sec)
|
||||
random_shuffle_queue_DequeueMany (3.09sec/3.09sec)
|
||||
unit_1_0/sub2/conv2/Conv2D (74.14ms/3.19sec)
|
||||
unit_1_3/sub2/conv2/Conv2D (60.75ms/3.34sec)
|
||||
unit_2_4/sub2/conv2/Conv2D (73.58ms/3.54sec)
|
||||
unit_3_3/sub2/conv2/Conv2D (10.26ms/3.60sec)
|
||||
```
|
||||
|
||||
6) I want to know the expensive operations during the back propagation.
|
||||
Hint: tensorflow prepend ‘gradient’ to your defined name scopes. Use the ‘scope’
|
||||
command to explore based on name scope hierarchies.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --run_meta_path.
|
||||
tfprof> scope -start_name_regexes gradient.* -max_depth 100 -min_micros 20000 -select micros -account_type_regexes .*
|
||||
_TFProfRoot (0us/2.29sec)
|
||||
gradients/unit_1_0/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (54.96ms/54.96ms)
|
||||
gradients/unit_1_0/sub2/conv2/Conv2D_grad/Conv2DBackpropFilter (83.63ms/83.63ms)
|
||||
gradients/unit_1_1/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (99.25ms/99.25ms)
|
||||
gradients/unit_1_2/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (95.40ms/95.40ms)
|
||||
gradients/unit_1_2/sub2/conv2/Conv2D_grad/Conv2DBackpropFilter (99.83ms/99.83ms)
|
||||
gradients/unit_1_3/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (95.39ms/95.39ms)
|
||||
...
|
||||
```
|
||||
|
||||
7) Show the number of float operations in the model.
|
||||
Note: float operations calculation depends on
|
||||
1) op.RegisterStatistics. If an op doesn’t
|
||||
have RegisterStatistics defined, its float operations cannot be counted.
|
||||
2) fully defined shape is also necessary in order to calculate flops.
|
||||
float operations number is provided by tensorflow::tfprof::OpLog logged from
|
||||
Python API.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --op_log_path.
|
||||
tfprof> scope -min_float_ops 1 -max_depth 10 -select float_ops -account_type_regexes .*
|
||||
_TFProfRoot (0/17.63b flops)
|
||||
gradients/pool_logit/xw_plus_b/MatMul_grad/MatMul (163.84k/163.84k flops)
|
||||
gradients/pool_logit/xw_plus_b/MatMul_grad/MatMul_1 (163.84k/163.84k flops)
|
||||
init/init_conv/Conv2D (113.25m/113.25m flops)
|
||||
pool_logit/xw_plus_b (1.28k/165.12k flops)
|
||||
pool_logit/xw_plus_b/MatMul (163.84k/163.84k flops)
|
||||
unit_1_0/sub1/conv1/Conv2D (603.98m/603.98m flops)
|
||||
unit_1_0/sub2/conv2/Conv2D (603.98m/603.98m flops)
|
||||
unit_1_1/sub1/conv1/Conv2D (603.98m/603.98m flops)
|
||||
unit_1_1/sub2/conv2/Conv2D (603.98m/603.98m flops)
|
||||
...
|
||||
```
|
||||
|
||||
8) Show the number of parameters of all `tf.trainable_variables()` in the model.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path --op_log_path.
|
||||
# store option for future commands.
|
||||
tfprof> set -account_type_regexes _trainable_variables
|
||||
tfprof> scope -max_depth 4 -select params
|
||||
_TFProfRoot (--/464.15k params)
|
||||
init/init_conv/DW (3x3x3x16, 432/432 params)
|
||||
pool_logit/DW (64x10, 640/640 params)
|
||||
pool_logit/biases (10, 10/10 params)
|
||||
unit_last/final_bn/beta (64, 64/64 params)
|
||||
unit_last/final_bn/gamma (64, 64/64 params)
|
||||
```
|
||||
|
||||
Where does “_trainable_variables” come from? It is from the OpLog file
|
||||
generated by write_op_log() Python API. write_op_log() help users create some
|
||||
common op types implicitly. Users can define their own op types and log it
|
||||
through the write_op_log() API.
|
||||
|
||||
9) What if I’m lazy and don’t want to define op type? I have given my ops
|
||||
well-defined names in my model’s code. And want to use names to select a group
|
||||
of ops. Let’s try it!
|
||||
|
||||
```shell
|
||||
tfprof> set -account_type_regexes .*
|
||||
tfprof> scope -show_name_regexes unit_2_1.*DW -max_depth 100 -account_displayed_op_only
|
||||
_TFProfRoot (0/18.43k params)
|
||||
unit_2_1/sub1/conv1/DW (3x3x32x32, 9.22k/9.22k params)
|
||||
unit_2_1/sub2/conv2/DW (3x3x32x32, 9.22k/9.22k params)
|
||||
```
|
||||
|
||||
The above command allows you to filter ops that match specific names.
|
||||
`-account_displayed_op_only` asks tfprof to only account ops displayed
|
||||
in terminal. Otherwise, tfprof accounts all ops matched by
|
||||
`-account_type_regexes` recursively even if they are hidden due to some
|
||||
options such as -max_depth.
|
||||
|
||||
10) TensorFlow has built-in op types. For example, built-in op type `Variable`
|
||||
seems to include `Variable's` created by your model. However, be careful when
|
||||
depending on it because TensorFlow creates extra `Variable` ops implicitly and
|
||||
the implicitly created ops can have the same prefix as the `Variable's` you
|
||||
defined.
|
||||
|
||||
In the following example, extra `Variables` are created and “/Momentum” is
|
||||
appended to their names. This might cause you “model capacity” calculation
|
||||
to get wrong.
|
||||
|
||||
```shell
|
||||
tfprof> scope -account_type_regexes Variable -max_depth 4 -select params
|
||||
_TFProfRoot (--/930.58k params)
|
||||
global_step (1/1 params)
|
||||
init/init_conv/DW (3x3x3x16, 432/864 params)
|
||||
pool_logit/DW (64x10, 640/1.28k params)
|
||||
pool_logit/DW/Momentum (64x10, 640/640 params)
|
||||
pool_logit/biases (10, 10/20 params)
|
||||
pool_logit/biases/Momentum (10, 10/10 params)
|
||||
unit_last/final_bn/beta (64, 64/128 params)
|
||||
unit_last/final_bn/gamma (64, 64/128 params)
|
||||
unit_last/final_bn/moving_mean (64, 64/64 params)
|
||||
unit_last/final_bn/moving_variance (64, 64/64 params)
|
||||
```
|
||||
|
||||
|
||||
11) A example of defining extra op type for ops using `OpLog`
|
||||
|
||||
First, in Python code, create an `OpLog` proto and add op type
|
||||
information to it:
|
||||
|
||||
```python
|
||||
|
||||
op_log = tfprof_log_pb2.OpLog()
|
||||
entry = op_log.log_entries.add()
|
||||
entry.name = 'pool_logit/DW'
|
||||
entry.types.append('pool_logit')
|
||||
entry = op_log.log_entries.add()
|
||||
entry.name = 'pool_logit/biases'
|
||||
# Alternatively:
|
||||
# var = tf.get_variable(xxx)
|
||||
# entry.name = var.op.name
|
||||
entry.types.append('pool_logit')
|
||||
```
|
||||
|
||||
Second, call write_op_log to write the OpLog proto.
|
||||
|
||||
```python
|
||||
tf.tfprof.tfprof_logger.write_op_log(sess.graph, /tmp/my_op_log_dir, op_log)
|
||||
```
|
||||
|
||||
Third, when starting the tfprof tool, specify
|
||||
"--op_log_path /tmp/my_op_log_dir/op_log"
|
||||
|
||||
```shell
|
||||
tfprof> scope -account_type_regexes pool_logit -max_depth 4 -select params
|
||||
_TFProfRoot (--/650 params)
|
||||
pool_logit/DW (64x10, 640/640 params)
|
||||
pool_logit/biases (10, 10/10 params)
|
||||
```
|
||||
|
||||
Note that when you call
|
||||
`tf.tfprof.tfprof_logger.write_op_log(...)`, the tool adds all `Variables`
|
||||
inside `tf.trainable_variables()` to `_trainable_variables`.
|
||||
|
||||
12) Run tfprof in one-shot mode and dump result to file.
|
||||
|
||||
```shell
|
||||
# Printed to stdout if --dump_to_file is not set.
|
||||
tfprof scope --graph_path /cns/ij-d/home/xpan/tfprof/graph.pbtxt \
|
||||
--max_depth 3 \
|
||||
--dump_to_file "/tmp/dump"
|
||||
Reading Files...
|
||||
Parsing GraphDef...
|
||||
Preparing Views...
|
||||
|
||||
cat /tmp/dump
|
||||
_TFProfRoot (--/930.58k params)
|
||||
global_step (0/0 params)
|
||||
pool_logit/DW (64x10, 640/1.28k params)
|
||||
pool_logit/biases (10, 10/20 params)
|
||||
```
|
||||
|
||||
13) Analyze how balanced Variable are on parameter servers.
|
||||
|
||||
In this tutorial, I'm going to use a seq2seq model, which are split
|
||||
on several gpus at workers and several parameter servers.
|
||||
|
||||
In tfprof, 'device' is an op_type. For example, if op1 and op2 are placed on
|
||||
gpu0. They share an op_type called 'gpu0'.
|
||||
|
||||
```shell
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path ~/tfprof/textsum/graph.pbtxt \
|
||||
--run_meta_path ~/tfprof/textsum/run_meta
|
||||
|
||||
# Looks like ps task 1 is holding twice more parameters than task 0.
|
||||
tfprof> scope -select device,params -account_type_regexes .*ps.*task:0.* -max_depth 1
|
||||
_TFProfRoot (--/25.81m params)
|
||||
tfprof> scope -select device,params -account_type_regexes .*ps.*task:1.* -max_depth 1
|
||||
_TFProfRoot (--/58.84m params)
|
||||
```
|
||||
|
||||
### CLI Input Files
|
||||
|
||||
tfprof command line inference (CLI) loads dumped files from a tensorflow model.
|
||||
Convert them into in-memory data structures. To use it, users need to specify
|
||||
the locations of the dumped files. The following are the dumped files loaded
|
||||
by tfprof:
|
||||
|
||||
<b>--graph_path:</b> GraphDef text file (required). Used to build in-memory
|
||||
representation of the model. For example, graph.pbtxt written by tf.Supervisor
|
||||
is a candidate. If you are not using tf.Supervisor, you can easily get GraphDef
|
||||
using tf.Graph.as_graph_def() or other API.
|
||||
|
||||
<b>--run_meta_path:</b> tensorflow::RunMetadata.
|
||||
Used to get the memory and time consumption of
|
||||
each op of the model. Users need to enable it. For example, the following code
|
||||
snippet writes a RunMetadata file:
|
||||
|
||||
```python
|
||||
run_options = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
# Once a while, call it the get the RunMeta.
|
||||
_ = self._sess.run(..., options=run_options, run_metadata=run_metadata)
|
||||
with gfile.Open(os.path.join(output_dir, "run_meta"), "w") as f:
|
||||
f.write(run_metadata.SerializeToString())
|
||||
```
|
||||
|
||||
<b>--op_log_path:</b>
|
||||
tensorflow::tfprof::OpLog. A proto used to provide extra op information
|
||||
for ops. By giving a group of ops a type name, users can easily aggregate the
|
||||
statistics for those ops without accidently missing or including extra ops.
|
||||
tfprof exposes the following Python API to add op information and logging.
|
||||
|
||||
```python
|
||||
tf.contrib.tfprof.tfprof_logger.write_op_log(graph, log_dir, op_log=None)
|
||||
```
|
||||
|
||||
<b>--checkpoint_path:</b>
|
||||
TensorFlow checkpoint. It defines _checkpoint_variable op type. It also
|
||||
provides checkpointed tensors' values.
|
||||
|
||||
|
||||
## Design
|
||||
|
||||
|
||||
### In-memory representation
|
||||
|
||||
<b>Scope:</b> This representation organizes ops based on name scope hierarchy,
|
||||
similar to filesystem hierarchy. Hence, it is essentially a tree data structure.
|
||||
For example op1 with name “name1/name2” is a child of op2 with name “name1”.
|
||||
|
||||
<b>Graph:</b> The representation organizes ops based on op inputs. Hence it is
|
||||
a graph structure. The graph is a “directed acyclic graph” (hopefully), with
|
||||
direction from “output to input”. The direction is design this way so that users
|
||||
can trace from “result” to its “sources”.
|
||||
|
||||
### Command line options
|
||||
|
||||
tfprof’s major goals are to measure system performance and quicly analyze
|
||||
model architectures. Hence, its commands and options should allow users to achieve
|
||||
these 2 goals easily.
|
||||
|
||||
<b>graph:</b> It is expected that users will mostly use graph representation to
|
||||
debug system performance. Hence, tfprof supports graph command, which pulls the
|
||||
graph in-memory representation described above.
|
||||
|
||||
<b>scope:</b> It is expected that some users might want to explore their model
|
||||
statistics using the name scope information they defined in the Python codes.
|
||||
Hence, tfprof supports “scope” command, which pulls the tree in-memory
|
||||
representation.
|
||||
|
||||
<b>set:</b> It is used to store the options so that user doesn’t need to
|
||||
re-type the same option again and again in the follow up command line. Note that
|
||||
tfprof has traditional terminal’s history and auto-complete support.
|
||||
|
||||
<b>help:</b> print help information.
|
||||
|
||||
<b>Options:</b> Run “tfprof help” to get detailed explanations.
|
||||
|
||||
```python
|
||||
"-max_depth",
|
||||
"-min_bytes",
|
||||
"-min_micros",
|
||||
"-min_params",
|
||||
"-min_float_ops",
|
||||
"-order_by",
|
||||
"-account_type_regexes",
|
||||
"-start_name_regexes",
|
||||
"-trim_name_regexes",
|
||||
"-show_name_regexes",
|
||||
"-hide_name_regexes",
|
||||
"-account_displayed_op_only",
|
||||
"-select",
|
||||
"-viz", # Only supported for graph command.
|
||||
"-dump_to_file",
|
||||
```
|
||||
|
||||
A key design is that stats are aggregated from descendants up to ancestors.
|
||||
`-account_type_regexes` is used to decide which ops stat is accounted. It makes
|
||||
decision based on op type. Usually set it to `.*` if no extra type information
|
||||
is added to the ops using OpLog. Intuitively, only accounted ops are displayed.
|
||||
`-min/max` and `-show/hide/trim/start` options are only used the optionally
|
||||
displayed or hide ops based on ops’ name and stats. However, they don’t prevent
|
||||
tfprof from accounting stats of hidden ops. Hence, the stat of a op can be
|
||||
aggregated by its parent even if it is hidden. `-account_displayed_op_only` is
|
||||
an option to break this rule. When it is set, only displayed ops are accounted.
|
||||
|
||||
Regexes are all comma-separated, for example `-show_name_regexes`
|
||||
`regex1.*,regex2.*`. It is designed this way because it is convenient and comma
|
||||
is not expected to show up in op names.
|
||||
|
||||
`-order_by` is used to order displayed ops. Displayed ops at the same hierarchy
|
||||
(notice the indent printed) are sorted according to order_by.
|
||||
|
||||
## Future Work
|
||||
|
||||
* Load SummaryWriter event logs so that it can show the latest summary value.
|
||||
|
||||
* Better sorting and aggregation of outputs. Easier comprehension.
|
||||
|
||||
* Currently, shape information is based on `graph.pbtxt`. When the shape
|
||||
information is incomplete, tfprof ignores it. See if it can use `RunMetadata`
|
||||
and `Checkpoint` to complete shape information.
|
||||
Enjoy!
|
@ -17,5 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
@ -3,14 +3,36 @@ licenses(["notice"]) # Apache 2.0
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
|
||||
py_library(
|
||||
name = "model_analyzer",
|
||||
srcs = ["model_analyzer.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:pywrap_tensorflow_print_model_analysis_lib",
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:tfprof_logger",
|
||||
"//tensorflow/tools/tfprof:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_analyzer_test",
|
||||
srcs = ["model_analyzer_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":model_analyzer",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tfprof_logger",
|
||||
srcs = ["tfprof_logger.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/tools/tfprof:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
@ -20,7 +42,34 @@ tf_py_test(
|
||||
additional_deps = [
|
||||
":tfprof_logger",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py",
|
||||
"//tensorflow/tools/tfprof:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_wrap_cc(
|
||||
name = "pywrap_tensorflow_print_model_analysis_lib",
|
||||
srcs = ["pywrap_tensorflow_print_model_analysis.i"],
|
||||
swig_includes = [
|
||||
"//tensorflow/python:lib/core/strings.i",
|
||||
"//tensorflow/python:platform/base.i",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/tools/tfprof/internal:print_model_analysis_hdr",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "print_model_analysis_test",
|
||||
srcs = ["print_model_analysis_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pywrap_tensorflow_print_model_analysis_lib",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/tools/tfprof:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
|
188
tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
Normal file
188
tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
Normal file
@ -0,0 +1,188 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Model Analyzer.
|
||||
|
||||
Analyze model, including shape, params, time, memory, structure, etc.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import pywrap_tensorflow_print_model_analysis_lib as print_mdl
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
|
||||
from tensorflow.tools.tfprof import tfprof_options_pb2
|
||||
from tensorflow.tools.tfprof import tfprof_output_pb2
|
||||
|
||||
# pylint: disable=bad-whitespace
|
||||
# pylint: disable=bad-continuation
|
||||
# 2 example tfprof_options for print_model_analysis API.
|
||||
#
|
||||
# Show the parameter statistics of trainable variables.
|
||||
TRAINABLE_VARS_PARAMS_STAT_OPTIONS = {
|
||||
'max_depth': 10000,
|
||||
'min_bytes': 0,
|
||||
'min_micros': 0,
|
||||
'min_params': 0,
|
||||
'min_float_ops': 0,
|
||||
'device_regexes': ['.*'],
|
||||
'order_by': 'name',
|
||||
'account_type_regexes': [tfprof_logger.TRAINABLE_VARIABLES],
|
||||
'start_name_regexes': ['.*'],
|
||||
'trim_name_regexes': [],
|
||||
'show_name_regexes': ['.*'],
|
||||
'hide_name_regexes': [],
|
||||
'account_displayed_op_only': True,
|
||||
'select': ['params'],
|
||||
'viz': False,
|
||||
'dump_to_file': ''
|
||||
}
|
||||
|
||||
# Show the number float operations.
|
||||
FLOAT_OPS_OPTIONS = {
|
||||
'max_depth': 10000,
|
||||
'min_bytes': 0,
|
||||
'min_micros': 0,
|
||||
'min_params': 0,
|
||||
'min_float_ops': 1,
|
||||
'device_regexes': ['.*'],
|
||||
'order_by': 'float_ops',
|
||||
'account_type_regexes': ['.*'],
|
||||
'start_name_regexes': ['.*'],
|
||||
'trim_name_regexes': [],
|
||||
'show_name_regexes': ['.*'],
|
||||
'hide_name_regexes': [],
|
||||
'account_displayed_op_only': True,
|
||||
'select': ['float_ops'],
|
||||
'viz': False,
|
||||
'dump_to_file': ''
|
||||
}
|
||||
|
||||
# Show number of parameters on parameter server 0.
|
||||
# It is recommended to provide`run_meta` argument
|
||||
# to have complete device placement info.
|
||||
PRINT_PARAMS_ON_DEVICE = {
|
||||
'max_depth': 1,
|
||||
'min_bytes': 0,
|
||||
'min_micros': 0,
|
||||
'min_params': 0,
|
||||
'min_float_ops': 0,
|
||||
'device_regexes': ['.*'],
|
||||
'order_by': 'name',
|
||||
'account_type_regexes': ['.*ps.*task:0.*'],
|
||||
'start_name_regexes': ['.*'],
|
||||
'trim_name_regexes': [],
|
||||
'show_name_regexes': ['.*'],
|
||||
'hide_name_regexes': [],
|
||||
'account_displayed_op_only': False,
|
||||
'select': ['device', 'params'],
|
||||
'viz': False,
|
||||
'dump_to_file': ''
|
||||
}
|
||||
|
||||
# Show the timing stats and memory demands.
|
||||
PRINT_ALL_TIMING_MEMORY = {
|
||||
'max_depth': 10000,
|
||||
'min_bytes': 1, # Only >=1
|
||||
'min_micros': 1, # Only >=1
|
||||
'min_params': 0,
|
||||
'min_float_ops': 0,
|
||||
'device_regexes': ['.*'],
|
||||
'order_by': 'name',
|
||||
'account_type_regexes': ['.*'],
|
||||
'start_name_regexes': ['.*'],
|
||||
'trim_name_regexes': [],
|
||||
'show_name_regexes': ['.*'],
|
||||
'hide_name_regexes': [],
|
||||
'account_displayed_op_only': True,
|
||||
'select': ['micros', 'bytes'],
|
||||
'viz': False,
|
||||
'dump_to_file': ''
|
||||
}
|
||||
|
||||
# pylint: enable=bad-whitespace
|
||||
# pylint: enable=bad-continuation
|
||||
|
||||
|
||||
def print_model_analysis(graph,
|
||||
run_meta=None,
|
||||
op_log=None,
|
||||
tfprof_cmd='scope',
|
||||
tfprof_options=TRAINABLE_VARS_PARAMS_STAT_OPTIONS):
|
||||
"""Print model statistics.
|
||||
|
||||
Prints the model statistics to stdout. Also returns the results
|
||||
in a TFProfNode proto. See go/tfprof or run tfprof tool:
|
||||
'bazel run third_party/tensorflow/tools/tfprof help'
|
||||
|
||||
Examples:
|
||||
Show the parameter/shape statistics of tf.trainable_variables().
|
||||
print_model_analysis(sess.graph).
|
||||
|
||||
Show number of float ops. Only ops with RegisterStatistics defined
|
||||
are counted.
|
||||
show_float_op_opts = model_analyzer.FLOAT_OPS_OPTIONS
|
||||
print_model_analysis(sess.graph, tfprof_options=show_float_op_opts)
|
||||
|
||||
Args:
|
||||
graph: tf.Graph.
|
||||
run_meta: tensorflow::RunMetadata proto. When provided, also shows valid
|
||||
timing and memory information when 'select' option contains
|
||||
'micros' and 'bytes'.
|
||||
op_log: tensorflow::tfprof::OpLog proto. users can use this proto to
|
||||
group together ops and use a op_type to select the group.
|
||||
tfprof_cmd: string. Either 'scope' or 'graph'. 'scope' view organize
|
||||
ops using their name scopes. 'graph' view organize ops using
|
||||
their graph inputs.
|
||||
tfprof_options: See 'tfprof help' for details.
|
||||
Returns:
|
||||
TFProfNode proto. Side effect: a formatted output to stdout.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
op_log = tfprof_logger._merge_default_with_oplog(graph, op_log, run_meta)
|
||||
# pylint: enable=protected-access
|
||||
opts = tfprof_options_pb2.OptionsProto()
|
||||
opts.max_depth = tfprof_options['max_depth']
|
||||
opts.min_bytes = tfprof_options['min_bytes']
|
||||
opts.min_micros = tfprof_options['min_micros']
|
||||
opts.min_params = tfprof_options['min_params']
|
||||
opts.min_float_ops = tfprof_options['min_float_ops']
|
||||
for p in tfprof_options['device_regexes']:
|
||||
opts.device_regexes.append(p)
|
||||
opts.order_by = tfprof_options['order_by']
|
||||
for p in tfprof_options['account_type_regexes']:
|
||||
opts.account_type_regexes.append(p)
|
||||
for p in tfprof_options['start_name_regexes']:
|
||||
opts.start_name_regexes.append(p)
|
||||
for p in tfprof_options['trim_name_regexes']:
|
||||
opts.trim_name_regexes.append(p)
|
||||
for p in tfprof_options['show_name_regexes']:
|
||||
opts.show_name_regexes.append(p)
|
||||
for p in tfprof_options['hide_name_regexes']:
|
||||
opts.hide_name_regexes.append(p)
|
||||
opts.account_displayed_op_only = tfprof_options['account_displayed_op_only']
|
||||
for p in tfprof_options['select']:
|
||||
opts.select.append(p)
|
||||
opts.viz = tfprof_options['viz']
|
||||
opts.dump_to_file = tfprof_options['dump_to_file']
|
||||
|
||||
run_meta_str = run_meta.SerializeToString() if run_meta else b''
|
||||
op_log_str = op_log.SerializeToString() if op_log else b''
|
||||
|
||||
tfprof_node = tfprof_output_pb2.TFProfNode()
|
||||
tfprof_node.ParseFromString(
|
||||
print_mdl.PrintModelAnalysis(
|
||||
graph.as_graph_def().SerializeToString(), run_meta_str, op_log_str,
|
||||
tfprof_cmd.encode('utf-8'), opts.SerializeToString()))
|
||||
return tfprof_node
|
@ -0,0 +1,84 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class PrintModelAnalysisTest(tf.test.TestCase):
|
||||
|
||||
def _BuildSmallModel(self):
|
||||
image = tf.zeros([2, 6, 6, 3])
|
||||
kernel = tf.get_variable(
|
||||
'DW', [3, 3, 3, 6],
|
||||
tf.float32,
|
||||
initializer=tf.random_normal_initializer(stddev=0.001))
|
||||
x = tf.nn.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
|
||||
kernel = tf.get_variable(
|
||||
'DW2', [2, 2, 6, 12],
|
||||
tf.float32,
|
||||
initializer=tf.random_normal_initializer(stddev=0.001))
|
||||
x = tf.nn.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
|
||||
return x
|
||||
|
||||
def testDumpToFile(self):
|
||||
opts = tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
|
||||
opts['dump_to_file'] = os.path.join(tf.test.get_temp_dir(), 'dump')
|
||||
|
||||
with tf.Session() as sess, tf.device('/cpu:0'):
|
||||
_ = self._BuildSmallModel()
|
||||
tf.contrib.tfprof.model_analyzer.print_model_analysis(
|
||||
sess.graph, tfprof_options=opts)
|
||||
|
||||
with tf.gfile.Open(opts['dump_to_file'], 'r') as f:
|
||||
self.assertEqual(u'_TFProfRoot (--/450 params)\n'
|
||||
' DW (3x3x3x6, 162/162 params)\n'
|
||||
' DW2 (2x2x6x12, 288/288 params)\n',
|
||||
f.read().decode('utf-8'))
|
||||
|
||||
def testSelectEverything(self):
|
||||
opts = tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
|
||||
opts['dump_to_file'] = os.path.join(tf.test.get_temp_dir(), 'dump')
|
||||
opts['account_type_regexes'] = ['.*']
|
||||
opts['select'] = [
|
||||
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', 'op_types'
|
||||
]
|
||||
|
||||
with tf.Session() as sess, tf.device('/cpu:0'):
|
||||
x = self._BuildSmallModel()
|
||||
|
||||
sess.run(tf.initialize_all_variables())
|
||||
run_meta = tf.RunMetadata()
|
||||
_ = sess.run(x,
|
||||
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
|
||||
run_metadata=run_meta)
|
||||
|
||||
tf.contrib.tfprof.model_analyzer.print_model_analysis(
|
||||
sess.graph, run_meta, tfprof_options=opts)
|
||||
|
||||
with tf.gfile.Open(opts['dump_to_file'], 'r') as f:
|
||||
# pylint: disable=line-too-long
|
||||
self.assertEqual(
|
||||
'_TFProfRoot (0/450 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Variable|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Variable|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n init (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n',
|
||||
f.read().decode('utf-8'))
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -0,0 +1,238 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""print_model_analysis test."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import pywrap_tensorflow_print_model_analysis_lib as print_mdl
|
||||
from tensorflow.tools.tfprof import tfprof_options_pb2
|
||||
from tensorflow.tools.tfprof import tfprof_output_pb2
|
||||
|
||||
# pylint: disable=bad-whitespace
|
||||
# pylint: disable=bad-continuation
|
||||
TEST_OPTIONS = {
|
||||
'max_depth': 10000,
|
||||
'min_bytes': 0,
|
||||
'min_micros': 0,
|
||||
'min_params': 0,
|
||||
'min_float_ops': 0,
|
||||
'device_regexes': ['.*'],
|
||||
'order_by': 'name',
|
||||
'account_type_regexes': ['.*'],
|
||||
'start_name_regexes': ['.*'],
|
||||
'trim_name_regexes': [],
|
||||
'show_name_regexes': ['.*'],
|
||||
'hide_name_regexes': [],
|
||||
'account_displayed_op_only': True,
|
||||
'select': ['params'],
|
||||
'viz': False
|
||||
}
|
||||
|
||||
# pylint: enable=bad-whitespace
|
||||
# pylint: enable=bad-continuation
|
||||
|
||||
|
||||
class PrintModelAnalysisTest(tf.test.TestCase):
|
||||
|
||||
def _BuildSmallModel(self):
|
||||
image = tf.zeros([2, 6, 6, 3])
|
||||
kernel = tf.get_variable(
|
||||
'DW', [6, 6, 3, 6],
|
||||
tf.float32,
|
||||
initializer=tf.random_normal_initializer(stddev=0.001))
|
||||
x = tf.nn.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
|
||||
return x
|
||||
|
||||
def testPrintModelAnalysis(self):
|
||||
opts = tfprof_options_pb2.OptionsProto()
|
||||
opts.max_depth = TEST_OPTIONS['max_depth']
|
||||
opts.min_bytes = TEST_OPTIONS['min_bytes']
|
||||
opts.min_micros = TEST_OPTIONS['min_micros']
|
||||
opts.min_params = TEST_OPTIONS['min_params']
|
||||
opts.min_float_ops = TEST_OPTIONS['min_float_ops']
|
||||
for p in TEST_OPTIONS['device_regexes']:
|
||||
opts.device_regexes.append(p)
|
||||
opts.order_by = TEST_OPTIONS['order_by']
|
||||
for p in TEST_OPTIONS['account_type_regexes']:
|
||||
opts.account_type_regexes.append(p)
|
||||
for p in TEST_OPTIONS['start_name_regexes']:
|
||||
opts.start_name_regexes.append(p)
|
||||
for p in TEST_OPTIONS['trim_name_regexes']:
|
||||
opts.trim_name_regexes.append(p)
|
||||
for p in TEST_OPTIONS['show_name_regexes']:
|
||||
opts.show_name_regexes.append(p)
|
||||
for p in TEST_OPTIONS['hide_name_regexes']:
|
||||
opts.hide_name_regexes.append(p)
|
||||
opts.account_displayed_op_only = TEST_OPTIONS['account_displayed_op_only']
|
||||
for p in TEST_OPTIONS['select']:
|
||||
opts.select.append(p)
|
||||
opts.viz = TEST_OPTIONS['viz']
|
||||
|
||||
with tf.Session() as sess, tf.device('/cpu:0'):
|
||||
_ = self._BuildSmallModel()
|
||||
tfprof_pb = tfprof_output_pb2.TFProfNode()
|
||||
tfprof_pb.ParseFromString(
|
||||
print_mdl.PrintModelAnalysis(sess.graph.as_graph_def(
|
||||
).SerializeToString(), b'', b'', b'scope', opts.SerializeToString()))
|
||||
|
||||
expected_pb = tfprof_output_pb2.TFProfNode()
|
||||
text_format.Merge(r"""name: "_TFProfRoot"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 648
|
||||
children {
|
||||
name: "Conv2D"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
children {
|
||||
name: "DW"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
parameters: 648
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 648
|
||||
device: "/device:CPU:0"
|
||||
children {
|
||||
name: "DW/Assign"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
children {
|
||||
name: "DW/Initializer"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
children {
|
||||
name: "DW/Initializer/random_normal"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
children {
|
||||
name: "DW/Initializer/random_normal/RandomStandardNormal"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
children {
|
||||
name: "DW/Initializer/random_normal/mean"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
children {
|
||||
name: "DW/Initializer/random_normal/mul"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
children {
|
||||
name: "DW/Initializer/random_normal/shape"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
children {
|
||||
name: "DW/Initializer/random_normal/stddev"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
children {
|
||||
name: "DW/read"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
children {
|
||||
name: "zeros"
|
||||
exec_micros: 0
|
||||
requested_bytes: 0
|
||||
total_exec_micros: 0
|
||||
total_requested_bytes: 0
|
||||
total_parameters: 0
|
||||
device: "/device:CPU:0"
|
||||
float_ops: 0
|
||||
total_float_ops: 0
|
||||
}
|
||||
float_ops: 0
|
||||
total_float_ops: 0""", expected_pb)
|
||||
self.assertEqual(expected_pb, tfprof_pb)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -0,0 +1,43 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
%}
|
||||
|
||||
%typemap(typecheck) const string & = char *;
|
||||
%typemap(in) const string& (string temp) {
|
||||
if (!_PyObjAs<string>($input, &temp)) return NULL;
|
||||
$1 = &temp;
|
||||
}
|
||||
%typemap(out) const string& {
|
||||
$result = PyString_FromStringAndSize($1->data(), $1->size());
|
||||
}
|
||||
%apply const string & {string &};
|
||||
%apply const string & {string *};
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::tfprof;
|
||||
%unignore tensorflow::tfprof::PrintModelAnalysis;
|
||||
|
||||
%include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
|
||||
|
||||
%unignoreall
|
@ -24,8 +24,8 @@ import os
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.tfprof.tools.tfprof import tfprof_log_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.tools.tfprof import tfprof_log_pb2
|
||||
|
||||
TRAINABLE_VARIABLES = '_trainable_variables'
|
||||
REGISTERED_FLOP_STATS = 'flops'
|
||||
@ -85,7 +85,7 @@ def _get_logged_ops(graph, run_meta=None):
|
||||
if node.name not in logged_ops:
|
||||
entry = tfprof_log_pb2.OpLogEntry()
|
||||
entry.name = node.name
|
||||
entry.float_ops = stats.value
|
||||
entry.float_ops = int(stats.value)
|
||||
logged_ops[entry.name] = entry
|
||||
|
||||
for v in graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
|
||||
|
@ -32,8 +32,9 @@ like to store state in the forward direction across segments of an example.
|
||||
To resample data with replacement on a per-example basis, use
|
||||
['rejection_sample'](#rejection_sample) or
|
||||
['resample_at_rate'](#resample_at_rate). For `rejection_sample`, provide
|
||||
a boolean Tensor describing whether to accept or reject. For `resample_at_rate`,
|
||||
providing the desired rate for each example. If you wish to specify relative
|
||||
a boolean Tensor describing whether to accept or reject. Resulting batch sizes
|
||||
are always the same. For `resample_at_rate`, provide the desired rate for each
|
||||
example. Resulting batch sizes may vary. If you wish to specify relative
|
||||
rates, rather than absolute ones, use ['weighted_resample'](#weighted_resample)
|
||||
(which also returns the actual resampling rate used for each output example).
|
||||
|
||||
|
@ -16,8 +16,10 @@ limitations under the License.
|
||||
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/immutable_constant_op.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -45,13 +47,27 @@ class NodeConverter {
|
||||
const DataType tensor_data_type = tensor_proto.dtype();
|
||||
const TensorShapeProto tensor_shape = tensor_proto.tensor_shape();
|
||||
|
||||
// Check that the tensor type is POD, only these types are supported for
|
||||
// memmapping.
|
||||
// DataType enum is explicitly converted to int to avoid errors with passing
|
||||
// enum type are a parameter type to std::unordered_set.
|
||||
static std::unordered_set<int> supported_types{
|
||||
#define TYPE_FOR_SET(type) static_cast<int>(DataTypeToEnum<type>::value),
|
||||
TF_CALL_POD_TYPES(TYPE_FOR_SET)
|
||||
#undef ADD_TYPE
|
||||
};
|
||||
|
||||
if (supported_types.count(static_cast<int>(tensor_data_type)) == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create Tensor from value and write it in memmapped format.
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
tensor_proto.DebugString());
|
||||
}
|
||||
if (parsed.TotalBytes() < min_conversion_size_bytes) {
|
||||
if (parsed.TotalBytes() < static_cast<size_t>(min_conversion_size_bytes)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -26,6 +26,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
bool GraphHasImmutableConstNodes(const GraphDef& graph_def) {
|
||||
for (const auto& node : graph_def.node()) {
|
||||
if (node.op() == "ImmutableConst") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
|
||||
const string dir = testing::TmpDir();
|
||||
const string filename_pb = io::JoinPath(dir, "graphdef.pb");
|
||||
@ -69,6 +78,7 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
|
||||
TF_ASSERT_OK(ReadBinaryProto(
|
||||
&memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&loaded_graph_def));
|
||||
ASSERT_TRUE(GraphHasImmutableConstNodes(loaded_graph_def));
|
||||
|
||||
TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph";
|
||||
std::vector<Tensor> outputs;
|
||||
@ -79,5 +89,48 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
|
||||
EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f * kTensorHeight);
|
||||
}
|
||||
|
||||
TEST(ConvertGraphdefMemmappedFormatTest, NotSupportedTypesConvert) {
|
||||
// Create a graph with strings.
|
||||
const string dir = testing::TmpDir();
|
||||
const string filename_pb = io::JoinPath(dir, "string_graphdef.pb");
|
||||
|
||||
constexpr int kTensorWidth = 4000;
|
||||
constexpr int kTensorHeight = 100;
|
||||
const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight});
|
||||
Tensor test_tensor1(DT_STRING, kTestTensorShape);
|
||||
test::FillFn<string>(&test_tensor1, [](int) -> string { return "ABC"; });
|
||||
|
||||
Tensor test_tensor2(DT_STRING, kTestTensorShape);
|
||||
test::FillFn<string>(&test_tensor2, [](int) -> string { return "XYZ"; });
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
ops::Output m = ops::Add(root, test_tensor1, test_tensor2);
|
||||
const string result_name = m.node()->name();
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
string graph_def_serialized;
|
||||
graph_def.SerializeToString(&graph_def_serialized);
|
||||
TF_ASSERT_OK(
|
||||
WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized));
|
||||
|
||||
const string filename_mmap = io::JoinPath(dir, "string_graphdef.mmap");
|
||||
TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 1000));
|
||||
|
||||
// Create and initialize MemmappedEnv from the converted file.
|
||||
MemmappedEnv memmapped_env(Env::Default());
|
||||
TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap));
|
||||
|
||||
// Load the graph and run calculations.
|
||||
SessionOptions session_options;
|
||||
session_options.env = &memmapped_env;
|
||||
std::unique_ptr<Session> session(NewSession(session_options));
|
||||
ASSERT_TRUE(session != nullptr) << "Failed to create session";
|
||||
GraphDef loaded_graph_def;
|
||||
TF_ASSERT_OK(ReadBinaryProto(
|
||||
&memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&loaded_graph_def));
|
||||
ASSERT_FALSE(GraphHasImmutableConstNodes(loaded_graph_def));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -164,6 +164,8 @@ cc_library(
|
||||
"lib/core/threadpool.h",
|
||||
"lib/gtl/array_slice.h",
|
||||
"lib/gtl/cleanup.h",
|
||||
"lib/gtl/flatmap.h",
|
||||
"lib/gtl/flatset.h",
|
||||
"lib/gtl/inlined_vector.h",
|
||||
"lib/gtl/priority_queue_util.h",
|
||||
"lib/hash/crc32c.h",
|
||||
@ -178,7 +180,6 @@ cc_library(
|
||||
"lib/io/table.h",
|
||||
"lib/io/table_builder.h",
|
||||
"lib/io/table_options.h",
|
||||
"lib/jpeg/jpeg_mem.h",
|
||||
"lib/math/math_util.h",
|
||||
"lib/monitoring/collected_metrics.h",
|
||||
"lib/monitoring/collection_registry.h",
|
||||
@ -220,6 +221,13 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jpeg",
|
||||
hdrs = ["lib/jpeg/jpeg_mem.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":jpeg_internal"],
|
||||
)
|
||||
|
||||
# Test support library needed for all tests
|
||||
# This is currently public, but may be made internal in the
|
||||
# future. Try to avoid depending on it.
|
||||
@ -521,6 +529,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:ctc_ops",
|
||||
"//tensorflow/core/kernels:data_flow",
|
||||
"//tensorflow/core/kernels:fake_quant_ops",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:image",
|
||||
"//tensorflow/core/kernels:io",
|
||||
@ -970,6 +979,7 @@ cc_library(
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"lib/jpeg/**/*",
|
||||
"platform/**/cuda.h",
|
||||
"platform/**/stream_executor.h",
|
||||
"platform/load_library.cc",
|
||||
@ -986,6 +996,7 @@ cc_library(
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"lib/jpeg/**/*",
|
||||
"platform/**/cuda.h",
|
||||
"platform/**/stream_executor.h",
|
||||
],
|
||||
@ -1019,7 +1030,6 @@ cc_library(
|
||||
"lib/io/zlib_compression_options.h",
|
||||
"lib/io/zlib_inputstream.h",
|
||||
"lib/io/zlib_outputbuffer.h",
|
||||
"lib/jpeg/jpeg_handle.h",
|
||||
"lib/png/png_io.h",
|
||||
"lib/random/random.h",
|
||||
"lib/random/random_distributions.h",
|
||||
@ -1048,6 +1058,26 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jpeg_internal",
|
||||
srcs = glob(
|
||||
[
|
||||
"lib/jpeg/*h",
|
||||
"lib/jpeg/*.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
],
|
||||
),
|
||||
hdrs = ["lib/jpeg/jpeg_handle.h"],
|
||||
copts = tf_copts(),
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":lib",
|
||||
"//tensorflow/core/platform/default/build_config:jpeg",
|
||||
],
|
||||
)
|
||||
|
||||
proto_text_hdrs_and_srcs = tf_generate_proto_text_sources(
|
||||
name = "proto_text_srcs_all",
|
||||
srcs = tf_proto_text_protos_relative(),
|
||||
@ -1149,83 +1179,6 @@ cc_header_only_library(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "framework_headers",
|
||||
srcs = [
|
||||
"framework/allocator.h",
|
||||
"framework/attr_value_util.h",
|
||||
"framework/bfloat16.h",
|
||||
"framework/cancellation.h",
|
||||
"framework/control_flow.h",
|
||||
"framework/device_base.h",
|
||||
"framework/function.h",
|
||||
"framework/kernel_def_builder.h",
|
||||
"framework/node_def_util.h",
|
||||
"framework/numeric_types.h",
|
||||
"framework/op.h",
|
||||
"framework/op_def_builder.h",
|
||||
"framework/op_def_util.h",
|
||||
"framework/op_kernel.h",
|
||||
"framework/partial_tensor_shape.h",
|
||||
"framework/register_types.h",
|
||||
"framework/rendezvous.h",
|
||||
"framework/selective_registration.h",
|
||||
"framework/session_state.h",
|
||||
"framework/shape_inference.h",
|
||||
"framework/tensor.h",
|
||||
"framework/tensor_reference.h",
|
||||
"framework/tensor_shape.h",
|
||||
"framework/tensor_types.h",
|
||||
"framework/tracking_allocator.h",
|
||||
"framework/type_traits.h",
|
||||
"framework/types.h",
|
||||
"framework/unique_tensor_references.h",
|
||||
"lib/core/errors.h",
|
||||
"lib/core/notification.h",
|
||||
"lib/core/refcount.h",
|
||||
"lib/core/status.h",
|
||||
"lib/core/stringpiece.h",
|
||||
"lib/core/threadpool.h",
|
||||
"lib/gtl/array_slice.h",
|
||||
"lib/gtl/array_slice_internal.h",
|
||||
"lib/gtl/inlined_vector.h",
|
||||
"lib/gtl/manual_constructor.h",
|
||||
"lib/hash/hash.h",
|
||||
"lib/strings/numbers.h",
|
||||
"lib/strings/str_util.h",
|
||||
"lib/strings/strcat.h",
|
||||
"platform/cpu_info.h",
|
||||
"platform/default/dynamic_annotations.h",
|
||||
"platform/default/integral_types.h",
|
||||
"platform/default/logging.h",
|
||||
"platform/default/mutex.h",
|
||||
"platform/default/notification.h",
|
||||
"platform/default/protobuf.h",
|
||||
"platform/default/thread_annotations.h",
|
||||
"platform/dynamic_annotations.h",
|
||||
"platform/env.h",
|
||||
"platform/file_statistics.h",
|
||||
"platform/file_system.h",
|
||||
"platform/fingerprint.h",
|
||||
"platform/logging.h",
|
||||
"platform/macros.h",
|
||||
"platform/mem.h",
|
||||
"platform/mutex.h",
|
||||
"platform/net.h",
|
||||
"platform/notification.h",
|
||||
"platform/platform.h",
|
||||
"platform/prefetch.h",
|
||||
"platform/protobuf.h",
|
||||
"platform/strong_hash.h",
|
||||
"platform/thread_annotations.h",
|
||||
"platform/types.h",
|
||||
"public/session.h",
|
||||
"public/session_options.h",
|
||||
"public/version.h",
|
||||
"util/device_name_utils.h",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "stream_executor",
|
||||
srcs = tf_additional_stream_executor_srcs(),
|
||||
@ -1316,7 +1269,7 @@ cc_library(
|
||||
"platform/regexp.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/contrib/tfprof:__subpackages__",
|
||||
"//tensorflow/tools/tfprof:__subpackages__",
|
||||
],
|
||||
deps = [":lib_internal"],
|
||||
)
|
||||
@ -1326,11 +1279,13 @@ tf_cuda_library(
|
||||
srcs = ["common_runtime/direct_session.cc"],
|
||||
hdrs = ["common_runtime/direct_session.h"],
|
||||
copts = tf_copts(),
|
||||
cuda_deps = [
|
||||
":gpu_tracer",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":core_cpu_internal",
|
||||
":framework",
|
||||
":gpu_tracer",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":proto_text",
|
||||
@ -1496,6 +1451,8 @@ tf_cc_tests(
|
||||
"lib/gtl/array_slice_test.cc",
|
||||
"lib/gtl/cleanup_test.cc",
|
||||
"lib/gtl/edit_distance_test.cc",
|
||||
"lib/gtl/flatmap_test.cc",
|
||||
"lib/gtl/flatset_test.cc",
|
||||
"lib/gtl/inlined_vector_test.cc",
|
||||
"lib/gtl/int_type_test.cc",
|
||||
"lib/gtl/iterator_range_test.cc",
|
||||
@ -1582,6 +1539,8 @@ cc_test(
|
||||
srcs = ["lib/jpeg/jpeg_mem_unittest.cc"],
|
||||
data = glob(["lib/jpeg/testdata/*.jpg"]),
|
||||
deps = [
|
||||
":jpeg",
|
||||
":jpeg_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":test",
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/simple_placer.h"
|
||||
@ -57,6 +56,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
@ -453,12 +456,14 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
args.stats_collector = run_state.collector.get();
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
std::unique_ptr<GPUTracer> tracer;
|
||||
if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
|
||||
tracer.reset(CreateGPUTracer());
|
||||
// tracer will be NULL on non-GPU platforms.
|
||||
if (tracer) tracer->Start();
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
for (const auto& item : executors_and_keys->items) {
|
||||
item.executor->RunAsync(args, barrier->Get());
|
||||
@ -468,10 +473,12 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
? run_options.timeout_in_ms()
|
||||
: operation_timeout_in_ms_);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
if (tracer) {
|
||||
tracer->Stop();
|
||||
tracer->Collect(args.stats_collector);
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
{
|
||||
mutex_lock l(run_state.mu_);
|
||||
@ -840,10 +847,11 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
|
||||
std::sort(tn_sorted.begin(), tn_sorted.end());
|
||||
|
||||
const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->",
|
||||
str_util::Join(outputs_sorted, ","), "/",
|
||||
str_util::Join(tn_sorted, ","), "/",
|
||||
run_state_args->is_partial_run);
|
||||
const string key = strings::StrCat(
|
||||
str_util::Join(inputs_sorted, ","), "->",
|
||||
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
|
||||
"/", run_state_args->is_partial_run, "/",
|
||||
SummarizeDebugTensorWatches(run_state_args->debug_tensor_watches));
|
||||
|
||||
// Set the handle.
|
||||
run_state_args->handle =
|
||||
@ -938,7 +946,7 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
partition_graph = iter->second.release();
|
||||
optimizer.Optimize(lib, options_.env, device, &partition_graph);
|
||||
|
||||
// EXPERIMENTAL: tfdb inserts debug nodes (i.e., probes) to the graph
|
||||
// EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph
|
||||
if (!run_state_args->debug_tensor_watches.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
DebugNodeInserter::InsertNodes(run_state_args->debug_tensor_watches,
|
||||
|
@ -291,7 +291,7 @@ class DirectSession : public Session {
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
|
||||
|
||||
// EXPERIMENTAL: debugger (tfdb) related
|
||||
// EXPERIMENTAL: debugger (tfdbg) related
|
||||
friend class DebugGateway;
|
||||
};
|
||||
|
||||
|
@ -222,7 +222,7 @@ typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
|
||||
class ExecutorImpl : public Executor {
|
||||
public:
|
||||
ExecutorImpl(const LocalExecutorParams& p, const Graph* g)
|
||||
: params_(p), graph_(g), initial_pending_counts_(graph_->num_node_ids()) {
|
||||
: params_(p), graph_(g) {
|
||||
CHECK(p.create_kernel != nullptr);
|
||||
CHECK(p.delete_kernel != nullptr);
|
||||
}
|
||||
@ -231,6 +231,7 @@ class ExecutorImpl : public Executor {
|
||||
for (int i = 0; i < graph_->num_node_ids(); i++) {
|
||||
params_.delete_kernel(nodes_[i].kernel);
|
||||
}
|
||||
delete[] frame_local_ids_;
|
||||
delete[] nodes_;
|
||||
delete graph_;
|
||||
}
|
||||
@ -256,13 +257,39 @@ class ExecutorImpl : public Executor {
|
||||
private:
|
||||
friend class ExecutorState;
|
||||
|
||||
static void InitializePending(const Graph* graph, PendingCounts* counts);
|
||||
struct ControlFlowInfo {
|
||||
std::unordered_map<string, int> frame_name_to_size;
|
||||
std::vector<string> frame_names;
|
||||
};
|
||||
|
||||
struct FrameInfo {
|
||||
// The total number of inputs to a frame.
|
||||
int input_count;
|
||||
|
||||
// The total number of input tensors of a frame.
|
||||
// == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame.
|
||||
int total_inputs;
|
||||
|
||||
// Each frame has its own PendingCounts only for the nodes in the frame.
|
||||
PendingCounts* pending_counts; // Owned
|
||||
|
||||
// The nodes in a frame. Used only for debugging.
|
||||
std::vector<const Node*>* nodes; // Owned
|
||||
|
||||
~FrameInfo() {
|
||||
delete pending_counts;
|
||||
delete nodes;
|
||||
}
|
||||
};
|
||||
|
||||
static Status BuildControlFlowInfo(const Graph* graph,
|
||||
ControlFlowInfo* cf_info);
|
||||
void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info);
|
||||
|
||||
// Owned.
|
||||
LocalExecutorParams params_;
|
||||
const Graph* graph_;
|
||||
NodeItem* nodes_ = nullptr; // array of size "graph_.num_node_ids()"
|
||||
int total_input_tensors_ = 0; // == sum(nodes_[*].num_inputs())
|
||||
int total_output_tensors_ = 0; // == sum(nodes_[*].num_outputs())
|
||||
|
||||
// A cached value of params_
|
||||
@ -271,14 +298,17 @@ class ExecutorImpl : public Executor {
|
||||
// Root nodes (with no in edges) that should form the initial ready queue
|
||||
std::vector<const Node*> root_nodes_;
|
||||
|
||||
PendingCounts initial_pending_counts_;
|
||||
|
||||
// The number of inputs for each frame in this graph. This is static
|
||||
// information of the graph.
|
||||
std::unordered_map<string, int> frame_input_count_;
|
||||
|
||||
std::vector<AllocatorAttributes> output_attrs_;
|
||||
|
||||
// Mapping from frame name to static information about the frame.
|
||||
// TODO(yuanbyu): We could cache it along with the graph so to avoid
|
||||
// the overhead of constructing it for each executor instance.
|
||||
std::unordered_map<string, FrameInfo> frame_info_;
|
||||
|
||||
// Mapping from a node's id to its index in the PendingCounts of the
|
||||
// frame the node belongs to.
|
||||
int* frame_local_ids_ = nullptr; // Owned
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
|
||||
};
|
||||
|
||||
@ -287,23 +317,31 @@ Status ExecutorImpl::Initialize() {
|
||||
delete[] nodes_;
|
||||
nodes_ = new NodeItem[num_nodes];
|
||||
|
||||
Status s;
|
||||
total_input_tensors_ = 0;
|
||||
total_output_tensors_ = 0;
|
||||
|
||||
InitializePending(graph_, &initial_pending_counts_);
|
||||
// Build the information about frames in this subgraph.
|
||||
ControlFlowInfo cf_info;
|
||||
BuildControlFlowInfo(graph_, &cf_info);
|
||||
|
||||
// Cache this value so we make this virtual function call once, rather
|
||||
// that O(# steps * # nodes per step) times.
|
||||
device_record_tensor_accesses_ =
|
||||
params_.device->RequiresRecordingAccessedTensors();
|
||||
|
||||
for (auto& it : cf_info.frame_name_to_size) {
|
||||
frame_info_[it.first].nodes = new std::vector<const Node*>;
|
||||
}
|
||||
frame_local_ids_ = new int[num_nodes];
|
||||
std::unordered_map<string, int> frame_count;
|
||||
|
||||
// Preprocess every node in the graph to create an instance of op
|
||||
// kernel for each node;
|
||||
// kernel for each node.
|
||||
for (const Node* n : graph_->nodes()) {
|
||||
const int id = n->id();
|
||||
const string& frame_name = cf_info.frame_names[id];
|
||||
FrameInfo& frame_info = frame_info_[frame_name];
|
||||
|
||||
// See if this node is a root node, and if so, add to root_nodes_
|
||||
// See if this node is a root node, and if so, add to root_nodes_.
|
||||
const int num_in_edges = n->in_edges().size();
|
||||
if (num_in_edges == 0) {
|
||||
root_nodes_.push_back(n);
|
||||
@ -321,18 +359,18 @@ Status ExecutorImpl::Initialize() {
|
||||
item->inlined_output_type[i] = n->output_type(i);
|
||||
}
|
||||
|
||||
item->input_start = total_input_tensors_;
|
||||
total_input_tensors_ += n->num_inputs();
|
||||
item->input_start = frame_info.total_inputs;
|
||||
frame_info.total_inputs += n->num_inputs();
|
||||
|
||||
item->output_attr_start = total_output_tensors_;
|
||||
total_output_tensors_ += n->num_outputs();
|
||||
|
||||
s = params_.create_kernel(n->def(), &item->kernel);
|
||||
Status s = params_.create_kernel(n->def(), &item->kernel);
|
||||
if (!s.ok()) {
|
||||
item->kernel = nullptr;
|
||||
s = AttachDef(s, n->def());
|
||||
LOG(ERROR) << "Executor failed to create kernel. " << s;
|
||||
break;
|
||||
return s;
|
||||
}
|
||||
CHECK(item->kernel);
|
||||
item->kernel_is_expensive = item->kernel->IsExpensive();
|
||||
@ -340,14 +378,18 @@ Status ExecutorImpl::Initialize() {
|
||||
item->is_merge = IsMerge(n);
|
||||
|
||||
// Initialize static information about the frames in the graph.
|
||||
frame_local_ids_[id] = frame_count[frame_name]++;
|
||||
frame_info.nodes->push_back(n);
|
||||
if (IsEnter(n)) {
|
||||
string frame_name;
|
||||
s = GetNodeAttr(n->def(), "frame_name", &frame_name);
|
||||
if (!s.ok()) return s;
|
||||
++frame_input_count_[frame_name];
|
||||
string enter_name;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "frame_name", &enter_name));
|
||||
++frame_info_[enter_name].input_count;
|
||||
}
|
||||
}
|
||||
if (!s.ok()) return s;
|
||||
|
||||
// Initialize PendingCounts only after frame_local_ids_ is initialized.
|
||||
InitializePending(graph_, cf_info);
|
||||
|
||||
return SetAllocAttrs();
|
||||
}
|
||||
|
||||
@ -533,12 +575,13 @@ class ExecutorState {
|
||||
typedef gtl::InlinedVector<Entry, 4> EntryVector;
|
||||
|
||||
struct IterationState {
|
||||
explicit IterationState(const ExecutorImpl* impl)
|
||||
: input_tensors(new Entry[impl->total_input_tensors_]),
|
||||
explicit IterationState(const PendingCounts* pending_counts,
|
||||
int total_input_tensors)
|
||||
: input_tensors(new Entry[total_input_tensors]),
|
||||
outstanding_ops(0),
|
||||
outstanding_frame_count(0),
|
||||
counts_(impl->graph_->num_node_ids()) {
|
||||
counts_.InitializeFrom(impl->initial_pending_counts_);
|
||||
counts_(pending_counts->num_nodes()) {
|
||||
counts_.InitializeFrom(*pending_counts);
|
||||
}
|
||||
|
||||
// The state of an iteration.
|
||||
@ -668,9 +711,23 @@ class ExecutorState {
|
||||
// will only "execute" the dead exits of the final iteration.
|
||||
std::vector<const Node*> dead_exits GUARDED_BY(mu);
|
||||
|
||||
// Static information specific to this frame.
|
||||
PendingCounts* pending_counts = nullptr;
|
||||
int total_input_tensors = 0;
|
||||
std::vector<const Node*>* nodes = nullptr;
|
||||
|
||||
// Lock ordering: ExecutorState.mu_ < mu.
|
||||
mutex mu;
|
||||
|
||||
void InitializeFrameInfo(const string& enter_name) {
|
||||
auto it_frame_info = executor->frame_info_.find(enter_name);
|
||||
DCHECK(it_frame_info != executor->frame_info_.end());
|
||||
pending_counts = it_frame_info->second.pending_counts;
|
||||
total_input_tensors = it_frame_info->second.total_inputs;
|
||||
num_pending_inputs = it_frame_info->second.input_count;
|
||||
nodes = it_frame_info->second.nodes;
|
||||
}
|
||||
|
||||
inline IterationState* GetIteration(int64 iter)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
int index = iter % iterations.size();
|
||||
@ -889,13 +946,12 @@ class ExecutorState {
|
||||
inline void MaybeMarkCompleted(FrameState* frame, int64 iter, int64 id);
|
||||
|
||||
// Provide debugging output about an outstanding node in the executor.
|
||||
void DumpCompletedNodeState(const int node_id, const Entry* input_vector);
|
||||
void DumpPendingNodeState(const int node_id, const Entry* input_vector,
|
||||
bool show_nodes_with_no_ready_inputs);
|
||||
void DumpActiveNodeState(const int node_id, const Entry* input_vector);
|
||||
|
||||
// Provide debugging output about an outstanding iteration in the executor.
|
||||
void DumpIterationState(IterationState* iteration);
|
||||
void DumpIterationState(const FrameState* frame, IterationState* iteration);
|
||||
|
||||
// Provide debugging output of the state of the executor.
|
||||
void DumpState();
|
||||
@ -932,16 +988,16 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
|
||||
num_outstanding_ops_(0) {
|
||||
// We start the entire execution in iteration 0 of the root frame
|
||||
// so let us create the root frame and the state for iteration 0.
|
||||
// Initialize the frame.
|
||||
// We assume root_frame_->frame_name.empty().
|
||||
root_frame_ = new FrameState(impl_, 1);
|
||||
root_frame_->frame_name = "_root"; // assume to be unique
|
||||
root_frame_->frame_id = 0; // must be 0
|
||||
// Initialize the first iteration.
|
||||
root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
|
||||
IterationState* iter_state = new IterationState(impl);
|
||||
root_frame_->iterations[0] = iter_state;
|
||||
root_frame_->InitializeFrameInfo(root_frame_->frame_name);
|
||||
|
||||
// Initialize iteration 0.
|
||||
root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
|
||||
root_frame_->iterations[0] = new IterationState(
|
||||
root_frame_->pending_counts, root_frame_->total_input_tensors);
|
||||
|
||||
if (vlog_) VLOG(2) << "Create frame: " << root_frame_->frame_name;
|
||||
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
|
||||
}
|
||||
|
||||
@ -949,21 +1005,88 @@ ExecutorState::~ExecutorState() {
|
||||
for (auto name_frame : outstanding_frames_) {
|
||||
delete name_frame.second;
|
||||
}
|
||||
|
||||
for (auto it : device_context_map_) {
|
||||
it->Unref();
|
||||
}
|
||||
|
||||
delete slice_reader_cache_;
|
||||
}
|
||||
|
||||
Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
|
||||
ControlFlowInfo* cf_info) {
|
||||
const int num_nodes = g->num_node_ids();
|
||||
cf_info->frame_names.resize(num_nodes);
|
||||
std::vector<Node*> parent_nodes;
|
||||
parent_nodes.resize(num_nodes);
|
||||
std::vector<bool> visited;
|
||||
visited.resize(num_nodes);
|
||||
|
||||
string frame_name;
|
||||
std::deque<Node*> ready;
|
||||
|
||||
// Initialize with the root nodes.
|
||||
for (Node* n : g->nodes()) {
|
||||
if (n->in_edges().empty()) {
|
||||
visited[n->id()] = true;
|
||||
++cf_info->frame_name_to_size[frame_name];
|
||||
ready.push_back(n);
|
||||
}
|
||||
}
|
||||
|
||||
while (!ready.empty()) {
|
||||
Node* curr_node = ready.front();
|
||||
int curr_id = curr_node->id();
|
||||
ready.pop_front();
|
||||
|
||||
Node* parent = nullptr;
|
||||
if (IsEnter(curr_node)) {
|
||||
// Enter a child frame.
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(curr_node->def(), "frame_name", &frame_name));
|
||||
parent = curr_node;
|
||||
} else if (IsExit(curr_node)) {
|
||||
// Exit to the parent frame.
|
||||
parent = parent_nodes[curr_id];
|
||||
frame_name = cf_info->frame_names[parent->id()];
|
||||
parent = parent_nodes[parent->id()];
|
||||
} else {
|
||||
parent = parent_nodes[curr_id];
|
||||
frame_name = cf_info->frame_names[curr_id];
|
||||
}
|
||||
|
||||
for (const Edge* out_edge : curr_node->out_edges()) {
|
||||
Node* out = out_edge->dst();
|
||||
int out_id = out->id();
|
||||
|
||||
// Add to ready queue if not visited.
|
||||
bool is_visited = visited[out_id];
|
||||
if (!is_visited) {
|
||||
ready.push_back(out);
|
||||
visited[out_id] = true;
|
||||
|
||||
// Process the node 'out'.
|
||||
cf_info->frame_names[out_id] = frame_name;
|
||||
parent_nodes[out_id] = parent;
|
||||
++cf_info->frame_name_to_size[frame_name];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ExecutorImpl::InitializePending(const Graph* graph,
|
||||
PendingCounts* counts) {
|
||||
for (int id = 0; id < graph->num_node_ids(); id++) {
|
||||
counts->set_initial_count(id, 0, 0); // Make sure everything is initialized
|
||||
const ControlFlowInfo& cf_info) {
|
||||
for (auto& it : cf_info.frame_name_to_size) {
|
||||
PendingCounts* counts = new PendingCounts(it.second);
|
||||
frame_info_[it.first].pending_counts = counts;
|
||||
// Make sure everything is initialized
|
||||
for (int id = 0; id < it.second; id++) {
|
||||
counts->set_initial_count(id, 0, 0);
|
||||
}
|
||||
}
|
||||
for (const Node* n : graph->nodes()) {
|
||||
const int id = n->id();
|
||||
const int pending_id = frame_local_ids_[id];
|
||||
const int num_in_edges = n->in_edges().size();
|
||||
int initial_count;
|
||||
if (IsMerge(n)) {
|
||||
@ -980,7 +1103,9 @@ void ExecutorImpl::InitializePending(const Graph* graph,
|
||||
} else {
|
||||
initial_count = num_in_edges;
|
||||
}
|
||||
counts->set_initial_count(id, initial_count, num_in_edges);
|
||||
const string& name = cf_info.frame_names[id];
|
||||
PendingCounts* counts = frame_info_[name].pending_counts;
|
||||
counts->set_initial_count(pending_id, initial_count, num_in_edges);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1104,8 +1229,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
// TODO(misard) Replace with a finer-grain enabling flag once we
|
||||
// add better optional debugging support.
|
||||
if (vlog_ && VLOG_IS_ON(1)) {
|
||||
int pending_id = impl_->frame_local_ids_[id];
|
||||
mutex_lock l(input_frame->mu);
|
||||
input_frame->GetIteration(input_iter)->mark_started(id);
|
||||
input_frame->GetIteration(input_iter)->mark_started(pending_id);
|
||||
}
|
||||
|
||||
// Set the device_context for this node id, if it exists.
|
||||
@ -1637,12 +1763,13 @@ void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
|
||||
}
|
||||
|
||||
inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter,
|
||||
int64 id) {
|
||||
int64 node_id) {
|
||||
// TODO(misard) Replace with a finer-grain enabling flag once we
|
||||
// add better optional debugging support.
|
||||
if (vlog_ && VLOG_IS_ON(1)) {
|
||||
int pending_id = impl_->frame_local_ids_[node_id];
|
||||
mutex_lock l(frame->mu);
|
||||
frame->GetIteration(iter)->mark_completed(id);
|
||||
frame->GetIteration(iter)->mark_completed(pending_id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1656,18 +1783,6 @@ const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) {
|
||||
}
|
||||
}
|
||||
|
||||
void ExecutorState::DumpCompletedNodeState(const int node_id,
|
||||
const Entry* input_vector) {
|
||||
const NodeItem& node_item = impl_->nodes_[node_id];
|
||||
const Node& node = *node_item.node;
|
||||
LOG(WARNING) << " Completed Node: " << node.DebugString();
|
||||
const int input_base = node_item.input_start;
|
||||
for (int i = 0; i < node.num_inputs(); ++i) {
|
||||
const Entry& input = input_vector[input_base + i];
|
||||
CHECK(!GetTensorValueForDump(input)->IsInitialized());
|
||||
}
|
||||
}
|
||||
|
||||
void ExecutorState::DumpPendingNodeState(
|
||||
const int node_id, const Entry* input_vector,
|
||||
const bool show_nodes_with_no_ready_inputs) {
|
||||
@ -1723,23 +1838,30 @@ void ExecutorState::DumpActiveNodeState(const int node_id,
|
||||
}
|
||||
}
|
||||
|
||||
void ExecutorState::DumpIterationState(IterationState* iteration) {
|
||||
void ExecutorState::DumpIterationState(const FrameState* frame,
|
||||
IterationState* iteration) {
|
||||
const std::vector<const Node*>* nodes = frame->nodes;
|
||||
// Dump any waiting nodes that are holding on to tensors.
|
||||
for (int i = 0; i < impl_->graph_->num_node_ids(); ++i) {
|
||||
if (iteration->node_state(i) == PendingCounts::PENDING_NOTREADY ||
|
||||
iteration->node_state(i) == PendingCounts::PENDING_READY) {
|
||||
DumpPendingNodeState(i, iteration->input_tensors, false);
|
||||
for (const Node* node : *nodes) {
|
||||
int node_id = node->id();
|
||||
int pending_id = impl_->frame_local_ids_[node_id];
|
||||
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
|
||||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
|
||||
DumpPendingNodeState(node_id, iteration->input_tensors, false);
|
||||
}
|
||||
}
|
||||
// Then the active nodes.
|
||||
for (int i = 0; i < impl_->graph_->num_node_ids(); ++i) {
|
||||
if (iteration->node_state(i) == PendingCounts::STARTED) {
|
||||
DumpActiveNodeState(i, iteration->input_tensors);
|
||||
for (const Node* node : *nodes) {
|
||||
int node_id = node->id();
|
||||
int pending_id = impl_->frame_local_ids_[node_id];
|
||||
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
|
||||
DumpActiveNodeState(pending_id, iteration->input_tensors);
|
||||
}
|
||||
}
|
||||
// Show all input tensors in use.
|
||||
int total_input_tensors = frame->total_input_tensors;
|
||||
size_t total_bytes = 0;
|
||||
for (int i = 0; i < impl_->total_input_tensors_; ++i) {
|
||||
for (int i = 0; i < total_input_tensors; ++i) {
|
||||
const Entry& input = iteration->input_tensors[i];
|
||||
const Tensor* tensor = GetTensorValueForDump(input);
|
||||
if (tensor->IsInitialized()) {
|
||||
@ -1764,7 +1886,7 @@ void ExecutorState::DumpState() {
|
||||
mutex_lock frame_lock(frame_state->mu);
|
||||
for (IterationState* iteration : frame_state->iterations) {
|
||||
LOG(WARNING) << " Iteration:";
|
||||
DumpIterationState(iteration);
|
||||
DumpIterationState(frame_state, iteration);
|
||||
}
|
||||
}
|
||||
dumped_on_error_ = true;
|
||||
@ -1819,16 +1941,13 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
temp->frame_id = Hash64(child_name);
|
||||
temp->parent_frame = frame;
|
||||
temp->parent_iter = iter;
|
||||
temp->InitializeFrameInfo(enter_name);
|
||||
|
||||
// 'iterations' is a fixed-length circular buffer.
|
||||
temp->iterations.resize(temp->max_parallel_iterations + 1);
|
||||
// Initialize the first iteration.
|
||||
IterationState* iter_state = new IterationState(impl_);
|
||||
temp->iterations[0] = iter_state;
|
||||
|
||||
auto frame_pending = impl_->frame_input_count_.find(enter_name);
|
||||
DCHECK(frame_pending != impl_->frame_input_count_.end());
|
||||
temp->num_pending_inputs = frame_pending->second;
|
||||
// Initialize iteration 0.
|
||||
temp->iterations[0] =
|
||||
new IterationState(temp->pending_counts, temp->total_input_tensors);
|
||||
|
||||
{
|
||||
mutex_lock executor_lock(mu_);
|
||||
@ -1851,33 +1970,40 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||
FrameState* parent_frame = frame->parent_frame;
|
||||
int64 parent_iter = frame->parent_iter;
|
||||
if (parent_frame != nullptr) {
|
||||
const int* pending_ids = impl_->frame_local_ids_;
|
||||
mutex_lock paranet_frame_lock(parent_frame->mu);
|
||||
// Propagate all the dead exits to the parent frame.
|
||||
for (const Node* node : frame->dead_exits) {
|
||||
auto parent_iter_state = parent_frame->GetIteration(parent_iter);
|
||||
for (const Edge* e : node->out_edges()) {
|
||||
const Node* dst_node = e->dst();
|
||||
const int dst_id = dst_node->id();
|
||||
const int dst_pending_id = pending_ids[dst_node->id()];
|
||||
|
||||
// TODO(yuanbyu): We don't need this if we require the subgraph
|
||||
// given to an executor not to contain a sink node.
|
||||
if (dst_node->IsSink()) continue;
|
||||
|
||||
bool dst_dead = true;
|
||||
bool dst_ready = false;
|
||||
// We know this is a dead input to dst.
|
||||
if (IsMerge(dst_node)) {
|
||||
if (e->IsControlEdge()) {
|
||||
parent_iter_state->decrement_pending(dst_id, 2);
|
||||
int count = parent_iter_state->pending(dst_id);
|
||||
dst_dead = (parent_iter_state->dead_count(dst_id) ==
|
||||
dst_node->num_inputs());
|
||||
parent_iter_state->decrement_pending(dst_pending_id, 2);
|
||||
int count = parent_iter_state->pending(dst_pending_id);
|
||||
int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
|
||||
dst_dead = (dead_cnt == dst_node->num_inputs());
|
||||
dst_ready = (count == 0) || ((count == 1) && dst_dead);
|
||||
} else {
|
||||
parent_iter_state->increment_dead_count(dst_id);
|
||||
const int dead_cnt = parent_iter_state->dead_count(dst_id);
|
||||
parent_iter_state->increment_dead_count(dst_pending_id);
|
||||
const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
|
||||
dst_dead = (dead_cnt == dst_node->num_inputs());
|
||||
dst_ready = (parent_iter_state->pending(dst_id) == 1) && dst_dead;
|
||||
dst_ready =
|
||||
(parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
|
||||
}
|
||||
} else {
|
||||
parent_iter_state->increment_dead_count(dst_id);
|
||||
dst_ready = (parent_iter_state->decrement_pending(dst_id, 1) == 0);
|
||||
parent_iter_state->increment_dead_count(dst_pending_id);
|
||||
dst_ready =
|
||||
(parent_iter_state->decrement_pending(dst_pending_id, 1) == 0);
|
||||
}
|
||||
if (dst_ready) {
|
||||
ready->push_back(
|
||||
@ -1923,12 +2049,18 @@ void ExecutorState::FrameState::ActivateNodes(const Node* node,
|
||||
const EntryVector& outputs,
|
||||
TaggedNodeSeq* ready) {
|
||||
const NodeItem* nodes = executor->nodes_;
|
||||
const int* pending_ids = executor->frame_local_ids_;
|
||||
IterationState* iter_state = GetIteration(iter);
|
||||
for (const Edge* e : node->out_edges()) {
|
||||
const Node* dst_node = e->dst();
|
||||
const int dst_id = dst_node->id();
|
||||
const int dst_pending_id = pending_ids[dst_id];
|
||||
const int src_slot = e->src_output();
|
||||
|
||||
// TODO(yuanbyu): We don't need this if we require the subgraph
|
||||
// given to an executor not to contain a sink node.
|
||||
if (dst_node->IsSink()) continue;
|
||||
|
||||
bool dst_dead = false;
|
||||
bool dst_ready = false;
|
||||
// True iff this input for dst is needed. We only set this input for
|
||||
@ -1940,15 +2072,16 @@ void ExecutorState::FrameState::ActivateNodes(const Node* node,
|
||||
// a) a live data input becomes available or b) all data inputs are dead.
|
||||
// For Merge, pending's LSB is set iff a live data input has arrived.
|
||||
if (e->IsControlEdge()) {
|
||||
iter_state->decrement_pending(dst_id, 2);
|
||||
int count = iter_state->pending(dst_id);
|
||||
dst_dead = (iter_state->dead_count(dst_id) == dst_node->num_inputs());
|
||||
iter_state->decrement_pending(dst_pending_id, 2);
|
||||
int count = iter_state->pending(dst_pending_id);
|
||||
int dead_cnt = iter_state->dead_count(dst_pending_id);
|
||||
dst_dead = (dead_cnt == dst_node->num_inputs());
|
||||
dst_ready = (count == 0) || ((count == 1) && dst_dead);
|
||||
} else {
|
||||
if (outputs[src_slot].has_value) {
|
||||
// This is a live data input.
|
||||
int count = iter_state->pending(dst_id);
|
||||
iter_state->mark_live(dst_id);
|
||||
int count = iter_state->pending(dst_pending_id);
|
||||
iter_state->mark_live(dst_pending_id);
|
||||
// Only the first live edge sets the input and (potentially)
|
||||
// triggers execution. The low bit of count is set if and
|
||||
// only if no live input has been used yet (mark_live clears
|
||||
@ -1962,10 +2095,10 @@ void ExecutorState::FrameState::ActivateNodes(const Node* node,
|
||||
// a dead enter. We need this to handle properly a while loop on
|
||||
// the untaken branch of a conditional.
|
||||
// TODO(yuanbyu): This is a bit hacky, but a good solution for now.
|
||||
iter_state->increment_dead_count(dst_id);
|
||||
const int dead_cnt = iter_state->dead_count(dst_id);
|
||||
iter_state->increment_dead_count(dst_pending_id);
|
||||
const int dead_cnt = iter_state->dead_count(dst_pending_id);
|
||||
dst_dead = (dead_cnt == dst_node->num_inputs()) || IsEnter(node);
|
||||
dst_ready = (iter_state->pending(dst_id) == 1) && dst_dead;
|
||||
dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
|
||||
dst_need_input = false;
|
||||
}
|
||||
}
|
||||
@ -1974,10 +2107,10 @@ void ExecutorState::FrameState::ActivateNodes(const Node* node,
|
||||
// for all inputs to come in even if we know the node is dead. This
|
||||
// ensures that all input tensors get cleaned up.
|
||||
if (is_dead || (!e->IsControlEdge() && !outputs[src_slot].has_value)) {
|
||||
iter_state->increment_dead_count(dst_id);
|
||||
iter_state->increment_dead_count(dst_pending_id);
|
||||
}
|
||||
dst_dead = iter_state->dead_count(dst_id) > 0;
|
||||
dst_ready = (iter_state->decrement_pending(dst_id, 1) == 0);
|
||||
dst_dead = iter_state->dead_count(dst_pending_id) > 0;
|
||||
dst_ready = (iter_state->decrement_pending(dst_pending_id, 1) == 0);
|
||||
}
|
||||
|
||||
if (dst_need_input) {
|
||||
@ -2052,7 +2185,8 @@ void ExecutorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
|
||||
int64 next_iter = iteration_count;
|
||||
|
||||
// Initialize the next iteration.
|
||||
IterationState* iter_state = new IterationState(executor);
|
||||
IterationState* iter_state =
|
||||
new IterationState(pending_counts, total_input_tensors);
|
||||
SetIteration(next_iter, iter_state);
|
||||
num_outstanding_iterations++;
|
||||
dead_exits.clear();
|
||||
|
@ -44,11 +44,7 @@ static const char* const kRetOp = "_Retval";
|
||||
static const char* const kGradientOp = "SymbolicGradient";
|
||||
static const char* const kNodeLabel = "Func";
|
||||
static const char* const kFuncAttr = "f";
|
||||
// kNoinlineAttr must start with an "_" to avoid collisions with
|
||||
// user-specified attrs.
|
||||
static const char* const kNoinlineAttr = "_noinline";
|
||||
// Old graphs use no "_".
|
||||
static const char* const kOldNoinlineAttr = "noinline";
|
||||
static const char* const kNoInlineAttr = "_noinline";
|
||||
|
||||
// Represents the index-th output of a node.
|
||||
struct Endpoint {
|
||||
@ -168,6 +164,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
|
||||
Device* device() override { return device_; }
|
||||
Env* env() override { return env_; }
|
||||
int graph_def_version() override { return graph_def_version_; }
|
||||
|
||||
string DebugString(Handle h) override;
|
||||
|
||||
@ -290,6 +287,34 @@ const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
|
||||
return func_graphs_[h];
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct CustomCreatorSingleton {
|
||||
mutex mu;
|
||||
CustomKernelCreator custom_creator = nullptr;
|
||||
|
||||
void Set(CustomKernelCreator cb) {
|
||||
mutex_lock l(mu);
|
||||
custom_creator = cb;
|
||||
}
|
||||
|
||||
CustomKernelCreator Get() {
|
||||
mutex_lock l(mu);
|
||||
return custom_creator;
|
||||
}
|
||||
};
|
||||
|
||||
CustomCreatorSingleton* GetCustomCreatorSingleton() {
|
||||
static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
|
||||
return ccs;
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
|
||||
void RegisterCustomKernelCreator(CustomKernelCreator cb) {
|
||||
GetCustomCreatorSingleton()->Set(cb);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
if (lib_def_->Find(ndef.op()) == nullptr) {
|
||||
@ -318,8 +343,23 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
||||
output_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY);
|
||||
}
|
||||
|
||||
// Constructs a CallOp kernel for running the instantiated function.
|
||||
// If a custom kernel creator is given, try that.
|
||||
CustomKernelCreator custom_creator = GetCustomCreatorSingleton()->Get();
|
||||
Status s;
|
||||
if (custom_creator) {
|
||||
std::unique_ptr<OpKernel> ret;
|
||||
s = custom_creator(this, ndef, &ret);
|
||||
if (s.ok()) {
|
||||
*kernel = ret.release();
|
||||
return s;
|
||||
} else {
|
||||
VLOG(2) << "Custom creator error: " << s;
|
||||
// Falls through.
|
||||
s = Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Constructs a CallOp kernel for running the instantiated function.
|
||||
auto device_type = DeviceType(device_->attributes().device_type());
|
||||
OpKernelConstruction construction(
|
||||
device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
|
||||
@ -327,7 +367,7 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
||||
fbody->ret_types, output_memory_types, graph_def_version_, &s);
|
||||
*kernel = new CallOp(handle, &construction);
|
||||
if (!s.ok()) {
|
||||
delete kernel;
|
||||
delete *kernel;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
@ -887,15 +927,11 @@ static void InlineFunctionBody(Graph* g, Node* caller,
|
||||
}
|
||||
|
||||
// Given a node's NodeDef, returns false iff the node explicitly
|
||||
// specified _noinline. This gives ExpandInlineFunctions a heuristic to
|
||||
// decide whether to inline the function.
|
||||
// `old` is true for GraphDef versions older than 12, when the
|
||||
// `noinline` attr was renamed to `_noinline` to avoid conflicts with
|
||||
// user-specified attrs.
|
||||
bool ShouldInline(const NodeDef& ndef, bool old) {
|
||||
// specified _noinline. This gives ExpandInlineFunctions a heuristic
|
||||
// to decide whether to inline the function.
|
||||
bool ShouldInline(const NodeDef& ndef) {
|
||||
bool noinline = false;
|
||||
const char* const attr = old ? kOldNoinlineAttr : kNoinlineAttr;
|
||||
if (GetNodeAttr(ndef, attr, &noinline).ok()) {
|
||||
if (GetNodeAttr(ndef, kNoInlineAttr, &noinline).ok()) {
|
||||
// If the node specifies attribute '_noinline', returns accordingly.
|
||||
return !noinline;
|
||||
}
|
||||
@ -914,7 +950,8 @@ bool ShouldInline(const NodeDef& ndef, bool old) {
|
||||
// continue and the runtime will error out.
|
||||
return false;
|
||||
}
|
||||
s = GetNodeAttr(AttrSlice(&forward_func_attrs->attr()), attr, &noinline);
|
||||
s = GetNodeAttr(AttrSlice(&forward_func_attrs->attr()), kNoInlineAttr,
|
||||
&noinline);
|
||||
if (!s.ok()) {
|
||||
// The forward function doesn't specify '_noinline' attr, we should
|
||||
// be free to decide.
|
||||
@ -926,11 +963,9 @@ bool ShouldInline(const NodeDef& ndef, bool old) {
|
||||
|
||||
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
|
||||
std::vector<std::pair<Node*, const FunctionBody*>> candidates;
|
||||
// Identify old graphs before the 'noinline' attr was renamed '_noinline'.
|
||||
const bool old_inline_attr = graph->versions().producer() < 12;
|
||||
for (Node* node : graph->nodes()) {
|
||||
VLOG(3) << "Expanding " << node->DebugString();
|
||||
if (!ShouldInline(node->def(), old_inline_attr)) {
|
||||
if (!ShouldInline(node->def())) {
|
||||
VLOG(3) << "noinline: " << node->DebugString();
|
||||
continue;
|
||||
}
|
||||
|
@ -123,6 +123,18 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
|
||||
// TODO(zhifengc): Asks math expert to say the comment again.
|
||||
FunctionBody* SymbolicGradient(const FunctionBody& f);
|
||||
|
||||
// Registers a customizable kernel creator for a function call.
|
||||
//
|
||||
// If 'cb()' returns a non-OK, we still fall back to an executor-based
|
||||
// interpreter op kernel to execute a function. If 'cb()' returns OK,
|
||||
// takes ownership of the returned OpKernel.
|
||||
//
|
||||
// TODO(zhifengc/phawkins): b/32379046
|
||||
typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
|
||||
std::unique_ptr<OpKernel>*)>
|
||||
CustomKernelCreator;
|
||||
void RegisterCustomKernelCreator(CustomKernelCreator cb);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
|
||||
|
@ -71,6 +71,7 @@ class PendingCounts {
|
||||
}
|
||||
}
|
||||
|
||||
inline int num_nodes() const { return num_nodes_; }
|
||||
NodeState node_state(int id) {
|
||||
if (IsLarge(id)) {
|
||||
return NodeStateLarge(id);
|
||||
@ -185,12 +186,7 @@ class PendingCounts {
|
||||
// use one byte to hold both the pending and dead count for a node
|
||||
// where these together can fit in one byte, and we use a hash table
|
||||
// to handle the rare node ids that need larger counts than this.
|
||||
|
||||
// TODO(yuanbyu): We current use O(# of nodes in partition) space
|
||||
// even for nested iterations where only a small fraction of the
|
||||
// nodes are involved. This is not efficient if the subgraph for
|
||||
// the frame is only a small subset of the partition. We should make
|
||||
// the vector size to be only the size of the frame subgraph.
|
||||
// Each frame in this subgraph has its own PendingCounts.
|
||||
|
||||
// We use 3 bits each for dead_count and pending.
|
||||
static const int kMaxCountForPackedCounts = 7;
|
||||
|
@ -27,6 +27,10 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
ShapeRefiner::ShapeRefiner(const OpRegistryInterface* ops)
|
||||
: ops_registry_(ops) {}
|
||||
|
||||
@ -37,7 +41,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
// from 'input's InferenceContext, and store into a vector
|
||||
// indexed by 'node's input.
|
||||
std::vector<Node*> input_nodes(node->num_inputs());
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes(node->num_inputs());
|
||||
std::vector<ShapeHandle> input_shapes(node->num_inputs());
|
||||
for (const Edge* e : node->in_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
|
||||
@ -49,7 +53,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
node->name(), "' was not previously added to ShapeRefiner.");
|
||||
}
|
||||
|
||||
shape_inference::InferenceContext* c = it->second;
|
||||
InferenceContext* c = it->second;
|
||||
DCHECK_GE(e->dst_input(), 0);
|
||||
input_nodes[e->dst_input()] = input;
|
||||
input_shapes[e->dst_input()] = c->output(e->src_output());
|
||||
@ -68,11 +72,13 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
std::vector<const Tensor*> input_tensors(node->num_inputs());
|
||||
std::vector<Tensor> real_tensors(node->num_inputs());
|
||||
std::vector<bool> attempted_materialization(node->num_inputs());
|
||||
std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
|
||||
std::vector<ShapeHandle> input_tensors_as_shapes;
|
||||
|
||||
// Create the inference context for this node with the existing input shapes.
|
||||
std::unique_ptr<shape_inference::InferenceContext> c(
|
||||
new shape_inference::InferenceContext(&node->def(), node->op_def(),
|
||||
input_shapes, input_tensors));
|
||||
std::unique_ptr<InferenceContext> c(
|
||||
new InferenceContext(&node->def(), node->op_def(), input_shapes,
|
||||
input_tensors, input_tensors_as_shapes));
|
||||
if (!c->construction_status().ok()) {
|
||||
return c->construction_status();
|
||||
}
|
||||
@ -101,63 +107,44 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
// subgraph once.
|
||||
|
||||
for (int i = 0; i < c->num_inputs(); ++i) {
|
||||
if (!c->requested_input_tensor(i)) {
|
||||
continue;
|
||||
}
|
||||
// Check if we have not already filled in the requested input,
|
||||
// and if not, try to materialize the tensors.
|
||||
if (c->requested_input_tensor(i) && !attempted_materialization[i]) {
|
||||
if (!attempted_materialization[i]) {
|
||||
attempted_materialization[i] = true;
|
||||
|
||||
const Edge* input_edge;
|
||||
TF_RETURN_IF_ERROR(node->input_edge(i, &input_edge));
|
||||
|
||||
bool is_constant_graph = false;
|
||||
Graph subgraph(ops_registry_);
|
||||
|
||||
// We identify the possibly constant subgraph to evaluate by
|
||||
// recursively iterating backwards through the inputs to 'node'
|
||||
// until we either 1) find an already existing input to our subgraph
|
||||
// (filled in `const_inputs`), 2) Discover our graph is not constant,
|
||||
// or 3) Hit a root node.
|
||||
std::vector<std::pair<string, Tensor>> const_inputs;
|
||||
TF_RETURN_IF_ERROR(ExtractConstantSubgraph(
|
||||
input_nodes[i], &subgraph, &is_constant_graph, &const_inputs));
|
||||
if (is_constant_graph) {
|
||||
const string output_tensor_name = strings::StrCat(
|
||||
input_nodes[i]->name(), ":", input_edge->src_output());
|
||||
std::vector<Tensor> outputs;
|
||||
// NOTE; we should pass in a function library runtime if we want
|
||||
// to support constant-expression evaluation on functions.
|
||||
Status s = GraphRunner::Run(&subgraph, nullptr /* function_library */,
|
||||
Env::Default(), const_inputs,
|
||||
{output_tensor_name}, &outputs);
|
||||
|
||||
// If all kernels in the constant graph are not registered
|
||||
// in the process, GraphRunner::Run may fail, in which case
|
||||
// we cannot propagate constants, so this is best-effort.
|
||||
if (s.ok()) {
|
||||
real_tensors[i] = outputs[0];
|
||||
input_tensors[i] = &real_tensors[i];
|
||||
|
||||
// We have more concrete information about a shape,
|
||||
// so re-run shape inference.
|
||||
rerun_shape_fn = true;
|
||||
|
||||
// We memoize (small) constants evaluated so far, so
|
||||
// ExtractConstantSubgraph can avoid extracting the full
|
||||
// subgraph. As we build up large graphs, this avoids
|
||||
// repeated computation of the early parts of a constant
|
||||
// graph.
|
||||
if (outputs[0].TotalBytes() <= kMaxTensorSize) {
|
||||
const_tensor_map_[output_tensor_name] = outputs[0];
|
||||
}
|
||||
}
|
||||
Tensor result;
|
||||
bool evaluated = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
|
||||
if (evaluated) {
|
||||
real_tensors[i] = result;
|
||||
input_tensors[i] = &real_tensors[i];
|
||||
// We have more concrete information about a shape,
|
||||
// so re-run shape inference.
|
||||
rerun_shape_fn = true;
|
||||
}
|
||||
}
|
||||
if (c->requested_input_tensor_as_partial_shape(i) &&
|
||||
!attempted_tensor_as_shape_conversion[i]) {
|
||||
attempted_tensor_as_shape_conversion[i] = true;
|
||||
if (i >= input_tensors_as_shapes.size()) {
|
||||
input_tensors_as_shapes.resize(i + 1);
|
||||
}
|
||||
ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(ConstantPartialShape(c.get(), node, i, &s));
|
||||
input_tensors_as_shapes[i] = s;
|
||||
rerun_shape_fn = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (rerun_shape_fn) {
|
||||
// We have more information about the shapes on this pass,
|
||||
// so re-run shape inference.
|
||||
c->set_input_tensors(input_tensors);
|
||||
c->set_input_tensors_as_shapes(input_tensors_as_shapes);
|
||||
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c.get()));
|
||||
}
|
||||
} while (rerun_shape_fn);
|
||||
@ -169,7 +156,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
}
|
||||
|
||||
Status ShapeRefiner::SetShape(const Node* node, int output_port,
|
||||
shape_inference::ShapeHandle shape) {
|
||||
ShapeHandle shape) {
|
||||
auto c = GetContext(node);
|
||||
if (c == nullptr) {
|
||||
return errors::Internal("Could not find context for ", node->name());
|
||||
@ -182,7 +169,7 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
|
||||
}
|
||||
|
||||
// Check compatibility, and merge the shapes.
|
||||
shape_inference::ShapeHandle existing_shape = c->output(output_port);
|
||||
ShapeHandle existing_shape = c->output(output_port);
|
||||
TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
|
||||
c->set_output(output_port, shape);
|
||||
|
||||
@ -196,6 +183,55 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
|
||||
int dst_idx, bool* evaluated,
|
||||
Tensor* result) {
|
||||
*evaluated = false;
|
||||
const Edge* input_edge;
|
||||
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
|
||||
|
||||
bool is_constant_graph = false;
|
||||
Graph subgraph(ops_registry_);
|
||||
|
||||
// We identify the possibly constant subgraph to evaluate by
|
||||
// recursively iterating backwards through the inputs to 'node'
|
||||
// until we either 1) find an already existing input to our subgraph
|
||||
// (filled in `const_inputs`), 2) Discover our graph is not constant,
|
||||
// or 3) Hit a root node.
|
||||
std::vector<std::pair<string, Tensor>> const_inputs;
|
||||
TF_RETURN_IF_ERROR(ExtractConstantSubgraph(
|
||||
input_edge->src(), &subgraph, &is_constant_graph, &const_inputs));
|
||||
if (!is_constant_graph) {
|
||||
return Status::OK();
|
||||
}
|
||||
const string output_tensor_name =
|
||||
strings::StrCat(input_edge->src()->name(), ":", input_edge->src_output());
|
||||
std::vector<Tensor> outputs;
|
||||
// NOTE; we should pass in a function library runtime if we want
|
||||
// to support constant-expression evaluation on functions.
|
||||
Status s = GraphRunner::Run(&subgraph, nullptr /* function_library */,
|
||||
Env::Default(), const_inputs,
|
||||
{output_tensor_name}, &outputs);
|
||||
|
||||
// If all kernels in the constant graph are not registered
|
||||
// in the process, GraphRunner::Run may fail, in which case
|
||||
// we cannot propagate constants, so this is best-effort.
|
||||
if (s.ok()) {
|
||||
*result = outputs[0];
|
||||
*evaluated = true;
|
||||
|
||||
// We memoize (small) constants evaluated so far, so
|
||||
// ExtractConstantSubgraph can avoid extracting the full
|
||||
// subgraph. As we build up large graphs, this avoids
|
||||
// repeated computation of the early parts of a constant
|
||||
// graph.
|
||||
if (outputs[0].TotalBytes() <= kMaxTensorSize) {
|
||||
const_tensor_map_[output_tensor_name] = outputs[0];
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::ExtractConstantSubgraph(
|
||||
Node* target_node, Graph* out_graph, bool* is_constant_graph,
|
||||
std::vector<std::pair<string, Tensor>>* const_inputs) {
|
||||
@ -308,4 +344,75 @@ Status ShapeRefiner::ExtractConstantSubgraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
const Node* node, int dst_idx,
|
||||
ShapeHandle* result) {
|
||||
const Edge* input_edge;
|
||||
TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
|
||||
|
||||
InferenceContext* src_context = GetContext(input_edge->src());
|
||||
if (src_context == nullptr) return errors::Internal("Missing src context");
|
||||
ShapeHandle src_shape = src_context->output(input_edge->src_output());
|
||||
TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
|
||||
|
||||
const string& src_op = input_edge->src()->type_string();
|
||||
if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
|
||||
// Source tensor is a vector of length 0, so the shape it
|
||||
// represents is as scalar.
|
||||
*result = target_context->Scalar();
|
||||
} else if (src_op == "Shape") {
|
||||
*result = src_context->input(0);
|
||||
} else if (src_op == "Pack") {
|
||||
std::vector<DimensionHandle> dims;
|
||||
// Pack is concatenating its input scalars to form the shape tensor vector.
|
||||
for (int i = 0; i < src_context->num_inputs(); ++i) {
|
||||
Tensor scalar;
|
||||
bool evaluated = false;
|
||||
TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i,
|
||||
&evaluated, &scalar));
|
||||
if (evaluated) {
|
||||
int64 size;
|
||||
if (scalar.dtype() == DT_INT32) {
|
||||
size = scalar.scalar<int32>()();
|
||||
} else if (scalar.dtype() == DT_INT64) {
|
||||
size = scalar.scalar<int64>()();
|
||||
} else {
|
||||
return errors::InvalidArgument("Pack input must be int32 or int64");
|
||||
}
|
||||
dims.push_back(size < 0 ? target_context->UnknownDim()
|
||||
: target_context->MakeDim(size));
|
||||
} else {
|
||||
dims.push_back(target_context->UnknownDim());
|
||||
}
|
||||
}
|
||||
*result = target_context->MakeShape(dims);
|
||||
} else if (src_op == "Concat") {
|
||||
*result = target_context->Scalar();
|
||||
// Concat is concatenating its input shape vectors.
|
||||
// input 0 is ignored as it is the concat dim and will always be 0.
|
||||
for (int i = 1; i < src_context->num_inputs(); ++i) {
|
||||
ShapeHandle sub_result;
|
||||
TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
|
||||
i, &sub_result));
|
||||
if (!target_context->RankKnown(sub_result)) {
|
||||
// Failed to evaluate. Treat the output as completely unknown.
|
||||
// TODO(cwhipkey): we could rely on all inputs being the same size, so
|
||||
// figure that size out and append the right number of unknown dims.
|
||||
*result = target_context->UnknownShape();
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
target_context->Concatenate(*result, sub_result, result));
|
||||
}
|
||||
} else {
|
||||
Tensor t;
|
||||
bool evaluated = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
|
||||
TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
|
||||
evaluated ? &t : nullptr, src_shape, result));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -71,6 +71,34 @@ class ShapeRefiner {
|
||||
Node* node, Graph* out_graph, bool* is_constant_graph,
|
||||
std::vector<std::pair<string, Tensor>>* const_inputs) TF_MUST_USE_RESULT;
|
||||
|
||||
Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
|
||||
bool* evaluated, Tensor* result);
|
||||
|
||||
// This function tries to materialize as much information about the 'node''s
|
||||
// dst_idx input as a statically computable shape, and the result may be
|
||||
// partially known, depending on what is statically inferable.
|
||||
//
|
||||
// This is called when node.input[dst_idx] is a tensor that is used to define
|
||||
// the shape of some other tensor (e.g., the second argument to Reshape is a
|
||||
// <shape> tensor, where each element of the shape tensor is a dimension of
|
||||
// the target tensor). It returns in <result> a shape for that input.
|
||||
//
|
||||
// Unlike simply resolving node.input[dst_idx] to a constant and then
|
||||
// converting that to a shape, this function can return a partial shape. This
|
||||
// is useful for cases where the shape tensor is only partially defined, such
|
||||
// as with calls for: reshape(x, shape(y)) where shape(y) is partially
|
||||
// defined.
|
||||
//
|
||||
// The implementation has op implementations for ops commonly called on shape
|
||||
// tensors, and the implementations are specialized to shape tensors (namely,
|
||||
// the output is a vector).
|
||||
//
|
||||
// <target_context> is used when creating new DimensionHandle and ShapeHandle
|
||||
// objects.
|
||||
Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
|
||||
const Node* node, int dst_idx,
|
||||
shape_inference::ShapeHandle* result);
|
||||
|
||||
const OpRegistryInterface* ops_registry_ = nullptr;
|
||||
|
||||
// Stores a map from a node to its InferenceContext.
|
||||
|
@ -398,5 +398,347 @@ TEST(ShapeRefinerTest, ConstantValueVisitNodeTwice) {
|
||||
EXPECT_EQ("[1,4,7]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Status TensorAsShapeShapeFn(shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0 /* input_idx */, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Register ops used by the ConstantValueAsShape* tests.
|
||||
|
||||
REGISTER_OP("TensorAsShapeInt32")
|
||||
.Input("a: int32")
|
||||
.Output("o: int32")
|
||||
.SetShapeFn(TensorAsShapeShapeFn);
|
||||
|
||||
REGISTER_OP("TensorAsShapeInt64")
|
||||
.Input("a: int64")
|
||||
.Output("o: int64")
|
||||
.SetShapeFn(TensorAsShapeShapeFn);
|
||||
|
||||
REGISTER_OP("NonConstScalarInt32")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("NonConstScalarInt64")
|
||||
.Output("o: int64")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("WithEmptyVectorShape")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(0));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("WithPartialShape")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(
|
||||
0, c->MakeShape({1, shape_inference::InferenceContext::kUnknownDim, 3,
|
||||
shape_inference::InferenceContext::kUnknownDim, 5}));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("WithPartialShape2")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(
|
||||
0,
|
||||
c->MakeShape({6, shape_inference::InferenceContext::kUnknownDim, 8}));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("WithUnknownShape")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Node* input;
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder("in", "WithEmptyVectorShape").Finalize(root.graph(), &input));
|
||||
Node* result;
|
||||
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
|
||||
.Input(input)
|
||||
.Finalize(root.graph(), &result));
|
||||
|
||||
ShapeRefiner m(OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(input));
|
||||
TF_ASSERT_OK(m.AddNode(result));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(result);
|
||||
EXPECT_EQ("[]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, ConstantValueAsShape_Shape) {
|
||||
for (int pass = 0; pass < 2; ++pass) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Node* input;
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder("in", pass == 0 ? "WithPartialShape" : "WithUnknownShape")
|
||||
.Finalize(root.graph(), &input));
|
||||
auto shape = ops::Shape(root, ops::Output(input));
|
||||
Node* result;
|
||||
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
|
||||
.Input(shape.node())
|
||||
.Finalize(root.graph(), &result));
|
||||
|
||||
ShapeRefiner m(OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(input));
|
||||
TF_ASSERT_OK(m.AddNode(shape.node()));
|
||||
TF_ASSERT_OK(m.AddNode(result));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(result);
|
||||
if (pass == 0) {
|
||||
EXPECT_EQ("[1,?,3,?,5]", ctx->DebugString(ctx->output(0)));
|
||||
} else {
|
||||
EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Node* scalar_non_const;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
||||
.Finalize(root.graph(), &scalar_non_const));
|
||||
|
||||
ops::InputList inputs{
|
||||
ops::Input(ops::Const<int32>(root, 10)),
|
||||
ops::Input(ops::Const<int32>(root, 20)),
|
||||
ops::Input(ops::Output(scalar_non_const)),
|
||||
ops::Input(ops::Const<int32>(root, 40)),
|
||||
};
|
||||
auto pack = ops::Pack(root, inputs);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
|
||||
.Input(pack.node())
|
||||
.Finalize(root.graph(), &result));
|
||||
|
||||
ShapeRefiner m(OpRegistry::Global());
|
||||
for (auto input : inputs) {
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
}
|
||||
TF_ASSERT_OK(m.AddNode(pack.node()));
|
||||
TF_ASSERT_OK(m.AddNode(result));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(result);
|
||||
EXPECT_EQ("[10,20,?,40]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Node* scalar_non_const;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
|
||||
.Finalize(root.graph(), &scalar_non_const));
|
||||
|
||||
ops::InputList inputs{
|
||||
ops::Input(ops::Const<int64>(root, 10LL)),
|
||||
ops::Input(ops::Const<int64>(root, 20LL)),
|
||||
ops::Input(ops::Output(scalar_non_const)),
|
||||
ops::Input(ops::Const<int64>(root, 1LL << 40)),
|
||||
};
|
||||
auto pack = ops::Pack(root, inputs);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
|
||||
.Input(pack.node())
|
||||
.Finalize(root.graph(), &result));
|
||||
|
||||
ShapeRefiner m(OpRegistry::Global());
|
||||
for (const auto& input : inputs) {
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
}
|
||||
TF_ASSERT_OK(m.AddNode(pack.node()));
|
||||
TF_ASSERT_OK(m.AddNode(result));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(result);
|
||||
EXPECT_EQ("[10,20,?,1099511627776]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
ops::InputList inputs{
|
||||
ops::Input(ops::Const<int64>(root, 10LL)),
|
||||
ops::Input(ops::Const<int64>(root, -1LL)),
|
||||
};
|
||||
auto pack = ops::Pack(root, inputs);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
|
||||
.Input(pack.node())
|
||||
.Finalize(root.graph(), &result));
|
||||
|
||||
ShapeRefiner m(OpRegistry::Global());
|
||||
for (const auto& input : inputs) {
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
}
|
||||
TF_ASSERT_OK(m.AddNode(pack.node()));
|
||||
TF_ASSERT_OK(m.AddNode(result));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(result);
|
||||
EXPECT_EQ("[10,?]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
// Inputs are length 2 vectors instead of scalars.
|
||||
ops::InputList inputs{
|
||||
ops::Input(ops::Const<int64>(root, {10LL, 20LL})),
|
||||
ops::Input(ops::Const<int64>(root, {10LL, 21LL})),
|
||||
};
|
||||
auto pack = ops::Pack(root, inputs);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
|
||||
.Input(pack.node())
|
||||
.Finalize(root.graph(), &result));
|
||||
|
||||
ShapeRefiner m(OpRegistry::Global());
|
||||
for (const auto& input : inputs) {
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
}
|
||||
TF_ASSERT_OK(m.AddNode(pack.node()));
|
||||
EXPECT_TRUE(
|
||||
StringPiece(m.AddNode(result).error_message()).contains("but is rank 2"));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Graph* g = root.graph();
|
||||
Node* partial_1;
|
||||
Node* partial_2;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
|
||||
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
|
||||
auto const_input = ops::Const(root, {9, 10, 11});
|
||||
ops::OutputList concat_inputs{
|
||||
ops::Shape(root, ops::Output(partial_1)),
|
||||
ops::Shape(root, ops::Output(partial_2)), const_input,
|
||||
};
|
||||
auto concat_dim = ops::Const(root, 0);
|
||||
auto concat = ops::Concat(root, concat_dim, concat_inputs);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
|
||||
.Input(concat.node())
|
||||
.Finalize(g, &result));
|
||||
|
||||
ShapeRefiner m(OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(partial_1));
|
||||
TF_ASSERT_OK(m.AddNode(partial_2));
|
||||
for (const auto& o : concat_inputs) {
|
||||
TF_ASSERT_OK(m.AddNode(o.node()));
|
||||
}
|
||||
TF_ASSERT_OK(m.AddNode(concat_dim.node()));
|
||||
TF_ASSERT_OK(m.AddNode(concat.node()));
|
||||
TF_ASSERT_OK(m.AddNode(result));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(result);
|
||||
EXPECT_EQ("[1,?,3,?,5,6,?,8,9,10,11]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Graph* g = root.graph();
|
||||
Node* scalar_non_const;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
||||
.Finalize(root.graph(), &scalar_non_const));
|
||||
|
||||
Node* partial_1;
|
||||
Node* partial_2;
|
||||
Node* unknown;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
|
||||
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
|
||||
TF_ASSERT_OK(NodeBuilder("in", "WithUnknownShape").Finalize(g, &unknown));
|
||||
ops::OutputList concat_inputs{
|
||||
ops::Shape(root, ops::Output(partial_1)),
|
||||
ops::Shape(root, ops::Output(partial_2)),
|
||||
ops::Shape(root, ops::Output(unknown)),
|
||||
};
|
||||
auto concat_dim = ops::Const(root, 0);
|
||||
auto concat = ops::Concat(root, concat_dim, concat_inputs);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
|
||||
.Input(concat.node())
|
||||
.Finalize(g, &result));
|
||||
|
||||
ShapeRefiner m(OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(partial_1));
|
||||
TF_ASSERT_OK(m.AddNode(partial_2));
|
||||
TF_ASSERT_OK(m.AddNode(unknown));
|
||||
for (const auto& o : concat_inputs) {
|
||||
TF_ASSERT_OK(m.AddNode(o.node()));
|
||||
}
|
||||
TF_ASSERT_OK(m.AddNode(concat_dim.node()));
|
||||
TF_ASSERT_OK(m.AddNode(concat.node()));
|
||||
TF_ASSERT_OK(m.AddNode(result));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(result);
|
||||
EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Graph* g = root.graph();
|
||||
Node* scalar_non_const;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
||||
.Finalize(root.graph(), &scalar_non_const));
|
||||
|
||||
Node* partial_1;
|
||||
Node* partial_2;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
|
||||
TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
|
||||
auto const_input = ops::Const(root, {9, -2, 11});
|
||||
ops::OutputList concat_inputs{
|
||||
ops::Shape(root, ops::Output(partial_1)),
|
||||
ops::Shape(root, ops::Output(partial_2)), //
|
||||
const_input,
|
||||
};
|
||||
auto concat_dim = ops::Const(root, 0);
|
||||
auto concat = ops::Concat(root, concat_dim, concat_inputs);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
|
||||
.Input(concat.node())
|
||||
.Finalize(g, &result));
|
||||
|
||||
ShapeRefiner m(OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(partial_1));
|
||||
TF_ASSERT_OK(m.AddNode(partial_2));
|
||||
for (const auto& o : concat_inputs) {
|
||||
TF_ASSERT_OK(m.AddNode(o.node()));
|
||||
}
|
||||
TF_ASSERT_OK(m.AddNode(concat_dim.node()));
|
||||
TF_ASSERT_OK(m.AddNode(concat.node()));
|
||||
EXPECT_EQ("Invalid value in tensor used for shape: -2",
|
||||
m.AddNode(result).error_message());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -274,16 +274,6 @@ Status SimpleGraphExecutionState::InitBaseGraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SimpleGraphExecutionState::UpdateCostsFromStats(const StepStats& ss) {
|
||||
mutex_lock l(mu_);
|
||||
costs_.MergeFromStats(node_name_to_cost_id_map_, ss);
|
||||
}
|
||||
|
||||
void SimpleGraphExecutionState::MergeCostsFromGlobal(CostModel* costs) {
|
||||
mutex_lock l(mu_);
|
||||
costs->MergeFromGlobal(costs_);
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::BuildGraph(
|
||||
const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
|
||||
VLOG(1) << "BuildGraph";
|
||||
|
@ -133,22 +133,6 @@ class SimpleGraphExecutionState {
|
||||
Status BuildGraph(const BuildGraphOptions& options,
|
||||
std::unique_ptr<SimpleClientGraph>* out);
|
||||
|
||||
// Sums execution statistics in "ss" into the CostModel.
|
||||
void UpdateCostsFromStats(const StepStats& ss);
|
||||
|
||||
Microseconds TimeEstimate(const Node* n) {
|
||||
mutex_lock l(mu_); // could use reader lock
|
||||
return costs_.TimeEstimate(n);
|
||||
}
|
||||
|
||||
Bytes SizeEstimate(const Node* n, int output_slot) {
|
||||
mutex_lock l(mu_); // could use reader lock
|
||||
return costs_.SizeEstimate(n, output_slot);
|
||||
}
|
||||
|
||||
// Merge the cost model maintained by this graph_execution_state to 'costs'.
|
||||
void MergeCostsFromGlobal(CostModel* costs);
|
||||
|
||||
// The graph returned by BuildGraph may contain only the pruned
|
||||
// graph, whereas some clients may want access to the full graph.
|
||||
const Graph* full_graph() {
|
||||
|
@ -335,7 +335,9 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) {
|
||||
}
|
||||
|
||||
TEST_F(SessionDebugMinusAXTest,
|
||||
RunSimpleNetworkConcurrentlyWithDebugNodesInserted) {
|
||||
RunSimpleNetworkConcurrentlyWithDifferentDebugTensorWatches) {
|
||||
// Test concurrent Run() calls on a graph with different debug watches.
|
||||
|
||||
Initialize({3, 2, -1, 0});
|
||||
std::unique_ptr<DirectSession> session(CreateSession());
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
@ -351,33 +353,39 @@ TEST_F(SessionDebugMinusAXTest,
|
||||
|
||||
mutex mu;
|
||||
DebugGateway debug_gateway(session.get());
|
||||
std::vector<Tensor> debug_identity_tensor_vals;
|
||||
std::unordered_map<string, Tensor> debug_identity_tensor_vals;
|
||||
|
||||
const string debug_identity = "DebugIdentity";
|
||||
const string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
|
||||
|
||||
const string a_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
|
||||
strings::StrCat(a_, ":", 0), 0, debug_identity);
|
||||
const string x_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
|
||||
strings::StrCat(x_, ":", 0), 0, debug_identity);
|
||||
const string y_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
|
||||
strings::StrCat(y_, ":", 0), 0, debug_identity);
|
||||
|
||||
Notification callbacks_done;
|
||||
int comp_callback_count = 0;
|
||||
int val_callback_count = 0;
|
||||
debug_gateway.SetNodeCompletionCallback(
|
||||
[&mu, &callbacks_done, &comp_callback_count, &debug_identity_node_name](
|
||||
const string& node_name, const bool any_output) {
|
||||
mutex_lock l(mu);
|
||||
if (node_name == debug_identity_node_name) {
|
||||
comp_callback_count++;
|
||||
}
|
||||
});
|
||||
volatile int val_callback_count = 0;
|
||||
|
||||
debug_gateway.SetNodeValueCallback(
|
||||
[this, &mu, &val_callback_count, &debug_identity_node_name,
|
||||
[this, &mu, &val_callback_count, &a_debug_identity_node_name,
|
||||
&x_debug_identity_node_name, &y_debug_identity_node_name,
|
||||
&debug_identity_tensor_vals,
|
||||
&callbacks_done](const string& node_name, const int output_slot,
|
||||
const Tensor& tensor_value, const bool is_ref) {
|
||||
mutex_lock l(mu);
|
||||
if (node_name == debug_identity_node_name && output_slot == 0) {
|
||||
|
||||
if (node_name == a_debug_identity_node_name && output_slot == 0) {
|
||||
debug_identity_tensor_vals["a"] = tensor_value;
|
||||
val_callback_count++;
|
||||
} else if (node_name == x_debug_identity_node_name &&
|
||||
output_slot == 0) {
|
||||
// output_slot == 0 carries the debug signal.
|
||||
debug_identity_tensor_vals.push_back(tensor_value);
|
||||
debug_identity_tensor_vals["x"] = tensor_value;
|
||||
val_callback_count++;
|
||||
} else if (node_name == y_debug_identity_node_name &&
|
||||
output_slot == 0) {
|
||||
debug_identity_tensor_vals["y"] = tensor_value;
|
||||
val_callback_count++;
|
||||
}
|
||||
|
||||
@ -389,19 +397,41 @@ TEST_F(SessionDebugMinusAXTest,
|
||||
}
|
||||
});
|
||||
|
||||
int run_counter = 0;
|
||||
mutex run_lock;
|
||||
|
||||
// Function to be executed concurrently.
|
||||
auto fn = [this, &session, output_names, target_nodes, &debug_identity]() {
|
||||
// Create unique debug tensor watch options for each of the two concurrent
|
||||
auto fn = [this, &run_lock, &run_counter, &session, output_names,
|
||||
target_nodes, &debug_identity]() {
|
||||
// Create unique debug tensor watch options for each of the concurrent
|
||||
// run calls.
|
||||
RunOptions run_opts;
|
||||
run_opts.set_output_partition_graphs(true);
|
||||
|
||||
DebugTensorWatch* tensor_watch_opts =
|
||||
run_opts.add_debug_tensor_watch_opts();
|
||||
|
||||
tensor_watch_opts->set_node_name(y_);
|
||||
tensor_watch_opts->set_output_slot(0);
|
||||
tensor_watch_opts->add_debug_ops(debug_identity);
|
||||
|
||||
{
|
||||
// Let the concurrent runs watch different tensors.
|
||||
|
||||
mutex_lock l(run_lock);
|
||||
|
||||
if (run_counter == 0) {
|
||||
// Let the 1st concurrent run watch a.
|
||||
tensor_watch_opts->set_node_name(a_);
|
||||
} else if (run_counter == 1) {
|
||||
// Let the 2nd concurrent watch x.
|
||||
tensor_watch_opts->set_node_name(x_);
|
||||
} else if (run_counter == 2) {
|
||||
// Let the 3rd concurrent watch y.
|
||||
tensor_watch_opts->set_node_name(y_);
|
||||
}
|
||||
|
||||
run_counter++;
|
||||
}
|
||||
|
||||
// Run the graph.
|
||||
RunMetadata run_metadata;
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
@ -436,15 +466,26 @@ TEST_F(SessionDebugMinusAXTest,
|
||||
|
||||
{
|
||||
mutex_lock l(mu);
|
||||
ASSERT_EQ(kConcurrentRuns, comp_callback_count);
|
||||
|
||||
ASSERT_EQ(kConcurrentRuns, val_callback_count);
|
||||
ASSERT_EQ(kConcurrentRuns, debug_identity_tensor_vals.size());
|
||||
for (int i = 0; i < kConcurrentRuns; ++i) {
|
||||
ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals[i].shape());
|
||||
auto mat_identity = debug_identity_tensor_vals[i].matrix<float>();
|
||||
ASSERT_EQ(5.0, mat_identity(0, 0));
|
||||
ASSERT_EQ(-1.0, mat_identity(1, 0));
|
||||
}
|
||||
|
||||
ASSERT_EQ(TensorShape({2, 2}), debug_identity_tensor_vals["a"].shape());
|
||||
auto a_mat_identity = debug_identity_tensor_vals["a"].matrix<float>();
|
||||
ASSERT_EQ(3.0, a_mat_identity(0, 0));
|
||||
ASSERT_EQ(2.0, a_mat_identity(0, 1));
|
||||
ASSERT_EQ(-1.0, a_mat_identity(1, 0));
|
||||
ASSERT_EQ(0.0, a_mat_identity(1, 1));
|
||||
|
||||
ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals["x"].shape());
|
||||
auto x_mat_identity = debug_identity_tensor_vals["x"].matrix<float>();
|
||||
ASSERT_EQ(1.0, x_mat_identity(0, 0));
|
||||
ASSERT_EQ(1.0, x_mat_identity(1, 0));
|
||||
|
||||
ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals["y"].shape());
|
||||
auto y_mat_identity = debug_identity_tensor_vals["y"].matrix<float>();
|
||||
ASSERT_EQ(5.0, y_mat_identity(0, 0));
|
||||
ASSERT_EQ(-1.0, y_mat_identity(1, 0));
|
||||
}
|
||||
}
|
||||
|
||||
@ -499,25 +540,22 @@ TEST_F(SessionDebugOutputSlotWithoutOngoingEdgeTest,
|
||||
|
||||
Notification callbacks_done;
|
||||
|
||||
debug_gateway.SetNodeCompletionCallback(
|
||||
[&mu, &callbacks_done](const string& node_name, const bool any_output) {
|
||||
mutex_lock l(mu);
|
||||
if (node_name == "_SINK" && !callbacks_done.HasBeenNotified()) {
|
||||
callbacks_done.Notify();
|
||||
}
|
||||
});
|
||||
|
||||
std::vector<Tensor> debug_identity_tensor_vals;
|
||||
debug_gateway.SetNodeValueCallback(
|
||||
[this, &mu, &debug_identity_node_name, &debug_identity_tensor_vals](
|
||||
const string& node_name, const int output_slot,
|
||||
const Tensor& tensor_value, const bool is_ref) {
|
||||
mutex_lock l(mu);
|
||||
debug_gateway.SetNodeValueCallback([this, &mu, &callbacks_done,
|
||||
&debug_identity_node_name,
|
||||
&debug_identity_tensor_vals](
|
||||
const string& node_name, const int output_slot,
|
||||
const Tensor& tensor_value, const bool is_ref) {
|
||||
mutex_lock l(mu);
|
||||
|
||||
if (node_name == debug_identity_node_name && output_slot == 0) {
|
||||
debug_identity_tensor_vals.push_back(tensor_value);
|
||||
}
|
||||
});
|
||||
if (node_name == debug_identity_node_name && output_slot == 0) {
|
||||
debug_identity_tensor_vals.push_back(tensor_value);
|
||||
|
||||
if (!callbacks_done.HasBeenNotified()) {
|
||||
callbacks_done.Notify();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Add DebugIdentity watch on c:0, which does not have an outgoing edge.
|
||||
RunOptions run_opts;
|
||||
|
@ -24,6 +24,30 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const string SummarizeDebugTensorWatches(
|
||||
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches) {
|
||||
std::ostringstream oss;
|
||||
|
||||
for (const DebugTensorWatch& watch : watches) {
|
||||
string tensor_name =
|
||||
strings::StrCat(watch.node_name(), ":", watch.output_slot());
|
||||
oss << tensor_name << "|";
|
||||
|
||||
for (const string& debug_op : watch.debug_ops()) {
|
||||
oss << debug_op << ",";
|
||||
}
|
||||
|
||||
oss << "@";
|
||||
for (const string& debug_url : watch.debug_urls()) {
|
||||
oss << debug_url << ",";
|
||||
}
|
||||
|
||||
oss << ";";
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// static
|
||||
Status DebugNodeInserter::InsertNodes(
|
||||
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches, Graph* graph,
|
||||
|
@ -27,6 +27,10 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns a summary string for RepeatedPtrFields of DebugTensorWatches.
|
||||
const string SummarizeDebugTensorWatches(
|
||||
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches);
|
||||
|
||||
class DebugNodeInserter {
|
||||
public:
|
||||
// EXPERIMENTAL: Insert special debug ops (e.g., DebugIdentity) to graph for
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
@ -207,6 +208,11 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
||||
if (!s.ok()) {
|
||||
break;
|
||||
}
|
||||
unit->graph = subgraph;
|
||||
unit->build_cost_model = graph_options.build_cost_model();
|
||||
if (unit->build_cost_model > 0) {
|
||||
skip_cost_models_ = false;
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
@ -319,6 +325,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
||||
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
const ExecutorOpts& opts,
|
||||
StepStatsCollector* collector,
|
||||
CostGraphDef* cost_graph,
|
||||
CancellationManager* cancellation_manager,
|
||||
const NamedTensors& in, StatusCallback done) {
|
||||
// Lookup an item. Holds one ref while executing.
|
||||
@ -348,7 +355,7 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
return;
|
||||
}
|
||||
|
||||
StartParallelExecutors(handle, item, rendezvous, collector,
|
||||
StartParallelExecutors(handle, item, rendezvous, collector, cost_graph,
|
||||
cancellation_manager,
|
||||
[this, item, rendezvous, done](const Status& s) {
|
||||
done(s);
|
||||
@ -360,6 +367,7 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
void GraphMgr::StartParallelExecutors(const string& handle, Item* item,
|
||||
Rendezvous* rendezvous,
|
||||
StepStatsCollector* collector,
|
||||
CostGraphDef* cost_graph,
|
||||
CancellationManager* cancellation_manager,
|
||||
StatusCallback done) {
|
||||
const int num_units = item->units.size();
|
||||
@ -367,7 +375,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, Item* item,
|
||||
ResourceMgr* step_resource_manager = new ResourceMgr;
|
||||
// NOTE: Transfer one ref of rendezvous and item.
|
||||
ExecutorBarrier* barrier = new ExecutorBarrier(
|
||||
num_units, rendezvous, [step_resource_manager, done](const Status& s) {
|
||||
num_units, rendezvous, [this, item, collector, cost_graph,
|
||||
step_resource_manager, done](const Status& s) {
|
||||
BuildCostModel(item, collector, cost_graph);
|
||||
done(s);
|
||||
delete step_resource_manager;
|
||||
});
|
||||
@ -393,4 +403,24 @@ void GraphMgr::StartParallelExecutors(const string& handle, Item* item,
|
||||
}
|
||||
}
|
||||
|
||||
void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
|
||||
CostGraphDef* cost_graph) {
|
||||
if (collector && !skip_cost_models_) {
|
||||
// Build the cost model
|
||||
std::unordered_map<string, const Graph*> device_to_graph;
|
||||
for (const auto& unit : item->units) {
|
||||
if (unit.build_cost_model > 0) {
|
||||
device_to_graph[unit.device->name()] = unit.graph;
|
||||
}
|
||||
}
|
||||
collector->BuildCostModel(&cost_model_manager_, device_to_graph);
|
||||
|
||||
if (cost_graph != nullptr) {
|
||||
for (const auto& unit : item->units) {
|
||||
cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -19,9 +19,11 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/costmodel_manager.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -73,6 +75,7 @@ class GraphMgr {
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
void ExecuteAsync(const string& handle, const int64 step_id,
|
||||
const ExecutorOpts& opts, StepStatsCollector* collector,
|
||||
CostGraphDef* cost_graph,
|
||||
CancellationManager* cancellation_manager,
|
||||
const NamedTensors& in, StatusCallback done);
|
||||
|
||||
@ -89,9 +92,12 @@ class GraphMgr {
|
||||
typedef GraphMgr ME;
|
||||
|
||||
struct ExecutionUnit {
|
||||
Graph* graph = nullptr;
|
||||
Device* device = nullptr;
|
||||
Executor* root = nullptr;
|
||||
FunctionLibraryRuntime* lib = nullptr;
|
||||
// Build the cost model if this value is strictly positive.
|
||||
int64 build_cost_model = 0;
|
||||
};
|
||||
|
||||
struct Item : public core::RefCounted {
|
||||
@ -117,6 +123,8 @@ class GraphMgr {
|
||||
// Not owned.
|
||||
const WorkerEnv* worker_env_;
|
||||
|
||||
CostModelManager cost_model_manager_;
|
||||
|
||||
// Owned.
|
||||
mutex mu_;
|
||||
int64 next_id_ GUARDED_BY(mu_) = 0;
|
||||
@ -131,9 +139,17 @@ class GraphMgr {
|
||||
void StartParallelExecutors(const string& handle, Item* item,
|
||||
Rendezvous* rendezvous,
|
||||
StepStatsCollector* collector,
|
||||
CostGraphDef* cost_graph,
|
||||
CancellationManager* cancellation_manager,
|
||||
StatusCallback done);
|
||||
|
||||
// Don't attempt to process cost models unless explicitely requested for at
|
||||
// least one of the items.
|
||||
bool skip_cost_models_ = true;
|
||||
|
||||
void BuildCostModel(Item* item, StepStatsCollector* collector,
|
||||
CostGraphDef* cost_graph);
|
||||
|
||||
Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in);
|
||||
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out);
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/scheduler.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -58,6 +59,7 @@ struct PerStepState {
|
||||
Microseconds end_micros = Microseconds(0);
|
||||
std::vector<StepStats> step_stats; // per partition
|
||||
StepStats rpc_stats; // for RPC layer
|
||||
CostGraphDef cost_graph;
|
||||
};
|
||||
|
||||
// MasterSession wraps SimpleClientGraph in a reference counted object.
|
||||
@ -178,7 +180,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
// Post-processing of any runtime statistics gathered during execution.
|
||||
void ProcessStats(const MasterEnv* env, int64 step_id, PerStepState* pss,
|
||||
SimpleGraphExecutionState* execution_state,
|
||||
ProfileHandler* ph, RunStepResponse* resp);
|
||||
ProfileHandler* ph, const RunStepRequest& req,
|
||||
RunStepResponse* resp);
|
||||
void ProcessDeviceStats(ProfileHandler* ph,
|
||||
const SimpleGraphExecutionState* execution_state,
|
||||
const DeviceStepStats& ds, bool is_rpc);
|
||||
@ -480,17 +483,6 @@ class RunManyGraphs {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
|
||||
};
|
||||
|
||||
int64 CostFrequency(int64 x) {
|
||||
if (x < 10) {
|
||||
return 1; // 100%
|
||||
} else if (x < 100) {
|
||||
return 10; // 10%
|
||||
} else if (x < 1000) {
|
||||
return 100; // 1%
|
||||
} else {
|
||||
return 1000; // 0.1%
|
||||
}
|
||||
}
|
||||
|
||||
Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
const MasterEnv* env, int64 step_id, int64 execution_count,
|
||||
@ -604,6 +596,12 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
if (pss->collect_timeline && calls.get(i)->resp.has_step_stats()) {
|
||||
pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats());
|
||||
}
|
||||
if (pss->collect_costs && calls.get(i)->resp.has_cost_graph()) {
|
||||
for (int j = 0; j < calls.get(i)->resp.cost_graph().node_size(); ++j) {
|
||||
resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
|
||||
calls.get(i)->resp.mutable_cost_graph()->mutable_node(j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return status;
|
||||
@ -679,7 +677,7 @@ void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
|
||||
void MasterSession::ReffedClientGraph::ProcessStats(
|
||||
const MasterEnv* env, int64 step_id, PerStepState* pss,
|
||||
SimpleGraphExecutionState* execution_state, ProfileHandler* ph,
|
||||
RunStepResponse* resp) {
|
||||
const RunStepRequest& req, RunStepResponse* resp) {
|
||||
if (!pss->collect_costs && !pss->collect_timeline) return;
|
||||
|
||||
// Out-of-band logging data is collected now, during post-processing.
|
||||
@ -689,9 +687,6 @@ void MasterSession::ReffedClientGraph::ProcessStats(
|
||||
}
|
||||
for (size_t i = 0; i < partitions_.size(); ++i) {
|
||||
const StepStats& ss = pss->step_stats[i];
|
||||
if (pss->collect_costs) {
|
||||
execution_state->UpdateCostsFromStats(ss);
|
||||
}
|
||||
if (ph) {
|
||||
for (const auto& ds : ss.dev_stats()) {
|
||||
ProcessDeviceStats(ph, execution_state, ds, false /*is_rpc*/);
|
||||
@ -717,7 +712,7 @@ void MasterSession::ReffedClientGraph::ProcessStats(
|
||||
stats_publisher_->PublishStatsProto(step_stats_proto);
|
||||
// Copy the stats back, but only for on-demand profiling to avoid slowing
|
||||
// down calls that trigger the automatic profiling.
|
||||
if (session_opts_.config.graph_options().timeline_step() <= 0) {
|
||||
if (req.options().trace_level() == RunOptions::FULL_TRACE) {
|
||||
resp->mutable_metadata()->mutable_step_stats()->Swap(&step_stats_proto);
|
||||
}
|
||||
}
|
||||
@ -1063,7 +1058,17 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
|
||||
std::unique_ptr<ProfileHandler> ph;
|
||||
pss.collect_timeline = req->options().trace_level() == RunOptions::FULL_TRACE;
|
||||
pss.collect_costs = (0 == (count % CostFrequency(count)));
|
||||
|
||||
// Build the cost model every 'build_cost_model_every' steps after skipping an
|
||||
// initial 'build_cost_model_after' steps.
|
||||
const int64 build_cost_model_after =
|
||||
session_opts_.config.graph_options().build_cost_model_after();
|
||||
const int64 build_cost_model_every =
|
||||
session_opts_.config.graph_options().build_cost_model();
|
||||
pss.collect_costs =
|
||||
build_cost_model_every > 0 &&
|
||||
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
|
||||
|
||||
ph = rcg->GetProfileHandler(step_id, count, req->options());
|
||||
if (ph) {
|
||||
pss.collect_timeline = true;
|
||||
@ -1078,7 +1083,7 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
|
||||
// Schedule post-processing and cleanup to be done asynchronously.
|
||||
rcg->Ref();
|
||||
rcg->ProcessStats(env_, step_id, &pss, execution_state_.get(), ph.get(),
|
||||
rcg->ProcessStats(env_, step_id, &pss, execution_state_.get(), ph.get(), *req,
|
||||
resp);
|
||||
rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
|
@ -329,7 +329,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
return;
|
||||
}
|
||||
StepStatsCollector* collector = nullptr;
|
||||
if (call->request.exec_opts().record_timeline()) {
|
||||
if (call->request.exec_opts().record_timeline() ||
|
||||
call->request.exec_opts().record_costs()) {
|
||||
collector = new StepStatsCollector(call->response.mutable_step_stats());
|
||||
// TODO(mrry,pbar): GPU tracing for distributed steps.
|
||||
}
|
||||
@ -345,9 +346,10 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
cancellation_manager_->RegisterCallback(token,
|
||||
[cm]() { cm->StartCancel(); });
|
||||
}
|
||||
CostGraphDef* cost_graph = call->response.mutable_cost_graph();
|
||||
env_->graph_mgr->ExecuteAsync(
|
||||
call->request.graph_handle(), step_id, call->request.exec_opts(),
|
||||
collector, cm, in,
|
||||
collector, cost_graph, cm, in,
|
||||
[this, step_id, call, cm, out, token, collector](Status s) {
|
||||
if (s.ok()) {
|
||||
env_->graph_mgr->RecvOutputs(step_id, out);
|
||||
|
@ -56,7 +56,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
|
||||
.Input({{"data", 0, DT_FLOAT}})
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(&def, op_def, {S({}), S({10})}, {});
|
||||
InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {});
|
||||
TF_EXPECT_OK(NoOutputs(&c));
|
||||
EXPECT_EQ(0, c.num_outputs());
|
||||
}
|
||||
@ -74,14 +74,14 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
|
||||
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({})}, {});
|
||||
InferenceContext c(&def, op_def, {S({})}, {}, {});
|
||||
TF_EXPECT_OK(ScalarShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(0, c.Rank(output));
|
||||
}
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {});
|
||||
InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {});
|
||||
TF_EXPECT_OK(ScalarShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(0, c.Rank(output));
|
||||
@ -108,7 +108,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Unknown inner dimension for one
|
||||
InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -126,7 +126,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Invalid rank.
|
||||
InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
@ -136,7 +136,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Unknown outer dimension
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -145,7 +145,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Inner shapes not compatible
|
||||
InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
@ -156,7 +156,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Inner shapes not compatible
|
||||
InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
@ -174,7 +174,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Attr("type", DT_FLOAT)
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {});
|
||||
InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -191,7 +191,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Attr("type", DT_FLOAT)
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -215,7 +215,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -224,7 +224,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
// Unknown ranks.
|
||||
InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {});
|
||||
InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_FALSE(c.RankKnown(output));
|
||||
@ -232,7 +232,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
// Rank > 2
|
||||
InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {});
|
||||
InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
|
||||
@ -245,7 +245,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
|
||||
@ -258,7 +258,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {});
|
||||
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {},
|
||||
{});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
|
||||
@ -271,7 +272,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {});
|
||||
InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[10,11,12]", c.DebugString(output));
|
||||
@ -279,7 +280,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(&def, op_def, {S({3}), S({3})}, {});
|
||||
InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
|
||||
@ -291,7 +292,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
}
|
||||
@ -310,7 +311,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({2, 10})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 10})}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
@ -318,7 +319,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
|
||||
{
|
||||
// Rank > 2
|
||||
InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {});
|
||||
InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
@ -330,7 +331,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
||||
@ -342,7 +343,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {});
|
||||
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
||||
@ -354,7 +355,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({10, 11, 12})}, {});
|
||||
InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
@ -362,7 +363,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(&def, op_def, {S({3})}, {});
|
||||
InferenceContext c(&def, op_def, {S({3})}, {}, {});
|
||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||
}
|
||||
|
||||
@ -373,7 +374,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(&def, op_def, {S({2, 3})}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3})}, {}, {});
|
||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||
}
|
||||
}
|
||||
|
@ -400,6 +400,9 @@ class FunctionLibraryRuntime {
|
||||
// Returns a debug string showing the definition of the function of
|
||||
// 'handle'.
|
||||
virtual string DebugString(Handle handle) = 0;
|
||||
|
||||
// Returns the graph version number.
|
||||
virtual int graph_def_version() = 0;
|
||||
};
|
||||
|
||||
// To register a gradient function for a builtin op, one should use
|
||||
|
@ -30,9 +30,10 @@ constexpr int64 InferenceContext::kUnknownDim;
|
||||
InferenceContext::InferenceContext(
|
||||
const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<TensorShapeProto>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors)
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes)
|
||||
: node_def_(*CHECK_NOTNULL(node_def)) {
|
||||
PreInputInit(op_def, input_tensors);
|
||||
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
|
||||
if (!construction_status_.ok()) return;
|
||||
for (const TensorShapeProto& p : input_shapes) {
|
||||
ShapeHandle shape;
|
||||
@ -48,9 +49,10 @@ InferenceContext::InferenceContext(
|
||||
InferenceContext::InferenceContext(
|
||||
const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<ShapeHandle>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors)
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes)
|
||||
: node_def_(*CHECK_NOTNULL(node_def)) {
|
||||
PreInputInit(op_def, input_tensors);
|
||||
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
|
||||
if (!construction_status_.ok()) return;
|
||||
inputs_ = input_shapes;
|
||||
PostInputInit();
|
||||
@ -106,8 +108,10 @@ Status InferenceContext::output(StringPiece output_name,
|
||||
}
|
||||
|
||||
void InferenceContext::PreInputInit(
|
||||
const OpDef& op_def, const std::vector<const Tensor*>& input_tensors) {
|
||||
const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes) {
|
||||
input_tensors_ = input_tensors;
|
||||
input_tensors_as_shapes_ = input_tensors_as_shapes;
|
||||
|
||||
construction_status_ =
|
||||
NameRangesForNode(node_def_, op_def, &input_name_map_, &output_name_map_);
|
||||
@ -139,6 +143,7 @@ void InferenceContext::PostInputInit() {
|
||||
CHECK_LE(input_tensors_.size(), inputs_.size());
|
||||
input_tensors_.resize(inputs_.size());
|
||||
requested_input_tensor_.resize(inputs_.size());
|
||||
requested_input_tensor_as_partial_shape_.resize(inputs_.size());
|
||||
}
|
||||
|
||||
bool InferenceContext::FullyDefined(ShapeHandle s) {
|
||||
@ -470,11 +475,24 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
|
||||
ShapeHandle input_shape;
|
||||
TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
|
||||
|
||||
const Tensor* t = input_tensor(input_idx);
|
||||
if (input_idx < input_tensors_as_shapes_.size() &&
|
||||
input_tensors_as_shapes_[input_idx].IsSet() &&
|
||||
RankKnown(input_tensors_as_shapes_[input_idx])) {
|
||||
*out = input_tensors_as_shapes_[input_idx];
|
||||
return Status::OK();
|
||||
}
|
||||
requested_input_tensor_as_partial_shape_[input_idx] = true;
|
||||
|
||||
return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out);
|
||||
}
|
||||
|
||||
Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
|
||||
ShapeHandle tensor_shape,
|
||||
ShapeHandle* out) {
|
||||
if (t == nullptr) {
|
||||
// Shape tensor is not known, but if the shape of the shape tensor is then
|
||||
// the right number of unknown dims can be created.
|
||||
DimensionHandle shape_dim = Dim(input_shape, 0);
|
||||
DimensionHandle shape_dim = Dim(tensor_shape, 0);
|
||||
if (!ValueKnown(shape_dim)) {
|
||||
return ReturnUnknownShape(out);
|
||||
}
|
||||
@ -493,12 +511,24 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
|
||||
if (t->dtype() == DataType::DT_INT32) {
|
||||
auto flat_t = t->flat<int32>();
|
||||
for (int i = 0; i < flat_t.size(); ++i) {
|
||||
dims.push_back(MakeDim(flat_t(i)));
|
||||
const int32 val = flat_t(i);
|
||||
if (val < -1) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid value in tensor used for shape: ", val);
|
||||
}
|
||||
// -1 will become an unknown dim.
|
||||
dims.push_back(MakeDim(val));
|
||||
}
|
||||
} else if (t->dtype() == DataType::DT_INT64) {
|
||||
auto flat_t = t->flat<int64>();
|
||||
for (int i = 0; i < flat_t.size(); ++i) {
|
||||
dims.push_back(MakeDim(flat_t(i)));
|
||||
const int64 val = flat_t(i);
|
||||
if (val < -1) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid value in tensor used for shape: ", val);
|
||||
}
|
||||
// -1 will become an unknown dim.
|
||||
dims.push_back(MakeDim(val));
|
||||
}
|
||||
} else {
|
||||
*out = nullptr;
|
||||
@ -558,24 +588,27 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InferenceContext::Divide(DimensionHandle dividend, int64 divisor,
|
||||
Status InferenceContext::Divide(DimensionHandle dividend,
|
||||
DimensionOrConstant divisor,
|
||||
bool evenly_divisible, DimensionHandle* out) {
|
||||
if (divisor == 1) {
|
||||
const int64 divisor_value = Value(divisor);
|
||||
if (divisor_value == 1) {
|
||||
*out = dividend;
|
||||
} else if (!ValueKnown(dividend)) {
|
||||
} else if (!ValueKnown(dividend) ||
|
||||
(divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
|
||||
*out = UnknownDim();
|
||||
} else {
|
||||
const int64 v = Value(dividend);
|
||||
if (divisor <= 0) {
|
||||
if (divisor_value <= 0) {
|
||||
return errors::InvalidArgument("Divisor must be positive but is ",
|
||||
divisor);
|
||||
divisor_value);
|
||||
}
|
||||
if (evenly_divisible && (v % divisor) != 0) {
|
||||
if (evenly_divisible && (v % divisor_value) != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Dimension size must be evenly divisible by ", divisor, " but is ",
|
||||
v);
|
||||
"Dimension size must be evenly divisible by ", divisor_value,
|
||||
" but is ", v);
|
||||
}
|
||||
*out = MakeDim(v / divisor);
|
||||
*out = MakeDim(v / divisor_value);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -136,17 +136,33 @@ class InferenceContext {
|
||||
|
||||
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
|
||||
//
|
||||
// Elements of <input_tensors_as_shapes> are used for when a shape function
|
||||
// makes a call to MakeShapeFromShapeTensor; in particular, when the
|
||||
// input_tensors[i] is nullptr but the shape represented by it is partially
|
||||
// known from analysis of the graph.
|
||||
// <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
|
||||
// Values of <input_tensors_as_shapes> do not need to outlive the context.
|
||||
//
|
||||
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
|
||||
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<ShapeHandle>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors);
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes);
|
||||
|
||||
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
|
||||
//
|
||||
// Elements of <input_tensors_as_shapes> are used for when a shape function
|
||||
// makes a call to MakeShapeFromShapeTensor; in particular, when the
|
||||
// input_tensors[i] is nullptr but the shape represented by it is partially
|
||||
// known from analysis of the graph.
|
||||
// <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
|
||||
// Values of <input_tensors_as_shapes> do not need to outlive the context.
|
||||
//
|
||||
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
|
||||
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<TensorShapeProto>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors);
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes);
|
||||
|
||||
~InferenceContext();
|
||||
|
||||
@ -180,10 +196,21 @@ class InferenceContext {
|
||||
return requested_input_tensor_[idx];
|
||||
}
|
||||
|
||||
// Returns true if MakeShapeFromInputTensor was called but the constant
|
||||
// input_tensor was not present.
|
||||
bool requested_input_tensor_as_partial_shape(int idx) const {
|
||||
return requested_input_tensor_as_partial_shape_[idx];
|
||||
}
|
||||
|
||||
void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
|
||||
input_tensors_ = input_tensors;
|
||||
}
|
||||
|
||||
void set_input_tensors_as_shapes(
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes) {
|
||||
input_tensors_as_shapes_ = input_tensors_as_shapes;
|
||||
}
|
||||
|
||||
void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
|
||||
Status set_output(StringPiece output_name,
|
||||
const std::vector<ShapeHandle>& shapes);
|
||||
@ -336,8 +363,8 @@ class InferenceContext {
|
||||
// Returns in <out> the result of dividing <dividend> by <divisor>.
|
||||
// Returns an error if <divisor> is not positive or if <evenly_divisible>
|
||||
// and <divisor> does not evenly divide <dividend>.
|
||||
Status Divide(DimensionHandle dividend, int64 divisor, bool evenly_divisible,
|
||||
DimensionHandle* out);
|
||||
Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
|
||||
bool evenly_divisible, DimensionHandle* out);
|
||||
|
||||
// Returns in <out> the sum of <first> and <second>.
|
||||
Status Add(DimensionHandle first, DimensionOrConstant second,
|
||||
@ -408,6 +435,15 @@ class InferenceContext {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note that shape functions should usually call MakeShapeFromShapeTensor,
|
||||
// as it does more analysis to provide partial shapes.
|
||||
//
|
||||
// Returns in <out> a new shape whose dimension sizes come from tensor <t>.
|
||||
// The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL,
|
||||
// then an unknown shape is returned.
|
||||
Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
|
||||
ShapeHandle* out);
|
||||
|
||||
private:
|
||||
// Creates and stores shapes for use in InferenceContext.
|
||||
class ShapeManager {
|
||||
@ -443,7 +479,8 @@ class InferenceContext {
|
||||
// Shared initialization across the two constructors. Remove
|
||||
// once we get rid of one of them.
|
||||
void PreInputInit(const OpDef& op_def,
|
||||
const std::vector<const Tensor*>& input_tensors);
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes);
|
||||
void PostInputInit();
|
||||
|
||||
DimensionHandle GetDimension(const DimensionOrConstant& d);
|
||||
@ -463,11 +500,15 @@ class InferenceContext {
|
||||
|
||||
ShapeManager shape_manager_;
|
||||
|
||||
// inputs_ and outputs_ refer to values from `shape_manager_`.
|
||||
// inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
|
||||
// `shape_manager_`.
|
||||
std::vector<ShapeHandle> inputs_;
|
||||
std::vector<const Tensor*> input_tensors_;
|
||||
std::vector<bool> requested_input_tensor_;
|
||||
std::vector<ShapeHandle> outputs_;
|
||||
// Can have fewer elements than inputs_.
|
||||
std::vector<ShapeHandle> input_tensors_as_shapes_;
|
||||
std::vector<bool> requested_input_tensor_as_partial_shape_;
|
||||
|
||||
const NodeDef& node_def_;
|
||||
NameRangeMap input_name_map_;
|
||||
|
@ -71,7 +71,7 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
|
||||
.Attr("N", 3)
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Finalize(&def);
|
||||
InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {});
|
||||
InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {});
|
||||
|
||||
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
|
||||
EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
|
||||
@ -107,7 +107,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {});
|
||||
EXPECT_EQ(InferenceContext::kUnknownDim,
|
||||
c.Value(InferenceContext::kUnknownDim));
|
||||
EXPECT_EQ(1, c.Value(1));
|
||||
@ -122,7 +122,7 @@ TEST_F(ShapeInferenceTest, Run) {
|
||||
NodeDef def;
|
||||
def.set_name("foo");
|
||||
def.set_op("foo_op");
|
||||
InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {});
|
||||
|
||||
{
|
||||
auto fn = [](InferenceContext* c) {
|
||||
@ -154,7 +154,7 @@ TEST_F(ShapeInferenceTest, Run) {
|
||||
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})},
|
||||
{});
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(2, c.num_outputs());
|
||||
|
||||
@ -195,7 +195,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
|
||||
TEST_F(ShapeInferenceTest, NumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 2),
|
||||
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {});
|
||||
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {});
|
||||
|
||||
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
|
||||
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
|
||||
@ -208,7 +208,7 @@ TEST_F(ShapeInferenceTest, NumElements) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -246,7 +246,7 @@ TEST_F(ShapeInferenceTest, WithRank) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -284,7 +284,7 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -322,7 +322,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithValue) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {});
|
||||
|
||||
auto d0 = c.Dim(c.input(0), 0);
|
||||
auto d1 = c.Dim(c.input(0), 1);
|
||||
@ -363,7 +363,7 @@ TEST_F(ShapeInferenceTest, WithValue) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, MergeDim) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {});
|
||||
|
||||
auto d2 = c.Dim(c.input(0), 0);
|
||||
auto d_unknown = c.Dim(c.input(0), 1);
|
||||
@ -412,7 +412,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
|
||||
InferenceContext c(&def, MakeOpDef(7, 2),
|
||||
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
|
||||
Unknown(), S({1})},
|
||||
{});
|
||||
{}, {});
|
||||
|
||||
auto s_unknown = c.input(0);
|
||||
auto s_1_2 = c.input(1);
|
||||
@ -483,7 +483,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
|
||||
{
|
||||
Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
|
||||
},
|
||||
{});
|
||||
{}, {});
|
||||
|
||||
auto s_unknown = c.input(0);
|
||||
auto s_u_2 = c.input(1);
|
||||
@ -536,7 +536,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
|
||||
TEST_F(ShapeInferenceTest, Subshape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()},
|
||||
{});
|
||||
{}, {});
|
||||
|
||||
ShapeHandle unknown = c.input(1);
|
||||
ShapeHandle out;
|
||||
@ -611,7 +611,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
|
||||
TEST_F(ShapeInferenceTest, Concatenate) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 2),
|
||||
{S({1, -1, 3}), S({4, 5}), Unknown()}, {});
|
||||
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -637,7 +637,7 @@ TEST_F(ShapeInferenceTest, Concatenate) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ReplaceDim) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {});
|
||||
InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {});
|
||||
|
||||
auto in = c.input(0);
|
||||
auto unknown = c.input(1);
|
||||
@ -668,7 +668,7 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, MakeShape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {});
|
||||
|
||||
std::vector<DimensionHandle> dims;
|
||||
auto in0 = c.input(0);
|
||||
@ -693,7 +693,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
|
||||
TEST_F(ShapeInferenceTest, UnknownShape) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
|
||||
auto u0 = c.UnknownShape();
|
||||
auto u1 = c.UnknownShape();
|
||||
@ -705,7 +705,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
|
||||
TEST_F(ShapeInferenceTest, Scalar) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
|
||||
auto s0 = c.Scalar();
|
||||
EXPECT_EQ("[]", c.DebugString(s0));
|
||||
@ -716,7 +716,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
|
||||
TEST_F(ShapeInferenceTest, Vector) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
|
||||
auto s0 = c.Vector(1);
|
||||
EXPECT_EQ("[1]", c.DebugString(s0));
|
||||
@ -732,7 +732,7 @@ TEST_F(ShapeInferenceTest, Vector) {
|
||||
TEST_F(ShapeInferenceTest, Matrix) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
|
||||
auto s0 = c.Matrix(1, 2);
|
||||
EXPECT_EQ("[1,2]", c.DebugString(s0));
|
||||
@ -754,7 +754,7 @@ TEST_F(ShapeInferenceTest, Matrix) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
auto create = [&](Tensor* t) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t});
|
||||
InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {});
|
||||
ShapeHandle out;
|
||||
Status s = c.MakeShapeFromShapeTensor(0, &out);
|
||||
if (s.ok()) {
|
||||
@ -774,6 +774,9 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
t = ::tensorflow::test::AsTensor<int64>({3, 2, 1});
|
||||
EXPECT_EQ("[3,2,1]", create(&t));
|
||||
|
||||
t = ::tensorflow::test::AsTensor<int64>({3, -1, 1});
|
||||
EXPECT_EQ("[3,?,1]", create(&t));
|
||||
|
||||
t = ::tensorflow::test::AsTensor<int64>({});
|
||||
EXPECT_EQ("[]", create(&t));
|
||||
|
||||
@ -790,10 +793,20 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
EXPECT_TRUE(StringPiece(create(&t))
|
||||
.contains("Input tensor must be rank 1, but was rank 2"));
|
||||
|
||||
// Test negative values for the dims.
|
||||
t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
|
||||
EXPECT_TRUE(StringPiece(create(&t))
|
||||
.contains("Invalid value in tensor used for shape: -2"));
|
||||
|
||||
// Test negative values for the dims.
|
||||
t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
|
||||
EXPECT_TRUE(StringPiece(create(&t))
|
||||
.contains("Invalid value in tensor used for shape: -2"));
|
||||
|
||||
// Test when the input shape is wrong.
|
||||
{
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr});
|
||||
InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {});
|
||||
ShapeHandle out;
|
||||
EXPECT_EQ("Shape must be rank 1 but is rank 2",
|
||||
c.MakeShapeFromShapeTensor(0, &out).error_message());
|
||||
@ -803,7 +816,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
TensorShapeProto proto;
|
||||
|
||||
// With a set unknown rank.
|
||||
@ -839,7 +852,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
||||
TEST_F(ShapeInferenceTest, MakeDim) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
|
||||
auto d0 = c.MakeDim(1);
|
||||
auto d1 = c.MakeDim(1);
|
||||
@ -853,7 +866,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
|
||||
TEST_F(ShapeInferenceTest, UnknownDim) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
|
||||
auto d0 = c.UnknownDim();
|
||||
auto d1 = c.UnknownDim();
|
||||
@ -865,7 +878,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
|
||||
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
|
||||
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
|
||||
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
|
||||
@ -879,7 +892,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
|
||||
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
|
||||
{&t1, &t2});
|
||||
{&t1, &t2}, {});
|
||||
|
||||
EXPECT_TRUE(c.input_tensor(0) == &t1);
|
||||
EXPECT_TRUE(c.input_tensor(1) == &t2);
|
||||
@ -890,7 +903,7 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
|
||||
Tensor t1 = tensorflow::test::AsScalar<int32>(20);
|
||||
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2});
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {});
|
||||
|
||||
DimensionHandle d;
|
||||
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
|
||||
@ -921,7 +934,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
|
||||
.ok());
|
||||
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, op_reg_data.op_def, empty, {});
|
||||
InferenceContext c(&def, op_reg_data.op_def, empty, {}, {});
|
||||
string value;
|
||||
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
|
||||
EXPECT_EQ("bar", value);
|
||||
@ -929,11 +942,14 @@ TEST_F(ShapeInferenceTest, GetAttr) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Divide) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
auto d_unknown = c.Dim(s, 1);
|
||||
auto d_1 = c.Dim(s, 2);
|
||||
auto d_2 = c.Dim(s, 3);
|
||||
auto d_0 = c.Dim(s, 4);
|
||||
bool evenly_divisible = true;
|
||||
|
||||
// Dividing unknown by non-1 gives new unknown.
|
||||
@ -947,9 +963,15 @@ TEST_F(ShapeInferenceTest, Divide) {
|
||||
EXPECT_TRUE(SameHandle(out, d_unknown));
|
||||
EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok());
|
||||
EXPECT_TRUE(SameHandle(out, d_6));
|
||||
EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok());
|
||||
EXPECT_TRUE(SameHandle(out, d_unknown));
|
||||
EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok());
|
||||
EXPECT_TRUE(SameHandle(out, d_6));
|
||||
|
||||
EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok());
|
||||
EXPECT_EQ("3", c.DebugString(out));
|
||||
EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
|
||||
EXPECT_EQ("3", c.DebugString(out));
|
||||
|
||||
EXPECT_TRUE(
|
||||
StringPiece(c.Divide(d_6, 5, evenly_divisible, &out).error_message())
|
||||
@ -958,6 +980,9 @@ TEST_F(ShapeInferenceTest, Divide) {
|
||||
EXPECT_TRUE(
|
||||
StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message())
|
||||
.contains("Divisor must be positive but is 0"));
|
||||
EXPECT_TRUE(
|
||||
StringPiece(c.Divide(d_6, d_0, evenly_divisible, &out).error_message())
|
||||
.contains("Divisor must be positive but is 0"));
|
||||
|
||||
EXPECT_TRUE(
|
||||
StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message())
|
||||
@ -979,7 +1004,7 @@ TEST_F(ShapeInferenceTest, Divide) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Add) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1030,7 +1055,7 @@ TEST_F(ShapeInferenceTest, Add) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Subtract) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1079,7 +1104,7 @@ TEST_F(ShapeInferenceTest, Subtract) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Multiply) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1132,7 +1157,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
|
||||
TEST_F(ShapeInferenceTest, FullyDefined) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
|
||||
// No rank or missing dimension information should return false.
|
||||
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
|
||||
@ -1145,7 +1170,7 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Min) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_1 = c.Dim(s, 0);
|
||||
@ -1193,7 +1218,7 @@ TEST_F(ShapeInferenceTest, Min) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Max) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_1 = c.Dim(s, 0);
|
||||
@ -1231,7 +1256,7 @@ TEST_F(ShapeInferenceTest, Max) {
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
|
||||
{});
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1243,7 +1268,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})},
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
|
||||
{});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1256,7 +1281,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {},
|
||||
{});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1269,7 +1295,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {},
|
||||
{});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1282,7 +1309,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {},
|
||||
{});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1295,7 +1323,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
|
||||
{});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1307,7 +1336,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
|
||||
{});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1319,7 +1349,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
|
||||
{});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1331,7 +1362,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
|
||||
{});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1343,7 +1375,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {},
|
||||
{});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
|
@ -44,7 +44,8 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
|
||||
}
|
||||
|
||||
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
|
||||
in_shapes, op.input_tensors);
|
||||
in_shapes, op.input_tensors,
|
||||
{} /* input_tensors_as_shapes */);
|
||||
TF_RETURN_IF_ERROR(c.construction_status());
|
||||
if (op_reg_data->shape_inference_fn == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
|
@ -243,6 +243,11 @@ void CostModel::RecordMaxMemorySize(const Node* node, int output_slot,
|
||||
if (id < 0) return;
|
||||
Ensure(id);
|
||||
auto& current_max = max_mem_usage_[id].output_port_mem[output_slot];
|
||||
// If the memory allocator doesn't track memory usage, let's infer a lower
|
||||
// bound from the tensor shape and its data type.
|
||||
if (bytes.value() < 0) {
|
||||
bytes = MinTensorMemoryUsage(tensor_shape, dtype);
|
||||
}
|
||||
if (bytes.value() > current_max.value()) {
|
||||
current_max = bytes.value();
|
||||
max_mem_usage_[id].output_port_shape[output_slot] = tensor_shape;
|
||||
@ -476,4 +481,18 @@ void CostModel::WriteSummaryToLog() const {
|
||||
}
|
||||
}
|
||||
|
||||
Bytes CostModel::MinTensorMemoryUsage(const TensorShapeProto& tensor_shape,
|
||||
const DataType& dtype) {
|
||||
if (tensor_shape.unknown_rank()) {
|
||||
return Bytes(-1);
|
||||
}
|
||||
|
||||
size_t num_coefficients = 1;
|
||||
for (const TensorShapeProto::Dim& dim : tensor_shape.dim()) {
|
||||
// If the dimension is unknown, it has to be at least 1
|
||||
num_coefficients *= std::max<size_t>(dim.size(), 1);
|
||||
}
|
||||
return Bytes(num_coefficients * DataTypeSize(dtype));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user