From 795f35da2d458cbae477ac2fe2bff80c1427a771 Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Tue, 1 Dec 2015 13:26:53 -0800 Subject: [PATCH] TensorFlow: upstream changes to git Change: Clean up documentation for ReverseSequence Change: Updated several tensorflow operations to use 32bit indices on GPU. Change: Add attribute batch_dim to ReverseSequenceOp. Change: Fix error in convert_to_records.py. As reported in https://github.com/tensorflow/tensorflow/issues/370 by AlexUnderMicrocontRoll. Change: Update TensorBoard README. Change: Fixes to boolean flags reported in https://github.com/tensorflow/tensorflow/issues/379. Supports: --bool_flag=True --> True --bool_flag=False --> False --bool_flag=gibberish --> False --bool_flag --> True --nobool_flag --> False Fixes #379 Change: Update generated Op docs. Change: Enable local development of TensorBoard using gulp Also make tf-tensorboard a regular component rather than special case This is mostly effected by creating tfserve.js, which is a small server with clever routing to load from bower_components/ and components/ using the paths that work within google3. Workflow: `gulp serve` Change: Add a full working code example to the tensorboard and summaries tutorial Change: Fix seq2seq_test when running on GPU. The "proj_w" and "proj_b" variables were being created before the `test_session()`'s device function took effect, which pushed the placement algorithm into making an incorrect decision. Change: Add a sentence in TensorBoard README on how to serialize summary data to logs and provide link to the how-to tutorial on the TensorFlow website. Change: Add error-catching code if string_input_producer is supplied a null input. Before this change, it would die with an opaque shape error from inside the queue. This change catches (most) python null lists being passed directly in, and at runtime detects null tensors. Adds two tests for this to input_test.py Change: Speed up for models that use the same variable multiple times in the case where variables must be copied across devices: - Have Variables wrap the Variable op in an Identity op when converted to Tensor. This avoids multiple copies across devices if a variable is used multiple time in a computation. - Add Variable.mutable() to return the non-wrapped Variable op for used when assigning new values. - Add an as_ref parameter to convert_to_tensor() to allow code to specify if they plan to assign a new value to the result of the conversion. Make Variable return the result of Variable.mutable() when as_ref is True. - Make all ops that assign values to variables pass as_ref=True when converting their arguments. Change: Change to reduce critical section times in gpu_event_mgr.h: (1) Call stream->ThenRecordEvent outside the EventMgr critical section (2) Do memory deallocation outside the critical section Speeds up one configuration of ptb_word_lm from 2924 words per second (wps) to 3278 wps on my desktop machine with a Titan X. Change: Remove some colons that break the open source build ::tensorflow::StringPiece breaks for @raingo, see https://github.com/tensorflow/tensorflow/issues/358. tensorflow::StringPiece (without the leading colons) seems to fix the problem. Change: Added check that inputs to Operation is a list and make a defensive copy of the input. This is for cases where the input list is changed such as in _add_input. Change: Use standard names for TensorFlow dtypes in the tutorial. Change: Add tests for tensor inputs. Change: Fix build after declaring more types for ops Change: Switch to 32 bit indexing to speedup convolutions and concatenations. Change: Add convert_image op to convert between types for images (similar to OpenCV's cvtScale). Change: Make cast work between numeric types (bool, uint8, int16, int32, int64, float, double). Change: Padding input data for odd number of paddings, so we can use cudnn anyway. + Fix total padding computation when padding==VALID. + This CL makes the Googlenet benchmark run 5x faster. Change: Support IndexedSlices in ConcatGrad Change: * sampled softmax op uses one embedding lookup for positive and negative samples * float64 support for sampled softmax Change: Move RNN code out of models.rnn (without breaking existing code). The API may still undergo minor changes, until full documentation as added. Change: Changed to use per-step stacks for the accumulators used in while-loop gradient computation. This addresses the problem caused by using concat without sufficient static shape information. It should also improve performance as we avoided those expensive concats. Change: Update generated Op docs. Change: Improve error messages when the optimizer finds no variables to minimize or when none of the variables has gradients. Change: Say that -1 isn't just for flattening in reshape docs Also add scalar reshape (reshape(t, [])) as an example. This fixes https://github.com/tensorflow/tensorflow/issues/281. Change: This is a test. Base CL: 109118714 --- tensorflow/core/BUILD | 7 +- tensorflow/core/common_runtime/executor.cc | 8 +- .../core/common_runtime/gpu/gpu_event_mgr.cc | 34 +- .../core/common_runtime/gpu/gpu_event_mgr.h | 74 +- .../common_runtime/gpu/gpu_event_mgr_test.cc | 16 +- tensorflow/core/framework/device_base.h | 7 +- tensorflow/core/kernels/cast_op.cc | 104 ++- tensorflow/core/kernels/cast_op_gpu.cu.cc | 38 +- tensorflow/core/kernels/cast_op_test.cc | 44 +- tensorflow/core/kernels/concat_op_gpu.cu.cc | 8 +- tensorflow/core/kernels/constant_op_gpu.cu.cc | 4 +- tensorflow/core/kernels/conv_grad_ops.cc | 147 ++-- tensorflow/core/kernels/conv_ops_gpu.cu.cc | 8 +- tensorflow/core/kernels/cwise_op_div.cc | 18 +- .../core/kernels/cwise_op_gpu_div.cu.cc | 2 +- .../core/kernels/cwise_op_gpu_mul.cu.cc | 2 +- tensorflow/core/kernels/cwise_op_mul.cc | 18 +- tensorflow/core/kernels/cwise_ops_common.h | 5 + .../core/kernels/cwise_ops_gpu_common.cu.h | 28 +- tensorflow/core/kernels/lrn_op.cc | 47 +- tensorflow/core/kernels/reference_gemm.h | 90 -- .../core/kernels/reverse_sequence_op.cc | 60 +- tensorflow/core/kernels/reverse_sequence_op.h | 20 +- tensorflow/core/kernels/softsign_op.cc | 112 +++ tensorflow/core/kernels/softsign_op.h | 60 ++ tensorflow/core/kernels/softsign_op_gpu.cu.cc | 40 + tensorflow/core/kernels/split_op_gpu.cu.cc | 2 +- tensorflow/core/kernels/stack_ops.cc | 15 + tensorflow/core/ops/array_ops.cc | 52 +- tensorflow/core/ops/math_ops.cc | 4 +- tensorflow/core/ops/nn_ops.cc | 21 + tensorflow/core/ops/ops.pbtxt | 74 +- tensorflow/core/public/README.md | 8 +- tensorflow/examples/label_image/main.cc | 10 +- tensorflow/g3doc/api_docs/python/array_ops.md | 12 +- tensorflow/g3doc/api_docs/python/framework.md | 3 +- tensorflow/g3doc/api_docs/python/image.md | 44 +- tensorflow/g3doc/api_docs/python/index.md | 4 + tensorflow/g3doc/api_docs/python/io_ops.md | 6 + tensorflow/g3doc/api_docs/python/math_ops.md | 6 +- tensorflow/g3doc/api_docs/python/nn.md | 112 ++- .../g3doc/api_docs/python/sparse_ops.md | 14 +- tensorflow/g3doc/api_docs/python/state_ops.md | 45 + tensorflow/g3doc/api_docs/python/train.md | 5 +- tensorflow/g3doc/get_started/basic_usage.md | 4 +- .../g3doc/how_tos/adding_an_op/index.md | 2 +- .../how_tos/adding_an_op/zero_out_2_test.py | 3 +- .../reading_data/convert_to_records.py | 2 +- .../summaries_and_tensorboard/index.md | 68 +- .../mnist_with_summaries.py | 69 ++ .../g3doc/tutorials/mnist/beginners/index.md | 12 +- tensorflow/models/embedding/BUILD | 3 + tensorflow/models/rnn/BUILD | 46 - tensorflow/models/rnn/linear.py | 53 +- tensorflow/models/rnn/rnn.py | 133 +-- tensorflow/models/rnn/rnn_cell.py | 610 +------------- tensorflow/models/rnn/seq2seq.py | 751 +---------------- tensorflow/python/BUILD | 5 + .../python/framework/gen_docs_combined.py | 4 +- tensorflow/python/framework/ops.py | 56 +- tensorflow/python/framework/tensor_util.py | 8 +- .../python/framework/tensor_util_test.py | 18 +- .../kernel_tests/batch_matmul_op_test.py | 13 +- .../python/kernel_tests/bias_op_test.py | 8 +- .../python/kernel_tests/cast_op_test.py | 4 +- .../python/kernel_tests/concat_op_test.py | 57 ++ .../kernel_tests/control_flow_ops_py_test.py | 19 +- .../python/kernel_tests/conv_ops_test.py | 10 +- .../python/kernel_tests/cwise_ops_test.py | 65 +- .../python/kernel_tests/embedding_ops_test.py | 23 +- .../python/kernel_tests/gradient_checker.py | 83 +- .../kernel_tests/gradient_checker_test.py | 26 +- .../python/kernel_tests/linalg_grad_test.py | 19 +- .../kernel_tests}/linear_test.py | 10 +- tensorflow/python/kernel_tests/lrn_op_test.py | 5 +- .../python/kernel_tests/matmul_op_test.py | 10 +- .../python/kernel_tests/pack_op_test.py | 4 +- tensorflow/python/kernel_tests/pad_op_test.py | 8 +- .../python/kernel_tests/pooling_ops_test.py | 10 +- .../python/kernel_tests/reduction_ops_test.py | 134 ++- .../python/kernel_tests/relu_op_test.py | 40 +- .../python/kernel_tests/reshape_op_test.py | 9 +- .../kernel_tests/reverse_sequence_op_test.py | 60 +- .../kernel_tests}/rnn_cell_test.py | 6 +- .../rnn => python/kernel_tests}/rnn_test.py | 68 +- .../segment_reduction_ops_test.py | 34 +- .../kernel_tests}/seq2seq_test.py | 167 ++-- .../python/kernel_tests/shape_ops_test.py | 17 +- .../python/kernel_tests/softplus_op_test.py | 8 +- .../python/kernel_tests/softsign_op_test.py | 68 ++ .../kernel_tests/sparse_matmul_op_test.py | 8 +- .../python/kernel_tests/transpose_op_test.py | 10 +- .../python/kernel_tests/unpack_op_test.py | 5 +- .../python/kernel_tests/xent_op_test.py | 4 +- tensorflow/python/lib/io/py_record_writer.cc | 2 +- tensorflow/python/lib/io/py_record_writer.h | 2 +- tensorflow/python/ops/array_grad.py | 98 ++- tensorflow/python/ops/array_ops.py | 13 +- tensorflow/python/ops/constant_op.py | 26 +- tensorflow/python/ops/control_flow_grad.py | 6 +- tensorflow/python/ops/control_flow_ops.py | 119 +-- tensorflow/python/ops/gradients.py | 4 +- tensorflow/python/ops/image_grad_test.py | 21 +- tensorflow/python/ops/image_ops.py | 72 +- tensorflow/python/ops/image_ops_test.py | 42 + tensorflow/python/ops/nn.py | 122 +-- tensorflow/python/ops/nn_grad.py | 5 + tensorflow/python/ops/nn_ops.py | 83 ++ tensorflow/python/ops/nn_test.py | 330 ++++---- tensorflow/python/ops/op_def_library.py | 19 +- tensorflow/python/ops/rnn.py | 150 ++++ tensorflow/python/ops/rnn_cell.py | 685 +++++++++++++++ tensorflow/python/ops/seq2seq.py | 784 ++++++++++++++++++ tensorflow/python/ops/variables.py | 57 +- tensorflow/python/platform/default/_flags.py | 10 +- tensorflow/python/platform/default/_gfile.py | 22 + .../python/platform/default/flags_test.py | 43 +- tensorflow/python/platform/test.py | 5 + .../python/summary/event_multiplexer.py | 35 +- .../python/summary/event_multiplexer_test.py | 92 +- tensorflow/python/training/adagrad_test.py | 22 + tensorflow/python/training/adam.py | 8 +- tensorflow/python/training/adam_test.py | 36 + .../python/training/gradient_descent_test.py | 19 + tensorflow/python/training/input.py | 14 + tensorflow/python/training/input_test.py | 22 + tensorflow/python/training/momentum_test.py | 51 ++ tensorflow/python/training/optimizer.py | 9 +- tensorflow/python/training/optimizer_test.py | 20 + tensorflow/python/training/saver.py | 4 +- tensorflow/tensorboard/README.md | 47 +- tensorflow/tensorboard/app/index.html | 11 +- .../tensorboard/app/tf-tensorboard-demo.html | 72 -- tensorflow/tensorboard/bower.json | 2 + .../tensorboard/components/imports/README.md | 6 + .../tensorboard/components/imports/dagre.html | 2 + .../components/imports/graphlib.html | 1 + .../components/imports/local-imports/d3.html | 1 + .../imports/local-imports/dagre.html | 4 + .../imports/local-imports/graphlib.html | 1 + .../imports/local-imports/lodash.html | 1 + .../imports/local-imports/plottable.html | 3 + .../{ => components}/test/index.html | 7 +- .../tf-graph-common/test/index.html | 2 +- .../tf-graph-common/tf-graph-common.html | 8 +- .../tf-graph-loader/test/index.html | 2 +- .../tf-tensorboard}/demo/data/cos.json | 0 .../tf-tensorboard}/demo/data/cubic.json | 0 .../tf-tensorboard}/demo/data/linear.json | 0 .../demo/data/poly5-graph.pbtxt | 0 .../tf-tensorboard}/demo/data/poly5.json | 0 .../tf-tensorboard}/demo/data/runs.json | 0 .../tf-tensorboard}/demo/data/sin-graph.pbtxt | 0 .../tf-tensorboard}/demo/data/sin.json | 0 .../tf-tensorboard}/demo/data/sq.json | 0 .../tf-tensorboard}/demo/index.html | 4 +- .../tf-tensorboard}/tf-tensorboard.html | 18 +- tensorflow/tensorboard/gulpfile.js | 26 +- tensorflow/tensorboard/package.json | 4 +- tensorflow/tensorboard/scripts/tfserve.js | 79 ++ tensorflow/tensorboard/tensorboard.py | 24 +- tensorflow/tensorboard/tensorboard_handler.py | 7 +- 162 files changed, 4658 insertions(+), 3023 deletions(-) delete mode 100644 tensorflow/core/kernels/reference_gemm.h create mode 100644 tensorflow/core/kernels/softsign_op.cc create mode 100644 tensorflow/core/kernels/softsign_op.h create mode 100644 tensorflow/core/kernels/softsign_op_gpu.cu.cc create mode 100644 tensorflow/g3doc/how_tos/summaries_and_tensorboard/mnist_with_summaries.py rename tensorflow/{models/rnn => python/kernel_tests}/linear_test.py (89%) rename tensorflow/{models/rnn => python/kernel_tests}/rnn_cell_test.py (97%) rename tensorflow/{models/rnn => python/kernel_tests}/rnn_test.py (91%) rename tensorflow/{models/rnn => python/kernel_tests}/seq2seq_test.py (74%) create mode 100644 tensorflow/python/kernel_tests/softsign_op_test.py create mode 100644 tensorflow/python/ops/rnn.py create mode 100644 tensorflow/python/ops/rnn_cell.py create mode 100644 tensorflow/python/ops/seq2seq.py delete mode 100644 tensorflow/tensorboard/app/tf-tensorboard-demo.html create mode 100644 tensorflow/tensorboard/components/imports/README.md create mode 100644 tensorflow/tensorboard/components/imports/dagre.html create mode 100644 tensorflow/tensorboard/components/imports/graphlib.html create mode 100644 tensorflow/tensorboard/components/imports/local-imports/d3.html create mode 100644 tensorflow/tensorboard/components/imports/local-imports/dagre.html create mode 100644 tensorflow/tensorboard/components/imports/local-imports/graphlib.html create mode 100644 tensorflow/tensorboard/components/imports/local-imports/lodash.html create mode 100644 tensorflow/tensorboard/components/imports/local-imports/plottable.html rename tensorflow/tensorboard/{ => components}/test/index.html (56%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/data/cos.json (100%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/data/cubic.json (100%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/data/linear.json (100%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/data/poly5-graph.pbtxt (100%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/data/poly5.json (100%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/data/runs.json (100%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/data/sin-graph.pbtxt (100%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/data/sin.json (100%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/data/sq.json (100%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/demo/index.html (79%) rename tensorflow/tensorboard/{app => components/tf-tensorboard}/tf-tensorboard.html (82%) create mode 100644 tensorflow/tensorboard/scripts/tfserve.js diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5e321686531..0e512772849 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -293,14 +293,13 @@ cc_library( ], ) -# TODO(opensource): Make it work externally tf_proto_library( name = "protos_all", srcs = glob(["**/*.proto"]), cc_api_version = 2, go_api_version = 2, java_api_version = 2, - py_api_version = 2, # TODO(irving): Handle 3 + py_api_version = 2, visibility = ["//visibility:public"], ) @@ -507,7 +506,6 @@ filegroup( "kernels/maxpooling_op.h", "kernels/pooling_ops_common.h", "kernels/pooling_ops_common.cc", - "kernels/reference_gemm.h", ], exclude = [ "**/*test.cc", @@ -571,7 +569,6 @@ filegroup( "//tensorflow/core:kernels/no_op.cc", "//tensorflow/core:kernels/no_op.h", "//tensorflow/core:kernels/pack_op.cc", - "//tensorflow/core:kernels/reference_gemm.h", "//tensorflow/core:kernels/reshape_op.cc", "//tensorflow/core:kernels/reshape_op.h", "//tensorflow/core:kernels/reverse_sequence_op.cc", @@ -628,6 +625,8 @@ filegroup( "//tensorflow/core:kernels/relu_op.h", "//tensorflow/core:kernels/softplus_op.cc", "//tensorflow/core:kernels/softplus_op.h", + "//tensorflow/core:kernels/softsign_op.cc", + "//tensorflow/core:kernels/softsign_op.h", "//tensorflow/core:kernels/stack_ops.cc", "//tensorflow/core:kernels/transpose_op.cc", "//tensorflow/core:kernels/transpose_op.h", diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 3178e91f617..63d49b50108 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -758,7 +758,11 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { // Ask the device to fill in the device context map. Device* device = impl_->params_.device; - device->FillContextMap(graph, &device_context_map_); + Status fill_status = device->FillContextMap(graph, &device_context_map_); + if (!fill_status.ok()) { + done(fill_status); + return; + } // Initialize the ready queue. for (const Node* n : graph->nodes()) { @@ -1077,7 +1081,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, for (int i = 0; i < node->num_outputs(); ++i) { TensorValue val = ctx->release_output(i); - // Only Switch and Recv nodes can generate new dead outputs + // Only Switch and Recv can generate new dead outputs. if (*ctx->is_output_dead() || val.tensor == nullptr) { DCHECK(IsSwitch(node) || IsRecv(node)); } else { diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc index 1821289f4b6..32109157aa5 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc @@ -40,13 +40,13 @@ EventMgr::~EventMgr() { delete e; } while (!used_events_.empty()) { - delete used_events_[0].event; - delete used_events_[0].mem; - if (used_events_[0].bufrec.buf) { - used_events_[0].bufrec.alloc->DeallocateRaw(used_events_[0].bufrec.buf); + InUse* ue = &used_events_[0]; + delete ue->event; + delete ue->mem; + if (ue->bufrec.buf) { + ue->bufrec.alloc->DeallocateRaw(ue->bufrec.buf); } - if (used_events_[0].func != nullptr) - threadpool_.Schedule(used_events_[0].func); + if (ue->func != nullptr) threadpool_.Schedule(ue->func); used_events_.pop_front(); } } @@ -60,15 +60,17 @@ EventMgr::~EventMgr() { void EventMgr::PollLoop() { while (!stop_polling_.HasBeenNotified()) { Env::Default()->SleepForMicroseconds(1 * 1000); + ToFreeVector to_free; { mutex_lock l(mu_); - PollEvents(true); + PollEvents(true, &to_free); } + FreeMemory(to_free); } polling_stopped_.Notify(); } -void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) { +void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu, gpu::Event** e) { VLOG(2) << "QueueInUse free_events_ " << free_events_.size() << " used_events_ " << used_events_.size(); // Events are created on demand, and repeatedly reused. There is no @@ -77,10 +79,9 @@ void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) { free_events_.push_back(new gpu::Event(exec_)); free_events_.back()->Init(); } - gpu::Event* e = free_events_.back(); + *e = free_events_.back(); free_events_.pop_back(); - stream->ThenRecordEvent(e); - iu.event = e; + iu.event = *e; used_events_.push_back(iu); } @@ -103,7 +104,8 @@ void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) { // GPU memory use to spike needlessly. An alternative strategy would // be to throttle new Op execution until the pending event queue // clears. -void EventMgr::PollEvents(bool is_dedicated_poller) { +void EventMgr::PollEvents(bool is_dedicated_poller, + gtl::InlinedVector* to_free) { VLOG(2) << "PollEvents free_events_ " << free_events_.size() << " used_events_ " << used_events_.size(); // Sweep the remaining events in order. If this is the dedicated @@ -123,11 +125,9 @@ void EventMgr::PollEvents(bool is_dedicated_poller) { if (!is_dedicated_poller) return; // quit processing queue break; case gpu::Event::Status::kComplete: - delete iu.mem; - if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf); - // The function must be called in another thread, outside of - // the mutex held here. - if (iu.func != nullptr) threadpool_.Schedule(iu.func); + // Make a copy of the InUse record so we can free it after releasing + // the lock + to_free->push_back(iu); free_events_.push_back(iu.event); // Mark this InUse record as completed. iu.event = nullptr; diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h index 5fe9fd782db..f2a1ea26031 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h @@ -18,8 +18,10 @@ limitations under the License. #include #include +#include "tensorflow/stream_executor/stream.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/port.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/tensor.h" @@ -47,9 +49,15 @@ class EventMgr { // currently enqueued on *stream have completed. inline void ThenDeleteTensors(perftools::gputools::Stream* stream, std::vector* tensors) { - mutex_lock l(mu_); - QueueTensors(stream, tensors); - PollEvents(false); + ToFreeVector to_free; + ::perftools::gputools::Event* e; + { + mutex_lock l(mu_); + QueueTensors(stream, tensors, &e); + PollEvents(false, &to_free); + } + stream->ThenRecordEvent(e); + FreeMemory(to_free); } struct BufRec { @@ -61,16 +69,28 @@ class EventMgr { // on it as soon as all events currently enqueued on *stream have completed. inline void ThenDeleteBuffer(perftools::gputools::Stream* stream, BufRec bufrec) { - mutex_lock l(mu_); - QueueBuffer(stream, bufrec); - PollEvents(false); + ToFreeVector to_free; + ::perftools::gputools::Event* e; + { + mutex_lock l(mu_); + QueueBuffer(stream, bufrec, &e); + PollEvents(false, &to_free); + } + stream->ThenRecordEvent(e); + FreeMemory(to_free); } inline void ThenExecute(perftools::gputools::Stream* stream, std::function func) { - mutex_lock l(mu_); - QueueFunc(stream, func); - PollEvents(false); + ToFreeVector to_free; + ::perftools::gputools::Event* e; + { + mutex_lock l(mu_); + QueueFunc(stream, func, &e); + PollEvents(false, &to_free); + } + stream->ThenRecordEvent(e); + FreeMemory(to_free); } private: @@ -85,32 +105,50 @@ class EventMgr { std::function func; }; + typedef gtl::InlinedVector ToFreeVector; + + void FreeMemory(const ToFreeVector& to_free) { + for (const auto& iu : to_free) { + delete iu.mem; + if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf); + // The function must be called in another thread. + if (iu.func != nullptr) threadpool_.Schedule(iu.func); + } + } + // Stream-enqueue an unused Event and save with it a collection of // Tensors and/or a BufRec to be deleted only after the Event // records. - void QueueInUse(perftools::gputools::Stream* stream, InUse in_use) + void QueueInUse(perftools::gputools::Stream* stream, InUse in_use, + ::perftools::gputools::Event** e) EXCLUSIVE_LOCKS_REQUIRED(mu_); void QueueTensors(perftools::gputools::Stream* stream, - std::vector* tensors) + std::vector* tensors, + ::perftools::gputools::Event** e) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr}); + QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr}, e); } - void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec) + void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec, + ::perftools::gputools::Event** e) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr}); + QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr}, e); } void QueueFunc(perftools::gputools::Stream* stream, - std::function func) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - QueueInUse(stream, {nullptr, nullptr, BufRec(), func}); + std::function func, ::perftools::gputools::Event** e) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + QueueInUse(stream, {nullptr, nullptr, BufRec(), func}, e); } // This function should be called at roughly the same tempo as // QueueTensors() to check whether pending events have recorded, - // and then retire them. - void PollEvents(bool is_dedicated_poller) EXCLUSIVE_LOCKS_REQUIRED(mu_); + // and then retire them. It appends InUse elements that need cleanup + // to "*to_free". The caller should call FreeMemory(to_free) + // when this returns. + void PollEvents(bool is_dedicated_poller, ToFreeVector* to_free) + EXCLUSIVE_LOCKS_REQUIRED(mu_); // An internal polling loop that runs at a low frequency to clear // straggler Events. diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc index 6956ead643e..c6893c91e7e 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc @@ -42,13 +42,21 @@ class TEST_EventMgrHelper { void QueueTensors(perftools::gputools::Stream* stream, std::vector* tensors) { - mutex_lock l(em_->mu_); - em_->QueueTensors(stream, tensors); + ::perftools::gputools::Event* e; + { + mutex_lock l(em_->mu_); + em_->QueueTensors(stream, tensors, &e); + } + stream->ThenRecordEvent(e); } void PollEvents(bool is_dedicated_poller) { - mutex_lock l(em_->mu_); - em_->PollEvents(is_dedicated_poller); + EventMgr::ToFreeVector to_free; + { + mutex_lock l(em_->mu_); + em_->PollEvents(is_dedicated_poller, &to_free); + } + em_->FreeMemory(to_free); } private: diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 29f3cf483ac..66f181d4cf0 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -119,9 +119,10 @@ class DeviceBase { // "event_mgr" is used to delay deallocation of temporary GPU buffers. // TODO(pbar) Work out how to move this out of DeviceBase. struct GpuDeviceInfo { - perftools::gputools::Stream* stream; - DeviceContext* default_context; - EventMgr* event_mgr; + // Make sure all the defaults are NULL, so we can spot missing assignments. + perftools::gputools::Stream* stream = nullptr; + DeviceContext* default_context = nullptr; + EventMgr* event_mgr = nullptr; }; // Does not take ownership. diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 960c6535938..8d5ed3c2fe4 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -55,6 +55,24 @@ struct CastFunctor { } // namespace functor +#define CURRY_TYPES2(FN, arg0) \ + FN(arg0, bool); \ + FN(arg0, uint8); \ + FN(arg0, int16); \ + FN(arg0, int32); \ + FN(arg0, int64); \ + FN(arg0, float); \ + FN(arg0, double) + +#define CURRY_TYPES3(FN, arg0, arg1) \ + FN(arg0, arg1, bool); \ + FN(arg0, arg1, uint8); \ + FN(arg0, arg1, int16); \ + FN(arg0, arg1, int32); \ + FN(arg0, arg1, int64); \ + FN(arg0, arg1, float); \ + FN(arg0, arg1, double) + #define CAST_CASE(DEVICE, IN, OUT) \ if (DataTypeToEnum::value == src_dtype_ && \ DataTypeToEnum::value == dst_dtype_) { \ @@ -110,27 +128,14 @@ class CpuCastOp : public CastOpBase { work_ = nullptr; // Identity return Status::OK(); } - CAST_CASE(CPUDevice, bool, float); - CAST_CASE(CPUDevice, bool, int32); - CAST_CASE(CPUDevice, bool, double); - CAST_CASE(CPUDevice, double, float); - CAST_CASE(CPUDevice, double, int32); - CAST_CASE(CPUDevice, double, int64); - CAST_CASE(CPUDevice, float, double); - CAST_CASE(CPUDevice, float, uint8); - CAST_CASE(CPUDevice, float, int32); - CAST_CASE(CPUDevice, float, int64); - CAST_CASE(CPUDevice, int32, double); - CAST_CASE(CPUDevice, int32, float); - CAST_CASE(CPUDevice, int32, uint8); - CAST_CASE(CPUDevice, int32, int64); - CAST_CASE(CPUDevice, int64, double); - CAST_CASE(CPUDevice, int64, float); - CAST_CASE(CPUDevice, int64, int32); - CAST_CASE(CPUDevice, uint8, float); - CAST_CASE(CPUDevice, uint8, int32); - CAST_CASE(CPUDevice, uint8, int64); - CAST_CASE(CPUDevice, uint8, double); + CURRY_TYPES3(CAST_CASE, CPUDevice, bool); + CURRY_TYPES3(CAST_CASE, CPUDevice, uint8); + CURRY_TYPES3(CAST_CASE, CPUDevice, int16); + CURRY_TYPES3(CAST_CASE, CPUDevice, int32); + CURRY_TYPES3(CAST_CASE, CPUDevice, int64); + CURRY_TYPES3(CAST_CASE, CPUDevice, float); + CURRY_TYPES3(CAST_CASE, CPUDevice, double); + if (src_dtype_ == DT_BFLOAT16 && dst_dtype_ == DT_FLOAT) { work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { int64 N = out->NumElements(); @@ -185,24 +190,15 @@ class GpuCastOp : public CastOpBase { work_ = nullptr; // Identity return Status::OK(); } - CAST_CASE(GPUDevice, bfloat16, float); - CAST_CASE(GPUDevice, bool, float); - CAST_CASE(GPUDevice, double, float); - CAST_CASE(GPUDevice, double, int64); + CURRY_TYPES3(CAST_CASE, GPUDevice, bool); + CURRY_TYPES3(CAST_CASE, GPUDevice, uint8); + CURRY_TYPES3(CAST_CASE, GPUDevice, int16); + CURRY_TYPES3(CAST_CASE, GPUDevice, int32); + CURRY_TYPES3(CAST_CASE, GPUDevice, int64); + CURRY_TYPES3(CAST_CASE, GPUDevice, float); + CURRY_TYPES3(CAST_CASE, GPUDevice, double); CAST_CASE(GPUDevice, float, bfloat16); - CAST_CASE(GPUDevice, float, double); - CAST_CASE(GPUDevice, float, int64); - CAST_CASE(GPUDevice, int64, double); - CAST_CASE(GPUDevice, int64, float); - CAST_CASE(GPUDevice, uint8, float); - CAST_CASE(GPUDevice, float, uint8); - CAST_CASE(GPUDevice, bool, int32); - CAST_CASE(GPUDevice, double, int32); - CAST_CASE(GPUDevice, float, int32); - CAST_CASE(GPUDevice, int32, double); - CAST_CASE(GPUDevice, int32, float); - CAST_CASE(GPUDevice, int32, int64); - CAST_CASE(GPUDevice, int64, int32); + CAST_CASE(GPUDevice, bfloat16, float); return Unimplemented(); } }; @@ -217,28 +213,24 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp); .TypeConstraint("SrcT") \ .TypeConstraint("DstT") \ .Device(DEVICE_GPU), \ - GpuCastOp); -REGISTER_CAST_GPU(bfloat16, float); -REGISTER_CAST_GPU(bool, float); -REGISTER_CAST_GPU(double, float); -REGISTER_CAST_GPU(double, int64); + GpuCastOp) + +CURRY_TYPES2(REGISTER_CAST_GPU, bool); +CURRY_TYPES2(REGISTER_CAST_GPU, uint8); +CURRY_TYPES2(REGISTER_CAST_GPU, int16); +CURRY_TYPES2(REGISTER_CAST_GPU, int32); +CURRY_TYPES2(REGISTER_CAST_GPU, int64); +CURRY_TYPES2(REGISTER_CAST_GPU, float); +CURRY_TYPES2(REGISTER_CAST_GPU, double); REGISTER_CAST_GPU(float, bfloat16); -REGISTER_CAST_GPU(float, double); -REGISTER_CAST_GPU(float, int64); -REGISTER_CAST_GPU(int64, double); -REGISTER_CAST_GPU(int64, float); -REGISTER_CAST_GPU(uint8, float); -REGISTER_CAST_GPU(float, uint8); -REGISTER_CAST_GPU(bool, int32); -REGISTER_CAST_GPU(double, int32); -REGISTER_CAST_GPU(float, int32); -REGISTER_CAST_GPU(int32, double); -REGISTER_CAST_GPU(int32, float); -REGISTER_CAST_GPU(int32, int64); -REGISTER_CAST_GPU(int64, int32); +REGISTER_CAST_GPU(bfloat16, float); + #undef REGISTER_CAST_GPU #endif // GOOGLE_CUDA +#undef CURRY_TYPES2 +#undef CURRY_TYPES3 + // HostCast differs from Cast in that its input and output are in host memory. REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp); REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc index 43f8cd90edc..57f08736211 100644 --- a/tensorflow/core/kernels/cast_op_gpu.cu.cc +++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc @@ -33,25 +33,27 @@ struct CastFunctor { } }; -#define DEFINE(O, I) template struct CastFunctor; -DEFINE(float, double); -DEFINE(float, int32); -DEFINE(float, int64); -DEFINE(double, float); -DEFINE(double, int32); -DEFINE(double, int64); -DEFINE(int32, float); -DEFINE(int32, double); -DEFINE(int32, int64); -DEFINE(int64, float); -DEFINE(int64, double); -DEFINE(int64, int32); -DEFINE(int32, bool); -DEFINE(float, bool); -DEFINE(float, uint8); -DEFINE(uint8, float); -DEFINE(float, bfloat16); +#define DEFINE(O, I) template struct CastFunctor +#define DEFINE_ALL_FROM(in_type) \ + DEFINE(in_type, bool); \ + DEFINE(in_type, uint8); \ + DEFINE(in_type, int16); \ + DEFINE(in_type, int32); \ + DEFINE(in_type, int64); \ + DEFINE(in_type, float); \ + DEFINE(in_type, double) + +DEFINE_ALL_FROM(bool); +DEFINE_ALL_FROM(uint8); +DEFINE_ALL_FROM(int16); +DEFINE_ALL_FROM(int32); +DEFINE_ALL_FROM(int64); +DEFINE_ALL_FROM(float); +DEFINE_ALL_FROM(double); DEFINE(bfloat16, float); +DEFINE(float, bfloat16); + +#undef DEFINE_ALL_FROM #undef DEFINE } // end namespace functor diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index b93c0857db9..168914f5539 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -41,22 +41,48 @@ class CastOpTest : public OpsTestBase { void MakeOp(DataType src, DataType dst) { RequireDefaultOps(); EXPECT_OK(NodeDefBuilder("cast_op", "Cast") - .Input(FakeInput(DT_INT32)) + .Input(FakeInput(src)) .Attr("SrcT", src) .Attr("DstT", dst) .Finalize(node_def())); EXPECT_OK(InitOp()); } + + template + void CheckCast() { + DataType in_type = DataTypeToEnum::v(); + DataType out_type = DataTypeToEnum::v(); + MakeOp(in_type, out_type); + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), out_type, TensorShape({1, 2, 2, 1})); + test::FillValues(&expected, {1, 2, 3, 4}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } }; -TEST_F(CastOpTest, Int32ToUint8) { - MakeOp(DT_INT32, DT_UINT8); - AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); - ASSERT_OK(RunOpKernel()); - Tensor expected(allocator(), DT_UINT8, TensorShape({1, 2, 2, 1})); - test::FillValues(&expected, {1, 2, 3, 4}); - test::ExpectTensorEqual(expected, *GetOutput(0)); -} +#define TEST_CAST(in, out) \ + TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast(); } + +#define TEST_ALL_CASTS_FROM(in) \ + TEST_CAST(in, uint8); \ + TEST_CAST(in, int16); \ + TEST_CAST(in, int32); \ + TEST_CAST(in, int64); \ + TEST_CAST(in, float); \ + TEST_CAST(in, double) + +TEST_ALL_CASTS_FROM(uint8) +TEST_ALL_CASTS_FROM(int16) +TEST_ALL_CASTS_FROM(int32) +TEST_ALL_CASTS_FROM(int64) +TEST_ALL_CASTS_FROM(float) +TEST_ALL_CASTS_FROM(double) + +#undef TEST_ALL_CASTS_FROM +#undef TEST_CAST + +// TODO(wicke): check conversions from/to bool, and bfloat16 static void BM_cpu_float_int64(int iters, int num) { testing::ItemsProcessed(static_cast(iters) * num); diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc index 581171c6bae..084ca9a7643 100644 --- a/tensorflow/core/kernels/concat_op_gpu.cu.cc +++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc @@ -34,10 +34,12 @@ void ConcatGPU(const GPUDevice& d, const std::vector< std::unique_ptr::ConstMatrix>>& inputs, typename TTypes::Matrix* output) { - Eigen::array offset(0, 0); + Eigen::array offset{0, 0}; for (int i = 0; i < inputs.size(); ++i) { - Eigen::array size = inputs[i]->dimensions(); - output->slice(offset, size).device(d) = *inputs[i]; + Eigen::array size; + size[0] = inputs[i]->dimension(0); + size[1] = inputs[i]->dimension(1); + To32Bit(*output).slice(offset, size).device(d) = To32Bit(*inputs[i]); offset[1] += size[1]; } } diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/constant_op_gpu.cu.cc index 5991391850a..bbb7a0ee284 100644 --- a/tensorflow/core/kernels/constant_op_gpu.cu.cc +++ b/tensorflow/core/kernels/constant_op_gpu.cu.cc @@ -73,7 +73,7 @@ struct FillFunctor { void operator()(const GPUDevice& d, typename TTypes::Flat out, typename TTypes::ConstScalar in) { Eigen::internal::scalar_const_op f(in.data()); - out.device(d) = out.nullaryExpr(f); + To32Bit(out).device(d) = To32Bit(out).nullaryExpr(f); } }; @@ -91,7 +91,7 @@ DEFINE_FILL_GPU(int64); template struct SetZeroFunctor { void operator()(const GPUDevice& d, typename TTypes::Flat out) { - out.device(d) = out.constant(0); + To32Bit(out).device(d) = To32Bit(out).constant(0); } }; diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index dae06f4bfc7..8bd13b4be3d 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -242,13 +242,13 @@ typedef Eigen::GpuDevice GPUDevice; const auto expanded_out_cols = (output_cols - 1) * stride + 1; \ const auto padded_out_rows = input_rows + filter_rows - 1; \ const auto padded_out_cols = input_cols + filter_cols - 1; \ - const auto top_pad_rows = filter_rows - 1 - pad_rows; \ - const auto left_pad_cols = filter_cols - 1 - pad_cols; \ - const auto bottom_pad_rows = \ + const int top_pad_rows = filter_rows - 1 - pad_rows; \ + const int left_pad_cols = filter_cols - 1 - pad_cols; \ + const int bottom_pad_rows = \ padded_out_rows - expanded_out_rows - top_pad_rows; \ - const auto right_pad_cols = \ + const int right_pad_cols = \ padded_out_cols - expanded_out_cols - left_pad_cols; \ - Eigen::DSizes strides{1, stride, stride, 1}; \ + Eigen::DSizes strides{1, stride, stride, 1}; \ VLOG(2) << "Conv2d: " << label \ << ": expanded_out_rows = " << expanded_out_rows \ << ", expanded_out_cols = " << expanded_out_cols \ @@ -809,9 +809,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel { context->allocate_output(0, input_shape, &in_backprop)); const int padding_rows = - (output_rows - 1) * stride + filter_rows - input_rows; + (padding_ == VALID) ? 0 : (output_rows - 1) * stride + filter_rows - + input_rows; const int padding_cols = - (output_cols - 1) * stride + filter_cols - input_cols; + (padding_ == VALID) ? 0 : (output_cols - 1) * stride + filter_cols - + input_cols; // TODO(keveman): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts @@ -954,16 +956,17 @@ class Conv2DSlowBackpropInputOp : public OpKernel { context->allocate_temp(DataTypeToEnum::v(), padded_out_shape, &padded_output)); - Eigen::DSizes trivial_order{0, 1, 2, 3}; - Eigen::array, 4> pad_dims{ + Eigen::DSizes trivial_order{0, 1, 2, 3}; + Eigen::array, 4> pad_dims{ {{0, 0}, {top_pad_rows, bottom_pad_rows}, {left_pad_cols, right_pad_cols}, {0, 0}}}; - functor::InflatePadAndShuffle()( - context->eigen_device(), out_backprop.tensor(), strides, - pad_dims, trivial_order, padded_output.tensor()); + functor::InflatePadAndShuffle()( + context->eigen_device(), To32Bit(out_backprop.tensor()), + strides, pad_dims, trivial_order, + To32Bit(padded_output.tensor())); const Tensor& padded_output_cref = padded_output; // We then need to fill a new "reverted" filter @@ -976,11 +979,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel { context->allocate_temp(DataTypeToEnum::v(), r_filter_shape, &r_filter)); - Eigen::DSizes filter_order{0, 1, 3, 2}; + Eigen::DSizes filter_order{0, 1, 3, 2}; Eigen::array filter_rev_dims{true, true, false, false}; - functor::ShuffleAndReverse()( - context->eigen_device(), filter.tensor(), filter_order, - filter_rev_dims, r_filter.tensor()); + functor::ShuffleAndReverse()( + context->eigen_device(), To32Bit(filter.tensor()), + filter_order, filter_rev_dims, To32Bit(r_filter.tensor())); const Tensor& r_filter_cref = r_filter; // Now we can call conv_2d directly. @@ -1039,20 +1042,22 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { context->allocate_output(0, filter_shape, &filter_backprop)); const int padding_rows = - (output_rows - 1) * stride + filter_rows - input_rows; + (padding_ == VALID) ? 0 : (output_rows - 1) * stride + filter_rows - + input_rows; const int padding_cols = - (output_cols - 1) * stride + filter_cols - input_cols; + (padding_ == VALID) ? 0 : (output_cols - 1) * stride + filter_cols - + input_cols; // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts // supporting different padding. - bool padding_compatible = - (padding_rows % 2 == 0) && (padding_cols % 2 == 0); + bool rows_odd = (padding_rows % 2 != 0); + bool cols_odd = (padding_cols % 2 != 0); auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (use_cudnn_ && padding_compatible) { + if (use_cudnn_) { if (filter_rows == 1 && filter_cols == 1 && stride == 1) { const uint64 m = in_depth; const uint64 k = batch * input_rows * input_cols; @@ -1089,10 +1094,31 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { return; } + Tensor compatible_input; + if (rows_odd || cols_odd) { + // If a padding dimension is odd, we have one more element on the right + // side or the bottom side. This is unsupported in cudnn. Therefore, + // we pad that extra element and make it compatible. + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({input.dim_size(0), input.dim_size(1) + rows_odd, + input.dim_size(2) + cols_odd, input.dim_size(3)}), + &compatible_input)); + + functor::PadInput()( + context->template eigen_device(), + To32Bit(input.tensor()), 0, rows_odd, 0, cols_odd, + To32Bit(compatible_input.tensor())); + } else { + compatible_input = input; + } + perftools::gputools::dnn::BatchDescriptor input_desc; input_desc.set_count(batch) - .set_height(input_rows) - .set_width(input_cols) + .set_height(compatible_input.dim_size(1)) + .set_width(compatible_input.dim_size(2)) .set_feature_map_count(in_depth) .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); perftools::gputools::dnn::BatchDescriptor output_desc; @@ -1146,14 +1172,19 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { transformed_out_backprop.tensor()); Tensor transformed_input; - OP_REQUIRES_OK(context, - context->allocate_temp( - DataTypeToEnum::value, - TensorShape({batch, in_depth, input_rows, input_cols}), - &transformed_input)); - functor::NHWCToNCHW()(context->eigen_device(), - input.tensor(), - transformed_input.tensor()); + OP_REQUIRES_OK( + context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({ + compatible_input.dim_size(0), compatible_input.dim_size(3), + compatible_input.dim_size(1), compatible_input.dim_size(2), + }), + &transformed_input)); + functor::NHWCToNCHW()( + context->eigen_device(), + const_cast(compatible_input).tensor(), + transformed_input.tensor()); auto out_backprop_ptr = AsDeviceMemory(transformed_out_backprop.template flat().data(), @@ -1193,7 +1224,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // [batch, out_rows, out_cols, out_depth] // And we need to change it to // [out_depth, out_rows, out_cols, batch] - Eigen::DSizes out_order{3, 1, 2, 0}; + Eigen::DSizes out_order{3, 1, 2, 0}; TensorShape padded_out_shape( {out_depth, padded_out_rows, padded_out_cols, batch}); Tensor padded_output; @@ -1201,14 +1232,14 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { context->allocate_temp(DataTypeToEnum::v(), padded_out_shape, &padded_output)); - Eigen::array, 4> pad_dims{ + Eigen::array, 4> pad_dims{ {{0, 0}, {top_pad_rows, bottom_pad_rows}, {left_pad_cols, right_pad_cols}, {0, 0}}}; - functor::InflatePadAndShuffle()( - context->eigen_device(), out_backprop.tensor(), strides, - pad_dims, out_order, padded_output.tensor()); + functor::InflatePadAndShuffle()( + context->eigen_device(), To32Bit(out_backprop.tensor()), + strides, pad_dims, out_order, To32Bit(padded_output.tensor())); const Tensor& padded_output_cref = padded_output; // For the backprop of the filter, we need to transpose the input. @@ -1216,7 +1247,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // [batch, in_rows, in_cols, in_depth] // And we need to change it to // [in_rows, in_cols, batch, in_depth] - Eigen::DSizes in_order{1, 2, 0, 3}; + Eigen::DSizes in_order{1, 2, 0, 3}; TensorShape in_shuffle_shape({input_rows, input_cols, batch, in_depth}); Tensor in_shuffle; OP_REQUIRES_OK(context, @@ -1225,9 +1256,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // No need for reversing this time. Eigen::array trivial_dims{false, false, false, false}; - functor::ShuffleAndReverse()( - context->eigen_device(), input.tensor(), in_order, - trivial_dims, in_shuffle.tensor()); + functor::ShuffleAndReverse()( + context->eigen_device(), To32Bit(input.tensor()), + in_order, trivial_dims, To32Bit(in_shuffle.tensor())); const Tensor& in_shuffle_cref = in_shuffle; // The output of the conv_2d would be @@ -1250,12 +1281,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { BrainPadding2EigenPadding(VALID)); // Now copy the filter_backprop back to the destination. - Eigen::DSizes filter_order{1, 2, 3, 0}; + Eigen::DSizes filter_order{1, 2, 3, 0}; Eigen::array filter_rev_dims{true, true, false, false}; const Tensor& filter_shuffle_cref = filter_shuffle; - functor::ShuffleAndReverse()( - context->eigen_device(), filter_shuffle_cref.tensor(), - filter_order, filter_rev_dims, filter_backprop->tensor()); + functor::ShuffleAndReverse()( + context->eigen_device(), + To32Bit(filter_shuffle_cref.tensor()), filter_order, + filter_rev_dims, To32Bit(filter_backprop->tensor())); } } @@ -1271,25 +1303,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ - void ShuffleAndReverse::operator()( \ - const GPUDevice& d, \ - typename TTypes::ConstTensor input, \ - const Eigen::DSizes& order, \ - const Eigen::array& reverse_dims, \ - typename TTypes::Tensor output); \ - extern template struct ShuffleAndReverse; \ - template <> \ - void InflatePadAndShuffle::operator()( \ - const GPUDevice& d, \ - typename TTypes::ConstTensor input, \ - const Eigen::DSizes& strides, \ - const Eigen::array, 4>& pad_dims, \ - const Eigen::DSizes& order, \ - typename TTypes::Tensor output); \ - extern template struct InflatePadAndShuffle; \ - template <> \ void ShuffleAndReverse::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor input, \ const Eigen::DSizes& order, \ @@ -1328,7 +1341,13 @@ namespace functor { typename TTypes::ConstTensor filter, \ typename TTypes::ConstTensor output_backprop, int input_rows, \ int input_cols, int stride); \ - extern template struct SpatialConvolutionBackwardInput + extern template struct SpatialConvolutionBackwardInput; \ + template <> \ + void PadInput::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + int padding_rows_left, int padding_rows_right, int padding_cols_left, \ + int padding_cols_right, typename TTypes::Tensor out); \ + extern template struct PadInput; DECLARE_GPU_SPEC(float); #undef DECLARE_GPU_SPEC diff --git a/tensorflow/core/kernels/conv_ops_gpu.cu.cc b/tensorflow/core/kernels/conv_ops_gpu.cu.cc index 60ff6b00241..e4ee058406e 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu.cu.cc @@ -33,12 +33,8 @@ struct SpatialConvolution { typename TTypes::ConstTensor input, typename TTypes::ConstTensor filter, int stride, const Eigen::PaddingType& padding) { - // TODO(keveman): nvcc 6.5 crashes when 32 bit indexing is turned on. Enable - // this when we move to cuda 7.0. - // SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input), - // To32Bit(filter), stride, padding); - - SpatialConvolutionFunc(d, output, input, filter, stride, padding); + SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input), To32Bit(filter), + stride, padding); } }; diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc index bc2b62375f1..8fed594b258 100644 --- a/tensorflow/core/kernels/cwise_op_div.cc +++ b/tensorflow/core/kernels/cwise_op_div.cc @@ -16,21 +16,11 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER5(BinaryOp, CPU, "Div", functor::div, float, double, int32, int64, - complex64); +REGISTER7(BinaryOp, CPU, "Div", functor::div, float, double, uint8, int16, + int32, int64, complex64); #if GOOGLE_CUDA -REGISTER3(BinaryOp, GPU, "Div", functor::div, float, double, int64); +REGISTER6(BinaryOp, GPU, "Div", functor::div, float, double, uint8, int16, + int32, int64); #endif -// A special GPU kernel for int32. -// TODO(b/25387198): Also enable int32 in device memory. This kernel -// registration requires all int32 inputs and outputs to be in host memory. -REGISTER_KERNEL_BUILDER(Name("Div") - .Device(DEVICE_GPU) - .HostMemory("x") - .HostMemory("y") - .HostMemory("z") - .TypeConstraint("T"), - BinaryOp>); - } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc index 80a02da6512..a2809d54811 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY3(div, float, double, int64); +DEFINE_BINARY6(div, float, double, uint8, int16, int32, int64); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc index a4ecaf185ac..068003b2945 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY3(mul, float, double, int64); +DEFINE_BINARY7(mul, float, double, uint8, int8, int16, int32, int64); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_mul.cc b/tensorflow/core/kernels/cwise_op_mul.cc index a7b9859b193..42d50358e63 100644 --- a/tensorflow/core/kernels/cwise_op_mul.cc +++ b/tensorflow/core/kernels/cwise_op_mul.cc @@ -16,21 +16,11 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER7(BinaryOp, CPU, "Mul", functor::mul, float, double, int32, int64, int8, - int16, complex64); +REGISTER8(BinaryOp, CPU, "Mul", functor::mul, float, double, uint8, int8, int16, + int32, int64, complex64); #if GOOGLE_CUDA -REGISTER3(BinaryOp, GPU, "Mul", functor::mul, float, double, int64); +REGISTER7(BinaryOp, GPU, "Mul", functor::mul, float, double, uint8, int8, int16, + int32, int64); #endif -// A special GPU kernel for int32. -// TODO(b/25387198): Also enable int32 in device memory. This kernel -// registration requires all int32 inputs and outputs to be in host memory. -REGISTER_KERNEL_BUILDER(Name("Mul") - .Device(DEVICE_GPU) - .HostMemory("x") - .HostMemory("y") - .HostMemory("z") - .TypeConstraint("T"), - BinaryOp>); - } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index 3296826d483..adf4203322a 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -379,6 +379,8 @@ struct SelectFunctor { #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0) #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ REGISTER(OP, D, N, F, T0) +#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ + REGISTER(OP, D, N, F, T0) #else // !defined(__ANDROID__) #define REGISTER2(OP, D, N, F, T0, T1) \ REGISTER(OP, D, N, F, T0) \ @@ -398,6 +400,9 @@ struct SelectFunctor { #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ REGISTER3(OP, D, N, F, T4, T5, T6) +#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ + REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ + REGISTER4(OP, D, N, F, T4, T5, T6, T7) #endif // defined(__ANDROID__) } // end namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h index 966d3393b65..091c6717dc3 100644 --- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h +++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h @@ -40,7 +40,7 @@ template struct UnaryFunctor { void operator()(const GPUDevice& d, typename Functor::tout_type out, typename Functor::tin_type in) { - out.device(d) = in.unaryExpr(typename Functor::func()); + To32Bit(out).device(d) = To32Bit(in).unaryExpr(typename Functor::func()); } }; @@ -50,7 +50,8 @@ struct BinaryFunctor { void operator()(const GPUDevice& d, typename Functor::tout_type out, typename Functor::tin_type in0, typename Functor::tin_type in1) { - out.device(d) = in0.binaryExpr(in1, typename Functor::func()); + To32Bit(out).device(d) = + To32Bit(in0).binaryExpr(in1, typename Functor::func()); } void Left(const GPUDevice& d, typename Functor::tout_type out, @@ -60,7 +61,7 @@ struct BinaryFunctor { typedef typename Functor::in_type Tin; typedef typename Functor::func Binary; typedef typename Eigen::internal::scalar_left Unary; - out.device(d) = in.unaryExpr(Unary(scalar.data())); + To32Bit(out).device(d) = To32Bit(in).unaryExpr(Unary(scalar.data())); } void Right(const GPUDevice& d, typename Functor::tout_type out, @@ -70,7 +71,7 @@ struct BinaryFunctor { typedef typename Functor::in_type Tin; typedef typename Functor::func Binary; typedef typename Eigen::internal::scalar_right Unary; - out.device(d) = in.unaryExpr(Unary(scalar.data())); + To32Bit(out).device(d) = To32Bit(in).unaryExpr(Unary(scalar.data())); } void BCast(const GPUDevice& d, @@ -86,16 +87,18 @@ struct BinaryFunctor { const bool bcast0_all_one = AllOne(bcast0); const bool bcast1_all_one = AllOne(bcast1); if (bcast0_all_one && !bcast1_all_one) { - out.device(d) = in0.binaryExpr(in1.broadcast(bcast1), func); + To32Bit(out).device(d) = + To32Bit(in0).binaryExpr(To32Bit(in1).broadcast(bcast1), func); return; } if (!bcast0_all_one && bcast1_all_one) { - out.device(d) = in0.broadcast(bcast0).binaryExpr(in1, func); + To32Bit(out).device(d) = + To32Bit(in0).broadcast(bcast0).binaryExpr(To32Bit(in1), func); return; } } - out.device(d) = - in0.broadcast(bcast0).binaryExpr(in1.broadcast(bcast1), func); + To32Bit(out).device(d) = To32Bit(in0).broadcast(bcast0).binaryExpr( + To32Bit(in1).broadcast(bcast1), func); } }; @@ -105,7 +108,8 @@ struct SelectFunctor { typename TTypes::ConstFlat cond_flat, typename TTypes::ConstFlat then_flat, typename TTypes::ConstFlat else_flat) { - out.device(d) = cond_flat.select(then_flat, else_flat); + To32Bit(out).device(d) = + To32Bit(cond_flat).select(To32Bit(then_flat), To32Bit(else_flat)); } }; @@ -143,6 +147,12 @@ struct SelectFunctor { #define DEFINE_BINARY5(F, T0, T1, T2, T3, T4) \ DEFINE_BINARY2(F, T0, T1); \ DEFINE_BINARY3(F, T2, T3, T4) +#define DEFINE_BINARY6(F, T0, T1, T2, T3, T4, T5) \ + DEFINE_BINARY3(F, T0, T1, T2); \ + DEFINE_BINARY3(F, T3, T4, T5) +#define DEFINE_BINARY7(F, T0, T1, T2, T3, T4, T5, T6) \ + DEFINE_BINARY3(F, T0, T1, T2); \ + DEFINE_BINARY4(F, T3, T4, T5, T6) } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc index fb779f24665..9ae2eedb30c 100644 --- a/tensorflow/core/kernels/lrn_op.cc +++ b/tensorflow/core/kernels/lrn_op.cc @@ -30,10 +30,17 @@ limitations under the License. namespace tensorflow { +namespace { + +// When the depth is large and beta_ is 0.5 or 1.0, MognetLRN is faster than the +// main band matrix approach used below. Benchmarks suggest switching to +// MognetLRN when depth > 384. +const int kMognetLRNDepthCutoff = 384; + // Create a depth-by-depth band matrix with 1s along a swath of size (2 * // depth_radius + 1) around the diagonal. -static void GetBandMatrix(int depth, int64 depth_radius, - Eigen::Tensor* result) { +void GetBandMatrix(int depth, int64 depth_radius, + Eigen::Tensor* result) { result->setZero(); for (int row = 0; row < depth; ++row) { const int begin = std::max(0, row - depth_radius); @@ -44,6 +51,8 @@ static void GetBandMatrix(int depth, int64 depth_radius, } } +} // namespace + class LRNOp : public OpKernel { public: explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) { @@ -69,6 +78,11 @@ class LRNOp : public OpKernel { #if defined(__ANDROID__) MognetLRN(in, batch, rows, cols, depth, output); #else + if (depth > kMognetLRNDepthCutoff && (beta_ == 0.5f || beta_ == 1.0f)) { + MognetLRN(in, batch, rows, cols, depth, output); + return; + } + const int nodes = cols * rows; auto in_shaped = in.shaped({nodes * batch, depth}); @@ -79,13 +93,16 @@ class LRNOp : public OpKernel { auto out_shaped = output->shaped({nodes * batch, depth}); Eigen::array dims = {{DimPair(1, 0)}}; - /// TODO(keveman): Optimize for beta in {0, 1, 0.5} - out_shaped.device(context->eigen_cpu_device()) = - in_shaped / - in_shaped.square() - .contract(multiplier, dims) - .unaryExpr([this](float x) { return bias_ + alpha_ * x; }) - .pow(beta_); + auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_; + if (beta_ == 1.0f) { + out_shaped.device(context->eigen_cpu_device()) = + in_shaped * tmp.inverse(); + } else if (beta_ == 0.5f) { + out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt(); + } else { + out_shaped.device(context->eigen_cpu_device()) = + in_shaped * (tmp.log() * -beta_).exp(); + } #endif } @@ -104,11 +121,11 @@ class LRNOp : public OpKernel { Eigen::VectorXf padded_square(data_in.rows() + double_depth_radius); padded_square.setZero(); for (int r = 0; r < data_in.cols(); ++r) { - // Do local response normalization for data_in(:, r) - // first, compute the square and store them in buffer for repeated use + // Do local response normalization for data_in(:, r). First, compute the + // square and store them in buffer for repeated use. padded_square.block(depth_radius_, 0, data_out.rows(), 1) = data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_; - // Then, compute the scale and writes them to data_out + // Then, compute the scale and write it to data_out. float accumulated_scale = 0; for (int i = 0; i < double_depth_radius; ++i) { accumulated_scale += padded_square(i); @@ -120,13 +137,13 @@ class LRNOp : public OpKernel { } } - // In a few cases, the pow computation could benefit from speedups. if (beta_ == 1) { data_out.array() = data_in.array() * data_out.array().inverse(); } else if (beta_ == 0.5) { - data_out.array() = data_in.array() * data_out.array().sqrt().inverse(); + data_out.array() = data_in.array() * data_out.array().rsqrt(); } else { - data_out.array() = data_in.array() * data_out.array().pow(-beta_); + data_out.array() = + data_in.array() * (data_out.array().log() * -beta_).exp(); } } diff --git a/tensorflow/core/kernels/reference_gemm.h b/tensorflow/core/kernels/reference_gemm.h deleted file mode 100644 index 16fa541238f..00000000000 --- a/tensorflow/core/kernels/reference_gemm.h +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright 2015 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_KERNELS_REFERENCE_GEMM_H_ -#define TENSORFLOW_KERNELS_REFERENCE_GEMM_H_ - -// This is an unoptimized but debuggable implementation of the GEMM matrix -// multiply function, used to compare to faster but more opaque versions, or -// for bit depths or argument combinations that aren't supported by optimized -// code. -// It assumes the row-major convention used by TensorFlow, and implements -// C = A * B, like the standard BLAS GEMM interface. If the tranpose flags are -// true, then the relevant matrix is treated as stored in column-major order. - -namespace tensorflow { -template -void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c, - size_t m, size_t n, size_t k, const T1* a, T1 offset_a, - size_t lda, const T2* b, T2 offset_b, size_t ldb, T3* c, - int32 shift_c, int32 offset_c, int32 mult_c, size_t ldc) { - int a_i_stride; - int a_l_stride; - if (transpose_a) { - a_i_stride = 1; - a_l_stride = lda; - } else { - a_i_stride = lda; - a_l_stride = 1; - } - int b_j_stride; - int b_l_stride; - if (transpose_b) { - b_j_stride = ldb; - b_l_stride = 1; - } else { - b_j_stride = 1; - b_l_stride = ldb; - } - int c_i_stride; - int c_j_stride; - if (transpose_c) { - c_i_stride = 1; - c_j_stride = ldc; - } else { - c_i_stride = ldc; - c_j_stride = 1; - } - - const int32 highest = static_cast(Eigen::NumTraits::highest()); - const int32 lowest = static_cast(Eigen::NumTraits::lowest()); - const int32 rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1)); - - int i, j, l; - for (j = 0; j < n; j++) { - for (i = 0; i < m; i++) { - int32 total = 0; - for (l = 0; l < k; l++) { - const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); - const int32 a_value = a[a_index] - offset_a; - const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); - const int32 b_value = b[b_index] - offset_b; - total += (a_value * b_value); - } - const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); - int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c); - if (output > highest) { - output = highest; - } - if (output < lowest) { - output = lowest; - } - c[c_index] = static_cast(output); - } - } -} -} // namespace tensorflow - -#endif // TENSORFLOW_KERNELS_REFERENCE_GEMM_H_ diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc index 0671414c510..a25c68a15ac 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.cc +++ b/tensorflow/core/kernels/reverse_sequence_op.cc @@ -39,7 +39,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; template -void CheckErrors(OpKernelContext* context, int seq_dim) { +void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { const Tensor& input = context->input(0); const Tensor& seq_lens = context->input(1); @@ -52,15 +52,18 @@ void CheckErrors(OpKernelContext* context, int seq_dim) { seq_lens_vec.data(), seq_lens_t.data(), sizeof(int64) * seq_lens_t.size()); - OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim")); + OP_REQUIRES(context, batch_dim != seq_dim, + errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); OP_REQUIRES(context, seq_dim < input.dims(), errors::InvalidArgument("seq_dim must be < input.dims()", "( ", seq_dim, " vs. ", input.dims(), ")")); - - OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0), - errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ", - "(", seq_lens.NumElements(), " vs. ", - input.dim_size(seq_dim))); + OP_REQUIRES(context, batch_dim < input.dims(), + errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + batch_dim, " vs. ", input.dims(), ")")); + OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), + errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, + "), ", "(", seq_lens.NumElements(), + " vs. ", input.dim_size(batch_dim))); for (int d = 0; d < seq_lens_vec.size(); ++d) { OP_REQUIRES(context, seq_lens_vec[d] >= 0, @@ -72,19 +75,24 @@ void CheckErrors(OpKernelContext* context, int seq_dim) { } template <> -void CheckErrors(OpKernelContext* context, int seq_dim) { +void CheckErrors(OpKernelContext* context, int batch_dim, + int seq_dim) { const Tensor& input = context->input(0); const Tensor& seq_lens = context->input(1); - OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim")); + OP_REQUIRES(context, batch_dim != seq_dim, + errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); OP_REQUIRES(context, seq_dim < input.dims(), errors::InvalidArgument("seq_dim must be < input.dims()", "( ", seq_dim, " vs. ", input.dims(), ")")); + OP_REQUIRES(context, batch_dim < input.dims(), + errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + batch_dim, " vs. ", input.dims(), ")")); - OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0), - errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ", - "(", seq_lens.NumElements(), " vs. ", - input.dim_size(seq_dim))); + OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), + errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, + "), ", "(", seq_lens.NumElements(), + " vs. ", input.dim_size(batch_dim))); } template @@ -92,6 +100,7 @@ class ReverseSequenceOp : public OpKernel { public: explicit ReverseSequenceOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_)); OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_)); } @@ -106,7 +115,7 @@ class ReverseSequenceOp : public OpKernel { auto seq_lens_t = seq_lens.vec(); - CheckErrors(context, seq_dim_); + CheckErrors(context, batch_dim_, seq_dim_); const int input_dims = input.dims(); @@ -114,11 +123,11 @@ class ReverseSequenceOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); -#define HANDLE_DIM(NDIM) \ - case NDIM: \ - functor::ReverseSequence::Compute( \ - context->eigen_device(), input.tensor(), seq_dim_, \ - seq_lens_t, output->tensor()); \ +#define HANDLE_DIM(NDIM) \ + case NDIM: \ + functor::ReverseSequence::Compute( \ + context->eigen_device(), input.tensor(), batch_dim_, \ + seq_dim_, seq_lens_t, output->tensor()); \ break; switch (input_dims) { @@ -136,6 +145,7 @@ class ReverseSequenceOp : public OpKernel { } private: + int32 batch_dim_; int32 seq_dim_; TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp); @@ -152,12 +162,12 @@ TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE); // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T, Dims) \ - template <> \ - void ReverseSequence::Compute( \ - const GPUDevice& d, typename TTypes::ConstTensor input, \ - int32 seq_dim, TTypes::ConstVec seq_lens, \ - typename TTypes::Tensor output); \ +#define DECLARE_GPU_SPEC(T, Dims) \ + template <> \ + void ReverseSequence::Compute( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + int32 batch_dim, int32 seq_dim, TTypes::ConstVec seq_lens, \ + typename TTypes::Tensor output); \ extern template struct ReverseSequence; #define DECLARE_GPU_SPECS(T) \ diff --git a/tensorflow/core/kernels/reverse_sequence_op.h b/tensorflow/core/kernels/reverse_sequence_op.h index ceb1b0b8801..9dd1e4d01dd 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.h +++ b/tensorflow/core/kernels/reverse_sequence_op.h @@ -29,15 +29,19 @@ template class ReverseGenerator { public: EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE - ReverseGenerator(typename TTypes::ConstTensor input, int32 seq_dim, - TTypes::ConstVec seq_lengths) - : input_(input), seq_dim_(seq_dim), seq_lengths_(seq_lengths) {} + ReverseGenerator(typename TTypes::ConstTensor input, int32 batch_dim, + int32 seq_dim, TTypes::ConstVec seq_lengths) + : input_(input), + batch_dim_(batch_dim), + seq_dim_(seq_dim), + seq_lengths_(seq_lengths) {} EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T operator()(const Eigen::array& coords) const { Eigen::array new_coords = coords; - if (coords[seq_dim_] < seq_lengths_(coords[0])) { - new_coords[seq_dim_] = seq_lengths_(coords[0]) - coords[seq_dim_] - 1; + if (coords[seq_dim_] < seq_lengths_(coords[batch_dim_])) { + new_coords[seq_dim_] = + seq_lengths_(coords[batch_dim_]) - coords[seq_dim_] - 1; } return input_(new_coords); @@ -45,6 +49,7 @@ class ReverseGenerator { private: typename TTypes::ConstTensor input_; + int32 batch_dim_; int32 seq_dim_; TTypes::ConstVec seq_lengths_; }; @@ -57,9 +62,10 @@ template struct ReverseSequence { EIGEN_ALWAYS_INLINE static void Compute( const Device& d, typename TTypes::ConstTensor input, - int32 seq_dim, TTypes::ConstVec seq_lengths, + int32 batch_dim, int32 seq_dim, TTypes::ConstVec seq_lengths, typename TTypes::Tensor output) { - generator::ReverseGenerator generator(input, seq_dim, seq_lengths); + generator::ReverseGenerator generator(input, batch_dim, seq_dim, + seq_lengths); output.device(d) = input.generate(generator); } }; diff --git a/tensorflow/core/kernels/softsign_op.cc b/tensorflow/core/kernels/softsign_op.cc new file mode 100644 index 00000000000..e3480e35947 --- /dev/null +++ b/tensorflow/core/kernels/softsign_op.cc @@ -0,0 +1,112 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/numeric_op.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/softsign_op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class SoftsignOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Softsign functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class SoftsignGradOp + : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): inputs that were passed to SoftsignOp() + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OP_REQUIRES(context, a.IsSameSize(g), + errors::InvalidArgument("g and a must be the same size")); + functor::SoftsignGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Softsign").Device(DEVICE_CPU).TypeConstraint("T"), \ + SoftsignOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SoftsignGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + SoftsignGradOp); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Softsign::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Softsign; \ + \ + template <> \ + void SoftsignGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct SoftsignGrad; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Softsign").Device(DEVICE_GPU).TypeConstraint("T"), \ + SoftsignOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SoftsignGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + SoftsignGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/softsign_op.h b/tensorflow/core/kernels/softsign_op.h new file mode 100644 index 00000000000..36790a5874c --- /dev/null +++ b/tensorflow/core/kernels/softsign_op.h @@ -0,0 +1,60 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_SOFTSIGN_OP_H_ +#define TENSORFLOW_KERNELS_SOFTSIGN_OP_H_ +// Functor definition for SoftsignOp and SoftsignGradOp, must be compilable by +// nvcc. + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by SoftsignOp to do the computations. +template +struct Softsign { + // Computes Softsign activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = + features / (features.abs() + features.constant(1.0f)); + } +}; + +// Functor used by SoftsignGradOp to do the computations. +template +struct SoftsignGrad { + // Computes SoftsignGrad backprops. + // + // gradients: gradients backpropagated to the Softsign op. + // features: inputs that were passed to the Softsign op. + // backprops: gradients to backpropagate to the Softsign inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + backprops.device(d) = + gradients / (features.abs() + features.constant(1.0f)).square(); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_SOFTSIGN_OP_H_ diff --git a/tensorflow/core/kernels/softsign_op_gpu.cu.cc b/tensorflow/core/kernels/softsign_op_gpu.cu.cc new file mode 100644 index 00000000000..4ae941c9f01 --- /dev/null +++ b/tensorflow/core/kernels/softsign_op_gpu.cu.cc @@ -0,0 +1,40 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include + +#include "tensorflow/core/kernels/softsign_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Definition of the GPU implementations declared in softsign_op.cc. +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Softsign; \ + template struct functor::SoftsignGrad; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/split_op_gpu.cu.cc b/tensorflow/core/kernels/split_op_gpu.cu.cc index c79410b68c0..13463b705b0 100644 --- a/tensorflow/core/kernels/split_op_gpu.cu.cc +++ b/tensorflow/core/kernels/split_op_gpu.cu.cc @@ -33,7 +33,7 @@ void Split::operator()( typename TTypes::ConstTensor input, const Eigen::DSizes& slice_indices, const Eigen::DSizes& slice_sizes) { - output.device(d) = input.slice(slice_indices, slice_sizes); + To32Bit(output).device(d) = To32Bit(input).slice(slice_indices, slice_sizes); } #define DEFINE_GPU_KERNELS(T) template struct Split; diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc index 055050cd34a..2c146b3d6c2 100644 --- a/tensorflow/core/kernels/stack_ops.cc +++ b/tensorflow/core/kernels/stack_ops.cc @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // See docs in ../ops/data_flow_ops.cc. #include diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 20806cff684..8287d758f0b 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -404,8 +404,9 @@ Reshapes a tensor. Given `tensor`, this operation returns a tensor that has the same values as `tensor` with shape `shape`. -If `shape` is the special value `[-1]`, then `tensor` is flattened and the -operation outputs a 1-D tensor with all elements of `tensor`. +If one component of `shape` is the special value -1, the size of that dimension +is computed so that the total size remains constant. In particular, a `shape` +of `[-1]` flattens into 1-D. At most one component of `shape` can be -1. If `shape` is 1-D or higher, then the operation returns a tensor with shape `shape` filled with the values of `tensor`. In this case, the number of elements @@ -435,6 +436,13 @@ reshape(t, [2, 4]) ==> [[1, 1, 2, 2] # tensor 't' has shape [3, 2, 3] # pass '[-1]' to flatten 't' reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] +# -1 can also be used with higher dimensional shapes +reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 6, 6, 6]] + +# tensor 't' is [7] +# shape `[]` reshapes to a scalar +reshape(t, []) ==> 7 ``` shape: Defines the shape of the output tensor. @@ -535,25 +543,29 @@ REGISTER_OP("ReverseSequence") .Input("seq_lengths: int64") .Output("output: T") .Attr("seq_dim: int") + .Attr("batch_dim: int = 0") .Attr("T: type") .Doc(R"doc( -Reverses variable length slices in dimension `seq_dim`. +Reverses variable length slices. -This op first slices `input` along the first dimension, and for each slice `i`, -reverses the first `seq_lengths[i]` elements along the dimension `seq_dim`. +This op first slices `input` along the dimension `batch_dim`, and for each +slice `i`, reverses the first `seq_lengths[i]` elements along +the dimension `seq_dim`. The elements of `seq_lengths` must obey `seq_lengths[i] < input.dims[seq_dim]`, -and `seq_lengths` must be a vector of length `input.dims(0)`. +and `seq_lengths` must be a vector of length `input.dims[batch_dim]`. -The output slice `i` along dimension 0 is then given by input slice `i`, with -the first `seq_lengths[i]` slices along dimension `seq_dim` reversed. +The output slice `i` along dimension `batch_dim` is then given by input +slice `i`, with the first `seq_lengths[i]` slices along dimension +`seq_dim` reversed. For example: ```prettyprint # Given this: +batch_dim = 0 seq_dim = 1 -input.dims = (4, ...) +input.dims = (4, 8, ...) seq_lengths = [7, 2, 3, 5] # then slices of input are reversed on seq_dim, but only up to seq_lengths: @@ -569,10 +581,32 @@ output[2, 3:, :, ...] = input[2, 3:, :, ...] output[3, 2:, :, ...] = input[3, 2:, :, ...] ``` +In contrast, if: +```prettyprint +# Given this: +batch_dim = 2 +seq_dim = 0 +input.dims = (8, ?, 4, ...) +seq_lengths = [7, 2, 3, 5] + +# then slices of input are reversed on seq_dim, but only up to seq_lengths: +output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] +output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] +output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] +output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] + +# while entries past seq_lens are copied through: +output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] +output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] +output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] +output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...] +``` + input: The input to reverse. seq_lengths: 1-D with length `input.dims(0)` and `max(seq_lengths) < input.dims(seq_dim)` seq_dim: The dimension which is partially reversed. +batch_dim: The dimension along which reversal is performed. output: The partially reversed input. It has the same shape as `input`. )doc"); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 023c598aa61..a1f1db5f7f1 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -264,7 +264,7 @@ Returns element-wise smallest integer in not less than x. #define BINARY_MORE() \ Input("x: T").Input("y: T").Output("z: T").Attr( \ - "T: {float, double, int8, int16, int32, complex64, int64}") + "T: {float, double, uint8, int8, int16, int32, int64, complex64}") #define BINARY_FEWER() \ Input("x: T").Input("y: T").Output("z: T").Attr( \ @@ -293,7 +293,7 @@ Returns x * y element-wise. )doc"); REGISTER_OP("Div") - .BINARY_FEWER() + .BINARY_MORE() .Doc(R"doc( Returns x / y element-wise. )doc"); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 593f986edb7..29a71730950 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -466,6 +466,27 @@ features: The features passed as input to the corresponding softplus operation. backprops: The gradients: `gradients / (1 + exp(-features))`. )doc"); +REGISTER_OP("Softsign") + .Input("features: T") + .Output("activations: T") + .Attr("T: realnumbertype") + .Doc(R"doc( +Computes softsign: `features / (abs(features) + 1)`. +)doc"); + +REGISTER_OP("SoftsignGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("T: realnumbertype") + .Doc(R"doc( +Computes softsign gradients for a softsign operation. + +gradients: The backpropagated gradients to the corresponding softsign operation. +features: The features passed as input to the corresponding softsign operation. +backprops: The gradients: `gradients / (1 + abs(-features)) ** 2`. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("Softmax") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 33875345873..9f48da94e1d 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -44,11 +44,12 @@ op { list { type: DT_FLOAT type: DT_DOUBLE + type: DT_UINT8 type: DT_INT8 type: DT_INT16 type: DT_INT32 - type: DT_COMPLEX64 type: DT_INT64 + type: DT_COMPLEX64 } } } @@ -1973,9 +1974,12 @@ op { list { type: DT_FLOAT type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 type: DT_INT32 - type: DT_COMPLEX64 type: DT_INT64 + type: DT_COMPLEX64 } } } @@ -4251,11 +4255,12 @@ op { list { type: DT_FLOAT type: DT_DOUBLE + type: DT_UINT8 type: DT_INT8 type: DT_INT16 type: DT_INT32 - type: DT_COMPLEX64 type: DT_INT64 + type: DT_COMPLEX64 } } } @@ -5532,7 +5537,7 @@ op { type: "type" } summary: "Reshapes a tensor." - description: "Given `tensor`, this operation returns a tensor that has the same values\nas `tensor` with shape `shape`.\n\nIf `shape` is the special value `[-1]`, then `tensor` is flattened and the\noperation outputs a 1-D tensor with all elements of `tensor`.\n\nIf `shape` is 1-D or higher, then the operation returns a tensor with shape\n`shape` filled with the values of `tensor`. In this case, the number of elements\nimplied by `shape` must be the same as the number of elements in `tensor`.\n\nFor example:\n\n```prettyprint\n# tensor \'t\' is [1, 2, 3, 4, 5, 6, 7, 8, 9]\n# tensor \'t\' has shape [9]\nreshape(t, [3, 3]) ==> [[1, 2, 3]\n [4, 5, 6]\n [7, 8, 9]]\n\n# tensor \'t\' is [[[1, 1], [2, 2]]\n# [[3, 3], [4, 4]]]\n# tensor \'t\' has shape [2, 2, 2]\nreshape(t, [2, 4]) ==> [[1, 1, 2, 2]\n [3, 3, 4, 4]]\n\n# tensor \'t\' is [[[1, 1, 1],\n# [2, 2, 2]],\n# [[3, 3, 3],\n# [4, 4, 4]],\n# [[5, 5, 5],\n# [6, 6, 6]]]\n# tensor \'t\' has shape [3, 2, 3]\n# pass \'[-1]\' to flatten \'t\'\nreshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]\n```" + description: "Given `tensor`, this operation returns a tensor that has the same values\nas `tensor` with shape `shape`.\n\nIf one component of `shape` is the special value -1, the size of that dimension\nis computed so that the total size remains constant. In particular, a `shape`\nof `[-1]` flattens into 1-D. At most one component of `shape` can be -1.\n\nIf `shape` is 1-D or higher, then the operation returns a tensor with shape\n`shape` filled with the values of `tensor`. In this case, the number of elements\nimplied by `shape` must be the same as the number of elements in `tensor`.\n\nFor example:\n\n```prettyprint\n# tensor \'t\' is [1, 2, 3, 4, 5, 6, 7, 8, 9]\n# tensor \'t\' has shape [9]\nreshape(t, [3, 3]) ==> [[1, 2, 3]\n [4, 5, 6]\n [7, 8, 9]]\n\n# tensor \'t\' is [[[1, 1], [2, 2]]\n# [[3, 3], [4, 4]]]\n# tensor \'t\' has shape [2, 2, 2]\nreshape(t, [2, 4]) ==> [[1, 1, 2, 2]\n [3, 3, 4, 4]]\n\n# tensor \'t\' is [[[1, 1, 1],\n# [2, 2, 2]],\n# [[3, 3, 3],\n# [4, 4, 4]],\n# [[5, 5, 5],\n# [6, 6, 6]]]\n# tensor \'t\' has shape [3, 2, 3]\n# pass \'[-1]\' to flatten \'t\'\nreshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]\n# -1 can also be used with higher dimensional shapes\nreshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],\n [4, 4, 4, 5, 5, 5, 6, 6, 6]]\n\n# tensor \'t\' is [7]\n# shape `[]` reshapes to a scalar\nreshape(t, []) ==> 7\n```" } op { name: "ResizeArea" @@ -6770,6 +6775,67 @@ op { } summary: "Computes softplus gradients for a softplus operation." } +op { + name: "Softsign" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + } + } + } + summary: "Computes softsign: `features / (abs(features) + 1)`." +} +op { + name: "SoftsignGrad" + input_arg { + name: "gradients" + description: "The backpropagated gradients to the corresponding softsign operation." + type_attr: "T" + } + input_arg { + name: "features" + description: "The features passed as input to the corresponding softsign operation." + type_attr: "T" + } + output_arg { + name: "backprops" + description: "The gradients: `gradients / (1 + abs(-features)) ** 2`." + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + } + } + } + summary: "Computes softsign gradients for a softsign operation." +} op { name: "SparseApplyAdagrad" input_arg { diff --git a/tensorflow/core/public/README.md b/tensorflow/core/public/README.md index d5051ae690f..f5e10bf79f3 100644 --- a/tensorflow/core/public/README.md +++ b/tensorflow/core/public/README.md @@ -12,7 +12,7 @@ process. First, bring in tensorflow python dependency -//third_party/tensorflow:tensorflow_py +//third_party/py/tensorflow to get the python TensorFlow API. @@ -22,9 +22,9 @@ Then: import tensorflow as tf with tf.Session("local"): - input1 = tf.Constant(1.0, shape=[1, 1], name="input1") - input2 = tf.Constant(2.0, shape=[1, 1], name="input2") - output = tf.MatMul(input1, input2) + input1 = tf.constant(1.0, shape=[1, 1], name="input1") + input2 = tf.constant(2.0, shape=[1, 1], name="input2") + output = tf.matmul(input1, input2) # Run graph and fetch the output result = output.eval() diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index c78ee33e06d..9cff418c670 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -64,11 +64,13 @@ TF_DEFINE_string(image, "tensorflow/examples/label_image/data/grace_hopper.jpg", "The image to classify (JPEG or PNG)."); TF_DEFINE_string(graph, - "tensorflow/examples/label_image/data/googlenet_graph.pb", + "tensorflow/examples/label_image/data/" + "tensorflow_inception_graph.pb", "The location of the GraphDef file containing the protobuf" " definition of the network."); TF_DEFINE_string(labels, - "tensorflow/examples/label_image/data/googlenet_labels.txt", + "tensorflow/examples/label_image/data/" + "imagenet_comp_graph_label_strings.txt", "A text file containing the labels of all the categories, one" " per line."); TF_DEFINE_int32(input_width, 224, "Width of the image the network expects."); @@ -85,6 +87,10 @@ TF_DEFINE_string(root_dir, "", "The directory at the root of the data files."); // of the result is a multiple of 16, because our model expects that. Status ReadLabelsFile(string file_name, std::vector* result) { std::ifstream file(file_name); + if (!file) { + return tensorflow::errors::NotFound("Labels file ", file_name, + " not found."); + } result->clear(); string line; while (std::getline(file, line)) { diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md index 14d1e396233..79abef17177 100644 --- a/tensorflow/g3doc/api_docs/python/array_ops.md +++ b/tensorflow/g3doc/api_docs/python/array_ops.md @@ -277,8 +277,9 @@ Reshapes a tensor. Given `tensor`, this operation returns a tensor that has the same values as `tensor` with shape `shape`. -If `shape` is the special value `[-1]`, then `tensor` is flattened and the -operation outputs a 1-D tensor with all elements of `tensor`. +If one component of `shape` is the special value -1, the size of that dimension +is computed so that the total size remains constant. In particular, a `shape` +of `[-1]` flattens into 1-D. At most one component of `shape` can be -1. If `shape` is 1-D or higher, then the operation returns a tensor with shape `shape` filled with the values of `tensor`. In this case, the number of elements @@ -308,6 +309,13 @@ reshape(t, [2, 4]) ==> [[1, 1, 2, 2] # tensor 't' has shape [3, 2, 3] # pass '[-1]' to flatten 't' reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] +# -1 can also be used with higher dimensional shapes +reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 6, 6, 6]] + +# tensor 't' is [7] +# shape `[]` reshapes to a scalar +reshape(t, []) ==> 7 ``` ##### Args: diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md index eea786647b5..8217729e9cf 100644 --- a/tensorflow/g3doc/api_docs/python/framework.md +++ b/tensorflow/g3doc/api_docs/python/framework.md @@ -1355,7 +1355,7 @@ for more details. - - - -### `tf.convert_to_tensor(value, dtype=None, name=None)` {#convert_to_tensor} +### `tf.convert_to_tensor(value, dtype=None, name=None, as_ref=False)` {#convert_to_tensor} Converts the given `value` to a `Tensor`. @@ -1390,6 +1390,7 @@ and scalars in addition to `Tensor` objects. * `dtype`: Optional element type for the returned tensor. If missing, the type is inferred from the type of `value`. * `name`: Optional name to use if a new `Tensor` is created. +* `as_ref`: True if we want the result as a ref tensor. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/image.md b/tensorflow/g3doc/api_docs/python/image.md index 5ca185edf4a..3d8c51d5204 100644 --- a/tensorflow/g3doc/api_docs/python/image.md +++ b/tensorflow/g3doc/api_docs/python/image.md @@ -18,7 +18,8 @@ are all of variable size. If you need fixed size images, pass the output of the decode Ops to one of the cropping and resizing Ops. Note: The PNG encode and decode Ops support RGBA, but the conversions Ops -presently only support RGB, HSV, and GrayScale. +presently only support RGB, HSV, and GrayScale. Presently, the alpha channel has +to be stripped from the image and re-attached using slicing ops. - - - @@ -204,10 +205,6 @@ image = tf.image.decode_jpeg(...) resized_image = tf.image.resize_bilinear(image, [299, 299]) ``` -Maybe refer to the Queue examples that show how to add images to a Queue -after resizing them to a fixed size, and how to dequeue batches of resized -images from the Queue. - - - - ### `tf.image.resize_images(images, new_height, new_width, method=0)` {#resize_images} @@ -661,6 +658,43 @@ See also `transpose()`. +## Converting Between Colorspaces. + +Internally, images are either stored in as one `float32` per channel per pixel +(implicitly, values are assumed to lie in `[0,1)`) or one `uint8` per channel +per pixel (values are assumed to lie in `[0,255]`). + +- - - + +### `tf.image.convert_image_dtype(image, dtype, name=None)` {#convert_image_dtype} + +Convert `image` to `dtype`, scaling its values if needed. + +Images that are represented using floating point values are expected to have +values in the range [0,1). Image data stored in integer data types are +expected to have values in the range `[0,MAX]`, wbere `MAX` is the largest +positive representable number for the data type. + +This op converts between data types, scaling the values appropriately before +casting. + +Note that for floating point inputs, this op expects values to lie in [0,1). +Conversion of an image containing values outside that range may lead to +overflow errors when converted to integer `Dtype`s. + +##### Args: + + +* `image`: An image. +* `dtype`: A `DType` to convert `image` to. +* `name`: A name for this operation (optional). + +##### Returns: + + `image`, converted to `dtype`. + + + ## Image Adjustments TensorFlow provides functions to adjust images in various ways: brightness, diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index 1211233fae9..b04559287cd 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -194,6 +194,7 @@ * **[Images](../../api_docs/python/image.md)**: * [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness) * [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast) + * [`convert_image_dtype`](../../api_docs/python/image.md#convert_image_dtype) * [`crop_to_bounding_box`](../../api_docs/python/image.md#crop_to_bounding_box) * [`decode_jpeg`](../../api_docs/python/image.md#decode_jpeg) * [`decode_png`](../../api_docs/python/image.md#decode_png) @@ -283,6 +284,7 @@ * [`nce_loss`](../../api_docs/python/nn.md#nce_loss) * [`relu`](../../api_docs/python/nn.md#relu) * [`relu6`](../../api_docs/python/nn.md#relu6) + * [`rnn`](../../api_docs/python/nn.md#rnn) * [`sampled_softmax_loss`](../../api_docs/python/nn.md#sampled_softmax_loss) * [`separable_conv2d`](../../api_docs/python/nn.md#separable_conv2d) * [`sigmoid`](../../api_docs/python/nn.md#sigmoid) @@ -290,6 +292,8 @@ * [`softmax`](../../api_docs/python/nn.md#softmax) * [`softmax_cross_entropy_with_logits`](../../api_docs/python/nn.md#softmax_cross_entropy_with_logits) * [`softplus`](../../api_docs/python/nn.md#softplus) + * [`softsign`](../../api_docs/python/nn.md#softsign) + * [`state_saving_rnn`](../../api_docs/python/nn.md#state_saving_rnn) * [`tanh`](../../api_docs/python/nn.md#tanh) * [`top_k`](../../api_docs/python/nn.md#top_k) * [`uniform_candidate_sampler`](../../api_docs/python/nn.md#uniform_candidate_sampler) diff --git a/tensorflow/g3doc/api_docs/python/io_ops.md b/tensorflow/g3doc/api_docs/python/io_ops.md index 0d4d52eea55..8d84df05302 100644 --- a/tensorflow/g3doc/api_docs/python/io_ops.md +++ b/tensorflow/g3doc/api_docs/python/io_ops.md @@ -1773,6 +1773,12 @@ Output strings (e.g. filenames) to a queue for an input pipeline. A queue with the output strings. A `QueueRunner` for the Queue is added to the current `Graph`'s `QUEUE_RUNNER` collection. +##### Raises: + + +* `ValueError`: If the string_tensor is a null Python list. At runtime, + will fail with an assertion if string_tensor becomes a null tensor. + ### Batching at the end of an input pipeline diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md index 43261de10bf..346c2fbf956 100644 --- a/tensorflow/g3doc/api_docs/python/math_ops.md +++ b/tensorflow/g3doc/api_docs/python/math_ops.md @@ -23,7 +23,7 @@ Returns x + y element-wise. ##### Args: -* `x`: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int8`, `int16`, `int32`, `complex64`, `int64`. +* `x`: A `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`. * `y`: A `Tensor`. Must have the same type as `x`. * `name`: A name for the operation (optional). @@ -59,7 +59,7 @@ Returns x * y element-wise. ##### Args: -* `x`: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int8`, `int16`, `int32`, `complex64`, `int64`. +* `x`: A `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`. * `y`: A `Tensor`. Must have the same type as `x`. * `name`: A name for the operation (optional). @@ -77,7 +77,7 @@ Returns x / y element-wise. ##### Args: -* `x`: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `complex64`, `int64`. +* `x`: A `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`. * `y`: A `Tensor`. Must have the same type as `x`. * `name`: A name for the operation (optional). diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 068e5f2ec47..67c315745dc 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -9,11 +9,10 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by ## Activation Functions -The activation ops provide different types of nonlinearities for use in -neural networks. These include smooth nonlinearities (`sigmoid`, -`tanh`, and `softplus`), continuous but not everywhere differentiable -functions (`relu`, `relu6`, and `relu_x`), and random regularization -(`dropout`). +The activation ops provide different types of nonlinearities for use in neural +networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `softplus`, +and `softsign`), continuous but not everywhere differentiable functions (`relu`, +`relu6`, and `relu_x`), and random regularization (`dropout`). All activation ops apply componentwise, and produce a tensor of the same shape as the input tensor. @@ -62,6 +61,23 @@ Computes softplus: `log(exp(features) + 1)`. ##### Args: +* `features`: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`. +* `name`: A name for the operation (optional). + +##### Returns: + + A `Tensor`. Has the same type as `features`. + + +- - - + +### `tf.nn.softsign(features, name=None)` {#softsign} + +Computes softsign: `features / (abs(features) + 1)`. + +##### Args: + + * `features`: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`. * `name`: A name for the operation (optional). @@ -1228,3 +1244,89 @@ target classes as noise classes for the same example. Each value is `-FLOAT_MAX`. + +## Other Functions and Classes +- - - + +### `tf.nn.rnn(cell, inputs, initial_state=None, dtype=None, sequence_length=None, scope=None)` {#rnn} + +Creates a recurrent neural network specified by RNNCell "cell". + +##### The simplest form of RNN network generated is: + + state = cell.zero_state(...) + outputs = [] + states = [] + for input_ in inputs: + output, state = cell(input_, state) + outputs.append(output) + states.append(state) + return (outputs, states) + +However, a few other options are available: + +An initial state can be provided. +If sequence_length is provided, dynamic calculation is performed. + +Dynamic calculation returns, at time t: + (t >= max(sequence_length) + ? (zeros(output_shape), zeros(state_shape)) + : cell(input, state) + +Thus saving computational time when unrolling past the max sequence length. + +##### Args: + + +* `cell`: An instance of RNNCell. +* `inputs`: A length T list of inputs, each a vector with shape [batch_size]. +* `initial_state`: (optional) An initial state for the RNN. This must be + a tensor of appropriate type and shape [batch_size x cell.state_size]. +* `dtype`: (optional) The data type for the initial state. Required if + initial_state is not provided. +* `sequence_length`: An int64 vector (tensor) size [batch_size]. +* `scope`: VariableScope for the created subgraph; defaults to "RNN". + +##### Returns: + + A pair (outputs, states) where: + outputs is a length T list of outputs (one for each input) + states is a length T list of states (one state following each input) + +##### Raises: + + +* `TypeError`: If "cell" is not an instance of RNNCell. +* `ValueError`: If inputs is None or an empty list. + + +- - - + +### `tf.nn.state_saving_rnn(cell, inputs, state_saver, state_name, sequence_length=None, scope=None)` {#state_saving_rnn} + +RNN that accepts a state saver for time-truncated RNN calculation. + +##### Args: + + +* `cell`: An instance of RNNCell. +* `inputs`: A length T list of inputs, each a vector with shape [batch_size]. +* `state_saver`: A state saver object with methods `state` and `save_state`. +* `state_name`: The name to use with the state_saver. +* `sequence_length`: (optional) An int64 vector (tensor) size [batch_size]. + See the documentation for rnn() for more details about sequence_length. +* `scope`: VariableScope for the created subgraph; defaults to "RNN". + +##### Returns: + + A pair (outputs, states) where: + outputs is a length T list of outputs (one for each input) + states is a length T list of states (one state following each input) + +##### Raises: + + +* `TypeError`: If "cell" is not an instance of RNNCell. +* `ValueError`: If inputs is None or an empty list. + + diff --git a/tensorflow/g3doc/api_docs/python/sparse_ops.md b/tensorflow/g3doc/api_docs/python/sparse_ops.md index 4c7db4b10f5..99a075f14d8 100644 --- a/tensorflow/g3doc/api_docs/python/sparse_ops.md +++ b/tensorflow/g3doc/api_docs/python/sparse_ops.md @@ -43,23 +43,23 @@ dense[tuple(indices[i])] = values[i] ``` By convention, `indices` should be sorted in row-major order (or equivalently -lexigraphic order on the tuples `indices[i]`). This is not enforced when -`SparseTensor` objects are constructed, but most Ops assume correct ordering. +lexicographic order on the tuples `indices[i]`). This is not enforced when +`SparseTensor` objects are constructed, but most ops assume correct ordering. If the ordering is wrong, it can be fixed by calling `sparse_reorder` on the misordered `SparseTensor`. Example: The sparse tensor ```python - SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4]) +SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4]) ``` represents the dense tensor ```python - [[1, 0, 0, 0] - [0, 0, 2, 0] - [0, 0, 0, 0]] +[[1, 0, 0, 0] + [0, 0, 2, 0] + [0, 0, 0, 0]] ``` - - - @@ -73,7 +73,7 @@ Creates a `SparseTensor`. * `indices`: A 2-D int64 tensor of shape `[N, ndims]`. * `values`: A 1-D tensor of any type and shape `[N]`. -* `dense_shape`: A 1-D int64 tensor of shape `[ndims]`. +* `shape`: A 1-D int64 tensor of shape `[ndims]`. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index cb9a090ebda..8b2e8b379f7 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -380,6 +380,51 @@ The `Operation` of this variable. +#### Other Methods +- - - + +#### `tf.Variable.ref()` {#Variable.ref} + +Returns a reference to this variable. + +You usually do not need to call this method as all ops that need a reference +to the variable call it automatically. + +Returns is a `Tensor` which holds a reference to the variable. You can +assign a new value to the variable by passing the tensor to an assign op. +See [`value()`](#Variable.value) if you want to get the value of the +variable. + +##### Returns: + + A `Tensor` that is a reference to the variable. + + +- - - + +#### `tf.Variable.value()` {#Variable.value} + +Returns the last snapshot of this variable. + +You usually do not need to call this method as all ops that need the value +of the variable call it automatically through a `convert_to_tensor()` call. + +Returns a `Tensor` which holds the value of the variable. You can not +assign a new value to this tensor as it is not a reference to the variable. +See [`ref()`](#Variable.ref) if you want to get a reference to the +variable. + +To avoid copies, if the consumer of the returned value is on the same device +as the variable, this actually returns the live value of the variable, not +a copy. Updates to the variable are seen by the consumer. If the consumer +is on a different device it will get a copy of the variable. + +##### Returns: + + A `Tensor` containing the value of the variable. + + + ## Variable helper functions diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index 6b36d913565..b686968a8c1 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -192,6 +192,7 @@ applies gradients. * `TypeError`: if `grads_and_vars` is malformed. +* `ValueError`: if none of the variables have gradients. @@ -388,9 +389,9 @@ current good choice is 1.0 or 0.1. * `beta1`: A float value or a constant float tensor. The exponential decay rate for the 1st moment estimates. * `beta2`: A float value or a constant float tensor. - The exponential decay rate for the 2st moment estimates. + The exponential decay rate for the 2nd moment estimates. * `epsilon`: A small constant for numerical stability. -* `use_locking`: If True use locks for update operation.s +* `use_locking`: If True use locks for update operations. * `name`: Optional name for the operations created when applying gradients. Defaults to "Adam". diff --git a/tensorflow/g3doc/get_started/basic_usage.md b/tensorflow/g3doc/get_started/basic_usage.md index e09b574cc7a..cca15c1de46 100644 --- a/tensorflow/g3doc/get_started/basic_usage.md +++ b/tensorflow/g3doc/get_started/basic_usage.md @@ -274,8 +274,8 @@ tf.placeholder() to create them: ```python -input1 = tf.placeholder(tf.types.float32) -input2 = tf.placeholder(tf.types.float32) +input1 = tf.placeholder(tf.float32) +input2 = tf.placeholder(tf.float32) output = tf.mul(input1, input2) with tf.Session() as sess: diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md index 150ad8d6e68..fe943fac6cf 100644 --- a/tensorflow/g3doc/how_tos/adding_an_op/index.md +++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md @@ -22,7 +22,7 @@ to: * Optionally, write a function to compute gradients for the Op. * Optionally, write a function that describes the input and output shapes for the Op. This allows shape inference to work with your Op. -* Test the Op, typically in Python. +* Test the Op, typically in Python. If you define gradients, you can verify them with the Python [`GradientChecker`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/kernel_tests/gradient_checker.py). [TOC] diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py index 49df02d2cbd..25229213b9a 100644 --- a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py +++ b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py @@ -24,7 +24,6 @@ import tensorflow.python.platform import tensorflow as tf from tensorflow.g3doc.how_tos.adding_an_op import gen_zero_out_op_2 from tensorflow.g3doc.how_tos.adding_an_op import zero_out_grad_2 -from tensorflow.python.kernel_tests import gradient_checker class ZeroOut2Test(tf.test.TestCase): @@ -39,7 +38,7 @@ class ZeroOut2Test(tf.test.TestCase): shape = (5,) x = tf.constant([5, 4, 3, 2, 1], dtype=tf.float32) y = gen_zero_out_op_2.zero_out(x) - err = gradient_checker.ComputeGradientError(x, shape, y, shape) + err = tf.test.compute_gradient_error(x, shape, y, shape) self.assertLess(err, 1e-4) diff --git a/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py b/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py index 00b351545c2..ce3b016798f 100644 --- a/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py +++ b/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py @@ -53,7 +53,7 @@ def convert_to(images, labels, name): num_examples = labels.shape[0] if images.shape[0] != num_examples: raise ValueError("Images size %d does not match label size %d." % - (dat.shape[0], num_examples)) + (images.shape[0], num_examples)) rows = images.shape[1] cols = images.shape[2] depth = images.shape[3] diff --git a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md index fdec071aeec..f1b7bb8205b 100644 --- a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md +++ b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md @@ -62,18 +62,66 @@ Now that you've modified your graph and have a `SummaryWriter`, you're ready to start running your network! If you want, you could run the merged summary op every single step, and record a ton of training data. That's likely to be more data than you need, though. Instead, consider running the merged summary op -every hundred steps or so, as in the following code example. +every `n` steps. + +The code example below is a modification of the [simple MNIST tutorial] +(http://tensorflow.org/tutorials/mnist/beginners/index.md), in which we have +added some summary ops, and run them every ten steps. If you run this and then +launch `tensorboard --logdir=/tmp/mnist_data`, you'll be able to visualize +statistics, such as how the weights or accuracy varied during training. +The code below is an exerpt; full source is [here](mnist_with_summaries.py). ```python -merged_summary_op = tf.merge_all_summaries() -summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph_def) -total_step = 0 -while training: - total_step += 1 - session.run(training_op) - if total_step % 100 == 0: - summary_str = session.run(merged_summary_op) - summary_writer.add_summary(summary_str, total_step) +# Create the model +x = tf.placeholder("float", [None, 784], name="x-input") +W = tf.Variable(tf.zeros([784,10]), name="weights") +b = tf.Variable(tf.zeros([10], name="bias")) + +# use a name scope to organize nodes in the graph visualizer +with tf.name_scope("Wx_b") as scope: + y = tf.nn.softmax(tf.matmul(x,W) + b) + +# Add summary ops to collect data +w_hist = tf.histogram_summary("weights", W) +b_hist = tf.histogram_summary("biases", b) +y_hist = tf.histogram_summary("y", y) + +# Define loss and optimizer +y_ = tf.placeholder("float", [None,10], name="y-input") +# More name scopes will clean up the graph representation +with tf.name_scope("xent") as scope: + cross_entropy = -tf.reduce_sum(y_*tf.log(y)) + ce_summ = tf.scalar_summary("cross entropy", cross_entropy) +with tf.name_scope("train") as scope: + train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) + +with tf.name_scope("test") as scope: + correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) + accuracy_summary = tf.scalar_summary("accuracy", accuracy) + +# Merge all the summaries and write them out to /tmp/mnist_logs +merged = tf.merge_all_summaries() +writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph_def) +tf.initialize_all_variables().run() + +# Train the model, and feed in test data and record summaries every 10 steps + +for i in range(1000): + if i % 10 == 0: # Record summary data, and the accuracy + feed = {x: mnist.test.images, y_: mnist.test.labels} + result = sess.run([merged, accuracy], feed_dict=feed) + summary_str = result[0] + acc = result[1] + writer.add_summary(summary_str, i) + print("Accuracy at step %s: %s" % (i, acc)) + else: + batch_xs, batch_ys = mnist.train.next_batch(100) + feed = {x: batch_xs, y_: batch_ys} + sess.run(train_step, feed_dict=feed) + +print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})) + ``` You're now all set to visualize this data using TensorBoard. diff --git a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/mnist_with_summaries.py b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/mnist_with_summaries.py new file mode 100644 index 00000000000..cea82b137ee --- /dev/null +++ b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/mnist_with_summaries.py @@ -0,0 +1,69 @@ +"""A very simple MNIST classifer, modified to display data in TensorBoard + +See extensive documentation for the original model at +http://tensorflow.org/tutorials/mnist/beginners/index.md + +See documentaion on the TensorBoard specific pieces at +http://tensorflow.org/how_tos/summaries_and_tensorboard/index.md + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Import data +import input_data +mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) + +import tensorflow as tf +sess = tf.InteractiveSession() + +# Create the model +x = tf.placeholder("float", [None, 784], name="x-input") +W = tf.Variable(tf.zeros([784,10]), name="weights") +b = tf.Variable(tf.zeros([10], name="bias")) + +# use a name scope to organize nodes in the graph visualizer +with tf.name_scope("Wx_b") as scope: + y = tf.nn.softmax(tf.matmul(x,W) + b) + +# Add summary ops to collect data +w_hist = tf.histogram_summary("weights", W) +b_hist = tf.histogram_summary("biases", b) +y_hist = tf.histogram_summary("y", y) + +# Define loss and optimizer +y_ = tf.placeholder("float", [None,10], name="y-input") +# More name scopes will clean up the graph representation +with tf.name_scope("xent") as scope: + cross_entropy = -tf.reduce_sum(y_*tf.log(y)) + ce_summ = tf.scalar_summary("cross entropy", cross_entropy) +with tf.name_scope("train") as scope: + train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) + +with tf.name_scope("test") as scope: + correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) + accuracy_summary = tf.scalar_summary("accuracy", accuracy) + +# Merge all the summaries and write them out to /tmp/mnist_logs +merged = tf.merge_all_summaries() +writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph_def) +tf.initialize_all_variables().run() + +# Train the model, and feed in test data and record summaries every 10 steps + +for i in range(1000): + if i % 10 == 0: # Record summary data, and the accuracy + feed = {x: mnist.test.images, y_: mnist.test.labels} + result = sess.run([merged, accuracy], feed_dict=feed) + summary_str = result[0] + acc = result[1] + writer.add_summary(summary_str, i) + print("Accuracy at step %s: %s" % (i, acc)) + else: + batch_xs, batch_ys = mnist.train.next_batch(100) + feed = {x: batch_xs, y_: batch_ys} + sess.run(train_step, feed_dict=feed) + +print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})) diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md index fc29a47ceba..44efd432352 100644 --- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md +++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md @@ -224,13 +224,13 @@ We describe these interacting operations by manipulating symbolic variables. Let's create one: ```python -x = tf.placeholder("float", [None, 784]) +x = tf.placeholder(tf.float32, [None, 784]) ``` `x` isn't a specific value. It's a `placeholder`, a value that we'll input when we ask TensorFlow to run a computation. We want to be able to input any number of MNIST images, each flattened into a 784-dimensional vector. We represent -this as a 2d tensor of floating point numbers, with a shape `[None, 784]`. +this as a 2-D tensor of floating-point numbers, with a shape `[None, 784]`. (Here `None` means that a dimension can be of any length.) We also need the weights and biases for our model. We could imagine treating @@ -242,7 +242,7 @@ operations. It can be used and even modified by the computation. For machine learning applications, one generally has the model parameters be `Variable`s. ```python -W = tf.Variable(tf.zeros([784,10])) +W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) ``` @@ -259,10 +259,10 @@ to the output. We can now implement our model. It only takes one line! ```python -y = tf.nn.softmax(tf.matmul(x,W) + b) +y = tf.nn.softmax(tf.matmul(x, W) + b) ``` -First, we multiply `x` by `W` with the expression `tf.matmul(x,W)`. This is +First, we multiply `x` by `W` with the expression `tf.matmul(x, W)`. This is flipped from when we multiplied them in our equation, where we had \\(Wx\\), as a small trick to deal with `x` being a 2D tensor with multiple inputs. We then add `b`, and @@ -301,7 +301,7 @@ To implement cross-entropy we need to first add a new placeholder to input the correct answers: ```python -y_ = tf.placeholder("float", [None,10]) +y_ = tf.placeholder(tf.float32, [None, 10]) ``` Then we can implement the cross-entropy, \\(-\sum y'\log(y)\\): diff --git a/tensorflow/models/embedding/BUILD b/tensorflow/models/embedding/BUILD index 9cd0d24b5b4..fbed1b0a380 100644 --- a/tensorflow/models/embedding/BUILD +++ b/tensorflow/models/embedding/BUILD @@ -38,6 +38,9 @@ py_test( size = "small", srcs = ["word2vec_test.py"], srcs_version = "PY2AND3", + tags = [ + "notsan", # b/25864127 + ], deps = [ ":word2vec", "//tensorflow:tensorflow_py", diff --git a/tensorflow/models/rnn/BUILD b/tensorflow/models/rnn/BUILD index 1a81ce2801e..118884fd28d 100644 --- a/tensorflow/models/rnn/BUILD +++ b/tensorflow/models/rnn/BUILD @@ -7,8 +7,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("/tensorflow/tensorflow", "cuda_py_tests") - py_library( name = "linear", srcs = [ @@ -20,17 +18,6 @@ py_library( ], ) -py_test( - name = "linear_test", - size = "small", - srcs = ["linear_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":linear", - "//tensorflow:tensorflow_py", - ], -) - py_library( name = "rnn_cell", srcs = [ @@ -43,17 +30,6 @@ py_library( ], ) -py_test( - name = "rnn_cell_test", - size = "small", - srcs = ["rnn_cell_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":rnn_cell", - "//tensorflow:tensorflow_py", - ], -) - py_library( name = "package", srcs = [ @@ -79,16 +55,6 @@ py_library( ], ) -cuda_py_tests( - name = "rnn_tests", - srcs = [ - "rnn_test.py", - ], - additional_deps = [ - ":rnn", - ], -) - py_library( name = "seq2seq", srcs = [ @@ -101,18 +67,6 @@ py_library( ], ) -py_test( - name = "seq2seq_test", - srcs = [ - "seq2seq_test.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":seq2seq", - "//tensorflow:tensorflow_py", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/models/rnn/linear.py b/tensorflow/models/rnn/linear.py index 1c8eda67151..30b420087c8 100644 --- a/tensorflow/models/rnn/linear.py +++ b/tensorflow/models/rnn/linear.py @@ -12,57 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -"""Basic linear combinations that implicitly generate variables.""" - +"""Import linear python op for backward compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=g-bad-import-order,unused-import +import tensorflow.python.platform + import tensorflow as tf - -def linear(args, output_size, bias, bias_start=0.0, scope=None): - """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. - - Args: - args: a 2D Tensor or a list of 2D, batch x n, Tensors. - output_size: int, second dimension of W[i]. - bias: boolean, whether to add a bias term or not. - bias_start: starting value to initialize the bias; 0 by default. - scope: VariableScope for the created subgraph; defaults to "Linear". - - Returns: - A 2D Tensor with shape [batch x output_size] equal to - sum_i(args[i] * W[i]), where W[i]s are newly created matrices. - - Raises: - ValueError: if some of the arguments has unspecified or wrong shape. - """ - assert args - if not isinstance(args, (list, tuple)): - args = [args] - - # Calculate the total size of arguments on dimension 1. - total_arg_size = 0 - shapes = [a.get_shape().as_list() for a in args] - for shape in shapes: - if len(shape) != 2: - raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) - if not shape[1]: - raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) - else: - total_arg_size += shape[1] - - # Now the computation. - with tf.variable_scope(scope or "Linear"): - matrix = tf.get_variable("Matrix", [total_arg_size, output_size]) - if len(args) == 1: - res = tf.matmul(args[0], matrix) - else: - res = tf.matmul(tf.concat(1, args), matrix) - if not bias: - return res - bias_term = tf.get_variable("Bias", [output_size], - initializer=tf.constant_initializer(bias_start)) - return res + bias_term +linear = tf.nn.linear diff --git a/tensorflow/models/rnn/rnn.py b/tensorflow/models/rnn/rnn.py index b95bf98f723..9bfc978db1f 100644 --- a/tensorflow/models/rnn/rnn.py +++ b/tensorflow/models/rnn/rnn.py @@ -12,137 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -"""RNN helpers for TensorFlow models.""" +"""Import rnn python ops for backward compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import tensorflow as tf - -from tensorflow.models.rnn import rnn_cell -from tensorflow.python.ops import control_flow_ops - - -def rnn(cell, inputs, initial_state=None, dtype=None, - sequence_length=None, scope=None): - """Creates a recurrent neural network specified by RNNCell "cell". - - The simplest form of RNN network generated is: - state = cell.zero_state(...) - outputs = [] - states = [] - for input_ in inputs: - output, state = cell(input_, state) - outputs.append(output) - states.append(state) - return (outputs, states) - - However, a few other options are available: - - An initial state can be provided. - If sequence_length is provided, dynamic calculation is performed. - - Dynamic calculation returns, at time t: - (t >= max(sequence_length) - ? (zeros(output_shape), zeros(state_shape)) - : cell(input, state) - - Thus saving computational time when unrolling past the max sequence length. - - Args: - cell: An instance of RNNCell. - inputs: A length T list of inputs, each a vector with shape [batch_size]. - initial_state: (optional) An initial state for the RNN. This must be - a tensor of appropriate type and shape [batch_size x cell.state_size]. - dtype: (optional) The data type for the initial state. Required if - initial_state is not provided. - sequence_length: An int64 vector (tensor) size [batch_size]. - scope: VariableScope for the created subgraph; defaults to "RNN". - - Returns: - A pair (outputs, states) where: - outputs is a length T list of outputs (one for each input) - states is a length T list of states (one state following each input) - - Raises: - TypeError: If "cell" is not an instance of RNNCell. - ValueError: If inputs is None or an empty list. - """ - - if not isinstance(cell, rnn_cell.RNNCell): - raise TypeError("cell must be an instance of RNNCell") - if not isinstance(inputs, list): - raise TypeError("inputs must be a list") - if not inputs: - raise ValueError("inputs must not be empty") - - outputs = [] - states = [] - with tf.variable_scope(scope or "RNN"): - batch_size = tf.shape(inputs[0])[0] - if initial_state is not None: - state = initial_state - else: - if not dtype: - raise ValueError("If no initial_state is provided, dtype must be.") - state = cell.zero_state(batch_size, dtype) - - if sequence_length: # Prepare variables - zero_output_state = ( - tf.zeros(tf.pack([batch_size, cell.output_size]), - inputs[0].dtype), - tf.zeros(tf.pack([batch_size, cell.state_size]), - state.dtype)) - max_sequence_length = tf.reduce_max(sequence_length) - - for time, input_ in enumerate(inputs): - if time > 0: tf.get_variable_scope().reuse_variables() - # pylint: disable=cell-var-from-loop - def output_state(): - return cell(input_, state) - # pylint: enable=cell-var-from-loop - if sequence_length: - (output, state) = control_flow_ops.cond( - time >= max_sequence_length, - lambda: zero_output_state, output_state) - else: - (output, state) = output_state() - - outputs.append(output) - states.append(state) - - return (outputs, states) - - -def state_saving_rnn(cell, inputs, state_saver, state_name, - sequence_length=None, scope=None): - """RNN that accepts a state saver for time-truncated RNN calculation. - - Args: - cell: An instance of RNNCell. - inputs: A length T list of inputs, each a vector with shape [batch_size]. - state_saver: A state saver object with methods `state` and `save_state`. - state_name: The name to use with the state_saver. - sequence_length: (optional) An int64 vector (tensor) size [batch_size]. - See the documentation for rnn() for more details about sequence_length. - scope: VariableScope for the created subgraph; defaults to "RNN". - - Returns: - A pair (outputs, states) where: - outputs is a length T list of outputs (one for each input) - states is a length T list of states (one state following each input) - - Raises: - TypeError: If "cell" is not an instance of RNNCell. - ValueError: If inputs is None or an empty list. - """ - initial_state = state_saver.state(state_name) - (outputs, states) = rnn(cell, inputs, initial_state=initial_state, - sequence_length=sequence_length, scope=scope) - save_state = state_saver.save_state(state_name, states[-1]) - with tf.control_dependencies([save_state]): - outputs[-1] = tf.identity(outputs[-1]) - - return (outputs, states) +# pylint: disable=g-bad-import-order,wildcard-import,unused-import +import tensorflow.python.platform +from tensorflow.python.ops.rnn import * diff --git a/tensorflow/models/rnn/rnn_cell.py b/tensorflow/models/rnn/rnn_cell.py index bdedfebd7fc..6ff94e8026f 100644 --- a/tensorflow/models/rnn/rnn_cell.py +++ b/tensorflow/models/rnn/rnn_cell.py @@ -12,614 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Import rnn_cell python ops for backward compatibility.""" -"""Module for constructing RNN Cells.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - -from six.moves import xrange # pylint: disable=redefined-builtin -import tensorflow as tf - -from tensorflow.models.rnn import linear - - -class RNNCell(object): - """Abstract object representing an RNN cell. - - An RNN cell, in the most abstract setting, is anything that has - a state -- a vector of floats of size self.state_size -- and performs some - operation that takes inputs of size self.input_size. This operation - results in an output of size self.output_size and a new state. - - This module provides a number of basic commonly used RNN cells, such as - LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number - of operators that allow add dropouts, projections, or embeddings for inputs. - Constructing multi-layer cells is supported by a super-class, MultiRNNCell, - defined later. Every RNNCell must have the properties below and and - implement __call__ with the following signature. - """ - - def __call__(self, inputs, state, scope=None): - """Run this RNN cell on inputs, starting from the given state. - - Args: - inputs: 2D Tensor with shape [batch_size x self.input_size]. - state: 2D Tensor with shape [batch_size x self.state_size]. - scope: VariableScope for the created subgraph; defaults to class name. - - Returns: - A pair containing: - - Output: A 2D Tensor with shape [batch_size x self.output_size] - - New state: A 2D Tensor with shape [batch_size x self.state_size]. - """ - raise NotImplementedError("Abstract method") - - @property - def input_size(self): - """Integer: size of inputs accepted by this cell.""" - raise NotImplementedError("Abstract method") - - @property - def output_size(self): - """Integer: size of outputs produced by this cell.""" - raise NotImplementedError("Abstract method") - - @property - def state_size(self): - """Integer: size of state used by this cell.""" - raise NotImplementedError("Abstract method") - - def zero_state(self, batch_size, dtype): - """Return state tensor (shape [batch_size x state_size]) filled with 0. - - Args: - batch_size: int, float, or unit Tensor representing the batch size. - dtype: the data type to use for the state. - - Returns: - A 2D Tensor of shape [batch_size x state_size] filled with zeros. - """ - zeros = tf.zeros(tf.pack([batch_size, self.state_size]), dtype=dtype) - zeros.set_shape([None, self.state_size]) - return zeros - - -class BasicRNNCell(RNNCell): - """The most basic RNN cell.""" - - def __init__(self, num_units): - self._num_units = num_units - - @property - def input_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - @property - def state_size(self): - return self._num_units - - def __call__(self, inputs, state, scope=None): - """Most basic RNN: output = new_state = tanh(W * input + U * state + B).""" - with tf.variable_scope(scope or type(self).__name__): # "BasicRNNCell" - output = tf.tanh(linear.linear([inputs, state], self._num_units, True)) - return output, output - - -class GRUCell(RNNCell): - """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" - - def __init__(self, num_units): - self._num_units = num_units - - @property - def input_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - @property - def state_size(self): - return self._num_units - - def __call__(self, inputs, state, scope=None): - """Gated recurrent unit (GRU) with nunits cells.""" - with tf.variable_scope(scope or type(self).__name__): # "GRUCell" - with tf.variable_scope("Gates"): # Reset gate and update gate. - # We start with bias of 1.0 to not reset and not udpate. - r, u = tf.split(1, 2, linear.linear([inputs, state], - 2 * self._num_units, True, 1.0)) - r, u = tf.sigmoid(r), tf.sigmoid(u) - with tf.variable_scope("Candidate"): - c = tf.tanh(linear.linear([inputs, r * state], self._num_units, True)) - new_h = u * state + (1 - u) * c - return new_h, new_h - - -class BasicLSTMCell(RNNCell): - """Basic LSTM recurrent network cell. - - The implementation is based on: http://arxiv.org/pdf/1409.2329v5.pdf. - - It does not allow cell clipping, a projection layer, and does not - use peep-hole connections: it is the basic baseline. - - Biases of the forget gate are initialized by default to 1 in order to reduce - the scale of forgetting in the beginning of the training. - """ - - def __init__(self, num_units, forget_bias=1.0): - self._num_units = num_units - self._forget_bias = forget_bias - - @property - def input_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - @property - def state_size(self): - return 2 * self._num_units - - def __call__(self, inputs, state, scope=None): - """Long short-term memory cell (LSTM).""" - with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" - # Parameters of gates are concatenated into one multiply for efficiency. - c, h = tf.split(1, 2, state) - concat = linear.linear([inputs, h], 4 * self._num_units, True) - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - i, j, f, o = tf.split(1, 4, concat) - - new_c = c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * tf.tanh(j) - new_h = tf.tanh(new_c) * tf.sigmoid(o) - - return new_h, tf.concat(1, [new_c, new_h]) - - -class LSTMCell(RNNCell): - """Long short-term memory unit (LSTM) recurrent network cell. - - This implementation is based on: - - https://research.google.com/pubs/archive/43905.pdf - - Hasim Sak, Andrew Senior, and Francoise Beaufays. - "Long short-term memory recurrent neural network architectures for - large scale acoustic modeling." INTERSPEECH, 2014. - - It uses peep-hole connections, optional cell clipping, and an optional - projection layer. - """ - - def __init__(self, num_units, input_size, - use_peepholes=False, cell_clip=None, - initializer=None, num_proj=None, - num_unit_shards=1, num_proj_shards=1): - """Initialize the parameters for an LSTM cell. - - Args: - num_units: int, The number of units in the LSTM cell - input_size: int, The dimensionality of the inputs into the LSTM cell - use_peepholes: bool, set True to enable diagonal/peephole connections. - cell_clip: (optional) A float value, if provided the cell state is clipped - by this value prior to the cell output activation. - initializer: (optional) The initializer to use for the weight and - projection matrices. - num_proj: (optional) int, The output dimensionality for the projection - matrices. If None, no projection is performed. - num_unit_shards: How to split the weight matrix. If >1, the weight - matrix is stored across num_unit_shards. - Note that num_unit_shards must evenly divide num_units * 4. - num_proj_shards: How to split the projection matrix. If >1, the - projection matrix is stored across num_proj_shards. - Note that num_proj_shards must evenly divide num_proj - (if num_proj is not None). - - Raises: - ValueError: if num_unit_shards doesn't divide 4 * num_units or - num_proj_shards doesn't divide num_proj - """ - self._num_units = num_units - self._input_size = input_size - self._use_peepholes = use_peepholes - self._cell_clip = cell_clip - self._initializer = initializer - self._num_proj = num_proj - self._num_unit_shards = num_unit_shards - self._num_proj_shards = num_proj_shards - - if (num_units * 4) % num_unit_shards != 0: - raise ValueError("num_unit_shards must evently divide 4 * num_units") - if num_proj and num_proj % num_proj_shards != 0: - raise ValueError("num_proj_shards must evently divide num_proj") - - if num_proj: - self._state_size = num_units + num_proj - self._output_size = num_proj - else: - self._state_size = 2 * num_units - self._output_size = num_units - - @property - def input_size(self): - return self._input_size - - @property - def output_size(self): - return self._output_size - - @property - def state_size(self): - return self._state_size - - def __call__(self, input_, state, scope=None): - """Run one step of LSTM. - - Args: - input_: input Tensor, 2D, batch x num_units. - state: state Tensor, 2D, batch x state_size. - scope: VariableScope for the created subgraph; defaults to "LSTMCell". - - Returns: - A tuple containing: - - A 2D, batch x output_dim, Tensor representing the output of the LSTM - after reading "input_" when previous state was "state". - Here output_dim is: - num_proj if num_proj was set, - num_units otherwise. - - A 2D, batch x state_size, Tensor representing the new state of LSTM - after reading "input_" when previous state was "state". - """ - num_proj = self._num_units if self._num_proj is None else self._num_proj - - c_prev = tf.slice(state, [0, 0], [-1, self._num_units]) - m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj]) - - dtype = input_.dtype - - unit_shard_size = (4 * self._num_units) // self._num_unit_shards - - with tf.variable_scope(scope or type(self).__name__): # "LSTMCell" - w = tf.concat( - 1, - [tf.get_variable("W_%d" % i, - shape=[self.input_size + num_proj, unit_shard_size], - initializer=self._initializer, - dtype=dtype) for i in xrange(self._num_unit_shards)]) - - b = tf.get_variable( - "B", shape=[4 * self._num_units], - initializer=tf.zeros_initializer, dtype=dtype) - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - cell_inputs = tf.concat(1, [input_, m_prev]) - i, j, f, o = tf.split(1, 4, tf.nn.bias_add(tf.matmul(cell_inputs, w), b)) - - # Diagonal connections - if self._use_peepholes: - w_f_diag = tf.get_variable( - "W_F_diag", shape=[self._num_units], dtype=dtype) - w_i_diag = tf.get_variable( - "W_I_diag", shape=[self._num_units], dtype=dtype) - w_o_diag = tf.get_variable( - "W_O_diag", shape=[self._num_units], dtype=dtype) - - if self._use_peepholes: - c = (tf.sigmoid(f + 1 + w_f_diag * c_prev) * c_prev + - tf.sigmoid(i + w_i_diag * c_prev) * tf.tanh(j)) - else: - c = (tf.sigmoid(f + 1) * c_prev + tf.sigmoid(i) * tf.tanh(j)) - - if self._cell_clip is not None: - c = tf.clip_by_value(c, -self._cell_clip, self._cell_clip) - - if self._use_peepholes: - m = tf.sigmoid(o + w_o_diag * c) * tf.tanh(c) - else: - m = tf.sigmoid(o) * tf.tanh(c) - - if self._num_proj is not None: - proj_shard_size = self._num_proj // self._num_proj_shards - w_proj = tf.concat( - 1, - [tf.get_variable("W_P_%d" % i, - shape=[self._num_units, proj_shard_size], - initializer=self._initializer, - dtype=dtype) - for i in xrange(self._num_proj_shards)]) - # TODO(ebrevdo), use matmulsum - m = tf.matmul(m, w_proj) - - return m, tf.concat(1, [c, m]) - - -class OutputProjectionWrapper(RNNCell): - """Operator adding an output projection to the given cell. - - Note: in many cases it may be more efficient to not use this wrapper, - but instead concatenate the whole sequence of your outputs in time, - do the projection on this batch-concated sequence, then split it - if needed or directly feed into a softmax. - """ - - def __init__(self, cell, output_size): - """Create a cell with output projection. - - Args: - cell: an RNNCell, a projection to output_size is added to it. - output_size: integer, the size of the output after projection. - - Raises: - TypeError: if cell is not an RNNCell. - ValueError: if output_size is not positive. - """ - if not isinstance(cell, RNNCell): - raise TypeError("The parameter cell is not RNNCell.") - if output_size < 1: - raise ValueError("Parameter output_size must be > 0: %d." % output_size) - self._cell = cell - self._output_size = output_size - - @property - def input_size(self): - return self._cell.input_size - - @property - def output_size(self): - return self._output_size - - @property - def state_size(self): - return self._cell.state_size - - def __call__(self, inputs, state, scope=None): - """Run the cell and output projection on inputs, starting from state.""" - output, res_state = self._cell(inputs, state) - # Default scope: "OutputProjectionWrapper" - with tf.variable_scope(scope or type(self).__name__): - projected = linear.linear(output, self._output_size, True) - return projected, res_state - - -class InputProjectionWrapper(RNNCell): - """Operator adding an input projection to the given cell. - - Note: in many cases it may be more efficient to not use this wrapper, - but instead concatenate the whole sequence of your inputs in time, - do the projection on this batch-concated sequence, then split it. - """ - - def __init__(self, cell, input_size): - """Create a cell with input projection. - - Args: - cell: an RNNCell, a projection of inputs is added before it. - input_size: integer, the size of the inputs before projection. - - Raises: - TypeError: if cell is not an RNNCell. - ValueError: if input_size is not positive. - """ - if not isinstance(cell, RNNCell): - raise TypeError("The parameter cell is not RNNCell.") - if input_size < 1: - raise ValueError("Parameter input_size must be > 0: %d." % input_size) - self._cell = cell - self._input_size = input_size - - @property - def input_size(self): - return self._input_size - - @property - def output_size(self): - return self._cell.output_size - - @property - def state_size(self): - return self._cell.state_size - - def __call__(self, inputs, state, scope=None): - """Run the input projection and then the cell.""" - # Default scope: "InputProjectionWrapper" - with tf.variable_scope(scope or type(self).__name__): - projected = linear.linear(inputs, self._cell.input_size, True) - return self._cell(projected, state) - - -class DropoutWrapper(RNNCell): - """Operator adding dropout to inputs and outputs of the given cell.""" - - def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, - seed=None): - """Create a cell with added input and/or output dropout. - - Dropout is never used on the state. - - Args: - cell: an RNNCell, a projection to output_size is added to it. - input_keep_prob: unit Tensor or float between 0 and 1, input keep - probability; if it is float and 1, no input dropout will be added. - output_keep_prob: unit Tensor or float between 0 and 1, output keep - probability; if it is float and 1, no output dropout will be added. - seed: (optional) integer, the randomness seed. - - Raises: - TypeError: if cell is not an RNNCell. - ValueError: if keep_prob is not between 0 and 1. - """ - if not isinstance(cell, RNNCell): - raise TypeError("The parameter cell is not a RNNCell.") - if (isinstance(input_keep_prob, float) and - not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)): - raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d" - % input_keep_prob) - if (isinstance(output_keep_prob, float) and - not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)): - raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d" - % output_keep_prob) - self._cell = cell - self._input_keep_prob = input_keep_prob - self._output_keep_prob = output_keep_prob - self._seed = seed - - @property - def input_size(self): - return self._cell.input_size - - @property - def output_size(self): - return self._cell.output_size - - @property - def state_size(self): - return self._cell.state_size - - def __call__(self, inputs, state): - """Run the cell with the declared dropouts.""" - if (not isinstance(self._input_keep_prob, float) or - self._input_keep_prob < 1): - inputs = tf.nn.dropout(inputs, self._input_keep_prob, seed=self._seed) - output, new_state = self._cell(inputs, state) - if (not isinstance(self._output_keep_prob, float) or - self._output_keep_prob < 1): - output = tf.nn.dropout(output, self._output_keep_prob, seed=self._seed) - return output, new_state - - -class EmbeddingWrapper(RNNCell): - """Operator adding input embedding to the given cell. - - Note: in many cases it may be more efficient to not use this wrapper, - but instead concatenate the whole sequence of your inputs in time, - do the embedding on this batch-concated sequence, then split it and - feed into your RNN. - """ - - def __init__(self, cell, embedding_classes=0, embedding=None, - initializer=None): - """Create a cell with an added input embedding. - - Args: - cell: an RNNCell, an embedding will be put before its inputs. - embedding_classes: integer, how many symbols will be embedded. - embedding: Variable, the embedding to use; if None, a new embedding - will be created; if set, then embedding_classes is not required. - initializer: an initializer to use when creating the embedding; - if None, the initializer from variable scope or a default one is used. - - Raises: - TypeError: if cell is not an RNNCell. - ValueError: if embedding_classes is not positive. - """ - if not isinstance(cell, RNNCell): - raise TypeError("The parameter cell is not RNNCell.") - if embedding_classes < 1 and embedding is None: - raise ValueError("Pass embedding or embedding_classes must be > 0: %d." - % embedding_classes) - if embedding_classes > 0 and embedding is not None: - if embedding.size[0] != embedding_classes: - raise ValueError("You declared embedding_classes=%d but passed an " - "embedding for %d classes." % (embedding.size[0], - embedding_classes)) - if embedding.size[1] != cell.input_size: - raise ValueError("You passed embedding with output size %d and a cell" - " that accepts size %d." % (embedding.size[1], - cell.input_size)) - self._cell = cell - self._embedding_classes = embedding_classes - self._embedding = embedding - self._initializer = initializer - - @property - def input_size(self): - return 1 - - @property - def output_size(self): - return self._cell.output_size - - @property - def state_size(self): - return self._cell.state_size - - def __call__(self, inputs, state, scope=None): - """Run the cell on embedded inputs.""" - with tf.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper" - with tf.device("/cpu:0"): - if self._embedding: - embedding = self._embedding - else: - if self._initializer: - initializer = self._initializer - elif tf.get_variable_scope().initializer: - initializer = tf.get_variable_scope().initializer - else: - # Default initializer for embeddings should have variance=1. - sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. - initializer = tf.random_uniform_initializer(-sqrt3, sqrt3) - embedding = tf.get_variable("embedding", [self._embedding_classes, - self._cell.input_size], - initializer=initializer) - embedded = tf.nn.embedding_lookup(embedding, tf.reshape(inputs, [-1])) - return self._cell(embedded, state) - - -class MultiRNNCell(RNNCell): - """RNN cell composed sequentially of multiple simple cells.""" - - def __init__(self, cells): - """Create a RNN cell composed sequentially of a number of RNNCells. - - Args: - cells: list of RNNCells that will be composed in this order. - - Raises: - ValueError: if cells is empty (not allowed) or if their sizes don't match. - """ - if not cells: - raise ValueError("Must specify at least one cell for MultiRNNCell.") - for i in xrange(len(cells) - 1): - if cells[i + 1].input_size != cells[i].output_size: - raise ValueError("In MultiRNNCell, the input size of each next" - " cell must match the output size of the previous one." - " Mismatched output size in cell %d." % i) - self._cells = cells - - @property - def input_size(self): - return self._cells[0].input_size - - @property - def output_size(self): - return self._cells[-1].output_size - - @property - def state_size(self): - return sum([cell.state_size for cell in self._cells]) - - def __call__(self, inputs, state, scope=None): - """Run this multi-layer cell on inputs, starting from state.""" - with tf.variable_scope(scope or type(self).__name__): # "MultiRNNCell" - cur_state_pos = 0 - cur_inp = inputs - new_states = [] - for i, cell in enumerate(self._cells): - with tf.variable_scope("Cell%d" % i): - cur_state = tf.slice(state, [0, cur_state_pos], [-1, cell.state_size]) - cur_state_pos += cell.state_size - cur_inp, new_state = cell(cur_inp, cur_state) - new_states.append(new_state) - return cur_inp, tf.concat(1, new_states) +# pylint: disable=g-bad-import-order,wildcard-import,unused-import +import tensorflow.python.platform +from tensorflow.python.ops.rnn_cell import * diff --git a/tensorflow/models/rnn/seq2seq.py b/tensorflow/models/rnn/seq2seq.py index 77782ee9347..3732a096922 100644 --- a/tensorflow/models/rnn/seq2seq.py +++ b/tensorflow/models/rnn/seq2seq.py @@ -12,757 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Import seq2seq python ops for backward compatibility.""" -"""Library for creating sequence-to-sequence models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=g-bad-import-order,wildcard-import,unused-import import tensorflow.python.platform - -from six.moves import xrange # pylint: disable=redefined-builtin -import tensorflow as tf - -from tensorflow.models.rnn import linear -from tensorflow.models.rnn import rnn -from tensorflow.models.rnn import rnn_cell - - -def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, - scope=None): - """RNN decoder for the sequence-to-sequence model. - - Args: - decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - initial_state: 2D Tensor with shape [batch_size x cell.state_size]. - cell: rnn_cell.RNNCell defining the cell function and size. - loop_function: if not None, this function will be applied to i-th output - in order to generate i+1-th input, and decoder_inputs will be ignored, - except for the first element ("GO" symbol). This can be used for decoding, - but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. - Signature -- loop_function(prev, i) = next - * prev is a 2D Tensor of shape [batch_size x cell.output_size], - * i is an integer, the step number (when advanced control is needed), - * next is a 2D Tensor of shape [batch_size x cell.input_size]. - scope: VariableScope for the created subgraph; defaults to "rnn_decoder". - - Returns: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x cell.output_size] containing generated outputs. - states: The state of each cell in each time-step. This is a list with - length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. - (Note that in some cases, like basic RNN cell or GRU cell, outputs and - states can be the same. They are different for LSTM cells though.) - """ - with tf.variable_scope(scope or "rnn_decoder"): - states = [initial_state] - outputs = [] - prev = None - for i in xrange(len(decoder_inputs)): - inp = decoder_inputs[i] - if loop_function is not None and prev is not None: - with tf.variable_scope("loop_function", reuse=True): - # We do not propagate gradients over the loop function. - inp = tf.stop_gradient(loop_function(prev, i)) - if i > 0: - tf.get_variable_scope().reuse_variables() - output, new_state = cell(inp, states[-1]) - outputs.append(output) - states.append(new_state) - if loop_function is not None: - prev = tf.stop_gradient(output) - return outputs, states - - -def basic_rnn_seq2seq( - encoder_inputs, decoder_inputs, cell, dtype=tf.float32, scope=None): - """Basic RNN sequence-to-sequence model. - - This model first runs an RNN to encode encoder_inputs into a state vector, and - then runs decoder, initialized with the last encoder state, on decoder_inputs. - Encoder and decoder use the same RNN cell type, but don't share parameters. - - Args: - encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - cell: rnn_cell.RNNCell defining the cell function and size. - dtype: The dtype of the initial state of the RNN cell (default: tf.float32). - scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". - - Returns: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x cell.output_size] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. - """ - with tf.variable_scope(scope or "basic_rnn_seq2seq"): - _, enc_states = rnn.rnn(cell, encoder_inputs, dtype=dtype) - return rnn_decoder(decoder_inputs, enc_states[-1], cell) - - -def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, - loop_function=None, dtype=tf.float32, scope=None): - """RNN sequence-to-sequence model with tied encoder and decoder parameters. - - This model first runs an RNN to encode encoder_inputs into a state vector, and - then runs decoder, initialized with the last encoder state, on decoder_inputs. - Encoder and decoder use the same RNN cell and share parameters. - - Args: - encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - cell: rnn_cell.RNNCell defining the cell function and size. - loop_function: if not None, this function will be applied to i-th output - in order to generate i+1-th input, and decoder_inputs will be ignored, - except for the first element ("GO" symbol), see rnn_decoder for details. - dtype: The dtype of the initial state of the rnn cell (default: tf.float32). - scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". - - Returns: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x cell.output_size] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. - """ - with tf.variable_scope("combined_tied_rnn_seq2seq"): - scope = scope or "tied_rnn_seq2seq" - _, enc_states = rnn.rnn( - cell, encoder_inputs, dtype=dtype, scope=scope) - tf.get_variable_scope().reuse_variables() - return rnn_decoder(decoder_inputs, enc_states[-1], cell, - loop_function=loop_function, scope=scope) - - -def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols, - output_projection=None, feed_previous=False, - scope=None): - """RNN decoder with embedding and a pure-decoding option. - - Args: - decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs). - initial_state: 2D Tensor [batch_size x cell.state_size]. - cell: rnn_cell.RNNCell defining the cell function. - num_symbols: integer, how many symbols come into the embedding. - output_projection: None or a pair (W, B) of output projection weights and - biases; W has shape [cell.output_size x num_symbols] and B has - shape [num_symbols]; if provided and feed_previous=True, each fed - previous output will first be multiplied by W and added B. - feed_previous: Boolean; if True, only the first of decoder_inputs will be - used (the "GO" symbol), and all other decoder inputs will be generated by: - next = embedding_lookup(embedding, argmax(previous_output)), - In effect, this implements a greedy decoder. It can also be used - during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. - If False, decoder_inputs are used as given (the standard decoder case). - scope: VariableScope for the created subgraph; defaults to - "embedding_rnn_decoder". - - Returns: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x cell.output_size] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. - - Raises: - ValueError: when output_projection has the wrong shape. - """ - if output_projection is not None: - proj_weights = tf.convert_to_tensor(output_projection[0], dtype=tf.float32) - proj_weights.get_shape().assert_is_compatible_with([cell.output_size, - num_symbols]) - proj_biases = tf.convert_to_tensor(output_projection[1], dtype=tf.float32) - proj_biases.get_shape().assert_is_compatible_with([num_symbols]) - - with tf.variable_scope(scope or "embedding_rnn_decoder"): - with tf.device("/cpu:0"): - embedding = tf.get_variable("embedding", [num_symbols, cell.input_size]) - - def extract_argmax_and_embed(prev, _): - """Loop_function that extracts the symbol from prev and embeds it.""" - if output_projection is not None: - prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]) - prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) - return tf.nn.embedding_lookup(embedding, prev_symbol) - - loop_function = None - if feed_previous: - loop_function = extract_argmax_and_embed - - emb_inp = [tf.nn.embedding_lookup(embedding, i) for i in decoder_inputs] - return rnn_decoder(emb_inp, initial_state, cell, - loop_function=loop_function) - - -def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, - num_encoder_symbols, num_decoder_symbols, - output_projection=None, feed_previous=False, - dtype=tf.float32, scope=None): - """Embedding RNN sequence-to-sequence model. - - This model first embeds encoder_inputs by a newly created embedding (of shape - [num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode - embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs - by another newly created embedding (of shape [num_decoder_symbols x - cell.input_size]). Then it runs RNN decoder, initialized with the last - encoder state, on embedded decoder_inputs. - - Args: - encoder_inputs: a list of 1D int32-Tensors of shape [batch_size]. - decoder_inputs: a list of 1D int32-Tensors of shape [batch_size]. - cell: rnn_cell.RNNCell defining the cell function and size. - num_encoder_symbols: integer; number of symbols on the encoder side. - num_decoder_symbols: integer; number of symbols on the decoder side. - output_projection: None or a pair (W, B) of output projection weights and - biases; W has shape [cell.output_size x num_decoder_symbols] and B has - shape [num_decoder_symbols]; if provided and feed_previous=True, each - fed previous output will first be multiplied by W and added B. - feed_previous: Boolean or scalar Boolean Tensor; if True, only the first - of decoder_inputs will be used (the "GO" symbol), and all other decoder - inputs will be taken from previous outputs (as in embedding_rnn_decoder). - If False, decoder_inputs are used as given (the standard decoder case). - dtype: The dtype of the initial state for both the encoder and encoder - rnn cells (default: tf.float32). - scope: VariableScope for the created subgraph; defaults to - "embedding_rnn_seq2seq" - - Returns: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x num_decoder_symbols] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. - """ - with tf.variable_scope(scope or "embedding_rnn_seq2seq"): - # Encoder. - encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols) - _, encoder_states = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) - - # Decoder. - if output_projection is None: - cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) - - if isinstance(feed_previous, bool): - return embedding_rnn_decoder(decoder_inputs, encoder_states[-1], cell, - num_decoder_symbols, output_projection, - feed_previous) - else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. - outputs1, states1 = embedding_rnn_decoder( - decoder_inputs, encoder_states[-1], cell, num_decoder_symbols, - output_projection, True) - tf.get_variable_scope().reuse_variables() - outputs2, states2 = embedding_rnn_decoder( - decoder_inputs, encoder_states[-1], cell, num_decoder_symbols, - output_projection, False) - - outputs = tf.control_flow_ops.cond(feed_previous, - lambda: outputs1, lambda: outputs2) - states = tf.control_flow_ops.cond(feed_previous, - lambda: states1, lambda: states2) - return outputs, states - - -def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, - num_symbols, output_projection=None, - feed_previous=False, dtype=tf.float32, - scope=None): - """Embedding RNN sequence-to-sequence model with tied (shared) parameters. - - This model first embeds encoder_inputs by a newly created embedding (of shape - [num_symbols x cell.input_size]). Then it runs an RNN to encode embedded - encoder_inputs into a state vector. Next, it embeds decoder_inputs using - the same embedding. Then it runs RNN decoder, initialized with the last - encoder state, on embedded decoder_inputs. - - Args: - encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - cell: rnn_cell.RNNCell defining the cell function and size. - num_symbols: integer; number of symbols for both encoder and decoder. - output_projection: None or a pair (W, B) of output projection weights and - biases; W has shape [cell.output_size x num_symbols] and B has - shape [num_symbols]; if provided and feed_previous=True, each - fed previous output will first be multiplied by W and added B. - feed_previous: Boolean or scalar Boolean Tensor; if True, only the first - of decoder_inputs will be used (the "GO" symbol), and all other decoder - inputs will be taken from previous outputs (as in embedding_rnn_decoder). - If False, decoder_inputs are used as given (the standard decoder case). - dtype: The dtype to use for the initial RNN states (default: tf.float32). - scope: VariableScope for the created subgraph; defaults to - "embedding_tied_rnn_seq2seq". - - Returns: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x num_decoder_symbols] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. - - Raises: - ValueError: when output_projection has the wrong shape. - """ - if output_projection is not None: - proj_weights = tf.convert_to_tensor(output_projection[0], dtype=dtype) - proj_weights.get_shape().assert_is_compatible_with([cell.output_size, - num_symbols]) - proj_biases = tf.convert_to_tensor(output_projection[1], dtype=dtype) - proj_biases.get_shape().assert_is_compatible_with([num_symbols]) - - with tf.variable_scope(scope or "embedding_tied_rnn_seq2seq"): - with tf.device("/cpu:0"): - embedding = tf.get_variable("embedding", [num_symbols, cell.input_size]) - - emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x) - for x in encoder_inputs] - emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x) - for x in decoder_inputs] - - def extract_argmax_and_embed(prev, _): - """Loop_function that extracts the symbol from prev and embeds it.""" - if output_projection is not None: - prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]) - prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) - return tf.nn.embedding_lookup(embedding, prev_symbol) - - if output_projection is None: - cell = rnn_cell.OutputProjectionWrapper(cell, num_symbols) - - if isinstance(feed_previous, bool): - loop_function = extract_argmax_and_embed if feed_previous else None - return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell, - loop_function=loop_function, dtype=dtype) - else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. - outputs1, states1 = tied_rnn_seq2seq( - emb_encoder_inputs, emb_decoder_inputs, cell, - loop_function=extract_argmax_and_embed, dtype=dtype) - tf.get_variable_scope().reuse_variables() - outputs2, states2 = tied_rnn_seq2seq( - emb_encoder_inputs, emb_decoder_inputs, cell, dtype=dtype) - - outputs = tf.control_flow_ops.cond(feed_previous, - lambda: outputs1, lambda: outputs2) - states = tf.control_flow_ops.cond(feed_previous, - lambda: states1, lambda: states2) - return outputs, states - - -def attention_decoder(decoder_inputs, initial_state, attention_states, cell, - output_size=None, num_heads=1, loop_function=None, - dtype=tf.float32, scope=None): - """RNN decoder with attention for the sequence-to-sequence model. - - Args: - decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - initial_state: 2D Tensor [batch_size x cell.state_size]. - attention_states: 3D Tensor [batch_size x attn_length x attn_size]. - cell: rnn_cell.RNNCell defining the cell function and size. - output_size: size of the output vectors; if None, we use cell.output_size. - num_heads: number of attention heads that read from attention_states. - loop_function: if not None, this function will be applied to i-th output - in order to generate i+1-th input, and decoder_inputs will be ignored, - except for the first element ("GO" symbol). This can be used for decoding, - but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. - Signature -- loop_function(prev, i) = next - * prev is a 2D Tensor of shape [batch_size x cell.output_size], - * i is an integer, the step number (when advanced control is needed), - * next is a 2D Tensor of shape [batch_size x cell.input_size]. - dtype: The dtype to use for the RNN initial state (default: tf.float32). - scope: VariableScope for the created subgraph; default: "attention_decoder". - - Returns: - outputs: A list of the same length as decoder_inputs of 2D Tensors of shape - [batch_size x output_size]. These represent the generated outputs. - Output i is computed from input i (which is either i-th decoder_inputs or - loop_function(output {i-1}, i)) as follows. First, we run the cell - on a combination of the input and previous attention masks: - cell_output, new_state = cell(linear(input, prev_attn), prev_state). - Then, we calculate new attention masks: - new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) - and then we calculate the output: - output = linear(cell_output, new_attn). - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. - - Raises: - ValueError: when num_heads is not positive, there are no inputs, or shapes - of attention_states are not set. - """ - if not decoder_inputs: - raise ValueError("Must provide at least 1 input to attention decoder.") - if num_heads < 1: - raise ValueError("With less than 1 heads, use a non-attention decoder.") - if not attention_states.get_shape()[1:2].is_fully_defined(): - raise ValueError("Shape[1] and [2] of attention_states must be known: %s" - % attention_states.get_shape()) - if output_size is None: - output_size = cell.output_size - - with tf.variable_scope(scope or "attention_decoder"): - batch_size = tf.shape(decoder_inputs[0])[0] # Needed for reshaping. - attn_length = attention_states.get_shape()[1].value - attn_size = attention_states.get_shape()[2].value - - # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. - hidden = tf.reshape(attention_states, [-1, attn_length, 1, attn_size]) - hidden_features = [] - v = [] - attention_vec_size = attn_size # Size of query vectors for attention. - for a in xrange(num_heads): - k = tf.get_variable("AttnW_%d" % a, [1, 1, attn_size, attention_vec_size]) - hidden_features.append(tf.nn.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) - v.append(tf.get_variable("AttnV_%d" % a, [attention_vec_size])) - - states = [initial_state] - - def attention(query): - """Put attention masks on hidden using hidden_features and query.""" - ds = [] # Results of attention reads will be stored here. - for a in xrange(num_heads): - with tf.variable_scope("Attention_%d" % a): - y = linear.linear(query, attention_vec_size, True) - y = tf.reshape(y, [-1, 1, 1, attention_vec_size]) - # Attention mask is a softmax of v^T * tanh(...). - s = tf.reduce_sum(v[a] * tf.tanh(hidden_features[a] + y), [2, 3]) - a = tf.nn.softmax(s) - # Now calculate the attention-weighted vector d. - d = tf.reduce_sum(tf.reshape(a, [-1, attn_length, 1, 1]) * hidden, - [1, 2]) - ds.append(tf.reshape(d, [-1, attn_size])) - return ds - - outputs = [] - prev = None - batch_attn_size = tf.pack([batch_size, attn_size]) - attns = [tf.zeros(batch_attn_size, dtype=dtype) - for _ in xrange(num_heads)] - for a in attns: # Ensure the second shape of attention vectors is set. - a.set_shape([None, attn_size]) - for i in xrange(len(decoder_inputs)): - if i > 0: - tf.get_variable_scope().reuse_variables() - inp = decoder_inputs[i] - # If loop_function is set, we use it instead of decoder_inputs. - if loop_function is not None and prev is not None: - with tf.variable_scope("loop_function", reuse=True): - inp = tf.stop_gradient(loop_function(prev, i)) - # Merge input and previous attentions into one vector of the right size. - x = linear.linear([inp] + attns, cell.input_size, True) - # Run the RNN. - cell_output, new_state = cell(x, states[-1]) - states.append(new_state) - # Run the attention mechanism. - attns = attention(new_state) - with tf.variable_scope("AttnOutputProjection"): - output = linear.linear([cell_output] + attns, output_size, True) - if loop_function is not None: - # We do not propagate gradients over the loop function. - prev = tf.stop_gradient(output) - outputs.append(output) - - return outputs, states - - -def embedding_attention_decoder(decoder_inputs, initial_state, attention_states, - cell, num_symbols, num_heads=1, - output_size=None, output_projection=None, - feed_previous=False, dtype=tf.float32, - scope=None): - """RNN decoder with embedding and attention and a pure-decoding option. - - Args: - decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs). - initial_state: 2D Tensor [batch_size x cell.state_size]. - attention_states: 3D Tensor [batch_size x attn_length x attn_size]. - cell: rnn_cell.RNNCell defining the cell function. - num_symbols: integer, how many symbols come into the embedding. - num_heads: number of attention heads that read from attention_states. - output_size: size of the output vectors; if None, use cell.output_size. - output_projection: None or a pair (W, B) of output projection weights and - biases; W has shape [output_size x num_symbols] and B has shape - [num_symbols]; if provided and feed_previous=True, each fed previous - output will first be multiplied by W and added B. - feed_previous: Boolean; if True, only the first of decoder_inputs will be - used (the "GO" symbol), and all other decoder inputs will be generated by: - next = embedding_lookup(embedding, argmax(previous_output)), - In effect, this implements a greedy decoder. It can also be used - during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. - If False, decoder_inputs are used as given (the standard decoder case). - dtype: The dtype to use for the RNN initial states (default: tf.float32). - scope: VariableScope for the created subgraph; defaults to - "embedding_attention_decoder". - - Returns: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x output_size] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. - - Raises: - ValueError: when output_projection has the wrong shape. - """ - if output_size is None: - output_size = cell.output_size - if output_projection is not None: - proj_weights = tf.convert_to_tensor(output_projection[0], dtype=dtype) - proj_weights.get_shape().assert_is_compatible_with([cell.output_size, - num_symbols]) - proj_biases = tf.convert_to_tensor(output_projection[1], dtype=dtype) - proj_biases.get_shape().assert_is_compatible_with([num_symbols]) - - with tf.variable_scope(scope or "embedding_attention_decoder"): - with tf.device("/cpu:0"): - embedding = tf.get_variable("embedding", [num_symbols, cell.input_size]) - - def extract_argmax_and_embed(prev, _): - """Loop_function that extracts the symbol from prev and embeds it.""" - if output_projection is not None: - prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]) - prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) - emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol) - return emb_prev - - loop_function = None - if feed_previous: - loop_function = extract_argmax_and_embed - - emb_inp = [tf.nn.embedding_lookup(embedding, i) for i in decoder_inputs] - return attention_decoder( - emb_inp, initial_state, attention_states, cell, output_size=output_size, - num_heads=num_heads, loop_function=loop_function) - - -def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell, - num_encoder_symbols, num_decoder_symbols, - num_heads=1, output_projection=None, - feed_previous=False, dtype=tf.float32, - scope=None): - """Embedding sequence-to-sequence model with attention. - - This model first embeds encoder_inputs by a newly created embedding (of shape - [num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode - embedded encoder_inputs into a state vector. It keeps the outputs of this - RNN at every step to use for attention later. Next, it embeds decoder_inputs - by another newly created embedding (of shape [num_decoder_symbols x - cell.input_size]). Then it runs attention decoder, initialized with the last - encoder state, on embedded decoder_inputs and attending to encoder outputs. - - Args: - encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - cell: rnn_cell.RNNCell defining the cell function and size. - num_encoder_symbols: integer; number of symbols on the encoder side. - num_decoder_symbols: integer; number of symbols on the decoder side. - num_heads: number of attention heads that read from attention_states. - output_projection: None or a pair (W, B) of output projection weights and - biases; W has shape [cell.output_size x num_decoder_symbols] and B has - shape [num_decoder_symbols]; if provided and feed_previous=True, each - fed previous output will first be multiplied by W and added B. - feed_previous: Boolean or scalar Boolean Tensor; if True, only the first - of decoder_inputs will be used (the "GO" symbol), and all other decoder - inputs will be taken from previous outputs (as in embedding_rnn_decoder). - If False, decoder_inputs are used as given (the standard decoder case). - dtype: The dtype of the initial RNN state (default: tf.float32). - scope: VariableScope for the created subgraph; defaults to - "embedding_attention_seq2seq". - - Returns: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x num_decoder_symbols] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. - """ - with tf.variable_scope(scope or "embedding_attention_seq2seq"): - # Encoder. - encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols) - encoder_outputs, encoder_states = rnn.rnn( - encoder_cell, encoder_inputs, dtype=dtype) - - # First calculate a concatenation of encoder outputs to put attention on. - top_states = [tf.reshape(e, [-1, 1, cell.output_size]) - for e in encoder_outputs] - attention_states = tf.concat(1, top_states) - - # Decoder. - output_size = None - if output_projection is None: - cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) - output_size = num_decoder_symbols - - if isinstance(feed_previous, bool): - return embedding_attention_decoder( - decoder_inputs, encoder_states[-1], attention_states, cell, - num_decoder_symbols, num_heads, output_size, output_projection, - feed_previous) - else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. - outputs1, states1 = embedding_attention_decoder( - decoder_inputs, encoder_states[-1], attention_states, cell, - num_decoder_symbols, num_heads, output_size, output_projection, True) - tf.get_variable_scope().reuse_variables() - outputs2, states2 = embedding_attention_decoder( - decoder_inputs, encoder_states[-1], attention_states, cell, - num_decoder_symbols, num_heads, output_size, output_projection, False) - - outputs = tf.control_flow_ops.cond(feed_previous, - lambda: outputs1, lambda: outputs2) - states = tf.control_flow_ops.cond(feed_previous, - lambda: states1, lambda: states2) - return outputs, states - - -def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols, - average_across_timesteps=True, - softmax_loss_function=None, name=None): - """Weighted cross-entropy loss for a sequence of logits (per example). - - Args: - logits: list of 2D Tensors of shape [batch_size x num_decoder_symbols]. - targets: list of 1D batch-sized int32-Tensors of the same length as logits. - weights: list of 1D batch-sized float-Tensors of the same length as logits. - num_decoder_symbols: integer, number of decoder symbols (output classes). - average_across_timesteps: If set, divide the returned cost by the total - label weight. - softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch - to be used instead of the standard softmax (the default if this is None). - name: optional name for this operation, default: "sequence_loss_by_example". - - Returns: - 1D batch-sized float Tensor: the log-perplexity for each sequence. - - Raises: - ValueError: if len(logits) is different from len(targets) or len(weights). - """ - if len(targets) != len(logits) or len(weights) != len(logits): - raise ValueError("Lengths of logits, weights, and targets must be the same " - "%d, %d, %d." % (len(logits), len(weights), len(targets))) - with tf.op_scope(logits + targets + weights, name, - "sequence_loss_by_example"): - batch_size = tf.shape(targets[0])[0] - log_perp_list = [] - length = batch_size * num_decoder_symbols - for i in xrange(len(logits)): - if softmax_loss_function is None: - # TODO(lukaszkaiser): There is no SparseCrossEntropy in TensorFlow, so - # we need to first cast targets into a dense representation, and as - # SparseToDense does not accept batched inputs, we need to do this by - # re-indexing and re-sizing. When TensorFlow adds SparseCrossEntropy, - # rewrite this method. - indices = targets[i] + num_decoder_symbols * tf.range(batch_size) - with tf.device("/cpu:0"): # Sparse-to-dense must happen on CPU for now. - dense = tf.sparse_to_dense(indices, tf.expand_dims(length, 0), 1.0, - 0.0) - target = tf.reshape(dense, [-1, num_decoder_symbols]) - crossent = tf.nn.softmax_cross_entropy_with_logits( - logits[i], target, name="SequenceLoss/CrossEntropy{0}".format(i)) - else: - crossent = softmax_loss_function(logits[i], targets[i]) - log_perp_list.append(crossent * weights[i]) - log_perps = tf.add_n(log_perp_list) - if average_across_timesteps: - total_size = tf.add_n(weights) - total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. - log_perps /= total_size - return log_perps - - -def sequence_loss(logits, targets, weights, num_decoder_symbols, - average_across_timesteps=True, average_across_batch=True, - softmax_loss_function=None, name=None): - """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. - - Args: - logits: list of 2D Tensors os shape [batch_size x num_decoder_symbols]. - targets: list of 1D batch-sized int32-Tensors of the same length as logits. - weights: list of 1D batch-sized float-Tensors of the same length as logits. - num_decoder_symbols: integer, number of decoder symbols (output classes). - average_across_timesteps: If set, divide the returned cost by the total - label weight. - average_across_batch: If set, divide the returned cost by the batch size. - softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch - to be used instead of the standard softmax (the default if this is None). - name: optional name for this operation, defaults to "sequence_loss". - - Returns: - A scalar float Tensor: the average log-perplexity per symbol (weighted). - - Raises: - ValueError: if len(logits) is different from len(targets) or len(weights). - """ - with tf.op_scope(logits + targets + weights, name, "sequence_loss"): - cost = tf.reduce_sum(sequence_loss_by_example( - logits, targets, weights, num_decoder_symbols, - average_across_timesteps=average_across_timesteps, - softmax_loss_function=softmax_loss_function)) - if average_across_batch: - batch_size = tf.shape(targets[0])[0] - return cost / tf.cast(batch_size, tf.float32) - else: - return cost - - -def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, - buckets, num_decoder_symbols, seq2seq, - softmax_loss_function=None, name=None): - """Create a sequence-to-sequence model with support for bucketing. - - The seq2seq argument is a function that defines a sequence-to-sequence model, - e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) - - Args: - encoder_inputs: a list of Tensors to feed the encoder; first seq2seq input. - decoder_inputs: a list of Tensors to feed the decoder; second seq2seq input. - targets: a list of 1D batch-sized int32-Tensors (desired output sequence). - weights: list of 1D batch-sized float-Tensors to weight the targets. - buckets: a list of pairs of (input size, output size) for each bucket. - num_decoder_symbols: integer, number of decoder symbols (output classes). - seq2seq: a sequence-to-sequence model function; it takes 2 input that - agree with encoder_inputs and decoder_inputs, and returns a pair - consisting of outputs and states (as, e.g., basic_rnn_seq2seq). - softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch - to be used instead of the standard softmax (the default if this is None). - name: optional name for this operation, defaults to "model_with_buckets". - - Returns: - outputs: The outputs for each bucket. Its j'th element consists of a list - of 2D Tensors of shape [batch_size x num_decoder_symbols] (j'th outputs). - losses: List of scalar Tensors, representing losses for each bucket. - Raises: - ValueError: if length of encoder_inputsut, targets, or weights is smaller - than the largest (last) bucket. - """ - if len(encoder_inputs) < buckets[-1][0]: - raise ValueError("Length of encoder_inputs (%d) must be at least that of la" - "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) - if len(targets) < buckets[-1][1]: - raise ValueError("Length of targets (%d) must be at least that of last" - "bucket (%d)." % (len(targets), buckets[-1][1])) - if len(weights) < buckets[-1][1]: - raise ValueError("Length of weights (%d) must be at least that of last" - "bucket (%d)." % (len(weights), buckets[-1][1])) - - all_inputs = encoder_inputs + decoder_inputs + targets + weights - losses = [] - outputs = [] - with tf.op_scope(all_inputs, name, "model_with_buckets"): - for j in xrange(len(buckets)): - if j > 0: - tf.get_variable_scope().reuse_variables() - bucket_encoder_inputs = [encoder_inputs[i] - for i in xrange(buckets[j][0])] - bucket_decoder_inputs = [decoder_inputs[i] - for i in xrange(buckets[j][1])] - bucket_outputs, _ = seq2seq(bucket_encoder_inputs, - bucket_decoder_inputs) - outputs.append(bucket_outputs) - - bucket_targets = [targets[i] for i in xrange(buckets[j][1])] - bucket_weights = [weights[i] for i in xrange(buckets[j][1])] - losses.append(sequence_loss( - outputs[-1], bucket_targets, bucket_weights, num_decoder_symbols, - softmax_loss_function=softmax_loss_function)) - - return outputs, losses +from tensorflow.python.ops.seq2seq import * diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 76126b30435..49f42dd6f3b 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -34,6 +34,7 @@ py_library( ":client_testlib", ":framework", ":framework_test_lib", + ":kernel_tests/gradient_checker", ":platform", ":platform_test", ":summary", @@ -467,6 +468,7 @@ tf_gen_op_wrapper_py( "ReluGrad", "Relu6Grad", "SoftplusGrad", + "SoftsignGrad", "BiasAdd", "Relu6", "AvgPool", @@ -588,6 +590,9 @@ py_library( "ops/op_def_library.py", "ops/parsing_ops.py", "ops/random_ops.py", + "ops/rnn.py", + "ops/rnn_cell.py", + "ops/seq2seq.py", "ops/sparse_grad.py", "ops/sparse_ops.py", "ops/standard_ops.py", diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index 9646ef6673d..7c4018332d8 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -93,8 +93,8 @@ def all_libraries(module_to_name, members, documented): "max_pool_grad", "max_pool_grad_with_argmax", "batch_norm_with_global_normalization_grad", "lrn_grad", "relu6_grad", "softplus_grad", - "xw_plus_b", "relu_layer", "lrn", - "batch_norm_with_global_normalization", + "softsign_grad", "xw_plus_b", "relu_layer", + "lrn", "batch_norm_with_global_normalization", "batch_norm_with_global_normalization_grad", "all_candidate_sampler", "embedding_lookup_sparse"], diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index d66e93300d9..352c73c0f77 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -442,8 +442,8 @@ class Tensor(object): return _eval_using_default_session(self, feed_dict, self.graph, session) -def _TensorTensorConversionFunction(t, dtype=None, name=None): - _ = name +def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False): + _ = name, as_ref if dtype and not dtype.is_compatible_with(t.dtype): raise ValueError( "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" @@ -455,7 +455,7 @@ _tensor_conversion_func_registry = { 0: [(Tensor, _TensorTensorConversionFunction)]} -def convert_to_tensor(value, dtype=None, name=None): +def convert_to_tensor(value, dtype=None, name=None, as_ref=False): """Converts the given `value` to a `Tensor`. This function converts Python objects of various types to `Tensor` @@ -487,6 +487,7 @@ def convert_to_tensor(value, dtype=None, name=None): dtype: Optional element type for the returned tensor. If missing, the type is inferred from the type of `value`. name: Optional name to use if a new `Tensor` is created. + as_ref: True if we want the result as a ref tensor. Returns: A `Tensor` based on `value`. @@ -502,7 +503,7 @@ def convert_to_tensor(value, dtype=None, name=None): for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()): for base_type, conversion_func in funcs_at_priority: if isinstance(value, base_type): - ret = conversion_func(value, dtype=dtype, name=name) + ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) if not isinstance(ret, Tensor): raise RuntimeError( "%sConversion function %r for type %s returned non-Tensor: %r" @@ -519,7 +520,8 @@ def convert_to_tensor(value, dtype=None, name=None): % (error_prefix, value, type(value))) -def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): +def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None, + as_ref=False): """Converts the given object to a `Tensor` or an `IndexedSlices`. If `value` is an `IndexedSlices` it is returned @@ -532,6 +534,7 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): dtype: (Optional.) The required `DType` of the returned `Tensor` or `IndexedSlices`. name: (Optional.) A name to use if a new `Tensor` is created. + as_ref: True if the caller wants the results as ref tensors. Returns: An `Tensor` or an `IndexedSlices` based on `value`. @@ -546,10 +549,11 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): % (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) return value else: - return convert_to_tensor(value, dtype, name) + return convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref) -def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): +def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None, + as_ref=False): """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. Args: @@ -557,10 +561,10 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): by `convert_to_tensor()`. dtype: (Optional.) The required `DType` of the returned `Tensor` `IndexedSlices`. - name: (Optional.) A name prefix to used when a new `Tensor` is created, in which case element `i` will be given the name `name + '_' + i`. + as_ref: True if the caller wants the results as ref tensors. Returns: A list of `Tensor` and/or `IndexedSlices` objects. @@ -580,7 +584,8 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): else: n = None if name is None else "%s_%d" % (name, i) ret.append( - convert_to_tensor_or_indexed_slices(value, dtype=dtype, name=n)) + convert_to_tensor_or_indexed_slices(value, dtype=dtype, name=n, + as_ref=as_ref)) return ret @@ -590,13 +595,16 @@ def register_tensor_conversion_function(base_type, conversion_func, The conversion function must have the following signature: - def conversion_func(value, dtype=None, name=None): + def conversion_func(value, dtype=None, name=None, as_ref=False): # ... It must return a Tensor with the given dtype if specified. If the conversion function creates a new Tensor, it should use the given name if specified. All exceptions will be propagated to the caller. + If `as_ref` is true, the function must return a Tensor reference, + such as a VariableOp. + NOTE: The conversion functions will execute in order of priority, followed by order of registration. To ensure that a conversion function F runs before another conversion function G, ensure that @@ -762,23 +770,23 @@ class SparseTensor(object): ``` By convention, `indices` should be sorted in row-major order (or equivalently - lexigraphic order on the tuples `indices[i]`). This is not enforced when - `SparseTensor` objects are constructed, but most Ops assume correct ordering. + lexicographic order on the tuples `indices[i]`). This is not enforced when + `SparseTensor` objects are constructed, but most ops assume correct ordering. If the ordering is wrong, it can be fixed by calling `sparse_reorder` on the misordered `SparseTensor`. Example: The sparse tensor ```python - SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4]) + SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4]) ``` represents the dense tensor ```python - [[1, 0, 0, 0] - [0, 0, 2, 0] - [0, 0, 0, 0]] + [[1, 0, 0, 0] + [0, 0, 2, 0] + [0, 0, 0, 0]] ``` @@__init__ @@ -795,14 +803,18 @@ class SparseTensor(object): Args: indices: A 2-D int64 tensor of shape `[N, ndims]`. values: A 1-D tensor of any type and shape `[N]`. - dense_shape: A 1-D int64 tensor of shape `[ndims]`. + shape: A 1-D int64 tensor of shape `[ndims]`. Returns: A `SparseTensor` """ with op_scope([indices, values, shape], None, "SparseTensor"): indices = convert_to_tensor(indices, name="indices") - values = convert_to_tensor(values, name="values") + # Always pass as_ref=True because we want to be able to update + # values later if it is a VariableOp. + # TODO(touts): Consider adding mutable_values() when 'values' + # is a VariableOp and updating users of SparseTensor. + values = convert_to_tensor(values, name="values", as_ref=True) shape = convert_to_tensor(shape, name="shape") self._indices = indices self._values = values @@ -987,7 +999,9 @@ class Operation(object): self._graph = g if inputs is None: inputs = [] - self._inputs = inputs + elif not isinstance(inputs, list): + raise TypeError("inputs needs to be a list of Tensors: %s" % inputs) + self._inputs = list(inputs) # Defensive copy. for a in self._inputs: if not isinstance(a, Tensor): raise TypeError("input needs to be a Tensor: %s" % a) @@ -1391,6 +1405,7 @@ def get_gradient_function(op): _shape_registry = registry.Registry("shape functions") _default_shape_function_registry = registry.Registry("default shape functions") + class RegisterShape(object): """A decorator for registering the shape function for an op type. @@ -1924,6 +1939,7 @@ class Graph(object): A list of Operations. """ return list(self._nodes_by_id.values()) + def get_operation_by_name(self, name): """Returns the `Operation` with the given `name`. @@ -2045,7 +2061,7 @@ class Graph(object): else: c = [] for item in self._collections.get(name, list()): - if hasattr(item, 'name') and item.name.startswith(scope): + if hasattr(item, "name") and item.name.startswith(scope): c.append(item) return c diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 7802db473e4..dd0c6e01b6d 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -522,19 +522,21 @@ def ConstantValue(tensor): elif tensor.op.type == "Shape": input_shape = tensor.op.inputs[0].get_shape() if input_shape.is_fully_defined(): - return np.array([dim.value for dim in input_shape.dims]) + return np.array([dim.value for dim in input_shape.dims], + dtype=tensor.dtype.as_numpy_dtype) else: return None elif tensor.op.type == "Size": input_shape = tensor.op.inputs[0].get_shape() if input_shape.is_fully_defined(): - return np.array([np.prod([dim.value for dim in input_shape.dims])]) + return np.array([np.prod([dim.value for dim in input_shape.dims])], + dtype=tensor.dtype.as_numpy_dtype) else: return None elif tensor.op.type == "Rank": input_shape = tensor.op.inputs[0].get_shape() if input_shape.ndims is not None: - return np.array([input_shape.ndims]) + return np.array([input_shape.ndims], dtype=tensor.dtype.as_numpy_dtype) else: return None elif tensor.op.type == "Range": diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index c7e672d460e..f2828475aef 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -378,19 +378,25 @@ class ConstantValueTest(test_util.TensorFlowTestCase): self.assertIs(None, tensor_util.ConstantValue(tf_val)) def testShape(self): - np_val = np.array([1, 2, 3]) + np_val = np.array([1, 2, 3], dtype=np.int32) tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) - self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val)) + c_val = tensor_util.ConstantValue(tf_val) + self.assertAllEqual(np_val, c_val) + self.assertEqual(np.int32, c_val.dtype) def testSize(self): - np_val = np.array([6]) + np_val = np.array([6], dtype=np.int32) tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3])) - self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val)) + c_val = tensor_util.ConstantValue(tf_val) + self.assertAllEqual(np_val, c_val) + self.assertEqual(np.int32, c_val.dtype) def testRank(self): - np_val = np.array([3]) + np_val = np.array([3], dtype=np.int32) tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3])) - self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val)) + c_val = tensor_util.ConstantValue(tf_val) + self.assertAllEqual(np_val, c_val) + self.assertEqual(np.int32, c_val.dtype) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py index 2a09594ad43..809b23bd7d5 100644 --- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class BatchMatmulOpTest(tf.test.TestCase): @@ -176,9 +174,14 @@ class BatchMatmulGradientTest(tf.test.TestCase): z = tf.batch_matmul(inx, iny, adj_x, adj_y) loss = tf.reduce_sum(z) epsilon = 1e-2 - ((x_jacob_t, x_jacob_n), (y_jacob_t, y_jacob_n)) = gc.ComputeGradient( - [inx, iny], [x.shape, y.shape], loss, [1], - x_init_value=[x, y], delta=epsilon) + ((x_jacob_t, x_jacob_n), + (y_jacob_t, y_jacob_n)) = tf.test.compute_gradient( + [inx, iny], + [x.shape, y.shape], + loss, + [1], + x_init_value=[x, y], + delta=epsilon) tf.logging.info("x_jacob_t = %s", x_jacob_t.reshape(x.shape)) tf.logging.info("x_jacob_n = %s", x_jacob_n.reshape(x.shape)) diff --git a/tensorflow/python/kernel_tests/bias_op_test.py b/tensorflow/python/kernel_tests/bias_op_test.py index cffbfc97c4e..e79cb8fc022 100644 --- a/tensorflow/python/kernel_tests/bias_op_test.py +++ b/tensorflow/python/kernel_tests/bias_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker - class BiasAddTest(tf.test.TestCase): @@ -82,7 +80,7 @@ class BiasAddTest(tf.test.TestCase): dtype=tf.float64) b = tf.constant([1.3, 2.4], dtype=tf.float64) bo = tf.nn.bias_add(t, b) - err = gradient_checker.ComputeGradientError(t, [3, 2], bo, [3, 2]) + err = tf.test.compute_gradient_error(t, [3, 2], bo, [3, 2]) print("bias add tensor gradient err = ", err) self.assertLess(err, 1e-10) @@ -92,7 +90,7 @@ class BiasAddTest(tf.test.TestCase): dtype=tf.float64) b = tf.constant([1.3, 2.4], dtype=tf.float64) bo = tf.nn.bias_add(t, b) - err = gradient_checker.ComputeGradientError(b, [2], bo, [3, 2]) + err = tf.test.compute_gradient_error(b, [2], bo, [3, 2]) print("bias add bias gradient err = ", err) self.assertLess(err, 1e-10) @@ -103,7 +101,7 @@ class BiasAddTest(tf.test.TestCase): t = tf.constant(x, shape=s, dtype=tf.float32) b = tf.constant([1.3, 2.4], dtype=tf.float32) bo = tf.nn.bias_add(t, b) - err = gradient_checker.ComputeGradientError(t, s, bo, s, x_init_value=x) + err = tf.test.compute_gradient_error(t, s, bo, s, x_init_value=x) print("bias add tensor gradient err = ", err) self.assertLess(err, 1e-3) diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py index 4b3699c0611..cf2a8949cbe 100644 --- a/tensorflow/python/kernel_tests/cast_op_test.py +++ b/tensorflow/python/kernel_tests/cast_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class CastOpTest(tf.test.TestCase): @@ -160,7 +158,7 @@ class CastOpTest(tf.test.TestCase): x = tf.constant(1.0, src_t) z = tf.identity(x) y = tf.cast(z, dst_t) - err = gc.ComputeGradientError(x, [1], y, [1]) + err = tf.test.compute_gradient_error(x, [1], y, [1]) self.assertLess(err, 1e-3) diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index f96750d4b04..9bd73f710a5 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -303,6 +303,63 @@ class ConcatOpTest(tf.test.TestCase): dxs = sess.run(tf.gradients(c, xs, dc)) self.assertAllEqual(dc, np.concatenate(dxs, axis=axis)) + def testTensorConcatDim0Grad(self): + x_shapes = [[20, 7, 3], [10, 7, 3], [14, 7, 3]] + output_shape = [44, 7, 3] + x_vals = [np.random.random_sample(x_shape).astype( + np.float64) for x_shape in x_shapes] + with self.test_session(): + xs = [tf.constant(x_val) for x_val in x_vals] + output = tf.concat(0, xs) + err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape) + self.assertLess(err, 1e-11) + + def testTensorConcatDim1Grad(self): + x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]] + output_shape = [20, 11, 3] + x_vals = [np.random.random_sample(x_shape).astype( + np.float64) for x_shape in x_shapes] + with self.test_session(): + xs = [tf.constant(x_val) for x_val in x_vals] + output = tf.concat(1, xs) + err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape) + self.assertLess(err, 1e-11) + + def testIndexedSlicesConcatDim0Grad(self): + x_shapes = [[20, 7, 3], [10, 7, 3], [14, 7, 3]] + output_shape = [4, 7, 3] + x_vals = [np.random.random_sample(x_shape).astype( + np.float64) for x_shape in x_shapes] + with self.test_session(): + xs = [tf.constant(x_val) for x_val in x_vals] + x_concat = tf.concat(0, xs) + output = tf.gather(x_concat, [1, 2, 0, 5]) + err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape) + self.assertLess(err, 1e-11) + + def testIndexedSlicesConcatDim1Grad(self): + x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]] + output_shape = [4, 11, 3] + x_vals = [np.random.random_sample(x_shape).astype( + np.float64) for x_shape in x_shapes] + with self.test_session(): + xs = [tf.constant(x_val) for x_val in x_vals] + x_concat = tf.concat(1, xs) + output = tf.gather(x_concat, [1, 2, 0, 5]) + err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape) + self.assertLess(err, 1e-11) + + def testIndexedSlicesConcatDim2Grad(self): + x_shapes = [[20, 7, 3], [20, 7, 1], [20, 7, 2]] + output_shape = [4, 7, 6] + x_vals = [np.random.random_sample(x_shape).astype( + np.float64) for x_shape in x_shapes] + with self.test_session(): + xs = [tf.constant(x_val) for x_val in x_vals] + x_concat = tf.concat(2, xs) + output = tf.gather(x_concat, [1, 2, 0, 5]) + err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape) + self.assertLess(err, 1e-11) if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index c9634562549..b70ec134aba 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1091,9 +1091,10 @@ class ControlFlowTest(tf.test.TestCase): # Use a control dependency to ensure init_variable is run # while asking for c - real_v = control_flow_ops.with_dependencies(name="real_tensor", - output_tensor=v, - dependencies=[v.initializer]) + real_v = control_flow_ops.with_dependencies( + name="real_tensor", + output_tensor=v.ref(), + dependencies=[v.initializer]) c_val, real_v_val = sess.run([c, real_v]) # Ensure the result of 'real_c' is the same as 'c' @@ -1259,12 +1260,12 @@ class TupleTest(tf.test.TestCase): with self.test_session(): v1 = tf.Variable([1.0]) add1 = tf.add( - control_flow_ops.with_dependencies([v1.initializer], v1), + control_flow_ops.with_dependencies([v1.initializer], v1.ref()), 2.0) v2 = tf.Variable([10.0]) - add2 = tf.add(control_flow_ops.with_dependencies([v2.initializer], - v2), - 20.0) + add2 = tf.add( + control_flow_ops.with_dependencies([v2.initializer], v2.ref()), + 20.0) t1, _, t2 = control_flow_ops.tuple([add1, None, add2]) # v1 is not initialized. @@ -1291,14 +1292,14 @@ class TupleTest(tf.test.TestCase): np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype( np.float32)) v1_at_1 = tf.IndexedSlices( - control_flow_ops.with_dependencies([v1.initializer], v1), + control_flow_ops.with_dependencies([v1.initializer], v1.ref()), tf.constant([1])) v2 = tf.Variable( np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype( np.float32)) v2_at_1 = tf.IndexedSlices( - control_flow_ops.with_dependencies([v2.initializer], v2), + control_flow_ops.with_dependencies([v2.initializer], v2.ref()), tf.constant([1])) st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1]) diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 1ce09c48a6a..5efe4855670 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - def GetInceptionShapes(): """Iterator for the convolution shapes used in the Inception 2015 model. @@ -429,11 +427,11 @@ class Conv2DTest(tf.test.TestCase): name="conv") self.assertEqual(output_shape, conv.get_shape()) if test_input: - err = gc.ComputeGradientError(input_tensor, input_shape, - conv, output_shape) + err = tf.test.compute_gradient_error(input_tensor, input_shape, conv, + output_shape) else: - err = gc.ComputeGradientError(filter_tensor, filter_shape, - conv, output_shape) + err = tf.test.compute_gradient_error(filter_tensor, filter_shape, conv, + output_shape) print("conv_2d gradient error = ", err) self.assertLess(err, tolerance) diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 4fb2fcafbf4..a823250d512 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -24,7 +24,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc _ADD = lambda x, y: x + y _SUB = lambda x, y: x - y @@ -58,11 +57,19 @@ class UnaryOpTest(tf.test.TestCase): self.assertAllClose(np_ans, tf_cpu) if x.dtype == np.float32: s = list(np.shape(x)) - jacob_t, jacob_n = gc.ComputeGradient(inx, s, y, s, x_init_value=x) + jacob_t, jacob_n = tf.test.compute_gradient(inx, + s, + y, + s, + x_init_value=x) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) elif x.dtype == np.float64: s = list(np.shape(x)) - jacob_t, jacob_n = gc.ComputeGradient(inx, s, y, s, x_init_value=x) + jacob_t, jacob_n = tf.test.compute_gradient(inx, + s, + y, + s, + x_init_value=x) self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5) def _compareGpu(self, x, np_func, tf_func): @@ -216,7 +223,11 @@ class BinaryOpTest(tf.test.TestCase): iny = tf.convert_to_tensor(y) out = tf_func(inx, iny) xs = list(x.shape) - jacob_t, jacob_n = gc.ComputeGradient(inx, xs, out, zs, x_init_value=x) + jacob_t, jacob_n = tf.test.compute_gradient(inx, + xs, + out, + zs, + x_init_value=x) if x.dtype == np.float32: self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) elif x.dtype == np.float64: @@ -230,7 +241,11 @@ class BinaryOpTest(tf.test.TestCase): iny = tf.convert_to_tensor(y) out = tf_func(inx, iny) ys = list(np.shape(y)) - jacob_t, jacob_n = gc.ComputeGradient(iny, ys, out, zs, x_init_value=y) + jacob_t, jacob_n = tf.test.compute_gradient(iny, + ys, + out, + zs, + x_init_value=y) if x.dtype == np.float32: self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) elif x.dtype == np.float64: @@ -833,7 +848,11 @@ class SelectOpTest(tf.test.TestCase): iny = tf.convert_to_tensor(y) out = tf.select(c, inx, iny) s = list(np.shape(c)) - jacob_t, jacob_n = gc.ComputeGradient(inx, s, out, s, x_init_value=x) + jacob_t, jacob_n = tf.test.compute_gradient(inx, + s, + out, + s, + x_init_value=x) if x.dtype == np.float32: self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) elif x.dtype == np.float64: @@ -845,7 +864,11 @@ class SelectOpTest(tf.test.TestCase): iny = tf.convert_to_tensor(y) out = tf.select(c, inx, iny) s = list(np.shape(c)) - jacob_t, jacob_n = gc.ComputeGradient(iny, s, out, s, x_init_value=y) + jacob_t, jacob_n = tf.test.compute_gradient(iny, + s, + out, + s, + x_init_value=y) if x.dtype == np.float32: self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) elif x.dtype == np.float64: @@ -923,7 +946,11 @@ class MinMaxOpTest(tf.test.TestCase): iny = tf.convert_to_tensor(y) out = func(inx, iny) s = list(np.shape(x)) - jacob_t, jacob_n = gc.ComputeGradient(inx, s, out, s, x_init_value=x) + jacob_t, jacob_n = tf.test.compute_gradient(inx, + s, + out, + s, + x_init_value=x) if x.dtype == np.float32: self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) elif x.dtype == np.float64: @@ -935,7 +962,11 @@ class MinMaxOpTest(tf.test.TestCase): iny = tf.convert_to_tensor(y) out = func(inx, iny) s = list(np.shape(x)) - jacob_t, jacob_n = gc.ComputeGradient(iny, s, out, s, x_init_value=y) + jacob_t, jacob_n = tf.test.compute_gradient(iny, + s, + out, + s, + x_init_value=y) if x.dtype == np.float32: self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) elif x.dtype == np.float64: @@ -1159,8 +1190,12 @@ class ComplexMakeRealImagTest(tf.test.TestCase): tf.square(tf.real(cplx))) + tf.reduce_sum( tf.square(tf.imag(cplx))) epsilon = 1e-3 - jacob_t, jacob_n = gc.ComputeGradient(inx, list(x.shape), loss, [1], - x_init_value=x, delta=epsilon) + jacob_t, jacob_n = tf.test.compute_gradient(inx, + list(x.shape), + loss, + [1], + x_init_value=x, + delta=epsilon) self.assertAllClose(jacob_t, jacob_n, rtol=epsilon, atol=epsilon) def testGradient(self): @@ -1187,8 +1222,12 @@ class ComplexMakeRealImagTest(tf.test.TestCase): # Defines the loss function as the sum of all coefficients of z. loss = tf.reduce_sum(tf.real(z) + tf.imag(z)) epsilon = 0.005 - jacob_t, jacob_n = gc.ComputeGradient(inp, list(data.shape), loss, [1], - x_init_value=data, delta=epsilon) + jacob_t, jacob_n = tf.test.compute_gradient(inp, + list(data.shape), + loss, + [1], + x_init_value=data, + delta=epsilon) self.assertAllClose(jacob_t, jacob_n, rtol=epsilon, atol=epsilon) def testMulGradient(self): diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index b17cdc0ed54..5f54f02bf06 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -26,8 +26,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - def _AsLong(array): """Casts arrays elements to long type. Used to convert from numpy tf.""" @@ -225,8 +223,11 @@ class EmbeddingLookupTest(tf.test.TestCase): x_name = [_PName(i) for i in range(num_shards)] x_init_value = [params[x_n + ":0"] for x_n in x_name] x_shape = [i.shape for i in x_init_value] - err = gc.ComputeGradientError(x, x_shape, y, y_shape, - x_init_value=x_init_value) + err = tf.test.compute_gradient_error(x, + x_shape, + y, + y_shape, + x_init_value=x_init_value) self.assertLess(err, 1e-4) def testGradientsEmbeddingLookupWithComputedParams(self): @@ -246,8 +247,11 @@ class EmbeddingLookupTest(tf.test.TestCase): x_name = [_PName(i) for i in range(num_shards)] x_init_value = [params[x_n + ":0"] for x_n in x_name] x_shape = [i.shape for i in x_init_value] - err = gc.ComputeGradientError(x, x_shape, y, y_shape, - x_init_value=x_init_value) + err = tf.test.compute_gradient_error(x, + x_shape, + y, + y_shape, + x_init_value=x_init_value) self.assertLess(err, 1e-3) def testConstructionNonSharded(self): @@ -381,8 +385,11 @@ class EmbeddingLookupSparseTest(tf.test.TestCase): x_init_value = [params[x_n + ":0"] for x_n in x_name] x_shape = [i.shape for i in x_init_value] y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:]) - err = gc.ComputeGradientError(x, x_shape, y, y_shape, - x_init_value=x_init_value) + err = tf.test.compute_gradient_error(x, + x_shape, + y, + y_shape, + x_init_value=x_init_value) self.assertLess(err, 1e-5 if dtype == tf.float64 else 2e-3) diff --git a/tensorflow/python/kernel_tests/gradient_checker.py b/tensorflow/python/kernel_tests/gradient_checker.py index 69cc811a6ba..d0cdc3b3bcb 100644 --- a/tensorflow/python/kernel_tests/gradient_checker.py +++ b/tensorflow/python/kernel_tests/gradient_checker.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import gradients from tensorflow.python.platform import logging -def _Product(t): +def _product(t): if isinstance(t, int): return t else: @@ -44,11 +44,11 @@ def _Product(t): return y -def _ComputeTheoricalJacobian(x, x_shape, x_data, dy, dy_shape, dx): +def _compute_theoretical_jacobian(x, x_shape, x_data, dy, dy_shape, dx): """Computes the theoretical Jacobian for dy/dx. Computes the theoretical Jacobian using the ops generated by - ComputeGradient(). + compute_gradient(). Args: x: the tensor "x". @@ -64,9 +64,9 @@ def _ComputeTheoricalJacobian(x, x_shape, x_data, dy, dy_shape, dx): "dy_size" is the number of elements in dy. """ # To compute the jacobian, we treat x and y are one-dimensional vectors - x_size = _Product(x_shape) - x_val_size = _Product(x_shape[1:]) # This is used for sparse gradients - dy_size = _Product(dy_shape) + x_size = _product(x_shape) + x_val_size = _product(x_shape[1:]) # This is used for sparse gradients + dy_size = _product(dy_shape) jacobian = np.zeros((x_size, dy_size), dtype=x_data.dtype) # For each of the entry of dy, we set this to be 1 and @@ -92,7 +92,7 @@ def _ComputeTheoricalJacobian(x, x_shape, x_data, dy, dy_shape, dx): return jacobian -def _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta): +def _compute_numeric_jacobian(x, x_shape, x_data, y, y_shape, delta): """Computes the numeric Jacobian for dy/dx. Computes the numeric Jacobian by slightly perturbing the inputs and @@ -113,8 +113,8 @@ def _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta): """ # To compute the jacobian, we treat x and y are one-dimensional vectors - x_size = _Product(x_shape) - y_size = _Product(y_shape) + x_size = _product(x_shape) + y_size = _product(y_shape) jacobian = np.zeros((x_size, y_size), dtype=x_data.dtype) # For each of the entry of x, we slightly perturbs this by adding and @@ -134,7 +134,7 @@ def _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta): return jacobian -def _ComputeDxAndDy(x, y, y_shape): +def _compute_dx_and_dy(x, y, y_shape): """Returns a node to compute gradient of x wrt y.""" # We make up a dy so that we can compute the gradients. We don't really use # the value of dy -- we will always feed it. We need to add an identity node @@ -149,8 +149,14 @@ def _ComputeDxAndDy(x, y, y_shape): return grads[0], dy_orig -def _ComputeGradient(x, x_shape, dx, y, y_shape, dy, - x_init_value=None, delta=1e-3): +def _compute_gradient(x, + x_shape, + dx, + y, + y_shape, + dy, + x_init_value=None, + delta=1e-3): """Computes the theoretical and numerical jacobian.""" t = dtypes.as_dtype(x.dtype) allowed_types = [dtypes.float32, dtypes.float64] @@ -170,16 +176,21 @@ def _ComputeGradient(x, x_shape, dx, y, y_shape, dy, dtype = np.float64 x_data = np.asfarray(np.random.random_sample(x_shape), dtype=dtype) - jacob_t = _ComputeTheoricalJacobian(x, x_shape, x_data, dy, y_shape, dx) - jacob_n = _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta) + jacob_t = _compute_theoretical_jacobian(x, x_shape, x_data, dy, y_shape, dx) + jacob_n = _compute_numeric_jacobian(x, x_shape, x_data, y, y_shape, delta) return jacob_t, jacob_n -def _ComputeGradientList( - x, x_shape, y, y_shape, x_init_value=None, delta=1e-3, init_targets=None): +def _compute_gradient_list(x, + x_shape, + y, + y_shape, + x_init_value=None, + delta=1e-3, + init_targets=None): """Compute gradients for a list of x values.""" assert isinstance(x, list) - dx, dy = zip(*[_ComputeDxAndDy(xi, y, y_shape) for xi in x]) + dx, dy = zip(*[_compute_dx_and_dy(xi, y, y_shape) for xi in x]) if init_targets is not None: assert isinstance(init_targets, (list, tuple)) @@ -187,15 +198,20 @@ def _ComputeGradientList( init.run() if x_init_value is None: x_init_value = [None] * len(x) - ret = [_ComputeGradient(xi, x_shapei, dxi, y, y_shape, dyi, - x_init_valuei, delta) - for xi, x_shapei, dxi, dyi, x_init_valuei in - zip(x, x_shape, dx, dy, x_init_value)] + ret = [_compute_gradient(xi, x_shapei, dxi, y, y_shape, dyi, x_init_valuei, + delta) + for xi, x_shapei, dxi, dyi, x_init_valuei in zip(x, x_shape, dx, dy, + x_init_value)] return ret -def ComputeGradient( - x, x_shape, y, y_shape, x_init_value=None, delta=1e-3, init_targets=None): +def compute_gradient(x, + x_shape, + y, + y_shape, + x_init_value=None, + delta=1e-3, + init_targets=None): """Computes and returns the theoretical and numerical Jacobian. Args: @@ -219,20 +235,25 @@ def ComputeGradient( number of elements in y. If x is a list, returns a list of two numpy arrays. """ if isinstance(x, list): - return _ComputeGradientList(x, x_shape, y, y_shape, x_init_value, - delta, init_targets) + return _compute_gradient_list(x, x_shape, y, y_shape, x_init_value, delta, + init_targets) else: if init_targets is not None: assert isinstance(init_targets, (list, tuple)) for init in init_targets: init.run() - dx, dy = _ComputeDxAndDy(x, y, y_shape) - ret = _ComputeGradient(x, x_shape, dx, y, y_shape, dy, x_init_value, delta) + dx, dy = _compute_dx_and_dy(x, y, y_shape) + ret = _compute_gradient(x, x_shape, dx, y, y_shape, dy, x_init_value, delta) return ret -def ComputeGradientError( - x, x_shape, y, y_shape, x_init_value=None, delta=1e-3, init_targets=None): +def compute_gradient_error(x, + x_shape, + y, + y_shape, + x_init_value=None, + delta=1e-3, + init_targets=None): """Computes the gradient error. Computes the maximum error for dy/dx between the computed Jacobian and the @@ -263,8 +284,8 @@ def ComputeGradientError( Returns: The maximum error in between the two Jacobians. """ - grad = ComputeGradient(x, x_shape, y, y_shape, x_init_value, - delta, init_targets) + grad = compute_gradient(x, x_shape, y, y_shape, x_init_value, delta, + init_targets) if isinstance(grad, tuple): grad = [grad] return max(np.fabs(j_t - j_n).max() for j_t, j_n in grad) diff --git a/tensorflow/python/kernel_tests/gradient_checker_test.py b/tensorflow/python/kernel_tests/gradient_checker_test.py index 6a835ff651c..2ded0375a87 100644 --- a/tensorflow/python/kernel_tests/gradient_checker_test.py +++ b/tensorflow/python/kernel_tests/gradient_checker_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests.gradient_checker import ComputeGradientError - class GradientCheckerTest(tf.test.TestCase): @@ -37,7 +35,7 @@ class GradientCheckerTest(tf.test.TestCase): y = tf.add(x1, x2, name="y") # checking gradients for x1 - error = ComputeGradientError(x1, size, y, size) + error = tf.test.compute_gradient_error(x1, size, y, size) tf.logging.info("x1 error = %f", error) assert error < 1e-4 @@ -50,7 +48,7 @@ class GradientCheckerTest(tf.test.TestCase): y = tf.add(x1, x2, name="y") # checking gradients for x1 - error = ComputeGradientError(x1, size, y, size) + error = tf.test.compute_gradient_error(x1, size, y, size) tf.logging.info("x1 error = %f", error) assert error < 1e-4 @@ -66,8 +64,12 @@ class GradientCheckerTest(tf.test.TestCase): # checkint gradients for x2 using a special init_value and delta x_init_value = np.asarray(np.arange(6, dtype=np.float64).reshape(2, 3)) - error = ComputeGradientError(x2, size, y, size, x_init_value=x_init_value, - delta=1e-2) + error = tf.test.compute_gradient_error(x2, + size, + y, + size, + x_init_value=x_init_value, + delta=1e-2) tf.logging.info("x2 error = %f", error) assert error < 1e-10 @@ -82,7 +84,7 @@ class GradientCheckerTest(tf.test.TestCase): indices = tf.constant(index_values, name="i") y = tf.gather(params, indices, name="y") - error = ComputeGradientError(params, p_shape, y, y_shape) + error = tf.test.compute_gradient_error(params, p_shape, y, y_shape) tf.logging.info("gather error = %f", error) assert error < 1e-4 @@ -101,7 +103,7 @@ class GradientCheckerTest(tf.test.TestCase): indices2 = tf.constant(index_values2, name="i2") y2 = tf.gather(y, indices2, name="y2") - error = ComputeGradientError(params, p_shape, y2, y2_shape) + error = tf.test.compute_gradient_error(params, p_shape, y2, y2_shape) tf.logging.info("nested gather error = %f", error) assert error < 1e-4 @@ -166,9 +168,11 @@ def BuildAndTestMiniMNIST(param_index, tag): cost = tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="cost") # Test the gradients. - err = ComputeGradientError(all_params[param_index], - param_sizes[param_index], - cost, [batch], delta=1e-5) + err = tf.test.compute_gradient_error(all_params[param_index], + param_sizes[param_index], + cost, + [batch], + delta=1e-5) tf.logging.info("Mini MNIST: %s gradient error = %g", tag, err) return err diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py index 40da6c3ad51..8c9c47ac622 100644 --- a/tensorflow/python/kernel_tests/linalg_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg_grad_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class MatrixInverseGradientTest(tf.test.TestCase): pass # Filled in below @@ -49,11 +47,11 @@ def _GetMatrixInverseGradientTest(dtype_, shape_): else: ainv = tf.batch_matrix_inverse(a) - theoretical, numerical = gc.ComputeGradient(a, - shape_, - ainv, - shape_, - delta=delta) + theoretical, numerical = tf.test.compute_gradient(a, + shape_, + ainv, + shape_, + delta=delta) self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) return Test @@ -87,8 +85,11 @@ def _GetMatrixDeterminantGradientTest(dtype_, shape_): c = tf.batch_matrix_determinant(a) out_shape = shape_[:-2] # last two dimensions hold matrices - theoretical, numerical = gc.ComputeGradient(a, shape_, c, out_shape, - delta=delta) + theoretical, numerical = tf.test.compute_gradient(a, + shape_, + c, + out_shape, + delta=delta) self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) diff --git a/tensorflow/models/rnn/linear_test.py b/tensorflow/python/kernel_tests/linear_test.py similarity index 89% rename from tensorflow/models/rnn/linear_test.py rename to tensorflow/python/kernel_tests/linear_test.py index 22c38434133..fdb45411147 100644 --- a/tensorflow/models/rnn/linear_test.py +++ b/tensorflow/python/kernel_tests/linear_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.models.rnn import linear - class LinearTest(tf.test.TestCase): @@ -32,21 +30,21 @@ class LinearTest(tf.test.TestCase): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(1.0)): x = tf.zeros([1, 2]) - l = linear.linear([x], 2, False) + l = tf.nn.rnn_cell.linear([x], 2, False) sess.run([tf.variables.initialize_all_variables()]) res = sess.run([l], {x.name: np.array([[1., 2.]])}) self.assertAllClose(res[0], [[3.0, 3.0]]) # Checks prevent you from accidentally creating a shared function. with self.assertRaises(ValueError) as exc: - l1 = linear.linear([x], 2, False) + l1 = tf.nn.rnn_cell.linear([x], 2, False) self.assertEqual(str(exc.exception)[:12], "Over-sharing") # But you can create a new one in a new scope and share the variables. with tf.variable_scope("l1") as new_scope: - l1 = linear.linear([x], 2, False) + l1 = tf.nn.rnn_cell.linear([x], 2, False) with tf.variable_scope(new_scope, reuse=True): - linear.linear([l1], 2, False) + tf.nn.rnn_cell.linear([l1], 2, False) self.assertEqual(len(tf.trainable_variables()), 2) diff --git a/tensorflow/python/kernel_tests/lrn_op_test.py b/tensorflow/python/kernel_tests/lrn_op_test.py index 4dd7372f4a6..2d7a082b863 100644 --- a/tensorflow/python/kernel_tests/lrn_op_test.py +++ b/tensorflow/python/kernel_tests/lrn_op_test.py @@ -25,9 +25,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests.gradient_checker import ComputeGradientError - - class LRNOpTest(tf.test.TestCase): @@ -107,7 +104,7 @@ class LRNOpTest(tf.test.TestCase): lrn_op = tf.nn.local_response_normalization( inp, name="lrn", depth_radius=lrn_depth_radius, bias=bias, alpha=alpha, beta=beta) - err = ComputeGradientError(inp, shape, lrn_op, shape) + err = tf.test.compute_gradient_error(inp, shape, lrn_op, shape) print("LRN Gradient error ", err) self.assertLess(err, 1e-4) diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index 791951eb982..986aa9797ed 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class MatMulTest(tf.test.TestCase): @@ -161,7 +159,7 @@ class MatMulGradientTest(tf.test.TestCase): y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7], shape=[2, 4], dtype=tf.float64, name="y") m = tf.matmul(x, y, name="matmul") - err = gc.ComputeGradientError(x, [3, 2], m, [3, 4]) + err = tf.test.compute_gradient_error(x, [3, 2], m, [3, 4]) print("matmul input0 gradient err = ", err) self.assertLess(err, 1e-10) @@ -172,7 +170,7 @@ class MatMulGradientTest(tf.test.TestCase): y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7], shape=[2, 4], dtype=tf.float64, name="y") m = tf.matmul(x, y, name="matmul") - err = gc.ComputeGradientError(y, [2, 4], m, [3, 4]) + err = tf.test.compute_gradient_error(y, [2, 4], m, [3, 4]) print("matmul input1 gradient err = ", err) self.assertLess(err, 1e-10) @@ -189,7 +187,7 @@ class MatMulGradientTest(tf.test.TestCase): y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7], shape=shape_y, dtype=tf.float64, name="y") m = tf.matmul(x, y, transpose_a, transpose_b, name="matmul") - err = gc.ComputeGradientError(x, shape_x, m, [3, 4]) + err = tf.test.compute_gradient_error(x, shape_x, m, [3, 4]) print("matmul input0 gradient err = ", err) self.assertLess(err, 1e-10) @@ -211,7 +209,7 @@ class MatMulGradientTest(tf.test.TestCase): y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7], shape=shape_y, dtype=tf.float64, name="y") m = tf.matmul(x, y, transpose_a, transpose_b, name="matmul") - err = gc.ComputeGradientError(y, shape_y, m, [3, 4]) + err = tf.test.compute_gradient_error(y, shape_y, m, [3, 4]) print("matmul input1 gradient err = ", err) self.assertLess(err, 1e-10) diff --git a/tensorflow/python/kernel_tests/pack_op_test.py b/tensorflow/python/kernel_tests/pack_op_test.py index f9bdadb82b7..03f580169f2 100644 --- a/tensorflow/python/kernel_tests/pack_op_test.py +++ b/tensorflow/python/kernel_tests/pack_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker - class PackOpTest(tf.test.TestCase): @@ -51,7 +49,7 @@ class PackOpTest(tf.test.TestCase): # TODO(irving): Remove list() once we handle maps correctly xs = list(map(tf.constant, data)) c = tf.pack(xs) - err = gradient_checker.ComputeGradientError(xs, shapes, c, shape) + err = tf.test.compute_gradient_error(xs, shapes, c, shape) self.assertLess(err, 1e-6) def testZeroSize(self): diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py index 46f1b4a3a1e..754642b204d 100644 --- a/tensorflow/python/kernel_tests/pad_op_test.py +++ b/tensorflow/python/kernel_tests/pad_op_test.py @@ -24,8 +24,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class PadOpTest(tf.test.TestCase): @@ -58,7 +56,11 @@ class PadOpTest(tf.test.TestCase): y = tf.pad(inx, ina) # Expected y's shape to be: ys = list(np.array(x.shape) + np.sum(np.array(a), axis=1)) - jacob_t, jacob_n = gc.ComputeGradient(inx, xs, y, ys, x_init_value=x) + jacob_t, jacob_n = tf.test.compute_gradient(inx, + xs, + y, + ys, + x_init_value=x) self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5) def _testAll(self, np_inputs, paddings): diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index ab36ec5fde5..427d83b2106 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -23,7 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc from tensorflow.python.ops import gen_nn_ops @@ -436,9 +435,12 @@ class PoolingTest(tf.test.TestCase): t = pool_func(input_tensor, ksize=[1, window_rows, window_rows, 1], strides=[1, row_stride, col_stride, 1], padding=padding, name=func_name) - err = gc.ComputeGradientError( - input_tensor, input_sizes, t, output_sizes, - x_init_value=x_init_value, delta=1e-2) + err = tf.test.compute_gradient_error(input_tensor, + input_sizes, + t, + output_sizes, + x_init_value=x_init_value, + delta=1e-2) print("%s gradient error = " % func_name, err) self.assertLess(err, err_margin) diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index afb437ea3c2..3867034dc16 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -24,7 +24,6 @@ import numpy as np import tensorflow as tf from tensorflow.python.framework import tensor_shape -from tensorflow.python.kernel_tests import gradient_checker class SumReductionTest(tf.test.TestCase): @@ -150,13 +149,12 @@ class SumReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_sum(t, reduction_axes) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, - shape, - su, - sum_shape, - x_init_value=x, - delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + shape, + su, + sum_shape, + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) def testGradient(self): @@ -211,18 +209,30 @@ class MeanReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_mean(t, [1, 2]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) su = tf.reduce_mean(t, [0, 1, 2, 3]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [1], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [1], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) su = tf.reduce_mean(t, []) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 3, 4, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 3, 4, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) @@ -269,18 +279,30 @@ class ProdReductionTest(tf.test.TestCase): t = tf.convert_to_tensor(x) su = tf.reduce_prod(t, []) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 3, 4, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 3, 4, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) su = tf.reduce_prod(t, [1, 2]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) su = tf.reduce_prod(t, [0, 1, 2, 3]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [1], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [1], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) # NOTE(kearnes): the current gradient calculation gives NaNs for 0 inputs @@ -288,8 +310,12 @@ class ProdReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_prod(t, []) - jacob_t, _ = gradient_checker.ComputeGradient( - t, s, su, [2, 3, 4, 2], x_init_value=x, delta=1) + jacob_t, _ = tf.test.compute_gradient(t, + s, + su, + [2, 3, 4, 2], + x_init_value=x, + delta=1) with self.assertRaisesOpError("Tensor had NaN values"): tf.check_numerics(jacob_t, message="_ProdGrad NaN test").op.run() @@ -336,8 +362,12 @@ class MinReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_min(t, [1, 2]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) def testGradient2(self): @@ -346,8 +376,12 @@ class MinReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_min(t, [1]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 4, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 4, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) def testGradient3(self): @@ -356,8 +390,12 @@ class MinReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_min(t, [2]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 3, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 3, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) def testGradient4(self): @@ -366,8 +404,12 @@ class MinReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_min(t) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [1], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [1], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) @@ -414,8 +456,12 @@ class MaxReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_max(t, [1, 2]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) def testGradient2(self): @@ -424,8 +470,12 @@ class MaxReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_max(t, [1]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 4, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 4, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) def testGradient3(self): @@ -434,8 +484,12 @@ class MaxReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_max(t, [2]) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [2, 3, 2], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [2, 3, 2], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) def testGradient4(self): @@ -444,8 +498,12 @@ class MaxReductionTest(tf.test.TestCase): with self.test_session(): t = tf.convert_to_tensor(x) su = tf.reduce_max(t) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - t, s, su, [1], x_init_value=x, delta=1) + jacob_t, jacob_n = tf.test.compute_gradient(t, + s, + su, + [1], + x_init_value=x, + delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8) diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 0dbece5897d..38ab52b0c16 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class ReluTest(tf.test.TestCase): @@ -67,7 +65,11 @@ class ReluTest(tf.test.TestCase): x_init = np.asarray( [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], dtype=np.float32, order="F") - err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init) + err = tf.test.compute_gradient_error(x, + [2, 5], + y, + [2, 5], + x_init_value=x_init) print("relu (float) gradient err = ", err) self.assertLess(err, 1e-4) @@ -98,7 +100,11 @@ class ReluTest(tf.test.TestCase): x_init = np.asarray( [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], dtype=np.float64, order="F") - err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init) + err = tf.test.compute_gradient_error(x, + [2, 5], + y, + [2, 5], + x_init_value=x_init) print("relu (double) gradient err = ", err) self.assertLess(err, 1e-10) @@ -112,8 +118,11 @@ class ReluTest(tf.test.TestCase): x_init = np.asarray( [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], dtype=np.float32, order="F") - err = gc.ComputeGradientError(x, [2, 5], z[0], [2, 5], - x_init_value=x_init) + err = tf.test.compute_gradient_error(x, + [2, 5], + z[0], + [2, 5], + x_init_value=x_init) print("relu (float) gradient of gradient err = ", err) self.assertLess(err, 1e-4) @@ -127,8 +136,11 @@ class ReluTest(tf.test.TestCase): x_init = np.asarray( [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], dtype=np.float64, order="F") - err = gc.ComputeGradientError(x, [2, 5], z[0], [2, 5], - x_init_value=x_init) + err = tf.test.compute_gradient_error(x, + [2, 5], + z[0], + [2, 5], + x_init_value=x_init) print("relu (double) gradient of gradient err = ", err) self.assertLess(err, 1e-10) @@ -178,7 +190,11 @@ class Relu6Test(tf.test.TestCase): x_init = np.asarray( [[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]], dtype=np.float32, order="F") - err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init) + err = tf.test.compute_gradient_error(x, + [2, 5], + y, + [2, 5], + x_init_value=x_init) print("relu6 (float) gradient err = ", err) self.assertLess(err, 1e-4) @@ -191,7 +207,11 @@ class Relu6Test(tf.test.TestCase): x_init = np.asarray( [[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]], dtype=np.float64, order="F") - err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init) + err = tf.test.compute_gradient_error(x, + [2, 5], + y, + [2, 5], + x_init_value=x_init) print("relu6 (double) gradient err = ", err) self.assertLess(err, 1e-10) diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py index fd1130c082d..f3fc9086d6c 100644 --- a/tensorflow/python/kernel_tests/reshape_op_test.py +++ b/tensorflow/python/kernel_tests/reshape_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class ReshapeTest(tf.test.TestCase): @@ -81,8 +79,11 @@ class ReshapeTest(tf.test.TestCase): with self.test_session(): input_tensor = tf.constant(x, shape=[2, 3, 4]) reshape_out = tf.reshape(input_tensor, [1, 8, 3]) - err = gc.ComputeGradientError(input_tensor, s, - reshape_out, s, x_init_value=x) + err = tf.test.compute_gradient_error(input_tensor, + s, + reshape_out, + s, + x_init_value=x) print("Reshape gradient error = " % err) self.assertLess(err, 1e-3) diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py index ba90a35b1b6..f2bc9641091 100644 --- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py +++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py @@ -23,15 +23,14 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class ReverseSequenceTest(tf.test.TestCase): - def _testReverseSequence(self, x, seq_dim, seq_lengths, + def _testReverseSequence(self, x, batch_dim, seq_dim, seq_lengths, truth, use_gpu=False, expected_err_re=None): with self.test_session(use_gpu=use_gpu): ans = tf.reverse_sequence(x, + batch_dim=batch_dim, seq_dim=seq_dim, seq_lengths=seq_lengths) if expected_err_re is None: @@ -42,11 +41,11 @@ class ReverseSequenceTest(tf.test.TestCase): with self.assertRaisesOpError(expected_err_re): ans.eval() - def _testBothReverseSequence(self, x, seq_dim, seq_lengths, + def _testBothReverseSequence(self, x, batch_dim, seq_dim, seq_lengths, truth, expected_err_re=None): - self._testReverseSequence(x, seq_dim, seq_lengths, + self._testReverseSequence(x, batch_dim, seq_dim, seq_lengths, truth, True, expected_err_re) - self._testReverseSequence(x, seq_dim, seq_lengths, + self._testReverseSequence(x, batch_dim, seq_dim, seq_lengths, truth, False, expected_err_re) def _testBasic(self, dtype): @@ -55,18 +54,22 @@ class ReverseSequenceTest(tf.test.TestCase): [[9, 10, 11, 12], [13, 14, 15, 16]], [[17, 18, 19, 20], [21, 22, 23, 24]]], dtype=dtype) x = x.reshape(3, 2, 4, 1, 1) + x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 # reverse dim 2 up to (0:3, none, 0:4) along dim=0 - seq_dim = 2 seq_lengths = np.asarray([3, 0, 4], dtype=np.int64) - truth = np.asarray( + truth_orig = np.asarray( [[[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3 [[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none [[20, 19, 18, 17], [24, 23, 22, 21]]], # reverse 0:4 (all) dtype=dtype) - truth = truth.reshape(3, 2, 4, 1, 1) - self._testBothReverseSequence(x, seq_dim, seq_lengths, truth) + truth_orig = truth_orig.reshape(3, 2, 4, 1, 1) + truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 + + seq_dim = 0 # permute seq_dim and batch_dim (originally 2 and 0, resp.) + batch_dim = 2 + self._testBothReverseSequence(x, batch_dim, seq_dim, seq_lengths, truth) def testFloatBasic(self): self._testBasic(np.float32) @@ -89,22 +92,25 @@ class ReverseSequenceTest(tf.test.TestCase): [[9, 10, 11, 12], [13, 14, 15, 16]], [[17, 18, 19, 20], [21, 22, 23, 24]]], dtype=np.float) x = x.reshape(3, 2, 4, 1, 1) + x = x.transpose([2, 1, 0, 3, 4]) # transpose axes 0 <=> 2 - # reverse dim 2 up to (0:3, none, 0:4) along dim=0 - seq_dim = 2 + # reverse dim 0 up to (0:3, none, 0:4) along dim=2 + seq_dim = 0 + batch_dim = 2 seq_lengths = np.asarray([3, 0, 4], dtype=np.int64) with self.test_session(): input_t = tf.constant(x, shape=x.shape) seq_lengths_t = tf.constant(seq_lengths, shape=seq_lengths.shape) reverse_sequence_out = tf.reverse_sequence(input_t, + batch_dim=batch_dim, seq_dim=seq_dim, seq_lengths=seq_lengths_t) - err = gc.ComputeGradientError(input_t, - x.shape, - reverse_sequence_out, - x.shape, - x_init_value=x) + err = tf.test.compute_gradient_error(input_t, + x.shape, + reverse_sequence_out, + x.shape, + x_init_value=x) print("ReverseSequence gradient error = %g" % err) self.assertLess(err, 1e-8) @@ -123,6 +129,26 @@ class ReverseSequenceTest(tf.test.TestCase): seq_lengths=tf.placeholder(tf.int64, shape=(32,)), seq_dim=3) + # batch_dim out of bounds. + with self.assertRaisesRegexp( + ValueError, "batch_dim must be < input.dims()"): + tf.reverse_sequence( + tf.placeholder(tf.float32, shape=(32, 2, 3)), + seq_lengths=tf.placeholder(tf.int64, shape=(32,)), + seq_dim=0, + batch_dim=3) + + with self.test_session(): + inputs = tf.placeholder(tf.float32, shape=(32, 2, 3)) + seq_lengths = tf.placeholder(tf.int64, shape=(32,)) + output = tf.reverse_sequence( + inputs, + seq_lengths=seq_lengths, + seq_dim=0) # batch_dim default is 0 + with self.assertRaisesOpError("batch_dim == seq_dim"): + output.eval(feed_dict={inputs: np.random.rand(32, 2, 3), + seq_lengths: xrange(32)}) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/models/rnn/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py similarity index 97% rename from tensorflow/models/rnn/rnn_cell_test.py rename to tensorflow/python/kernel_tests/rnn_cell_test.py index 53d7caf2b7f..fefe4b078dc 100644 --- a/tensorflow/models/rnn/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -26,7 +26,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf -from tensorflow.models.rnn import rnn_cell +from tensorflow.python.ops import rnn_cell class RNNCellTest(tf.test.TestCase): @@ -96,9 +96,9 @@ class RNNCellTest(tf.test.TestCase): # Different inputs so different outputs and states for i in range(1, batch_size): self.assertTrue( - float(np.linalg.norm((res[0][0,:] - res[0][i,:]))) > 1e-6) + float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6) self.assertTrue( - float(np.linalg.norm((res[1][0,:] - res[1][i,:]))) > 1e-6) + float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) def testOutputProjectionWrapper(self): with self.test_session() as sess: diff --git a/tensorflow/models/rnn/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py similarity index 91% rename from tensorflow/models/rnn/rnn_test.py rename to tensorflow/python/kernel_tests/rnn_test.py index 108a615f9af..1ed53c0a1f3 100644 --- a/tensorflow/models/rnn/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -25,11 +25,8 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.models.rnn import rnn -from tensorflow.models.rnn import rnn_cell - -class Plus1RNNCell(rnn_cell.RNNCell): +class Plus1RNNCell(tf.nn.rnn_cell.RNNCell): """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" @property @@ -68,7 +65,7 @@ class RNNTest(tf.test.TestCase): cell = Plus1RNNCell() batch_size = 2 inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10 - outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) for out, inp in zip(outputs, inputs): self.assertEqual(out.get_shape(), inp.get_shape()) @@ -89,14 +86,15 @@ class RNNTest(tf.test.TestCase): def testDropout(self): cell = Plus1RNNCell() - full_dropout_cell = rnn_cell.DropoutWrapper( + full_dropout_cell = tf.nn.rnn_cell.DropoutWrapper( cell, input_keep_prob=1e-12, seed=0) batch_size = 2 inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10 with tf.variable_scope("share_scope"): - outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32) with tf.variable_scope("drop_scope"): - dropped_outputs, _ = rnn.rnn(full_dropout_cell, inputs, dtype=tf.float32) + dropped_outputs, _ = tf.nn.rnn( + full_dropout_cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) for out, inp in zip(outputs, inputs): self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list()) @@ -120,7 +118,7 @@ class RNNTest(tf.test.TestCase): batch_size = 2 inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10 with tf.variable_scope("drop_scope"): - dynamic_outputs, dynamic_states = rnn.rnn( + dynamic_outputs, dynamic_states = tf.nn.rnn( cell, inputs, sequence_length=sequence_length, dtype=tf.float32) self.assertEqual(len(dynamic_outputs), len(inputs)) self.assertEqual(len(dynamic_states), len(inputs)) @@ -158,11 +156,11 @@ class LSTMTest(tf.test.TestCase): batch_size = 2 with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( + cell = tf.nn.rnn_cell.LSTMCell( num_units, input_size, initializer=initializer) inputs = 10 * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] - outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) for out in outputs: self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) @@ -177,12 +175,12 @@ class LSTMTest(tf.test.TestCase): batch_size = 2 with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( + cell = tf.nn.rnn_cell.LSTMCell( num_units, input_size, use_peepholes=True, cell_clip=0.0, initializer=initializer) inputs = 10 * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] - outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) for out in outputs: self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) @@ -202,12 +200,12 @@ class LSTMTest(tf.test.TestCase): with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) state_saver = TestStateSaver(batch_size, 2*num_units) - cell = rnn_cell.LSTMCell( + cell = tf.nn.rnn_cell.LSTMCell( num_units, input_size, use_peepholes=False, initializer=initializer) inputs = 10 * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] with tf.variable_scope("share_scope"): - outputs, states = rnn.state_saving_rnn( + outputs, states = tf.nn.state_saving_rnn( cell, inputs, state_saver=state_saver, state_name="save_lstm") self.assertEqual(len(outputs), len(inputs)) for out in outputs: @@ -229,10 +227,10 @@ class LSTMTest(tf.test.TestCase): initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) inputs = 10 * [ tf.placeholder(tf.float32, shape=(None, input_size))] - cell = rnn_cell.LSTMCell( + cell = tf.nn.rnn_cell.LSTMCell( num_units, input_size, use_peepholes=True, num_proj=num_proj, initializer=initializer) - outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) tf.initialize_all_variables().run() @@ -252,7 +250,7 @@ class LSTMTest(tf.test.TestCase): inputs = 10 * [ tf.placeholder(tf.float32, shape=(None, input_size))] - cell = rnn_cell.LSTMCell( + cell = tf.nn.rnn_cell.LSTMCell( num_units, input_size=input_size, use_peepholes=True, @@ -261,7 +259,7 @@ class LSTMTest(tf.test.TestCase): num_proj_shards=num_proj_shards, initializer=initializer) - outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) @@ -280,7 +278,7 @@ class LSTMTest(tf.test.TestCase): initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed) inputs = 10 * [tf.placeholder(tf.float64)] - cell = rnn_cell.LSTMCell( + cell = tf.nn.rnn_cell.LSTMCell( num_units, input_size=input_size, use_peepholes=True, @@ -289,7 +287,7 @@ class LSTMTest(tf.test.TestCase): num_proj_shards=num_proj_shards, initializer=initializer) - outputs, _ = rnn.rnn( + outputs, _ = tf.nn.rnn( cell, inputs, initial_state=cell.zero_state(batch_size, tf.float64)) self.assertEqual(len(outputs), len(inputs)) @@ -311,7 +309,7 @@ class LSTMTest(tf.test.TestCase): inputs = 10 * [tf.placeholder(tf.float32)] initializer = tf.constant_initializer(0.001) - cell_noshard = rnn_cell.LSTMCell( + cell_noshard = tf.nn.rnn_cell.LSTMCell( num_units, input_size, num_proj=num_proj, use_peepholes=True, @@ -319,15 +317,15 @@ class LSTMTest(tf.test.TestCase): num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards) - cell_shard = rnn_cell.LSTMCell( + cell_shard = tf.nn.rnn_cell.LSTMCell( num_units, input_size, use_peepholes=True, initializer=initializer, num_proj=num_proj) with tf.variable_scope("noshard_scope"): - outputs_noshard, states_noshard = rnn.rnn( + outputs_noshard, states_noshard = tf.nn.rnn( cell_noshard, inputs, dtype=tf.float32) with tf.variable_scope("shard_scope"): - outputs_shard, states_shard = rnn.rnn( + outputs_shard, states_shard = tf.nn.rnn( cell_shard, inputs, dtype=tf.float32) self.assertEqual(len(outputs_noshard), len(inputs)) @@ -362,7 +360,7 @@ class LSTMTest(tf.test.TestCase): initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) inputs = 10 * [tf.placeholder(tf.float64)] - cell = rnn_cell.LSTMCell( + cell = tf.nn.rnn_cell.LSTMCell( num_units, input_size=input_size, use_peepholes=True, @@ -370,9 +368,9 @@ class LSTMTest(tf.test.TestCase): num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards, initializer=initializer) - dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0) + dropout_cell = tf.nn.rnn_cell.DropoutWrapper(cell, 0.5, seed=0) - outputs, states = rnn.rnn( + outputs, states = tf.nn.rnn( dropout_cell, inputs, sequence_length=sequence_length, initial_state=cell.zero_state(batch_size, tf.float64)) @@ -398,16 +396,16 @@ class LSTMTest(tf.test.TestCase): initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed) inputs = 10 * [ tf.placeholder(tf.float32, shape=(None, input_size))] - cell = rnn_cell.LSTMCell( + cell = tf.nn.rnn_cell.LSTMCell( num_units, input_size, use_peepholes=True, num_proj=num_proj, initializer=initializer) with tf.variable_scope("share_scope"): - outputs0, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs0, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) with tf.variable_scope("share_scope", reuse=True): - outputs1, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs1, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) with tf.variable_scope("diff_scope"): - outputs2, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs2, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) tf.initialize_all_variables().run() input_value = np.random.randn(batch_size, input_size) @@ -433,16 +431,16 @@ class LSTMTest(tf.test.TestCase): initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed) inputs = 10 * [ tf.placeholder(tf.float32, shape=(None, input_size))] - cell = rnn_cell.LSTMCell( + cell = tf.nn.rnn_cell.LSTMCell( num_units, input_size, use_peepholes=True, num_proj=num_proj, initializer=initializer) with tf.name_scope("scope0"): with tf.variable_scope("share_scope"): - outputs0, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs0, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) with tf.name_scope("scope1"): with tf.variable_scope("share_scope", reuse=True): - outputs1, _ = rnn.rnn(cell, inputs, dtype=tf.float32) + outputs1, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) tf.initialize_all_variables().run() input_value = np.random.randn(batch_size, input_size) diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index c84921f21e1..adfe42009a0 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker - class SegmentReductionHelper(tf.test.TestCase): @@ -127,8 +125,12 @@ class SegmentReductionOpTest(SegmentReductionHelper): with self.test_session(): tf_x, np_x = self._input(shape, dtype=tf.float64) s = tf_op(data=tf_x, segment_ids=indices) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - tf_x, shape, s, [3, 4], x_init_value=np_x.astype(np.double), + jacob_t, jacob_n = tf.test.compute_gradient( + tf_x, + shape, + s, + [3, 4], + x_init_value=np_x.astype(np.double), delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) @@ -170,7 +172,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): s = tf.unsorted_segment_sum(data=tf_x, segment_ids=indices, num_segments=num_segments) - jacob_t, jacob_n = gradient_checker.ComputeGradient( + jacob_t, jacob_n = tf.test.compute_gradient( tf_x, shape, s, @@ -196,14 +198,20 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): unsorted_s = tf.unsorted_segment_sum(data=tf_x, segment_ids=indices, num_segments=num_segments) - unsorted_jacob_t, unsorted_jacob_n = gradient_checker.ComputeGradient( - tf_x, shape, unsorted_s, [num_segments, num_cols], + (unsorted_jacob_t, unsorted_jacob_n) = tf.test.compute_gradient( + tf_x, + shape, + unsorted_s, + [num_segments, num_cols], x_init_value=np_x.astype(np.double), delta=1) # Results from SegmentSum sorted_s = tf.segment_sum(data=tf_x, segment_ids=indices) - sorted_jacob_t, sorted_jacob_n = gradient_checker.ComputeGradient( - tf_x, shape, sorted_s, [num_segments, num_cols], + sorted_jacob_t, sorted_jacob_n = tf.test.compute_gradient( + tf_x, + shape, + sorted_s, + [num_segments, num_cols], x_init_value=np_x.astype(np.double), delta=1) self.assertAllClose(unsorted_jacob_t, sorted_jacob_t, rtol=1e-3, atol=1e-3) @@ -277,8 +285,12 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): tf_indices, _, tf_x, np_x = self._sparse_input( shape, num_indices, dtype=tf.float64) s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices) - jacob_t, jacob_n = gradient_checker.ComputeGradient( - tf_x, shape, s, [3, 4], x_init_value=np_x.astype(np.double), + jacob_t, jacob_n = tf.test.compute_gradient( + tf_x, + shape, + s, + [3, 4], + x_init_value=np_x.astype(np.double), delta=1) self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3) diff --git a/tensorflow/models/rnn/seq2seq_test.py b/tensorflow/python/kernel_tests/seq2seq_test.py similarity index 74% rename from tensorflow/models/rnn/seq2seq_test.py rename to tensorflow/python/kernel_tests/seq2seq_test.py index 12d22630f0f..5ee2845780d 100644 --- a/tensorflow/models/rnn/seq2seq_test.py +++ b/tensorflow/python/kernel_tests/seq2seq_test.py @@ -21,16 +21,13 @@ from __future__ import print_function import math import random +# pylint: disable=g-bad-import-order,unused-import import tensorflow.python.platform import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf -from tensorflow.models.rnn import rnn -from tensorflow.models.rnn import rnn_cell -from tensorflow.models.rnn import seq2seq - class Seq2SeqTest(tf.test.TestCase): @@ -38,10 +35,12 @@ class Seq2SeqTest(tf.test.TestCase): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] - _, enc_states = rnn.rnn(rnn_cell.GRUCell(2), inp, dtype=tf.float32) + _, enc_states = tf.nn.rnn( + tf.nn.rnn_cell.GRUCell(2), inp, dtype=tf.float32) dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] - cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) - dec, mem = seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell) + cell = tf.nn.rnn_cell.OutputProjectionWrapper( + tf.nn.rnn_cell.GRUCell(2), 4) + dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) @@ -56,8 +55,9 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] - cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) - dec, mem = seq2seq.basic_rnn_seq2seq(inp, dec_inp, cell) + cell = tf.nn.rnn_cell.OutputProjectionWrapper( + tf.nn.rnn_cell.GRUCell(2), 4) + dec, mem = tf.nn.seq2seq.basic_rnn_seq2seq(inp, dec_inp, cell) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) @@ -72,8 +72,9 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] - cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) - dec, mem = seq2seq.tied_rnn_seq2seq(inp, dec_inp, cell) + cell = tf.nn.rnn_cell.OutputProjectionWrapper( + tf.nn.rnn_cell.GRUCell(2), 4) + dec, mem = tf.nn.seq2seq.tied_rnn_seq2seq(inp, dec_inp, cell) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) @@ -87,11 +88,11 @@ class Seq2SeqTest(tf.test.TestCase): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] - cell = rnn_cell.BasicLSTMCell(2) - _, enc_states = rnn.rnn(cell, inp, dtype=tf.float32) + cell = tf.nn.rnn_cell.BasicLSTMCell(2) + _, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32) dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] - dec, mem = seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1], - cell, 4) + dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1], + cell, 4) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) @@ -106,8 +107,9 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)] dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] - cell = rnn_cell.BasicLSTMCell(2) - dec, mem = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp, cell, 2, 5) + cell = tf.nn.rnn_cell.BasicLSTMCell(2) + dec, mem = tf.nn.seq2seq.embedding_rnn_seq2seq( + enc_inp, dec_inp, cell, 2, 5) sess.run([tf.variables.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) @@ -121,7 +123,7 @@ class Seq2SeqTest(tf.test.TestCase): w = tf.get_variable("proj_w", [2, 5]) b = tf.get_variable("proj_b", [5]) with tf.variable_scope("proj_seq2seq"): - dec, _ = seq2seq.embedding_rnn_seq2seq( + dec, _ = tf.nn.seq2seq.embedding_rnn_seq2seq( enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b)) sess.run([tf.variables.initialize_all_variables()]) res = sess.run(dec) @@ -131,12 +133,15 @@ class Seq2SeqTest(tf.test.TestCase): # Test that previous-feeding model ignores inputs after the first. dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)] tf.get_variable_scope().reuse_variables() - d1, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp, cell, 2, 5, - feed_previous=True) - d2, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp2, cell, 2, 5, - feed_previous=True) - d3, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp2, cell, 2, 5, - feed_previous=tf.constant(True)) + d1, _ = tf.nn.seq2seq.embedding_rnn_seq2seq( + enc_inp, dec_inp, cell, 2, 5, + feed_previous=True) + d2, _ = tf.nn.seq2seq.embedding_rnn_seq2seq( + enc_inp, dec_inp2, cell, 2, 5, + feed_previous=True) + d3, _ = tf.nn.seq2seq.embedding_rnn_seq2seq( + enc_inp, dec_inp2, cell, 2, 5, + feed_previous=tf.constant(True)) res1 = sess.run(d1) res2 = sess.run(d2) res3 = sess.run(d3) @@ -148,8 +153,9 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)] dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] - cell = rnn_cell.BasicLSTMCell(2) - dec, mem = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp, cell, 5) + cell = tf.nn.rnn_cell.BasicLSTMCell(2) + dec, mem = tf.nn.seq2seq.embedding_tied_rnn_seq2seq( + enc_inp, dec_inp, cell, 5) sess.run([tf.variables.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) @@ -163,7 +169,7 @@ class Seq2SeqTest(tf.test.TestCase): w = tf.get_variable("proj_w", [2, 5]) b = tf.get_variable("proj_b", [5]) with tf.variable_scope("proj_seq2seq"): - dec, _ = seq2seq.embedding_tied_rnn_seq2seq( + dec, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell, 5, output_projection=(w, b)) sess.run([tf.variables.initialize_all_variables()]) res = sess.run(dec) @@ -173,11 +179,13 @@ class Seq2SeqTest(tf.test.TestCase): # Test that previous-feeding model ignores inputs after the first. dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)] tf.get_variable_scope().reuse_variables() - d1, _ = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp, cell, 5, - feed_previous=True) - d2, _ = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp2, cell, 5, - feed_previous=True) - d3, _ = seq2seq.embedding_tied_rnn_seq2seq( + d1, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq( + enc_inp, dec_inp, cell, 5, + feed_previous=True) + d2, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq( + enc_inp, dec_inp2, cell, 5, + feed_previous=True) + d3, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq( enc_inp, dec_inp2, cell, 5, feed_previous=tf.constant(True)) res1 = sess.run(d1) res2 = sess.run(d2) @@ -188,14 +196,15 @@ class Seq2SeqTest(tf.test.TestCase): def testAttentionDecoder1(self): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): - cell = rnn_cell.GRUCell(2) + cell = tf.nn.rnn_cell.GRUCell(2) inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] - enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32) attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs]) dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] - dec, mem = seq2seq.attention_decoder(dec_inp, enc_states[-1], - attn_states, cell, output_size=4) + dec, mem = tf.nn.seq2seq.attention_decoder( + dec_inp, enc_states[-1], + attn_states, cell, output_size=4) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) @@ -208,15 +217,16 @@ class Seq2SeqTest(tf.test.TestCase): def testAttentionDecoder2(self): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): - cell = rnn_cell.GRUCell(2) + cell = tf.nn.rnn_cell.GRUCell(2) inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] - enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32) attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs]) dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] - dec, mem = seq2seq.attention_decoder(dec_inp, enc_states[-1], - attn_states, cell, output_size=4, - num_heads=2) + dec, mem = tf.nn.seq2seq.attention_decoder( + dec_inp, enc_states[-1], + attn_states, cell, output_size=4, + num_heads=2) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) @@ -230,14 +240,15 @@ class Seq2SeqTest(tf.test.TestCase): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] - cell = rnn_cell.GRUCell(2) - enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32) + cell = tf.nn.rnn_cell.GRUCell(2) + enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32) attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs]) dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] - dec, mem = seq2seq.embedding_attention_decoder(dec_inp, enc_states[-1], - attn_states, cell, 4, - output_size=3) + dec, mem = tf.nn.seq2seq.embedding_attention_decoder( + dec_inp, enc_states[-1], + attn_states, cell, 4, + output_size=3) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) @@ -252,8 +263,8 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)] dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] - cell = rnn_cell.BasicLSTMCell(2) - dec, mem = seq2seq.embedding_attention_seq2seq( + cell = tf.nn.rnn_cell.BasicLSTMCell(2) + dec, mem = tf.nn.seq2seq.embedding_attention_seq2seq( enc_inp, dec_inp, cell, 2, 5) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) @@ -268,7 +279,7 @@ class Seq2SeqTest(tf.test.TestCase): w = tf.get_variable("proj_w", [2, 5]) b = tf.get_variable("proj_b", [5]) with tf.variable_scope("proj_seq2seq"): - dec, _ = seq2seq.embedding_attention_seq2seq( + dec, _ = tf.nn.seq2seq.embedding_attention_seq2seq( enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b)) sess.run([tf.variables.initialize_all_variables()]) res = sess.run(dec) @@ -278,11 +289,11 @@ class Seq2SeqTest(tf.test.TestCase): # Test that previous-feeding model ignores inputs after the first. dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)] tf.get_variable_scope().reuse_variables() - d1, _ = seq2seq.embedding_attention_seq2seq( + d1, _ = tf.nn.seq2seq.embedding_attention_seq2seq( enc_inp, dec_inp, cell, 2, 5, feed_previous=True) - d2, _ = seq2seq.embedding_attention_seq2seq( + d2, _ = tf.nn.seq2seq.embedding_attention_seq2seq( enc_inp, dec_inp2, cell, 2, 5, feed_previous=True) - d3, _ = seq2seq.embedding_attention_seq2seq( + d3, _ = tf.nn.seq2seq.embedding_attention_seq2seq( enc_inp, dec_inp2, cell, 2, 5, feed_previous=tf.constant(True)) res1 = sess.run(d1) res2 = sess.run(d2) @@ -297,21 +308,21 @@ class Seq2SeqTest(tf.test.TestCase): targets = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] weights = [tf.constant(1.0, shape=[2]) for i in xrange(3)] - average_loss_per_example = seq2seq.sequence_loss( + average_loss_per_example = tf.nn.seq2seq.sequence_loss( logits, targets, weights, output_classes, average_across_timesteps=True, average_across_batch=True) res = sess.run(average_loss_per_example) self.assertAllClose(res, 1.60944) - average_loss_per_sequence = seq2seq.sequence_loss( + average_loss_per_sequence = tf.nn.seq2seq.sequence_loss( logits, targets, weights, output_classes, average_across_timesteps=False, average_across_batch=True) res = sess.run(average_loss_per_sequence) self.assertAllClose(res, 4.828314) - total_loss = seq2seq.sequence_loss( + total_loss = tf.nn.seq2seq.sequence_loss( logits, targets, weights, output_classes, average_across_timesteps=False, average_across_batch=False) @@ -326,13 +337,13 @@ class Seq2SeqTest(tf.test.TestCase): targets = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] weights = [tf.constant(1.0, shape=[2]) for i in xrange(3)] - average_loss_per_example = seq2seq.sequence_loss_by_example( + average_loss_per_example = tf.nn.seq2seq.sequence_loss_by_example( logits, targets, weights, output_classes, average_across_timesteps=True) res = sess.run(average_loss_per_example) self.assertAllClose(res, np.asarray([1.609438, 1.609438])) - loss_per_sequence = seq2seq.sequence_loss_by_example( + loss_per_sequence = tf.nn.seq2seq.sequence_loss_by_example( logits, targets, weights, output_classes, average_across_timesteps=False) res = sess.run(loss_per_sequence) @@ -343,26 +354,30 @@ class Seq2SeqTest(tf.test.TestCase): # We learn to copy 10 symbols in 2 buckets: length 4 and length 8. classes = 10 buckets = [(4, 4), (8, 8)] - # We use sampled softmax so we keep output projection separate. - w = tf.get_variable("proj_w", [24, classes]) - w_t = tf.transpose(w) - b = tf.get_variable("proj_b", [classes]) - # Here comes a sample Seq2Seq model using GRU cells. - def SampleGRUSeq2Seq(enc_inp, dec_inp, weights): - """Example sequence-to-sequence model that uses GRU cells.""" - def GRUSeq2Seq(enc_inp, dec_inp): - cell = rnn_cell.MultiRNNCell([rnn_cell.GRUCell(24)] * 2) - return seq2seq.embedding_attention_seq2seq( - enc_inp, dec_inp, cell, classes, classes, output_projection=(w, b)) - targets = [dec_inp[i+1] for i in xrange(len(dec_inp) - 1)] + [0] - def SampledLoss(inputs, labels): - labels = tf.reshape(labels, [-1, 1]) - return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 8, classes) - return seq2seq.model_with_buckets(enc_inp, dec_inp, targets, weights, - buckets, classes, GRUSeq2Seq, - softmax_loss_function=SampledLoss) - # Now we construct the copy model. + with self.test_session() as sess: + # We use sampled softmax so we keep output projection separate. + w = tf.get_variable("proj_w", [24, classes]) + w_t = tf.transpose(w) + b = tf.get_variable("proj_b", [classes]) + # Here comes a sample Seq2Seq model using GRU cells. + def SampleGRUSeq2Seq(enc_inp, dec_inp, weights): + """Example sequence-to-sequence model that uses GRU cells.""" + def GRUSeq2Seq(enc_inp, dec_inp): + cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.GRUCell(24)] * 2) + return tf.nn.seq2seq.embedding_attention_seq2seq( + enc_inp, dec_inp, cell, classes, classes, + output_projection=(w, b)) + targets = [dec_inp[i+1] for i in xrange(len(dec_inp) - 1)] + [0] + def SampledLoss(inputs, labels): + labels = tf.reshape(labels, [-1, 1]) + return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 8, classes) + return tf.nn.seq2seq.model_with_buckets( + enc_inp, dec_inp, targets, weights, + buckets, classes, GRUSeq2Seq, + softmax_loss_function=SampledLoss) + + # Now we construct the copy model. tf.set_random_seed(111) batch_size = 32 inp = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)] diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index b2ff0b92b43..2621ad9dec2 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -24,8 +24,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class ShapeOpsTest(tf.test.TestCase): @@ -119,7 +117,7 @@ class ShapeOpsTest(tf.test.TestCase): dtype=tf.float32) squeezed = tf.expand_dims(inp, 1) - err = gc.ComputeGradientError(inp, [4, 2], squeezed, [4, 1, 2]) + err = tf.test.compute_gradient_error(inp, [4, 2], squeezed, [4, 1, 2]) self.assertLess(err, 1e-3) def testExpandDimsScalar(self): @@ -202,7 +200,7 @@ class ShapeOpsTest(tf.test.TestCase): a = tf.reshape(inp, [4, 1, 2]) squeezed = tf.squeeze(a, []) - err = gc.ComputeGradientError(a, [4, 1, 2], squeezed, [4, 2]) + err = tf.test.compute_gradient_error(a, [4, 1, 2], squeezed, [4, 2]) self.assertLess(err, 1e-3) def testSqueezeGradientWithSqueezeDims(self): @@ -211,7 +209,7 @@ class ShapeOpsTest(tf.test.TestCase): a = tf.reshape(inp, [4, 1, 2, 1]) squeezed = tf.squeeze(a, [1]) - err = gc.ComputeGradientError(a, [4, 1, 2, 1], squeezed, [4, 2, 1]) + err = tf.test.compute_gradient_error(a, [4, 1, 2, 1], squeezed, [4, 2, 1]) self.assertLess(err, 1e-3) @@ -366,8 +364,11 @@ class TileTest(tf.test.TestCase): shape=input_shape, dtype=tf.float64) tiled = tf.tile(a, multiples) grad_shape = list(np.array(multiples) * np.array(inp.shape)) - err = gc.ComputeGradientError(a, list(input_shape), tiled, grad_shape, - x_init_value=inp) + err = tf.test.compute_gradient_error(a, + list(input_shape), + tiled, + grad_shape, + x_init_value=inp) print("tile(float) error = ", err) self.assertLess(err, 1e-3) @@ -382,7 +383,7 @@ class TileTest(tf.test.TestCase): a = tf.constant([float(x) for x in inp.flatten()], shape=[4, 2], dtype=tf.float32) tiled = tf.tile(a, [1, 2]) - err = gc.ComputeGradientError(a, [4, 2], tiled, [4, 4]) + err = tf.test.compute_gradient_error(a, [4, 2], tiled, [4, 4]) self.assertLess(err, 1e-3) def testShapeFunctionEdgeCases(self): diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py index e79fc1ca20f..3575f7ab7c3 100644 --- a/tensorflow/python/kernel_tests/softplus_op_test.py +++ b/tensorflow/python/kernel_tests/softplus_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class SoftplusTest(tf.test.TestCase): @@ -57,7 +55,11 @@ class SoftplusTest(tf.test.TestCase): x_init = np.asarray( [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], dtype=np.float32, order="F") - err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init) + err = tf.test.compute_gradient_error(x, + [2, 5], + y, + [2, 5], + x_init_value=x_init) print("softplus (float) gradient err = ", err) self.assertLess(err, 1e-4) diff --git a/tensorflow/python/kernel_tests/softsign_op_test.py b/tensorflow/python/kernel_tests/softsign_op_test.py new file mode 100644 index 00000000000..fd8431c7c76 --- /dev/null +++ b/tensorflow/python/kernel_tests/softsign_op_test.py @@ -0,0 +1,68 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for Softsign and SoftsignGrad.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class SoftsignTest(tf.test.TestCase): + + def _npSoftsign(self, np_features): + return np_features / (1 + np.abs(np_features)) + + def _testSoftsign(self, np_features, use_gpu=False): + np_softsign = self._npSoftsign(np_features) + with self.test_session(use_gpu=use_gpu): + softsign = tf.nn.softsign(np_features) + tf_softsign = softsign.eval() + self.assertAllClose(np_softsign, tf_softsign) + self.assertShapeEqual(np_softsign, softsign) + + def testNumbers(self): + for t in [np.float, np.double]: + self._testSoftsign( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + use_gpu=False) + self._testSoftsign( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + use_gpu=True) + + def testGradient(self): + with self.test_session(): + x = tf.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], name="x") + y = tf.nn.softsign(x, name="softsign") + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, order="F") + err = tf.test.compute_gradient_error(x, + [2, 5], + y, + [2, 5], + x_init_value=x_init) + print("softsign (float) gradient err = ", err) + self.assertLess(err, 1e-4) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py index 8f0c60c4553..89e9fda178b 100644 --- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - def RandMatrix(rows, cols, tr): if tr: @@ -96,8 +94,10 @@ class MatMulGradientTest(tf.test.TestCase): transpose_b=tr_b, a_is_sparse=sp_a, b_is_sparse=sp_b) - err = (gc.ComputeGradientError(a, [2, 3] if tr_a else [3, 2], m, [3, 4]) + - gc.ComputeGradientError(b, [4, 2] if tr_b else [2, 4], m, [3, 4])) + err = (tf.test.compute_gradient_error(a, [2, 3] + if tr_a else [3, 2], m, [3, 4]) + + tf.test.compute_gradient_error(b, [4, 2] + if tr_b else [2, 4], m, [3, 4])) print("sparse_matmul gradient err = ", err) self.assertLess(err, 1e-3) diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py index 8c5ff7bd7e6..fa38152a865 100644 --- a/tensorflow/python/kernel_tests/transpose_op_test.py +++ b/tensorflow/python/kernel_tests/transpose_op_test.py @@ -24,8 +24,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests.gradient_checker import ComputeGradient - class TransposeTest(tf.test.TestCase): @@ -48,10 +46,10 @@ class TransposeTest(tf.test.TestCase): xs = list(np.shape(x)) ys = list(np.shape(tf_ans)) if x.dtype == np.float32: - jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2) + jacob_t, jacob_n = tf.test.compute_gradient(inx, xs, y, ys, x, 1e-2) self.assertAllClose(jacob_t, jacob_n, 1e-3, 1e-3) elif x.dtype == np.float64: - jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2) + jacob_t, jacob_n = tf.test.compute_gradient(inx, xs, y, ys, x, 1e-2) self.assertAllClose(jacob_t, jacob_n, 1e-6, 1e-6) return tf_ans, jacob_t @@ -70,10 +68,10 @@ class TransposeTest(tf.test.TestCase): xs = list(np.shape(x)) ys = list(np.shape(tf_ans)) if x.dtype == np.float32: - jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2) + jacob_t, jacob_n = tf.test.compute_gradient(inx, xs, y, ys, x, 1e-2) self.assertAllClose(jacob_t, jacob_n, 1e-3, 1e-3) elif x.dtype == np.float64: - jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2) + jacob_t, jacob_n = tf.test.compute_gradient(inx, xs, y, ys, x, 1e-2) self.assertAllClose(jacob_t, jacob_n, 1e-6, 1e-6) return tf_ans, jacob_t diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py index 308b219f318..47ed9e617cb 100644 --- a/tensorflow/python/kernel_tests/unpack_op_test.py +++ b/tensorflow/python/kernel_tests/unpack_op_test.py @@ -24,8 +24,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker - class UnpackOpTest(tf.test.TestCase): @@ -53,8 +51,7 @@ class UnpackOpTest(tf.test.TestCase): with self.test_session(use_gpu=use_gpu): x = tf.constant(data) cs = tf.unpack(x, num=shape[0]) - err = gradient_checker.ComputeGradientError(x, shape, cs[i], - shapes[i]) + err = tf.test.compute_gradient_error(x, shape, cs[i], shapes[i]) self.assertLess(err, 1e-6) def testInferNum(self): diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index d3b01640529..39ec5f10a63 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -23,8 +23,6 @@ import tensorflow.python.platform import numpy as np import tensorflow as tf -from tensorflow.python.kernel_tests import gradient_checker as gc - class XentTest(tf.test.TestCase): @@ -120,7 +118,7 @@ class XentTest(tf.test.TestCase): 0.1, 0.8, 2.7, 6.4], shape=[3, 4], dtype=tf.float64, name="f") x = tf.nn.softmax_cross_entropy_with_logits(f, l, name="xent") - err = gc.ComputeGradientError(f, [3, 4], x, [3]) + err = tf.test.compute_gradient_error(f, [3, 4], x, [3]) print("cross entropy gradient err = ", err) self.assertLess(err, 5e-8) diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc index 63c1460ac07..956b8719221 100644 --- a/tensorflow/python/lib/io/py_record_writer.cc +++ b/tensorflow/python/lib/io/py_record_writer.cc @@ -42,7 +42,7 @@ PyRecordWriter::~PyRecordWriter() { delete file_; } -bool PyRecordWriter::WriteRecord(::tensorflow::StringPiece record) { +bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) { if (writer_ == nullptr) return false; Status s = writer_->WriteRecord(record); return s.ok(); diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h index 99720f3b8ee..637ee1b8bb2 100644 --- a/tensorflow/python/lib/io/py_record_writer.h +++ b/tensorflow/python/lib/io/py_record_writer.h @@ -36,7 +36,7 @@ class PyRecordWriter { static PyRecordWriter* New(const string& filename); ~PyRecordWriter(); - bool WriteRecord(::tensorflow::StringPiece record); + bool WriteRecord(tensorflow::StringPiece record); void Close(); private: diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index ce171ed9db0..8d288525dc7 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -20,16 +20,17 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op -from tensorflow.python.ops import math_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops @ops.RegisterGradient("Pack") def _PackGrad(op, grad): """Gradient for pack op.""" - return array_ops.unpack(grad, num=op.get_attr('N')) + return array_ops.unpack(grad, num=op.get_attr("N")) @ops.RegisterGradient("Unpack") @@ -41,28 +42,82 @@ def _UnpackGrad(_, *grads): @ops.RegisterGradient("Concat") def _ConcatGrad(op, grad): """Gradient for concat op.""" - assert isinstance(grad, ops.Tensor) + + def _CreateDenseMaskAndBegin(sizes, concat_dim): + """Create variables for iteratively slicing a dense gradients tensor.""" + # Since shape is 1-D, shape_of_shape = [rank-of-inputs] + shape_of_shape = array_ops.shape(sizes[0]) + # Make a vector of length equal to the input's dimensions, + # with 0's everywhere and 1 in the concat dim position. + # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now) + mask = array_ops.concat(0, + [array_ops.fill( + array_ops.expand_dims(concat_dim, 0), 0), + [1], + array_ops.fill( + shape_of_shape - concat_dim - 1, 0)]) + begin = array_ops.fill(shape_of_shape, 0) + return mask, begin + # Degenerate concatenation, just return grad. if len(op.inputs) == 2: return [None, grad] - # Get the inputs' tensor shapes - sizes = [array_ops.shape(x) for x in op.inputs[1:]] + concat_dim = op.inputs[0] - # Since shape is 1-D, shape_of_shape = [rank-of-inputs] - shape_of_shape = array_ops.shape(sizes[0]) - # Make a vector of length equal to the input's dimensions, - # with 0's everywhere and 1 in the concat dim position. - # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now) - mask = array_ops.concat(0, - [array_ops.fill( - array_ops.expand_dims(concat_dim, 0), 0), [1], - array_ops.fill(shape_of_shape - concat_dim - 1, 0)]) out_grads = [] - begin = array_ops.fill(shape_of_shape, 0) - for i in range(len(sizes)): - out_grads.append(array_ops.slice(grad, begin, sizes[i])) - # Lint complains begin = begin + ... - begin = math_ops.add(begin, sizes[i] * mask) + if isinstance(grad, ops.Tensor): + # Get the inputs' tensor shapes + sizes = [array_ops.shape(x) for x in op.inputs[1:]] + mask, begin = _CreateDenseMaskAndBegin(sizes, concat_dim) + for size in sizes: + out_grads.append(array_ops.slice(grad, begin, size)) + # Lint complains begin = begin + ... + begin = math_ops.add(begin, size * mask) + elif isinstance(grad, ops.IndexedSlices): + concat_dim_static = tensor_util.ConstantValue(concat_dim) + if concat_dim_static is None: + raise ValueError("Can only compute IndexedSlices gradient with " + "statically-known concat_dim") + # Get the inputs' tensor shapes + sizes = [array_ops.shape(x) for x in op.inputs[1:]] + if concat_dim_static > 0: + # IndexedSlices, concat_dim > 0. Each input gets IndexedSlices gradients + # with all the indices, but with grad.values sliced accordingly. This + # is like the Tensor case, except shape(grad.values)[0] is not equal to + # shape(sizes[i])[0], since only a subset of the dim-0 values are stored. + mask, begin = _CreateDenseMaskAndBegin(sizes, concat_dim) + for size in sizes: + new_values = array_ops.slice( + grad.values, + begin, + array_ops.concat(0, [[-1], array_ops.slice(size, [1], [-1])])) + out_grads.append( + ops.IndexedSlices(new_values, grad.indices, size)) + # Lint complains begin = begin + ... + begin = math_ops.add(begin, size * mask) + else: + # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients + # only for the relevant indices. + start = constant_op.constant(0, dtype=grad.indices.dtype) + for size in sizes: + size_concat_dim = array_ops.gather(size, concat_dim) + if size_concat_dim.dtype != grad.indices.dtype: + size_concat_dim = math_ops.cast(size_concat_dim, + dtype=grad.indices.dtype) + end = start + size_concat_dim + # Compute the 1-D Tensor of indices relevant for this input. + indices_to_select = array_ops.squeeze( + array_ops.where(math_ops.logical_and(grad.indices >= start, + grad.indices < end)), + squeeze_dims=[1]) + new_indices = array_ops.gather(grad.indices, indices_to_select) - start + new_values = array_ops.gather(grad.values, indices_to_select) + out_grads.append( + ops.IndexedSlices(new_values, new_indices, size)) + start = end + else: + raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad)) + return [None] + out_grads @@ -201,6 +256,7 @@ def _PadGrad(op, grad): def _ReverseSequenceGrad(op, grad): seq_lengths = op.inputs[1] return [array_ops.reverse_sequence(grad, - seq_dim=op.get_attr("seq_dim"), - seq_lengths=seq_lengths), + batch_dim=op.get_attr("batch_dim"), + seq_dim=op.get_attr("seq_dim"), + seq_lengths=seq_lengths), None] diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 50f3facf2e8..1e2950e74ff 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -990,17 +990,22 @@ def _ReverseSequenceShape(op): A single-element list containing the shape of the output. Raises: - ValueError: If the input shapes are incompatible. + ValueError: If the input shapes are incompatible or seq_dim == batch_dim. """ input_shape = op.inputs[0].get_shape() seq_lens_shape = op.inputs[1].get_shape().with_rank(1) - batch_size = input_shape[0].merge_with(seq_lens_shape[0]) - input_shape = tensor_shape.TensorShape([batch_size]).concatenate( - input_shape[1:]) seq_dim = op.get_attr("seq_dim") + batch_dim = op.get_attr("batch_dim") + if batch_dim >= input_shape.ndims: + raise ValueError("batch_dim must be < input.dims() (%d vs %d)" % + (batch_dim, input_shape.ndims)) if seq_dim >= input_shape.ndims: raise ValueError("seq_dim must be < input.dims() (%d vs %d)" % (seq_dim, input_shape.ndims)) + batch_size = input_shape[batch_dim].merge_with(seq_lens_shape[0]) + input_shape = tensor_shape.TensorShape([ + value if ix != batch_dim else batch_size + for ix, value in enumerate(input_shape)]) return [input_shape] diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py index f2aaad37a99..5d8d8a88d08 100644 --- a/tensorflow/python/ops/constant_op.py +++ b/tensorflow/python/ops/constant_op.py @@ -172,12 +172,24 @@ def _ConstantShape(op): [d.size for d in op.get_attr("value").tensor_shape.dim])] -ops.register_tensor_conversion_function((list, tuple), constant, 100) -ops.register_tensor_conversion_function(np.ndarray, constant, 100) -ops.register_tensor_conversion_function(np.generic, constant, 100) -ops.register_tensor_conversion_function(object, constant, 200) +def _constant_tensor_conversion_function(v, dtype=None, name=None, + as_ref=False): + _ = as_ref + return constant(v, dtype=dtype, name=name) -def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None): + +ops.register_tensor_conversion_function( + (list, tuple), _constant_tensor_conversion_function, 100) +ops.register_tensor_conversion_function( + np.ndarray, _constant_tensor_conversion_function, 100) +ops.register_tensor_conversion_function( + np.generic, _constant_tensor_conversion_function, 100) +ops.register_tensor_conversion_function( + object, _constant_tensor_conversion_function, 200) + +def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None, + as_ref=False): + _ = as_ref if not s.is_fully_defined(): raise ValueError( "Cannot convert a partially known TensorShape to a Tensor: %s" % s) @@ -193,7 +205,9 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None): ops.register_tensor_conversion_function( tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100) -def _dimension_tensor_conversion_function(d, dtype=None, name=None): +def _dimension_tensor_conversion_function(d, dtype=None, name=None, + as_ref=False): + _ = as_ref if d.value is None: raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d) if dtype is not None: diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index 8803ea62344..53bb20776c0 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -33,7 +33,7 @@ def _SwitchGrad(op, *grad): if isinstance(ctxt, WhileContext): merge_op = ctxt.switch_map.get(op) if merge_op: - merge_op._update_input(1, grad[1]) + merge_op._update_input(1, next_iteration(grad[1])) return None, None else: merge_op = merge(grad, name="b_switch")[0] @@ -70,7 +70,7 @@ def _MergeGrad(op, grad, _): else: num_inputs = len(op.inputs) cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)] - return [Switch(grad, cond[i])[1] for i in xrange(num_inputs)] + return [switch(grad, cond[i])[1] for i in xrange(num_inputs)] @ops.RegisterGradient("Exit") @@ -89,7 +89,7 @@ def _ExitGrad(op, grad): @ops.RegisterGradient("NextIteration") def _NextIterationGrad(_, grad): - return next_iteration(grad) + return grad @ops.RegisterGradient("Enter") diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 8eb1bd79bff..b2660c210ad 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -75,8 +75,9 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import common_shapes from tensorflow.python.ops import constant_op -from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_control_flow_ops +from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import,undefined-variable @@ -248,7 +249,7 @@ def _SwitchRefOrTensor(data, pred, name="Switch"): Raises: TypeError: if data is not a Tensor or IndexedSlices """ - data = ops.convert_to_tensor_or_indexed_slices(data, name="data") + data = ops.convert_to_tensor_or_indexed_slices(data, name="data", as_ref=True) if isinstance(data, ops.Tensor): if not data.dtype.is_ref_dtype: return switch(data, pred, name=name) @@ -418,8 +419,9 @@ def _GetRealValue(value): Returns: The same tensor value from the saved history. """ - real_value = value + # pylint: disable=protected-access forward_ctxt = value.op._get_control_flow_context() + # pylint: enable=protected-access real_value = forward_ctxt.history_map.get(value.name) assert value.op.type != "Variable" if real_value is None: @@ -432,29 +434,11 @@ def _GetRealValue(value): # to deepcopy the constants for the grad while context. history_value = forward_ctxt.AddForwardAccumulateLoop(value) - # The shapes of the whole history and a single event element. - forward_ctxt.grad_context.Exit() - elem_rank = array_ops.rank(history_value) - 1 - elem_rank_vec = array_ops.expand_dims(elem_rank, 0) - elem_shape = array_ops.slice(array_ops.shape(history_value), [1], - elem_rank_vec) - slice_shape = array_ops.concat(0, [[1], elem_shape]) - forward_ctxt.grad_context.Enter() - - # The begin position of the slice at slice_index. - slice_index = forward_ctxt.grad_context.index - b1 = array_ops.zeros(elem_rank_vec, dtype=dtypes.int32) - b = array_ops.concat(0, [array_ops.expand_dims(slice_index, 0), b1]) - - # The slice at slice_index. - # TODO(irving): Replace with gather once that's GPU accelerated - real_value = array_ops.squeeze( - array_ops.slice(history_value, - b, - slice_shape, - name="real"), - squeeze_dims=[0]) - forward_ctxt.history_map[value.name] = real_value + # pylint: disable=protected-access + real_value = gen_data_flow_ops._stack_pop(history_value, + value.dtype.base_dtype) + # pylint: enable=protected-access + forward_ctxt.history_map[value.name] = real_value return real_value @@ -656,7 +640,7 @@ def cond(pred, fn1, fn2, name=None): context_f = CondContext(pred, pivot_2, 0) context_f.Enter() res_f = context_f.BuildCondBranch(fn2) - context_t.ExitResult(res_f) + context_f.ExitResult(res_f) context_f.Exit() # Add the final merge to the graph. @@ -693,8 +677,10 @@ class WhileContext(ControlFlowContext): # generation for gradient computation self._pivot = None - # The tensors for the counters added by AddForwardCounterLoop or - # AddBackPropCounterLoop + # The loop counter added either by AddForwardCounterLoop or + # AddBackPropCounterLoop. For forward, it is the value of the loop + # counter for the next iteration. For backprop, it is the value of + # the loop counter for the current iteration. self._index = None # Information needed by backprop @@ -703,10 +689,10 @@ class WhileContext(ControlFlowContext): self._history_map = {} self._switch_map = {} - # values considered to have been already seen in this context + # Values considered to have been already seen in this context self._values = set() - # values referenced by but external to this context + # Values referenced by but external to this context self._external_values = {} @property @@ -841,10 +827,9 @@ class WhileContext(ControlFlowContext): name="f_count") merge_n = merge([enter_n, enter_n])[0] switch_n = switch(merge_n, self._pivot) - self._index = switch_n[1] - add_n = math_ops.add(self._index, 1) - next_n = next_iteration(add_n) + self._index = math_ops.add(switch_n[1], 1) + next_n = next_iteration(self._index) merge_n.op._update_input(1, next_n) self._total_iterations = exit(switch_n[0], name="f_count") @@ -859,54 +844,39 @@ class WhileContext(ControlFlowContext): The pseudocode is: ``` - acc; + acc = stack(); while (_pivot) { - if (index == 0) [value] else Concat(acc, [value]); + acc = stack_push(acc, value); } ``` Args: - value: The tensor that is accumulated. + value: The tensor that is to be accumulated. Returns: - The accumulated history of value. + The stack that contains the accumulated history of value. Raises: ValueError: If the shape of "value" is not known statically. """ - if not value.get_shape().is_fully_defined(): - raise ValueError("Must have known shape: %s" % value) self._grad_context.Exit() - # TODO(irving): Now that acc starts out empty, most of the - # conditional logic can go away. - acc = constant_op.constant([], - value.dtype, - shape=[0] + value.get_shape().as_list(), - name="f_acc") + # pylint: disable=protected-access + acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc") + # pylint: enable=protected-access self.Enter() self.AddName(acc.name) - enter_acc = _Enter(acc, self._name, is_constant=False, + enter_acc = _Enter(acc, self._name, is_constant=True, parallel_iterations=self._parallel_iterations, name="f_acc") - merge_acc = merge([enter_acc, enter_acc])[0] - switch_acc = switch(merge_acc, self._pivot) - # If index = 0 then [value] else Concat(acc, [value]). - cond = math_ops.greater(self._index, 0) - switch_add_acc = switch(switch_acc[1], cond) - expand_value = array_ops.expand_dims(value, 0) - true_branch = array_ops.concat(0, [switch_add_acc[1], expand_value]) - false_branch = array_ops.identity(switch_add_acc[0]) - false_branch = with_dependencies([false_branch], expand_value) - add_acc = merge([false_branch, true_branch])[0] + # pylint: disable=protected-access + push_op = gen_data_flow_ops._stack_push(enter_acc, value) + self._index.op._add_control_input(push_op.op) + # pylint: enable=protected-access - next_acc = next_iteration(add_acc) - merge_acc.op._update_input(1, next_acc) - - exit_acc = exit(switch_acc[0], name="f_acc") self.Exit() self._grad_context.Enter() - return exit_acc + return acc def AddForwardAccumulateCondLoop(self, value): """Add an accumulation loop for each conditional switch. @@ -916,9 +886,9 @@ class WhileContext(ControlFlowContext): The pseudocode is: ``` - acc; + acc = [] while (_pivot) { - Concat(acc, value); + acc = concat([acc, value]); } ``` @@ -929,19 +899,19 @@ class WhileContext(ControlFlowContext): The accumulated history of value. """ self._grad_context.Exit() - acc = constant_op.constant(False, name="f_acc") + acc = constant_op.constant(False, name="f_cond") self.Enter() self.AddName(acc.name) enter_acc = _Enter(acc, self._name, is_constant=False, parallel_iterations=self._parallel_iterations, - name="f_acc") + name="f_cond") merge_acc = merge([enter_acc, enter_acc])[0] switch_acc = switch(merge_acc, self._pivot) acc = array_ops.concat(0, [switch_add_acc[1], value]) next_acc = next_iteration(acc) merge_acc.op._update_input(1, next_acc) - exit_acc = exit(switch_acc[0], name="f_acc") + exit_acc = exit(switch_acc[0], name="f_cond") self.Exit() self._grad_context.Enter() return exit_acc @@ -974,11 +944,10 @@ class WhileContext(ControlFlowContext): self._pivot = loop_cond(cond, name="b_count") switch_count = switch(merge_count, self._pivot) - # Add next_iteration right after Switch to match the gradient function. - next_count = next_iteration(switch_count[1]) - self._pivot_for_body = next_count - self._index = math_ops.sub(next_count, one) - merge_count.op._update_input(1, self._index) + self._index = math_ops.sub(switch_count[1], one) + self._pivot_for_body = self._index + next_count = next_iteration(self._index) + merge_count.op._update_input(1, next_count) exit_count = exit(switch_count[0], name="b_count") self.Exit() @@ -1015,9 +984,9 @@ class WhileContext(ControlFlowContext): merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] switch_acc = switch(merge_acc, self._pivot) - next_acc = next_iteration(switch_acc[1]) - add_acc = math_ops.add(next_acc, value) - merge_acc.op._update_input(1, add_acc) + add_acc = math_ops.add(switch_acc[1], value) + next_acc = next_iteration(add_acc) + merge_acc.op._update_input(1, next_acc) exit_acc = exit(switch_acc[0], name="b_acc") return exit_acc diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index b790d9af6c9..599875ecb4b 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -50,7 +50,7 @@ from tensorflow.python.platform import logging _LARGE_SPARSE_NUM_ELEMENTS = 100000000 -def _IndexedSlicesToTensor(value, dtype=None, name=None): +def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False): """Converts an IndexedSlices object `value` to a Tensor. NOTE(mrry): This function is potentially expensive. @@ -59,6 +59,7 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None): value: An ops.IndexedSlices object. dtype: The dtype of the Tensor to be returned. name: Optional name to use for the returned Tensor. + as_ref: True if a ref is requested. Returns: A dense Tensor representing the values in the given IndexedSlices. @@ -66,6 +67,7 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None): Raises: ValueError: If the IndexedSlices does not have the same dtype. """ + _ = as_ref if dtype and not dtype.is_compatible_with(value.dtype): raise ValueError( "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" % diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py index 488e741e5b4..0c42f637aa4 100644 --- a/tensorflow/python/ops/image_grad_test.py +++ b/tensorflow/python/ops/image_grad_test.py @@ -21,7 +21,6 @@ from __future__ import print_function # pylint: disable=g-bad-import-order, # pylint: disable=unused-import import tensorflow.python.platform -from tensorflow.python.kernel_tests import gradient_checker as gc import numpy as np import tensorflow as tf @@ -56,11 +55,11 @@ class ResizeNearestNeighborOpTest(tf.test.TestCase): input_tensor = tf.constant(x, shape=in_shape) resize_out = tf.image.resize_nearest_neighbor(input_tensor, out_shape[1:3]) - err = gc.ComputeGradientError(input_tensor, - in_shape, - resize_out, - out_shape, - x_init_value=x) + err = tf.test.compute_gradient_error(input_tensor, + in_shape, + resize_out, + out_shape, + x_init_value=x) self.assertLess(err, 1e-3) def testGradFromResizeToSmallerInBothDims(self): @@ -73,11 +72,11 @@ class ResizeNearestNeighborOpTest(tf.test.TestCase): input_tensor = tf.constant(x, shape=in_shape) resize_out = tf.image.resize_nearest_neighbor(input_tensor, out_shape[1:3]) - err = gc.ComputeGradientError(input_tensor, - in_shape, - resize_out, - out_shape, - x_init_value=x) + err = tf.test.compute_gradient_error(input_tensor, + in_shape, + resize_out, + out_shape, + x_init_value=x) self.assertLess(err, 1e-3) diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 7be02b220f6..2392042d504 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -25,7 +25,8 @@ are all of variable size. If you need fixed size images, pass the output of the decode Ops to one of the cropping and resizing Ops. Note: The PNG encode and decode Ops support RGBA, but the conversions Ops -presently only support RGB, HSV, and GrayScale. +presently only support RGB, HSV, and GrayScale. Presently, the alpha channel has +to be stripped from the image and re-attached using slicing ops. @@decode_jpeg @@encode_jpeg @@ -82,6 +83,14 @@ resized_image = tf.image.resize_bilinear(image, [299, 299]) @@transpose_image +## Converting Between Colorspaces. + +Internally, images are either stored in as one `float32` per channel per pixel +(implicitly, values are assumed to lie in `[0,1)`) or one `uint8` per channel +per pixel (values are assumed to lie in `[0,255]`). + +@@convert_image_dtype + ## Image Adjustments TensorFlow provides functions to adjust images in various ways: brightness, @@ -805,3 +814,64 @@ def random_crop(image, size, seed=None, name=None): seed1, seed2 = random_seed.get_seed(seed) return gen_image_ops.random_crop(image, size, seed=seed1, seed2=seed2, name=name) + + +def convert_image_dtype(image, dtype, name=None): + """Convert `image` to `dtype`, scaling its values if needed. + + Images that are represented using floating point values are expected to have + values in the range [0,1). Image data stored in integer data types are + expected to have values in the range `[0,MAX]`, wbere `MAX` is the largest + positive representable number for the data type. + + This op converts between data types, scaling the values appropriately before + casting. + + Note that for floating point inputs, this op expects values to lie in [0,1). + Conversion of an image containing values outside that range may lead to + overflow errors when converted to integer `Dtype`s. + + Args: + image: An image. + dtype: A `DType` to convert `image` to. + name: A name for this operation (optional). + + Returns: + `image`, converted to `dtype`. + """ + + if dtype == image.dtype: + return image + + with ops.op_scope([image], name, 'convert_image') as name: + # Both integer: use integer multiplication in the larger range + if image.dtype.is_integer and dtype.is_integer: + scale_in = image.dtype.max + scale_out = dtype.max + if scale_in > scale_out: + # Scaling down, scale first, then cast. The scaling factor will + # cause in.max to be mapped to above out.max but below out.max+1, + # so that the output is safely in the supported range. + scale = (scale_in + 1) // (scale_out + 1) + scaled = math_ops.div(image, scale) + return math_ops.cast(scaled, dtype) + else: + # Scaling up, cast first, then scale. The scale will not map in.max to + # out.max, but converting back and forth should result in no change. + cast = math_ops.cast(image, dtype) + scale = (scale_out + 1) // (scale_in + 1) + return math_ops.mul(cast, scale) + elif image.dtype.is_floating and dtype.is_floating: + # Both float: Just cast, no possible overflows in the allowed ranges. + return math_ops.cast(image, dtype) + else: + if image.dtype.is_integer: + # Converting to float: first cast, then scale + cast = math_ops.cast(image, dtype) + scale = 1. / image.dtype.max + return math_ops.mul(cast, scale) + else: + # Converting from float: first scale, then cast + scale = dtype.max + 0.5 # avoid rounding problems in the cast + scaled = math_ops.mul(image, scale) + return math_ops.cast(scaled, dtype) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 7d315487d7a..1b7292e1e79 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -26,6 +26,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import test_util +from tensorflow.python.framework import dtypes from tensorflow.python.ops import constant_op from tensorflow.python.ops import image_ops from tensorflow.python.ops import io_ops @@ -787,5 +788,46 @@ class PngTest(test_util.TensorFlowTestCase): [None, None, channels or None]) +class ConvertImageTest(test_util.TensorFlowTestCase): + + def _convert(self, original, original_dtype, output_dtype, expected): + x_np = np.array(original, dtype=original_dtype.as_numpy_dtype()) + y_np = np.array(expected, dtype=output_dtype.as_numpy_dtype()) + + with self.test_session(): + image = constant_op.constant(x_np) + y = image_ops.convert_image_dtype(image, output_dtype) + self.assertTrue(y.dtype == output_dtype) + self.assertAllClose(y.eval(), y_np, atol=1e-5) + + def testNoConvert(self): + # Make sure converting to the same data type creates no ops + with self.test_session(): + image = constant_op.constant([1], dtype=dtypes.uint8) + y = image_ops.convert_image_dtype(image, dtypes.uint8) + self.assertEquals(image, y) + + def testConvertBetweenInteger(self): + # Make sure converting to between integer types scales appropriately + with self.test_session(): + self._convert([0, 255], dtypes.uint8, dtypes.int16, [0, 255 * 128]) + self._convert([0, 32767], dtypes.int16, dtypes.uint8, [0, 255]) + + def testConvertBetweenFloat(self): + # Make sure converting to between float types does nothing interesting + with self.test_session(): + self._convert([-1.0, 0, 1.0, 200000], dtypes.float32, dtypes.float64, + [-1.0, 0, 1.0, 200000]) + self._convert([-1.0, 0, 1.0, 200000], dtypes.float64, dtypes.float32, + [-1.0, 0, 1.0, 200000]) + + def testConvertBetweenIntegerAndFloat(self): + # Make sure converting from and to a float type scales appropriately + with self.test_session(): + self._convert([0, 1, 255], dtypes.uint8, dtypes.float32, + [0, 1.0 / 255.0, 1]) + self._convert([0, 1.1 / 255.0, 1], dtypes.float32, dtypes.uint8, + [0, 1, 255]) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 17160f909e0..3bd0e875631 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -16,11 +16,10 @@ # pylint: disable=wildcard-import,unused-import,g-bad-import-order """## Activation Functions -The activation ops provide different types of nonlinearities for use in -neural networks. These include smooth nonlinearities (`sigmoid`, -`tanh`, and `softplus`), continuous but not everywhere differentiable -functions (`relu`, `relu6`, and `relu_x`), and random regularization -(`dropout`). +The activation ops provide different types of nonlinearities for use in neural +networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `softplus`, +and `softsign`), continuous but not everywhere differentiable functions (`relu`, +`relu6`, and `relu_x`), and random regularization (`dropout`). All activation ops apply componentwise, and produce a tensor of the same shape as the input tensor. @@ -28,6 +27,7 @@ shape as the input tensor. @@relu @@relu6 @@softplus +@@softsign @@dropout @@bias_add @@sigmoid @@ -212,12 +212,16 @@ from tensorflow.python.ops import candidate_sampling_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_grad from tensorflow.python.ops import nn_ops from tensorflow.python.ops import numerics from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import seq2seq from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.math_ops import sigmoid from tensorflow.python.ops.math_ops import tanh @@ -225,6 +229,7 @@ from tensorflow.python.ops.math_ops import tanh from tensorflow.python.ops.nn_ops import * from tensorflow.python.ops.candidate_sampling_ops import * from tensorflow.python.ops.embedding_ops import * +from tensorflow.python.ops.rnn import * def sigmoid_cross_entropy_with_logits(logits, targets, name=None): @@ -268,28 +273,6 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None): name=name) -def xw_plus_b(x, weights, biases, name=None): - """Computes matmul(x, weights) + biases. - - Args: - x: a 2D tensor. Dimensions typically: batch, in_units - weights: a 2D tensor. Dimensions typically: in_units, out_units - biases: a 1D tensor. Dimensions: out_units - name: A name for the operation (optional). If not specified - "wx_plus_b" is used. - - Returns: - A 2-D Tensor computing matmul(x, weights) + biases. - Dimensions typically: batch, out_units. - """ - with ops.op_scope([x, weights, biases], name, "xw_plus_b") as name: - x = ops.convert_to_tensor(x, name="x") - weights = ops.convert_to_tensor(weights, name="weights") - biases = ops.convert_to_tensor(biases, name="biases") - mm = math_ops.matmul(x, weights) - return nn_ops.bias_add(mm, biases, name=name) - - def relu_layer(x, weights, biases, name=None): """Computes Relu(x * weight + biases). @@ -363,59 +346,6 @@ def zero_fraction(value, name=None): dtypes.float32)) -def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): - """Computes dropout. - - With probability `keep_prob`, outputs the input element scaled up by - `1 / keep_prob`, otherwise outputs `0`. The scaling is so that the expected - sum is unchanged. - - By default, each element is kept or dropped independently. If `noise_shape` - is specified, it must be - [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]` - will make independent decisions. For example, if `shape(x) = [k, l, m, n]` - and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be - kept independently and each row and column will be kept or not kept together. - - Args: - x: A tensor. - keep_prob: A scalar `Tensor` with the same type as x. The probability - that each element is kept. - noise_shape: A 1-D `Tensor` of type `int32`, representing the - shape for randomly generated keep/drop flags. - seed: A Python integer. Used to create random seeds. See - [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) - for behavior. - name: A name for this operation (optional). - - Returns: - A Tensor of the same shape of `x`. - - Raises: - ValueError: If `keep_prob` is not in `(0, 1]`. - """ - with ops.op_scope([x], name, "dropout") as name: - x = ops.convert_to_tensor(x, name="x") - if isinstance(keep_prob, float) and not(0 < keep_prob <= 1): - raise ValueError("keep_prob must be a scalar tensor or a float in the " - "range (0, 1], got %g" % keep_prob) - keep_prob = ops.convert_to_tensor( - keep_prob, dtype=x.dtype, name="keep_prob") - keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) - - noise_shape = noise_shape or array_ops.shape(x) - # uniform [keep_prob, 1.0 + keep_prob) - random_tensor = keep_prob - random_tensor += random_ops.random_uniform( - noise_shape, seed=seed, dtype=x.dtype) - # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) - binary_tensor = math_ops.floor(random_tensor) - ret = x * math_ops.inv(keep_prob) * binary_tensor - ret.set_shape(x.get_shape()) - return ret - - def depthwise_conv2d(input, filter, strides, padding, name=None): """Depthwise 2-D convolution. @@ -672,9 +602,9 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, labels_flat = array_ops.reshape(labels, [-1]) # Sample the negative labels. - # sampled shape: num_sampled vector - # true_expected_count shape = [batch_size, 1] - # sampled_expected_count shape = num_sampled vector + # sampled shape: [num_sampled] tensor + # true_expected_count shape = [batch_size, 1] tensor + # sampled_expected_count shape = [num_sampled] tensor if sampled_values is None: sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler( true_classes=labels, @@ -687,12 +617,18 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, sampled, true_expected_count, sampled_expected_count = sampled_values # pylint: enable=unpacking-non-sequence + # labels_flat is a [batch_size * num_true] tensor + # sampled is a [num_sampled] int tensor + all_ids = array_ops.concat(0, [labels_flat, sampled]) + # weights shape is [num_classes, dim] - # labels_flat is a [batch_size * num_true] vector + all_w = embedding_ops.embedding_lookup(weights, all_ids) + all_b = embedding_ops.embedding_lookup(biases, all_ids) # true_w shape is [batch_size * num_true, dim] - # true_b is a [batch_size * num_true] vector - true_w = embedding_ops.embedding_lookup(weights, labels_flat) - true_b = embedding_ops.embedding_lookup(biases, labels_flat) + # true_b is a [batch_size * num_true] tensor + true_w = array_ops.slice( + all_w, [0, 0], array_ops.pack([array_ops.shape(labels_flat)[0], -1])) + true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat)) # inputs shape is [batch_size, dim] # true_w shape is [batch_size * num_true, dim] @@ -711,11 +647,11 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, true_logits += true_b # Lookup weights and biases for sampled labels. - # sampled is a num_sampled int vector # sampled_w shape is [num_sampled, dim] - # sampled_b is a num_sampled float vector - sampled_w = embedding_ops.embedding_lookup(weights, sampled) - sampled_b = embedding_ops.embedding_lookup(biases, sampled) + # sampled_b is a [num_sampled] float tensor + sampled_w = array_ops.slice( + all_w, array_ops.pack([array_ops.shape(labels_flat)[0], 0]), [-1, -1]) + sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1]) # inputs has shape [batch_size, dim] # sampled_w has shape [num_sampled, dim] @@ -740,6 +676,8 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, sampled_logits_shape = array_ops.concat( 0, [array_ops.shape(labels)[:1], array_ops.expand_dims(num_sampled, 0)]) + if sampled_logits.dtype != acc_weights.dtype: + acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype) sampled_logits += sparse_ops.sparse_to_dense( sparse_indices, sampled_logits_shape, acc_weights, 0.0) @@ -879,5 +817,5 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, remove_accidental_hits=remove_accidental_hits, name=name) sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels) - # sampled_losses is a batch_size vector. + # sampled_losses is a [batch_size] tensor. return sampled_losses diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 7c8ad8883ca..48f57b65279 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -137,6 +137,11 @@ def _SoftplusGrad(op, grad): return gen_nn_ops._softplus_grad(grad, op.inputs[0]) +@ops.RegisterGradient("Softsign") +def _SoftsignGrad(op, grad): + return gen_nn_ops._softsign_grad(grad, op.inputs[0]) + + @ops.RegisterGradient("ReluGrad") def _ReluGradGrad(op, grad): x = op.inputs[1] diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 1eb8ef4c693..604739a6b6e 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -26,8 +26,11 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_nn_ops import * @@ -235,11 +238,13 @@ def max_pool(value, ksize, strides, padding, name=None): ops.RegisterShape("Relu")(common_shapes.unchanged_shape) ops.RegisterShape("Relu6")(common_shapes.unchanged_shape) ops.RegisterShape("Softplus")(common_shapes.unchanged_shape) +ops.RegisterShape("Softsign")(common_shapes.unchanged_shape) @ops.RegisterShape("ReluGrad") @ops.RegisterShape("Relu6Grad") @ops.RegisterShape("SoftplusGrad") +@ops.RegisterShape("SoftsignGrad") def _BinaryElementwiseShape(op): """Returns same shape as both inputs to op. @@ -383,3 +388,81 @@ def _MaxPoolGradShape(op): """Shape function for the MaxPoolGrad op.""" orig_input_shape = op.inputs[0].get_shape().with_rank(4) return [orig_input_shape] + + +def xw_plus_b(x, weights, biases, name=None): # pylint: disable=invalid-name + """Computes matmul(x, weights) + biases. + + Args: + x: a 2D tensor. Dimensions typically: batch, in_units + weights: a 2D tensor. Dimensions typically: in_units, out_units + biases: a 1D tensor. Dimensions: out_units + name: A name for the operation (optional). If not specified + "wx_plus_b" is used. + + Returns: + A 2-D Tensor computing matmul(x, weights) + biases. + Dimensions typically: batch, out_units. + """ + with ops.op_scope([x, weights, biases], name, "xw_plus_b") as name: + x = ops.convert_to_tensor(x, name="x") + weights = ops.convert_to_tensor(weights, name="weights") + biases = ops.convert_to_tensor(biases, name="biases") + mm = math_ops.matmul(x, weights) + return bias_add(mm, biases, name=name) + + +# pylint: disable=invalid-name +def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): + """Computes dropout. + + With probability `keep_prob`, outputs the input element scaled up by + `1 / keep_prob`, otherwise outputs `0`. The scaling is so that the expected + sum is unchanged. + + By default, each element is kept or dropped independently. If `noise_shape` + is specified, it must be + [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]` + will make independent decisions. For example, if `shape(x) = [k, l, m, n]` + and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be + kept independently and each row and column will be kept or not kept together. + + Args: + x: A tensor. + keep_prob: A scalar `Tensor` with the same type as x. The probability + that each element is kept. + noise_shape: A 1-D `Tensor` of type `int32`, representing the + shape for randomly generated keep/drop flags. + seed: A Python integer. Used to create random seeds. See + [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) + for behavior. + name: A name for this operation (optional). + + Returns: + A Tensor of the same shape of `x`. + + Raises: + ValueError: If `keep_prob` is not in `(0, 1]`. + """ + with ops.op_scope([x], name, "dropout") as name: + x = ops.convert_to_tensor(x, name="x") + if isinstance(keep_prob, float) and not 0 < keep_prob <= 1: + raise ValueError("keep_prob must be a scalar tensor or a float in the " + "range (0, 1], got %g" % keep_prob) + keep_prob = ops.convert_to_tensor( + keep_prob, dtype=x.dtype, name="keep_prob") + keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) + + noise_shape = noise_shape or array_ops.shape(x) + # uniform [keep_prob, 1.0 + keep_prob) + random_tensor = keep_prob + random_tensor += random_ops.random_uniform( + noise_shape, seed=seed, dtype=x.dtype) + # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) + binary_tensor = math_ops.floor(random_tensor) + ret = x * math_ops.inv(keep_prob) * binary_tensor + ret.set_shape(x.get_shape()) + return ret + +# pylint: enable=invalid-name diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 65e28978baa..4146803c255 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -22,26 +22,17 @@ import math import tensorflow.python.platform +import tensorflow as tf import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util -from tensorflow.python.kernel_tests import gradient_checker as gc -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import constant_op from tensorflow.python.ops import gen_nn_ops -from tensorflow.python.ops import gradients -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_grad -from tensorflow.python.platform import googletest exp = math.exp log = math.log -class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): +class SigmoidCrossEntropyWithLogitsTest(tf.test.TestCase): def _SigmoidCrossEntropyWithLogits(self, logits, targets): assert len(logits) == len(targets) @@ -50,28 +41,29 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): pred = [min(max(p, eps), 1 - eps) for p in pred] return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)] - def _Inputs(self, x=None, y=None, dtype=dtypes.float64, sizes=None): + def _Inputs(self, x=None, y=None, dtype=tf.float64, sizes=None): x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y assert len(x) == len(y) sizes = sizes if sizes else [len(x)] - logits = constant_op.constant(x, shape=sizes, dtype=dtype, name="logits") - targets = constant_op.constant(y, shape=sizes, dtype=dtype, name="targets") + logits = tf.constant(x, shape=sizes, dtype=dtype, name="logits") + targets = tf.constant(y, shape=sizes, dtype=dtype, name="targets") losses = np.array(self._SigmoidCrossEntropyWithLogits(x, y)).reshape(*sizes) return logits, targets, losses def testConstructionNamed(self): with self.test_session(): logits, targets, _ = self._Inputs() - loss = nn.sigmoid_cross_entropy_with_logits(logits, targets, - name="mylogistic") + loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, + targets, + name="mylogistic") self.assertEqual("mylogistic", loss.op.name) def testLogisticOutput(self): for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu): - logits, targets, losses = self._Inputs(dtype=dtypes.float32) - loss = nn.sigmoid_cross_entropy_with_logits(logits, targets) + logits, targets, losses = self._Inputs(dtype=tf.float32) + loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) np_loss = np.array(losses).astype(np.float32) tf_loss = loss.eval() self.assertAllClose(np_loss, tf_loss, atol=0.001) @@ -79,9 +71,9 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): def testLogisticOutputMultiDim(self): for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu): - logits, targets, losses = self._Inputs(dtype=dtypes.float32, + logits, targets, losses = self._Inputs(dtype=tf.float32, sizes=[2, 2, 2]) - loss = nn.sigmoid_cross_entropy_with_logits(logits, targets) + loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) np_loss = np.array(losses).astype(np.float32) tf_loss = loss.eval() self.assertAllClose(np_loss, tf_loss, atol=0.001) @@ -90,13 +82,13 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): sizes = [4, 2] with self.test_session(): logits, targets, _ = self._Inputs(sizes=sizes) - loss = nn.sigmoid_cross_entropy_with_logits(logits, targets) - err = gc.ComputeGradientError(logits, sizes, loss, sizes) + loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) + err = tf.test.compute_gradient_error(logits, sizes, loss, sizes) print("logistic loss gradient err = ", err) self.assertLess(err, 1e-7) -class ZeroFractionTest(test_util.TensorFlowTestCase): +class ZeroFractionTest(tf.test.TestCase): def _ZeroFraction(self, x): assert x.shape @@ -109,9 +101,9 @@ class ZeroFractionTest(test_util.TensorFlowTestCase): x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32) y_np = self._ZeroFraction(x_np) with self.test_session(): - x_tf = constant_op.constant(x_np) + x_tf = tf.constant(x_np) x_tf.set_shape(x_shape) - y_tf = nn.zero_fraction(x_tf) + y_tf = tf.nn.zero_fraction(x_tf) y_tf_np = y_tf.eval() eps = 1e-8 self.assertAllClose(y_tf_np, y_np, eps) @@ -119,11 +111,11 @@ class ZeroFractionTest(test_util.TensorFlowTestCase): def testZeroFractionEmpty(self): with self.test_session(): x = np.zeros(0) - y = nn.zero_fraction(x).eval() + y = tf.nn.zero_fraction(x).eval() self.assertTrue(np.isnan(y)) -class SoftmaxTest(test_util.TensorFlowTestCase): +class SoftmaxTest(tf.test.TestCase): def _softmax(self, x): assert len(x.shape) == 2 @@ -137,8 +129,8 @@ class SoftmaxTest(test_util.TensorFlowTestCase): x_np = np.random.randn(*x_shape).astype(np.float32) y_np = self._softmax(x_np) with self.test_session(): - x_tf = constant_op.constant(x_np) - y_tf = nn.softmax(x_tf) + x_tf = tf.constant(x_np) + y_tf = tf.nn.softmax(x_tf) y_tf_np = y_tf.eval() eps = 1e-3 self.assertAllClose(y_tf_np, y_np, eps) @@ -147,14 +139,14 @@ class SoftmaxTest(test_util.TensorFlowTestCase): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float64) with self.test_session(): - x_tf = constant_op.constant(x_np) - y_tf = nn.softmax(x_tf) - err = gc.ComputeGradientError(x_tf, x_shape, y_tf, x_shape) + x_tf = tf.constant(x_np) + y_tf = tf.nn.softmax(x_tf) + err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape) eps = 1e-8 self.assertLess(err, eps) -class DeConv2DTest(test_util.TensorFlowTestCase): +class DeConv2DTest(tf.test.TestCase): def testDeConv2DSingleStride(self): with self.test_session(): @@ -167,11 +159,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase): # Filter: [kernel_height, kernel_width, output_depth, input_depth] f_shape = [3, 3, 2, 3] - x = constant_op.constant(1.0, shape=x_shape, name="x", - dtype=dtypes.float32) - f = constant_op.constant(1.0, shape=f_shape, name="filter", - dtype=dtypes.float32) - output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") + x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32) + f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32) + output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") value = output.eval() # We count the number of cells being added at the locations in the output. @@ -204,11 +194,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase): # Filter: [kernel_height, kernel_width, output_depth, input_depth] f_shape = [3, 3, 2, 3] - x = constant_op.constant(1.0, shape=x_shape, name="x", - dtype=dtypes.float32) - f = constant_op.constant(1.0, shape=f_shape, name="filter", - dtype=dtypes.float32) - output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") + x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32) + f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32) + output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") value = output.eval() for n in xrange(x_shape[0]): @@ -236,11 +224,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase): # Filter: [kernel_height, kernel_width, output_depth, input_depth] f_shape = [3, 3, 2, 3] - x = constant_op.constant(1.0, shape=x_shape, name="x", - dtype=dtypes.float32) - f = constant_op.constant(1.0, shape=f_shape, name="filter", - dtype=dtypes.float32) - output = nn.deconv2d(x, f, y_shape, strides=strides, padding="VALID") + x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32) + f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32) + output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="VALID") value = output.eval() cache_values = np.zeros(y_shape, dtype=np.float32) @@ -281,21 +267,22 @@ class DeConv2DTest(test_util.TensorFlowTestCase): x_val = np.random.random_sample(x_shape).astype(np.float64) f_val = np.random.random_sample(f_shape).astype(np.float64) with self.test_session(): - x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) - f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) - output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") - err = gc.ComputeGradientError([x, f], [x_shape, f_shape], output, y_shape) + x = tf.constant(x_val, name="x", dtype=tf.float32) + f = tf.constant(f_val, name="f", dtype=tf.float32) + output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") + err = tf.test.compute_gradient_error( + [x, f], [x_shape, f_shape], output, y_shape) print("DeConv gradient err = %g " % err) err_tolerance = 0.0005 self.assertLess(err, err_tolerance) -class L2LossTest(test_util.TensorFlowTestCase): +class L2LossTest(tf.test.TestCase): def testL2Loss(self): with self.test_session(): - x = constant_op.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x") - l2loss = nn.l2_loss(x) + x = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x") + l2loss = tf.nn.l2_loss(x) value = l2loss.eval() self.assertAllClose(7.0, value) @@ -304,15 +291,15 @@ class L2LossTest(test_util.TensorFlowTestCase): np.random.seed(1) # Make it reproducible. x_val = np.random.random_sample(x_shape).astype(np.float64) with self.test_session(): - x = constant_op.constant(x_val, name="x") - output = nn.l2_loss(x) - err = gc.ComputeGradientError(x, x_shape, output, [1]) + x = tf.constant(x_val, name="x") + output = tf.nn.l2_loss(x) + err = tf.test.compute_gradient_error(x, x_shape, output, [1]) print("L2Loss gradient err = %g " % err) err_tolerance = 1e-11 self.assertLess(err, err_tolerance) -class L2NormalizeTest(test_util.TensorFlowTestCase): +class L2NormalizeTest(tf.test.TestCase): def _l2Normalize(self, x, dim): norm = np.apply_along_axis(np.linalg.norm, dim, x) @@ -325,8 +312,8 @@ class L2NormalizeTest(test_util.TensorFlowTestCase): for dim in range(len(x_shape)): y_np = self._l2Normalize(x_np, dim) with self.test_session(): - x_tf = constant_op.constant(x_np, name="x") - y_tf = nn.l2_normalize(x_tf, dim) + x_tf = tf.constant(x_np, name="x") + y_tf = tf.nn.l2_normalize(x_tf, dim) self.assertAllClose(y_np, y_tf.eval()) def testL2NormalizeGradient(self): @@ -335,14 +322,14 @@ class L2NormalizeTest(test_util.TensorFlowTestCase): x_np = np.random.random_sample(x_shape).astype(np.float64) for dim in range(len(x_shape)): with self.test_session(): - x_tf = constant_op.constant(x_np, name="x") - y_tf = nn.l2_normalize(x_tf, dim) - err = gc.ComputeGradientError(x_tf, x_shape, y_tf, x_shape) + x_tf = tf.constant(x_np, name="x") + y_tf = tf.nn.l2_normalize(x_tf, dim) + err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape) print("L2Normalize gradient err = %g " % err) self.assertLess(err, 1e-4) -class DropoutTest(test_util.TensorFlowTestCase): +class DropoutTest(tf.test.TestCase): def testDropout(self): # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate @@ -353,10 +340,8 @@ class DropoutTest(test_util.TensorFlowTestCase): num_iter = 10 for keep_prob in [0.1, 0.5, 0.8]: with self.test_session(): - t = constant_op.constant(1.0, - shape=[x_dim, y_dim], - dtype=dtypes.float32) - dropout = nn.dropout(t, keep_prob) + t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) + dropout = tf.nn.dropout(t, keep_prob) final_count = 0 self.assertEqual([x_dim, y_dim], dropout.get_shape()) for _ in xrange(0, num_iter): @@ -382,10 +367,8 @@ class DropoutTest(test_util.TensorFlowTestCase): num_iter = 10 for keep_prob in [0.1, 0.5, 0.8]: with self.test_session(): - t = constant_op.constant(1.0, - shape=[x_dim, y_dim], - dtype=dtypes.float32) - dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) + dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) self.assertEqual([x_dim, y_dim], dropout.get_shape()) final_count = 0 for _ in xrange(0, num_iter): @@ -408,10 +391,8 @@ class DropoutTest(test_util.TensorFlowTestCase): num_iter = 10 for keep_prob in [0.1, 0.5, 0.8]: with self.test_session(): - t = constant_op.constant(1.0, - shape=[x_dim, y_dim], - dtype=dtypes.float32) - dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) + dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) self.assertEqual([x_dim, y_dim], dropout.get_shape()) for _ in xrange(0, num_iter): value = dropout.eval() @@ -429,11 +410,9 @@ class DropoutTest(test_util.TensorFlowTestCase): num_iter = 10 for keep_prob in [0.1, 0.5, 0.8]: with self.test_session(): - t = constant_op.constant(1.0, - shape=[x_dim, y_dim], - dtype=dtypes.float32) - keep_prob_placeholder = array_ops.placeholder(dtypes.float32) - dropout = nn.dropout(t, keep_prob_placeholder) + t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) + keep_prob_placeholder = tf.placeholder(tf.float32) + dropout = tf.nn.dropout(t, keep_prob_placeholder) final_count = 0 self.assertEqual([x_dim, y_dim], dropout.get_shape()) for _ in xrange(0, num_iter): @@ -453,52 +432,49 @@ class DropoutTest(test_util.TensorFlowTestCase): x_dim = 40 y_dim = 30 keep_prob = 0.5 - x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - dropout_x = nn.dropout( - x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32)) + x = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) + dropout_x = tf.nn.dropout(x, + keep_prob, + noise_shape=tf.placeholder(tf.int32)) self.assertEqual(x.get_shape(), dropout_x.get_shape()) def testInvalidKeepProb(self): x_dim = 40 y_dim = 30 - t = constant_op.constant(1.0, - shape=[x_dim, y_dim], - dtype=dtypes.float32) + t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) with self.assertRaises(ValueError): - nn.dropout(t, -1.0) + tf.nn.dropout(t, -1.0) with self.assertRaises(ValueError): - nn.dropout(t, 1.1) + tf.nn.dropout(t, 1.1) with self.assertRaises(ValueError): - nn.dropout(t, [0.0, 1.0]) + tf.nn.dropout(t, [0.0, 1.0]) with self.assertRaises(ValueError): - nn.dropout(t, array_ops.placeholder(dtypes.float64)) + tf.nn.dropout(t, tf.placeholder(tf.float64)) with self.assertRaises(ValueError): - nn.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2])) + tf.nn.dropout(t, tf.placeholder(tf.float32, shape=[2])) def testShapedDropoutShapeError(self): # Runs shaped dropout and verifies an error is thrown on misshapen noise. x_dim = 40 y_dim = 30 keep_prob = 0.5 - t = constant_op.constant(1.0, - shape=[x_dim, y_dim], - dtype=dtypes.float32) + t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) with self.assertRaises(ValueError): - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) + _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) with self.assertRaises(ValueError): - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5]) + _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5]) with self.assertRaises(ValueError): - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3]) + _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim + 3]) with self.assertRaises(ValueError): - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim]) + _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim]) # test that broadcasting proceeds - _ = nn.dropout(t, keep_prob, noise_shape=[y_dim]) - _ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim]) - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) - _ = nn.dropout(t, keep_prob, noise_shape=[1, 1]) + _ = tf.nn.dropout(t, keep_prob, noise_shape=[y_dim]) + _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, y_dim]) + _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, 1]) -class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase): +class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): def _npBatchNorm(self, x, m, v, beta, gamma, epsilon, scale_after_normalization): @@ -509,7 +485,7 @@ class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase): def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon, scale_after_normalization): - y = (x - m) * math_ops.rsqrt(v + epsilon) + y = (x - m) * tf.rsqrt(v + epsilon) if scale_after_normalization: y = gamma * y y += beta @@ -525,14 +501,14 @@ class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase): gamma_val = np.random.random_sample(param_shape).astype(np.float32) for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu) as sess: - x = constant_op.constant(x_val, name="x") - m = constant_op.constant(m_val, name="m") - v = constant_op.constant(v_val, name="v") - beta = constant_op.constant(beta_val, name="beta") - gamma = constant_op.constant(gamma_val, name="gamma") + x = tf.constant(x_val, name="x") + m = tf.constant(m_val, name="m") + v = tf.constant(v_val, name="v") + beta = tf.constant(beta_val, name="beta") + gamma = tf.constant(gamma_val, name="gamma") epsilon = 0.001 for scale_after_normalization in [True, False]: - bn = nn.batch_norm_with_global_normalization( + bn = tf.nn.batch_norm_with_global_normalization( x, m, v, beta, gamma, epsilon, scale_after_normalization) on = self._opsBatchNorm( x, m, v, beta, gamma, epsilon, scale_after_normalization) @@ -555,20 +531,20 @@ class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase): beta_val = np.random.random_sample(param_shape).astype(np.float64) gamma_val = np.random.random_sample(param_shape).astype(np.float64) with self.test_session(): - x = constant_op.constant(x_val, name="x") - m = constant_op.constant(m_val, name="m") - v = constant_op.constant(v_val, name="v") - beta = constant_op.constant(beta_val, name="beta") - gamma = constant_op.constant(gamma_val, name="gamma") + x = tf.constant(x_val, name="x") + m = tf.constant(m_val, name="m") + v = tf.constant(v_val, name="v") + beta = tf.constant(beta_val, name="beta") + gamma = tf.constant(gamma_val, name="gamma") epsilon = 0.001 # If scale_after_normalization is False, backprop for gamma # will be 0. gamma is unchanged. - output = nn.batch_norm_with_global_normalization( + output = tf.nn.batch_norm_with_global_normalization( x, m, v, beta, gamma, epsilon, scale_after_normalization) all_params = [x, m, v, beta, gamma] all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape] - err = gc.ComputeGradientError(all_params[param_index], - all_shapes[param_index], output, x_shape) + err = tf.test.compute_gradient_error( + all_params[param_index], all_shapes[param_index], output, x_shape) print("Batch normalization %s gradient %s scale err = " % (tag, "with" if scale_after_normalization else "without"), err) self.assertLess(err, err_tolerance) @@ -606,12 +582,12 @@ class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase): backprop_val = np.random.random_sample(x_shape).astype(np.float32) for use_gpu in [False, True]: with self.test_session(use_gpu=use_gpu) as sess: - x = constant_op.constant(x_val, name="x") - m = constant_op.constant(m_val, name="m") - v = constant_op.constant(v_val, name="v") - beta = constant_op.constant(beta_val, name="beta") - gamma = constant_op.constant(gamma_val, name="gamma") - backprop = constant_op.constant(backprop_val, name="backprop") + x = tf.constant(x_val, name="x") + m = tf.constant(m_val, name="m") + v = tf.constant(v_val, name="v") + beta = tf.constant(beta_val, name="beta") + gamma = tf.constant(gamma_val, name="gamma") + backprop = tf.constant(backprop_val, name="backprop") epsilon = 0.001 for scale_after_normalization in [True, False]: dx, dm, dv, db, dg = ( @@ -619,7 +595,7 @@ class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase): x, m, v, gamma, backprop, epsilon, scale_after_normalization)) on = self._opsBatchNorm( x, m, v, beta, gamma, epsilon, scale_after_normalization) - odx, odm, odv, odb, odg = gradients.gradients( + odx, odm, odv, odb, odg = tf.gradients( [on], [x, m, v, beta, gamma], [backprop]) if scale_after_normalization: all_grads = sess.run([dx, dm, dv, db, dg, odx, odm, odv, odb, odg]) @@ -633,7 +609,7 @@ class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase): all_grads[i + len(to_check)], all_grads[i], atol=0.000001) -class MomentsTest(test_util.TensorFlowTestCase): +class MomentsTest(tf.test.TestCase): def RunMomentTestWithDynamicShape(self, shape, global_norm): with self.test_session(): @@ -641,10 +617,10 @@ class MomentsTest(test_util.TensorFlowTestCase): assert len(shape) == 4 x_numpy = np.random.normal(size=shape).astype(np.float32) - x = array_ops.placeholder(dtypes.float32, shape=[None] * len(shape)) + x = tf.placeholder(tf.float32, shape=[None] * len(shape)) axes = [0, 1, 2] if global_norm else [0] - mean, var = nn.moments(x, axes) + mean, var = tf.nn.moments(x, axes) num_elements = np.prod([shape[i] for i in axes]) @@ -665,10 +641,10 @@ class MomentsTest(test_util.TensorFlowTestCase): assert len(shape) == 4 x_numpy = np.random.normal(size=shape).astype(np.float32) - x = constant_op.constant(x_numpy) + x = tf.constant(x_numpy) axes = [0, 1, 2] if global_norm else [0] - mean, var = nn.moments(x, axes) + mean, var = tf.nn.moments(x, axes) num_elements = np.prod([shape[i] for i in axes]) @@ -695,17 +671,17 @@ class MomentsTest(test_util.TensorFlowTestCase): with self.test_session(): x_shape = [3, 5, 4, 2] x_val = np.random.random_sample(x_shape).astype(np.float64) - x = constant_op.constant(x_val) + x = tf.constant(x_val) x.set_shape(x_shape) axes = [0, 1, 2] y_shape = [2] # Depth of x - out_mean, out_var = nn.moments(x, axes) + out_mean, out_var = tf.nn.moments(x, axes) if from_y == "mean": y = out_mean elif from_y == "var": y = out_var - err = gc.ComputeGradientError(x, x_shape, y, y_shape) + err = tf.test.compute_gradient_error(x, x_shape, y, y_shape) print("Moments %s gradient err = %g" % (from_y, err)) self.assertLess(err, 1e-11) @@ -716,7 +692,7 @@ class MomentsTest(test_util.TensorFlowTestCase): self._testGlobalGradient(from_y="var") -class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): +class ComputeSampledLogitsTest(tf.test.TestCase): def setUp(self): self._num_classes = 5 @@ -768,18 +744,25 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): name="sampled_loss_TF"): # Should be called from within a `with test_session():` block if isinstance(weights, list): - weights_tf = [constant_op.constant(shard) for shard in weights] + weights_tf = [tf.constant(shard) for shard in weights] else: - weights_tf = constant_op.constant(weights) - biases_tf = constant_op.constant(biases) - hidden_acts_tf = constant_op.constant(hidden_acts, - shape=(self._batch_size, self._dim)) - labels_tf = constant_op.constant(labels, dtype=dtypes.int64, - shape=(self._batch_size, num_true)) + weights_tf = tf.constant(weights) + biases_tf = tf.constant(biases) + hidden_acts_tf = tf.constant(hidden_acts, + shape=(self._batch_size, self._dim)) + labels_tf = tf.constant(labels, + dtype=tf.int64, + shape=(self._batch_size, num_true)) - pred_logits_tf, pred_labels_tf = nn._compute_sampled_logits( - weights_tf, biases_tf, hidden_acts_tf, labels_tf, num_sampled, - num_classes, num_true, sampled_vals, + pred_logits_tf, pred_labels_tf = tf.nn._compute_sampled_logits( + weights_tf, + biases_tf, + hidden_acts_tf, + labels_tf, + num_sampled, + num_classes, + num_true, + sampled_vals, subtract_log_q=subtract_log_q, remove_accidental_hits=remove_accidental_hits, name=name) @@ -942,24 +925,28 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): nce_loss_np = np.sum( _SigmoidCrossEntropyWithLogits(logits_np, labels_np), 1) - labels_tf = constant_op.constant(labels, shape=(self._batch_size, 1)) - weights_tf = constant_op.constant(weights) - biases_tf = constant_op.constant(biases) - inputs_tf = constant_op.constant(hidden_acts) + labels_tf = tf.constant(labels, shape=(self._batch_size, 1)) + weights_tf = tf.constant(weights) + biases_tf = tf.constant(biases) + inputs_tf = tf.constant(hidden_acts) - nce_loss_tf = nn.nce_loss( - weights_tf, biases_tf, inputs_tf, labels_tf, - num_sampled=1, - num_classes=self._num_classes, - num_true=1, - sampled_values=test_sampled_vals) + nce_loss_tf = tf.nn.nce_loss(weights_tf, + biases_tf, + inputs_tf, + labels_tf, + num_sampled=1, + num_classes=self._num_classes, + num_true=1, + sampled_values=test_sampled_vals) self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4) # Test with sharded weights - nce_loss_tf = nn.nce_loss( - [constant_op.constant(shard) for shard in sharded_weights], - biases_tf, inputs_tf, labels_tf, + nce_loss_tf = tf.nn.nce_loss( + [tf.constant(shard) for shard in sharded_weights], + biases_tf, + inputs_tf, + labels_tf, num_sampled=1, num_classes=self._num_classes, num_true=1, @@ -996,13 +983,16 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): sampled_softmax_loss_np = _SoftmaxCrossEntropyWithLogits(logits_np, labels_np) - labels_tf = constant_op.constant(labels, shape=(self._batch_size, 1)) - weights_tf = constant_op.constant(weights) - biases_tf = constant_op.constant(biases) - inputs_tf = constant_op.constant(hidden_acts) + labels_tf = tf.constant(labels, shape=(self._batch_size, 1)) + weights_tf = tf.constant(weights) + biases_tf = tf.constant(biases) + inputs_tf = tf.constant(hidden_acts) - sampled_softmax_loss_tf = nn.sampled_softmax_loss( - weights_tf, biases_tf, inputs_tf, labels_tf, + sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss( + weights_tf, + biases_tf, + inputs_tf, + labels_tf, num_sampled=1, num_classes=self._num_classes, num_true=1, @@ -1013,9 +1003,11 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4) # Test with sharded weights - sampled_softmax_loss_tf = nn.sampled_softmax_loss( - [constant_op.constant(shard) for shard in sharded_weights], - biases_tf, inputs_tf, labels_tf, + sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss( + [tf.constant(shard) for shard in sharded_weights], + biases_tf, + inputs_tf, + labels_tf, num_sampled=1, num_classes=self._num_classes, num_true=1, @@ -1027,4 +1019,4 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): if __name__ == "__main__": - googletest.main() + tf.test.main() diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py index ad0406d43dc..c2ad3bdb582 100644 --- a/tensorflow/python/ops/op_def_library.py +++ b/tensorflow/python/ops/op_def_library.py @@ -378,14 +378,18 @@ class OpDefLibrary(object): break try: + if not input_arg.is_ref and dtype: + dtype = dtypes.as_dtype(dtype).base_dtype values = ops.convert_n_to_tensor_or_indexed_slices( values, name=input_arg.name, - dtype=dtypes.as_dtype(dtype).base_dtype if dtype else None) + dtype=dtype if dtype else None, + as_ref=input_arg.is_ref) except (TypeError, ValueError): assert dtype is not None, "Should not fail if dtype is None" assert input_arg.number_attr, "Should be number_attr case" # What types does the conversion function think values have? - values = ops.convert_n_to_tensor_or_indexed_slices(values) + values = ops.convert_n_to_tensor_or_indexed_slices( + values, as_ref=input_arg.is_ref) observed = ", ".join(v.dtype.base_dtype.name for v in values) prefix = ( @@ -393,11 +397,11 @@ class OpDefLibrary(object): (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s that do not match expected type %s." % - (prefix, dtypes.as_dtype(dtype).name)) + (prefix, dtype.name)) elif input_arg.type_attr in attrs: raise TypeError("%s that do not match type %s inferred from " "earlier arguments." % - (prefix, dtypes.as_dtype(dtype).name)) + (prefix, dtype.name)) else: raise TypeError("%s that don't all match." % prefix) @@ -411,13 +415,14 @@ class OpDefLibrary(object): dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] - try: values = ops.convert_to_tensor( - values, name=input_arg.name, dtype=dtype) + values, name=input_arg.name, dtype=dtype, + as_ref=input_arg.is_ref) except ValueError: # What type does convert_to_tensor think it has? - observed = ops.convert_to_tensor(values).dtype.name + observed = ops.convert_to_tensor(values, + as_ref=input_arg.is_ref).dtype.name prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py new file mode 100644 index 00000000000..e7d70ea79e3 --- /dev/null +++ b/tensorflow/python/ops/rnn.py @@ -0,0 +1,150 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""RNN helpers for TensorFlow models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variable_scope as vs + + +def rnn(cell, inputs, initial_state=None, dtype=None, + sequence_length=None, scope=None): + """Creates a recurrent neural network specified by RNNCell "cell". + + The simplest form of RNN network generated is: + state = cell.zero_state(...) + outputs = [] + states = [] + for input_ in inputs: + output, state = cell(input_, state) + outputs.append(output) + states.append(state) + return (outputs, states) + + However, a few other options are available: + + An initial state can be provided. + If sequence_length is provided, dynamic calculation is performed. + + Dynamic calculation returns, at time t: + (t >= max(sequence_length) + ? (zeros(output_shape), zeros(state_shape)) + : cell(input, state) + + Thus saving computational time when unrolling past the max sequence length. + + Args: + cell: An instance of RNNCell. + inputs: A length T list of inputs, each a vector with shape [batch_size]. + initial_state: (optional) An initial state for the RNN. This must be + a tensor of appropriate type and shape [batch_size x cell.state_size]. + dtype: (optional) The data type for the initial state. Required if + initial_state is not provided. + sequence_length: An int64 vector (tensor) size [batch_size]. + scope: VariableScope for the created subgraph; defaults to "RNN". + + Returns: + A pair (outputs, states) where: + outputs is a length T list of outputs (one for each input) + states is a length T list of states (one state following each input) + + Raises: + TypeError: If "cell" is not an instance of RNNCell. + ValueError: If inputs is None or an empty list. + """ + + if not isinstance(cell, rnn_cell.RNNCell): + raise TypeError("cell must be an instance of RNNCell") + if not isinstance(inputs, list): + raise TypeError("inputs must be a list") + if not inputs: + raise ValueError("inputs must not be empty") + + outputs = [] + states = [] + with vs.variable_scope(scope or "RNN"): + batch_size = array_ops.shape(inputs[0])[0] + if initial_state is not None: + state = initial_state + else: + if not dtype: + raise ValueError("If no initial_state is provided, dtype must be.") + state = cell.zero_state(batch_size, dtype) + + if sequence_length: # Prepare variables + zero_output_state = ( + array_ops.zeros(array_ops.pack([batch_size, cell.output_size]), + inputs[0].dtype), + array_ops.zeros(array_ops.pack([batch_size, cell.state_size]), + state.dtype)) + max_sequence_length = math_ops.reduce_max(sequence_length) + + for time, input_ in enumerate(inputs): + if time > 0: vs.get_variable_scope().reuse_variables() + # pylint: disable=cell-var-from-loop + def output_state(): + return cell(input_, state) + # pylint: enable=cell-var-from-loop + if sequence_length: + (output, state) = control_flow_ops.cond( + time >= max_sequence_length, + lambda: zero_output_state, output_state) + else: + (output, state) = output_state() + + outputs.append(output) + states.append(state) + + return (outputs, states) + + +def state_saving_rnn(cell, inputs, state_saver, state_name, + sequence_length=None, scope=None): + """RNN that accepts a state saver for time-truncated RNN calculation. + + Args: + cell: An instance of RNNCell. + inputs: A length T list of inputs, each a vector with shape [batch_size]. + state_saver: A state saver object with methods `state` and `save_state`. + state_name: The name to use with the state_saver. + sequence_length: (optional) An int64 vector (tensor) size [batch_size]. + See the documentation for rnn() for more details about sequence_length. + scope: VariableScope for the created subgraph; defaults to "RNN". + + Returns: + A pair (outputs, states) where: + outputs is a length T list of outputs (one for each input) + states is a length T list of states (one state following each input) + + Raises: + TypeError: If "cell" is not an instance of RNNCell. + ValueError: If inputs is None or an empty list. + """ + initial_state = state_saver.state(state_name) + (outputs, states) = rnn(cell, inputs, initial_state=initial_state, + sequence_length=sequence_length, scope=scope) + save_state = state_saver.save_state(state_name, states[-1]) + with ops.control_dependencies([save_state]): + outputs[-1] = array_ops.identity(outputs[-1]) + + return (outputs, states) diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py new file mode 100644 index 00000000000..584849236a9 --- /dev/null +++ b/tensorflow/python/ops/rnn_cell.py @@ -0,0 +1,685 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Module for constructing RNN Cells.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variable_scope as vs + +from tensorflow.python.ops.math_ops import sigmoid +from tensorflow.python.ops.math_ops import tanh + + +class RNNCell(object): + """Abstract object representing an RNN cell. + + An RNN cell, in the most abstract setting, is anything that has + a state -- a vector of floats of size self.state_size -- and performs some + operation that takes inputs of size self.input_size. This operation + results in an output of size self.output_size and a new state. + + This module provides a number of basic commonly used RNN cells, such as + LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number + of operators that allow add dropouts, projections, or embeddings for inputs. + Constructing multi-layer cells is supported by a super-class, MultiRNNCell, + defined later. Every RNNCell must have the properties below and and + implement __call__ with the following signature. + """ + + def __call__(self, inputs, state, scope=None): + """Run this RNN cell on inputs, starting from the given state. + + Args: + inputs: 2D Tensor with shape [batch_size x self.input_size]. + state: 2D Tensor with shape [batch_size x self.state_size]. + scope: VariableScope for the created subgraph; defaults to class name. + + Returns: + A pair containing: + - Output: A 2D Tensor with shape [batch_size x self.output_size] + - New state: A 2D Tensor with shape [batch_size x self.state_size]. + """ + raise NotImplementedError("Abstract method") + + @property + def input_size(self): + """Integer: size of inputs accepted by this cell.""" + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + @property + def state_size(self): + """Integer: size of state used by this cell.""" + raise NotImplementedError("Abstract method") + + def zero_state(self, batch_size, dtype): + """Return state tensor (shape [batch_size x state_size]) filled with 0. + + Args: + batch_size: int, float, or unit Tensor representing the batch size. + dtype: the data type to use for the state. + + Returns: + A 2D Tensor of shape [batch_size x state_size] filled with zeros. + """ + zeros = array_ops.zeros( + array_ops.pack([batch_size, self.state_size]), dtype=dtype) + zeros.set_shape([None, self.state_size]) + return zeros + + +class BasicRNNCell(RNNCell): + """The most basic RNN cell.""" + + def __init__(self, num_units): + self._num_units = num_units + + @property + def input_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @property + def state_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None): + """Most basic RNN: output = new_state = tanh(W * input + U * state + B).""" + with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell" + output = tanh(linear([inputs, state], self._num_units, True)) + return output, output + + +class GRUCell(RNNCell): + """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" + + def __init__(self, num_units): + self._num_units = num_units + + @property + def input_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @property + def state_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None): + """Gated recurrent unit (GRU) with nunits cells.""" + with vs.variable_scope(scope or type(self).__name__): # "GRUCell" + with vs.variable_scope("Gates"): # Reset gate and update gate. + # We start with bias of 1.0 to not reset and not udpate. + r, u = array_ops.split(1, 2, linear([inputs, state], + 2 * self._num_units, True, 1.0)) + r, u = sigmoid(r), sigmoid(u) + with vs.variable_scope("Candidate"): + c = tanh(linear([inputs, r * state], self._num_units, True)) + new_h = u * state + (1 - u) * c + return new_h, new_h + + +class BasicLSTMCell(RNNCell): + """Basic LSTM recurrent network cell. + + The implementation is based on: http://arxiv.org/pdf/1409.2329v5.pdf. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + + Biases of the forget gate are initialized by default to 1 in order to reduce + the scale of forgetting in the beginning of the training. + """ + + def __init__(self, num_units, forget_bias=1.0): + self._num_units = num_units + self._forget_bias = forget_bias + + @property + def input_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + @property + def state_size(self): + return 2 * self._num_units + + def __call__(self, inputs, state, scope=None): + """Long short-term memory cell (LSTM).""" + with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" + # Parameters of gates are concatenated into one multiply for efficiency. + c, h = array_ops.split(1, 2, state) + concat = linear([inputs, h], 4 * self._num_units, True) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split(1, 4, concat) + + new_c = c * sigmoid(f + self._forget_bias) + sigmoid(i) * tanh(j) + new_h = tanh(new_c) * sigmoid(o) + + return new_h, array_ops.concat(1, [new_c, new_h]) + + +class LSTMCell(RNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + This implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + + It uses peep-hole connections, optional cell clipping, and an optional + projection layer. + """ + + def __init__(self, num_units, input_size, + use_peepholes=False, cell_clip=None, + initializer=None, num_proj=None, + num_unit_shards=1, num_proj_shards=1): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell + input_size: int, The dimensionality of the inputs into the LSTM cell + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + num_unit_shards: How to split the weight matrix. If >1, the weight + matrix is stored across num_unit_shards. + Note that num_unit_shards must evenly divide num_units * 4. + num_proj_shards: How to split the projection matrix. If >1, the + projection matrix is stored across num_proj_shards. + Note that num_proj_shards must evenly divide num_proj + (if num_proj is not None). + + Raises: + ValueError: if num_unit_shards doesn't divide 4 * num_units or + num_proj_shards doesn't divide num_proj + """ + self._num_units = num_units + self._input_size = input_size + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializer + self._num_proj = num_proj + self._num_unit_shards = num_unit_shards + self._num_proj_shards = num_proj_shards + + if (num_units * 4) % num_unit_shards != 0: + raise ValueError("num_unit_shards must evently divide 4 * num_units") + if num_proj and num_proj % num_proj_shards != 0: + raise ValueError("num_proj_shards must evently divide num_proj") + + if num_proj: + self._state_size = num_units + num_proj + self._output_size = num_proj + else: + self._state_size = 2 * num_units + self._output_size = num_units + + @property + def input_size(self): + return self._input_size + + @property + def output_size(self): + return self._output_size + + @property + def state_size(self): + return self._state_size + + def __call__(self, input_, state, scope=None): + """Run one step of LSTM. + + Args: + input_: input Tensor, 2D, batch x num_units. + state: state Tensor, 2D, batch x state_size. + scope: VariableScope for the created subgraph; defaults to "LSTMCell". + + Returns: + A tuple containing: + - A 2D, batch x output_dim, Tensor representing the output of the LSTM + after reading "input_" when previous state was "state". + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - A 2D, batch x state_size, Tensor representing the new state of LSTM + after reading "input_" when previous state was "state". + """ + num_proj = self._num_units if self._num_proj is None else self._num_proj + + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + dtype = input_.dtype + + unit_shard_size = (4 * self._num_units) // self._num_unit_shards + + with vs.variable_scope(scope or type(self).__name__): # "LSTMCell" + w = array_ops.concat( + 1, + [vs.get_variable("W_%d" % i, + shape=[self.input_size + num_proj, unit_shard_size], + initializer=self._initializer, + dtype=dtype) for i in xrange(self._num_unit_shards)]) + + b = vs.get_variable( + "B", shape=[4 * self._num_units], + initializer=array_ops.zeros_initializer, dtype=dtype) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + cell_inputs = array_ops.concat(1, [input_, m_prev]) + i, j, f, o = array_ops.split( + 1, 4, nn_ops.bias_add(math_ops.matmul(cell_inputs, w), b)) + + # Diagonal connections + if self._use_peepholes: + w_f_diag = vs.get_variable( + "W_F_diag", shape=[self._num_units], dtype=dtype) + w_i_diag = vs.get_variable( + "W_I_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = vs.get_variable( + "W_O_diag", shape=[self._num_units], dtype=dtype) + + if self._use_peepholes: + c = (sigmoid(f + 1 + w_f_diag * c_prev) * c_prev + + sigmoid(i + w_i_diag * c_prev) * tanh(j)) + else: + c = (sigmoid(f + 1) * c_prev + sigmoid(i) * tanh(j)) + + if self._cell_clip is not None: + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + + if self._use_peepholes: + m = sigmoid(o + w_o_diag * c) * tanh(c) + else: + m = sigmoid(o) * tanh(c) + + if self._num_proj is not None: + proj_shard_size = self._num_proj // self._num_proj_shards + w_proj = array_ops.concat( + 1, + [vs.get_variable("W_P_%d" % i, + shape=[self._num_units, proj_shard_size], + initializer=self._initializer, + dtype=dtype) + for i in xrange(self._num_proj_shards)]) + # TODO(ebrevdo), use matmulsum + m = math_ops.matmul(m, w_proj) + + return m, array_ops.concat(1, [c, m]) + + +class OutputProjectionWrapper(RNNCell): + """Operator adding an output projection to the given cell. + + Note: in many cases it may be more efficient to not use this wrapper, + but instead concatenate the whole sequence of your outputs in time, + do the projection on this batch-concated sequence, then split it + if needed or directly feed into a softmax. + """ + + def __init__(self, cell, output_size): + """Create a cell with output projection. + + Args: + cell: an RNNCell, a projection to output_size is added to it. + output_size: integer, the size of the output after projection. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if output_size is not positive. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not RNNCell.") + if output_size < 1: + raise ValueError("Parameter output_size must be > 0: %d." % output_size) + self._cell = cell + self._output_size = output_size + + @property + def input_size(self): + return self._cell.input_size + + @property + def output_size(self): + return self._output_size + + @property + def state_size(self): + return self._cell.state_size + + def __call__(self, inputs, state, scope=None): + """Run the cell and output projection on inputs, starting from state.""" + output, res_state = self._cell(inputs, state) + # Default scope: "OutputProjectionWrapper" + with vs.variable_scope(scope or type(self).__name__): + projected = linear(output, self._output_size, True) + return projected, res_state + + +class InputProjectionWrapper(RNNCell): + """Operator adding an input projection to the given cell. + + Note: in many cases it may be more efficient to not use this wrapper, + but instead concatenate the whole sequence of your inputs in time, + do the projection on this batch-concated sequence, then split it. + """ + + def __init__(self, cell, input_size): + """Create a cell with input projection. + + Args: + cell: an RNNCell, a projection of inputs is added before it. + input_size: integer, the size of the inputs before projection. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if input_size is not positive. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not RNNCell.") + if input_size < 1: + raise ValueError("Parameter input_size must be > 0: %d." % input_size) + self._cell = cell + self._input_size = input_size + + @property + def input_size(self): + return self._input_size + + @property + def output_size(self): + return self._cell.output_size + + @property + def state_size(self): + return self._cell.state_size + + def __call__(self, inputs, state, scope=None): + """Run the input projection and then the cell.""" + # Default scope: "InputProjectionWrapper" + with vs.variable_scope(scope or type(self).__name__): + projected = linear(inputs, self._cell.input_size, True) + return self._cell(projected, state) + + +class DropoutWrapper(RNNCell): + """Operator adding dropout to inputs and outputs of the given cell.""" + + def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, + seed=None): + """Create a cell with added input and/or output dropout. + + Dropout is never used on the state. + + Args: + cell: an RNNCell, a projection to output_size is added to it. + input_keep_prob: unit Tensor or float between 0 and 1, input keep + probability; if it is float and 1, no input dropout will be added. + output_keep_prob: unit Tensor or float between 0 and 1, output keep + probability; if it is float and 1, no output dropout will be added. + seed: (optional) integer, the randomness seed. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if keep_prob is not between 0 and 1. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not a RNNCell.") + if (isinstance(input_keep_prob, float) and + not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)): + raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d" + % input_keep_prob) + if (isinstance(output_keep_prob, float) and + not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)): + raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d" + % output_keep_prob) + self._cell = cell + self._input_keep_prob = input_keep_prob + self._output_keep_prob = output_keep_prob + self._seed = seed + + @property + def input_size(self): + return self._cell.input_size + + @property + def output_size(self): + return self._cell.output_size + + @property + def state_size(self): + return self._cell.state_size + + def __call__(self, inputs, state): + """Run the cell with the declared dropouts.""" + if (not isinstance(self._input_keep_prob, float) or + self._input_keep_prob < 1): + inputs = nn_ops.dropout(inputs, self._input_keep_prob, seed=self._seed) + output, new_state = self._cell(inputs, state) + if (not isinstance(self._output_keep_prob, float) or + self._output_keep_prob < 1): + output = nn_ops.dropout(output, self._output_keep_prob, seed=self._seed) + return output, new_state + + +class EmbeddingWrapper(RNNCell): + """Operator adding input embedding to the given cell. + + Note: in many cases it may be more efficient to not use this wrapper, + but instead concatenate the whole sequence of your inputs in time, + do the embedding on this batch-concated sequence, then split it and + feed into your RNN. + """ + + def __init__(self, cell, embedding_classes=0, embedding=None, + initializer=None): + """Create a cell with an added input embedding. + + Args: + cell: an RNNCell, an embedding will be put before its inputs. + embedding_classes: integer, how many symbols will be embedded. + embedding: Variable, the embedding to use; if None, a new embedding + will be created; if set, then embedding_classes is not required. + initializer: an initializer to use when creating the embedding; + if None, the initializer from variable scope or a default one is used. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if embedding_classes is not positive. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not RNNCell.") + if embedding_classes < 1 and embedding is None: + raise ValueError("Pass embedding or embedding_classes must be > 0: %d." + % embedding_classes) + if embedding_classes > 0 and embedding is not None: + if embedding.size[0] != embedding_classes: + raise ValueError("You declared embedding_classes=%d but passed an " + "embedding for %d classes." % (embedding.size[0], + embedding_classes)) + if embedding.size[1] != cell.input_size: + raise ValueError("You passed embedding with output size %d and a cell" + " that accepts size %d." % (embedding.size[1], + cell.input_size)) + self._cell = cell + self._embedding_classes = embedding_classes + self._embedding = embedding + self._initializer = initializer + + @property + def input_size(self): + return 1 + + @property + def output_size(self): + return self._cell.output_size + + @property + def state_size(self): + return self._cell.state_size + + def __call__(self, inputs, state, scope=None): + """Run the cell on embedded inputs.""" + with vs.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper" + with ops.device("/cpu:0"): + if self._embedding: + embedding = self._embedding + else: + if self._initializer: + initializer = self._initializer + elif vs.get_variable_scope().initializer: + initializer = vs.get_variable_scope().initializer + else: + # Default initializer for embeddings should have variance=1. + sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. + initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3) + embedding = vs.get_variable("embedding", [self._embedding_classes, + self._cell.input_size], + initializer=initializer) + embedded = embedding_ops.embedding_lookup( + embedding, array_ops.reshape(inputs, [-1])) + return self._cell(embedded, state) + + +class MultiRNNCell(RNNCell): + """RNN cell composed sequentially of multiple simple cells.""" + + def __init__(self, cells): + """Create a RNN cell composed sequentially of a number of RNNCells. + + Args: + cells: list of RNNCells that will be composed in this order. + + Raises: + ValueError: if cells is empty (not allowed) or if their sizes don't match. + """ + if not cells: + raise ValueError("Must specify at least one cell for MultiRNNCell.") + for i in xrange(len(cells) - 1): + if cells[i + 1].input_size != cells[i].output_size: + raise ValueError("In MultiRNNCell, the input size of each next" + " cell must match the output size of the previous one." + " Mismatched output size in cell %d." % i) + self._cells = cells + + @property + def input_size(self): + return self._cells[0].input_size + + @property + def output_size(self): + return self._cells[-1].output_size + + @property + def state_size(self): + return sum([cell.state_size for cell in self._cells]) + + def __call__(self, inputs, state, scope=None): + """Run this multi-layer cell on inputs, starting from state.""" + with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell" + cur_state_pos = 0 + cur_inp = inputs + new_states = [] + for i, cell in enumerate(self._cells): + with vs.variable_scope("Cell%d" % i): + cur_state = array_ops.slice( + state, [0, cur_state_pos], [-1, cell.state_size]) + cur_state_pos += cell.state_size + cur_inp, new_state = cell(cur_inp, cur_state) + new_states.append(new_state) + return cur_inp, array_ops.concat(1, new_states) + + +def linear(args, output_size, bias, bias_start=0.0, scope=None): + """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. + + Args: + args: a 2D Tensor or a list of 2D, batch x n, Tensors. + output_size: int, second dimension of W[i]. + bias: boolean, whether to add a bias term or not. + bias_start: starting value to initialize the bias; 0 by default. + scope: VariableScope for the created subgraph; defaults to "Linear". + + Returns: + A 2D Tensor with shape [batch x output_size] equal to + sum_i(args[i] * W[i]), where W[i]s are newly created matrices. + + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + assert args + if not isinstance(args, (list, tuple)): + args = [args] + + # Calculate the total size of arguments on dimension 1. + total_arg_size = 0 + shapes = [a.get_shape().as_list() for a in args] + for shape in shapes: + if len(shape) != 2: + raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) + if not shape[1]: + raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) + else: + total_arg_size += shape[1] + + # Now the computation. + with vs.variable_scope(scope or "Linear"): + matrix = vs.get_variable("Matrix", [total_arg_size, output_size]) + if len(args) == 1: + res = math_ops.matmul(args[0], matrix) + else: + res = math_ops.matmul(array_ops.concat(1, args), matrix) + if not bias: + return res + bias_term = vs.get_variable( + "Bias", [output_size], + initializer=init_ops.constant_initializer(bias_start)) + return res + bias_term diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py new file mode 100644 index 00000000000..131524b77c5 --- /dev/null +++ b/tensorflow/python/ops/seq2seq.py @@ -0,0 +1,784 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Library for creating sequence-to-sequence models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope as vs + + +def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, + scope=None): + """RNN decoder for the sequence-to-sequence model. + + Args: + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + initial_state: 2D Tensor with shape [batch_size x cell.state_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + loop_function: if not None, this function will be applied to i-th output + in order to generate i+1-th input, and decoder_inputs will be ignored, + except for the first element ("GO" symbol). This can be used for decoding, + but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. + Signature -- loop_function(prev, i) = next + * prev is a 2D Tensor of shape [batch_size x cell.output_size], + * i is an integer, the step number (when advanced control is needed), + * next is a 2D Tensor of shape [batch_size x cell.input_size]. + scope: VariableScope for the created subgraph; defaults to "rnn_decoder". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x cell.output_size] containing generated outputs. + states: The state of each cell in each time-step. This is a list with + length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + (Note that in some cases, like basic RNN cell or GRU cell, outputs and + states can be the same. They are different for LSTM cells though.) + """ + with vs.variable_scope(scope or "rnn_decoder"): + states = [initial_state] + outputs = [] + prev = None + for i in xrange(len(decoder_inputs)): + inp = decoder_inputs[i] + if loop_function is not None and prev is not None: + with vs.variable_scope("loop_function", reuse=True): + # We do not propagate gradients over the loop function. + inp = array_ops.stop_gradient(loop_function(prev, i)) + if i > 0: + vs.get_variable_scope().reuse_variables() + output, new_state = cell(inp, states[-1]) + outputs.append(output) + states.append(new_state) + if loop_function is not None: + prev = array_ops.stop_gradient(output) + return outputs, states + + +def basic_rnn_seq2seq( + encoder_inputs, decoder_inputs, cell, dtype=dtypes.float32, scope=None): + """Basic RNN sequence-to-sequence model. + + This model first runs an RNN to encode encoder_inputs into a state vector, and + then runs decoder, initialized with the last encoder state, on decoder_inputs. + Encoder and decoder use the same RNN cell type, but don't share parameters. + + Args: + encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + dtype: The dtype of the initial state of the RNN cell (default: tf.float32). + scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x cell.output_size] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + """ + with vs.variable_scope(scope or "basic_rnn_seq2seq"): + _, enc_states = rnn.rnn(cell, encoder_inputs, dtype=dtype) + return rnn_decoder(decoder_inputs, enc_states[-1], cell) + + +def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, + loop_function=None, dtype=dtypes.float32, scope=None): + """RNN sequence-to-sequence model with tied encoder and decoder parameters. + + This model first runs an RNN to encode encoder_inputs into a state vector, and + then runs decoder, initialized with the last encoder state, on decoder_inputs. + Encoder and decoder use the same RNN cell and share parameters. + + Args: + encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + loop_function: if not None, this function will be applied to i-th output + in order to generate i+1-th input, and decoder_inputs will be ignored, + except for the first element ("GO" symbol), see rnn_decoder for details. + dtype: The dtype of the initial state of the rnn cell (default: tf.float32). + scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x cell.output_size] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + """ + with vs.variable_scope("combined_tied_rnn_seq2seq"): + scope = scope or "tied_rnn_seq2seq" + _, enc_states = rnn.rnn( + cell, encoder_inputs, dtype=dtype, scope=scope) + vs.get_variable_scope().reuse_variables() + return rnn_decoder(decoder_inputs, enc_states[-1], cell, + loop_function=loop_function, scope=scope) + + +def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols, + output_projection=None, feed_previous=False, + scope=None): + """RNN decoder with embedding and a pure-decoding option. + + Args: + decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs). + initial_state: 2D Tensor [batch_size x cell.state_size]. + cell: rnn_cell.RNNCell defining the cell function. + num_symbols: integer, how many symbols come into the embedding. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [cell.output_size x num_symbols] and B has + shape [num_symbols]; if provided and feed_previous=True, each fed + previous output will first be multiplied by W and added B. + feed_previous: Boolean; if True, only the first of decoder_inputs will be + used (the "GO" symbol), and all other decoder inputs will be generated by: + next = embedding_lookup(embedding, argmax(previous_output)), + In effect, this implements a greedy decoder. It can also be used + during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. + If False, decoder_inputs are used as given (the standard decoder case). + scope: VariableScope for the created subgraph; defaults to + "embedding_rnn_decoder". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x cell.output_size] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + + Raises: + ValueError: when output_projection has the wrong shape. + """ + if output_projection is not None: + proj_weights = ops.convert_to_tensor( + output_projection[0], dtype=dtypes.float32) + proj_weights.get_shape().assert_is_compatible_with([cell.output_size, + num_symbols]) + proj_biases = ops.convert_to_tensor( + output_projection[1], dtype=dtypes.float32) + proj_biases.get_shape().assert_is_compatible_with([num_symbols]) + + with vs.variable_scope(scope or "embedding_rnn_decoder"): + with ops.device("/cpu:0"): + embedding = vs.get_variable("embedding", [num_symbols, cell.input_size]) + + def extract_argmax_and_embed(prev, _): + """Loop_function that extracts the symbol from prev and embeds it.""" + if output_projection is not None: + prev = nn_ops.xw_plus_b( + prev, output_projection[0], output_projection[1]) + prev_symbol = array_ops.stop_gradient(math_ops.argmax(prev, 1)) + return embedding_ops.embedding_lookup(embedding, prev_symbol) + + loop_function = None + if feed_previous: + loop_function = extract_argmax_and_embed + + emb_inp = [ + embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] + return rnn_decoder(emb_inp, initial_state, cell, + loop_function=loop_function) + + +def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, + num_encoder_symbols, num_decoder_symbols, + output_projection=None, feed_previous=False, + dtype=dtypes.float32, scope=None): + """Embedding RNN sequence-to-sequence model. + + This model first embeds encoder_inputs by a newly created embedding (of shape + [num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode + embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs + by another newly created embedding (of shape [num_decoder_symbols x + cell.input_size]). Then it runs RNN decoder, initialized with the last + encoder state, on embedded decoder_inputs. + + Args: + encoder_inputs: a list of 1D int32-Tensors of shape [batch_size]. + decoder_inputs: a list of 1D int32-Tensors of shape [batch_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + num_encoder_symbols: integer; number of symbols on the encoder side. + num_decoder_symbols: integer; number of symbols on the decoder side. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [cell.output_size x num_decoder_symbols] and B has + shape [num_decoder_symbols]; if provided and feed_previous=True, each + fed previous output will first be multiplied by W and added B. + feed_previous: Boolean or scalar Boolean Tensor; if True, only the first + of decoder_inputs will be used (the "GO" symbol), and all other decoder + inputs will be taken from previous outputs (as in embedding_rnn_decoder). + If False, decoder_inputs are used as given (the standard decoder case). + dtype: The dtype of the initial state for both the encoder and encoder + rnn cells (default: tf.float32). + scope: VariableScope for the created subgraph; defaults to + "embedding_rnn_seq2seq" + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x num_decoder_symbols] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + """ + with vs.variable_scope(scope or "embedding_rnn_seq2seq"): + # Encoder. + encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols) + _, encoder_states = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) + + # Decoder. + if output_projection is None: + cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) + + if isinstance(feed_previous, bool): + return embedding_rnn_decoder(decoder_inputs, encoder_states[-1], cell, + num_decoder_symbols, output_projection, + feed_previous) + else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. + outputs1, states1 = embedding_rnn_decoder( + decoder_inputs, encoder_states[-1], cell, num_decoder_symbols, + output_projection, True) + vs.get_variable_scope().reuse_variables() + outputs2, states2 = embedding_rnn_decoder( + decoder_inputs, encoder_states[-1], cell, num_decoder_symbols, + output_projection, False) + + outputs = control_flow_ops.cond(feed_previous, + lambda: outputs1, lambda: outputs2) + states = control_flow_ops.cond(feed_previous, + lambda: states1, lambda: states2) + return outputs, states + + +def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, + num_symbols, output_projection=None, + feed_previous=False, dtype=dtypes.float32, + scope=None): + """Embedding RNN sequence-to-sequence model with tied (shared) parameters. + + This model first embeds encoder_inputs by a newly created embedding (of shape + [num_symbols x cell.input_size]). Then it runs an RNN to encode embedded + encoder_inputs into a state vector. Next, it embeds decoder_inputs using + the same embedding. Then it runs RNN decoder, initialized with the last + encoder state, on embedded decoder_inputs. + + Args: + encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + num_symbols: integer; number of symbols for both encoder and decoder. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [cell.output_size x num_symbols] and B has + shape [num_symbols]; if provided and feed_previous=True, each + fed previous output will first be multiplied by W and added B. + feed_previous: Boolean or scalar Boolean Tensor; if True, only the first + of decoder_inputs will be used (the "GO" symbol), and all other decoder + inputs will be taken from previous outputs (as in embedding_rnn_decoder). + If False, decoder_inputs are used as given (the standard decoder case). + dtype: The dtype to use for the initial RNN states (default: tf.float32). + scope: VariableScope for the created subgraph; defaults to + "embedding_tied_rnn_seq2seq". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x num_decoder_symbols] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + + Raises: + ValueError: when output_projection has the wrong shape. + """ + if output_projection is not None: + proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) + proj_weights.get_shape().assert_is_compatible_with([cell.output_size, + num_symbols]) + proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) + proj_biases.get_shape().assert_is_compatible_with([num_symbols]) + + with vs.variable_scope(scope or "embedding_tied_rnn_seq2seq"): + with ops.device("/cpu:0"): + embedding = vs.get_variable("embedding", [num_symbols, cell.input_size]) + + emb_encoder_inputs = [embedding_ops.embedding_lookup(embedding, x) + for x in encoder_inputs] + emb_decoder_inputs = [embedding_ops.embedding_lookup(embedding, x) + for x in decoder_inputs] + + def extract_argmax_and_embed(prev, _): + """Loop_function that extracts the symbol from prev and embeds it.""" + if output_projection is not None: + prev = nn_ops.xw_plus_b( + prev, output_projection[0], output_projection[1]) + prev_symbol = array_ops.stop_gradient(math_ops.argmax(prev, 1)) + return embedding_ops.embedding_lookup(embedding, prev_symbol) + + if output_projection is None: + cell = rnn_cell.OutputProjectionWrapper(cell, num_symbols) + + if isinstance(feed_previous, bool): + loop_function = extract_argmax_and_embed if feed_previous else None + return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell, + loop_function=loop_function, dtype=dtype) + else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. + outputs1, states1 = tied_rnn_seq2seq( + emb_encoder_inputs, emb_decoder_inputs, cell, + loop_function=extract_argmax_and_embed, dtype=dtype) + vs.get_variable_scope().reuse_variables() + outputs2, states2 = tied_rnn_seq2seq( + emb_encoder_inputs, emb_decoder_inputs, cell, dtype=dtype) + + outputs = control_flow_ops.cond(feed_previous, + lambda: outputs1, lambda: outputs2) + states = control_flow_ops.cond(feed_previous, + lambda: states1, lambda: states2) + return outputs, states + + +def attention_decoder(decoder_inputs, initial_state, attention_states, cell, + output_size=None, num_heads=1, loop_function=None, + dtype=dtypes.float32, scope=None): + """RNN decoder with attention for the sequence-to-sequence model. + + Args: + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + initial_state: 2D Tensor [batch_size x cell.state_size]. + attention_states: 3D Tensor [batch_size x attn_length x attn_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + output_size: size of the output vectors; if None, we use cell.output_size. + num_heads: number of attention heads that read from attention_states. + loop_function: if not None, this function will be applied to i-th output + in order to generate i+1-th input, and decoder_inputs will be ignored, + except for the first element ("GO" symbol). This can be used for decoding, + but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. + Signature -- loop_function(prev, i) = next + * prev is a 2D Tensor of shape [batch_size x cell.output_size], + * i is an integer, the step number (when advanced control is needed), + * next is a 2D Tensor of shape [batch_size x cell.input_size]. + dtype: The dtype to use for the RNN initial state (default: tf.float32). + scope: VariableScope for the created subgraph; default: "attention_decoder". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors of shape + [batch_size x output_size]. These represent the generated outputs. + Output i is computed from input i (which is either i-th decoder_inputs or + loop_function(output {i-1}, i)) as follows. First, we run the cell + on a combination of the input and previous attention masks: + cell_output, new_state = cell(linear(input, prev_attn), prev_state). + Then, we calculate new attention masks: + new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) + and then we calculate the output: + output = linear(cell_output, new_attn). + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + + Raises: + ValueError: when num_heads is not positive, there are no inputs, or shapes + of attention_states are not set. + """ + if not decoder_inputs: + raise ValueError("Must provide at least 1 input to attention decoder.") + if num_heads < 1: + raise ValueError("With less than 1 heads, use a non-attention decoder.") + if not attention_states.get_shape()[1:2].is_fully_defined(): + raise ValueError("Shape[1] and [2] of attention_states must be known: %s" + % attention_states.get_shape()) + if output_size is None: + output_size = cell.output_size + + with vs.variable_scope(scope or "attention_decoder"): + batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. + attn_length = attention_states.get_shape()[1].value + attn_size = attention_states.get_shape()[2].value + + # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. + hidden = array_ops.reshape( + attention_states, [-1, attn_length, 1, attn_size]) + hidden_features = [] + v = [] + attention_vec_size = attn_size # Size of query vectors for attention. + for a in xrange(num_heads): + k = vs.get_variable("AttnW_%d" % a, [1, 1, attn_size, attention_vec_size]) + hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) + v.append(vs.get_variable("AttnV_%d" % a, [attention_vec_size])) + + states = [initial_state] + + def attention(query): + """Put attention masks on hidden using hidden_features and query.""" + ds = [] # Results of attention reads will be stored here. + for a in xrange(num_heads): + with vs.variable_scope("Attention_%d" % a): + y = rnn_cell.linear(query, attention_vec_size, True) + y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) + # Attention mask is a softmax of v^T * tanh(...). + s = math_ops.reduce_sum( + v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) + a = nn_ops.softmax(s) + # Now calculate the attention-weighted vector d. + d = math_ops.reduce_sum( + array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, + [1, 2]) + ds.append(array_ops.reshape(d, [-1, attn_size])) + return ds + + outputs = [] + prev = None + batch_attn_size = array_ops.pack([batch_size, attn_size]) + attns = [array_ops.zeros(batch_attn_size, dtype=dtype) + for _ in xrange(num_heads)] + for a in attns: # Ensure the second shape of attention vectors is set. + a.set_shape([None, attn_size]) + for i in xrange(len(decoder_inputs)): + if i > 0: + vs.get_variable_scope().reuse_variables() + inp = decoder_inputs[i] + # If loop_function is set, we use it instead of decoder_inputs. + if loop_function is not None and prev is not None: + with vs.variable_scope("loop_function", reuse=True): + inp = array_ops.stop_gradient(loop_function(prev, i)) + # Merge input and previous attentions into one vector of the right size. + x = rnn_cell.linear([inp] + attns, cell.input_size, True) + # Run the RNN. + cell_output, new_state = cell(x, states[-1]) + states.append(new_state) + # Run the attention mechanism. + attns = attention(new_state) + with vs.variable_scope("AttnOutputProjection"): + output = rnn_cell.linear([cell_output] + attns, output_size, True) + if loop_function is not None: + # We do not propagate gradients over the loop function. + prev = array_ops.stop_gradient(output) + outputs.append(output) + + return outputs, states + + +def embedding_attention_decoder(decoder_inputs, initial_state, attention_states, + cell, num_symbols, num_heads=1, + output_size=None, output_projection=None, + feed_previous=False, dtype=dtypes.float32, + scope=None): + """RNN decoder with embedding and attention and a pure-decoding option. + + Args: + decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs). + initial_state: 2D Tensor [batch_size x cell.state_size]. + attention_states: 3D Tensor [batch_size x attn_length x attn_size]. + cell: rnn_cell.RNNCell defining the cell function. + num_symbols: integer, how many symbols come into the embedding. + num_heads: number of attention heads that read from attention_states. + output_size: size of the output vectors; if None, use cell.output_size. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [output_size x num_symbols] and B has shape + [num_symbols]; if provided and feed_previous=True, each fed previous + output will first be multiplied by W and added B. + feed_previous: Boolean; if True, only the first of decoder_inputs will be + used (the "GO" symbol), and all other decoder inputs will be generated by: + next = embedding_lookup(embedding, argmax(previous_output)), + In effect, this implements a greedy decoder. It can also be used + during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. + If False, decoder_inputs are used as given (the standard decoder case). + dtype: The dtype to use for the RNN initial states (default: tf.float32). + scope: VariableScope for the created subgraph; defaults to + "embedding_attention_decoder". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x output_size] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + + Raises: + ValueError: when output_projection has the wrong shape. + """ + if output_size is None: + output_size = cell.output_size + if output_projection is not None: + proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) + proj_weights.get_shape().assert_is_compatible_with([cell.output_size, + num_symbols]) + proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) + proj_biases.get_shape().assert_is_compatible_with([num_symbols]) + + with vs.variable_scope(scope or "embedding_attention_decoder"): + with ops.device("/cpu:0"): + embedding = vs.get_variable("embedding", [num_symbols, cell.input_size]) + + def extract_argmax_and_embed(prev, _): + """Loop_function that extracts the symbol from prev and embeds it.""" + if output_projection is not None: + prev = nn_ops.xw_plus_b( + prev, output_projection[0], output_projection[1]) + prev_symbol = array_ops.stop_gradient(math_ops.argmax(prev, 1)) + emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) + return emb_prev + + loop_function = None + if feed_previous: + loop_function = extract_argmax_and_embed + + emb_inp = [ + embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] + return attention_decoder( + emb_inp, initial_state, attention_states, cell, output_size=output_size, + num_heads=num_heads, loop_function=loop_function) + + +def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell, + num_encoder_symbols, num_decoder_symbols, + num_heads=1, output_projection=None, + feed_previous=False, dtype=dtypes.float32, + scope=None): + """Embedding sequence-to-sequence model with attention. + + This model first embeds encoder_inputs by a newly created embedding (of shape + [num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode + embedded encoder_inputs into a state vector. It keeps the outputs of this + RNN at every step to use for attention later. Next, it embeds decoder_inputs + by another newly created embedding (of shape [num_decoder_symbols x + cell.input_size]). Then it runs attention decoder, initialized with the last + encoder state, on embedded decoder_inputs and attending to encoder outputs. + + Args: + encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + cell: rnn_cell.RNNCell defining the cell function and size. + num_encoder_symbols: integer; number of symbols on the encoder side. + num_decoder_symbols: integer; number of symbols on the decoder side. + num_heads: number of attention heads that read from attention_states. + output_projection: None or a pair (W, B) of output projection weights and + biases; W has shape [cell.output_size x num_decoder_symbols] and B has + shape [num_decoder_symbols]; if provided and feed_previous=True, each + fed previous output will first be multiplied by W and added B. + feed_previous: Boolean or scalar Boolean Tensor; if True, only the first + of decoder_inputs will be used (the "GO" symbol), and all other decoder + inputs will be taken from previous outputs (as in embedding_rnn_decoder). + If False, decoder_inputs are used as given (the standard decoder case). + dtype: The dtype of the initial RNN state (default: tf.float32). + scope: VariableScope for the created subgraph; defaults to + "embedding_attention_seq2seq". + + Returns: + outputs: A list of the same length as decoder_inputs of 2D Tensors with + shape [batch_size x num_decoder_symbols] containing the generated outputs. + states: The state of each decoder cell in each time-step. This is a list + with length len(decoder_inputs) -- one item for each time-step. + Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + """ + with vs.variable_scope(scope or "embedding_attention_seq2seq"): + # Encoder. + encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols) + encoder_outputs, encoder_states = rnn.rnn( + encoder_cell, encoder_inputs, dtype=dtype) + + # First calculate a concatenation of encoder outputs to put attention on. + top_states = [array_ops.reshape(e, [-1, 1, cell.output_size]) + for e in encoder_outputs] + attention_states = array_ops.concat(1, top_states) + + # Decoder. + output_size = None + if output_projection is None: + cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) + output_size = num_decoder_symbols + + if isinstance(feed_previous, bool): + return embedding_attention_decoder( + decoder_inputs, encoder_states[-1], attention_states, cell, + num_decoder_symbols, num_heads, output_size, output_projection, + feed_previous) + else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. + outputs1, states1 = embedding_attention_decoder( + decoder_inputs, encoder_states[-1], attention_states, cell, + num_decoder_symbols, num_heads, output_size, output_projection, True) + vs.get_variable_scope().reuse_variables() + outputs2, states2 = embedding_attention_decoder( + decoder_inputs, encoder_states[-1], attention_states, cell, + num_decoder_symbols, num_heads, output_size, output_projection, False) + + outputs = control_flow_ops.cond(feed_previous, + lambda: outputs1, lambda: outputs2) + states = control_flow_ops.cond(feed_previous, + lambda: states1, lambda: states2) + return outputs, states + + +def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols, + average_across_timesteps=True, + softmax_loss_function=None, name=None): + """Weighted cross-entropy loss for a sequence of logits (per example). + + Args: + logits: list of 2D Tensors of shape [batch_size x num_decoder_symbols]. + targets: list of 1D batch-sized int32-Tensors of the same length as logits. + weights: list of 1D batch-sized float-Tensors of the same length as logits. + num_decoder_symbols: integer, number of decoder symbols (output classes). + average_across_timesteps: If set, divide the returned cost by the total + label weight. + softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch + to be used instead of the standard softmax (the default if this is None). + name: optional name for this operation, default: "sequence_loss_by_example". + + Returns: + 1D batch-sized float Tensor: the log-perplexity for each sequence. + + Raises: + ValueError: if len(logits) is different from len(targets) or len(weights). + """ + if len(targets) != len(logits) or len(weights) != len(logits): + raise ValueError("Lengths of logits, weights, and targets must be the same " + "%d, %d, %d." % (len(logits), len(weights), len(targets))) + with ops.op_scope(logits + targets + weights, name, + "sequence_loss_by_example"): + batch_size = array_ops.shape(targets[0])[0] + log_perp_list = [] + length = batch_size * num_decoder_symbols + for i in xrange(len(logits)): + if softmax_loss_function is None: + # TODO(lukaszkaiser): There is no SparseCrossEntropy in TensorFlow, so + # we need to first cast targets into a dense representation, and as + # SparseToDense does not accept batched inputs, we need to do this by + # re-indexing and re-sizing. When TensorFlow adds SparseCrossEntropy, + # rewrite this method. + indices = targets[i] + num_decoder_symbols * math_ops.range(batch_size) + with ops.device("/cpu:0"): # Sparse-to-dense must be on CPU for now. + dense = sparse_ops.sparse_to_dense( + indices, array_ops.expand_dims(length, 0), 1.0, + 0.0) + target = array_ops.reshape(dense, [-1, num_decoder_symbols]) + crossent = nn_ops.softmax_cross_entropy_with_logits( + logits[i], target, name="SequenceLoss/CrossEntropy{0}".format(i)) + else: + crossent = softmax_loss_function(logits[i], targets[i]) + log_perp_list.append(crossent * weights[i]) + log_perps = math_ops.add_n(log_perp_list) + if average_across_timesteps: + total_size = math_ops.add_n(weights) + total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. + log_perps /= total_size + return log_perps + + +def sequence_loss(logits, targets, weights, num_decoder_symbols, + average_across_timesteps=True, average_across_batch=True, + softmax_loss_function=None, name=None): + """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. + + Args: + logits: list of 2D Tensors os shape [batch_size x num_decoder_symbols]. + targets: list of 1D batch-sized int32-Tensors of the same length as logits. + weights: list of 1D batch-sized float-Tensors of the same length as logits. + num_decoder_symbols: integer, number of decoder symbols (output classes). + average_across_timesteps: If set, divide the returned cost by the total + label weight. + average_across_batch: If set, divide the returned cost by the batch size. + softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch + to be used instead of the standard softmax (the default if this is None). + name: optional name for this operation, defaults to "sequence_loss". + + Returns: + A scalar float Tensor: the average log-perplexity per symbol (weighted). + + Raises: + ValueError: if len(logits) is different from len(targets) or len(weights). + """ + with ops.op_scope(logits + targets + weights, name, "sequence_loss"): + cost = math_ops.reduce_sum(sequence_loss_by_example( + logits, targets, weights, num_decoder_symbols, + average_across_timesteps=average_across_timesteps, + softmax_loss_function=softmax_loss_function)) + if average_across_batch: + batch_size = array_ops.shape(targets[0])[0] + return cost / math_ops.cast(batch_size, dtypes.float32) + else: + return cost + + +def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, + buckets, num_decoder_symbols, seq2seq, + softmax_loss_function=None, name=None): + """Create a sequence-to-sequence model with support for bucketing. + + The seq2seq argument is a function that defines a sequence-to-sequence model, + e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) + + Args: + encoder_inputs: a list of Tensors to feed the encoder; first seq2seq input. + decoder_inputs: a list of Tensors to feed the decoder; second seq2seq input. + targets: a list of 1D batch-sized int32-Tensors (desired output sequence). + weights: list of 1D batch-sized float-Tensors to weight the targets. + buckets: a list of pairs of (input size, output size) for each bucket. + num_decoder_symbols: integer, number of decoder symbols (output classes). + seq2seq: a sequence-to-sequence model function; it takes 2 input that + agree with encoder_inputs and decoder_inputs, and returns a pair + consisting of outputs and states (as, e.g., basic_rnn_seq2seq). + softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch + to be used instead of the standard softmax (the default if this is None). + name: optional name for this operation, defaults to "model_with_buckets". + + Returns: + outputs: The outputs for each bucket. Its j'th element consists of a list + of 2D Tensors of shape [batch_size x num_decoder_symbols] (j'th outputs). + losses: List of scalar Tensors, representing losses for each bucket. + Raises: + ValueError: if length of encoder_inputsut, targets, or weights is smaller + than the largest (last) bucket. + """ + if len(encoder_inputs) < buckets[-1][0]: + raise ValueError("Length of encoder_inputs (%d) must be at least that of la" + "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) + if len(targets) < buckets[-1][1]: + raise ValueError("Length of targets (%d) must be at least that of last" + "bucket (%d)." % (len(targets), buckets[-1][1])) + if len(weights) < buckets[-1][1]: + raise ValueError("Length of weights (%d) must be at least that of last" + "bucket (%d)." % (len(weights), buckets[-1][1])) + + all_inputs = encoder_inputs + decoder_inputs + targets + weights + losses = [] + outputs = [] + with ops.op_scope(all_inputs, name, "model_with_buckets"): + for j in xrange(len(buckets)): + if j > 0: + vs.get_variable_scope().reuse_variables() + bucket_encoder_inputs = [encoder_inputs[i] + for i in xrange(buckets[j][0])] + bucket_decoder_inputs = [decoder_inputs[i] + for i in xrange(buckets[j][1])] + bucket_outputs, _ = seq2seq(bucket_encoder_inputs, + bucket_decoder_inputs) + outputs.append(bucket_outputs) + + bucket_targets = [targets[i] for i in xrange(buckets[j][1])] + bucket_weights = [weights[i] for i in xrange(buckets[j][1])] + losses.append(sequence_loss( + outputs[-1], bucket_targets, bucket_weights, num_decoder_symbols, + softmax_loss_function=softmax_loss_function)) + + return outputs, losses diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index fc37ac5ceb2..3840971d76a 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -139,6 +139,8 @@ class Variable(object): @@graph @@op """ + # TODO(touts): Add @@value and @@ref in the docstring above once they are + # ready for consumption. def __init__(self, initial_value, trainable=True, collections=None, validate_shape=True, name=None): @@ -199,6 +201,7 @@ class Variable(object): with ops.device(self._variable.device): self._initializer_op = state_ops.assign( self._variable, self._initial_value, validate_shape=False).op + self._snapshot = array_ops.identity(self._variable, name="read") else: self._variable = state_ops.variable_op( self._initial_value.get_shape(), @@ -207,6 +210,7 @@ class Variable(object): with ops.device(self._variable.device): self._initializer_op = state_ops.assign( self._variable, self._initial_value).op + self._snapshot = array_ops.identity(self._variable, name="read") for key in collections: ops.add_to_collection(key, self) self._save_slice_info = None @@ -216,7 +220,50 @@ class Variable(object): return self._variable def _AsTensor(self): - """Conversion function for ops.convert_to_tensor().""" + """Converts this variable to a Tensor. + + See [`value()`](#Variable.value). + + Returns: + A `Tensor` containing the value of the variable. + """ + return self._snapshot + + def value(self): + """Returns the last snapshot of this variable. + + You usually do not need to call this method as all ops that need the value + of the variable call it automatically through a `convert_to_tensor()` call. + + Returns a `Tensor` which holds the value of the variable. You can not + assign a new value to this tensor as it is not a reference to the variable. + See [`ref()`](#Variable.ref) if you want to get a reference to the + variable. + + To avoid copies, if the consumer of the returned value is on the same device + as the variable, this actually returns the live value of the variable, not + a copy. Updates to the variable are seen by the consumer. If the consumer + is on a different device it will get a copy of the variable. + + Returns: + A `Tensor` containing the value of the variable. + """ + return self._snapshot + + def ref(self): + """Returns a reference to this variable. + + You usually do not need to call this method as all ops that need a reference + to the variable call it automatically. + + Returns is a `Tensor` which holds a reference to the variable. You can + assign a new value to the variable by passing the tensor to an assign op. + See [`value()`](#Variable.value) if you want to get the value of the + variable. + + Returns: + A `Tensor` that is a reference to the variable. + """ return self._variable def eval(self, session=None): @@ -366,15 +413,17 @@ class Variable(object): # Conversion to tensor. @staticmethod - def _TensorConversionFunction(v, dtype=None, name=None): + def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): """Utility function for converting a Variable to a Tensor.""" _ = name - ret = v._AsTensor() # pylint: disable=protected-access if dtype and not dtype.is_compatible_with(v.dtype): raise ValueError( "Incompatible type conversion requested to type '%s' for variable " "of type '%s'" % (dtype.name, v.dtype.name)) - return ret + if as_ref: + return v.ref() + else: + return v.value() # Operator overloading. # diff --git a/tensorflow/python/platform/default/_flags.py b/tensorflow/python/platform/default/_flags.py index fcf78fde989..d7ae189c21c 100644 --- a/tensorflow/python/platform/default/_flags.py +++ b/tensorflow/python/platform/default/_flags.py @@ -94,7 +94,15 @@ def DEFINE_boolean(flag_name, default_value, docstring): default_value: The default value the flag should take as a boolean. docstring: A helpful message explaining the use of the flag. """ - _define_helper(flag_name, default_value, docstring, bool) + # Register a custom function for 'bool' so --flag=True works. + def str2bool(v): + return v.lower() in ('true', 't', '1') + _global_parser.add_argument('--' + flag_name, + nargs='?', + const=True, + help=docstring, + default=default_value, + type=str2bool) _global_parser.add_argument('--no' + flag_name, action='store_false', dest=flag_name) diff --git a/tensorflow/python/platform/default/_gfile.py b/tensorflow/python/platform/default/_gfile.py index 4ee28ca0123..44a09f0d9c3 100644 --- a/tensorflow/python/platform/default/_gfile.py +++ b/tensorflow/python/platform/default/_gfile.py @@ -358,3 +358,25 @@ def ListDirectory(directory, return_dotfiles=False): # pylint: disable=invalid- if not return_dotfiles: files = [f for f in files if not f.startswith('.')] return files + + +def Walk(top, topdown=1, onerror=None): + """Recursive directory tree generator. + + Args: + top: string, a pathname. + topdown: bool, should traversal be pre-order (True) or post-order (False) + onerror: function, optional callback for errors. + + By default, errors that occur when listing a directory are ignored. + (This is the same semantics as Python's os.walk() generator.) If the + optional argument "onerror" is specified, it should be a function. It + will be called with one argument, an os.error instance. It can return + to continue with the walk, or reraise the exception to abort the walk. + + Yields: + # Each yield is a 3-tuple: the pathname of a directory, followed + # by lists of all its subdirectories and leaf files. + (dirname, [subdirname, subdirname, ...], [filename, filename, ...]) + """ + return os.walk(top, topdown=topdown, onerror=onerror) diff --git a/tensorflow/python/platform/default/flags_test.py b/tensorflow/python/platform/default/flags_test.py index be32bb63bd9..3868576c2f8 100644 --- a/tensorflow/python/platform/default/flags_test.py +++ b/tensorflow/python/platform/default/flags_test.py @@ -26,10 +26,16 @@ from tensorflow.python.platform.default import _flags as flags flags.DEFINE_string("string_foo", "default_val", "HelpString") -flags.DEFINE_boolean("bool_foo", True, "HelpString") flags.DEFINE_integer("int_foo", 42, "HelpString") flags.DEFINE_float("float_foo", 42.0, "HelpString") +flags.DEFINE_boolean("bool_foo", True, "HelpString") +flags.DEFINE_boolean("bool_negation", True, "HelpString") +flags.DEFINE_boolean("bool_a", False, "HelpString") +flags.DEFINE_boolean("bool_c", False, "HelpString") +flags.DEFINE_boolean("bool_d", True, "HelpString") +flags.DEFINE_boolean("bool_e", True, "HelpString") + FLAGS = flags.FLAGS class FlagsTest(googletest.TestCase): @@ -46,14 +52,23 @@ class FlagsTest(googletest.TestCase): FLAGS.bool_foo = False self.assertFalse(FLAGS.bool_foo) - def testNoBool(self): - FLAGS.bool_foo = True - try: - sys.argv.append("--nobool_foo") - FLAGS._parse_flags() - self.assertFalse(FLAGS.bool_foo) - finally: - sys.argv.pop() + def testBoolCommandLines(self): + # Specified on command line with no args, sets to True, + # even if default is False. + self.assertEqual(True, FLAGS.bool_a) + + # --no before the flag forces it to False, even if the + # default is True + self.assertEqual(False, FLAGS.bool_negation) + + # --bool_flag=True sets to True + self.assertEqual(True, FLAGS.bool_c) + + # --bool_flag=False sets to False + self.assertEqual(False, FLAGS.bool_d) + + # --bool_flag=gibberish sets to False + self.assertEqual(False, FLAGS.bool_e) def testInt(self): res = FLAGS.int_foo @@ -69,4 +84,12 @@ class FlagsTest(googletest.TestCase): if __name__ == "__main__": - googletest.main() + # Test command lines + sys.argv.extend(["--bool_a", "--nobool_negation", "--bool_c=True", + "--bool_d=False", "--bool_e=gibberish"]) + + # googletest.main() tries to interpret the above flags, so use the + # direct functions instead. + runner = googletest.TextTestRunner() + itersuite = googletest.TestLoader().loadTestsFromTestCase(FlagsTest) + runner.run(itersuite) diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index f985092b66e..209f730c8db 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -17,9 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=unused-import from tensorflow.python.platform.googletest import GetTempDir from tensorflow.python.platform.googletest import main from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase from tensorflow.python.framework.test_util import IsGoogleCudaEnabled as IsBuiltWithCuda +from tensorflow.python.kernel_tests.gradient_checker import compute_gradient_error +from tensorflow.python.kernel_tests.gradient_checker import compute_gradient + get_temp_dir = GetTempDir +# pylint: enable=unused-import diff --git a/tensorflow/python/summary/event_multiplexer.py b/tensorflow/python/summary/event_multiplexer.py index a3ce42afb1e..46ee7523876 100644 --- a/tensorflow/python/summary/event_multiplexer.py +++ b/tensorflow/python/summary/event_multiplexer.py @@ -143,13 +143,15 @@ class EventMultiplexer(object): return self def AddRunsFromDirectory(self, path, name=None): - """Load runs from a directory, assuming each subdirectory is a run. + """Load runs from a directory; recursively walks subdirectories. If path doesn't exist, no-op. This ensures that it is safe to call `AddRunsFromDirectory` multiple times, even before the directory is made. - If the directory contains TensorFlow event files, it is itself treated as a - run. + If path is a directory, load event files in the directory (if any exist) and + recursively call AddRunsFromDirectory on any subdirectories. This mean you + can call AddRunsFromDirectory at the root of a tree of event logs and + TensorBoard will load them all. If the `EventMultiplexer` is already loaded or autoupdating, this will cause the newly created accumulators to also `Reload()` or `AutoUpdate()`. @@ -171,25 +173,16 @@ class EventMultiplexer(object): if not gfile.Exists(path): return # Maybe it hasn't been created yet, fail silently to retry later if not gfile.IsDirectory(path): - raise ValueError('Path exists and is not a directory, %s' % path) - paths = gfile.ListDirectory(path) - is_directory = lambda x: gfile.IsDirectory(os.path.join(path, x)) - subdirectories = filter(is_directory, paths) - for s in subdirectories: - if name: - subname = '/'.join([name, s]) - else: - subname = s - self.AddRun(os.path.join(path, s), subname) + raise ValueError('AddRunsFromDirectory: path exists and is not a ' + 'directory, %s' % path) + + for (subdir, _, files) in gfile.Walk(path): + if list(filter(event_accumulator.IsTensorFlowEventsFile, files)): + logging.info('Adding events from directory %s', subdir) + rpath = os.path.relpath(subdir, path) + subname = os.path.join(name, rpath) if name else rpath + self.AddRun(subdir, name=subname) - if list(filter(event_accumulator.IsTensorFlowEventsFile, paths)): - directory_name = os.path.split(path)[1] - logging.info('Directory %s has event files; loading', directory_name) - if name: - dname = name - else: - dname = directory_name - self.AddRun(path, dname) return self def Reload(self): diff --git a/tensorflow/python/summary/event_multiplexer_test.py b/tensorflow/python/summary/event_multiplexer_test.py index 01749a16f5d..e7cecba1447 100644 --- a/tensorflow/python/summary/event_multiplexer_test.py +++ b/tensorflow/python/summary/event_multiplexer_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import os +import os.path import tensorflow.python.platform @@ -28,6 +29,20 @@ from tensorflow.python.summary import event_accumulator from tensorflow.python.summary import event_multiplexer +def _AddEvents(path): + if not gfile.IsDirectory(path): + gfile.MakeDirs(path) + fpath = os.path.join(path, 'hypothetical.tfevents.out') + with gfile.GFile(fpath, 'w'): + return fpath + + +def _CreateCleanDirectory(path): + if gfile.IsDirectory(path): + gfile.DeleteRecursively(path) + gfile.MkDir(path) + + class _FakeAccumulator(object): def __init__(self, path): @@ -137,34 +152,33 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase): x.AddRunsFromDirectory(fakedir) self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect') - if gfile.IsDirectory(realdir): - gfile.DeleteRecursively(realdir) - gfile.MkDir(realdir) + _CreateCleanDirectory(realdir) x.AddRunsFromDirectory(realdir) self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect') path1 = join(realdir, 'path1') gfile.MkDir(path1) x.AddRunsFromDirectory(realdir) - self.assertEqual(sorted(x.Runs().keys()), ['path1'], 'loaded run: path1') + self.assertEqual(x.Runs(), {}, 'creating empty subdirectory had no effect') + + _AddEvents(path1) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ['path1'], 'loaded run: path1') loader1 = x._GetAccumulator('path1') self.assertEqual(loader1._path, path1, 'has the correct path') path2 = join(realdir, 'path2') - gfile.MkDir(path2) + _AddEvents(path2) x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(sorted(x.Runs().keys()), ['path1', 'path2']) + self.assertItemsEqual(x.Runs(), ['path1', 'path2']) self.assertEqual(x._GetAccumulator('path1'), loader1, 'loader1 not regenerated') - loader2 = x._GetAccumulator('path2') path2_2 = join(path2, 'path2') - gfile.MkDir(path2_2) - x.AddRunsFromDirectory(path2) - self.assertItemsEqual(sorted(x.Runs().keys()), ['path1', 'path2']) - self.assertNotEqual(loader2, x._GetAccumulator('path2'), - 'loader2 regenerated') - self.assertEqual(x._GetAccumulator('path2')._path, path2_2, + _AddEvents(path2_2) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ['path1', 'path2', 'path2/path2']) + self.assertEqual(x._GetAccumulator('path2/path2')._path, path2_2, 'loader2 path correct') def testAddRunsFromDirectoryThatContainsEvents(self): @@ -173,21 +187,18 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase): join = os.path.join realdir = join(tmpdir, 'event_containing_directory') - if gfile.IsDirectory(realdir): - gfile.DeleteRecursively(realdir) - gfile.MkDir(realdir) + _CreateCleanDirectory(realdir) self.assertEqual(x.Runs(), {}) - with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'): - pass + _AddEvents(realdir) x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['event_containing_directory']) + self.assertItemsEqual(x.Runs(), ['.']) subdir = join(realdir, 'subdir') - gfile.MkDir(subdir) + _AddEvents(subdir) x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['event_containing_directory', 'subdir']) + self.assertItemsEqual(x.Runs(), ['.', 'subdir']) def testAddRunsFromDirectoryWithRunNames(self): x = event_multiplexer.EventMultiplexer() @@ -195,30 +206,45 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase): join = os.path.join realdir = join(tmpdir, 'event_containing_directory') - if gfile.IsDirectory(realdir): - gfile.DeleteRecursively(realdir) - gfile.MkDir(realdir) + _CreateCleanDirectory(realdir) self.assertEqual(x.Runs(), {}) - with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'): - pass + _AddEvents(realdir) x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo']) + self.assertItemsEqual(x.Runs(), ['foo/.']) subdir = join(realdir, 'subdir') - gfile.MkDir(subdir) + _AddEvents(subdir) x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo', 'foo/subdir']) + self.assertItemsEqual(x.Runs(), ['foo/.', 'foo/subdir']) + + def testAddRunsFromDirectoryWalksTree(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + realdir = join(tmpdir, 'event_containing_directory') + + _CreateCleanDirectory(realdir) + _AddEvents(realdir) + sub = join(realdir, 'subdirectory') + sub1 = join(sub, '1') + sub2 = join(sub, '2') + sub1_1 = join(sub1, '1') + _AddEvents(sub1) + _AddEvents(sub2) + _AddEvents(sub1_1) + x.AddRunsFromDirectory(realdir) + + self.assertItemsEqual(x.Runs(), ['.', + 'subdirectory/1', 'subdirectory/2', + 'subdirectory/1/1']) def testAddRunsFromDirectoryThrowsException(self): x = event_multiplexer.EventMultiplexer() tmpdir = self.get_temp_dir() - filepath = os.path.join(tmpdir, 'bad_file') - with gfile.GFile(filepath, 'w'): - pass - + filepath = _AddEvents(tmpdir) with self.assertRaises(ValueError): x.AddRunsFromDirectory(filepath) diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py index d226d672abd..1057ec947e7 100644 --- a/tensorflow/python/training/adagrad_test.py +++ b/tensorflow/python/training/adagrad_test.py @@ -47,6 +47,28 @@ class AdagradOptimizerTest(tf.test.TestCase): self.assertAllClose(np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + def testTensorLearningRate(self): + with self.test_session(): + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0, 4.0]) + grads0 = tf.constant([0.1, 0.1]) + grads1 = tf.constant([0.01, 0.01]) + ada_opt = tf.train.AdagradOptimizer( + tf.constant(3.0), initial_accumulator_value=0.1) + ada_update = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + tf.initialize_all_variables().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 3 steps of adagrad + for _ in range(3): + ada_update.run() + # Validate updated params + self.assertAllClose(np.array([-1.6026098728179932, -0.6026098728179932]), + var0.eval()) + self.assertAllClose(np.array([2.715679168701172, 3.715679168701172]), + var1.eval()) + def testFloat64(self): with self.test_session(): opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1) diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index 41fa64e6d71..6729394083f 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -69,9 +69,9 @@ class AdamOptimizer(optimizer.Optimizer): beta1: A float value or a constant float tensor. The exponential decay rate for the 1st moment estimates. beta2: A float value or a constant float tensor. - The exponential decay rate for the 2st moment estimates. + The exponential decay rate for the 2nd moment estimates. epsilon: A small constant for numerical stability. - use_locking: If True use locks for update operation.s + use_locking: If True use locks for update operations. name: Optional name for the operations created when applying gradients. Defaults to "Adam". """ @@ -143,8 +143,8 @@ class AdamOptimizer(optimizer.Optimizer): use_locking=self._use_locking) v_sqrt = math_ops.sqrt(v_t) var_update = state_ops.assign_sub(var, - lr * m_t / (v_sqrt + self._epsilon_t), - use_locking=self._use_locking) + lr * m_t / (v_sqrt + self._epsilon_t), + use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t]) def _finish(self, update_ops, name_scope): diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index f9ea6c22f55..d6e18146912 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -115,6 +115,42 @@ class AdamOptimizerTest(tf.test.TestCase): self.assertAllClose(var0_np, var0.eval()) self.assertAllClose(var1_np, var1.eval()) + def testTensorLearningRate(self): + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=np.float32) + grads0_np = np.array([0.1, 0.1], dtype=np.float32) + var1_np = np.array([3.0, 4.0], dtype=np.float32) + grads1_np = np.array([0.01, 0.01], dtype=np.float32) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0 = tf.constant(grads0_np) + grads1 = tf.constant(grads1_np) + opt = tf.train.AdamOptimizer(tf.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + tf.initialize_all_variables().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllClose(0.9 ** t, beta1_power.eval()) + self.assertAllClose(0.999 ** t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllClose(var0_np, var0.eval()) + self.assertAllClose(var1_np, var1.eval()) + def testFloat64(self): with self.test_session(): opt = tf.train.AdamOptimizer() diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py index 68378cef9ec..dd4e391196c 100644 --- a/tensorflow/python/training/gradient_descent_test.py +++ b/tensorflow/python/training/gradient_descent_test.py @@ -44,6 +44,25 @@ class GradientDescentOptimizerTest(tf.test.TestCase): self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval()) self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval()) + def testTensorLearningRate(self): + with self.test_session(): + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0, 4.0]) + grads0 = tf.constant([0.1, 0.1]) + grads1 = tf.constant([0.01, 0.01]) + lrate = tf.constant(3.0) + sgd_op = tf.train.GradientDescentOptimizer(lrate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + tf.initialize_all_variables().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval()) + self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval()) + def testFloat64(self): with self.test_session(): opt = tf.train.GradientDescentOptimizer(3.0) diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index e580e80eb27..53bca00756c 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import io_ops +from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import summary_ops @@ -114,8 +115,21 @@ def string_input_producer(string_tensor, num_epochs=None, shuffle=True, Returns: A queue with the output strings. A `QueueRunner` for the Queue is added to the current `Graph`'s `QUEUE_RUNNER` collection. + + Raises: + ValueError: If the string_tensor is a null Python list. At runtime, + will fail with an assertion if string_tensor becomes a null tensor. """ + not_null_err = "string_input_producer requires a non-null input tensor" + if not string_tensor: + raise ValueError(not_null_err) + with ops.op_scope([string_tensor], name, "input_producer") as name: + string_tensor = ops.convert_to_tensor(string_tensor, dtype=dtypes.string) + with ops.control_dependencies([ + logging_ops.Assert(math_ops.greater(array_ops.size(string_tensor), 0), + [not_null_err])]): + string_tensor = array_ops.identity(string_tensor) return _input_producer( string_tensor, dtypes.string, num_epochs, shuffle, seed, capacity, name, "fraction_of_%d_full" % capacity) diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index 80961abdc38..ab17a6be495 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -132,6 +132,28 @@ class StringInputProducerTest(tf.test.TestCase): for thread in threads: thread.join() + def testNullStringPython(self): + # Graph-construction time check for empty string list: + with self.test_session(): + with self.assertRaises(ValueError): + _ = tf.train.string_input_producer([]) + + def testNullString(self): + # Runtime check for empty string list. This is slightly oblique: + # The queue runner should die with an assertion error on the null + # input tensor, causing the dequeue to fail with an OutOfRangeError. + with self.test_session(): + coord = tf.train.Coordinator() + queue = tf.train.string_input_producer(tf.constant([], dtype=tf.string)) + dequeue = queue.dequeue() + tf.initialize_all_variables().run() + threads = tf.train.start_queue_runners(coord=coord) + with self.assertRaises(tf.errors.OutOfRangeError): + dequeue.eval() + coord.request_stop() + for thread in threads: + thread.join() + class RangeInputProducerTest(tf.test.TestCase): diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index eee6f0300da..f7e1e3095c2 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -77,6 +77,57 @@ class MomentumOptimizerTest(tf.test.TestCase): 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]), var1.eval()) + def testTensorLearningRateAndMomentum(self): + with self.test_session(): + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0, 4.0]) + grads0 = tf.constant([0.1, 0.1]) + grads1 = tf.constant([0.01, 0.01]) + mom_opt = tf.train.MomentumOptimizer( + learning_rate=tf.constant(2.0), momentum=tf.constant(0.9)) + mom_update = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + tf.initialize_all_variables().run() + # Check we have slots + self.assertEqual(["momentum"], mom_opt.get_slot_names()) + slot0 = mom_opt.get_slot(var0, "momentum") + self.assertEquals(slot0.get_shape(), var0.get_shape()) + self.assertFalse(slot0 in tf.trainable_variables()) + slot1 = mom_opt.get_slot(var1, "momentum") + self.assertEquals(slot1.get_shape(), var1.get_shape()) + self.assertFalse(slot1 in tf.trainable_variables()) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the momentum accumulators where 0. So we should see a normal + # update: v -= grad * learning_rate + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllClose(np.array([0.1, 0.1]), slot0.eval()) + self.assertAllClose(np.array([0.01, 0.01]), slot1.eval()) + # Check that the parameters have been updated. + self.assertAllClose(np.array([1.0 - (0.1 * 2.0), + 2.0 - (0.1 * 2.0)]), + var0.eval()) + self.assertAllClose(np.array([3.0 - (0.01 * 2.0), + 4.0 - (0.01 * 2.0)]), + var1.eval()) + # Step 2: the momentum accumulators contain the previous update. + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllClose(np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + slot0.eval()) + self.assertAllClose(np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + slot1.eval()) + # Check that the parameters have been updated. + self.assertAllClose( + np.array([1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), + 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)]), + var0.eval()) + self.assertAllClose(np.array([2.98 - ((0.9 * 0.01 + 0.01) * 2.0), + 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]), + var1.eval()) + def testFloat64(self): with self.test_session(): opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index dc2f700f816..d9b6062cb71 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -225,6 +225,8 @@ class Optimizer(object): for var in var_list: if not isinstance(var, variables.Variable): raise TypeError("Argument is not a variables.Variable: %s" % var) + if not var_list: + raise ValueError("No variables to optimize") grads = gradients.gradients( loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP), aggregation_method=aggregation_method) @@ -254,6 +256,7 @@ class Optimizer(object): Raises: TypeError: if `grads_and_vars` is malformed. + ValueError: if none of the variables have gradients. """ # This is a default implementation of apply_gradients() that can be shared # by most optimizers. It relies on the subclass implementing the following @@ -268,7 +271,11 @@ class Optimizer(object): "Variable must be a variables.Variable: %s" % v) if g is not None: self._assert_valid_dtypes([g, v]) - self._create_slots([v for g, v in grads_and_vars if g is not None]) + var_list = [v for g, v in grads_and_vars if g is not None] + if not var_list: + raise ValueError("No gradients provided for any variable: %s" % + grads_and_vars) + self._create_slots(var_list) update_ops = [] with ops.op_scope([], name, self._name) as name: self._prepare() diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index b9f2b5fdef5..204d9e7d3b4 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -64,6 +64,26 @@ class OptimizerTest(tf.test.TestCase): self.assertAllClose([-14., -13.], var0.eval()) self.assertAllClose([-6., -5.], var1.eval()) + def testNoVariables(self): + with self.test_session(): + var0 = tf.Variable([1.0, 2.0], trainable=False) + var1 = tf.Variable([3.0, 4.0], trainable=False) + cost = 5 * var0 + var1 + sgd_op = tf.train.GradientDescentOptimizer(3.0) + with self.assertRaisesRegexp(ValueError, 'No variables'): + sgd_op.minimize(cost) + + def testNoGradients(self): + with self.test_session(): + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0, 4.0]) + cost = 5 * var0 + global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step') + sgd_op = tf.train.GradientDescentOptimizer(3.0) + with self.assertRaisesRegexp(ValueError, 'No gradients'): + # var1 has no gradient + sgd_op.minimize(cost, global_step, [var1]) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 4b86a08609d..08250dc750b 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -284,7 +284,7 @@ class BaseSaverBuilder(object): else: names_to_variables[name] = [var] else: - var = ops.convert_to_tensor(var) + var = ops.convert_to_tensor(var, as_ref=True) if not self._IsVariable(var): raise TypeError("Variable to save is not a Variable: %s" % var) name = var.op.name @@ -341,7 +341,7 @@ class BaseSaverBuilder(object): # pylint: enable=protected-access else: # A variable or tensor. - variable = ops.convert_to_tensor(v) + variable = ops.convert_to_tensor(v, as_ref=True) if not self._IsVariable(variable): raise TypeError("names_to_variables must be a dict mapping string " "names to Tensors/Variables. Not a variable: %s" % diff --git a/tensorflow/tensorboard/README.md b/tensorflow/tensorboard/README.md index eb85a1e4610..99279fa9ed5 100644 --- a/tensorflow/tensorboard/README.md +++ b/tensorflow/tensorboard/README.md @@ -1,13 +1,20 @@ # TensorBoard TensorBoard is a suite of web applications for inspecting and understanding your -TensorFlow runs and graphs. +TensorFlow runs and graphs. Before running TensorBoard, make sure you have +generated summary data in a log directory by creating a `SummaryWriter`: -Example Usage: +```python +# sess.graph_def is the graph definition. +summary_writer = tf.train.SummaryWriter('/path/to/logs', sess.graph_def) +``` + +For more details, see [this tutorial](http://www.tensorflow.org/how_tos/summaries_and_tensorboard/index.html#serializing-the-data). +Then run TensorBoard and provide the log directory: ``` python tensorflow/tensorboard/tensorboard.py --logdir=path/to/logs -# if installed via pip +# or if installed via pip, run: tensorboard --logdir=path/to/logs # if building from source @@ -26,7 +33,14 @@ includes a frontend (app/tf-tensorboard.html) that contains html and javascript for displaying this data in a UI. -## Building the TensorBoard frontend +## TensorBoard Development Instructions + +The following instructions are useful if you want to develop the TensorBoard +frontend in a lightweight frontend-only environment. It sets up gulp with +automatic recompiling and serves just the frontend assets without a connected +backend. + +If you just want to use TensorBoard, there is no need to read any further. ### Install Node, npm, gulp, bower, and tsd in your machine Get nodejs and npm through whatever package distribution system is appropriate @@ -43,24 +57,11 @@ run the following commands. bower install tsd install -### Run Gulp Vulcanize +### Run Gulp -Inside this directory, run `gulp vulcanize`. That will compile all of the -html/js/css dependencies for TensorBoard into a monolithic index.html file under -dist/. Once you've done this, you can locally run your own TensorBoard instance -and it will have a working frontend. +Inside this directory, run `gulp`. That will compile all of the +html/js/css dependencies for TensorBoard, and also spin up a server +(by default at port 8000). You can navigate to component-specific demo pages to +check out their behavior. -### Frontend General Dev Instructions - -To speed up the development process, we can run the frontend code independently -of the backend, and mock out the backend with static JSON files. This allows -testing the frontend's correctness without needing to find real data and spin -up a real server. Look at app/demo/index.html for an example. - -The following gulp commands are useful: - -* `gulp test` - build, test, and lint the code -* `gulp watch` - build, test, and rebuild on change -* `gulp server` - start a livereload server on localhost:8000 -* `gulp` - alias for `gulp watch` -* `gulp vulcanize` - +Running `gulp test` will run all unit tests, the linter, etc. diff --git a/tensorflow/tensorboard/app/index.html b/tensorflow/tensorboard/app/index.html index bf466c79243..c4031c14d30 100644 --- a/tensorflow/tensorboard/app/index.html +++ b/tensorflow/tensorboard/app/index.html @@ -2,9 +2,16 @@ + - - + TensorBoard diff --git a/tensorflow/tensorboard/app/tf-tensorboard-demo.html b/tensorflow/tensorboard/app/tf-tensorboard-demo.html deleted file mode 100644 index 5f0ef5b00c7..00000000000 --- a/tensorflow/tensorboard/app/tf-tensorboard-demo.html +++ /dev/null @@ -1,72 +0,0 @@ - - - - - - - diff --git a/tensorflow/tensorboard/bower.json b/tensorflow/tensorboard/bower.json index 7d9f033d4dc..06a31968ab0 100644 --- a/tensorflow/tensorboard/bower.json +++ b/tensorflow/tensorboard/bower.json @@ -20,9 +20,11 @@ "es6-promise": "3.0.2", "graphlib": "1.0.7", "iron-ajax": "PolymerElements/iron-ajax#1.0.7", + "iron-behaviors": "PolymerElements/iron-behaviors#1.0.10", "iron-collapse": "PolymerElements/iron-collapse#1.0.4", "iron-list": "PolymerElements/iron-list#1.1.5", "iron-selector": "PolymerElements/iron-selector#1.0.7", + "paper-behaviors": "PolymerElements/paper-behaviors#1.0.9", "paper-button": "PolymerElements/paper-button#1.0.8", "paper-checkbox": "PolymerElements/paper-checkbox#1.0.13", "paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#1.0.5", diff --git a/tensorflow/tensorboard/components/imports/README.md b/tensorflow/tensorboard/components/imports/README.md new file mode 100644 index 00000000000..695698bf237 --- /dev/null +++ b/tensorflow/tensorboard/components/imports/README.md @@ -0,0 +1,6 @@ +This file acts as import routers for third party javascript libraries, +e.g. Plottable and D3. + +The "local-imports" folder contains alternate versions of the import routers +that load from `bower_components`; it exists to faciliate local development +with a gulp workflow. diff --git a/tensorflow/tensorboard/components/imports/dagre.html b/tensorflow/tensorboard/components/imports/dagre.html new file mode 100644 index 00000000000..b75f137cb28 --- /dev/null +++ b/tensorflow/tensorboard/components/imports/dagre.html @@ -0,0 +1,2 @@ + + diff --git a/tensorflow/tensorboard/components/imports/graphlib.html b/tensorflow/tensorboard/components/imports/graphlib.html new file mode 100644 index 00000000000..189eff17201 --- /dev/null +++ b/tensorflow/tensorboard/components/imports/graphlib.html @@ -0,0 +1 @@ + diff --git a/tensorflow/tensorboard/components/imports/local-imports/d3.html b/tensorflow/tensorboard/components/imports/local-imports/d3.html new file mode 100644 index 00000000000..e2797c0a1a9 --- /dev/null +++ b/tensorflow/tensorboard/components/imports/local-imports/d3.html @@ -0,0 +1 @@ + diff --git a/tensorflow/tensorboard/components/imports/local-imports/dagre.html b/tensorflow/tensorboard/components/imports/local-imports/dagre.html new file mode 100644 index 00000000000..b685aea6c93 --- /dev/null +++ b/tensorflow/tensorboard/components/imports/local-imports/dagre.html @@ -0,0 +1,4 @@ +// hackhack for some reason getting graphlib via an import reference results in +// out of order script evaluation + + diff --git a/tensorflow/tensorboard/components/imports/local-imports/graphlib.html b/tensorflow/tensorboard/components/imports/local-imports/graphlib.html new file mode 100644 index 00000000000..a1e98e9089d --- /dev/null +++ b/tensorflow/tensorboard/components/imports/local-imports/graphlib.html @@ -0,0 +1 @@ + diff --git a/tensorflow/tensorboard/components/imports/local-imports/lodash.html b/tensorflow/tensorboard/components/imports/local-imports/lodash.html new file mode 100644 index 00000000000..95f8375a1d4 --- /dev/null +++ b/tensorflow/tensorboard/components/imports/local-imports/lodash.html @@ -0,0 +1 @@ + diff --git a/tensorflow/tensorboard/components/imports/local-imports/plottable.html b/tensorflow/tensorboard/components/imports/local-imports/plottable.html new file mode 100644 index 00000000000..dfbe77c8c44 --- /dev/null +++ b/tensorflow/tensorboard/components/imports/local-imports/plottable.html @@ -0,0 +1,3 @@ + + + diff --git a/tensorflow/tensorboard/test/index.html b/tensorflow/tensorboard/components/test/index.html similarity index 56% rename from tensorflow/tensorboard/test/index.html rename to tensorflow/tensorboard/components/test/index.html index e02aafc668f..bef954701fe 100644 --- a/tensorflow/tensorboard/test/index.html +++ b/tensorflow/tensorboard/components/test/index.html @@ -2,16 +2,15 @@ - - \ No newline at end of file + diff --git a/tensorflow/tensorboard/components/tf-graph-common/test/index.html b/tensorflow/tensorboard/components/tf-graph-common/test/index.html index fddcb2fde4e..c7694e75149 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/test/index.html +++ b/tensorflow/tensorboard/components/tf-graph-common/test/index.html @@ -12,4 +12,4 @@ - \ No newline at end of file + diff --git a/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html b/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html index 107e3ab7a09..f42e7d27dd7 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html +++ b/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html @@ -1,7 +1,7 @@ - - - - + + + + diff --git a/tensorflow/tensorboard/components/tf-graph-loader/test/index.html b/tensorflow/tensorboard/components/tf-graph-loader/test/index.html index ba31e22c5b7..e484b43822f 100644 --- a/tensorflow/tensorboard/components/tf-graph-loader/test/index.html +++ b/tensorflow/tensorboard/components/tf-graph-loader/test/index.html @@ -10,4 +10,4 @@ - \ No newline at end of file + diff --git a/tensorflow/tensorboard/app/demo/data/cos.json b/tensorflow/tensorboard/components/tf-tensorboard/demo/data/cos.json similarity index 100% rename from tensorflow/tensorboard/app/demo/data/cos.json rename to tensorflow/tensorboard/components/tf-tensorboard/demo/data/cos.json diff --git a/tensorflow/tensorboard/app/demo/data/cubic.json b/tensorflow/tensorboard/components/tf-tensorboard/demo/data/cubic.json similarity index 100% rename from tensorflow/tensorboard/app/demo/data/cubic.json rename to tensorflow/tensorboard/components/tf-tensorboard/demo/data/cubic.json diff --git a/tensorflow/tensorboard/app/demo/data/linear.json b/tensorflow/tensorboard/components/tf-tensorboard/demo/data/linear.json similarity index 100% rename from tensorflow/tensorboard/app/demo/data/linear.json rename to tensorflow/tensorboard/components/tf-tensorboard/demo/data/linear.json diff --git a/tensorflow/tensorboard/app/demo/data/poly5-graph.pbtxt b/tensorflow/tensorboard/components/tf-tensorboard/demo/data/poly5-graph.pbtxt similarity index 100% rename from tensorflow/tensorboard/app/demo/data/poly5-graph.pbtxt rename to tensorflow/tensorboard/components/tf-tensorboard/demo/data/poly5-graph.pbtxt diff --git a/tensorflow/tensorboard/app/demo/data/poly5.json b/tensorflow/tensorboard/components/tf-tensorboard/demo/data/poly5.json similarity index 100% rename from tensorflow/tensorboard/app/demo/data/poly5.json rename to tensorflow/tensorboard/components/tf-tensorboard/demo/data/poly5.json diff --git a/tensorflow/tensorboard/app/demo/data/runs.json b/tensorflow/tensorboard/components/tf-tensorboard/demo/data/runs.json similarity index 100% rename from tensorflow/tensorboard/app/demo/data/runs.json rename to tensorflow/tensorboard/components/tf-tensorboard/demo/data/runs.json diff --git a/tensorflow/tensorboard/app/demo/data/sin-graph.pbtxt b/tensorflow/tensorboard/components/tf-tensorboard/demo/data/sin-graph.pbtxt similarity index 100% rename from tensorflow/tensorboard/app/demo/data/sin-graph.pbtxt rename to tensorflow/tensorboard/components/tf-tensorboard/demo/data/sin-graph.pbtxt diff --git a/tensorflow/tensorboard/app/demo/data/sin.json b/tensorflow/tensorboard/components/tf-tensorboard/demo/data/sin.json similarity index 100% rename from tensorflow/tensorboard/app/demo/data/sin.json rename to tensorflow/tensorboard/components/tf-tensorboard/demo/data/sin.json diff --git a/tensorflow/tensorboard/app/demo/data/sq.json b/tensorflow/tensorboard/components/tf-tensorboard/demo/data/sq.json similarity index 100% rename from tensorflow/tensorboard/app/demo/data/sq.json rename to tensorflow/tensorboard/components/tf-tensorboard/demo/data/sq.json diff --git a/tensorflow/tensorboard/app/demo/index.html b/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html similarity index 79% rename from tensorflow/tensorboard/app/demo/index.html rename to tensorflow/tensorboard/components/tf-tensorboard/demo/index.html index a12b5abc261..369032bae15 100644 --- a/tensorflow/tensorboard/app/demo/index.html +++ b/tensorflow/tensorboard/components/tf-tensorboard/demo/index.html @@ -1,7 +1,7 @@ - + - + TensorBoard Demo diff --git a/tensorflow/tensorboard/app/tf-tensorboard.html b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html similarity index 82% rename from tensorflow/tensorboard/app/tf-tensorboard.html rename to tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html index 0f5114143e1..9b9da223852 100644 --- a/tensorflow/tensorboard/app/tf-tensorboard.html +++ b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html @@ -1,12 +1,12 @@ - - - - - - - - - + + + + + + + + +