TensorFlow: upstream changes to git
Change 109366961 TensorFlow BUILD: now that we have an ops library, set linkstatic to 1. This fixes a breakage in the would-be opensource build, and it *might* mean we can get rid of all of the RequireDefaultOps() calls in our code. The ops library is much smaller than the kernels library that was previously linked together. We set linkstatic=0 presumably since we didn't want to package a static copy of the kernels (very large) everywhere. But the op definitions are small, so this seems like a safe change to make. Time to build the various tests was not any longer after this change, and inspecting the example_trainer binary showed no large increase. Change 109363613 TensorFlow: new graph_def_builder_test needs to RequireDefaultOps. Change 109362569 Split ":ops" out of ":kernels" target in tensorflow/core. Change 109360666 Catch dtype and some shape errors sooner in `QueueBase`. Some avoidable errors were not being caught (e.g. the dtypes of the enqueue components were not checked against the queue's dtypes in Python), leading to cryptic messages at runtime. After this CL, they will be caught earlier. Change 109359569 TensorFlow: Expect g_ != nullptr in test Change 109350735 Add a version number to GraphDef We would like to be able to deprecate behavior in newly generated graphs without invalidating tensorflow's ability to read and evaluate old graphs. For this purpose, GraphDef now has a version field which can be checked inside op kernels to determine how backwards compatible to be. version.h defines TF_GRAPHDEF_VERSION_MIN and TF_GRAPHDEF_VERSION_MAX specifying the range of supported GraphDef versions in the current version of tensorflow. Also expose tf.__version__ and tf.__graph_def_version{,_min,_max}__ for Python interrogation purposes. Whenever we want to deprecate or change some GraphDef semantics, we will proceed as follows: 1. Bump TF_GRAPHDEF_VERSION_MAX, leaving TF_GRAPHDEF_VERSION_MIN unchanged. Describe the change in graph.proto, include the date introduced. 2. In each relevant kernel, implement the new behavior if the GraphDef version is new, but preserve the old behavior for previous GraphDef versions. 3. Wait six months or so (we need to formalize this somewhere). 4. Bump TF_GRAPHDEF_VERSION_MIN and remove the backwards compatibility. The GraphDef version is distinct from the open source version, but at least (4) and possibly (1) correspond to major version number bumps. The first GraphDef version bump is the upcoming scalar strictness change, which affects Google users only since open source is already scalar strict. This commit does not yet plumb the version number into OpKernelConstruction so that ops can access it. That will follow. Change 109350260 Made TensorShapeProto implicitly convertible to TensorShape. Base CL: 109366982
This commit is contained in:
parent
eb5e56e479
commit
54a644f33f
tensorflow
core
BUILD
common_runtime
framework
graph
equal_graph_def.ccequal_graph_def_test.ccgraph.ccgraph.hgraph_constructor.ccgraph_constructor_test.ccgraph_def_builder_test.ccgraph_partition.ccgraph_partition_test.cc
public
python
@ -187,10 +187,7 @@ cc_library(
|
||||
"graph/testlib.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
":friends",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":core_cpu",
|
||||
":tensorflow",
|
||||
@ -213,11 +210,9 @@ tf_cuda_library(
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "kernels",
|
||||
name = "ops",
|
||||
srcs = glob(
|
||||
[
|
||||
"kernels/**/*.h",
|
||||
"kernels/**/*.cc",
|
||||
"ops/**/*.h",
|
||||
"ops/**/*.cc",
|
||||
"user_ops/**/*.h",
|
||||
@ -226,14 +221,38 @@ tf_cuda_library(
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"**/*main.cc",
|
||||
"kernels/**/*.cu.cc",
|
||||
"user_ops/**/*.cu.cc",
|
||||
],
|
||||
),
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":core",
|
||||
":lib",
|
||||
":protos_cc",
|
||||
"//tensorflow/models/embedding:word2vec_ops",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "kernels",
|
||||
srcs = glob(
|
||||
[
|
||||
"kernels/**/*.h",
|
||||
"kernels/**/*.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"**/*main.cc",
|
||||
"kernels/**/*.cu.cc",
|
||||
],
|
||||
),
|
||||
copts = tf_copts(),
|
||||
cuda_deps = [
|
||||
":gpu_kernels",
|
||||
":cuda",
|
||||
],
|
||||
linkstatic = 0,
|
||||
visibility = ["//visibility:public"],
|
||||
@ -241,10 +260,10 @@ tf_cuda_library(
|
||||
"@gemmlowp//:eight_bit_int_gemm",
|
||||
":core",
|
||||
":lib",
|
||||
":ops",
|
||||
":protos_cc",
|
||||
":stream_executor",
|
||||
"//tensorflow/models/embedding:word2vec_kernels",
|
||||
"//tensorflow/models/embedding:word2vec_ops",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -262,6 +281,7 @@ tf_gpu_kernel_library(
|
||||
),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":cuda",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -416,6 +436,7 @@ tf_cc_tests(
|
||||
":direct_session",
|
||||
":kernels",
|
||||
":lib",
|
||||
":ops",
|
||||
":strict_headers",
|
||||
":test_main",
|
||||
":testlib",
|
||||
|
@ -164,6 +164,11 @@ Status DirectSession::Extend(const GraphDef& graph) {
|
||||
}
|
||||
|
||||
Status DirectSession::ExtendLocked(const GraphDef& graph) {
|
||||
if (graph_created_ && graph_def_.version() != graph.version()) {
|
||||
return errors::InvalidArgument("Incompatible GraphDef versions in Extend: ",
|
||||
graph_def_.version(), " != ",
|
||||
graph.version());
|
||||
}
|
||||
graph_created_ = true; // In case this is first call
|
||||
graph_def_.MergeFrom(graph);
|
||||
return Status::OK();
|
||||
|
@ -980,6 +980,7 @@ static void ToGraphDef(const Graph* g, GraphDef* gdef) {
|
||||
}
|
||||
gtl::InlinedVector<const Edge*, 4> inputs;
|
||||
gdef->Clear();
|
||||
gdef->set_version(g->version());
|
||||
while (!ready.empty()) {
|
||||
const Node* n = ready.front();
|
||||
ready.pop_front();
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace test {
|
||||
@ -27,6 +28,7 @@ typedef FunctionDefHelper FDH;
|
||||
GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
|
||||
gtl::ArraySlice<FunctionDef> funcs) {
|
||||
GraphDef g;
|
||||
g.set_version(TF_GRAPH_DEF_VERSION);
|
||||
for (auto n : nodes) {
|
||||
*(g.add_node()) = n;
|
||||
}
|
||||
|
@ -15,6 +15,17 @@ import "tensorflow/core/framework/function.proto";
|
||||
message GraphDef {
|
||||
repeated NodeDef node = 1;
|
||||
|
||||
// Compatibility version of the graph. Newly created graphs use
|
||||
// the most recent version. Version history:
|
||||
//
|
||||
// 0. Graphs created before GraphDef versioning
|
||||
// 1. First real version (2dec2015)
|
||||
//
|
||||
// The GraphDef version is distinct from the TensorFlow version.
|
||||
// Each released version of TensorFlow will support a range of
|
||||
// GraphDef versions.
|
||||
int32 version = 3;
|
||||
|
||||
// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
|
||||
//
|
||||
// "library" provides user-defined functions.
|
||||
|
@ -24,6 +24,7 @@ namespace tensorflow {
|
||||
|
||||
string SummarizeGraphDef(const GraphDef& graph_def) {
|
||||
string ret;
|
||||
strings::StrAppend(&ret, "version = ", graph_def.version(), ";\n");
|
||||
for (const NodeDef& node : graph_def.node()) {
|
||||
strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
|
||||
}
|
||||
|
@ -26,6 +26,14 @@ namespace tensorflow {
|
||||
|
||||
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
string* diff) {
|
||||
if (actual.version() != expected.version()) {
|
||||
if (diff != nullptr) {
|
||||
*diff = strings::StrCat("Expected version ", expected.version(),
|
||||
", got version ", actual.version());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_map<string, const NodeDef*> actual_index;
|
||||
for (const NodeDef& node : actual.node()) {
|
||||
actual_index[node.name()] = &node;
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -88,10 +89,11 @@ TEST_F(EqualGraphDefTest, ExtraNode) {
|
||||
Input(a_.opts().WithName("A"));
|
||||
Input(a_.opts().WithName("B"));
|
||||
EXPECT_FALSE(Match());
|
||||
EXPECT_EQ(
|
||||
"Found unexpected node 'B = Input[]()' not in expected graph:\n"
|
||||
"A = Input[]();\n",
|
||||
diff_);
|
||||
EXPECT_EQ(strings::StrCat(
|
||||
"Found unexpected node 'B = Input[]()' not in expected graph:\n"
|
||||
"version = ",
|
||||
TF_GRAPH_DEF_VERSION, ";\nA = Input[]();\n"),
|
||||
diff_);
|
||||
}
|
||||
|
||||
TEST_F(EqualGraphDefTest, NodeOrder) {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -105,7 +106,7 @@ Node::Properties::~Properties() {}
|
||||
// Graph
|
||||
|
||||
Graph::Graph(const OpRegistryInterface* ops)
|
||||
: ops_(ops), arena_(8 << 10 /* 8kB */) {
|
||||
: ops_(ops), version_(TF_GRAPH_DEF_VERSION), arena_(8 << 10 /* 8kB */) {
|
||||
// Source and sink have no endpoints, just control edges.
|
||||
NodeDef def;
|
||||
def.set_name("_SOURCE");
|
||||
@ -253,6 +254,7 @@ void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
|
||||
|
||||
void Graph::ToGraphDef(GraphDef* graph_def) const {
|
||||
graph_def->Clear();
|
||||
graph_def->set_version(version());
|
||||
std::vector<const Edge*>
|
||||
inputs; // Construct this outside the loop for speed.
|
||||
for (const Node* node : nodes()) {
|
||||
|
@ -187,11 +187,17 @@ class Graph {
|
||||
// single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
|
||||
//
|
||||
// The graph can hold ops found in registry.
|
||||
//
|
||||
// The version defaults to TF_GRAPH_DEF_VERSION.
|
||||
explicit Graph(const OpRegistryInterface* registry);
|
||||
~Graph();
|
||||
|
||||
static const int kControlSlot = -1;
|
||||
|
||||
// The GraphDef version of this graph (see graph.proto).
|
||||
int version() const { return version_; }
|
||||
void set_version(int version) { version_ = version; }
|
||||
|
||||
// Adds a new node to this graph, and returns it. Infers the Op and
|
||||
// input/output types for the node. *this owns the returned instance.
|
||||
// Returns nullptr and sets *status on error.
|
||||
@ -274,6 +280,9 @@ class Graph {
|
||||
// Registry of all known ops. Not owned.
|
||||
const OpRegistryInterface* const ops_;
|
||||
|
||||
// GraphDef version
|
||||
int version_;
|
||||
|
||||
// Allocator which will give us good locality.
|
||||
core::Arena arena_;
|
||||
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -45,6 +46,19 @@ class GraphConstructor {
|
||||
GraphConstructor(const GraphConstructorOptions& opts, const GraphDef* gdef,
|
||||
Graph* g, Status* status)
|
||||
: opts_(opts), gdef_(gdef), g_(g), status_(status) {
|
||||
const int version = gdef->version();
|
||||
if (!(TF_GRAPH_DEF_VERSION_MIN <= version &&
|
||||
version <= TF_GRAPH_DEF_VERSION_MAX)) {
|
||||
bool low = version < TF_GRAPH_DEF_VERSION_MAX;
|
||||
*status = errors::InvalidArgument(
|
||||
"GraphDef version ", version, " is ", low ? "no longer" : "not yet",
|
||||
" supported: TensorFlow ", TF_VERSION_STRING, " needs ",
|
||||
TF_GRAPH_DEF_VERSION_MAX, " <= version <= ", TF_GRAPH_DEF_VERSION_MIN,
|
||||
". ",
|
||||
low ? "Please regenerate your graph." : "Please upgrade TensorFlow.");
|
||||
return;
|
||||
}
|
||||
g->set_version(gdef->version());
|
||||
BuildNodeIndex();
|
||||
InitFromEdges();
|
||||
Convert();
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/public/status.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
// TODO(josh11b): Test InitCostModel().
|
||||
// TODO(josh11b): Test setting the "device" field of a NodeDef.
|
||||
@ -58,6 +59,12 @@ class GraphConstructorTest : public ::testing::Test {
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get()));
|
||||
}
|
||||
|
||||
void ExpectVersion(int version) {
|
||||
EXPECT_NE(nullptr, g_);
|
||||
EXPECT_EQ(version, g_->version()) << "Expected version " << version
|
||||
<< ", got " << g_->version();
|
||||
}
|
||||
|
||||
Node* FindNode(const string& name) {
|
||||
for (Node* n : g_->nodes()) {
|
||||
if (n->name() == name) return n;
|
||||
@ -160,7 +167,30 @@ TEST_F(GraphConstructorTest, TypeMismatch) {
|
||||
"expected int32.");
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, EmptyGraph) { ExpectOK(""); }
|
||||
TEST_F(GraphConstructorTest, EmptyGraph) {
|
||||
ExpectOK("");
|
||||
ExpectVersion(0); // The default GraphDef version is 0
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, VersionGraph) {
|
||||
ASSERT_LT(0, TF_GRAPH_DEF_VERSION); // Verify the assertion is nontrivial
|
||||
ExpectOK(strings::StrCat("version: ", TF_GRAPH_DEF_VERSION));
|
||||
ExpectVersion(TF_GRAPH_DEF_VERSION);
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, LowVersion) {
|
||||
ExpectError(strings::StrCat("version: ", -1),
|
||||
R"(^GraphDef version -1 is no longer supported: TensorFlow \S+ )"
|
||||
R"(needs \d+ <= version <= \d+\. )"
|
||||
R"(Please regenerate your graph\.$)");
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, HighVersion) {
|
||||
ExpectError(strings::StrCat("version: ", TF_GRAPH_DEF_VERSION_MAX + 1),
|
||||
R"(^GraphDef version \d+ is not yet supported: TensorFlow \S+ )"
|
||||
R"(needs \d+ <= version <= \d+\. )"
|
||||
R"(Please upgrade TensorFlow\.$)");
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, SimpleModel) {
|
||||
ExpectOK(
|
||||
|
48
tensorflow/core/graph/graph_def_builder_test.cc
Normal file
48
tensorflow/core/graph/graph_def_builder_test.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(GraphDefBuilderTest, Version) {
|
||||
RequireDefaultOps();
|
||||
|
||||
// Verify that our assertions will be nontrivial
|
||||
ASSERT_LT(0, TF_GRAPH_DEF_VERSION);
|
||||
|
||||
// Newly built graphs should use the current version
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
|
||||
// Check version when we convert to a Graph
|
||||
Graph graph(OpRegistry::Global());
|
||||
EXPECT_OK(builder.ToGraph(&graph));
|
||||
ASSERT_EQ(graph.version(), TF_GRAPH_DEF_VERSION);
|
||||
|
||||
// Check version when we convert to a GraphDef
|
||||
GraphDef graph_def;
|
||||
EXPECT_OK(builder.ToGraphDef(&graph_def));
|
||||
ASSERT_EQ(graph_def.version(), TF_GRAPH_DEF_VERSION);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -1051,6 +1051,11 @@ Status Partition(const PartitionOptions& opts, Graph* g,
|
||||
}
|
||||
}
|
||||
|
||||
// Set versions
|
||||
for (auto& it : *partitions) {
|
||||
it.second.set_version(g->version());
|
||||
}
|
||||
|
||||
// Set the start times for recvs at the very end.
|
||||
if (opts.scheduling_for_recvs) {
|
||||
for (auto& it : dup_recv) {
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -72,6 +73,12 @@ void Partition(const GraphDef& graph_def,
|
||||
popts.control_flow_added = false;
|
||||
Status s = Partition(popts, &g, partitions);
|
||||
CHECK(s.ok()) << s;
|
||||
|
||||
// Check versions
|
||||
EXPECT_EQ(graph_def.version(), TF_GRAPH_DEF_VERSION);
|
||||
for (auto& it : *partitions) {
|
||||
EXPECT_EQ(graph_def.version(), it.second.version());
|
||||
}
|
||||
}
|
||||
|
||||
void CheckLoopConstruction(const GraphDef& graph_def) {
|
||||
|
@ -36,4 +36,9 @@ limitations under the License.
|
||||
|
||||
// TODO(josh11b): Public API functions for exporting the above.
|
||||
|
||||
// Supported GraphDef versions (see graph.proto).
|
||||
#define TF_GRAPH_DEF_VERSION_MIN 0
|
||||
#define TF_GRAPH_DEF_VERSION_MAX 1
|
||||
#define TF_GRAPH_DEF_VERSION TF_GRAPH_DEF_VERSION_MAX
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_
|
||||
|
@ -138,6 +138,7 @@ py_library(
|
||||
"framework/tensor_shape.py",
|
||||
"framework/dtypes.py",
|
||||
"framework/tensor_util.py",
|
||||
"framework/versions.py",
|
||||
"ops/common_shapes.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
@ -195,6 +196,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "framework_versions_test",
|
||||
srcs = ["framework/versions_test.py"],
|
||||
main = "framework/versions_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "framework_importer_test",
|
||||
srcs = ["framework/importer_test.py"],
|
||||
|
@ -48,6 +48,7 @@ from tensorflow.core.util.event_pb2 import *
|
||||
|
||||
# Framework
|
||||
from tensorflow.python.framework.framework_lib import *
|
||||
from tensorflow.python.framework.versions import *
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
# Session
|
||||
@ -81,3 +82,4 @@ _whitelist = set([app, compat, errors, flags, image, logging, nn,
|
||||
_whitelist.update([ops, tensor_util]) # pylint: disable=undefined-variable
|
||||
__all__ = [name for name, x in locals().items() if not name.startswith('_') and
|
||||
(not inspect.ismodule(x) or x in _whitelist)]
|
||||
__all__.append('__version__')
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -425,7 +426,8 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testGraphDef(self):
|
||||
with session.Session() as sess:
|
||||
self.assertProtoEquals('', sess.graph_def)
|
||||
self.assertProtoEquals('version: %d' % versions.GRAPH_DEF_VERSION,
|
||||
sess.graph_def)
|
||||
c = constant_op.constant(5.0, name='c')
|
||||
self.assertEquals(len(sess.graph_def.node), 1)
|
||||
d = constant_op.constant(6.0, name='d')
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/python/client/tf_session_helper.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/public/status.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
%}
|
||||
|
||||
@ -32,6 +33,12 @@ limitations under the License.
|
||||
tensorflow::ImportNumpy();
|
||||
%}
|
||||
|
||||
// TensorFlow version and GraphDef versions
|
||||
%constant const char* __version__ = TF_VERSION_STRING;
|
||||
%constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION;
|
||||
%constant int GRAPH_DEF_VERSION_MIN = TF_GRAPH_DEF_VERSION_MIN;
|
||||
%constant int GRAPH_DEF_VERSION_MAX = TF_GRAPH_DEF_VERSION_MAX;
|
||||
|
||||
// Release the Python GIL for the duration of most methods.
|
||||
%exception {
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
|
@ -215,6 +215,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
|
||||
with ops.op_scope(input_map.values(), name, 'import'):
|
||||
g = ops.get_default_graph()
|
||||
g.graph_def_version = graph_def.version
|
||||
|
||||
with ops.name_scope('_inputs'):
|
||||
input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
|
||||
|
@ -111,7 +111,8 @@ for op_def in _op_list.op:
|
||||
|
||||
class ImportGraphDefTest(tf.test.TestCase):
|
||||
|
||||
def _MakeGraphDef(self, text):
|
||||
def _MakeGraphDef(self, text, version=tf.GRAPH_DEF_VERSION):
|
||||
text = "version: %d\n%s" % (version, text)
|
||||
ret = tf.GraphDef()
|
||||
text_format.Merge(text, ret)
|
||||
return ret
|
||||
@ -610,6 +611,28 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
g = tf.identity(t)
|
||||
g.eval()
|
||||
|
||||
def testVersion(self):
|
||||
for version in tf.GRAPH_DEF_VERSION_MIN, tf.GRAPH_DEF_VERSION_MAX:
|
||||
with tf.Graph().as_default():
|
||||
a, = tf.import_graph_def(
|
||||
self._MakeGraphDef("node { name: 'A' op: 'Oii' }", version=version),
|
||||
return_elements=['A'])
|
||||
self.assertEqual(a.graph.graph_def_version, version)
|
||||
|
||||
def testVersionLow(self):
|
||||
with tf.Graph().as_default():
|
||||
pat = (r"^GraphDef version -1 is no longer supported: TensorFlow \S+ "
|
||||
r"needs \d+ <= version <= \d+. Please regenerate your graph.$")
|
||||
with self.assertRaisesRegexp(ValueError, pat):
|
||||
tf.import_graph_def(self._MakeGraphDef("", version=-1))
|
||||
|
||||
def testVersionHigh(self):
|
||||
with tf.Graph().as_default():
|
||||
pat = (r"^GraphDef version \d+ is not yet supported: TensorFlow \S+ "
|
||||
r"needs \d+ <= version <= \d+. Please upgrade TensorFlow.$")
|
||||
with self.assertRaisesRegexp(ValueError, pat):
|
||||
tf.import_graph_def(self._MakeGraphDef("", version=1 << 30))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import registry
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -1545,6 +1546,7 @@ class Graph(object):
|
||||
@@seed
|
||||
@@unique_name
|
||||
@@version
|
||||
@@graph_def_version
|
||||
|
||||
@@create_op
|
||||
@@gradient_override_map
|
||||
@ -1585,6 +1587,8 @@ class Graph(object):
|
||||
self._finalized = False
|
||||
# Functions defined in the graph
|
||||
self._functions = collections.OrderedDict()
|
||||
# Default GraphDef version
|
||||
self._graph_def_version = versions.GRAPH_DEF_VERSION
|
||||
|
||||
def _check_not_finalized(self):
|
||||
"""Check if the graph is finalized.
|
||||
@ -1620,9 +1624,36 @@ class Graph(object):
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
"""Returns a version number that increases as ops are added to the graph."""
|
||||
"""Returns a version number that increases as ops are added to the graph.
|
||||
|
||||
Note that this is unrelated to the
|
||||
[GraphDef version](#Graph.graph_def_version).
|
||||
"""
|
||||
return self._next_id_counter
|
||||
|
||||
@property
|
||||
def graph_def_version(self):
|
||||
"""The GraphDef version of this graph.
|
||||
|
||||
For details on the meaning of each version, see [`GraphDef`]
|
||||
(https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto).
|
||||
"""
|
||||
return self._graph_def_version
|
||||
|
||||
@graph_def_version.setter
|
||||
def graph_def_version(self, version):
|
||||
if not (versions.GRAPH_DEF_VERSION_MIN <= version <=
|
||||
versions.GRAPH_DEF_VERSION_MAX):
|
||||
low = version < versions.GRAPH_DEF_VERSION_MIN
|
||||
raise ValueError(
|
||||
"GraphDef version %d is %s supported: TensorFlow %s needs %d <= "
|
||||
"version <= %d. Please %s." %
|
||||
(version, "no longer" if low else "not yet",
|
||||
versions.__version__, versions.GRAPH_DEF_VERSION_MIN,
|
||||
versions.GRAPH_DEF_VERSION_MAX,
|
||||
"regenerate your graph" if low else "upgrade TensorFlow"))
|
||||
self._graph_def_version = version
|
||||
|
||||
@property
|
||||
def seed(self):
|
||||
return self._seed
|
||||
@ -1684,6 +1715,7 @@ class Graph(object):
|
||||
ValueError: If the `graph_def` would be too large.
|
||||
"""
|
||||
graph = graph_pb2.GraphDef()
|
||||
graph.version = self._graph_def_version
|
||||
bytesize = 0
|
||||
for op_id in sorted(self._nodes_by_id):
|
||||
op = self._nodes_by_id[op_id]
|
||||
|
@ -410,7 +410,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
||||
op = g.create_op("an_op", [], [dtypes.float32])
|
||||
self.assertEqual(None, op.device)
|
||||
gd = g.as_graph_def()
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: "an_op" op: "an_op" }
|
||||
""", gd)
|
||||
|
||||
@ -419,7 +419,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
||||
with g.device("/job:worker/replica:2"):
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
gd = g.as_graph_def()
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: "an_op" op: "an_op" device: "/job:worker/replica:2" }
|
||||
""", gd)
|
||||
|
||||
@ -430,7 +430,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
||||
device_index=3)):
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
gd = g.as_graph_def()
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: "an_op" op: "an_op"
|
||||
device: "/job:worker/replica:2/task:0/device:CPU:3" }
|
||||
""", gd)
|
||||
@ -443,7 +443,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
gd = g.as_graph_def()
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: "an_op" op: "an_op"
|
||||
device: "/job:worker/replica:2" }
|
||||
node { name: "an_op_1" op: "an_op"
|
||||
@ -460,7 +460,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
gd = g.as_graph_def()
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: "an_op" op: "an_op"
|
||||
device: "/job:worker/replica:2" }
|
||||
node { name: "an_op_1" op: "an_op"
|
||||
@ -477,7 +477,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
gd = g.as_graph_def()
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: "an_op" op: "an_op"
|
||||
device: "/job:worker/replica:2/device:CPU:1" }
|
||||
node { name: "an_op_1" op: "an_op"
|
||||
@ -501,7 +501,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
|
||||
gd = g.as_graph_def()
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: "an_op" op: "an_op"
|
||||
device: "/device:GPU:0" }
|
||||
node { name: "an_op_1" op: "an_op"
|
||||
@ -522,7 +522,7 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
g.create_op("an_op", [], [dtypes.float32])
|
||||
gd = g.as_graph_def()
|
||||
self.assertProtoEquals("""
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: "an_op" op: "an_op"
|
||||
device: "/job:worker/replica:2/device:CPU:1" }
|
||||
node { name: "an_op_1" op: "an_op" }
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import tensorflow.python.platform
|
||||
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
|
||||
|
||||
class Dimension(object):
|
||||
"""Represents the value of one dimension in a TensorShape."""
|
||||
@ -407,6 +409,8 @@ class TensorShape(object):
|
||||
# TODO(irving): Eliminate the single integer special case.
|
||||
if dims is None:
|
||||
self._dims = None
|
||||
elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
|
||||
self._dims = [as_dimension(dim.size) for dim in dims.dim]
|
||||
else:
|
||||
try:
|
||||
dims_iter = iter(dims)
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import tensorflow.python.platform
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
@ -254,6 +255,19 @@ class ShapeTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesRegexp(TypeError, r"unsupported operand type"):
|
||||
unknown / unknown # pylint: disable=pointless-statement
|
||||
|
||||
def testConvertFromProto(self):
|
||||
proto = tensor_util.MakeTensorShapeProto([])
|
||||
self.assertEqual(tensor_shape.TensorShape([]),
|
||||
tensor_shape.TensorShape(proto))
|
||||
self.assertEqual(tensor_shape.TensorShape([]),
|
||||
tensor_shape.as_shape(proto))
|
||||
|
||||
proto = tensor_util.MakeTensorShapeProto([1, 37, 42])
|
||||
self.assertEqual(tensor_shape.TensorShape([1, 37, 42]),
|
||||
tensor_shape.TensorShape(proto))
|
||||
self.assertEqual(tensor_shape.TensorShape([1, 37, 42]),
|
||||
tensor_shape.as_shape(proto))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.client import graph_util
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import logging
|
||||
from tensorflow.python.util.protobuf import compare
|
||||
@ -113,6 +114,11 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
type(expected_message_maybe_ascii) + " and " +
|
||||
type(message))
|
||||
|
||||
def assertProtoEqualsVersion(self, expected, actual,
|
||||
version=versions.GRAPH_DEF_VERSION):
|
||||
expected = "version: %d\n%s" % (version, expected)
|
||||
self.assertProtoEquals(expected, actual)
|
||||
|
||||
def assertStartsWith(self, actual, expected_start, msg=None):
|
||||
"""Assert that actual.startswith(expected_start) is True.
|
||||
|
||||
|
33
tensorflow/python/framework/versions.py
Normal file
33
tensorflow/python/framework/versions.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""TensorFlow versions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.python.platform
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
|
||||
__version__ = pywrap_tensorflow.__version__
|
||||
GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION
|
||||
GRAPH_DEF_VERSION_MIN = pywrap_tensorflow.GRAPH_DEF_VERSION_MIN
|
||||
GRAPH_DEF_VERSION_MAX = pywrap_tensorflow.GRAPH_DEF_VERSION_MAX
|
||||
|
||||
# Make sure these symbols are exported even though one starts with _.
|
||||
__all__ = ["__version__", "GRAPH_DEF_VERSION", "GRAPH_DEF_VERSION_MIN",
|
||||
"GRAPH_DEF_VERSION_MAX"]
|
45
tensorflow/python/framework/versions_test.py
Normal file
45
tensorflow/python/framework/versions_test.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for exposed tensorflow versions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.python.platform
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class VersionTest(tf.test.TestCase):
|
||||
|
||||
def testVersion(self):
|
||||
self.assertEqual(type(tf.__version__), str)
|
||||
# This pattern will need to grow as we include alpha, builds, etc.
|
||||
self.assertRegexpMatches(tf.__version__, r'^\d+\.\d+\.\d+$')
|
||||
|
||||
def testGraphDefVersion(self):
|
||||
version = tf.GRAPH_DEF_VERSION
|
||||
min = tf.GRAPH_DEF_VERSION_MIN
|
||||
max = tf.GRAPH_DEF_VERSION_MAX
|
||||
for v in version, min, max:
|
||||
self.assertEqual(type(v), int)
|
||||
self.assertLessEqual(0, min)
|
||||
self.assertLessEqual(min, version)
|
||||
self.assertLessEqual(version, max)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -344,6 +344,42 @@ class FIFOQueueTest(tf.test.TestCase):
|
||||
self.assertAllEqual(dequeued_t.eval(), elems)
|
||||
|
||||
def testEnqueueWrongShape(self):
|
||||
q = tf.FIFOQueue(10, (tf.int32, tf.int32), ((), (2)))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
q.enqueue(([1, 2], [2, 2]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
q.enqueue_many((7, [[1, 2], [3, 4], [5, 6]]))
|
||||
|
||||
def testBatchSizeMismatch(self):
|
||||
q = tf.FIFOQueue(10, (tf.int32, tf.int32, tf.int32), ((), (), ()))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
q.enqueue_many(([1, 2, 3], [1, 2], [1, 2, 3]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
q.enqueue_many(([1, 2, 3], [1, 2], tf.placeholder(tf.int32)))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
q.enqueue_many((tf.placeholder(tf.int32), [1, 2], [1, 2, 3]))
|
||||
|
||||
def testEnqueueManyEmptyTypeConversion(self):
|
||||
q = tf.FIFOQueue(10, (tf.int32, tf.float32), ((), ()))
|
||||
enq = q.enqueue_many(([], []))
|
||||
self.assertEqual(tf.int32, enq.inputs[1].dtype)
|
||||
self.assertEqual(tf.float32, enq.inputs[2].dtype)
|
||||
|
||||
def testEnqueueWrongType(self):
|
||||
q = tf.FIFOQueue(10, (tf.int32, tf.float32), ((), ()))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
q.enqueue((tf.placeholder(tf.int32), tf.placeholder(tf.int32)))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
q.enqueue_many((tf.placeholder(tf.int32), tf.placeholder(tf.int32)))
|
||||
|
||||
def testEnqueueWrongShapeAtRuntime(self):
|
||||
with self.test_session() as sess:
|
||||
q = tf.FIFOQueue(10, (tf.int32, tf.int32), ((2, 2), (3, 3)))
|
||||
elems_ok = np.array([1] * 4).reshape((2, 2)).astype(np.int32)
|
||||
@ -353,8 +389,6 @@ class FIFOQueueTest(tf.test.TestCase):
|
||||
tf.errors.InvalidArgumentError, r"Expected \[3,3\], got \[3,4\]"):
|
||||
sess.run([enqueue_op],
|
||||
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
|
||||
sess.run([enqueue_op],
|
||||
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
|
||||
|
||||
def testEnqueueDequeueManyWrongShape(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -485,5 +485,74 @@ class LSTMTest(tf.test.TestCase):
|
||||
self._testDoubleInputWithDropoutAndDynamicCalculation(True)
|
||||
|
||||
|
||||
class BidirectionalRNNTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._seed = 23489
|
||||
np.random.seed(self._seed)
|
||||
|
||||
def _testBidirectionalRNN(self, use_gpu):
|
||||
num_units = 3
|
||||
input_size = 5
|
||||
batch_size = 2
|
||||
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
sequence_length = tf.placeholder(tf.int64)
|
||||
cell_fw = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, initializer=initializer)
|
||||
cell_bw = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, initializer=initializer)
|
||||
inputs = 10 * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
outputs = tf.nn.bidirectional_rnn(
|
||||
cell_fw, cell_bw, inputs, dtype=tf.float32,
|
||||
sequence_length=sequence_length)
|
||||
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
for out in outputs:
|
||||
self.assertEqual(out.get_shape().as_list(), [batch_size, 2 * num_units])
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
# Run with pre-specified sequence length of 2, 3
|
||||
out = sess.run(outputs, feed_dict={inputs[0]: input_value,
|
||||
sequence_length: [2, 3]})
|
||||
|
||||
# Since the forward and backward LSTM cells were initialized with the
|
||||
# same parameters, the forward and backward output has to be the same,
|
||||
# but reversed in time. The format is output[time][batch][depth], and
|
||||
# due to depth concatenation (as num_units=3 for both RNNs):
|
||||
# - forward output: out[][][depth] for 0 <= depth < 3
|
||||
# - backward output: out[][][depth] for 4 <= depth < 6
|
||||
#
|
||||
# First sequence in batch is length=2
|
||||
# Check that the time=0 forward output is equal to time=1 backward output
|
||||
self.assertEqual(out[0][0][0], out[1][0][3])
|
||||
self.assertEqual(out[0][0][1], out[1][0][4])
|
||||
self.assertEqual(out[0][0][2], out[1][0][5])
|
||||
# Check that the time=1 forward output is equal to time=0 backward output
|
||||
self.assertEqual(out[1][0][0], out[0][0][3])
|
||||
self.assertEqual(out[1][0][1], out[0][0][4])
|
||||
self.assertEqual(out[1][0][2], out[0][0][5])
|
||||
|
||||
# Second sequence in batch is length=3
|
||||
# Check that the time=0 forward output is equal to time=2 backward output
|
||||
self.assertEqual(out[0][1][0], out[2][1][3])
|
||||
self.assertEqual(out[0][1][1], out[2][1][4])
|
||||
self.assertEqual(out[0][1][2], out[2][1][5])
|
||||
# Check that the time=1 forward output is equal to time=1 backward output
|
||||
self.assertEqual(out[1][1][0], out[1][1][3])
|
||||
self.assertEqual(out[1][1][1], out[1][1][4])
|
||||
self.assertEqual(out[1][1][2], out[1][1][5])
|
||||
# Check that the time=2 forward output is equal to time=0 backward output
|
||||
self.assertEqual(out[2][1][0], out[0][1][3])
|
||||
self.assertEqual(out[2][1][1], out[0][1][4])
|
||||
self.assertEqual(out[2][1][2], out[0][1][5])
|
||||
|
||||
def testBidirectionalRNN(self):
|
||||
self._testBidirectionalRNN(use_gpu=False)
|
||||
self._testBidirectionalRNN(use_gpu=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -157,6 +157,26 @@ class QueueBase(object):
|
||||
"""The list of dtypes for each component of a queue element."""
|
||||
return self._dtypes
|
||||
|
||||
def _check_enqueue_dtypes(self, vals):
|
||||
"""Returns `vals` as a list of `Tensor`s, having checked their dtypes.
|
||||
|
||||
Args:
|
||||
vals: A tensor or a list of tensors, corresponding to an
|
||||
enqueue(_many) tuple.
|
||||
|
||||
Returns:
|
||||
A list of `Tensor` objects.
|
||||
"""
|
||||
if not isinstance(vals, (list, tuple)):
|
||||
vals = [vals]
|
||||
|
||||
tensors = []
|
||||
for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
|
||||
tensors.append(ops.convert_to_tensor(val, dtype=dtype,
|
||||
name="component_%d" % i))
|
||||
|
||||
return tensors
|
||||
|
||||
def enqueue(self, vals, name=None):
|
||||
"""Enqueues one element to this queue.
|
||||
|
||||
@ -170,16 +190,18 @@ class QueueBase(object):
|
||||
Returns:
|
||||
The operation that enqueues a new tuple of tensors to the queue.
|
||||
"""
|
||||
if name is None:
|
||||
name = "%s_enqueue" % self._name
|
||||
ret = gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=name)
|
||||
if not isinstance(vals, (list, tuple)):
|
||||
vals = [vals]
|
||||
|
||||
# NOTE(mrry): Not using a shape function because we need access to
|
||||
# the Queue object.
|
||||
for val, shape in zip(ret.inputs[1:], self._shapes):
|
||||
val.get_shape().assert_is_compatible_with(shape)
|
||||
with ops.op_scope(vals, name, "%s_enqueue" % self._name) as scope:
|
||||
vals = self._check_enqueue_dtypes(vals)
|
||||
|
||||
return ret
|
||||
# NOTE(mrry): Not using a shape function because we need access to
|
||||
# the `QueueBase` object.
|
||||
for val, shape in zip(vals, self._shapes):
|
||||
val.get_shape().assert_is_compatible_with(shape)
|
||||
|
||||
return gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=scope)
|
||||
|
||||
def enqueue_many(self, vals, name=None):
|
||||
"""Enqueues zero or elements to this queue.
|
||||
@ -199,20 +221,22 @@ class QueueBase(object):
|
||||
Returns:
|
||||
The operation that enqueues a batch of tuples of tensors to the queue.
|
||||
"""
|
||||
if name is None:
|
||||
name = "%s_EnqueueMany" % self._name
|
||||
if not isinstance(vals, (list, tuple)):
|
||||
vals = [vals]
|
||||
|
||||
ret = gen_data_flow_ops._queue_enqueue_many(
|
||||
self._queue_ref, vals, name=name)
|
||||
with ops.op_scope(vals, name, "%s_EnqueueMany" % self._name) as scope:
|
||||
vals = self._check_enqueue_dtypes(vals)
|
||||
|
||||
# NOTE(mrry): Not using a shape function because we need access to
|
||||
# the `QueueBase` object.
|
||||
batch_dim = ret.inputs[1].get_shape()[0]
|
||||
for val, shape in zip(ret.inputs[1:], self._shapes):
|
||||
batch_dim.merge_with(val.get_shape()[0])
|
||||
val.get_shape()[1:].assert_is_compatible_with(shape)
|
||||
# NOTE(mrry): Not using a shape function because we need access to
|
||||
# the `QueueBase` object.
|
||||
batch_dim = vals[0].get_shape().with_rank_at_least(1)[0]
|
||||
for val, shape in zip(vals, self._shapes):
|
||||
batch_dim = batch_dim.merge_with(
|
||||
val.get_shape().with_rank_at_least(1)[0])
|
||||
val.get_shape()[1:].assert_is_compatible_with(shape)
|
||||
|
||||
return ret
|
||||
return gen_data_flow_ops._queue_enqueue_many(
|
||||
self._queue_ref, vals, name=scope)
|
||||
|
||||
def dequeue(self, name=None):
|
||||
"""Dequeues one element from this queue.
|
||||
|
@ -148,3 +148,92 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
|
||||
outputs[-1] = array_ops.identity(outputs[-1])
|
||||
|
||||
return (outputs, states)
|
||||
|
||||
|
||||
def _reverse_seq(input_seq, lengths):
|
||||
"""Reverse a list of Tensors up to specified lengths.
|
||||
|
||||
Args:
|
||||
input_seq: Sequence of seq_len tensors of dimension (batch_size, depth)
|
||||
lengths: A tensor of dimension batch_size, containing lengths for each
|
||||
sequence in the batch. If "None" is specified, simply reverses
|
||||
the list.
|
||||
|
||||
Returns:
|
||||
time-reversed sequence
|
||||
"""
|
||||
if lengths is None:
|
||||
return list(reversed(input_seq))
|
||||
|
||||
# Join into (time, batch_size, depth)
|
||||
s_joined = array_ops.pack(input_seq)
|
||||
# Reverse along dimension 0
|
||||
s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
|
||||
# Split again into list
|
||||
result = array_ops.unpack(s_reversed)
|
||||
return result
|
||||
|
||||
|
||||
def bidirectional_rnn(cell_fw, cell_bw, inputs,
|
||||
initial_state_fw=None, initial_state_bw=None,
|
||||
dtype=None, sequence_length=None, scope=None):
|
||||
"""Creates a bidirectional recurrent neural network.
|
||||
|
||||
Similar to the unidirectional case above (rnn) but takes input and builds
|
||||
independent forward and backward RNNs with the final forward and backward
|
||||
outputs depth-concatenated, such that the output will have the format
|
||||
[time][batch][cell_fw.output_size + cell_bw.output_size]. The initial state
|
||||
for both directions is zero by default (but can be set optionally) and no
|
||||
intermediate states are ever returned -- the network is fully unrolled for
|
||||
the given (passed in) length(s) of the sequence(s).
|
||||
|
||||
Args:
|
||||
cell_fw: An instance of RNNCell, to be used for forward direction.
|
||||
cell_bw: An instance of RNNCell, to be used for backward direction.
|
||||
inputs: A length T list of inputs, each a vector with shape [batch_size].
|
||||
initial_state_fw: (optional) An initial state for the forward RNN.
|
||||
This must be a tensor of appropriate type and shape
|
||||
[batch_size x cell.state_size].
|
||||
initial_state_bw: (optional) Same as for initial_state_fw.
|
||||
dtype: (optional) The data type for the initial state. Required if either
|
||||
of the initial states are not provided.
|
||||
sequence_length: An int64 vector (tensor) of size [batch_size], containing
|
||||
the actual lengths for each of the sequences.
|
||||
scope: VariableScope for the created subgraph; defaults to "BiRNN"
|
||||
|
||||
Returns:
|
||||
A set of output `Tensors` where:
|
||||
outputs is a length T list of outputs (one for each input), which
|
||||
are depth-concatenated forward and backward outputs
|
||||
|
||||
Raises:
|
||||
TypeError: If "cell_fw" or "cell_bw" is not an instance of RNNCell.
|
||||
ValueError: If inputs is None or an empty list.
|
||||
ValueError: If sequence_length is not defined.
|
||||
"""
|
||||
|
||||
if not isinstance(cell_fw, rnn_cell.RNNCell):
|
||||
raise TypeError("cell_fw must be an instance of RNNCell")
|
||||
if not isinstance(cell_bw, rnn_cell.RNNCell):
|
||||
raise TypeError("cell_bw must be an instance of RNNCell")
|
||||
if not isinstance(inputs, list):
|
||||
raise TypeError("inputs must be a list")
|
||||
if not sequence_length:
|
||||
raise ValueError("sequence_length has to be defined")
|
||||
if not inputs:
|
||||
raise ValueError("inputs must not be empty")
|
||||
|
||||
name = scope or "BiRNN"
|
||||
# Forward direction
|
||||
with vs.variable_scope(name + "_FW"):
|
||||
output_fw, _ = rnn(cell_fw, inputs, initial_state_fw, dtype)
|
||||
# Backward direction
|
||||
with vs.variable_scope(name + "_BW"):
|
||||
tmp, _ = rnn(
|
||||
cell_bw, _reverse_seq(inputs, sequence_length), initial_state_bw, dtype)
|
||||
output_bw = _reverse_seq(tmp, sequence_length)
|
||||
# Concat each of the forward/backward outputs
|
||||
outputs = [array_ops.concat(1, [fw, bw])
|
||||
for fw, bw in zip(output_fw, output_bw)]
|
||||
|
||||
return outputs
|
||||
|
Loading…
Reference in New Issue
Block a user