TensorFlow: upstream changes to git

Change:
	Clean up documentation for ReverseSequence
Change:
	Updated several tensorflow operations to use 32bit indices on GPU.
Change:
	Add attribute batch_dim to ReverseSequenceOp.
Change:
	Fix error in convert_to_records.py.  As reported in
	https://github.com/tensorflow/tensorflow/issues/370
	by AlexUnderMicrocontRoll.
Change:
	Update TensorBoard README.
Change:
	Fixes to boolean flags reported in
	https://github.com/tensorflow/tensorflow/issues/379.  Supports:

	--bool_flag=True  --> True
	--bool_flag=False  --> False
	--bool_flag=gibberish  --> False
	--bool_flag --> True
	--nobool_flag --> False

	Fixes 
Change:
	Update generated Op docs.
Change:
	Enable local development of TensorBoard using gulp
	Also make tf-tensorboard a regular component rather than special case

	This is mostly effected by creating tfserve.js, which is a small server
	with clever routing to load from bower_components/ and components/ using
	the paths that work within google3.

	Workflow: `gulp serve`
Change:
	Add a full working code example to the tensorboard and summaries tutorial
Change:
	Fix seq2seq_test when running on GPU.

	The "proj_w" and "proj_b" variables were being created before the
	`test_session()`'s device function took effect, which pushed the
	placement algorithm into making an incorrect decision.
Change:
	Add a sentence in TensorBoard README on how to serialize summary data to logs and provide link to the how-to tutorial on the TensorFlow website.
Change:
	Add error-catching code if string_input_producer is supplied a null input.
	Before this change, it would die with an opaque shape error from inside
	the queue.  This change catches (most) python null lists being
	passed directly in, and at runtime detects null tensors.

	Adds two tests for this to input_test.py
Change:
	Speed up for models that use the same variable multiple times in the case
	where variables must be copied across devices:
	- Have Variables wrap the Variable op in an Identity op when converted to Tensor.
	  This avoids multiple copies across devices if a variable is used multiple time
	  in a computation.
	- Add Variable.mutable() to return the non-wrapped Variable op for used when
	  assigning new values.
	- Add an as_ref parameter to convert_to_tensor() to allow code to specify
	  if they plan to assign a new value to the result of the conversion.  Make Variable
	  return the result of Variable.mutable() when as_ref is True.
	- Make all ops that assign values to variables pass as_ref=True when converting
	  their arguments.
Change:
	Change to reduce critical section times in gpu_event_mgr.h:
	(1) Call stream->ThenRecordEvent outside the EventMgr critical section
	(2) Do memory deallocation outside the critical section

	Speeds up one configuration of ptb_word_lm from 2924 words per
	second (wps) to 3278 wps on my desktop machine with a Titan X.
Change:
	Remove some colons that break the open source build

	::tensorflow::StringPiece breaks for @raingo, see
	https://github.com/tensorflow/tensorflow/issues/358.
	tensorflow::StringPiece (without the leading colons)
	seems to fix the problem.
Change:
	Added check that inputs to Operation is a list and make a defensive copy of the input. This is for cases where the input list is changed such as in _add_input.
Change:
	Use standard names for TensorFlow dtypes in the tutorial.
Change:
	Add tests for tensor inputs.
Change:
	Fix build after declaring more types for ops
Change:
	Switch to 32 bit indexing to speedup convolutions and concatenations.
Change:
	Add convert_image op to convert between types for images (similar to OpenCV's cvtScale).
Change:
	Make cast work between numeric types (bool, uint8, int16, int32, int64, float, double).
Change:

	Padding input data for odd number of paddings, so we can use cudnn anyway.
	+ Fix total padding computation when padding==VALID.
	+ This CL makes the Googlenet benchmark run 5x faster.

Change:
	Support IndexedSlices in ConcatGrad
Change:
	* sampled softmax op uses one embedding lookup for positive and negative samples
	* float64 support for sampled softmax
Change:
	Move RNN code out of models.rnn (without breaking existing code).  The API may still undergo minor changes, until full documentation as added.
Change:
	Changed to use per-step stacks for the accumulators used in while-loop gradient computation. This addresses the problem caused by using concat without sufficient static shape information. It should also improve performance as we avoided those expensive concats.
Change:
	Update generated Op docs.
Change:
	Improve error messages when the optimizer finds no variables to minimize or
	when none of the variables has gradients.
Change:
	Say that -1 isn't just for flattening in reshape docs

	Also add scalar reshape (reshape(t, [])) as an example.

	This fixes https://github.com/tensorflow/tensorflow/issues/281.
Change:
	This is a test.

Base CL: 109118714
This commit is contained in:
Vijay Vasudevan 2015-12-01 13:26:53 -08:00
parent 3972c791b9
commit 795f35da2d
162 changed files with 4658 additions and 3023 deletions
tensorflow
core
examples/label_image
g3doc
models
python

View File

@ -293,14 +293,13 @@ cc_library(
],
)
# TODO(opensource): Make it work externally
tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
cc_api_version = 2,
go_api_version = 2,
java_api_version = 2,
py_api_version = 2, # TODO(irving): Handle 3
py_api_version = 2,
visibility = ["//visibility:public"],
)
@ -507,7 +506,6 @@ filegroup(
"kernels/maxpooling_op.h",
"kernels/pooling_ops_common.h",
"kernels/pooling_ops_common.cc",
"kernels/reference_gemm.h",
],
exclude = [
"**/*test.cc",
@ -571,7 +569,6 @@ filegroup(
"//tensorflow/core:kernels/no_op.cc",
"//tensorflow/core:kernels/no_op.h",
"//tensorflow/core:kernels/pack_op.cc",
"//tensorflow/core:kernels/reference_gemm.h",
"//tensorflow/core:kernels/reshape_op.cc",
"//tensorflow/core:kernels/reshape_op.h",
"//tensorflow/core:kernels/reverse_sequence_op.cc",
@ -628,6 +625,8 @@ filegroup(
"//tensorflow/core:kernels/relu_op.h",
"//tensorflow/core:kernels/softplus_op.cc",
"//tensorflow/core:kernels/softplus_op.h",
"//tensorflow/core:kernels/softsign_op.cc",
"//tensorflow/core:kernels/softsign_op.h",
"//tensorflow/core:kernels/stack_ops.cc",
"//tensorflow/core:kernels/transpose_op.cc",
"//tensorflow/core:kernels/transpose_op.h",

View File

@ -758,7 +758,11 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
// Ask the device to fill in the device context map.
Device* device = impl_->params_.device;
device->FillContextMap(graph, &device_context_map_);
Status fill_status = device->FillContextMap(graph, &device_context_map_);
if (!fill_status.ok()) {
done(fill_status);
return;
}
// Initialize the ready queue.
for (const Node* n : graph->nodes()) {
@ -1077,7 +1081,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
for (int i = 0; i < node->num_outputs(); ++i) {
TensorValue val = ctx->release_output(i);
// Only Switch and Recv nodes can generate new dead outputs
// Only Switch and Recv can generate new dead outputs.
if (*ctx->is_output_dead() || val.tensor == nullptr) {
DCHECK(IsSwitch(node) || IsRecv(node));
} else {

View File

@ -40,13 +40,13 @@ EventMgr::~EventMgr() {
delete e;
}
while (!used_events_.empty()) {
delete used_events_[0].event;
delete used_events_[0].mem;
if (used_events_[0].bufrec.buf) {
used_events_[0].bufrec.alloc->DeallocateRaw(used_events_[0].bufrec.buf);
InUse* ue = &used_events_[0];
delete ue->event;
delete ue->mem;
if (ue->bufrec.buf) {
ue->bufrec.alloc->DeallocateRaw(ue->bufrec.buf);
}
if (used_events_[0].func != nullptr)
threadpool_.Schedule(used_events_[0].func);
if (ue->func != nullptr) threadpool_.Schedule(ue->func);
used_events_.pop_front();
}
}
@ -60,15 +60,17 @@ EventMgr::~EventMgr() {
void EventMgr::PollLoop() {
while (!stop_polling_.HasBeenNotified()) {
Env::Default()->SleepForMicroseconds(1 * 1000);
ToFreeVector to_free;
{
mutex_lock l(mu_);
PollEvents(true);
PollEvents(true, &to_free);
}
FreeMemory(to_free);
}
polling_stopped_.Notify();
}
void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) {
void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu, gpu::Event** e) {
VLOG(2) << "QueueInUse free_events_ " << free_events_.size()
<< " used_events_ " << used_events_.size();
// Events are created on demand, and repeatedly reused. There is no
@ -77,10 +79,9 @@ void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) {
free_events_.push_back(new gpu::Event(exec_));
free_events_.back()->Init();
}
gpu::Event* e = free_events_.back();
*e = free_events_.back();
free_events_.pop_back();
stream->ThenRecordEvent(e);
iu.event = e;
iu.event = *e;
used_events_.push_back(iu);
}
@ -103,7 +104,8 @@ void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) {
// GPU memory use to spike needlessly. An alternative strategy would
// be to throttle new Op execution until the pending event queue
// clears.
void EventMgr::PollEvents(bool is_dedicated_poller) {
void EventMgr::PollEvents(bool is_dedicated_poller,
gtl::InlinedVector<InUse, 4>* to_free) {
VLOG(2) << "PollEvents free_events_ " << free_events_.size()
<< " used_events_ " << used_events_.size();
// Sweep the remaining events in order. If this is the dedicated
@ -123,11 +125,9 @@ void EventMgr::PollEvents(bool is_dedicated_poller) {
if (!is_dedicated_poller) return; // quit processing queue
break;
case gpu::Event::Status::kComplete:
delete iu.mem;
if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
// The function must be called in another thread, outside of
// the mutex held here.
if (iu.func != nullptr) threadpool_.Schedule(iu.func);
// Make a copy of the InUse record so we can free it after releasing
// the lock
to_free->push_back(iu);
free_events_.push_back(iu.event);
// Mark this InUse record as completed.
iu.event = nullptr;

View File

@ -18,8 +18,10 @@ limitations under the License.
#include <deque>
#include <vector>
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/tensor.h"
@ -47,9 +49,15 @@ class EventMgr {
// currently enqueued on *stream have completed.
inline void ThenDeleteTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors) {
mutex_lock l(mu_);
QueueTensors(stream, tensors);
PollEvents(false);
ToFreeVector to_free;
::perftools::gputools::Event* e;
{
mutex_lock l(mu_);
QueueTensors(stream, tensors, &e);
PollEvents(false, &to_free);
}
stream->ThenRecordEvent(e);
FreeMemory(to_free);
}
struct BufRec {
@ -61,16 +69,28 @@ class EventMgr {
// on it as soon as all events currently enqueued on *stream have completed.
inline void ThenDeleteBuffer(perftools::gputools::Stream* stream,
BufRec bufrec) {
mutex_lock l(mu_);
QueueBuffer(stream, bufrec);
PollEvents(false);
ToFreeVector to_free;
::perftools::gputools::Event* e;
{
mutex_lock l(mu_);
QueueBuffer(stream, bufrec, &e);
PollEvents(false, &to_free);
}
stream->ThenRecordEvent(e);
FreeMemory(to_free);
}
inline void ThenExecute(perftools::gputools::Stream* stream,
std::function<void()> func) {
mutex_lock l(mu_);
QueueFunc(stream, func);
PollEvents(false);
ToFreeVector to_free;
::perftools::gputools::Event* e;
{
mutex_lock l(mu_);
QueueFunc(stream, func, &e);
PollEvents(false, &to_free);
}
stream->ThenRecordEvent(e);
FreeMemory(to_free);
}
private:
@ -85,32 +105,50 @@ class EventMgr {
std::function<void()> func;
};
typedef gtl::InlinedVector<InUse, 4> ToFreeVector;
void FreeMemory(const ToFreeVector& to_free) {
for (const auto& iu : to_free) {
delete iu.mem;
if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
// The function must be called in another thread.
if (iu.func != nullptr) threadpool_.Schedule(iu.func);
}
}
// Stream-enqueue an unused Event and save with it a collection of
// Tensors and/or a BufRec to be deleted only after the Event
// records.
void QueueInUse(perftools::gputools::Stream* stream, InUse in_use)
void QueueInUse(perftools::gputools::Stream* stream, InUse in_use,
::perftools::gputools::Event** e)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
void QueueTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors)
std::vector<Tensor>* tensors,
::perftools::gputools::Event** e)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr});
QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr}, e);
}
void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec)
void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec,
::perftools::gputools::Event** e)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr});
QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr}, e);
}
void QueueFunc(perftools::gputools::Stream* stream,
std::function<void()> func) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, nullptr, BufRec(), func});
std::function<void()> func, ::perftools::gputools::Event** e)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, nullptr, BufRec(), func}, e);
}
// This function should be called at roughly the same tempo as
// QueueTensors() to check whether pending events have recorded,
// and then retire them.
void PollEvents(bool is_dedicated_poller) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// and then retire them. It appends InUse elements that need cleanup
// to "*to_free". The caller should call FreeMemory(to_free)
// when this returns.
void PollEvents(bool is_dedicated_poller, ToFreeVector* to_free)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
// An internal polling loop that runs at a low frequency to clear
// straggler Events.

View File

@ -42,13 +42,21 @@ class TEST_EventMgrHelper {
void QueueTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors) {
mutex_lock l(em_->mu_);
em_->QueueTensors(stream, tensors);
::perftools::gputools::Event* e;
{
mutex_lock l(em_->mu_);
em_->QueueTensors(stream, tensors, &e);
}
stream->ThenRecordEvent(e);
}
void PollEvents(bool is_dedicated_poller) {
mutex_lock l(em_->mu_);
em_->PollEvents(is_dedicated_poller);
EventMgr::ToFreeVector to_free;
{
mutex_lock l(em_->mu_);
em_->PollEvents(is_dedicated_poller, &to_free);
}
em_->FreeMemory(to_free);
}
private:

View File

@ -119,9 +119,10 @@ class DeviceBase {
// "event_mgr" is used to delay deallocation of temporary GPU buffers.
// TODO(pbar) Work out how to move this out of DeviceBase.
struct GpuDeviceInfo {
perftools::gputools::Stream* stream;
DeviceContext* default_context;
EventMgr* event_mgr;
// Make sure all the defaults are NULL, so we can spot missing assignments.
perftools::gputools::Stream* stream = nullptr;
DeviceContext* default_context = nullptr;
EventMgr* event_mgr = nullptr;
};
// Does not take ownership.

View File

@ -55,6 +55,24 @@ struct CastFunctor<CPUDevice, O, I> {
} // namespace functor
#define CURRY_TYPES2(FN, arg0) \
FN(arg0, bool); \
FN(arg0, uint8); \
FN(arg0, int16); \
FN(arg0, int32); \
FN(arg0, int64); \
FN(arg0, float); \
FN(arg0, double)
#define CURRY_TYPES3(FN, arg0, arg1) \
FN(arg0, arg1, bool); \
FN(arg0, arg1, uint8); \
FN(arg0, arg1, int16); \
FN(arg0, arg1, int32); \
FN(arg0, arg1, int64); \
FN(arg0, arg1, float); \
FN(arg0, arg1, double)
#define CAST_CASE(DEVICE, IN, OUT) \
if (DataTypeToEnum<IN>::value == src_dtype_ && \
DataTypeToEnum<OUT>::value == dst_dtype_) { \
@ -110,27 +128,14 @@ class CpuCastOp : public CastOpBase {
work_ = nullptr; // Identity
return Status::OK();
}
CAST_CASE(CPUDevice, bool, float);
CAST_CASE(CPUDevice, bool, int32);
CAST_CASE(CPUDevice, bool, double);
CAST_CASE(CPUDevice, double, float);
CAST_CASE(CPUDevice, double, int32);
CAST_CASE(CPUDevice, double, int64);
CAST_CASE(CPUDevice, float, double);
CAST_CASE(CPUDevice, float, uint8);
CAST_CASE(CPUDevice, float, int32);
CAST_CASE(CPUDevice, float, int64);
CAST_CASE(CPUDevice, int32, double);
CAST_CASE(CPUDevice, int32, float);
CAST_CASE(CPUDevice, int32, uint8);
CAST_CASE(CPUDevice, int32, int64);
CAST_CASE(CPUDevice, int64, double);
CAST_CASE(CPUDevice, int64, float);
CAST_CASE(CPUDevice, int64, int32);
CAST_CASE(CPUDevice, uint8, float);
CAST_CASE(CPUDevice, uint8, int32);
CAST_CASE(CPUDevice, uint8, int64);
CAST_CASE(CPUDevice, uint8, double);
CURRY_TYPES3(CAST_CASE, CPUDevice, bool);
CURRY_TYPES3(CAST_CASE, CPUDevice, uint8);
CURRY_TYPES3(CAST_CASE, CPUDevice, int16);
CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
CURRY_TYPES3(CAST_CASE, CPUDevice, int64);
CURRY_TYPES3(CAST_CASE, CPUDevice, float);
CURRY_TYPES3(CAST_CASE, CPUDevice, double);
if (src_dtype_ == DT_BFLOAT16 && dst_dtype_ == DT_FLOAT) {
work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
int64 N = out->NumElements();
@ -185,24 +190,15 @@ class GpuCastOp : public CastOpBase {
work_ = nullptr; // Identity
return Status::OK();
}
CAST_CASE(GPUDevice, bfloat16, float);
CAST_CASE(GPUDevice, bool, float);
CAST_CASE(GPUDevice, double, float);
CAST_CASE(GPUDevice, double, int64);
CURRY_TYPES3(CAST_CASE, GPUDevice, bool);
CURRY_TYPES3(CAST_CASE, GPUDevice, uint8);
CURRY_TYPES3(CAST_CASE, GPUDevice, int16);
CURRY_TYPES3(CAST_CASE, GPUDevice, int32);
CURRY_TYPES3(CAST_CASE, GPUDevice, int64);
CURRY_TYPES3(CAST_CASE, GPUDevice, float);
CURRY_TYPES3(CAST_CASE, GPUDevice, double);
CAST_CASE(GPUDevice, float, bfloat16);
CAST_CASE(GPUDevice, float, double);
CAST_CASE(GPUDevice, float, int64);
CAST_CASE(GPUDevice, int64, double);
CAST_CASE(GPUDevice, int64, float);
CAST_CASE(GPUDevice, uint8, float);
CAST_CASE(GPUDevice, float, uint8);
CAST_CASE(GPUDevice, bool, int32);
CAST_CASE(GPUDevice, double, int32);
CAST_CASE(GPUDevice, float, int32);
CAST_CASE(GPUDevice, int32, double);
CAST_CASE(GPUDevice, int32, float);
CAST_CASE(GPUDevice, int32, int64);
CAST_CASE(GPUDevice, int64, int32);
CAST_CASE(GPUDevice, bfloat16, float);
return Unimplemented();
}
};
@ -217,28 +213,24 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
.TypeConstraint<srctype>("SrcT") \
.TypeConstraint<dsttype>("DstT") \
.Device(DEVICE_GPU), \
GpuCastOp);
REGISTER_CAST_GPU(bfloat16, float);
REGISTER_CAST_GPU(bool, float);
REGISTER_CAST_GPU(double, float);
REGISTER_CAST_GPU(double, int64);
GpuCastOp)
CURRY_TYPES2(REGISTER_CAST_GPU, bool);
CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
CURRY_TYPES2(REGISTER_CAST_GPU, int16);
CURRY_TYPES2(REGISTER_CAST_GPU, int32);
CURRY_TYPES2(REGISTER_CAST_GPU, int64);
CURRY_TYPES2(REGISTER_CAST_GPU, float);
CURRY_TYPES2(REGISTER_CAST_GPU, double);
REGISTER_CAST_GPU(float, bfloat16);
REGISTER_CAST_GPU(float, double);
REGISTER_CAST_GPU(float, int64);
REGISTER_CAST_GPU(int64, double);
REGISTER_CAST_GPU(int64, float);
REGISTER_CAST_GPU(uint8, float);
REGISTER_CAST_GPU(float, uint8);
REGISTER_CAST_GPU(bool, int32);
REGISTER_CAST_GPU(double, int32);
REGISTER_CAST_GPU(float, int32);
REGISTER_CAST_GPU(int32, double);
REGISTER_CAST_GPU(int32, float);
REGISTER_CAST_GPU(int32, int64);
REGISTER_CAST_GPU(int64, int32);
REGISTER_CAST_GPU(bfloat16, float);
#undef REGISTER_CAST_GPU
#endif // GOOGLE_CUDA
#undef CURRY_TYPES2
#undef CURRY_TYPES3
// HostCast differs from Cast in that its input and output are in host memory.
REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
REGISTER_KERNEL_BUILDER(

View File

@ -33,25 +33,27 @@ struct CastFunctor<GPUDevice, O, I> {
}
};
#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>;
DEFINE(float, double);
DEFINE(float, int32);
DEFINE(float, int64);
DEFINE(double, float);
DEFINE(double, int32);
DEFINE(double, int64);
DEFINE(int32, float);
DEFINE(int32, double);
DEFINE(int32, int64);
DEFINE(int64, float);
DEFINE(int64, double);
DEFINE(int64, int32);
DEFINE(int32, bool);
DEFINE(float, bool);
DEFINE(float, uint8);
DEFINE(uint8, float);
DEFINE(float, bfloat16);
#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>
#define DEFINE_ALL_FROM(in_type) \
DEFINE(in_type, bool); \
DEFINE(in_type, uint8); \
DEFINE(in_type, int16); \
DEFINE(in_type, int32); \
DEFINE(in_type, int64); \
DEFINE(in_type, float); \
DEFINE(in_type, double)
DEFINE_ALL_FROM(bool);
DEFINE_ALL_FROM(uint8);
DEFINE_ALL_FROM(int16);
DEFINE_ALL_FROM(int32);
DEFINE_ALL_FROM(int64);
DEFINE_ALL_FROM(float);
DEFINE_ALL_FROM(double);
DEFINE(bfloat16, float);
DEFINE(float, bfloat16);
#undef DEFINE_ALL_FROM
#undef DEFINE
} // end namespace functor

View File

@ -41,22 +41,48 @@ class CastOpTest : public OpsTestBase {
void MakeOp(DataType src, DataType dst) {
RequireDefaultOps();
EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
.Input(FakeInput(DT_INT32))
.Input(FakeInput(src))
.Attr("SrcT", src)
.Attr("DstT", dst)
.Finalize(node_def()));
EXPECT_OK(InitOp());
}
template <typename IN, typename OUT>
void CheckCast() {
DataType in_type = DataTypeToEnum<IN>::v();
DataType out_type = DataTypeToEnum<OUT>::v();
MakeOp(in_type, out_type);
AddInputFromArray<IN>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), out_type, TensorShape({1, 2, 2, 1}));
test::FillValues<OUT>(&expected, {1, 2, 3, 4});
test::ExpectTensorEqual<OUT>(expected, *GetOutput(0));
}
};
TEST_F(CastOpTest, Int32ToUint8) {
MakeOp(DT_INT32, DT_UINT8);
AddInputFromArray<int32>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_UINT8, TensorShape({1, 2, 2, 1}));
test::FillValues<uint8>(&expected, {1, 2, 3, 4});
test::ExpectTensorEqual<uint8>(expected, *GetOutput(0));
}
#define TEST_CAST(in, out) \
TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); }
#define TEST_ALL_CASTS_FROM(in) \
TEST_CAST(in, uint8); \
TEST_CAST(in, int16); \
TEST_CAST(in, int32); \
TEST_CAST(in, int64); \
TEST_CAST(in, float); \
TEST_CAST(in, double)
TEST_ALL_CASTS_FROM(uint8)
TEST_ALL_CASTS_FROM(int16)
TEST_ALL_CASTS_FROM(int32)
TEST_ALL_CASTS_FROM(int64)
TEST_ALL_CASTS_FROM(float)
TEST_ALL_CASTS_FROM(double)
#undef TEST_ALL_CASTS_FROM
#undef TEST_CAST
// TODO(wicke): check conversions from/to bool, and bfloat16
static void BM_cpu_float_int64(int iters, int num) {
testing::ItemsProcessed(static_cast<int64>(iters) * num);

View File

@ -34,10 +34,12 @@ void ConcatGPU(const GPUDevice& d,
const std::vector<
std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
typename TTypes<T, 2>::Matrix* output) {
Eigen::array<Eigen::DenseIndex, 2> offset(0, 0);
Eigen::array<int32, 2> offset{0, 0};
for (int i = 0; i < inputs.size(); ++i) {
Eigen::array<Eigen::DenseIndex, 2> size = inputs[i]->dimensions();
output->slice(offset, size).device(d) = *inputs[i];
Eigen::array<int32_t, 2> size;
size[0] = inputs[i]->dimension(0);
size[1] = inputs[i]->dimension(1);
To32Bit(*output).slice(offset, size).device(d) = To32Bit(*inputs[i]);
offset[1] += size[1];
}
}

View File

@ -73,7 +73,7 @@ struct FillFunctor<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstScalar in) {
Eigen::internal::scalar_const_op<T> f(in.data());
out.device(d) = out.nullaryExpr(f);
To32Bit(out).device(d) = To32Bit(out).nullaryExpr(f);
}
};
@ -91,7 +91,7 @@ DEFINE_FILL_GPU(int64);
template <typename T>
struct SetZeroFunctor<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out) {
out.device(d) = out.constant(0);
To32Bit(out).device(d) = To32Bit(out).constant(0);
}
};

View File

@ -242,13 +242,13 @@ typedef Eigen::GpuDevice GPUDevice;
const auto expanded_out_cols = (output_cols - 1) * stride + 1; \
const auto padded_out_rows = input_rows + filter_rows - 1; \
const auto padded_out_cols = input_cols + filter_cols - 1; \
const auto top_pad_rows = filter_rows - 1 - pad_rows; \
const auto left_pad_cols = filter_cols - 1 - pad_cols; \
const auto bottom_pad_rows = \
const int top_pad_rows = filter_rows - 1 - pad_rows; \
const int left_pad_cols = filter_cols - 1 - pad_cols; \
const int bottom_pad_rows = \
padded_out_rows - expanded_out_rows - top_pad_rows; \
const auto right_pad_cols = \
const int right_pad_cols = \
padded_out_cols - expanded_out_cols - left_pad_cols; \
Eigen::DSizes<Eigen::DenseIndex, 4> strides{1, stride, stride, 1}; \
Eigen::DSizes<int, 4> strides{1, stride, stride, 1}; \
VLOG(2) << "Conv2d: " << label \
<< ": expanded_out_rows = " << expanded_out_rows \
<< ", expanded_out_cols = " << expanded_out_cols \
@ -809,9 +809,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
context->allocate_output(0, input_shape, &in_backprop));
const int padding_rows =
(output_rows - 1) * stride + filter_rows - input_rows;
(padding_ == VALID) ? 0 : (output_rows - 1) * stride + filter_rows -
input_rows;
const int padding_cols =
(output_cols - 1) * stride + filter_cols - input_cols;
(padding_ == VALID) ? 0 : (output_cols - 1) * stride + filter_cols -
input_cols;
// TODO(keveman): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
@ -954,16 +956,17 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
context->allocate_temp(DataTypeToEnum<T>::v(),
padded_out_shape, &padded_output));
Eigen::DSizes<Eigen::DenseIndex, 4> trivial_order{0, 1, 2, 3};
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4> pad_dims{
Eigen::DSizes<int, 4> trivial_order{0, 1, 2, 3};
Eigen::array<Eigen::IndexPair<int>, 4> pad_dims{
{{0, 0},
{top_pad_rows, bottom_pad_rows},
{left_pad_cols, right_pad_cols},
{0, 0}}};
functor::InflatePadAndShuffle<Device, T, 4, Eigen::DenseIndex>()(
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides,
pad_dims, trivial_order, padded_output.tensor<T, 4>());
functor::InflatePadAndShuffle<Device, T, 4, int>()(
context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()),
strides, pad_dims, trivial_order,
To32Bit(padded_output.tensor<T, 4>()));
const Tensor& padded_output_cref = padded_output;
// We then need to fill a new "reverted" filter
@ -976,11 +979,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
context->allocate_temp(DataTypeToEnum<T>::v(),
r_filter_shape, &r_filter));
Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{0, 1, 3, 2};
Eigen::DSizes<int, 4> filter_order{0, 1, 3, 2};
Eigen::array<bool, 4> filter_rev_dims{true, true, false, false};
functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()(
context->eigen_device<Device>(), filter.tensor<T, 4>(), filter_order,
filter_rev_dims, r_filter.tensor<T, 4>());
functor::ShuffleAndReverse<Device, T, 4, int>()(
context->eigen_device<Device>(), To32Bit(filter.tensor<T, 4>()),
filter_order, filter_rev_dims, To32Bit(r_filter.tensor<T, 4>()));
const Tensor& r_filter_cref = r_filter;
// Now we can call conv_2d directly.
@ -1039,20 +1042,22 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
context->allocate_output(0, filter_shape, &filter_backprop));
const int padding_rows =
(output_rows - 1) * stride + filter_rows - input_rows;
(padding_ == VALID) ? 0 : (output_rows - 1) * stride + filter_rows -
input_rows;
const int padding_cols =
(output_cols - 1) * stride + filter_cols - input_cols;
(padding_ == VALID) ? 0 : (output_cols - 1) * stride + filter_cols -
input_cols;
// TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
// supporting different padding.
bool padding_compatible =
(padding_rows % 2 == 0) && (padding_cols % 2 == 0);
bool rows_odd = (padding_rows % 2 != 0);
bool cols_odd = (padding_cols % 2 != 0);
auto* stream = context->op_device_context<GPUDeviceContext>()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
if (use_cudnn_ && padding_compatible) {
if (use_cudnn_) {
if (filter_rows == 1 && filter_cols == 1 && stride == 1) {
const uint64 m = in_depth;
const uint64 k = batch * input_rows * input_cols;
@ -1089,10 +1094,31 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
return;
}
Tensor compatible_input;
if (rows_odd || cols_odd) {
// If a padding dimension is odd, we have one more element on the right
// side or the bottom side. This is unsupported in cudnn. Therefore,
// we pad that extra element and make it compatible.
OP_REQUIRES_OK(
context,
context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({input.dim_size(0), input.dim_size(1) + rows_odd,
input.dim_size(2) + cols_odd, input.dim_size(3)}),
&compatible_input));
functor::PadInput<GPUDevice, T, int>()(
context->template eigen_device<GPUDevice>(),
To32Bit(input.tensor<T, 4>()), 0, rows_odd, 0, cols_odd,
To32Bit(compatible_input.tensor<T, 4>()));
} else {
compatible_input = input;
}
perftools::gputools::dnn::BatchDescriptor input_desc;
input_desc.set_count(batch)
.set_height(input_rows)
.set_width(input_cols)
.set_height(compatible_input.dim_size(1))
.set_width(compatible_input.dim_size(2))
.set_feature_map_count(in_depth)
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
perftools::gputools::dnn::BatchDescriptor output_desc;
@ -1146,14 +1172,19 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
transformed_out_backprop.tensor<T, 4>());
Tensor transformed_input;
OP_REQUIRES_OK(context,
context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({batch, in_depth, input_rows, input_cols}),
&transformed_input));
functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(),
input.tensor<T, 4>(),
transformed_input.tensor<T, 4>());
OP_REQUIRES_OK(
context,
context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({
compatible_input.dim_size(0), compatible_input.dim_size(3),
compatible_input.dim_size(1), compatible_input.dim_size(2),
}),
&transformed_input));
functor::NHWCToNCHW<Device, T>()(
context->eigen_device<Device>(),
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
transformed_input.tensor<T, 4>());
auto out_backprop_ptr =
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
@ -1193,7 +1224,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
// [batch, out_rows, out_cols, out_depth]
// And we need to change it to
// [out_depth, out_rows, out_cols, batch]
Eigen::DSizes<Eigen::DenseIndex, 4> out_order{3, 1, 2, 0};
Eigen::DSizes<int, 4> out_order{3, 1, 2, 0};
TensorShape padded_out_shape(
{out_depth, padded_out_rows, padded_out_cols, batch});
Tensor padded_output;
@ -1201,14 +1232,14 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
context->allocate_temp(DataTypeToEnum<T>::v(),
padded_out_shape, &padded_output));
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4> pad_dims{
Eigen::array<Eigen::IndexPair<int>, 4> pad_dims{
{{0, 0},
{top_pad_rows, bottom_pad_rows},
{left_pad_cols, right_pad_cols},
{0, 0}}};
functor::InflatePadAndShuffle<Device, T, 4, Eigen::DenseIndex>()(
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides,
pad_dims, out_order, padded_output.tensor<T, 4>());
functor::InflatePadAndShuffle<Device, T, 4, int>()(
context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()),
strides, pad_dims, out_order, To32Bit(padded_output.tensor<T, 4>()));
const Tensor& padded_output_cref = padded_output;
// For the backprop of the filter, we need to transpose the input.
@ -1216,7 +1247,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
// [batch, in_rows, in_cols, in_depth]
// And we need to change it to
// [in_rows, in_cols, batch, in_depth]
Eigen::DSizes<Eigen::DenseIndex, 4> in_order{1, 2, 0, 3};
Eigen::DSizes<int, 4> in_order{1, 2, 0, 3};
TensorShape in_shuffle_shape({input_rows, input_cols, batch, in_depth});
Tensor in_shuffle;
OP_REQUIRES_OK(context,
@ -1225,9 +1256,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
// No need for reversing this time.
Eigen::array<bool, 4> trivial_dims{false, false, false, false};
functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()(
context->eigen_device<Device>(), input.tensor<T, 4>(), in_order,
trivial_dims, in_shuffle.tensor<T, 4>());
functor::ShuffleAndReverse<Device, T, 4, int>()(
context->eigen_device<Device>(), To32Bit(input.tensor<T, 4>()),
in_order, trivial_dims, To32Bit(in_shuffle.tensor<T, 4>()));
const Tensor& in_shuffle_cref = in_shuffle;
// The output of the conv_2d would be
@ -1250,12 +1281,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
BrainPadding2EigenPadding(VALID));
// Now copy the filter_backprop back to the destination.
Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{1, 2, 3, 0};
Eigen::DSizes<int, 4> filter_order{1, 2, 3, 0};
Eigen::array<bool, 4> filter_rev_dims{true, true, false, false};
const Tensor& filter_shuffle_cref = filter_shuffle;
functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()(
context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 4>(),
filter_order, filter_rev_dims, filter_backprop->tensor<T, 4>());
functor::ShuffleAndReverse<Device, T, 4, int>()(
context->eigen_device<Device>(),
To32Bit(filter_shuffle_cref.tensor<T, 4>()), filter_order,
filter_rev_dims, To32Bit(filter_backprop->tensor<T, 4>()));
}
}
@ -1271,25 +1303,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void ShuffleAndReverse<GPUDevice, T, 4, Eigen::DenseIndex>::operator()( \
const GPUDevice& d, \
typename TTypes<T, 4, Eigen::DenseIndex>::ConstTensor input, \
const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
const Eigen::array<bool, 4>& reverse_dims, \
typename TTypes<T, 4, Eigen::DenseIndex>::Tensor output); \
extern template struct ShuffleAndReverse<GPUDevice, T, 4, \
Eigen::DenseIndex>; \
template <> \
void InflatePadAndShuffle<GPUDevice, T, 4, Eigen::DenseIndex>::operator()( \
const GPUDevice& d, \
typename TTypes<T, 4, Eigen::DenseIndex>::ConstTensor input, \
const Eigen::DSizes<Eigen::DenseIndex, 4>& strides, \
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4>& pad_dims, \
const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
typename TTypes<T, 4, Eigen::DenseIndex>::Tensor output); \
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, \
Eigen::DenseIndex>; \
template <> \
void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
const Eigen::DSizes<int, 4>& order, \
@ -1328,7 +1341,13 @@ namespace functor {
typename TTypes<T, 4>::ConstTensor filter, \
typename TTypes<T, 4>::ConstTensor output_backprop, int input_rows, \
int input_cols, int stride); \
extern template struct SpatialConvolutionBackwardInput<GPUDevice, T>
extern template struct SpatialConvolutionBackwardInput<GPUDevice, T>; \
template <> \
void PadInput<GPUDevice, T, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
int padding_rows_left, int padding_rows_right, int padding_cols_left, \
int padding_cols_right, typename TTypes<T, 4, int>::Tensor out); \
extern template struct PadInput<GPUDevice, T, int>;
DECLARE_GPU_SPEC(float);
#undef DECLARE_GPU_SPEC

View File

@ -33,12 +33,8 @@ struct SpatialConvolution<GPUDevice, T> {
typename TTypes<T, 4>::ConstTensor input,
typename TTypes<T, 4>::ConstTensor filter, int stride,
const Eigen::PaddingType& padding) {
// TODO(keveman): nvcc 6.5 crashes when 32 bit indexing is turned on. Enable
// this when we move to cuda 7.0.
// SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input),
// To32Bit(filter), stride, padding);
SpatialConvolutionFunc(d, output, input, filter, stride, padding);
SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input), To32Bit(filter),
stride, padding);
}
};

View File

@ -16,21 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER5(BinaryOp, CPU, "Div", functor::div, float, double, int32, int64,
complex64);
REGISTER7(BinaryOp, CPU, "Div", functor::div, float, double, uint8, int16,
int32, int64, complex64);
#if GOOGLE_CUDA
REGISTER3(BinaryOp, GPU, "Div", functor::div, float, double, int64);
REGISTER6(BinaryOp, GPU, "Div", functor::div, float, double, uint8, int16,
int32, int64);
#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(Name("Div")
.Device(DEVICE_GPU)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::div<int32>>);
} // namespace tensorflow

View File

@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
DEFINE_BINARY3(div, float, double, int64);
DEFINE_BINARY6(div, float, double, uint8, int16, int32, int64);
} // namespace functor
} // namespace tensorflow

View File

@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
DEFINE_BINARY3(mul, float, double, int64);
DEFINE_BINARY7(mul, float, double, uint8, int8, int16, int32, int64);
} // namespace functor
} // namespace tensorflow

View File

@ -16,21 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER7(BinaryOp, CPU, "Mul", functor::mul, float, double, int32, int64, int8,
int16, complex64);
REGISTER8(BinaryOp, CPU, "Mul", functor::mul, float, double, uint8, int8, int16,
int32, int64, complex64);
#if GOOGLE_CUDA
REGISTER3(BinaryOp, GPU, "Mul", functor::mul, float, double, int64);
REGISTER7(BinaryOp, GPU, "Mul", functor::mul, float, double, uint8, int8, int16,
int32, int64);
#endif
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(Name("Mul")
.Device(DEVICE_GPU)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::mul<int32>>);
} // namespace tensorflow

View File

@ -379,6 +379,8 @@ struct SelectFunctor<CPUDevice, T> {
#define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
REGISTER(OP, D, N, F, T0)
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
REGISTER(OP, D, N, F, T0)
#else // !defined(__ANDROID__)
#define REGISTER2(OP, D, N, F, T0, T1) \
REGISTER(OP, D, N, F, T0) \
@ -398,6 +400,9 @@ struct SelectFunctor<CPUDevice, T> {
#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
REGISTER3(OP, D, N, F, T4, T5, T6)
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
REGISTER4(OP, D, N, F, T4, T5, T6, T7)
#endif // defined(__ANDROID__)
} // end namespace tensorflow

View File

@ -40,7 +40,7 @@ template <typename Functor>
struct UnaryFunctor<GPUDevice, Functor> {
void operator()(const GPUDevice& d, typename Functor::tout_type out,
typename Functor::tin_type in) {
out.device(d) = in.unaryExpr(typename Functor::func());
To32Bit(out).device(d) = To32Bit(in).unaryExpr(typename Functor::func());
}
};
@ -50,7 +50,8 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> {
void operator()(const GPUDevice& d, typename Functor::tout_type out,
typename Functor::tin_type in0,
typename Functor::tin_type in1) {
out.device(d) = in0.binaryExpr(in1, typename Functor::func());
To32Bit(out).device(d) =
To32Bit(in0).binaryExpr(in1, typename Functor::func());
}
void Left(const GPUDevice& d, typename Functor::tout_type out,
@ -60,7 +61,7 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> {
typedef typename Functor::in_type Tin;
typedef typename Functor::func Binary;
typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
out.device(d) = in.unaryExpr(Unary(scalar.data()));
To32Bit(out).device(d) = To32Bit(in).unaryExpr(Unary(scalar.data()));
}
void Right(const GPUDevice& d, typename Functor::tout_type out,
@ -70,7 +71,7 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> {
typedef typename Functor::in_type Tin;
typedef typename Functor::func Binary;
typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
out.device(d) = in.unaryExpr(Unary(scalar.data()));
To32Bit(out).device(d) = To32Bit(in).unaryExpr(Unary(scalar.data()));
}
void BCast(const GPUDevice& d,
@ -86,16 +87,18 @@ struct BinaryFunctor<GPUDevice, Functor, NDIMS> {
const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
if (bcast0_all_one && !bcast1_all_one) {
out.device(d) = in0.binaryExpr(in1.broadcast(bcast1), func);
To32Bit(out).device(d) =
To32Bit(in0).binaryExpr(To32Bit(in1).broadcast(bcast1), func);
return;
}
if (!bcast0_all_one && bcast1_all_one) {
out.device(d) = in0.broadcast(bcast0).binaryExpr(in1, func);
To32Bit(out).device(d) =
To32Bit(in0).broadcast(bcast0).binaryExpr(To32Bit(in1), func);
return;
}
}
out.device(d) =
in0.broadcast(bcast0).binaryExpr(in1.broadcast(bcast1), func);
To32Bit(out).device(d) = To32Bit(in0).broadcast(bcast0).binaryExpr(
To32Bit(in1).broadcast(bcast1), func);
}
};
@ -105,7 +108,8 @@ struct SelectFunctor<GPUDevice, T> {
typename TTypes<bool>::ConstFlat cond_flat,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat) {
out.device(d) = cond_flat.select(then_flat, else_flat);
To32Bit(out).device(d) =
To32Bit(cond_flat).select(To32Bit(then_flat), To32Bit(else_flat));
}
};
@ -143,6 +147,12 @@ struct SelectFunctor<GPUDevice, T> {
#define DEFINE_BINARY5(F, T0, T1, T2, T3, T4) \
DEFINE_BINARY2(F, T0, T1); \
DEFINE_BINARY3(F, T2, T3, T4)
#define DEFINE_BINARY6(F, T0, T1, T2, T3, T4, T5) \
DEFINE_BINARY3(F, T0, T1, T2); \
DEFINE_BINARY3(F, T3, T4, T5)
#define DEFINE_BINARY7(F, T0, T1, T2, T3, T4, T5, T6) \
DEFINE_BINARY3(F, T0, T1, T2); \
DEFINE_BINARY4(F, T3, T4, T5, T6)
} // end namespace functor
} // end namespace tensorflow

View File

@ -30,10 +30,17 @@ limitations under the License.
namespace tensorflow {
namespace {
// When the depth is large and beta_ is 0.5 or 1.0, MognetLRN is faster than the
// main band matrix approach used below. Benchmarks suggest switching to
// MognetLRN when depth > 384.
const int kMognetLRNDepthCutoff = 384;
// Create a depth-by-depth band matrix with 1s along a swath of size (2 *
// depth_radius + 1) around the diagonal.
static void GetBandMatrix(int depth, int64 depth_radius,
Eigen::Tensor<float, 2, Eigen::RowMajor>* result) {
void GetBandMatrix(int depth, int64 depth_radius,
Eigen::Tensor<float, 2, Eigen::RowMajor>* result) {
result->setZero();
for (int row = 0; row < depth; ++row) {
const int begin = std::max<int>(0, row - depth_radius);
@ -44,6 +51,8 @@ static void GetBandMatrix(int depth, int64 depth_radius,
}
}
} // namespace
class LRNOp : public OpKernel {
public:
explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) {
@ -69,6 +78,11 @@ class LRNOp : public OpKernel {
#if defined(__ANDROID__)
MognetLRN(in, batch, rows, cols, depth, output);
#else
if (depth > kMognetLRNDepthCutoff && (beta_ == 0.5f || beta_ == 1.0f)) {
MognetLRN(in, batch, rows, cols, depth, output);
return;
}
const int nodes = cols * rows;
auto in_shaped = in.shaped<float, 2>({nodes * batch, depth});
@ -79,13 +93,16 @@ class LRNOp : public OpKernel {
auto out_shaped = output->shaped<float, 2>({nodes * batch, depth});
Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
/// TODO(keveman): Optimize for beta in {0, 1, 0.5}
out_shaped.device(context->eigen_cpu_device()) =
in_shaped /
in_shaped.square()
.contract(multiplier, dims)
.unaryExpr([this](float x) { return bias_ + alpha_ * x; })
.pow(beta_);
auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
if (beta_ == 1.0f) {
out_shaped.device(context->eigen_cpu_device()) =
in_shaped * tmp.inverse();
} else if (beta_ == 0.5f) {
out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
} else {
out_shaped.device(context->eigen_cpu_device()) =
in_shaped * (tmp.log() * -beta_).exp();
}
#endif
}
@ -104,11 +121,11 @@ class LRNOp : public OpKernel {
Eigen::VectorXf padded_square(data_in.rows() + double_depth_radius);
padded_square.setZero();
for (int r = 0; r < data_in.cols(); ++r) {
// Do local response normalization for data_in(:, r)
// first, compute the square and store them in buffer for repeated use
// Do local response normalization for data_in(:, r). First, compute the
// square and store them in buffer for repeated use.
padded_square.block(depth_radius_, 0, data_out.rows(), 1) =
data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_;
// Then, compute the scale and writes them to data_out
// Then, compute the scale and write it to data_out.
float accumulated_scale = 0;
for (int i = 0; i < double_depth_radius; ++i) {
accumulated_scale += padded_square(i);
@ -120,13 +137,13 @@ class LRNOp : public OpKernel {
}
}
// In a few cases, the pow computation could benefit from speedups.
if (beta_ == 1) {
data_out.array() = data_in.array() * data_out.array().inverse();
} else if (beta_ == 0.5) {
data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
data_out.array() = data_in.array() * data_out.array().rsqrt();
} else {
data_out.array() = data_in.array() * data_out.array().pow(-beta_);
data_out.array() =
data_in.array() * (data_out.array().log() * -beta_).exp();
}
}

View File

@ -1,90 +0,0 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_REFERENCE_GEMM_H_
#define TENSORFLOW_KERNELS_REFERENCE_GEMM_H_
// This is an unoptimized but debuggable implementation of the GEMM matrix
// multiply function, used to compare to faster but more opaque versions, or
// for bit depths or argument combinations that aren't supported by optimized
// code.
// It assumes the row-major convention used by TensorFlow, and implements
// C = A * B, like the standard BLAS GEMM interface. If the tranpose flags are
// true, then the relevant matrix is treated as stored in column-major order.
namespace tensorflow {
template <class T1, class T2, class T3>
void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c,
size_t m, size_t n, size_t k, const T1* a, T1 offset_a,
size_t lda, const T2* b, T2 offset_b, size_t ldb, T3* c,
int32 shift_c, int32 offset_c, int32 mult_c, size_t ldc) {
int a_i_stride;
int a_l_stride;
if (transpose_a) {
a_i_stride = 1;
a_l_stride = lda;
} else {
a_i_stride = lda;
a_l_stride = 1;
}
int b_j_stride;
int b_l_stride;
if (transpose_b) {
b_j_stride = ldb;
b_l_stride = 1;
} else {
b_j_stride = 1;
b_l_stride = ldb;
}
int c_i_stride;
int c_j_stride;
if (transpose_c) {
c_i_stride = 1;
c_j_stride = ldc;
} else {
c_i_stride = ldc;
c_j_stride = 1;
}
const int32 highest = static_cast<int32>(Eigen::NumTraits<T3>::highest());
const int32 lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest());
const int32 rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1));
int i, j, l;
for (j = 0; j < n; j++) {
for (i = 0; i < m; i++) {
int32 total = 0;
for (l = 0; l < k; l++) {
const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
const int32 a_value = a[a_index] - offset_a;
const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
const int32 b_value = b[b_index] - offset_b;
total += (a_value * b_value);
}
const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c);
if (output > highest) {
output = highest;
}
if (output < lowest) {
output = lowest;
}
c[c_index] = static_cast<T3>(output);
}
}
}
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_REFERENCE_GEMM_H_

View File

@ -39,7 +39,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device>
void CheckErrors(OpKernelContext* context, int seq_dim) {
void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
const Tensor& input = context->input(0);
const Tensor& seq_lens = context->input(1);
@ -52,15 +52,18 @@ void CheckErrors(OpKernelContext* context, int seq_dim) {
seq_lens_vec.data(), seq_lens_t.data(),
sizeof(int64) * seq_lens_t.size());
OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim"));
OP_REQUIRES(context, batch_dim != seq_dim,
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
OP_REQUIRES(context, seq_dim < input.dims(),
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
seq_dim, " vs. ", input.dims(), ")"));
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0),
errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ",
"(", seq_lens.NumElements(), " vs. ",
input.dim_size(seq_dim)));
OP_REQUIRES(context, batch_dim < input.dims(),
errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
batch_dim, " vs. ", input.dims(), ")"));
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
" vs. ", input.dim_size(batch_dim)));
for (int d = 0; d < seq_lens_vec.size(); ++d) {
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
@ -72,19 +75,24 @@ void CheckErrors(OpKernelContext* context, int seq_dim) {
}
template <>
void CheckErrors<GPUDevice>(OpKernelContext* context, int seq_dim) {
void CheckErrors<GPUDevice>(OpKernelContext* context, int batch_dim,
int seq_dim) {
const Tensor& input = context->input(0);
const Tensor& seq_lens = context->input(1);
OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim"));
OP_REQUIRES(context, batch_dim != seq_dim,
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
OP_REQUIRES(context, seq_dim < input.dims(),
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
seq_dim, " vs. ", input.dims(), ")"));
OP_REQUIRES(context, batch_dim < input.dims(),
errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
batch_dim, " vs. ", input.dims(), ")"));
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0),
errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ",
"(", seq_lens.NumElements(), " vs. ",
input.dim_size(seq_dim)));
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
" vs. ", input.dim_size(batch_dim)));
}
template <typename Device, typename T>
@ -92,6 +100,7 @@ class ReverseSequenceOp : public OpKernel {
public:
explicit ReverseSequenceOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
}
@ -106,7 +115,7 @@ class ReverseSequenceOp : public OpKernel {
auto seq_lens_t = seq_lens.vec<int64>();
CheckErrors<Device>(context, seq_dim_);
CheckErrors<Device>(context, batch_dim_, seq_dim_);
const int input_dims = input.dims();
@ -114,11 +123,11 @@ class ReverseSequenceOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
#define HANDLE_DIM(NDIM) \
case NDIM: \
functor::ReverseSequence<Device, T, NDIM>::Compute( \
context->eigen_device<Device>(), input.tensor<T, NDIM>(), seq_dim_, \
seq_lens_t, output->tensor<T, NDIM>()); \
#define HANDLE_DIM(NDIM) \
case NDIM: \
functor::ReverseSequence<Device, T, NDIM>::Compute( \
context->eigen_device<Device>(), input.tensor<T, NDIM>(), batch_dim_, \
seq_dim_, seq_lens_t, output->tensor<T, NDIM>()); \
break;
switch (input_dims) {
@ -136,6 +145,7 @@ class ReverseSequenceOp : public OpKernel {
}
private:
int32 batch_dim_;
int32 seq_dim_;
TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp);
@ -152,12 +162,12 @@ TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE);
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T, Dims) \
template <> \
void ReverseSequence<GPUDevice, T, Dims>::Compute( \
const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
int32 seq_dim, TTypes<int64>::ConstVec seq_lens, \
typename TTypes<T, Dims>::Tensor output); \
#define DECLARE_GPU_SPEC(T, Dims) \
template <> \
void ReverseSequence<GPUDevice, T, Dims>::Compute( \
const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
int32 batch_dim, int32 seq_dim, TTypes<int64>::ConstVec seq_lens, \
typename TTypes<T, Dims>::Tensor output); \
extern template struct ReverseSequence<GPUDevice, T, Dims>;
#define DECLARE_GPU_SPECS(T) \

View File

@ -29,15 +29,19 @@ template <typename T, size_t Dims>
class ReverseGenerator {
public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 seq_dim,
TTypes<int64>::ConstVec seq_lengths)
: input_(input), seq_dim_(seq_dim), seq_lengths_(seq_lengths) {}
ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 batch_dim,
int32 seq_dim, TTypes<int64>::ConstVec seq_lengths)
: input_(input),
batch_dim_(batch_dim),
seq_dim_(seq_dim),
seq_lengths_(seq_lengths) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const Eigen::array<Eigen::DenseIndex, Dims>& coords) const {
Eigen::array<Eigen::DenseIndex, Dims> new_coords = coords;
if (coords[seq_dim_] < seq_lengths_(coords[0])) {
new_coords[seq_dim_] = seq_lengths_(coords[0]) - coords[seq_dim_] - 1;
if (coords[seq_dim_] < seq_lengths_(coords[batch_dim_])) {
new_coords[seq_dim_] =
seq_lengths_(coords[batch_dim_]) - coords[seq_dim_] - 1;
}
return input_(new_coords);
@ -45,6 +49,7 @@ class ReverseGenerator {
private:
typename TTypes<T, Dims>::ConstTensor input_;
int32 batch_dim_;
int32 seq_dim_;
TTypes<int64>::ConstVec seq_lengths_;
};
@ -57,9 +62,10 @@ template <typename Device, typename T, size_t Dims>
struct ReverseSequence {
EIGEN_ALWAYS_INLINE static void Compute(
const Device& d, typename TTypes<T, Dims>::ConstTensor input,
int32 seq_dim, TTypes<int64>::ConstVec seq_lengths,
int32 batch_dim, int32 seq_dim, TTypes<int64>::ConstVec seq_lengths,
typename TTypes<T, Dims>::Tensor output) {
generator::ReverseGenerator<T, Dims> generator(input, seq_dim, seq_lengths);
generator::ReverseGenerator<T, Dims> generator(input, batch_dim, seq_dim,
seq_lengths);
output.device(d) = input.generate(generator);
}
};

View File

@ -0,0 +1,112 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/nn_ops.cc.
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/numeric_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/softsign_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/tensor.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T>
class SoftsignOp : public UnaryElementWiseOp<T, SoftsignOp<Device, T>> {
public:
using UnaryElementWiseOp<T, SoftsignOp<Device, T>>::UnaryElementWiseOp;
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Softsign<Device, T> functor;
functor(context->eigen_device<Device>(), input.flat<T>(),
output->flat<T>());
}
};
template <typename Device, typename T>
class SoftsignGradOp
: public BinaryElementWiseOp<T, SoftsignGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>::BinaryElementWiseOp;
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): inputs that were passed to SoftsignOp()
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OP_REQUIRES(context, a.IsSameSize(g),
errors::InvalidArgument("g and a must be the same size"));
functor::SoftsignGrad<Device, T> functor;
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
output->flat<T>());
}
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Softsign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SoftsignOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SoftsignGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SoftsignGradOp<CPUDevice, type>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void Softsign<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
extern template struct Softsign<GPUDevice, T>; \
\
template <> \
void SoftsignGrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor backprops); \
extern template struct SoftsignGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
// Registration of the GPU implementations.
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Softsign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SoftsignOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SoftsignGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SoftsignGradOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -0,0 +1,60 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
#define TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
// Functor definition for SoftsignOp and SoftsignGradOp, must be compilable by
// nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
namespace functor {
// Functor used by SoftsignOp to do the computations.
template <typename Device, typename T>
struct Softsign {
// Computes Softsign activation.
//
// features: any shape.
// activations: same shape as "features".
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor activations) {
activations.device(d) =
features / (features.abs() + features.constant(1.0f));
}
};
// Functor used by SoftsignGradOp to do the computations.
template <typename Device, typename T>
struct SoftsignGrad {
// Computes SoftsignGrad backprops.
//
// gradients: gradients backpropagated to the Softsign op.
// features: inputs that were passed to the Softsign op.
// backprops: gradients to backpropagate to the Softsign inputs.
void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor backprops) {
backprops.device(d) =
gradients / (features.abs() + features.constant(1.0f)).square();
}
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_SOFTSIGN_OP_H_

View File

@ -0,0 +1,40 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include <stdio.h>
#include "tensorflow/core/kernels/softsign_op.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
// Definition of the GPU implementations declared in softsign_op.cc.
#define DEFINE_GPU_KERNELS(T) \
template struct functor::Softsign<GPUDevice, T>; \
template struct functor::SoftsignGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
} // end namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -33,7 +33,7 @@ void Split<Device, T>::operator()(
typename TTypes<T, 3>::ConstTensor input,
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
output.device(d) = input.slice(slice_indices, slice_sizes);
To32Bit(output).device(d) = To32Bit(input).slice(slice_indices, slice_sizes);
}
#define DEFINE_GPU_KERNELS(T) template struct Split<Eigen::GpuDevice, T>;

View File

@ -1,3 +1,18 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/data_flow_ops.cc.
#include <limits.h>

View File

@ -404,8 +404,9 @@ Reshapes a tensor.
Given `tensor`, this operation returns a tensor that has the same values
as `tensor` with shape `shape`.
If `shape` is the special value `[-1]`, then `tensor` is flattened and the
operation outputs a 1-D tensor with all elements of `tensor`.
If one component of `shape` is the special value -1, the size of that dimension
is computed so that the total size remains constant. In particular, a `shape`
of `[-1]` flattens into 1-D. At most one component of `shape` can be -1.
If `shape` is 1-D or higher, then the operation returns a tensor with shape
`shape` filled with the values of `tensor`. In this case, the number of elements
@ -435,6 +436,13 @@ reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
# tensor 't' has shape [3, 2, 3]
# pass '[-1]' to flatten 't'
reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
# -1 can also be used with higher dimensional shapes
reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
[4, 4, 4, 5, 5, 5, 6, 6, 6]]
# tensor 't' is [7]
# shape `[]` reshapes to a scalar
reshape(t, []) ==> 7
```
shape: Defines the shape of the output tensor.
@ -535,25 +543,29 @@ REGISTER_OP("ReverseSequence")
.Input("seq_lengths: int64")
.Output("output: T")
.Attr("seq_dim: int")
.Attr("batch_dim: int = 0")
.Attr("T: type")
.Doc(R"doc(
Reverses variable length slices in dimension `seq_dim`.
Reverses variable length slices.
This op first slices `input` along the first dimension, and for each slice `i`,
reverses the first `seq_lengths[i]` elements along the dimension `seq_dim`.
This op first slices `input` along the dimension `batch_dim`, and for each
slice `i`, reverses the first `seq_lengths[i]` elements along
the dimension `seq_dim`.
The elements of `seq_lengths` must obey `seq_lengths[i] < input.dims[seq_dim]`,
and `seq_lengths` must be a vector of length `input.dims(0)`.
and `seq_lengths` must be a vector of length `input.dims[batch_dim]`.
The output slice `i` along dimension 0 is then given by input slice `i`, with
the first `seq_lengths[i]` slices along dimension `seq_dim` reversed.
The output slice `i` along dimension `batch_dim` is then given by input
slice `i`, with the first `seq_lengths[i]` slices along dimension
`seq_dim` reversed.
For example:
```prettyprint
# Given this:
batch_dim = 0
seq_dim = 1
input.dims = (4, ...)
input.dims = (4, 8, ...)
seq_lengths = [7, 2, 3, 5]
# then slices of input are reversed on seq_dim, but only up to seq_lengths:
@ -569,10 +581,32 @@ output[2, 3:, :, ...] = input[2, 3:, :, ...]
output[3, 2:, :, ...] = input[3, 2:, :, ...]
```
In contrast, if:
```prettyprint
# Given this:
batch_dim = 2
seq_dim = 0
input.dims = (8, ?, 4, ...)
seq_lengths = [7, 2, 3, 5]
# then slices of input are reversed on seq_dim, but only up to seq_lengths:
output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]
output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]
output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]
output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...]
# while entries past seq_lens are copied through:
output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]
output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]
output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]
output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]
```
input: The input to reverse.
seq_lengths: 1-D with length `input.dims(0)` and
`max(seq_lengths) < input.dims(seq_dim)`
seq_dim: The dimension which is partially reversed.
batch_dim: The dimension along which reversal is performed.
output: The partially reversed input. It has the same shape as `input`.
)doc");

View File

@ -264,7 +264,7 @@ Returns element-wise smallest integer in not less than x.
#define BINARY_MORE() \
Input("x: T").Input("y: T").Output("z: T").Attr( \
"T: {float, double, int8, int16, int32, complex64, int64}")
"T: {float, double, uint8, int8, int16, int32, int64, complex64}")
#define BINARY_FEWER() \
Input("x: T").Input("y: T").Output("z: T").Attr( \
@ -293,7 +293,7 @@ Returns x * y element-wise.
)doc");
REGISTER_OP("Div")
.BINARY_FEWER()
.BINARY_MORE()
.Doc(R"doc(
Returns x / y element-wise.
)doc");

View File

@ -466,6 +466,27 @@ features: The features passed as input to the corresponding softplus operation.
backprops: The gradients: `gradients / (1 + exp(-features))`.
)doc");
REGISTER_OP("Softsign")
.Input("features: T")
.Output("activations: T")
.Attr("T: realnumbertype")
.Doc(R"doc(
Computes softsign: `features / (abs(features) + 1)`.
)doc");
REGISTER_OP("SoftsignGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
.Attr("T: realnumbertype")
.Doc(R"doc(
Computes softsign gradients for a softsign operation.
gradients: The backpropagated gradients to the corresponding softsign operation.
features: The features passed as input to the corresponding softsign operation.
backprops: The gradients: `gradients / (1 + abs(-features)) ** 2`.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("Softmax")

View File

@ -44,11 +44,12 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
type: DT_COMPLEX64
}
}
}
@ -1973,9 +1974,12 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
type: DT_COMPLEX64
}
}
}
@ -4251,11 +4255,12 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
type: DT_COMPLEX64
}
}
}
@ -5532,7 +5537,7 @@ op {
type: "type"
}
summary: "Reshapes a tensor."
description: "Given `tensor`, this operation returns a tensor that has the same values\nas `tensor` with shape `shape`.\n\nIf `shape` is the special value `[-1]`, then `tensor` is flattened and the\noperation outputs a 1-D tensor with all elements of `tensor`.\n\nIf `shape` is 1-D or higher, then the operation returns a tensor with shape\n`shape` filled with the values of `tensor`. In this case, the number of elements\nimplied by `shape` must be the same as the number of elements in `tensor`.\n\nFor example:\n\n```prettyprint\n# tensor \'t\' is [1, 2, 3, 4, 5, 6, 7, 8, 9]\n# tensor \'t\' has shape [9]\nreshape(t, [3, 3]) ==> [[1, 2, 3]\n [4, 5, 6]\n [7, 8, 9]]\n\n# tensor \'t\' is [[[1, 1], [2, 2]]\n# [[3, 3], [4, 4]]]\n# tensor \'t\' has shape [2, 2, 2]\nreshape(t, [2, 4]) ==> [[1, 1, 2, 2]\n [3, 3, 4, 4]]\n\n# tensor \'t\' is [[[1, 1, 1],\n# [2, 2, 2]],\n# [[3, 3, 3],\n# [4, 4, 4]],\n# [[5, 5, 5],\n# [6, 6, 6]]]\n# tensor \'t\' has shape [3, 2, 3]\n# pass \'[-1]\' to flatten \'t\'\nreshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]\n```"
description: "Given `tensor`, this operation returns a tensor that has the same values\nas `tensor` with shape `shape`.\n\nIf one component of `shape` is the special value -1, the size of that dimension\nis computed so that the total size remains constant. In particular, a `shape`\nof `[-1]` flattens into 1-D. At most one component of `shape` can be -1.\n\nIf `shape` is 1-D or higher, then the operation returns a tensor with shape\n`shape` filled with the values of `tensor`. In this case, the number of elements\nimplied by `shape` must be the same as the number of elements in `tensor`.\n\nFor example:\n\n```prettyprint\n# tensor \'t\' is [1, 2, 3, 4, 5, 6, 7, 8, 9]\n# tensor \'t\' has shape [9]\nreshape(t, [3, 3]) ==> [[1, 2, 3]\n [4, 5, 6]\n [7, 8, 9]]\n\n# tensor \'t\' is [[[1, 1], [2, 2]]\n# [[3, 3], [4, 4]]]\n# tensor \'t\' has shape [2, 2, 2]\nreshape(t, [2, 4]) ==> [[1, 1, 2, 2]\n [3, 3, 4, 4]]\n\n# tensor \'t\' is [[[1, 1, 1],\n# [2, 2, 2]],\n# [[3, 3, 3],\n# [4, 4, 4]],\n# [[5, 5, 5],\n# [6, 6, 6]]]\n# tensor \'t\' has shape [3, 2, 3]\n# pass \'[-1]\' to flatten \'t\'\nreshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]\n# -1 can also be used with higher dimensional shapes\nreshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],\n [4, 4, 4, 5, 5, 5, 6, 6, 6]]\n\n# tensor \'t\' is [7]\n# shape `[]` reshapes to a scalar\nreshape(t, []) ==> 7\n```"
}
op {
name: "ResizeArea"
@ -6770,6 +6775,67 @@ op {
}
summary: "Computes softplus gradients for a softplus operation."
}
op {
name: "Softsign"
input_arg {
name: "features"
type_attr: "T"
}
output_arg {
name: "activations"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_UINT8
type: DT_INT16
type: DT_INT8
}
}
}
summary: "Computes softsign: `features / (abs(features) + 1)`."
}
op {
name: "SoftsignGrad"
input_arg {
name: "gradients"
description: "The backpropagated gradients to the corresponding softsign operation."
type_attr: "T"
}
input_arg {
name: "features"
description: "The features passed as input to the corresponding softsign operation."
type_attr: "T"
}
output_arg {
name: "backprops"
description: "The gradients: `gradients / (1 + abs(-features)) ** 2`."
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_UINT8
type: DT_INT16
type: DT_INT8
}
}
}
summary: "Computes softsign gradients for a softsign operation."
}
op {
name: "SparseApplyAdagrad"
input_arg {

View File

@ -12,7 +12,7 @@ process.
First, bring in tensorflow python dependency
//third_party/tensorflow:tensorflow_py
//third_party/py/tensorflow
to get the python TensorFlow API.
@ -22,9 +22,9 @@ Then:
import tensorflow as tf
with tf.Session("local"):
input1 = tf.Constant(1.0, shape=[1, 1], name="input1")
input2 = tf.Constant(2.0, shape=[1, 1], name="input2")
output = tf.MatMul(input1, input2)
input1 = tf.constant(1.0, shape=[1, 1], name="input1")
input2 = tf.constant(2.0, shape=[1, 1], name="input2")
output = tf.matmul(input1, input2)
# Run graph and fetch the output
result = output.eval()

View File

@ -64,11 +64,13 @@ TF_DEFINE_string(image,
"tensorflow/examples/label_image/data/grace_hopper.jpg",
"The image to classify (JPEG or PNG).");
TF_DEFINE_string(graph,
"tensorflow/examples/label_image/data/googlenet_graph.pb",
"tensorflow/examples/label_image/data/"
"tensorflow_inception_graph.pb",
"The location of the GraphDef file containing the protobuf"
" definition of the network.");
TF_DEFINE_string(labels,
"tensorflow/examples/label_image/data/googlenet_labels.txt",
"tensorflow/examples/label_image/data/"
"imagenet_comp_graph_label_strings.txt",
"A text file containing the labels of all the categories, one"
" per line.");
TF_DEFINE_int32(input_width, 224, "Width of the image the network expects.");
@ -85,6 +87,10 @@ TF_DEFINE_string(root_dir, "", "The directory at the root of the data files.");
// of the result is a multiple of 16, because our model expects that.
Status ReadLabelsFile(string file_name, std::vector<string>* result) {
std::ifstream file(file_name);
if (!file) {
return tensorflow::errors::NotFound("Labels file ", file_name,
" not found.");
}
result->clear();
string line;
while (std::getline(file, line)) {

View File

@ -277,8 +277,9 @@ Reshapes a tensor.
Given `tensor`, this operation returns a tensor that has the same values
as `tensor` with shape `shape`.
If `shape` is the special value `[-1]`, then `tensor` is flattened and the
operation outputs a 1-D tensor with all elements of `tensor`.
If one component of `shape` is the special value -1, the size of that dimension
is computed so that the total size remains constant. In particular, a `shape`
of `[-1]` flattens into 1-D. At most one component of `shape` can be -1.
If `shape` is 1-D or higher, then the operation returns a tensor with shape
`shape` filled with the values of `tensor`. In this case, the number of elements
@ -308,6 +309,13 @@ reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
# tensor 't' has shape [3, 2, 3]
# pass '[-1]' to flatten 't'
reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
# -1 can also be used with higher dimensional shapes
reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
[4, 4, 4, 5, 5, 5, 6, 6, 6]]
# tensor 't' is [7]
# shape `[]` reshapes to a scalar
reshape(t, []) ==> 7
```
##### Args:

View File

@ -1355,7 +1355,7 @@ for more details.
- - -
### `tf.convert_to_tensor(value, dtype=None, name=None)` {#convert_to_tensor}
### `tf.convert_to_tensor(value, dtype=None, name=None, as_ref=False)` {#convert_to_tensor}
Converts the given `value` to a `Tensor`.
@ -1390,6 +1390,7 @@ and scalars in addition to `Tensor` objects.
* <b>`dtype`</b>: Optional element type for the returned tensor. If missing, the
type is inferred from the type of `value`.
* <b>`name`</b>: Optional name to use if a new `Tensor` is created.
* <b>`as_ref`</b>: True if we want the result as a ref tensor.
##### Returns:

View File

@ -18,7 +18,8 @@ are all of variable size. If you need fixed size images, pass the output of
the decode Ops to one of the cropping and resizing Ops.
Note: The PNG encode and decode Ops support RGBA, but the conversions Ops
presently only support RGB, HSV, and GrayScale.
presently only support RGB, HSV, and GrayScale. Presently, the alpha channel has
to be stripped from the image and re-attached using slicing ops.
- - -
@ -204,10 +205,6 @@ image = tf.image.decode_jpeg(...)
resized_image = tf.image.resize_bilinear(image, [299, 299])
```
<i>Maybe refer to the Queue examples that show how to add images to a Queue
after resizing them to a fixed size, and how to dequeue batches of resized
images from the Queue.</i>
- - -
### `tf.image.resize_images(images, new_height, new_width, method=0)` {#resize_images}
@ -661,6 +658,43 @@ See also `transpose()`.
## Converting Between Colorspaces.
Internally, images are either stored in as one `float32` per channel per pixel
(implicitly, values are assumed to lie in `[0,1)`) or one `uint8` per channel
per pixel (values are assumed to lie in `[0,255]`).
- - -
### `tf.image.convert_image_dtype(image, dtype, name=None)` {#convert_image_dtype}
Convert `image` to `dtype`, scaling its values if needed.
Images that are represented using floating point values are expected to have
values in the range [0,1). Image data stored in integer data types are
expected to have values in the range `[0,MAX]`, wbere `MAX` is the largest
positive representable number for the data type.
This op converts between data types, scaling the values appropriately before
casting.
Note that for floating point inputs, this op expects values to lie in [0,1).
Conversion of an image containing values outside that range may lead to
overflow errors when converted to integer `Dtype`s.
##### Args:
* <b>`image`</b>: An image.
* <b>`dtype`</b>: A `DType` to convert `image` to.
* <b>`name`</b>: A name for this operation (optional).
##### Returns:
`image`, converted to `dtype`.
## Image Adjustments
TensorFlow provides functions to adjust images in various ways: brightness,

View File

@ -194,6 +194,7 @@
* **[Images](../../api_docs/python/image.md)**:
* [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness)
* [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast)
* [`convert_image_dtype`](../../api_docs/python/image.md#convert_image_dtype)
* [`crop_to_bounding_box`](../../api_docs/python/image.md#crop_to_bounding_box)
* [`decode_jpeg`](../../api_docs/python/image.md#decode_jpeg)
* [`decode_png`](../../api_docs/python/image.md#decode_png)
@ -283,6 +284,7 @@
* [`nce_loss`](../../api_docs/python/nn.md#nce_loss)
* [`relu`](../../api_docs/python/nn.md#relu)
* [`relu6`](../../api_docs/python/nn.md#relu6)
* [`rnn`](../../api_docs/python/nn.md#rnn)
* [`sampled_softmax_loss`](../../api_docs/python/nn.md#sampled_softmax_loss)
* [`separable_conv2d`](../../api_docs/python/nn.md#separable_conv2d)
* [`sigmoid`](../../api_docs/python/nn.md#sigmoid)
@ -290,6 +292,8 @@
* [`softmax`](../../api_docs/python/nn.md#softmax)
* [`softmax_cross_entropy_with_logits`](../../api_docs/python/nn.md#softmax_cross_entropy_with_logits)
* [`softplus`](../../api_docs/python/nn.md#softplus)
* [`softsign`](../../api_docs/python/nn.md#softsign)
* [`state_saving_rnn`](../../api_docs/python/nn.md#state_saving_rnn)
* [`tanh`](../../api_docs/python/nn.md#tanh)
* [`top_k`](../../api_docs/python/nn.md#top_k)
* [`uniform_candidate_sampler`](../../api_docs/python/nn.md#uniform_candidate_sampler)

View File

@ -1773,6 +1773,12 @@ Output strings (e.g. filenames) to a queue for an input pipeline.
A queue with the output strings. A `QueueRunner` for the Queue
is added to the current `Graph`'s `QUEUE_RUNNER` collection.
##### Raises:
* <b>`ValueError`</b>: If the string_tensor is a null Python list. At runtime,
will fail with an assertion if string_tensor becomes a null tensor.
### Batching at the end of an input pipeline

View File

@ -23,7 +23,7 @@ Returns x + y element-wise.
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int8`, `int16`, `int32`, `complex64`, `int64`.
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
@ -59,7 +59,7 @@ Returns x * y element-wise.
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int8`, `int16`, `int32`, `complex64`, `int64`.
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
@ -77,7 +77,7 @@ Returns x / y element-wise.
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `complex64`, `int64`.
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `complex64`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).

View File

@ -9,11 +9,10 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
## Activation Functions
The activation ops provide different types of nonlinearities for use in
neural networks. These include smooth nonlinearities (`sigmoid`,
`tanh`, and `softplus`), continuous but not everywhere differentiable
functions (`relu`, `relu6`, and `relu_x`), and random regularization
(`dropout`).
The activation ops provide different types of nonlinearities for use in neural
networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `softplus`,
and `softsign`), continuous but not everywhere differentiable functions (`relu`,
`relu6`, and `relu_x`), and random regularization (`dropout`).
All activation ops apply componentwise, and produce a tensor of the same
shape as the input tensor.
@ -62,6 +61,23 @@ Computes softplus: `log(exp(features) + 1)`.
##### Args:
* <b>`features`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A `Tensor`. Has the same type as `features`.
- - -
### `tf.nn.softsign(features, name=None)` {#softsign}
Computes softsign: `features / (abs(features) + 1)`.
##### Args:
* <b>`features`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`.
* <b>`name`</b>: A name for the operation (optional).
@ -1228,3 +1244,89 @@ target classes as noise classes for the same example.
Each value is `-FLOAT_MAX`.
## Other Functions and Classes
- - -
### `tf.nn.rnn(cell, inputs, initial_state=None, dtype=None, sequence_length=None, scope=None)` {#rnn}
Creates a recurrent neural network specified by RNNCell "cell".
##### The simplest form of RNN network generated is:
state = cell.zero_state(...)
outputs = []
states = []
for input_ in inputs:
output, state = cell(input_, state)
outputs.append(output)
states.append(state)
return (outputs, states)
However, a few other options are available:
An initial state can be provided.
If sequence_length is provided, dynamic calculation is performed.
Dynamic calculation returns, at time t:
(t >= max(sequence_length)
? (zeros(output_shape), zeros(state_shape))
: cell(input, state)
Thus saving computational time when unrolling past the max sequence length.
##### Args:
* <b>`cell`</b>: An instance of RNNCell.
* <b>`inputs`</b>: A length T list of inputs, each a vector with shape [batch_size].
* <b>`initial_state`</b>: (optional) An initial state for the RNN. This must be
a tensor of appropriate type and shape [batch_size x cell.state_size].
* <b>`dtype`</b>: (optional) The data type for the initial state. Required if
initial_state is not provided.
* <b>`sequence_length`</b>: An int64 vector (tensor) size [batch_size].
* <b>`scope`</b>: VariableScope for the created subgraph; defaults to "RNN".
##### Returns:
A pair (outputs, states) where:
outputs is a length T list of outputs (one for each input)
states is a length T list of states (one state following each input)
##### Raises:
* <b>`TypeError`</b>: If "cell" is not an instance of RNNCell.
* <b>`ValueError`</b>: If inputs is None or an empty list.
- - -
### `tf.nn.state_saving_rnn(cell, inputs, state_saver, state_name, sequence_length=None, scope=None)` {#state_saving_rnn}
RNN that accepts a state saver for time-truncated RNN calculation.
##### Args:
* <b>`cell`</b>: An instance of RNNCell.
* <b>`inputs`</b>: A length T list of inputs, each a vector with shape [batch_size].
* <b>`state_saver`</b>: A state saver object with methods `state` and `save_state`.
* <b>`state_name`</b>: The name to use with the state_saver.
* <b>`sequence_length`</b>: (optional) An int64 vector (tensor) size [batch_size].
See the documentation for rnn() for more details about sequence_length.
* <b>`scope`</b>: VariableScope for the created subgraph; defaults to "RNN".
##### Returns:
A pair (outputs, states) where:
outputs is a length T list of outputs (one for each input)
states is a length T list of states (one state following each input)
##### Raises:
* <b>`TypeError`</b>: If "cell" is not an instance of RNNCell.
* <b>`ValueError`</b>: If inputs is None or an empty list.

View File

@ -43,23 +43,23 @@ dense[tuple(indices[i])] = values[i]
```
By convention, `indices` should be sorted in row-major order (or equivalently
lexigraphic order on the tuples `indices[i]`). This is not enforced when
`SparseTensor` objects are constructed, but most Ops assume correct ordering.
lexicographic order on the tuples `indices[i]`). This is not enforced when
`SparseTensor` objects are constructed, but most ops assume correct ordering.
If the ordering is wrong, it can be fixed by calling `sparse_reorder` on the
misordered `SparseTensor`.
Example: The sparse tensor
```python
SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4])
SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4])
```
represents the dense tensor
```python
[[1, 0, 0, 0]
[0, 0, 2, 0]
[0, 0, 0, 0]]
[[1, 0, 0, 0]
[0, 0, 2, 0]
[0, 0, 0, 0]]
```
- - -
@ -73,7 +73,7 @@ Creates a `SparseTensor`.
* <b>`indices`</b>: A 2-D int64 tensor of shape `[N, ndims]`.
* <b>`values`</b>: A 1-D tensor of any type and shape `[N]`.
* <b>`dense_shape`</b>: A 1-D int64 tensor of shape `[ndims]`.
* <b>`shape`</b>: A 1-D int64 tensor of shape `[ndims]`.
##### Returns:

View File

@ -380,6 +380,51 @@ The `Operation` of this variable.
#### Other Methods
- - -
#### `tf.Variable.ref()` {#Variable.ref}
Returns a reference to this variable.
You usually do not need to call this method as all ops that need a reference
to the variable call it automatically.
Returns is a `Tensor` which holds a reference to the variable. You can
assign a new value to the variable by passing the tensor to an assign op.
See [`value()`](#Variable.value) if you want to get the value of the
variable.
##### Returns:
A `Tensor` that is a reference to the variable.
- - -
#### `tf.Variable.value()` {#Variable.value}
Returns the last snapshot of this variable.
You usually do not need to call this method as all ops that need the value
of the variable call it automatically through a `convert_to_tensor()` call.
Returns a `Tensor` which holds the value of the variable. You can not
assign a new value to this tensor as it is not a reference to the variable.
See [`ref()`](#Variable.ref) if you want to get a reference to the
variable.
To avoid copies, if the consumer of the returned value is on the same device
as the variable, this actually returns the live value of the variable, not
a copy. Updates to the variable are seen by the consumer. If the consumer
is on a different device it will get a copy of the variable.
##### Returns:
A `Tensor` containing the value of the variable.
## Variable helper functions

View File

@ -192,6 +192,7 @@ applies gradients.
* <b>`TypeError`</b>: if `grads_and_vars` is malformed.
* <b>`ValueError`</b>: if none of the variables have gradients.
@ -388,9 +389,9 @@ current good choice is 1.0 or 0.1.
* <b>`beta1`</b>: A float value or a constant float tensor.
The exponential decay rate for the 1st moment estimates.
* <b>`beta2`</b>: A float value or a constant float tensor.
The exponential decay rate for the 2st moment estimates.
The exponential decay rate for the 2nd moment estimates.
* <b>`epsilon`</b>: A small constant for numerical stability.
* <b>`use_locking`</b>: If True use locks for update operation.s
* <b>`use_locking`</b>: If True use locks for update operations.
* <b>`name`</b>: Optional name for the operations created when applying gradients.
Defaults to "Adam".

View File

@ -274,8 +274,8 @@ tf.placeholder() to create them:
```python
input1 = tf.placeholder(tf.types.float32)
input2 = tf.placeholder(tf.types.float32)
input1 = tf.placeholder(tf.float32)
input2 = tf.placeholder(tf.float32)
output = tf.mul(input1, input2)
with tf.Session() as sess:

View File

@ -22,7 +22,7 @@ to:
* Optionally, write a function to compute gradients for the Op.
* Optionally, write a function that describes the input and output shapes
for the Op. This allows shape inference to work with your Op.
* Test the Op, typically in Python.
* Test the Op, typically in Python. If you define gradients, you can verify them with the Python [`GradientChecker`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/kernel_tests/gradient_checker.py).
[TOC]

View File

@ -24,7 +24,6 @@ import tensorflow.python.platform
import tensorflow as tf
from tensorflow.g3doc.how_tos.adding_an_op import gen_zero_out_op_2
from tensorflow.g3doc.how_tos.adding_an_op import zero_out_grad_2
from tensorflow.python.kernel_tests import gradient_checker
class ZeroOut2Test(tf.test.TestCase):
@ -39,7 +38,7 @@ class ZeroOut2Test(tf.test.TestCase):
shape = (5,)
x = tf.constant([5, 4, 3, 2, 1], dtype=tf.float32)
y = gen_zero_out_op_2.zero_out(x)
err = gradient_checker.ComputeGradientError(x, shape, y, shape)
err = tf.test.compute_gradient_error(x, shape, y, shape)
self.assertLess(err, 1e-4)

View File

@ -53,7 +53,7 @@ def convert_to(images, labels, name):
num_examples = labels.shape[0]
if images.shape[0] != num_examples:
raise ValueError("Images size %d does not match label size %d." %
(dat.shape[0], num_examples))
(images.shape[0], num_examples))
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]

View File

@ -62,18 +62,66 @@ Now that you've modified your graph and have a `SummaryWriter`, you're ready to
start running your network! If you want, you could run the merged summary op
every single step, and record a ton of training data. That's likely to be more
data than you need, though. Instead, consider running the merged summary op
every hundred steps or so, as in the following code example.
every `n` steps.
The code example below is a modification of the [simple MNIST tutorial]
(http://tensorflow.org/tutorials/mnist/beginners/index.md), in which we have
added some summary ops, and run them every ten steps. If you run this and then
launch `tensorboard --logdir=/tmp/mnist_data`, you'll be able to visualize
statistics, such as how the weights or accuracy varied during training.
The code below is an exerpt; full source is [here](mnist_with_summaries.py).
```python
merged_summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph_def)
total_step = 0
while training:
total_step += 1
session.run(training_op)
if total_step % 100 == 0:
summary_str = session.run(merged_summary_op)
summary_writer.add_summary(summary_str, total_step)
# Create the model
x = tf.placeholder("float", [None, 784], name="x-input")
W = tf.Variable(tf.zeros([784,10]), name="weights")
b = tf.Variable(tf.zeros([10], name="bias"))
# use a name scope to organize nodes in the graph visualizer
with tf.name_scope("Wx_b") as scope:
y = tf.nn.softmax(tf.matmul(x,W) + b)
# Add summary ops to collect data
w_hist = tf.histogram_summary("weights", W)
b_hist = tf.histogram_summary("biases", b)
y_hist = tf.histogram_summary("y", y)
# Define loss and optimizer
y_ = tf.placeholder("float", [None,10], name="y-input")
# More name scopes will clean up the graph representation
with tf.name_scope("xent") as scope:
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
ce_summ = tf.scalar_summary("cross entropy", cross_entropy)
with tf.name_scope("train") as scope:
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
with tf.name_scope("test") as scope:
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
accuracy_summary = tf.scalar_summary("accuracy", accuracy)
# Merge all the summaries and write them out to /tmp/mnist_logs
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph_def)
tf.initialize_all_variables().run()
# Train the model, and feed in test data and record summaries every 10 steps
for i in range(1000):
if i % 10 == 0: # Record summary data, and the accuracy
feed = {x: mnist.test.images, y_: mnist.test.labels}
result = sess.run([merged, accuracy], feed_dict=feed)
summary_str = result[0]
acc = result[1]
writer.add_summary(summary_str, i)
print("Accuracy at step %s: %s" % (i, acc))
else:
batch_xs, batch_ys = mnist.train.next_batch(100)
feed = {x: batch_xs, y_: batch_ys}
sess.run(train_step, feed_dict=feed)
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
```
You're now all set to visualize this data using TensorBoard.

View File

@ -0,0 +1,69 @@
"""A very simple MNIST classifer, modified to display data in TensorBoard
See extensive documentation for the original model at
http://tensorflow.org/tutorials/mnist/beginners/index.md
See documentaion on the TensorBoard specific pieces at
http://tensorflow.org/how_tos/summaries_and_tensorboard/index.md
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import data
import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
import tensorflow as tf
sess = tf.InteractiveSession()
# Create the model
x = tf.placeholder("float", [None, 784], name="x-input")
W = tf.Variable(tf.zeros([784,10]), name="weights")
b = tf.Variable(tf.zeros([10], name="bias"))
# use a name scope to organize nodes in the graph visualizer
with tf.name_scope("Wx_b") as scope:
y = tf.nn.softmax(tf.matmul(x,W) + b)
# Add summary ops to collect data
w_hist = tf.histogram_summary("weights", W)
b_hist = tf.histogram_summary("biases", b)
y_hist = tf.histogram_summary("y", y)
# Define loss and optimizer
y_ = tf.placeholder("float", [None,10], name="y-input")
# More name scopes will clean up the graph representation
with tf.name_scope("xent") as scope:
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
ce_summ = tf.scalar_summary("cross entropy", cross_entropy)
with tf.name_scope("train") as scope:
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
with tf.name_scope("test") as scope:
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
accuracy_summary = tf.scalar_summary("accuracy", accuracy)
# Merge all the summaries and write them out to /tmp/mnist_logs
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph_def)
tf.initialize_all_variables().run()
# Train the model, and feed in test data and record summaries every 10 steps
for i in range(1000):
if i % 10 == 0: # Record summary data, and the accuracy
feed = {x: mnist.test.images, y_: mnist.test.labels}
result = sess.run([merged, accuracy], feed_dict=feed)
summary_str = result[0]
acc = result[1]
writer.add_summary(summary_str, i)
print("Accuracy at step %s: %s" % (i, acc))
else:
batch_xs, batch_ys = mnist.train.next_batch(100)
feed = {x: batch_xs, y_: batch_ys}
sess.run(train_step, feed_dict=feed)
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

View File

@ -224,13 +224,13 @@ We describe these interacting operations by manipulating symbolic variables.
Let's create one:
```python
x = tf.placeholder("float", [None, 784])
x = tf.placeholder(tf.float32, [None, 784])
```
`x` isn't a specific value. It's a `placeholder`, a value that we'll input when
we ask TensorFlow to run a computation. We want to be able to input any number
of MNIST images, each flattened into a 784-dimensional vector. We represent
this as a 2d tensor of floating point numbers, with a shape `[None, 784]`.
this as a 2-D tensor of floating-point numbers, with a shape `[None, 784]`.
(Here `None` means that a dimension can be of any length.)
We also need the weights and biases for our model. We could imagine treating
@ -242,7 +242,7 @@ operations. It can be used and even modified by the computation. For machine
learning applications, one generally has the model parameters be `Variable`s.
```python
W = tf.Variable(tf.zeros([784,10]))
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
```
@ -259,10 +259,10 @@ to the output.
We can now implement our model. It only takes one line!
```python
y = tf.nn.softmax(tf.matmul(x,W) + b)
y = tf.nn.softmax(tf.matmul(x, W) + b)
```
First, we multiply `x` by `W` with the expression `tf.matmul(x,W)`. This is
First, we multiply `x` by `W` with the expression `tf.matmul(x, W)`. This is
flipped from when we multiplied them in our equation, where we had \\(Wx\\), as a
small trick
to deal with `x` being a 2D tensor with multiple inputs. We then add `b`, and
@ -301,7 +301,7 @@ To implement cross-entropy we need to first add a new placeholder to input
the correct answers:
```python
y_ = tf.placeholder("float", [None,10])
y_ = tf.placeholder(tf.float32, [None, 10])
```
Then we can implement the cross-entropy, \\(-\sum y'\log(y)\\):

View File

@ -38,6 +38,9 @@ py_test(
size = "small",
srcs = ["word2vec_test.py"],
srcs_version = "PY2AND3",
tags = [
"notsan", # b/25864127
],
deps = [
":word2vec",
"//tensorflow:tensorflow_py",

View File

@ -7,8 +7,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("/tensorflow/tensorflow", "cuda_py_tests")
py_library(
name = "linear",
srcs = [
@ -20,17 +18,6 @@ py_library(
],
)
py_test(
name = "linear_test",
size = "small",
srcs = ["linear_test.py"],
srcs_version = "PY2AND3",
deps = [
":linear",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "rnn_cell",
srcs = [
@ -43,17 +30,6 @@ py_library(
],
)
py_test(
name = "rnn_cell_test",
size = "small",
srcs = ["rnn_cell_test.py"],
srcs_version = "PY2AND3",
deps = [
":rnn_cell",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "package",
srcs = [
@ -79,16 +55,6 @@ py_library(
],
)
cuda_py_tests(
name = "rnn_tests",
srcs = [
"rnn_test.py",
],
additional_deps = [
":rnn",
],
)
py_library(
name = "seq2seq",
srcs = [
@ -101,18 +67,6 @@ py_library(
],
)
py_test(
name = "seq2seq_test",
srcs = [
"seq2seq_test.py",
],
srcs_version = "PY2AND3",
deps = [
":seq2seq",
"//tensorflow:tensorflow_py",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -12,57 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic linear combinations that implicitly generate variables."""
"""Import linear python op for backward compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order,unused-import
import tensorflow.python.platform
import tensorflow as tf
def linear(args, output_size, bias, bias_start=0.0, scope=None):
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
Args:
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
output_size: int, second dimension of W[i].
bias: boolean, whether to add a bias term or not.
bias_start: starting value to initialize the bias; 0 by default.
scope: VariableScope for the created subgraph; defaults to "Linear".
Returns:
A 2D Tensor with shape [batch x output_size] equal to
sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
Raises:
ValueError: if some of the arguments has unspecified or wrong shape.
"""
assert args
if not isinstance(args, (list, tuple)):
args = [args]
# Calculate the total size of arguments on dimension 1.
total_arg_size = 0
shapes = [a.get_shape().as_list() for a in args]
for shape in shapes:
if len(shape) != 2:
raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
if not shape[1]:
raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
else:
total_arg_size += shape[1]
# Now the computation.
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix", [total_arg_size, output_size])
if len(args) == 1:
res = tf.matmul(args[0], matrix)
else:
res = tf.matmul(tf.concat(1, args), matrix)
if not bias:
return res
bias_term = tf.get_variable("Bias", [output_size],
initializer=tf.constant_initializer(bias_start))
return res + bias_term
linear = tf.nn.linear

View File

@ -12,137 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RNN helpers for TensorFlow models."""
"""Import rnn python ops for backward compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
from tensorflow.python.ops import control_flow_ops
def rnn(cell, inputs, initial_state=None, dtype=None,
sequence_length=None, scope=None):
"""Creates a recurrent neural network specified by RNNCell "cell".
The simplest form of RNN network generated is:
state = cell.zero_state(...)
outputs = []
states = []
for input_ in inputs:
output, state = cell(input_, state)
outputs.append(output)
states.append(state)
return (outputs, states)
However, a few other options are available:
An initial state can be provided.
If sequence_length is provided, dynamic calculation is performed.
Dynamic calculation returns, at time t:
(t >= max(sequence_length)
? (zeros(output_shape), zeros(state_shape))
: cell(input, state)
Thus saving computational time when unrolling past the max sequence length.
Args:
cell: An instance of RNNCell.
inputs: A length T list of inputs, each a vector with shape [batch_size].
initial_state: (optional) An initial state for the RNN. This must be
a tensor of appropriate type and shape [batch_size x cell.state_size].
dtype: (optional) The data type for the initial state. Required if
initial_state is not provided.
sequence_length: An int64 vector (tensor) size [batch_size].
scope: VariableScope for the created subgraph; defaults to "RNN".
Returns:
A pair (outputs, states) where:
outputs is a length T list of outputs (one for each input)
states is a length T list of states (one state following each input)
Raises:
TypeError: If "cell" is not an instance of RNNCell.
ValueError: If inputs is None or an empty list.
"""
if not isinstance(cell, rnn_cell.RNNCell):
raise TypeError("cell must be an instance of RNNCell")
if not isinstance(inputs, list):
raise TypeError("inputs must be a list")
if not inputs:
raise ValueError("inputs must not be empty")
outputs = []
states = []
with tf.variable_scope(scope or "RNN"):
batch_size = tf.shape(inputs[0])[0]
if initial_state is not None:
state = initial_state
else:
if not dtype:
raise ValueError("If no initial_state is provided, dtype must be.")
state = cell.zero_state(batch_size, dtype)
if sequence_length: # Prepare variables
zero_output_state = (
tf.zeros(tf.pack([batch_size, cell.output_size]),
inputs[0].dtype),
tf.zeros(tf.pack([batch_size, cell.state_size]),
state.dtype))
max_sequence_length = tf.reduce_max(sequence_length)
for time, input_ in enumerate(inputs):
if time > 0: tf.get_variable_scope().reuse_variables()
# pylint: disable=cell-var-from-loop
def output_state():
return cell(input_, state)
# pylint: enable=cell-var-from-loop
if sequence_length:
(output, state) = control_flow_ops.cond(
time >= max_sequence_length,
lambda: zero_output_state, output_state)
else:
(output, state) = output_state()
outputs.append(output)
states.append(state)
return (outputs, states)
def state_saving_rnn(cell, inputs, state_saver, state_name,
sequence_length=None, scope=None):
"""RNN that accepts a state saver for time-truncated RNN calculation.
Args:
cell: An instance of RNNCell.
inputs: A length T list of inputs, each a vector with shape [batch_size].
state_saver: A state saver object with methods `state` and `save_state`.
state_name: The name to use with the state_saver.
sequence_length: (optional) An int64 vector (tensor) size [batch_size].
See the documentation for rnn() for more details about sequence_length.
scope: VariableScope for the created subgraph; defaults to "RNN".
Returns:
A pair (outputs, states) where:
outputs is a length T list of outputs (one for each input)
states is a length T list of states (one state following each input)
Raises:
TypeError: If "cell" is not an instance of RNNCell.
ValueError: If inputs is None or an empty list.
"""
initial_state = state_saver.state(state_name)
(outputs, states) = rnn(cell, inputs, initial_state=initial_state,
sequence_length=sequence_length, scope=scope)
save_state = state_saver.save_state(state_name, states[-1])
with tf.control_dependencies([save_state]):
outputs[-1] = tf.identity(outputs[-1])
return (outputs, states)
# pylint: disable=g-bad-import-order,wildcard-import,unused-import
import tensorflow.python.platform
from tensorflow.python.ops.rnn import *

View File

@ -12,614 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import rnn_cell python ops for backward compatibility."""
"""Module for constructing RNN Cells."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn import linear
class RNNCell(object):
"""Abstract object representing an RNN cell.
An RNN cell, in the most abstract setting, is anything that has
a state -- a vector of floats of size self.state_size -- and performs some
operation that takes inputs of size self.input_size. This operation
results in an output of size self.output_size and a new state.
This module provides a number of basic commonly used RNN cells, such as
LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number
of operators that allow add dropouts, projections, or embeddings for inputs.
Constructing multi-layer cells is supported by a super-class, MultiRNNCell,
defined later. Every RNNCell must have the properties below and and
implement __call__ with the following signature.
"""
def __call__(self, inputs, state, scope=None):
"""Run this RNN cell on inputs, starting from the given state.
Args:
inputs: 2D Tensor with shape [batch_size x self.input_size].
state: 2D Tensor with shape [batch_size x self.state_size].
scope: VariableScope for the created subgraph; defaults to class name.
Returns:
A pair containing:
- Output: A 2D Tensor with shape [batch_size x self.output_size]
- New state: A 2D Tensor with shape [batch_size x self.state_size].
"""
raise NotImplementedError("Abstract method")
@property
def input_size(self):
"""Integer: size of inputs accepted by this cell."""
raise NotImplementedError("Abstract method")
@property
def output_size(self):
"""Integer: size of outputs produced by this cell."""
raise NotImplementedError("Abstract method")
@property
def state_size(self):
"""Integer: size of state used by this cell."""
raise NotImplementedError("Abstract method")
def zero_state(self, batch_size, dtype):
"""Return state tensor (shape [batch_size x state_size]) filled with 0.
Args:
batch_size: int, float, or unit Tensor representing the batch size.
dtype: the data type to use for the state.
Returns:
A 2D Tensor of shape [batch_size x state_size] filled with zeros.
"""
zeros = tf.zeros(tf.pack([batch_size, self.state_size]), dtype=dtype)
zeros.set_shape([None, self.state_size])
return zeros
class BasicRNNCell(RNNCell):
"""The most basic RNN cell."""
def __init__(self, num_units):
self._num_units = num_units
@property
def input_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
"""Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
with tf.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
output = tf.tanh(linear.linear([inputs, state], self._num_units, True))
return output, output
class GRUCell(RNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
def __init__(self, num_units):
self._num_units = num_units
@property
def input_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) with nunits cells."""
with tf.variable_scope(scope or type(self).__name__): # "GRUCell"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not udpate.
r, u = tf.split(1, 2, linear.linear([inputs, state],
2 * self._num_units, True, 1.0))
r, u = tf.sigmoid(r), tf.sigmoid(u)
with tf.variable_scope("Candidate"):
c = tf.tanh(linear.linear([inputs, r * state], self._num_units, True))
new_h = u * state + (1 - u) * c
return new_h, new_h
class BasicLSTMCell(RNNCell):
"""Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/pdf/1409.2329v5.pdf.
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
Biases of the forget gate are initialized by default to 1 in order to reduce
the scale of forgetting in the beginning of the training.
"""
def __init__(self, num_units, forget_bias=1.0):
self._num_units = num_units
self._forget_bias = forget_bias
@property
def input_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_size(self):
return 2 * self._num_units
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
# Parameters of gates are concatenated into one multiply for efficiency.
c, h = tf.split(1, 2, state)
concat = linear.linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = tf.split(1, 4, concat)
new_c = c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * tf.tanh(j)
new_h = tf.tanh(new_c) * tf.sigmoid(o)
return new_h, tf.concat(1, [new_c, new_h])
class LSTMCell(RNNCell):
"""Long short-term memory unit (LSTM) recurrent network cell.
This implementation is based on:
https://research.google.com/pubs/archive/43905.pdf
Hasim Sak, Andrew Senior, and Francoise Beaufays.
"Long short-term memory recurrent neural network architectures for
large scale acoustic modeling." INTERSPEECH, 2014.
It uses peep-hole connections, optional cell clipping, and an optional
projection layer.
"""
def __init__(self, num_units, input_size,
use_peepholes=False, cell_clip=None,
initializer=None, num_proj=None,
num_unit_shards=1, num_proj_shards=1):
"""Initialize the parameters for an LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell
input_size: int, The dimensionality of the inputs into the LSTM cell
use_peepholes: bool, set True to enable diagonal/peephole connections.
cell_clip: (optional) A float value, if provided the cell state is clipped
by this value prior to the cell output activation.
initializer: (optional) The initializer to use for the weight and
projection matrices.
num_proj: (optional) int, The output dimensionality for the projection
matrices. If None, no projection is performed.
num_unit_shards: How to split the weight matrix. If >1, the weight
matrix is stored across num_unit_shards.
Note that num_unit_shards must evenly divide num_units * 4.
num_proj_shards: How to split the projection matrix. If >1, the
projection matrix is stored across num_proj_shards.
Note that num_proj_shards must evenly divide num_proj
(if num_proj is not None).
Raises:
ValueError: if num_unit_shards doesn't divide 4 * num_units or
num_proj_shards doesn't divide num_proj
"""
self._num_units = num_units
self._input_size = input_size
self._use_peepholes = use_peepholes
self._cell_clip = cell_clip
self._initializer = initializer
self._num_proj = num_proj
self._num_unit_shards = num_unit_shards
self._num_proj_shards = num_proj_shards
if (num_units * 4) % num_unit_shards != 0:
raise ValueError("num_unit_shards must evently divide 4 * num_units")
if num_proj and num_proj % num_proj_shards != 0:
raise ValueError("num_proj_shards must evently divide num_proj")
if num_proj:
self._state_size = num_units + num_proj
self._output_size = num_proj
else:
self._state_size = 2 * num_units
self._output_size = num_units
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._output_size
@property
def state_size(self):
return self._state_size
def __call__(self, input_, state, scope=None):
"""Run one step of LSTM.
Args:
input_: input Tensor, 2D, batch x num_units.
state: state Tensor, 2D, batch x state_size.
scope: VariableScope for the created subgraph; defaults to "LSTMCell".
Returns:
A tuple containing:
- A 2D, batch x output_dim, Tensor representing the output of the LSTM
after reading "input_" when previous state was "state".
Here output_dim is:
num_proj if num_proj was set,
num_units otherwise.
- A 2D, batch x state_size, Tensor representing the new state of LSTM
after reading "input_" when previous state was "state".
"""
num_proj = self._num_units if self._num_proj is None else self._num_proj
c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj])
dtype = input_.dtype
unit_shard_size = (4 * self._num_units) // self._num_unit_shards
with tf.variable_scope(scope or type(self).__name__): # "LSTMCell"
w = tf.concat(
1,
[tf.get_variable("W_%d" % i,
shape=[self.input_size + num_proj, unit_shard_size],
initializer=self._initializer,
dtype=dtype) for i in xrange(self._num_unit_shards)])
b = tf.get_variable(
"B", shape=[4 * self._num_units],
initializer=tf.zeros_initializer, dtype=dtype)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
cell_inputs = tf.concat(1, [input_, m_prev])
i, j, f, o = tf.split(1, 4, tf.nn.bias_add(tf.matmul(cell_inputs, w), b))
# Diagonal connections
if self._use_peepholes:
w_f_diag = tf.get_variable(
"W_F_diag", shape=[self._num_units], dtype=dtype)
w_i_diag = tf.get_variable(
"W_I_diag", shape=[self._num_units], dtype=dtype)
w_o_diag = tf.get_variable(
"W_O_diag", shape=[self._num_units], dtype=dtype)
if self._use_peepholes:
c = (tf.sigmoid(f + 1 + w_f_diag * c_prev) * c_prev +
tf.sigmoid(i + w_i_diag * c_prev) * tf.tanh(j))
else:
c = (tf.sigmoid(f + 1) * c_prev + tf.sigmoid(i) * tf.tanh(j))
if self._cell_clip is not None:
c = tf.clip_by_value(c, -self._cell_clip, self._cell_clip)
if self._use_peepholes:
m = tf.sigmoid(o + w_o_diag * c) * tf.tanh(c)
else:
m = tf.sigmoid(o) * tf.tanh(c)
if self._num_proj is not None:
proj_shard_size = self._num_proj // self._num_proj_shards
w_proj = tf.concat(
1,
[tf.get_variable("W_P_%d" % i,
shape=[self._num_units, proj_shard_size],
initializer=self._initializer,
dtype=dtype)
for i in xrange(self._num_proj_shards)])
# TODO(ebrevdo), use matmulsum
m = tf.matmul(m, w_proj)
return m, tf.concat(1, [c, m])
class OutputProjectionWrapper(RNNCell):
"""Operator adding an output projection to the given cell.
Note: in many cases it may be more efficient to not use this wrapper,
but instead concatenate the whole sequence of your outputs in time,
do the projection on this batch-concated sequence, then split it
if needed or directly feed into a softmax.
"""
def __init__(self, cell, output_size):
"""Create a cell with output projection.
Args:
cell: an RNNCell, a projection to output_size is added to it.
output_size: integer, the size of the output after projection.
Raises:
TypeError: if cell is not an RNNCell.
ValueError: if output_size is not positive.
"""
if not isinstance(cell, RNNCell):
raise TypeError("The parameter cell is not RNNCell.")
if output_size < 1:
raise ValueError("Parameter output_size must be > 0: %d." % output_size)
self._cell = cell
self._output_size = output_size
@property
def input_size(self):
return self._cell.input_size
@property
def output_size(self):
return self._output_size
@property
def state_size(self):
return self._cell.state_size
def __call__(self, inputs, state, scope=None):
"""Run the cell and output projection on inputs, starting from state."""
output, res_state = self._cell(inputs, state)
# Default scope: "OutputProjectionWrapper"
with tf.variable_scope(scope or type(self).__name__):
projected = linear.linear(output, self._output_size, True)
return projected, res_state
class InputProjectionWrapper(RNNCell):
"""Operator adding an input projection to the given cell.
Note: in many cases it may be more efficient to not use this wrapper,
but instead concatenate the whole sequence of your inputs in time,
do the projection on this batch-concated sequence, then split it.
"""
def __init__(self, cell, input_size):
"""Create a cell with input projection.
Args:
cell: an RNNCell, a projection of inputs is added before it.
input_size: integer, the size of the inputs before projection.
Raises:
TypeError: if cell is not an RNNCell.
ValueError: if input_size is not positive.
"""
if not isinstance(cell, RNNCell):
raise TypeError("The parameter cell is not RNNCell.")
if input_size < 1:
raise ValueError("Parameter input_size must be > 0: %d." % input_size)
self._cell = cell
self._input_size = input_size
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._cell.output_size
@property
def state_size(self):
return self._cell.state_size
def __call__(self, inputs, state, scope=None):
"""Run the input projection and then the cell."""
# Default scope: "InputProjectionWrapper"
with tf.variable_scope(scope or type(self).__name__):
projected = linear.linear(inputs, self._cell.input_size, True)
return self._cell(projected, state)
class DropoutWrapper(RNNCell):
"""Operator adding dropout to inputs and outputs of the given cell."""
def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
seed=None):
"""Create a cell with added input and/or output dropout.
Dropout is never used on the state.
Args:
cell: an RNNCell, a projection to output_size is added to it.
input_keep_prob: unit Tensor or float between 0 and 1, input keep
probability; if it is float and 1, no input dropout will be added.
output_keep_prob: unit Tensor or float between 0 and 1, output keep
probability; if it is float and 1, no output dropout will be added.
seed: (optional) integer, the randomness seed.
Raises:
TypeError: if cell is not an RNNCell.
ValueError: if keep_prob is not between 0 and 1.
"""
if not isinstance(cell, RNNCell):
raise TypeError("The parameter cell is not a RNNCell.")
if (isinstance(input_keep_prob, float) and
not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)):
raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d"
% input_keep_prob)
if (isinstance(output_keep_prob, float) and
not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)):
raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d"
% output_keep_prob)
self._cell = cell
self._input_keep_prob = input_keep_prob
self._output_keep_prob = output_keep_prob
self._seed = seed
@property
def input_size(self):
return self._cell.input_size
@property
def output_size(self):
return self._cell.output_size
@property
def state_size(self):
return self._cell.state_size
def __call__(self, inputs, state):
"""Run the cell with the declared dropouts."""
if (not isinstance(self._input_keep_prob, float) or
self._input_keep_prob < 1):
inputs = tf.nn.dropout(inputs, self._input_keep_prob, seed=self._seed)
output, new_state = self._cell(inputs, state)
if (not isinstance(self._output_keep_prob, float) or
self._output_keep_prob < 1):
output = tf.nn.dropout(output, self._output_keep_prob, seed=self._seed)
return output, new_state
class EmbeddingWrapper(RNNCell):
"""Operator adding input embedding to the given cell.
Note: in many cases it may be more efficient to not use this wrapper,
but instead concatenate the whole sequence of your inputs in time,
do the embedding on this batch-concated sequence, then split it and
feed into your RNN.
"""
def __init__(self, cell, embedding_classes=0, embedding=None,
initializer=None):
"""Create a cell with an added input embedding.
Args:
cell: an RNNCell, an embedding will be put before its inputs.
embedding_classes: integer, how many symbols will be embedded.
embedding: Variable, the embedding to use; if None, a new embedding
will be created; if set, then embedding_classes is not required.
initializer: an initializer to use when creating the embedding;
if None, the initializer from variable scope or a default one is used.
Raises:
TypeError: if cell is not an RNNCell.
ValueError: if embedding_classes is not positive.
"""
if not isinstance(cell, RNNCell):
raise TypeError("The parameter cell is not RNNCell.")
if embedding_classes < 1 and embedding is None:
raise ValueError("Pass embedding or embedding_classes must be > 0: %d."
% embedding_classes)
if embedding_classes > 0 and embedding is not None:
if embedding.size[0] != embedding_classes:
raise ValueError("You declared embedding_classes=%d but passed an "
"embedding for %d classes." % (embedding.size[0],
embedding_classes))
if embedding.size[1] != cell.input_size:
raise ValueError("You passed embedding with output size %d and a cell"
" that accepts size %d." % (embedding.size[1],
cell.input_size))
self._cell = cell
self._embedding_classes = embedding_classes
self._embedding = embedding
self._initializer = initializer
@property
def input_size(self):
return 1
@property
def output_size(self):
return self._cell.output_size
@property
def state_size(self):
return self._cell.state_size
def __call__(self, inputs, state, scope=None):
"""Run the cell on embedded inputs."""
with tf.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper"
with tf.device("/cpu:0"):
if self._embedding:
embedding = self._embedding
else:
if self._initializer:
initializer = self._initializer
elif tf.get_variable_scope().initializer:
initializer = tf.get_variable_scope().initializer
else:
# Default initializer for embeddings should have variance=1.
sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
initializer = tf.random_uniform_initializer(-sqrt3, sqrt3)
embedding = tf.get_variable("embedding", [self._embedding_classes,
self._cell.input_size],
initializer=initializer)
embedded = tf.nn.embedding_lookup(embedding, tf.reshape(inputs, [-1]))
return self._cell(embedded, state)
class MultiRNNCell(RNNCell):
"""RNN cell composed sequentially of multiple simple cells."""
def __init__(self, cells):
"""Create a RNN cell composed sequentially of a number of RNNCells.
Args:
cells: list of RNNCells that will be composed in this order.
Raises:
ValueError: if cells is empty (not allowed) or if their sizes don't match.
"""
if not cells:
raise ValueError("Must specify at least one cell for MultiRNNCell.")
for i in xrange(len(cells) - 1):
if cells[i + 1].input_size != cells[i].output_size:
raise ValueError("In MultiRNNCell, the input size of each next"
" cell must match the output size of the previous one."
" Mismatched output size in cell %d." % i)
self._cells = cells
@property
def input_size(self):
return self._cells[0].input_size
@property
def output_size(self):
return self._cells[-1].output_size
@property
def state_size(self):
return sum([cell.state_size for cell in self._cells])
def __call__(self, inputs, state, scope=None):
"""Run this multi-layer cell on inputs, starting from state."""
with tf.variable_scope(scope or type(self).__name__): # "MultiRNNCell"
cur_state_pos = 0
cur_inp = inputs
new_states = []
for i, cell in enumerate(self._cells):
with tf.variable_scope("Cell%d" % i):
cur_state = tf.slice(state, [0, cur_state_pos], [-1, cell.state_size])
cur_state_pos += cell.state_size
cur_inp, new_state = cell(cur_inp, cur_state)
new_states.append(new_state)
return cur_inp, tf.concat(1, new_states)
# pylint: disable=g-bad-import-order,wildcard-import,unused-import
import tensorflow.python.platform
from tensorflow.python.ops.rnn_cell import *

View File

@ -12,757 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import seq2seq python ops for backward compatibility."""
"""Library for creating sequence-to-sequence models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order,wildcard-import,unused-import
import tensorflow.python.platform
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn import linear
from tensorflow.models.rnn import rnn
from tensorflow.models.rnn import rnn_cell
def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,
scope=None):
"""RNN decoder for the sequence-to-sequence model.
Args:
decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
cell: rnn_cell.RNNCell defining the cell function and size.
loop_function: if not None, this function will be applied to i-th output
in order to generate i+1-th input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol). This can be used for decoding,
but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf.
Signature -- loop_function(prev, i) = next
* prev is a 2D Tensor of shape [batch_size x cell.output_size],
* i is an integer, the step number (when advanced control is needed),
* next is a 2D Tensor of shape [batch_size x cell.input_size].
scope: VariableScope for the created subgraph; defaults to "rnn_decoder".
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x cell.output_size] containing generated outputs.
states: The state of each cell in each time-step. This is a list with
length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
(Note that in some cases, like basic RNN cell or GRU cell, outputs and
states can be the same. They are different for LSTM cells though.)
"""
with tf.variable_scope(scope or "rnn_decoder"):
states = [initial_state]
outputs = []
prev = None
for i in xrange(len(decoder_inputs)):
inp = decoder_inputs[i]
if loop_function is not None and prev is not None:
with tf.variable_scope("loop_function", reuse=True):
# We do not propagate gradients over the loop function.
inp = tf.stop_gradient(loop_function(prev, i))
if i > 0:
tf.get_variable_scope().reuse_variables()
output, new_state = cell(inp, states[-1])
outputs.append(output)
states.append(new_state)
if loop_function is not None:
prev = tf.stop_gradient(output)
return outputs, states
def basic_rnn_seq2seq(
encoder_inputs, decoder_inputs, cell, dtype=tf.float32, scope=None):
"""Basic RNN sequence-to-sequence model.
This model first runs an RNN to encode encoder_inputs into a state vector, and
then runs decoder, initialized with the last encoder state, on decoder_inputs.
Encoder and decoder use the same RNN cell type, but don't share parameters.
Args:
encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
cell: rnn_cell.RNNCell defining the cell function and size.
dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x cell.output_size] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with tf.variable_scope(scope or "basic_rnn_seq2seq"):
_, enc_states = rnn.rnn(cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_states[-1], cell)
def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
loop_function=None, dtype=tf.float32, scope=None):
"""RNN sequence-to-sequence model with tied encoder and decoder parameters.
This model first runs an RNN to encode encoder_inputs into a state vector, and
then runs decoder, initialized with the last encoder state, on decoder_inputs.
Encoder and decoder use the same RNN cell and share parameters.
Args:
encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
cell: rnn_cell.RNNCell defining the cell function and size.
loop_function: if not None, this function will be applied to i-th output
in order to generate i+1-th input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol), see rnn_decoder for details.
dtype: The dtype of the initial state of the rnn cell (default: tf.float32).
scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq".
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x cell.output_size] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with tf.variable_scope("combined_tied_rnn_seq2seq"):
scope = scope or "tied_rnn_seq2seq"
_, enc_states = rnn.rnn(
cell, encoder_inputs, dtype=dtype, scope=scope)
tf.get_variable_scope().reuse_variables()
return rnn_decoder(decoder_inputs, enc_states[-1], cell,
loop_function=loop_function, scope=scope)
def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
output_projection=None, feed_previous=False,
scope=None):
"""RNN decoder with embedding and a pure-decoding option.
Args:
decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
cell: rnn_cell.RNNCell defining the cell function.
num_symbols: integer, how many symbols come into the embedding.
output_projection: None or a pair (W, B) of output projection weights and
biases; W has shape [cell.output_size x num_symbols] and B has
shape [num_symbols]; if provided and feed_previous=True, each fed
previous output will first be multiplied by W and added B.
feed_previous: Boolean; if True, only the first of decoder_inputs will be
used (the "GO" symbol), and all other decoder inputs will be generated by:
next = embedding_lookup(embedding, argmax(previous_output)),
In effect, this implements a greedy decoder. It can also be used
during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf.
If False, decoder_inputs are used as given (the standard decoder case).
scope: VariableScope for the created subgraph; defaults to
"embedding_rnn_decoder".
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x cell.output_size] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
Raises:
ValueError: when output_projection has the wrong shape.
"""
if output_projection is not None:
proj_weights = tf.convert_to_tensor(output_projection[0], dtype=tf.float32)
proj_weights.get_shape().assert_is_compatible_with([cell.output_size,
num_symbols])
proj_biases = tf.convert_to_tensor(output_projection[1], dtype=tf.float32)
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with tf.variable_scope(scope or "embedding_rnn_decoder"):
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [num_symbols, cell.input_size])
def extract_argmax_and_embed(prev, _):
"""Loop_function that extracts the symbol from prev and embeds it."""
if output_projection is not None:
prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(embedding, prev_symbol)
loop_function = None
if feed_previous:
loop_function = extract_argmax_and_embed
emb_inp = [tf.nn.embedding_lookup(embedding, i) for i in decoder_inputs]
return rnn_decoder(emb_inp, initial_state, cell,
loop_function=loop_function)
def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
num_encoder_symbols, num_decoder_symbols,
output_projection=None, feed_previous=False,
dtype=tf.float32, scope=None):
"""Embedding RNN sequence-to-sequence model.
This model first embeds encoder_inputs by a newly created embedding (of shape
[num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode
embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs
by another newly created embedding (of shape [num_decoder_symbols x
cell.input_size]). Then it runs RNN decoder, initialized with the last
encoder state, on embedded decoder_inputs.
Args:
encoder_inputs: a list of 1D int32-Tensors of shape [batch_size].
decoder_inputs: a list of 1D int32-Tensors of shape [batch_size].
cell: rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: integer; number of symbols on the encoder side.
num_decoder_symbols: integer; number of symbols on the decoder side.
output_projection: None or a pair (W, B) of output projection weights and
biases; W has shape [cell.output_size x num_decoder_symbols] and B has
shape [num_decoder_symbols]; if provided and feed_previous=True, each
fed previous output will first be multiplied by W and added B.
feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
of decoder_inputs will be used (the "GO" symbol), and all other decoder
inputs will be taken from previous outputs (as in embedding_rnn_decoder).
If False, decoder_inputs are used as given (the standard decoder case).
dtype: The dtype of the initial state for both the encoder and encoder
rnn cells (default: tf.float32).
scope: VariableScope for the created subgraph; defaults to
"embedding_rnn_seq2seq"
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x num_decoder_symbols] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with tf.variable_scope(scope or "embedding_rnn_seq2seq"):
# Encoder.
encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
_, encoder_states = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
if output_projection is None:
cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
if isinstance(feed_previous, bool):
return embedding_rnn_decoder(decoder_inputs, encoder_states[-1], cell,
num_decoder_symbols, output_projection,
feed_previous)
else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
outputs1, states1 = embedding_rnn_decoder(
decoder_inputs, encoder_states[-1], cell, num_decoder_symbols,
output_projection, True)
tf.get_variable_scope().reuse_variables()
outputs2, states2 = embedding_rnn_decoder(
decoder_inputs, encoder_states[-1], cell, num_decoder_symbols,
output_projection, False)
outputs = tf.control_flow_ops.cond(feed_previous,
lambda: outputs1, lambda: outputs2)
states = tf.control_flow_ops.cond(feed_previous,
lambda: states1, lambda: states2)
return outputs, states
def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
num_symbols, output_projection=None,
feed_previous=False, dtype=tf.float32,
scope=None):
"""Embedding RNN sequence-to-sequence model with tied (shared) parameters.
This model first embeds encoder_inputs by a newly created embedding (of shape
[num_symbols x cell.input_size]). Then it runs an RNN to encode embedded
encoder_inputs into a state vector. Next, it embeds decoder_inputs using
the same embedding. Then it runs RNN decoder, initialized with the last
encoder state, on embedded decoder_inputs.
Args:
encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
cell: rnn_cell.RNNCell defining the cell function and size.
num_symbols: integer; number of symbols for both encoder and decoder.
output_projection: None or a pair (W, B) of output projection weights and
biases; W has shape [cell.output_size x num_symbols] and B has
shape [num_symbols]; if provided and feed_previous=True, each
fed previous output will first be multiplied by W and added B.
feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
of decoder_inputs will be used (the "GO" symbol), and all other decoder
inputs will be taken from previous outputs (as in embedding_rnn_decoder).
If False, decoder_inputs are used as given (the standard decoder case).
dtype: The dtype to use for the initial RNN states (default: tf.float32).
scope: VariableScope for the created subgraph; defaults to
"embedding_tied_rnn_seq2seq".
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x num_decoder_symbols] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
Raises:
ValueError: when output_projection has the wrong shape.
"""
if output_projection is not None:
proj_weights = tf.convert_to_tensor(output_projection[0], dtype=dtype)
proj_weights.get_shape().assert_is_compatible_with([cell.output_size,
num_symbols])
proj_biases = tf.convert_to_tensor(output_projection[1], dtype=dtype)
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with tf.variable_scope(scope or "embedding_tied_rnn_seq2seq"):
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [num_symbols, cell.input_size])
emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x)
for x in encoder_inputs]
emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x)
for x in decoder_inputs]
def extract_argmax_and_embed(prev, _):
"""Loop_function that extracts the symbol from prev and embeds it."""
if output_projection is not None:
prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(embedding, prev_symbol)
if output_projection is None:
cell = rnn_cell.OutputProjectionWrapper(cell, num_symbols)
if isinstance(feed_previous, bool):
loop_function = extract_argmax_and_embed if feed_previous else None
return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell,
loop_function=loop_function, dtype=dtype)
else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
outputs1, states1 = tied_rnn_seq2seq(
emb_encoder_inputs, emb_decoder_inputs, cell,
loop_function=extract_argmax_and_embed, dtype=dtype)
tf.get_variable_scope().reuse_variables()
outputs2, states2 = tied_rnn_seq2seq(
emb_encoder_inputs, emb_decoder_inputs, cell, dtype=dtype)
outputs = tf.control_flow_ops.cond(feed_previous,
lambda: outputs1, lambda: outputs2)
states = tf.control_flow_ops.cond(feed_previous,
lambda: states1, lambda: states2)
return outputs, states
def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
output_size=None, num_heads=1, loop_function=None,
dtype=tf.float32, scope=None):
"""RNN decoder with attention for the sequence-to-sequence model.
Args:
decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
initial_state: 2D Tensor [batch_size x cell.state_size].
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
cell: rnn_cell.RNNCell defining the cell function and size.
output_size: size of the output vectors; if None, we use cell.output_size.
num_heads: number of attention heads that read from attention_states.
loop_function: if not None, this function will be applied to i-th output
in order to generate i+1-th input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol). This can be used for decoding,
but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf.
Signature -- loop_function(prev, i) = next
* prev is a 2D Tensor of shape [batch_size x cell.output_size],
* i is an integer, the step number (when advanced control is needed),
* next is a 2D Tensor of shape [batch_size x cell.input_size].
dtype: The dtype to use for the RNN initial state (default: tf.float32).
scope: VariableScope for the created subgraph; default: "attention_decoder".
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors of shape
[batch_size x output_size]. These represent the generated outputs.
Output i is computed from input i (which is either i-th decoder_inputs or
loop_function(output {i-1}, i)) as follows. First, we run the cell
on a combination of the input and previous attention masks:
cell_output, new_state = cell(linear(input, prev_attn), prev_state).
Then, we calculate new attention masks:
new_attn = softmax(V^T * tanh(W * attention_states + U * new_state))
and then we calculate the output:
output = linear(cell_output, new_attn).
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
Raises:
ValueError: when num_heads is not positive, there are no inputs, or shapes
of attention_states are not set.
"""
if not decoder_inputs:
raise ValueError("Must provide at least 1 input to attention decoder.")
if num_heads < 1:
raise ValueError("With less than 1 heads, use a non-attention decoder.")
if not attention_states.get_shape()[1:2].is_fully_defined():
raise ValueError("Shape[1] and [2] of attention_states must be known: %s"
% attention_states.get_shape())
if output_size is None:
output_size = cell.output_size
with tf.variable_scope(scope or "attention_decoder"):
batch_size = tf.shape(decoder_inputs[0])[0] # Needed for reshaping.
attn_length = attention_states.get_shape()[1].value
attn_size = attention_states.get_shape()[2].value
# To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
hidden = tf.reshape(attention_states, [-1, attn_length, 1, attn_size])
hidden_features = []
v = []
attention_vec_size = attn_size # Size of query vectors for attention.
for a in xrange(num_heads):
k = tf.get_variable("AttnW_%d" % a, [1, 1, attn_size, attention_vec_size])
hidden_features.append(tf.nn.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
v.append(tf.get_variable("AttnV_%d" % a, [attention_vec_size]))
states = [initial_state]
def attention(query):
"""Put attention masks on hidden using hidden_features and query."""
ds = [] # Results of attention reads will be stored here.
for a in xrange(num_heads):
with tf.variable_scope("Attention_%d" % a):
y = linear.linear(query, attention_vec_size, True)
y = tf.reshape(y, [-1, 1, 1, attention_vec_size])
# Attention mask is a softmax of v^T * tanh(...).
s = tf.reduce_sum(v[a] * tf.tanh(hidden_features[a] + y), [2, 3])
a = tf.nn.softmax(s)
# Now calculate the attention-weighted vector d.
d = tf.reduce_sum(tf.reshape(a, [-1, attn_length, 1, 1]) * hidden,
[1, 2])
ds.append(tf.reshape(d, [-1, attn_size]))
return ds
outputs = []
prev = None
batch_attn_size = tf.pack([batch_size, attn_size])
attns = [tf.zeros(batch_attn_size, dtype=dtype)
for _ in xrange(num_heads)]
for a in attns: # Ensure the second shape of attention vectors is set.
a.set_shape([None, attn_size])
for i in xrange(len(decoder_inputs)):
if i > 0:
tf.get_variable_scope().reuse_variables()
inp = decoder_inputs[i]
# If loop_function is set, we use it instead of decoder_inputs.
if loop_function is not None and prev is not None:
with tf.variable_scope("loop_function", reuse=True):
inp = tf.stop_gradient(loop_function(prev, i))
# Merge input and previous attentions into one vector of the right size.
x = linear.linear([inp] + attns, cell.input_size, True)
# Run the RNN.
cell_output, new_state = cell(x, states[-1])
states.append(new_state)
# Run the attention mechanism.
attns = attention(new_state)
with tf.variable_scope("AttnOutputProjection"):
output = linear.linear([cell_output] + attns, output_size, True)
if loop_function is not None:
# We do not propagate gradients over the loop function.
prev = tf.stop_gradient(output)
outputs.append(output)
return outputs, states
def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
cell, num_symbols, num_heads=1,
output_size=None, output_projection=None,
feed_previous=False, dtype=tf.float32,
scope=None):
"""RNN decoder with embedding and attention and a pure-decoding option.
Args:
decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
cell: rnn_cell.RNNCell defining the cell function.
num_symbols: integer, how many symbols come into the embedding.
num_heads: number of attention heads that read from attention_states.
output_size: size of the output vectors; if None, use cell.output_size.
output_projection: None or a pair (W, B) of output projection weights and
biases; W has shape [output_size x num_symbols] and B has shape
[num_symbols]; if provided and feed_previous=True, each fed previous
output will first be multiplied by W and added B.
feed_previous: Boolean; if True, only the first of decoder_inputs will be
used (the "GO" symbol), and all other decoder inputs will be generated by:
next = embedding_lookup(embedding, argmax(previous_output)),
In effect, this implements a greedy decoder. It can also be used
during training to emulate http://arxiv.org/pdf/1506.03099v2.pdf.
If False, decoder_inputs are used as given (the standard decoder case).
dtype: The dtype to use for the RNN initial states (default: tf.float32).
scope: VariableScope for the created subgraph; defaults to
"embedding_attention_decoder".
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x output_size] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
Raises:
ValueError: when output_projection has the wrong shape.
"""
if output_size is None:
output_size = cell.output_size
if output_projection is not None:
proj_weights = tf.convert_to_tensor(output_projection[0], dtype=dtype)
proj_weights.get_shape().assert_is_compatible_with([cell.output_size,
num_symbols])
proj_biases = tf.convert_to_tensor(output_projection[1], dtype=dtype)
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with tf.variable_scope(scope or "embedding_attention_decoder"):
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [num_symbols, cell.input_size])
def extract_argmax_and_embed(prev, _):
"""Loop_function that extracts the symbol from prev and embeds it."""
if output_projection is not None:
prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
return emb_prev
loop_function = None
if feed_previous:
loop_function = extract_argmax_and_embed
emb_inp = [tf.nn.embedding_lookup(embedding, i) for i in decoder_inputs]
return attention_decoder(
emb_inp, initial_state, attention_states, cell, output_size=output_size,
num_heads=num_heads, loop_function=loop_function)
def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell,
num_encoder_symbols, num_decoder_symbols,
num_heads=1, output_projection=None,
feed_previous=False, dtype=tf.float32,
scope=None):
"""Embedding sequence-to-sequence model with attention.
This model first embeds encoder_inputs by a newly created embedding (of shape
[num_encoder_symbols x cell.input_size]). Then it runs an RNN to encode
embedded encoder_inputs into a state vector. It keeps the outputs of this
RNN at every step to use for attention later. Next, it embeds decoder_inputs
by another newly created embedding (of shape [num_decoder_symbols x
cell.input_size]). Then it runs attention decoder, initialized with the last
encoder state, on embedded decoder_inputs and attending to encoder outputs.
Args:
encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size].
cell: rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: integer; number of symbols on the encoder side.
num_decoder_symbols: integer; number of symbols on the decoder side.
num_heads: number of attention heads that read from attention_states.
output_projection: None or a pair (W, B) of output projection weights and
biases; W has shape [cell.output_size x num_decoder_symbols] and B has
shape [num_decoder_symbols]; if provided and feed_previous=True, each
fed previous output will first be multiplied by W and added B.
feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
of decoder_inputs will be used (the "GO" symbol), and all other decoder
inputs will be taken from previous outputs (as in embedding_rnn_decoder).
If False, decoder_inputs are used as given (the standard decoder case).
dtype: The dtype of the initial RNN state (default: tf.float32).
scope: VariableScope for the created subgraph; defaults to
"embedding_attention_seq2seq".
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x num_decoder_symbols] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with tf.variable_scope(scope or "embedding_attention_seq2seq"):
# Encoder.
encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
encoder_outputs, encoder_states = rnn.rnn(
encoder_cell, encoder_inputs, dtype=dtype)
# First calculate a concatenation of encoder outputs to put attention on.
top_states = [tf.reshape(e, [-1, 1, cell.output_size])
for e in encoder_outputs]
attention_states = tf.concat(1, top_states)
# Decoder.
output_size = None
if output_projection is None:
cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
output_size = num_decoder_symbols
if isinstance(feed_previous, bool):
return embedding_attention_decoder(
decoder_inputs, encoder_states[-1], attention_states, cell,
num_decoder_symbols, num_heads, output_size, output_projection,
feed_previous)
else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
outputs1, states1 = embedding_attention_decoder(
decoder_inputs, encoder_states[-1], attention_states, cell,
num_decoder_symbols, num_heads, output_size, output_projection, True)
tf.get_variable_scope().reuse_variables()
outputs2, states2 = embedding_attention_decoder(
decoder_inputs, encoder_states[-1], attention_states, cell,
num_decoder_symbols, num_heads, output_size, output_projection, False)
outputs = tf.control_flow_ops.cond(feed_previous,
lambda: outputs1, lambda: outputs2)
states = tf.control_flow_ops.cond(feed_previous,
lambda: states1, lambda: states2)
return outputs, states
def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols,
average_across_timesteps=True,
softmax_loss_function=None, name=None):
"""Weighted cross-entropy loss for a sequence of logits (per example).
Args:
logits: list of 2D Tensors of shape [batch_size x num_decoder_symbols].
targets: list of 1D batch-sized int32-Tensors of the same length as logits.
weights: list of 1D batch-sized float-Tensors of the same length as logits.
num_decoder_symbols: integer, number of decoder symbols (output classes).
average_across_timesteps: If set, divide the returned cost by the total
label weight.
softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch
to be used instead of the standard softmax (the default if this is None).
name: optional name for this operation, default: "sequence_loss_by_example".
Returns:
1D batch-sized float Tensor: the log-perplexity for each sequence.
Raises:
ValueError: if len(logits) is different from len(targets) or len(weights).
"""
if len(targets) != len(logits) or len(weights) != len(logits):
raise ValueError("Lengths of logits, weights, and targets must be the same "
"%d, %d, %d." % (len(logits), len(weights), len(targets)))
with tf.op_scope(logits + targets + weights, name,
"sequence_loss_by_example"):
batch_size = tf.shape(targets[0])[0]
log_perp_list = []
length = batch_size * num_decoder_symbols
for i in xrange(len(logits)):
if softmax_loss_function is None:
# TODO(lukaszkaiser): There is no SparseCrossEntropy in TensorFlow, so
# we need to first cast targets into a dense representation, and as
# SparseToDense does not accept batched inputs, we need to do this by
# re-indexing and re-sizing. When TensorFlow adds SparseCrossEntropy,
# rewrite this method.
indices = targets[i] + num_decoder_symbols * tf.range(batch_size)
with tf.device("/cpu:0"): # Sparse-to-dense must happen on CPU for now.
dense = tf.sparse_to_dense(indices, tf.expand_dims(length, 0), 1.0,
0.0)
target = tf.reshape(dense, [-1, num_decoder_symbols])
crossent = tf.nn.softmax_cross_entropy_with_logits(
logits[i], target, name="SequenceLoss/CrossEntropy{0}".format(i))
else:
crossent = softmax_loss_function(logits[i], targets[i])
log_perp_list.append(crossent * weights[i])
log_perps = tf.add_n(log_perp_list)
if average_across_timesteps:
total_size = tf.add_n(weights)
total_size += 1e-12 # Just to avoid division by 0 for all-0 weights.
log_perps /= total_size
return log_perps
def sequence_loss(logits, targets, weights, num_decoder_symbols,
average_across_timesteps=True, average_across_batch=True,
softmax_loss_function=None, name=None):
"""Weighted cross-entropy loss for a sequence of logits, batch-collapsed.
Args:
logits: list of 2D Tensors os shape [batch_size x num_decoder_symbols].
targets: list of 1D batch-sized int32-Tensors of the same length as logits.
weights: list of 1D batch-sized float-Tensors of the same length as logits.
num_decoder_symbols: integer, number of decoder symbols (output classes).
average_across_timesteps: If set, divide the returned cost by the total
label weight.
average_across_batch: If set, divide the returned cost by the batch size.
softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch
to be used instead of the standard softmax (the default if this is None).
name: optional name for this operation, defaults to "sequence_loss".
Returns:
A scalar float Tensor: the average log-perplexity per symbol (weighted).
Raises:
ValueError: if len(logits) is different from len(targets) or len(weights).
"""
with tf.op_scope(logits + targets + weights, name, "sequence_loss"):
cost = tf.reduce_sum(sequence_loss_by_example(
logits, targets, weights, num_decoder_symbols,
average_across_timesteps=average_across_timesteps,
softmax_loss_function=softmax_loss_function))
if average_across_batch:
batch_size = tf.shape(targets[0])[0]
return cost / tf.cast(batch_size, tf.float32)
else:
return cost
def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights,
buckets, num_decoder_symbols, seq2seq,
softmax_loss_function=None, name=None):
"""Create a sequence-to-sequence model with support for bucketing.
The seq2seq argument is a function that defines a sequence-to-sequence model,
e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24))
Args:
encoder_inputs: a list of Tensors to feed the encoder; first seq2seq input.
decoder_inputs: a list of Tensors to feed the decoder; second seq2seq input.
targets: a list of 1D batch-sized int32-Tensors (desired output sequence).
weights: list of 1D batch-sized float-Tensors to weight the targets.
buckets: a list of pairs of (input size, output size) for each bucket.
num_decoder_symbols: integer, number of decoder symbols (output classes).
seq2seq: a sequence-to-sequence model function; it takes 2 input that
agree with encoder_inputs and decoder_inputs, and returns a pair
consisting of outputs and states (as, e.g., basic_rnn_seq2seq).
softmax_loss_function: function (inputs-batch, labels-batch) -> loss-batch
to be used instead of the standard softmax (the default if this is None).
name: optional name for this operation, defaults to "model_with_buckets".
Returns:
outputs: The outputs for each bucket. Its j'th element consists of a list
of 2D Tensors of shape [batch_size x num_decoder_symbols] (j'th outputs).
losses: List of scalar Tensors, representing losses for each bucket.
Raises:
ValueError: if length of encoder_inputsut, targets, or weights is smaller
than the largest (last) bucket.
"""
if len(encoder_inputs) < buckets[-1][0]:
raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
"st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
if len(targets) < buckets[-1][1]:
raise ValueError("Length of targets (%d) must be at least that of last"
"bucket (%d)." % (len(targets), buckets[-1][1]))
if len(weights) < buckets[-1][1]:
raise ValueError("Length of weights (%d) must be at least that of last"
"bucket (%d)." % (len(weights), buckets[-1][1]))
all_inputs = encoder_inputs + decoder_inputs + targets + weights
losses = []
outputs = []
with tf.op_scope(all_inputs, name, "model_with_buckets"):
for j in xrange(len(buckets)):
if j > 0:
tf.get_variable_scope().reuse_variables()
bucket_encoder_inputs = [encoder_inputs[i]
for i in xrange(buckets[j][0])]
bucket_decoder_inputs = [decoder_inputs[i]
for i in xrange(buckets[j][1])]
bucket_outputs, _ = seq2seq(bucket_encoder_inputs,
bucket_decoder_inputs)
outputs.append(bucket_outputs)
bucket_targets = [targets[i] for i in xrange(buckets[j][1])]
bucket_weights = [weights[i] for i in xrange(buckets[j][1])]
losses.append(sequence_loss(
outputs[-1], bucket_targets, bucket_weights, num_decoder_symbols,
softmax_loss_function=softmax_loss_function))
return outputs, losses
from tensorflow.python.ops.seq2seq import *

View File

@ -34,6 +34,7 @@ py_library(
":client_testlib",
":framework",
":framework_test_lib",
":kernel_tests/gradient_checker",
":platform",
":platform_test",
":summary",
@ -467,6 +468,7 @@ tf_gen_op_wrapper_py(
"ReluGrad",
"Relu6Grad",
"SoftplusGrad",
"SoftsignGrad",
"BiasAdd",
"Relu6",
"AvgPool",
@ -588,6 +590,9 @@ py_library(
"ops/op_def_library.py",
"ops/parsing_ops.py",
"ops/random_ops.py",
"ops/rnn.py",
"ops/rnn_cell.py",
"ops/seq2seq.py",
"ops/sparse_grad.py",
"ops/sparse_ops.py",
"ops/standard_ops.py",

View File

@ -93,8 +93,8 @@ def all_libraries(module_to_name, members, documented):
"max_pool_grad", "max_pool_grad_with_argmax",
"batch_norm_with_global_normalization_grad",
"lrn_grad", "relu6_grad", "softplus_grad",
"xw_plus_b", "relu_layer", "lrn",
"batch_norm_with_global_normalization",
"softsign_grad", "xw_plus_b", "relu_layer",
"lrn", "batch_norm_with_global_normalization",
"batch_norm_with_global_normalization_grad",
"all_candidate_sampler",
"embedding_lookup_sparse"],

View File

@ -442,8 +442,8 @@ class Tensor(object):
return _eval_using_default_session(self, feed_dict, self.graph, session)
def _TensorTensorConversionFunction(t, dtype=None, name=None):
_ = name
def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
_ = name, as_ref
if dtype and not dtype.is_compatible_with(t.dtype):
raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
@ -455,7 +455,7 @@ _tensor_conversion_func_registry = {
0: [(Tensor, _TensorTensorConversionFunction)]}
def convert_to_tensor(value, dtype=None, name=None):
def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
"""Converts the given `value` to a `Tensor`.
This function converts Python objects of various types to `Tensor`
@ -487,6 +487,7 @@ def convert_to_tensor(value, dtype=None, name=None):
dtype: Optional element type for the returned tensor. If missing, the
type is inferred from the type of `value`.
name: Optional name to use if a new `Tensor` is created.
as_ref: True if we want the result as a ref tensor.
Returns:
A `Tensor` based on `value`.
@ -502,7 +503,7 @@ def convert_to_tensor(value, dtype=None, name=None):
for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()):
for base_type, conversion_func in funcs_at_priority:
if isinstance(value, base_type):
ret = conversion_func(value, dtype=dtype, name=name)
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
if not isinstance(ret, Tensor):
raise RuntimeError(
"%sConversion function %r for type %s returned non-Tensor: %r"
@ -519,7 +520,8 @@ def convert_to_tensor(value, dtype=None, name=None):
% (error_prefix, value, type(value)))
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
as_ref=False):
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
If `value` is an `IndexedSlices` it is returned
@ -532,6 +534,7 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
dtype: (Optional.) The required `DType` of the returned `Tensor` or
`IndexedSlices`.
name: (Optional.) A name to use if a new `Tensor` is created.
as_ref: True if the caller wants the results as ref tensors.
Returns:
An `Tensor` or an `IndexedSlices` based on `value`.
@ -546,10 +549,11 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
% (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
return value
else:
return convert_to_tensor(value, dtype, name)
return convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None,
as_ref=False):
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
Args:
@ -557,10 +561,10 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
by `convert_to_tensor()`.
dtype: (Optional.) The required `DType` of the returned `Tensor`
`IndexedSlices`.
name: (Optional.) A name prefix to used when a new `Tensor` is
created, in which case element `i` will be given the name `name
+ '_' + i`.
as_ref: True if the caller wants the results as ref tensors.
Returns:
A list of `Tensor` and/or `IndexedSlices` objects.
@ -580,7 +584,8 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
else:
n = None if name is None else "%s_%d" % (name, i)
ret.append(
convert_to_tensor_or_indexed_slices(value, dtype=dtype, name=n))
convert_to_tensor_or_indexed_slices(value, dtype=dtype, name=n,
as_ref=as_ref))
return ret
@ -590,13 +595,16 @@ def register_tensor_conversion_function(base_type, conversion_func,
The conversion function must have the following signature:
def conversion_func(value, dtype=None, name=None):
def conversion_func(value, dtype=None, name=None, as_ref=False):
# ...
It must return a Tensor with the given dtype if specified. If the
conversion function creates a new Tensor, it should use the given
name if specified. All exceptions will be propagated to the caller.
If `as_ref` is true, the function must return a Tensor reference,
such as a VariableOp.
NOTE: The conversion functions will execute in order of priority,
followed by order of registration. To ensure that a conversion
function F runs before another conversion function G, ensure that
@ -762,23 +770,23 @@ class SparseTensor(object):
```
By convention, `indices` should be sorted in row-major order (or equivalently
lexigraphic order on the tuples `indices[i]`). This is not enforced when
`SparseTensor` objects are constructed, but most Ops assume correct ordering.
lexicographic order on the tuples `indices[i]`). This is not enforced when
`SparseTensor` objects are constructed, but most ops assume correct ordering.
If the ordering is wrong, it can be fixed by calling `sparse_reorder` on the
misordered `SparseTensor`.
Example: The sparse tensor
```python
SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4])
SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4])
```
represents the dense tensor
```python
[[1, 0, 0, 0]
[0, 0, 2, 0]
[0, 0, 0, 0]]
[[1, 0, 0, 0]
[0, 0, 2, 0]
[0, 0, 0, 0]]
```
@@__init__
@ -795,14 +803,18 @@ class SparseTensor(object):
Args:
indices: A 2-D int64 tensor of shape `[N, ndims]`.
values: A 1-D tensor of any type and shape `[N]`.
dense_shape: A 1-D int64 tensor of shape `[ndims]`.
shape: A 1-D int64 tensor of shape `[ndims]`.
Returns:
A `SparseTensor`
"""
with op_scope([indices, values, shape], None, "SparseTensor"):
indices = convert_to_tensor(indices, name="indices")
values = convert_to_tensor(values, name="values")
# Always pass as_ref=True because we want to be able to update
# values later if it is a VariableOp.
# TODO(touts): Consider adding mutable_values() when 'values'
# is a VariableOp and updating users of SparseTensor.
values = convert_to_tensor(values, name="values", as_ref=True)
shape = convert_to_tensor(shape, name="shape")
self._indices = indices
self._values = values
@ -987,7 +999,9 @@ class Operation(object):
self._graph = g
if inputs is None:
inputs = []
self._inputs = inputs
elif not isinstance(inputs, list):
raise TypeError("inputs needs to be a list of Tensors: %s" % inputs)
self._inputs = list(inputs) # Defensive copy.
for a in self._inputs:
if not isinstance(a, Tensor):
raise TypeError("input needs to be a Tensor: %s" % a)
@ -1391,6 +1405,7 @@ def get_gradient_function(op):
_shape_registry = registry.Registry("shape functions")
_default_shape_function_registry = registry.Registry("default shape functions")
class RegisterShape(object):
"""A decorator for registering the shape function for an op type.
@ -1924,6 +1939,7 @@ class Graph(object):
A list of Operations.
"""
return list(self._nodes_by_id.values())
def get_operation_by_name(self, name):
"""Returns the `Operation` with the given `name`.
@ -2045,7 +2061,7 @@ class Graph(object):
else:
c = []
for item in self._collections.get(name, list()):
if hasattr(item, 'name') and item.name.startswith(scope):
if hasattr(item, "name") and item.name.startswith(scope):
c.append(item)
return c

View File

@ -522,19 +522,21 @@ def ConstantValue(tensor):
elif tensor.op.type == "Shape":
input_shape = tensor.op.inputs[0].get_shape()
if input_shape.is_fully_defined():
return np.array([dim.value for dim in input_shape.dims])
return np.array([dim.value for dim in input_shape.dims],
dtype=tensor.dtype.as_numpy_dtype)
else:
return None
elif tensor.op.type == "Size":
input_shape = tensor.op.inputs[0].get_shape()
if input_shape.is_fully_defined():
return np.array([np.prod([dim.value for dim in input_shape.dims])])
return np.array([np.prod([dim.value for dim in input_shape.dims])],
dtype=tensor.dtype.as_numpy_dtype)
else:
return None
elif tensor.op.type == "Rank":
input_shape = tensor.op.inputs[0].get_shape()
if input_shape.ndims is not None:
return np.array([input_shape.ndims])
return np.array([input_shape.ndims], dtype=tensor.dtype.as_numpy_dtype)
else:
return None
elif tensor.op.type == "Range":

View File

@ -378,19 +378,25 @@ class ConstantValueTest(test_util.TensorFlowTestCase):
self.assertIs(None, tensor_util.ConstantValue(tf_val))
def testShape(self):
np_val = np.array([1, 2, 3])
np_val = np.array([1, 2, 3], dtype=np.int32)
tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3]))
self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val))
c_val = tensor_util.ConstantValue(tf_val)
self.assertAllEqual(np_val, c_val)
self.assertEqual(np.int32, c_val.dtype)
def testSize(self):
np_val = np.array([6])
np_val = np.array([6], dtype=np.int32)
tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3]))
self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val))
c_val = tensor_util.ConstantValue(tf_val)
self.assertAllEqual(np_val, c_val)
self.assertEqual(np.int32, c_val.dtype)
def testRank(self):
np_val = np.array([3])
np_val = np.array([3], dtype=np.int32)
tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val))
c_val = tensor_util.ConstantValue(tf_val)
self.assertAllEqual(np_val, c_val)
self.assertEqual(np.int32, c_val.dtype)
if __name__ == "__main__":

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class BatchMatmulOpTest(tf.test.TestCase):
@ -176,9 +174,14 @@ class BatchMatmulGradientTest(tf.test.TestCase):
z = tf.batch_matmul(inx, iny, adj_x, adj_y)
loss = tf.reduce_sum(z)
epsilon = 1e-2
((x_jacob_t, x_jacob_n), (y_jacob_t, y_jacob_n)) = gc.ComputeGradient(
[inx, iny], [x.shape, y.shape], loss, [1],
x_init_value=[x, y], delta=epsilon)
((x_jacob_t, x_jacob_n),
(y_jacob_t, y_jacob_n)) = tf.test.compute_gradient(
[inx, iny],
[x.shape, y.shape],
loss,
[1],
x_init_value=[x, y],
delta=epsilon)
tf.logging.info("x_jacob_t = %s", x_jacob_t.reshape(x.shape))
tf.logging.info("x_jacob_n = %s", x_jacob_n.reshape(x.shape))

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker
class BiasAddTest(tf.test.TestCase):
@ -82,7 +80,7 @@ class BiasAddTest(tf.test.TestCase):
dtype=tf.float64)
b = tf.constant([1.3, 2.4], dtype=tf.float64)
bo = tf.nn.bias_add(t, b)
err = gradient_checker.ComputeGradientError(t, [3, 2], bo, [3, 2])
err = tf.test.compute_gradient_error(t, [3, 2], bo, [3, 2])
print("bias add tensor gradient err = ", err)
self.assertLess(err, 1e-10)
@ -92,7 +90,7 @@ class BiasAddTest(tf.test.TestCase):
dtype=tf.float64)
b = tf.constant([1.3, 2.4], dtype=tf.float64)
bo = tf.nn.bias_add(t, b)
err = gradient_checker.ComputeGradientError(b, [2], bo, [3, 2])
err = tf.test.compute_gradient_error(b, [2], bo, [3, 2])
print("bias add bias gradient err = ", err)
self.assertLess(err, 1e-10)
@ -103,7 +101,7 @@ class BiasAddTest(tf.test.TestCase):
t = tf.constant(x, shape=s, dtype=tf.float32)
b = tf.constant([1.3, 2.4], dtype=tf.float32)
bo = tf.nn.bias_add(t, b)
err = gradient_checker.ComputeGradientError(t, s, bo, s, x_init_value=x)
err = tf.test.compute_gradient_error(t, s, bo, s, x_init_value=x)
print("bias add tensor gradient err = ", err)
self.assertLess(err, 1e-3)

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class CastOpTest(tf.test.TestCase):
@ -160,7 +158,7 @@ class CastOpTest(tf.test.TestCase):
x = tf.constant(1.0, src_t)
z = tf.identity(x)
y = tf.cast(z, dst_t)
err = gc.ComputeGradientError(x, [1], y, [1])
err = tf.test.compute_gradient_error(x, [1], y, [1])
self.assertLess(err, 1e-3)

View File

@ -303,6 +303,63 @@ class ConcatOpTest(tf.test.TestCase):
dxs = sess.run(tf.gradients(c, xs, dc))
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
def testTensorConcatDim0Grad(self):
x_shapes = [[20, 7, 3], [10, 7, 3], [14, 7, 3]]
output_shape = [44, 7, 3]
x_vals = [np.random.random_sample(x_shape).astype(
np.float64) for x_shape in x_shapes]
with self.test_session():
xs = [tf.constant(x_val) for x_val in x_vals]
output = tf.concat(0, xs)
err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape)
self.assertLess(err, 1e-11)
def testTensorConcatDim1Grad(self):
x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]]
output_shape = [20, 11, 3]
x_vals = [np.random.random_sample(x_shape).astype(
np.float64) for x_shape in x_shapes]
with self.test_session():
xs = [tf.constant(x_val) for x_val in x_vals]
output = tf.concat(1, xs)
err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape)
self.assertLess(err, 1e-11)
def testIndexedSlicesConcatDim0Grad(self):
x_shapes = [[20, 7, 3], [10, 7, 3], [14, 7, 3]]
output_shape = [4, 7, 3]
x_vals = [np.random.random_sample(x_shape).astype(
np.float64) for x_shape in x_shapes]
with self.test_session():
xs = [tf.constant(x_val) for x_val in x_vals]
x_concat = tf.concat(0, xs)
output = tf.gather(x_concat, [1, 2, 0, 5])
err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape)
self.assertLess(err, 1e-11)
def testIndexedSlicesConcatDim1Grad(self):
x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]]
output_shape = [4, 11, 3]
x_vals = [np.random.random_sample(x_shape).astype(
np.float64) for x_shape in x_shapes]
with self.test_session():
xs = [tf.constant(x_val) for x_val in x_vals]
x_concat = tf.concat(1, xs)
output = tf.gather(x_concat, [1, 2, 0, 5])
err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape)
self.assertLess(err, 1e-11)
def testIndexedSlicesConcatDim2Grad(self):
x_shapes = [[20, 7, 3], [20, 7, 1], [20, 7, 2]]
output_shape = [4, 7, 6]
x_vals = [np.random.random_sample(x_shape).astype(
np.float64) for x_shape in x_shapes]
with self.test_session():
xs = [tf.constant(x_val) for x_val in x_vals]
x_concat = tf.concat(2, xs)
output = tf.gather(x_concat, [1, 2, 0, 5])
err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape)
self.assertLess(err, 1e-11)
if __name__ == "__main__":
tf.test.main()

View File

@ -1091,9 +1091,10 @@ class ControlFlowTest(tf.test.TestCase):
# Use a control dependency to ensure init_variable is run
# while asking for c
real_v = control_flow_ops.with_dependencies(name="real_tensor",
output_tensor=v,
dependencies=[v.initializer])
real_v = control_flow_ops.with_dependencies(
name="real_tensor",
output_tensor=v.ref(),
dependencies=[v.initializer])
c_val, real_v_val = sess.run([c, real_v])
# Ensure the result of 'real_c' is the same as 'c'
@ -1259,12 +1260,12 @@ class TupleTest(tf.test.TestCase):
with self.test_session():
v1 = tf.Variable([1.0])
add1 = tf.add(
control_flow_ops.with_dependencies([v1.initializer], v1),
control_flow_ops.with_dependencies([v1.initializer], v1.ref()),
2.0)
v2 = tf.Variable([10.0])
add2 = tf.add(control_flow_ops.with_dependencies([v2.initializer],
v2),
20.0)
add2 = tf.add(
control_flow_ops.with_dependencies([v2.initializer], v2.ref()),
20.0)
t1, _, t2 = control_flow_ops.tuple([add1, None, add2])
# v1 is not initialized.
@ -1291,14 +1292,14 @@ class TupleTest(tf.test.TestCase):
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
np.float32))
v1_at_1 = tf.IndexedSlices(
control_flow_ops.with_dependencies([v1.initializer], v1),
control_flow_ops.with_dependencies([v1.initializer], v1.ref()),
tf.constant([1]))
v2 = tf.Variable(
np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
np.float32))
v2_at_1 = tf.IndexedSlices(
control_flow_ops.with_dependencies([v2.initializer], v2),
control_flow_ops.with_dependencies([v2.initializer], v2.ref()),
tf.constant([1]))
st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
def GetInceptionShapes():
"""Iterator for the convolution shapes used in the Inception 2015 model.
@ -429,11 +427,11 @@ class Conv2DTest(tf.test.TestCase):
name="conv")
self.assertEqual(output_shape, conv.get_shape())
if test_input:
err = gc.ComputeGradientError(input_tensor, input_shape,
conv, output_shape)
err = tf.test.compute_gradient_error(input_tensor, input_shape, conv,
output_shape)
else:
err = gc.ComputeGradientError(filter_tensor, filter_shape,
conv, output_shape)
err = tf.test.compute_gradient_error(filter_tensor, filter_shape, conv,
output_shape)
print("conv_2d gradient error = ", err)
self.assertLess(err, tolerance)

View File

@ -24,7 +24,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
_ADD = lambda x, y: x + y
_SUB = lambda x, y: x - y
@ -58,11 +57,19 @@ class UnaryOpTest(tf.test.TestCase):
self.assertAllClose(np_ans, tf_cpu)
if x.dtype == np.float32:
s = list(np.shape(x))
jacob_t, jacob_n = gc.ComputeGradient(inx, s, y, s, x_init_value=x)
jacob_t, jacob_n = tf.test.compute_gradient(inx,
s,
y,
s,
x_init_value=x)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
elif x.dtype == np.float64:
s = list(np.shape(x))
jacob_t, jacob_n = gc.ComputeGradient(inx, s, y, s, x_init_value=x)
jacob_t, jacob_n = tf.test.compute_gradient(inx,
s,
y,
s,
x_init_value=x)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _compareGpu(self, x, np_func, tf_func):
@ -216,7 +223,11 @@ class BinaryOpTest(tf.test.TestCase):
iny = tf.convert_to_tensor(y)
out = tf_func(inx, iny)
xs = list(x.shape)
jacob_t, jacob_n = gc.ComputeGradient(inx, xs, out, zs, x_init_value=x)
jacob_t, jacob_n = tf.test.compute_gradient(inx,
xs,
out,
zs,
x_init_value=x)
if x.dtype == np.float32:
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
elif x.dtype == np.float64:
@ -230,7 +241,11 @@ class BinaryOpTest(tf.test.TestCase):
iny = tf.convert_to_tensor(y)
out = tf_func(inx, iny)
ys = list(np.shape(y))
jacob_t, jacob_n = gc.ComputeGradient(iny, ys, out, zs, x_init_value=y)
jacob_t, jacob_n = tf.test.compute_gradient(iny,
ys,
out,
zs,
x_init_value=y)
if x.dtype == np.float32:
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
elif x.dtype == np.float64:
@ -833,7 +848,11 @@ class SelectOpTest(tf.test.TestCase):
iny = tf.convert_to_tensor(y)
out = tf.select(c, inx, iny)
s = list(np.shape(c))
jacob_t, jacob_n = gc.ComputeGradient(inx, s, out, s, x_init_value=x)
jacob_t, jacob_n = tf.test.compute_gradient(inx,
s,
out,
s,
x_init_value=x)
if x.dtype == np.float32:
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
elif x.dtype == np.float64:
@ -845,7 +864,11 @@ class SelectOpTest(tf.test.TestCase):
iny = tf.convert_to_tensor(y)
out = tf.select(c, inx, iny)
s = list(np.shape(c))
jacob_t, jacob_n = gc.ComputeGradient(iny, s, out, s, x_init_value=y)
jacob_t, jacob_n = tf.test.compute_gradient(iny,
s,
out,
s,
x_init_value=y)
if x.dtype == np.float32:
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
elif x.dtype == np.float64:
@ -923,7 +946,11 @@ class MinMaxOpTest(tf.test.TestCase):
iny = tf.convert_to_tensor(y)
out = func(inx, iny)
s = list(np.shape(x))
jacob_t, jacob_n = gc.ComputeGradient(inx, s, out, s, x_init_value=x)
jacob_t, jacob_n = tf.test.compute_gradient(inx,
s,
out,
s,
x_init_value=x)
if x.dtype == np.float32:
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
elif x.dtype == np.float64:
@ -935,7 +962,11 @@ class MinMaxOpTest(tf.test.TestCase):
iny = tf.convert_to_tensor(y)
out = func(inx, iny)
s = list(np.shape(x))
jacob_t, jacob_n = gc.ComputeGradient(iny, s, out, s, x_init_value=y)
jacob_t, jacob_n = tf.test.compute_gradient(iny,
s,
out,
s,
x_init_value=y)
if x.dtype == np.float32:
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
elif x.dtype == np.float64:
@ -1159,8 +1190,12 @@ class ComplexMakeRealImagTest(tf.test.TestCase):
tf.square(tf.real(cplx))) + tf.reduce_sum(
tf.square(tf.imag(cplx)))
epsilon = 1e-3
jacob_t, jacob_n = gc.ComputeGradient(inx, list(x.shape), loss, [1],
x_init_value=x, delta=epsilon)
jacob_t, jacob_n = tf.test.compute_gradient(inx,
list(x.shape),
loss,
[1],
x_init_value=x,
delta=epsilon)
self.assertAllClose(jacob_t, jacob_n, rtol=epsilon, atol=epsilon)
def testGradient(self):
@ -1187,8 +1222,12 @@ class ComplexMakeRealImagTest(tf.test.TestCase):
# Defines the loss function as the sum of all coefficients of z.
loss = tf.reduce_sum(tf.real(z) + tf.imag(z))
epsilon = 0.005
jacob_t, jacob_n = gc.ComputeGradient(inp, list(data.shape), loss, [1],
x_init_value=data, delta=epsilon)
jacob_t, jacob_n = tf.test.compute_gradient(inp,
list(data.shape),
loss,
[1],
x_init_value=data,
delta=epsilon)
self.assertAllClose(jacob_t, jacob_n, rtol=epsilon, atol=epsilon)
def testMulGradient(self):

View File

@ -26,8 +26,6 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
def _AsLong(array):
"""Casts arrays elements to long type. Used to convert from numpy tf."""
@ -225,8 +223,11 @@ class EmbeddingLookupTest(tf.test.TestCase):
x_name = [_PName(i) for i in range(num_shards)]
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]
err = gc.ComputeGradientError(x, x_shape, y, y_shape,
x_init_value=x_init_value)
err = tf.test.compute_gradient_error(x,
x_shape,
y,
y_shape,
x_init_value=x_init_value)
self.assertLess(err, 1e-4)
def testGradientsEmbeddingLookupWithComputedParams(self):
@ -246,8 +247,11 @@ class EmbeddingLookupTest(tf.test.TestCase):
x_name = [_PName(i) for i in range(num_shards)]
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]
err = gc.ComputeGradientError(x, x_shape, y, y_shape,
x_init_value=x_init_value)
err = tf.test.compute_gradient_error(x,
x_shape,
y,
y_shape,
x_init_value=x_init_value)
self.assertLess(err, 1e-3)
def testConstructionNonSharded(self):
@ -381,8 +385,11 @@ class EmbeddingLookupSparseTest(tf.test.TestCase):
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]
y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
err = gc.ComputeGradientError(x, x_shape, y, y_shape,
x_init_value=x_init_value)
err = tf.test.compute_gradient_error(x,
x_shape,
y,
y_shape,
x_init_value=x_init_value)
self.assertLess(err, 1e-5 if dtype == tf.float64 else 2e-3)

View File

@ -34,7 +34,7 @@ from tensorflow.python.ops import gradients
from tensorflow.python.platform import logging
def _Product(t):
def _product(t):
if isinstance(t, int):
return t
else:
@ -44,11 +44,11 @@ def _Product(t):
return y
def _ComputeTheoricalJacobian(x, x_shape, x_data, dy, dy_shape, dx):
def _compute_theoretical_jacobian(x, x_shape, x_data, dy, dy_shape, dx):
"""Computes the theoretical Jacobian for dy/dx.
Computes the theoretical Jacobian using the ops generated by
ComputeGradient().
compute_gradient().
Args:
x: the tensor "x".
@ -64,9 +64,9 @@ def _ComputeTheoricalJacobian(x, x_shape, x_data, dy, dy_shape, dx):
"dy_size" is the number of elements in dy.
"""
# To compute the jacobian, we treat x and y are one-dimensional vectors
x_size = _Product(x_shape)
x_val_size = _Product(x_shape[1:]) # This is used for sparse gradients
dy_size = _Product(dy_shape)
x_size = _product(x_shape)
x_val_size = _product(x_shape[1:]) # This is used for sparse gradients
dy_size = _product(dy_shape)
jacobian = np.zeros((x_size, dy_size), dtype=x_data.dtype)
# For each of the entry of dy, we set this to be 1 and
@ -92,7 +92,7 @@ def _ComputeTheoricalJacobian(x, x_shape, x_data, dy, dy_shape, dx):
return jacobian
def _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta):
def _compute_numeric_jacobian(x, x_shape, x_data, y, y_shape, delta):
"""Computes the numeric Jacobian for dy/dx.
Computes the numeric Jacobian by slightly perturbing the inputs and
@ -113,8 +113,8 @@ def _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta):
"""
# To compute the jacobian, we treat x and y are one-dimensional vectors
x_size = _Product(x_shape)
y_size = _Product(y_shape)
x_size = _product(x_shape)
y_size = _product(y_shape)
jacobian = np.zeros((x_size, y_size), dtype=x_data.dtype)
# For each of the entry of x, we slightly perturbs this by adding and
@ -134,7 +134,7 @@ def _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta):
return jacobian
def _ComputeDxAndDy(x, y, y_shape):
def _compute_dx_and_dy(x, y, y_shape):
"""Returns a node to compute gradient of x wrt y."""
# We make up a dy so that we can compute the gradients. We don't really use
# the value of dy -- we will always feed it. We need to add an identity node
@ -149,8 +149,14 @@ def _ComputeDxAndDy(x, y, y_shape):
return grads[0], dy_orig
def _ComputeGradient(x, x_shape, dx, y, y_shape, dy,
x_init_value=None, delta=1e-3):
def _compute_gradient(x,
x_shape,
dx,
y,
y_shape,
dy,
x_init_value=None,
delta=1e-3):
"""Computes the theoretical and numerical jacobian."""
t = dtypes.as_dtype(x.dtype)
allowed_types = [dtypes.float32, dtypes.float64]
@ -170,16 +176,21 @@ def _ComputeGradient(x, x_shape, dx, y, y_shape, dy,
dtype = np.float64
x_data = np.asfarray(np.random.random_sample(x_shape), dtype=dtype)
jacob_t = _ComputeTheoricalJacobian(x, x_shape, x_data, dy, y_shape, dx)
jacob_n = _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta)
jacob_t = _compute_theoretical_jacobian(x, x_shape, x_data, dy, y_shape, dx)
jacob_n = _compute_numeric_jacobian(x, x_shape, x_data, y, y_shape, delta)
return jacob_t, jacob_n
def _ComputeGradientList(
x, x_shape, y, y_shape, x_init_value=None, delta=1e-3, init_targets=None):
def _compute_gradient_list(x,
x_shape,
y,
y_shape,
x_init_value=None,
delta=1e-3,
init_targets=None):
"""Compute gradients for a list of x values."""
assert isinstance(x, list)
dx, dy = zip(*[_ComputeDxAndDy(xi, y, y_shape) for xi in x])
dx, dy = zip(*[_compute_dx_and_dy(xi, y, y_shape) for xi in x])
if init_targets is not None:
assert isinstance(init_targets, (list, tuple))
@ -187,15 +198,20 @@ def _ComputeGradientList(
init.run()
if x_init_value is None:
x_init_value = [None] * len(x)
ret = [_ComputeGradient(xi, x_shapei, dxi, y, y_shape, dyi,
x_init_valuei, delta)
for xi, x_shapei, dxi, dyi, x_init_valuei in
zip(x, x_shape, dx, dy, x_init_value)]
ret = [_compute_gradient(xi, x_shapei, dxi, y, y_shape, dyi, x_init_valuei,
delta)
for xi, x_shapei, dxi, dyi, x_init_valuei in zip(x, x_shape, dx, dy,
x_init_value)]
return ret
def ComputeGradient(
x, x_shape, y, y_shape, x_init_value=None, delta=1e-3, init_targets=None):
def compute_gradient(x,
x_shape,
y,
y_shape,
x_init_value=None,
delta=1e-3,
init_targets=None):
"""Computes and returns the theoretical and numerical Jacobian.
Args:
@ -219,20 +235,25 @@ def ComputeGradient(
number of elements in y. If x is a list, returns a list of two numpy arrays.
"""
if isinstance(x, list):
return _ComputeGradientList(x, x_shape, y, y_shape, x_init_value,
delta, init_targets)
return _compute_gradient_list(x, x_shape, y, y_shape, x_init_value, delta,
init_targets)
else:
if init_targets is not None:
assert isinstance(init_targets, (list, tuple))
for init in init_targets:
init.run()
dx, dy = _ComputeDxAndDy(x, y, y_shape)
ret = _ComputeGradient(x, x_shape, dx, y, y_shape, dy, x_init_value, delta)
dx, dy = _compute_dx_and_dy(x, y, y_shape)
ret = _compute_gradient(x, x_shape, dx, y, y_shape, dy, x_init_value, delta)
return ret
def ComputeGradientError(
x, x_shape, y, y_shape, x_init_value=None, delta=1e-3, init_targets=None):
def compute_gradient_error(x,
x_shape,
y,
y_shape,
x_init_value=None,
delta=1e-3,
init_targets=None):
"""Computes the gradient error.
Computes the maximum error for dy/dx between the computed Jacobian and the
@ -263,8 +284,8 @@ def ComputeGradientError(
Returns:
The maximum error in between the two Jacobians.
"""
grad = ComputeGradient(x, x_shape, y, y_shape, x_init_value,
delta, init_targets)
grad = compute_gradient(x, x_shape, y, y_shape, x_init_value, delta,
init_targets)
if isinstance(grad, tuple):
grad = [grad]
return max(np.fabs(j_t - j_n).max() for j_t, j_n in grad)

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests.gradient_checker import ComputeGradientError
class GradientCheckerTest(tf.test.TestCase):
@ -37,7 +35,7 @@ class GradientCheckerTest(tf.test.TestCase):
y = tf.add(x1, x2, name="y")
# checking gradients for x1
error = ComputeGradientError(x1, size, y, size)
error = tf.test.compute_gradient_error(x1, size, y, size)
tf.logging.info("x1 error = %f", error)
assert error < 1e-4
@ -50,7 +48,7 @@ class GradientCheckerTest(tf.test.TestCase):
y = tf.add(x1, x2, name="y")
# checking gradients for x1
error = ComputeGradientError(x1, size, y, size)
error = tf.test.compute_gradient_error(x1, size, y, size)
tf.logging.info("x1 error = %f", error)
assert error < 1e-4
@ -66,8 +64,12 @@ class GradientCheckerTest(tf.test.TestCase):
# checkint gradients for x2 using a special init_value and delta
x_init_value = np.asarray(np.arange(6, dtype=np.float64).reshape(2, 3))
error = ComputeGradientError(x2, size, y, size, x_init_value=x_init_value,
delta=1e-2)
error = tf.test.compute_gradient_error(x2,
size,
y,
size,
x_init_value=x_init_value,
delta=1e-2)
tf.logging.info("x2 error = %f", error)
assert error < 1e-10
@ -82,7 +84,7 @@ class GradientCheckerTest(tf.test.TestCase):
indices = tf.constant(index_values, name="i")
y = tf.gather(params, indices, name="y")
error = ComputeGradientError(params, p_shape, y, y_shape)
error = tf.test.compute_gradient_error(params, p_shape, y, y_shape)
tf.logging.info("gather error = %f", error)
assert error < 1e-4
@ -101,7 +103,7 @@ class GradientCheckerTest(tf.test.TestCase):
indices2 = tf.constant(index_values2, name="i2")
y2 = tf.gather(y, indices2, name="y2")
error = ComputeGradientError(params, p_shape, y2, y2_shape)
error = tf.test.compute_gradient_error(params, p_shape, y2, y2_shape)
tf.logging.info("nested gather error = %f", error)
assert error < 1e-4
@ -166,9 +168,11 @@ def BuildAndTestMiniMNIST(param_index, tag):
cost = tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="cost")
# Test the gradients.
err = ComputeGradientError(all_params[param_index],
param_sizes[param_index],
cost, [batch], delta=1e-5)
err = tf.test.compute_gradient_error(all_params[param_index],
param_sizes[param_index],
cost,
[batch],
delta=1e-5)
tf.logging.info("Mini MNIST: %s gradient error = %g", tag, err)
return err

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class MatrixInverseGradientTest(tf.test.TestCase):
pass # Filled in below
@ -49,11 +47,11 @@ def _GetMatrixInverseGradientTest(dtype_, shape_):
else:
ainv = tf.batch_matrix_inverse(a)
theoretical, numerical = gc.ComputeGradient(a,
shape_,
ainv,
shape_,
delta=delta)
theoretical, numerical = tf.test.compute_gradient(a,
shape_,
ainv,
shape_,
delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
return Test
@ -87,8 +85,11 @@ def _GetMatrixDeterminantGradientTest(dtype_, shape_):
c = tf.batch_matrix_determinant(a)
out_shape = shape_[:-2] # last two dimensions hold matrices
theoretical, numerical = gc.ComputeGradient(a, shape_, c, out_shape,
delta=delta)
theoretical, numerical = tf.test.compute_gradient(a,
shape_,
c,
out_shape,
delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.models.rnn import linear
class LinearTest(tf.test.TestCase):
@ -32,21 +30,21 @@ class LinearTest(tf.test.TestCase):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(1.0)):
x = tf.zeros([1, 2])
l = linear.linear([x], 2, False)
l = tf.nn.rnn_cell.linear([x], 2, False)
sess.run([tf.variables.initialize_all_variables()])
res = sess.run([l], {x.name: np.array([[1., 2.]])})
self.assertAllClose(res[0], [[3.0, 3.0]])
# Checks prevent you from accidentally creating a shared function.
with self.assertRaises(ValueError) as exc:
l1 = linear.linear([x], 2, False)
l1 = tf.nn.rnn_cell.linear([x], 2, False)
self.assertEqual(str(exc.exception)[:12], "Over-sharing")
# But you can create a new one in a new scope and share the variables.
with tf.variable_scope("l1") as new_scope:
l1 = linear.linear([x], 2, False)
l1 = tf.nn.rnn_cell.linear([x], 2, False)
with tf.variable_scope(new_scope, reuse=True):
linear.linear([l1], 2, False)
tf.nn.rnn_cell.linear([l1], 2, False)
self.assertEqual(len(tf.trainable_variables()), 2)

View File

@ -25,9 +25,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests.gradient_checker import ComputeGradientError
class LRNOpTest(tf.test.TestCase):
@ -107,7 +104,7 @@ class LRNOpTest(tf.test.TestCase):
lrn_op = tf.nn.local_response_normalization(
inp, name="lrn", depth_radius=lrn_depth_radius, bias=bias,
alpha=alpha, beta=beta)
err = ComputeGradientError(inp, shape, lrn_op, shape)
err = tf.test.compute_gradient_error(inp, shape, lrn_op, shape)
print("LRN Gradient error ", err)
self.assertLess(err, 1e-4)

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class MatMulTest(tf.test.TestCase):
@ -161,7 +159,7 @@ class MatMulGradientTest(tf.test.TestCase):
y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
shape=[2, 4], dtype=tf.float64, name="y")
m = tf.matmul(x, y, name="matmul")
err = gc.ComputeGradientError(x, [3, 2], m, [3, 4])
err = tf.test.compute_gradient_error(x, [3, 2], m, [3, 4])
print("matmul input0 gradient err = ", err)
self.assertLess(err, 1e-10)
@ -172,7 +170,7 @@ class MatMulGradientTest(tf.test.TestCase):
y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
shape=[2, 4], dtype=tf.float64, name="y")
m = tf.matmul(x, y, name="matmul")
err = gc.ComputeGradientError(y, [2, 4], m, [3, 4])
err = tf.test.compute_gradient_error(y, [2, 4], m, [3, 4])
print("matmul input1 gradient err = ", err)
self.assertLess(err, 1e-10)
@ -189,7 +187,7 @@ class MatMulGradientTest(tf.test.TestCase):
y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
shape=shape_y, dtype=tf.float64, name="y")
m = tf.matmul(x, y, transpose_a, transpose_b, name="matmul")
err = gc.ComputeGradientError(x, shape_x, m, [3, 4])
err = tf.test.compute_gradient_error(x, shape_x, m, [3, 4])
print("matmul input0 gradient err = ", err)
self.assertLess(err, 1e-10)
@ -211,7 +209,7 @@ class MatMulGradientTest(tf.test.TestCase):
y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
shape=shape_y, dtype=tf.float64, name="y")
m = tf.matmul(x, y, transpose_a, transpose_b, name="matmul")
err = gc.ComputeGradientError(y, shape_y, m, [3, 4])
err = tf.test.compute_gradient_error(y, shape_y, m, [3, 4])
print("matmul input1 gradient err = ", err)
self.assertLess(err, 1e-10)

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker
class PackOpTest(tf.test.TestCase):
@ -51,7 +49,7 @@ class PackOpTest(tf.test.TestCase):
# TODO(irving): Remove list() once we handle maps correctly
xs = list(map(tf.constant, data))
c = tf.pack(xs)
err = gradient_checker.ComputeGradientError(xs, shapes, c, shape)
err = tf.test.compute_gradient_error(xs, shapes, c, shape)
self.assertLess(err, 1e-6)
def testZeroSize(self):

View File

@ -24,8 +24,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class PadOpTest(tf.test.TestCase):
@ -58,7 +56,11 @@ class PadOpTest(tf.test.TestCase):
y = tf.pad(inx, ina)
# Expected y's shape to be:
ys = list(np.array(x.shape) + np.sum(np.array(a), axis=1))
jacob_t, jacob_n = gc.ComputeGradient(inx, xs, y, ys, x_init_value=x)
jacob_t, jacob_n = tf.test.compute_gradient(inx,
xs,
y,
ys,
x_init_value=x)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _testAll(self, np_inputs, paddings):

View File

@ -23,7 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
from tensorflow.python.ops import gen_nn_ops
@ -436,9 +435,12 @@ class PoolingTest(tf.test.TestCase):
t = pool_func(input_tensor, ksize=[1, window_rows, window_rows, 1],
strides=[1, row_stride, col_stride, 1],
padding=padding, name=func_name)
err = gc.ComputeGradientError(
input_tensor, input_sizes, t, output_sizes,
x_init_value=x_init_value, delta=1e-2)
err = tf.test.compute_gradient_error(input_tensor,
input_sizes,
t,
output_sizes,
x_init_value=x_init_value,
delta=1e-2)
print("%s gradient error = " % func_name, err)
self.assertLess(err, err_margin)

View File

@ -24,7 +24,6 @@ import numpy as np
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.python.kernel_tests import gradient_checker
class SumReductionTest(tf.test.TestCase):
@ -150,13 +149,12 @@ class SumReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_sum(t, reduction_axes)
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t,
shape,
su,
sum_shape,
x_init_value=x,
delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
shape,
su,
sum_shape,
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testGradient(self):
@ -211,18 +209,30 @@ class MeanReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_mean(t, [1, 2])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
su = tf.reduce_mean(t, [0, 1, 2, 3])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [1], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[1],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
su = tf.reduce_mean(t, [])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 3, 4, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 3, 4, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
@ -269,18 +279,30 @@ class ProdReductionTest(tf.test.TestCase):
t = tf.convert_to_tensor(x)
su = tf.reduce_prod(t, [])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 3, 4, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 3, 4, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
su = tf.reduce_prod(t, [1, 2])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
su = tf.reduce_prod(t, [0, 1, 2, 3])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [1], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[1],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
# NOTE(kearnes): the current gradient calculation gives NaNs for 0 inputs
@ -288,8 +310,12 @@ class ProdReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_prod(t, [])
jacob_t, _ = gradient_checker.ComputeGradient(
t, s, su, [2, 3, 4, 2], x_init_value=x, delta=1)
jacob_t, _ = tf.test.compute_gradient(t,
s,
su,
[2, 3, 4, 2],
x_init_value=x,
delta=1)
with self.assertRaisesOpError("Tensor had NaN values"):
tf.check_numerics(jacob_t, message="_ProdGrad NaN test").op.run()
@ -336,8 +362,12 @@ class MinReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_min(t, [1, 2])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testGradient2(self):
@ -346,8 +376,12 @@ class MinReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_min(t, [1])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 4, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 4, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testGradient3(self):
@ -356,8 +390,12 @@ class MinReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_min(t, [2])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 3, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 3, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testGradient4(self):
@ -366,8 +404,12 @@ class MinReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_min(t)
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [1], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[1],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
@ -414,8 +456,12 @@ class MaxReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_max(t, [1, 2])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testGradient2(self):
@ -424,8 +470,12 @@ class MaxReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_max(t, [1])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 4, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 4, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testGradient3(self):
@ -434,8 +484,12 @@ class MaxReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_max(t, [2])
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [2, 3, 2], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[2, 3, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testGradient4(self):
@ -444,8 +498,12 @@ class MaxReductionTest(tf.test.TestCase):
with self.test_session():
t = tf.convert_to_tensor(x)
su = tf.reduce_max(t)
jacob_t, jacob_n = gradient_checker.ComputeGradient(
t, s, su, [1], x_init_value=x, delta=1)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[1],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class ReluTest(tf.test.TestCase):
@ -67,7 +65,11 @@ class ReluTest(tf.test.TestCase):
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float32, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
err = tf.test.compute_gradient_error(x,
[2, 5],
y,
[2, 5],
x_init_value=x_init)
print("relu (float) gradient err = ", err)
self.assertLess(err, 1e-4)
@ -98,7 +100,11 @@ class ReluTest(tf.test.TestCase):
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float64, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
err = tf.test.compute_gradient_error(x,
[2, 5],
y,
[2, 5],
x_init_value=x_init)
print("relu (double) gradient err = ", err)
self.assertLess(err, 1e-10)
@ -112,8 +118,11 @@ class ReluTest(tf.test.TestCase):
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float32, order="F")
err = gc.ComputeGradientError(x, [2, 5], z[0], [2, 5],
x_init_value=x_init)
err = tf.test.compute_gradient_error(x,
[2, 5],
z[0],
[2, 5],
x_init_value=x_init)
print("relu (float) gradient of gradient err = ", err)
self.assertLess(err, 1e-4)
@ -127,8 +136,11 @@ class ReluTest(tf.test.TestCase):
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float64, order="F")
err = gc.ComputeGradientError(x, [2, 5], z[0], [2, 5],
x_init_value=x_init)
err = tf.test.compute_gradient_error(x,
[2, 5],
z[0],
[2, 5],
x_init_value=x_init)
print("relu (double) gradient of gradient err = ", err)
self.assertLess(err, 1e-10)
@ -178,7 +190,11 @@ class Relu6Test(tf.test.TestCase):
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]],
dtype=np.float32, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
err = tf.test.compute_gradient_error(x,
[2, 5],
y,
[2, 5],
x_init_value=x_init)
print("relu6 (float) gradient err = ", err)
self.assertLess(err, 1e-4)
@ -191,7 +207,11 @@ class Relu6Test(tf.test.TestCase):
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]],
dtype=np.float64, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
err = tf.test.compute_gradient_error(x,
[2, 5],
y,
[2, 5],
x_init_value=x_init)
print("relu6 (double) gradient err = ", err)
self.assertLess(err, 1e-10)

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class ReshapeTest(tf.test.TestCase):
@ -81,8 +79,11 @@ class ReshapeTest(tf.test.TestCase):
with self.test_session():
input_tensor = tf.constant(x, shape=[2, 3, 4])
reshape_out = tf.reshape(input_tensor, [1, 8, 3])
err = gc.ComputeGradientError(input_tensor, s,
reshape_out, s, x_init_value=x)
err = tf.test.compute_gradient_error(input_tensor,
s,
reshape_out,
s,
x_init_value=x)
print("Reshape gradient error = " % err)
self.assertLess(err, 1e-3)

View File

@ -23,15 +23,14 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class ReverseSequenceTest(tf.test.TestCase):
def _testReverseSequence(self, x, seq_dim, seq_lengths,
def _testReverseSequence(self, x, batch_dim, seq_dim, seq_lengths,
truth, use_gpu=False, expected_err_re=None):
with self.test_session(use_gpu=use_gpu):
ans = tf.reverse_sequence(x,
batch_dim=batch_dim,
seq_dim=seq_dim,
seq_lengths=seq_lengths)
if expected_err_re is None:
@ -42,11 +41,11 @@ class ReverseSequenceTest(tf.test.TestCase):
with self.assertRaisesOpError(expected_err_re):
ans.eval()
def _testBothReverseSequence(self, x, seq_dim, seq_lengths,
def _testBothReverseSequence(self, x, batch_dim, seq_dim, seq_lengths,
truth, expected_err_re=None):
self._testReverseSequence(x, seq_dim, seq_lengths,
self._testReverseSequence(x, batch_dim, seq_dim, seq_lengths,
truth, True, expected_err_re)
self._testReverseSequence(x, seq_dim, seq_lengths,
self._testReverseSequence(x, batch_dim, seq_dim, seq_lengths,
truth, False, expected_err_re)
def _testBasic(self, dtype):
@ -55,18 +54,22 @@ class ReverseSequenceTest(tf.test.TestCase):
[[9, 10, 11, 12], [13, 14, 15, 16]],
[[17, 18, 19, 20], [21, 22, 23, 24]]], dtype=dtype)
x = x.reshape(3, 2, 4, 1, 1)
x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
# reverse dim 2 up to (0:3, none, 0:4) along dim=0
seq_dim = 2
seq_lengths = np.asarray([3, 0, 4], dtype=np.int64)
truth = np.asarray(
truth_orig = np.asarray(
[[[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3
[[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none
[[20, 19, 18, 17], [24, 23, 22, 21]]], # reverse 0:4 (all)
dtype=dtype)
truth = truth.reshape(3, 2, 4, 1, 1)
self._testBothReverseSequence(x, seq_dim, seq_lengths, truth)
truth_orig = truth_orig.reshape(3, 2, 4, 1, 1)
truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
seq_dim = 0 # permute seq_dim and batch_dim (originally 2 and 0, resp.)
batch_dim = 2
self._testBothReverseSequence(x, batch_dim, seq_dim, seq_lengths, truth)
def testFloatBasic(self):
self._testBasic(np.float32)
@ -89,22 +92,25 @@ class ReverseSequenceTest(tf.test.TestCase):
[[9, 10, 11, 12], [13, 14, 15, 16]],
[[17, 18, 19, 20], [21, 22, 23, 24]]], dtype=np.float)
x = x.reshape(3, 2, 4, 1, 1)
x = x.transpose([2, 1, 0, 3, 4]) # transpose axes 0 <=> 2
# reverse dim 2 up to (0:3, none, 0:4) along dim=0
seq_dim = 2
# reverse dim 0 up to (0:3, none, 0:4) along dim=2
seq_dim = 0
batch_dim = 2
seq_lengths = np.asarray([3, 0, 4], dtype=np.int64)
with self.test_session():
input_t = tf.constant(x, shape=x.shape)
seq_lengths_t = tf.constant(seq_lengths, shape=seq_lengths.shape)
reverse_sequence_out = tf.reverse_sequence(input_t,
batch_dim=batch_dim,
seq_dim=seq_dim,
seq_lengths=seq_lengths_t)
err = gc.ComputeGradientError(input_t,
x.shape,
reverse_sequence_out,
x.shape,
x_init_value=x)
err = tf.test.compute_gradient_error(input_t,
x.shape,
reverse_sequence_out,
x.shape,
x_init_value=x)
print("ReverseSequence gradient error = %g" % err)
self.assertLess(err, 1e-8)
@ -123,6 +129,26 @@ class ReverseSequenceTest(tf.test.TestCase):
seq_lengths=tf.placeholder(tf.int64, shape=(32,)),
seq_dim=3)
# batch_dim out of bounds.
with self.assertRaisesRegexp(
ValueError, "batch_dim must be < input.dims()"):
tf.reverse_sequence(
tf.placeholder(tf.float32, shape=(32, 2, 3)),
seq_lengths=tf.placeholder(tf.int64, shape=(32,)),
seq_dim=0,
batch_dim=3)
with self.test_session():
inputs = tf.placeholder(tf.float32, shape=(32, 2, 3))
seq_lengths = tf.placeholder(tf.int64, shape=(32,))
output = tf.reverse_sequence(
inputs,
seq_lengths=seq_lengths,
seq_dim=0) # batch_dim default is 0
with self.assertRaisesOpError("batch_dim == seq_dim"):
output.eval(feed_dict={inputs: np.random.rand(32, 2, 3),
seq_lengths: xrange(32)})
if __name__ == "__main__":
tf.test.main()

View File

@ -26,7 +26,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
from tensorflow.python.ops import rnn_cell
class RNNCellTest(tf.test.TestCase):
@ -96,9 +96,9 @@ class RNNCellTest(tf.test.TestCase):
# Different inputs so different outputs and states
for i in range(1, batch_size):
self.assertTrue(
float(np.linalg.norm((res[0][0,:] - res[0][i,:]))) > 1e-6)
float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6)
self.assertTrue(
float(np.linalg.norm((res[1][0,:] - res[1][i,:]))) > 1e-6)
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testOutputProjectionWrapper(self):
with self.test_session() as sess:

View File

@ -25,11 +25,8 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.models.rnn import rnn
from tensorflow.models.rnn import rnn_cell
class Plus1RNNCell(rnn_cell.RNNCell):
class Plus1RNNCell(tf.nn.rnn_cell.RNNCell):
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
@property
@ -68,7 +65,7 @@ class RNNTest(tf.test.TestCase):
cell = Plus1RNNCell()
batch_size = 2
inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10
outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
for out, inp in zip(outputs, inputs):
self.assertEqual(out.get_shape(), inp.get_shape())
@ -89,14 +86,15 @@ class RNNTest(tf.test.TestCase):
def testDropout(self):
cell = Plus1RNNCell()
full_dropout_cell = rnn_cell.DropoutWrapper(
full_dropout_cell = tf.nn.rnn_cell.DropoutWrapper(
cell, input_keep_prob=1e-12, seed=0)
batch_size = 2
inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10
with tf.variable_scope("share_scope"):
outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
with tf.variable_scope("drop_scope"):
dropped_outputs, _ = rnn.rnn(full_dropout_cell, inputs, dtype=tf.float32)
dropped_outputs, _ = tf.nn.rnn(
full_dropout_cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
for out, inp in zip(outputs, inputs):
self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list())
@ -120,7 +118,7 @@ class RNNTest(tf.test.TestCase):
batch_size = 2
inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10
with tf.variable_scope("drop_scope"):
dynamic_outputs, dynamic_states = rnn.rnn(
dynamic_outputs, dynamic_states = tf.nn.rnn(
cell, inputs, sequence_length=sequence_length, dtype=tf.float32)
self.assertEqual(len(dynamic_outputs), len(inputs))
self.assertEqual(len(dynamic_states), len(inputs))
@ -158,11 +156,11 @@ class LSTMTest(tf.test.TestCase):
batch_size = 2
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
cell = rnn_cell.LSTMCell(
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, initializer=initializer)
inputs = 10 * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
for out in outputs:
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
@ -177,12 +175,12 @@ class LSTMTest(tf.test.TestCase):
batch_size = 2
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
cell = rnn_cell.LSTMCell(
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=True,
cell_clip=0.0, initializer=initializer)
inputs = 10 * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
for out in outputs:
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
@ -202,12 +200,12 @@ class LSTMTest(tf.test.TestCase):
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, 2*num_units)
cell = rnn_cell.LSTMCell(
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=False, initializer=initializer)
inputs = 10 * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
with tf.variable_scope("share_scope"):
outputs, states = rnn.state_saving_rnn(
outputs, states = tf.nn.state_saving_rnn(
cell, inputs, state_saver=state_saver, state_name="save_lstm")
self.assertEqual(len(outputs), len(inputs))
for out in outputs:
@ -229,10 +227,10 @@ class LSTMTest(tf.test.TestCase):
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
inputs = 10 * [
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = rnn_cell.LSTMCell(
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=True,
num_proj=num_proj, initializer=initializer)
outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
tf.initialize_all_variables().run()
@ -252,7 +250,7 @@ class LSTMTest(tf.test.TestCase):
inputs = 10 * [
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = rnn_cell.LSTMCell(
cell = tf.nn.rnn_cell.LSTMCell(
num_units,
input_size=input_size,
use_peepholes=True,
@ -261,7 +259,7 @@ class LSTMTest(tf.test.TestCase):
num_proj_shards=num_proj_shards,
initializer=initializer)
outputs, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
@ -280,7 +278,7 @@ class LSTMTest(tf.test.TestCase):
initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed)
inputs = 10 * [tf.placeholder(tf.float64)]
cell = rnn_cell.LSTMCell(
cell = tf.nn.rnn_cell.LSTMCell(
num_units,
input_size=input_size,
use_peepholes=True,
@ -289,7 +287,7 @@ class LSTMTest(tf.test.TestCase):
num_proj_shards=num_proj_shards,
initializer=initializer)
outputs, _ = rnn.rnn(
outputs, _ = tf.nn.rnn(
cell, inputs, initial_state=cell.zero_state(batch_size, tf.float64))
self.assertEqual(len(outputs), len(inputs))
@ -311,7 +309,7 @@ class LSTMTest(tf.test.TestCase):
inputs = 10 * [tf.placeholder(tf.float32)]
initializer = tf.constant_initializer(0.001)
cell_noshard = rnn_cell.LSTMCell(
cell_noshard = tf.nn.rnn_cell.LSTMCell(
num_units, input_size,
num_proj=num_proj,
use_peepholes=True,
@ -319,15 +317,15 @@ class LSTMTest(tf.test.TestCase):
num_unit_shards=num_unit_shards,
num_proj_shards=num_proj_shards)
cell_shard = rnn_cell.LSTMCell(
cell_shard = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=True,
initializer=initializer, num_proj=num_proj)
with tf.variable_scope("noshard_scope"):
outputs_noshard, states_noshard = rnn.rnn(
outputs_noshard, states_noshard = tf.nn.rnn(
cell_noshard, inputs, dtype=tf.float32)
with tf.variable_scope("shard_scope"):
outputs_shard, states_shard = rnn.rnn(
outputs_shard, states_shard = tf.nn.rnn(
cell_shard, inputs, dtype=tf.float32)
self.assertEqual(len(outputs_noshard), len(inputs))
@ -362,7 +360,7 @@ class LSTMTest(tf.test.TestCase):
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
inputs = 10 * [tf.placeholder(tf.float64)]
cell = rnn_cell.LSTMCell(
cell = tf.nn.rnn_cell.LSTMCell(
num_units,
input_size=input_size,
use_peepholes=True,
@ -370,9 +368,9 @@ class LSTMTest(tf.test.TestCase):
num_unit_shards=num_unit_shards,
num_proj_shards=num_proj_shards,
initializer=initializer)
dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
dropout_cell = tf.nn.rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
outputs, states = rnn.rnn(
outputs, states = tf.nn.rnn(
dropout_cell, inputs, sequence_length=sequence_length,
initial_state=cell.zero_state(batch_size, tf.float64))
@ -398,16 +396,16 @@ class LSTMTest(tf.test.TestCase):
initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed)
inputs = 10 * [
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = rnn_cell.LSTMCell(
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=True,
num_proj=num_proj, initializer=initializer)
with tf.variable_scope("share_scope"):
outputs0, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs0, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
with tf.variable_scope("share_scope", reuse=True):
outputs1, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs1, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
with tf.variable_scope("diff_scope"):
outputs2, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs2, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
tf.initialize_all_variables().run()
input_value = np.random.randn(batch_size, input_size)
@ -433,16 +431,16 @@ class LSTMTest(tf.test.TestCase):
initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed)
inputs = 10 * [
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = rnn_cell.LSTMCell(
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=True,
num_proj=num_proj, initializer=initializer)
with tf.name_scope("scope0"):
with tf.variable_scope("share_scope"):
outputs0, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs0, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
with tf.name_scope("scope1"):
with tf.variable_scope("share_scope", reuse=True):
outputs1, _ = rnn.rnn(cell, inputs, dtype=tf.float32)
outputs1, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
tf.initialize_all_variables().run()
input_value = np.random.randn(batch_size, input_size)

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker
class SegmentReductionHelper(tf.test.TestCase):
@ -127,8 +125,12 @@ class SegmentReductionOpTest(SegmentReductionHelper):
with self.test_session():
tf_x, np_x = self._input(shape, dtype=tf.float64)
s = tf_op(data=tf_x, segment_ids=indices)
jacob_t, jacob_n = gradient_checker.ComputeGradient(
tf_x, shape, s, [3, 4], x_init_value=np_x.astype(np.double),
jacob_t, jacob_n = tf.test.compute_gradient(
tf_x,
shape,
s,
[3, 4],
x_init_value=np_x.astype(np.double),
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
@ -170,7 +172,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
s = tf.unsorted_segment_sum(data=tf_x,
segment_ids=indices,
num_segments=num_segments)
jacob_t, jacob_n = gradient_checker.ComputeGradient(
jacob_t, jacob_n = tf.test.compute_gradient(
tf_x,
shape,
s,
@ -196,14 +198,20 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
unsorted_s = tf.unsorted_segment_sum(data=tf_x,
segment_ids=indices,
num_segments=num_segments)
unsorted_jacob_t, unsorted_jacob_n = gradient_checker.ComputeGradient(
tf_x, shape, unsorted_s, [num_segments, num_cols],
(unsorted_jacob_t, unsorted_jacob_n) = tf.test.compute_gradient(
tf_x,
shape,
unsorted_s,
[num_segments, num_cols],
x_init_value=np_x.astype(np.double),
delta=1)
# Results from SegmentSum
sorted_s = tf.segment_sum(data=tf_x, segment_ids=indices)
sorted_jacob_t, sorted_jacob_n = gradient_checker.ComputeGradient(
tf_x, shape, sorted_s, [num_segments, num_cols],
sorted_jacob_t, sorted_jacob_n = tf.test.compute_gradient(
tf_x,
shape,
sorted_s,
[num_segments, num_cols],
x_init_value=np_x.astype(np.double),
delta=1)
self.assertAllClose(unsorted_jacob_t, sorted_jacob_t, rtol=1e-3, atol=1e-3)
@ -277,8 +285,12 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
tf_indices, _, tf_x, np_x = self._sparse_input(
shape, num_indices, dtype=tf.float64)
s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
jacob_t, jacob_n = gradient_checker.ComputeGradient(
tf_x, shape, s, [3, 4], x_init_value=np_x.astype(np.double),
jacob_t, jacob_n = tf.test.compute_gradient(
tf_x,
shape,
s,
[3, 4],
x_init_value=np_x.astype(np.double),
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)

View File

@ -21,16 +21,13 @@ from __future__ import print_function
import math
import random
# pylint: disable=g-bad-import-order,unused-import
import tensorflow.python.platform
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn import rnn
from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import seq2seq
class Seq2SeqTest(tf.test.TestCase):
@ -38,10 +35,12 @@ class Seq2SeqTest(tf.test.TestCase):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
_, enc_states = rnn.rnn(rnn_cell.GRUCell(2), inp, dtype=tf.float32)
_, enc_states = tf.nn.rnn(
tf.nn.rnn_cell.GRUCell(2), inp, dtype=tf.float32)
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
dec, mem = seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell)
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
tf.nn.rnn_cell.GRUCell(2), 4)
dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
@ -56,8 +55,9 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
dec, mem = seq2seq.basic_rnn_seq2seq(inp, dec_inp, cell)
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
tf.nn.rnn_cell.GRUCell(2), 4)
dec, mem = tf.nn.seq2seq.basic_rnn_seq2seq(inp, dec_inp, cell)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
@ -72,8 +72,9 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
dec, mem = seq2seq.tied_rnn_seq2seq(inp, dec_inp, cell)
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
tf.nn.rnn_cell.GRUCell(2), 4)
dec, mem = tf.nn.seq2seq.tied_rnn_seq2seq(inp, dec_inp, cell)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
@ -87,11 +88,11 @@ class Seq2SeqTest(tf.test.TestCase):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
cell = rnn_cell.BasicLSTMCell(2)
_, enc_states = rnn.rnn(cell, inp, dtype=tf.float32)
cell = tf.nn.rnn_cell.BasicLSTMCell(2)
_, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
dec, mem = seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1],
cell, 4)
dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1],
cell, 4)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
@ -106,8 +107,9 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)]
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
cell = rnn_cell.BasicLSTMCell(2)
dec, mem = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp, cell, 2, 5)
cell = tf.nn.rnn_cell.BasicLSTMCell(2)
dec, mem = tf.nn.seq2seq.embedding_rnn_seq2seq(
enc_inp, dec_inp, cell, 2, 5)
sess.run([tf.variables.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
@ -121,7 +123,7 @@ class Seq2SeqTest(tf.test.TestCase):
w = tf.get_variable("proj_w", [2, 5])
b = tf.get_variable("proj_b", [5])
with tf.variable_scope("proj_seq2seq"):
dec, _ = seq2seq.embedding_rnn_seq2seq(
dec, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b))
sess.run([tf.variables.initialize_all_variables()])
res = sess.run(dec)
@ -131,12 +133,15 @@ class Seq2SeqTest(tf.test.TestCase):
# Test that previous-feeding model ignores inputs after the first.
dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)]
tf.get_variable_scope().reuse_variables()
d1, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp, cell, 2, 5,
feed_previous=True)
d2, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp2, cell, 2, 5,
feed_previous=True)
d3, _ = seq2seq.embedding_rnn_seq2seq(enc_inp, dec_inp2, cell, 2, 5,
feed_previous=tf.constant(True))
d1, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
enc_inp, dec_inp, cell, 2, 5,
feed_previous=True)
d2, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
enc_inp, dec_inp2, cell, 2, 5,
feed_previous=True)
d3, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
enc_inp, dec_inp2, cell, 2, 5,
feed_previous=tf.constant(True))
res1 = sess.run(d1)
res2 = sess.run(d2)
res3 = sess.run(d3)
@ -148,8 +153,9 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)]
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
cell = rnn_cell.BasicLSTMCell(2)
dec, mem = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp, cell, 5)
cell = tf.nn.rnn_cell.BasicLSTMCell(2)
dec, mem = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
enc_inp, dec_inp, cell, 5)
sess.run([tf.variables.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
@ -163,7 +169,7 @@ class Seq2SeqTest(tf.test.TestCase):
w = tf.get_variable("proj_w", [2, 5])
b = tf.get_variable("proj_b", [5])
with tf.variable_scope("proj_seq2seq"):
dec, _ = seq2seq.embedding_tied_rnn_seq2seq(
dec, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
enc_inp, dec_inp, cell, 5, output_projection=(w, b))
sess.run([tf.variables.initialize_all_variables()])
res = sess.run(dec)
@ -173,11 +179,13 @@ class Seq2SeqTest(tf.test.TestCase):
# Test that previous-feeding model ignores inputs after the first.
dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)]
tf.get_variable_scope().reuse_variables()
d1, _ = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp, cell, 5,
feed_previous=True)
d2, _ = seq2seq.embedding_tied_rnn_seq2seq(enc_inp, dec_inp2, cell, 5,
feed_previous=True)
d3, _ = seq2seq.embedding_tied_rnn_seq2seq(
d1, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
enc_inp, dec_inp, cell, 5,
feed_previous=True)
d2, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
enc_inp, dec_inp2, cell, 5,
feed_previous=True)
d3, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
enc_inp, dec_inp2, cell, 5, feed_previous=tf.constant(True))
res1 = sess.run(d1)
res2 = sess.run(d2)
@ -188,14 +196,15 @@ class Seq2SeqTest(tf.test.TestCase):
def testAttentionDecoder1(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
cell = rnn_cell.GRUCell(2)
cell = tf.nn.rnn_cell.GRUCell(2)
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32)
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
for e in enc_outputs])
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
dec, mem = seq2seq.attention_decoder(dec_inp, enc_states[-1],
attn_states, cell, output_size=4)
dec, mem = tf.nn.seq2seq.attention_decoder(
dec_inp, enc_states[-1],
attn_states, cell, output_size=4)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
@ -208,15 +217,16 @@ class Seq2SeqTest(tf.test.TestCase):
def testAttentionDecoder2(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
cell = rnn_cell.GRUCell(2)
cell = tf.nn.rnn_cell.GRUCell(2)
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32)
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
for e in enc_outputs])
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
dec, mem = seq2seq.attention_decoder(dec_inp, enc_states[-1],
attn_states, cell, output_size=4,
num_heads=2)
dec, mem = tf.nn.seq2seq.attention_decoder(
dec_inp, enc_states[-1],
attn_states, cell, output_size=4,
num_heads=2)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
@ -230,14 +240,15 @@ class Seq2SeqTest(tf.test.TestCase):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
cell = rnn_cell.GRUCell(2)
enc_outputs, enc_states = rnn.rnn(cell, inp, dtype=tf.float32)
cell = tf.nn.rnn_cell.GRUCell(2)
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
for e in enc_outputs])
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
dec, mem = seq2seq.embedding_attention_decoder(dec_inp, enc_states[-1],
attn_states, cell, 4,
output_size=3)
dec, mem = tf.nn.seq2seq.embedding_attention_decoder(
dec_inp, enc_states[-1],
attn_states, cell, 4,
output_size=3)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
@ -252,8 +263,8 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
enc_inp = [tf.constant(1, tf.int32, shape=[2]) for i in xrange(2)]
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
cell = rnn_cell.BasicLSTMCell(2)
dec, mem = seq2seq.embedding_attention_seq2seq(
cell = tf.nn.rnn_cell.BasicLSTMCell(2)
dec, mem = tf.nn.seq2seq.embedding_attention_seq2seq(
enc_inp, dec_inp, cell, 2, 5)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
@ -268,7 +279,7 @@ class Seq2SeqTest(tf.test.TestCase):
w = tf.get_variable("proj_w", [2, 5])
b = tf.get_variable("proj_b", [5])
with tf.variable_scope("proj_seq2seq"):
dec, _ = seq2seq.embedding_attention_seq2seq(
dec, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b))
sess.run([tf.variables.initialize_all_variables()])
res = sess.run(dec)
@ -278,11 +289,11 @@ class Seq2SeqTest(tf.test.TestCase):
# Test that previous-feeding model ignores inputs after the first.
dec_inp2 = [tf.constant(0, tf.int32, shape=[2]) for _ in xrange(3)]
tf.get_variable_scope().reuse_variables()
d1, _ = seq2seq.embedding_attention_seq2seq(
d1, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
enc_inp, dec_inp, cell, 2, 5, feed_previous=True)
d2, _ = seq2seq.embedding_attention_seq2seq(
d2, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
enc_inp, dec_inp2, cell, 2, 5, feed_previous=True)
d3, _ = seq2seq.embedding_attention_seq2seq(
d3, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
enc_inp, dec_inp2, cell, 2, 5, feed_previous=tf.constant(True))
res1 = sess.run(d1)
res2 = sess.run(d2)
@ -297,21 +308,21 @@ class Seq2SeqTest(tf.test.TestCase):
targets = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
weights = [tf.constant(1.0, shape=[2]) for i in xrange(3)]
average_loss_per_example = seq2seq.sequence_loss(
average_loss_per_example = tf.nn.seq2seq.sequence_loss(
logits, targets, weights, output_classes,
average_across_timesteps=True,
average_across_batch=True)
res = sess.run(average_loss_per_example)
self.assertAllClose(res, 1.60944)
average_loss_per_sequence = seq2seq.sequence_loss(
average_loss_per_sequence = tf.nn.seq2seq.sequence_loss(
logits, targets, weights, output_classes,
average_across_timesteps=False,
average_across_batch=True)
res = sess.run(average_loss_per_sequence)
self.assertAllClose(res, 4.828314)
total_loss = seq2seq.sequence_loss(
total_loss = tf.nn.seq2seq.sequence_loss(
logits, targets, weights, output_classes,
average_across_timesteps=False,
average_across_batch=False)
@ -326,13 +337,13 @@ class Seq2SeqTest(tf.test.TestCase):
targets = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
weights = [tf.constant(1.0, shape=[2]) for i in xrange(3)]
average_loss_per_example = seq2seq.sequence_loss_by_example(
average_loss_per_example = tf.nn.seq2seq.sequence_loss_by_example(
logits, targets, weights, output_classes,
average_across_timesteps=True)
res = sess.run(average_loss_per_example)
self.assertAllClose(res, np.asarray([1.609438, 1.609438]))
loss_per_sequence = seq2seq.sequence_loss_by_example(
loss_per_sequence = tf.nn.seq2seq.sequence_loss_by_example(
logits, targets, weights, output_classes,
average_across_timesteps=False)
res = sess.run(loss_per_sequence)
@ -343,26 +354,30 @@ class Seq2SeqTest(tf.test.TestCase):
# We learn to copy 10 symbols in 2 buckets: length 4 and length 8.
classes = 10
buckets = [(4, 4), (8, 8)]
# We use sampled softmax so we keep output projection separate.
w = tf.get_variable("proj_w", [24, classes])
w_t = tf.transpose(w)
b = tf.get_variable("proj_b", [classes])
# Here comes a sample Seq2Seq model using GRU cells.
def SampleGRUSeq2Seq(enc_inp, dec_inp, weights):
"""Example sequence-to-sequence model that uses GRU cells."""
def GRUSeq2Seq(enc_inp, dec_inp):
cell = rnn_cell.MultiRNNCell([rnn_cell.GRUCell(24)] * 2)
return seq2seq.embedding_attention_seq2seq(
enc_inp, dec_inp, cell, classes, classes, output_projection=(w, b))
targets = [dec_inp[i+1] for i in xrange(len(dec_inp) - 1)] + [0]
def SampledLoss(inputs, labels):
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 8, classes)
return seq2seq.model_with_buckets(enc_inp, dec_inp, targets, weights,
buckets, classes, GRUSeq2Seq,
softmax_loss_function=SampledLoss)
# Now we construct the copy model.
with self.test_session() as sess:
# We use sampled softmax so we keep output projection separate.
w = tf.get_variable("proj_w", [24, classes])
w_t = tf.transpose(w)
b = tf.get_variable("proj_b", [classes])
# Here comes a sample Seq2Seq model using GRU cells.
def SampleGRUSeq2Seq(enc_inp, dec_inp, weights):
"""Example sequence-to-sequence model that uses GRU cells."""
def GRUSeq2Seq(enc_inp, dec_inp):
cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.GRUCell(24)] * 2)
return tf.nn.seq2seq.embedding_attention_seq2seq(
enc_inp, dec_inp, cell, classes, classes,
output_projection=(w, b))
targets = [dec_inp[i+1] for i in xrange(len(dec_inp) - 1)] + [0]
def SampledLoss(inputs, labels):
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 8, classes)
return tf.nn.seq2seq.model_with_buckets(
enc_inp, dec_inp, targets, weights,
buckets, classes, GRUSeq2Seq,
softmax_loss_function=SampledLoss)
# Now we construct the copy model.
tf.set_random_seed(111)
batch_size = 32
inp = [tf.placeholder(tf.int32, shape=[None]) for _ in xrange(8)]

View File

@ -24,8 +24,6 @@ import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class ShapeOpsTest(tf.test.TestCase):
@ -119,7 +117,7 @@ class ShapeOpsTest(tf.test.TestCase):
dtype=tf.float32)
squeezed = tf.expand_dims(inp, 1)
err = gc.ComputeGradientError(inp, [4, 2], squeezed, [4, 1, 2])
err = tf.test.compute_gradient_error(inp, [4, 2], squeezed, [4, 1, 2])
self.assertLess(err, 1e-3)
def testExpandDimsScalar(self):
@ -202,7 +200,7 @@ class ShapeOpsTest(tf.test.TestCase):
a = tf.reshape(inp, [4, 1, 2])
squeezed = tf.squeeze(a, [])
err = gc.ComputeGradientError(a, [4, 1, 2], squeezed, [4, 2])
err = tf.test.compute_gradient_error(a, [4, 1, 2], squeezed, [4, 2])
self.assertLess(err, 1e-3)
def testSqueezeGradientWithSqueezeDims(self):
@ -211,7 +209,7 @@ class ShapeOpsTest(tf.test.TestCase):
a = tf.reshape(inp, [4, 1, 2, 1])
squeezed = tf.squeeze(a, [1])
err = gc.ComputeGradientError(a, [4, 1, 2, 1], squeezed, [4, 2, 1])
err = tf.test.compute_gradient_error(a, [4, 1, 2, 1], squeezed, [4, 2, 1])
self.assertLess(err, 1e-3)
@ -366,8 +364,11 @@ class TileTest(tf.test.TestCase):
shape=input_shape, dtype=tf.float64)
tiled = tf.tile(a, multiples)
grad_shape = list(np.array(multiples) * np.array(inp.shape))
err = gc.ComputeGradientError(a, list(input_shape), tiled, grad_shape,
x_init_value=inp)
err = tf.test.compute_gradient_error(a,
list(input_shape),
tiled,
grad_shape,
x_init_value=inp)
print("tile(float) error = ", err)
self.assertLess(err, 1e-3)
@ -382,7 +383,7 @@ class TileTest(tf.test.TestCase):
a = tf.constant([float(x) for x in inp.flatten()],
shape=[4, 2], dtype=tf.float32)
tiled = tf.tile(a, [1, 2])
err = gc.ComputeGradientError(a, [4, 2], tiled, [4, 4])
err = tf.test.compute_gradient_error(a, [4, 2], tiled, [4, 4])
self.assertLess(err, 1e-3)
def testShapeFunctionEdgeCases(self):

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class SoftplusTest(tf.test.TestCase):
@ -57,7 +55,11 @@ class SoftplusTest(tf.test.TestCase):
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float32, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
err = tf.test.compute_gradient_error(x,
[2, 5],
y,
[2, 5],
x_init_value=x_init)
print("softplus (float) gradient err = ", err)
self.assertLess(err, 1e-4)

View File

@ -0,0 +1,68 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Softsign and SoftsignGrad."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
import numpy as np
import tensorflow as tf
class SoftsignTest(tf.test.TestCase):
def _npSoftsign(self, np_features):
return np_features / (1 + np.abs(np_features))
def _testSoftsign(self, np_features, use_gpu=False):
np_softsign = self._npSoftsign(np_features)
with self.test_session(use_gpu=use_gpu):
softsign = tf.nn.softsign(np_features)
tf_softsign = softsign.eval()
self.assertAllClose(np_softsign, tf_softsign)
self.assertShapeEqual(np_softsign, softsign)
def testNumbers(self):
for t in [np.float, np.double]:
self._testSoftsign(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=False)
self._testSoftsign(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=True)
def testGradient(self):
with self.test_session():
x = tf.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5], name="x")
y = tf.nn.softsign(x, name="softsign")
x_init = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float32, order="F")
err = tf.test.compute_gradient_error(x,
[2, 5],
y,
[2, 5],
x_init_value=x_init)
print("softsign (float) gradient err = ", err)
self.assertLess(err, 1e-4)
if __name__ == "__main__":
tf.test.main()

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
def RandMatrix(rows, cols, tr):
if tr:
@ -96,8 +94,10 @@ class MatMulGradientTest(tf.test.TestCase):
transpose_b=tr_b,
a_is_sparse=sp_a,
b_is_sparse=sp_b)
err = (gc.ComputeGradientError(a, [2, 3] if tr_a else [3, 2], m, [3, 4]) +
gc.ComputeGradientError(b, [4, 2] if tr_b else [2, 4], m, [3, 4]))
err = (tf.test.compute_gradient_error(a, [2, 3]
if tr_a else [3, 2], m, [3, 4]) +
tf.test.compute_gradient_error(b, [4, 2]
if tr_b else [2, 4], m, [3, 4]))
print("sparse_matmul gradient err = ", err)
self.assertLess(err, 1e-3)

View File

@ -24,8 +24,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests.gradient_checker import ComputeGradient
class TransposeTest(tf.test.TestCase):
@ -48,10 +46,10 @@ class TransposeTest(tf.test.TestCase):
xs = list(np.shape(x))
ys = list(np.shape(tf_ans))
if x.dtype == np.float32:
jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2)
jacob_t, jacob_n = tf.test.compute_gradient(inx, xs, y, ys, x, 1e-2)
self.assertAllClose(jacob_t, jacob_n, 1e-3, 1e-3)
elif x.dtype == np.float64:
jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2)
jacob_t, jacob_n = tf.test.compute_gradient(inx, xs, y, ys, x, 1e-2)
self.assertAllClose(jacob_t, jacob_n, 1e-6, 1e-6)
return tf_ans, jacob_t
@ -70,10 +68,10 @@ class TransposeTest(tf.test.TestCase):
xs = list(np.shape(x))
ys = list(np.shape(tf_ans))
if x.dtype == np.float32:
jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2)
jacob_t, jacob_n = tf.test.compute_gradient(inx, xs, y, ys, x, 1e-2)
self.assertAllClose(jacob_t, jacob_n, 1e-3, 1e-3)
elif x.dtype == np.float64:
jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2)
jacob_t, jacob_n = tf.test.compute_gradient(inx, xs, y, ys, x, 1e-2)
self.assertAllClose(jacob_t, jacob_n, 1e-6, 1e-6)
return tf_ans, jacob_t

View File

@ -24,8 +24,6 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker
class UnpackOpTest(tf.test.TestCase):
@ -53,8 +51,7 @@ class UnpackOpTest(tf.test.TestCase):
with self.test_session(use_gpu=use_gpu):
x = tf.constant(data)
cs = tf.unpack(x, num=shape[0])
err = gradient_checker.ComputeGradientError(x, shape, cs[i],
shapes[i])
err = tf.test.compute_gradient_error(x, shape, cs[i], shapes[i])
self.assertLess(err, 1e-6)
def testInferNum(self):

View File

@ -23,8 +23,6 @@ import tensorflow.python.platform
import numpy as np
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
class XentTest(tf.test.TestCase):
@ -120,7 +118,7 @@ class XentTest(tf.test.TestCase):
0.1, 0.8, 2.7, 6.4], shape=[3, 4],
dtype=tf.float64, name="f")
x = tf.nn.softmax_cross_entropy_with_logits(f, l, name="xent")
err = gc.ComputeGradientError(f, [3, 4], x, [3])
err = tf.test.compute_gradient_error(f, [3, 4], x, [3])
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)

View File

@ -42,7 +42,7 @@ PyRecordWriter::~PyRecordWriter() {
delete file_;
}
bool PyRecordWriter::WriteRecord(::tensorflow::StringPiece record) {
bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) {
if (writer_ == nullptr) return false;
Status s = writer_->WriteRecord(record);
return s.ok();

View File

@ -36,7 +36,7 @@ class PyRecordWriter {
static PyRecordWriter* New(const string& filename);
~PyRecordWriter();
bool WriteRecord(::tensorflow::StringPiece record);
bool WriteRecord(tensorflow::StringPiece record);
void Close();
private:

View File

@ -20,16 +20,17 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
@ops.RegisterGradient("Pack")
def _PackGrad(op, grad):
"""Gradient for pack op."""
return array_ops.unpack(grad, num=op.get_attr('N'))
return array_ops.unpack(grad, num=op.get_attr("N"))
@ops.RegisterGradient("Unpack")
@ -41,28 +42,82 @@ def _UnpackGrad(_, *grads):
@ops.RegisterGradient("Concat")
def _ConcatGrad(op, grad):
"""Gradient for concat op."""
assert isinstance(grad, ops.Tensor)
def _CreateDenseMaskAndBegin(sizes, concat_dim):
"""Create variables for iteratively slicing a dense gradients tensor."""
# Since shape is 1-D, shape_of_shape = [rank-of-inputs]
shape_of_shape = array_ops.shape(sizes[0])
# Make a vector of length equal to the input's dimensions,
# with 0's everywhere and 1 in the concat dim position.
# Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
mask = array_ops.concat(0,
[array_ops.fill(
array_ops.expand_dims(concat_dim, 0), 0),
[1],
array_ops.fill(
shape_of_shape - concat_dim - 1, 0)])
begin = array_ops.fill(shape_of_shape, 0)
return mask, begin
# Degenerate concatenation, just return grad.
if len(op.inputs) == 2:
return [None, grad]
# Get the inputs' tensor shapes
sizes = [array_ops.shape(x) for x in op.inputs[1:]]
concat_dim = op.inputs[0]
# Since shape is 1-D, shape_of_shape = [rank-of-inputs]
shape_of_shape = array_ops.shape(sizes[0])
# Make a vector of length equal to the input's dimensions,
# with 0's everywhere and 1 in the concat dim position.
# Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
mask = array_ops.concat(0,
[array_ops.fill(
array_ops.expand_dims(concat_dim, 0), 0), [1],
array_ops.fill(shape_of_shape - concat_dim - 1, 0)])
out_grads = []
begin = array_ops.fill(shape_of_shape, 0)
for i in range(len(sizes)):
out_grads.append(array_ops.slice(grad, begin, sizes[i]))
# Lint complains begin = begin + ...
begin = math_ops.add(begin, sizes[i] * mask)
if isinstance(grad, ops.Tensor):
# Get the inputs' tensor shapes
sizes = [array_ops.shape(x) for x in op.inputs[1:]]
mask, begin = _CreateDenseMaskAndBegin(sizes, concat_dim)
for size in sizes:
out_grads.append(array_ops.slice(grad, begin, size))
# Lint complains begin = begin + ...
begin = math_ops.add(begin, size * mask)
elif isinstance(grad, ops.IndexedSlices):
concat_dim_static = tensor_util.ConstantValue(concat_dim)
if concat_dim_static is None:
raise ValueError("Can only compute IndexedSlices gradient with "
"statically-known concat_dim")
# Get the inputs' tensor shapes
sizes = [array_ops.shape(x) for x in op.inputs[1:]]
if concat_dim_static > 0:
# IndexedSlices, concat_dim > 0. Each input gets IndexedSlices gradients
# with all the indices, but with grad.values sliced accordingly. This
# is like the Tensor case, except shape(grad.values)[0] is not equal to
# shape(sizes[i])[0], since only a subset of the dim-0 values are stored.
mask, begin = _CreateDenseMaskAndBegin(sizes, concat_dim)
for size in sizes:
new_values = array_ops.slice(
grad.values,
begin,
array_ops.concat(0, [[-1], array_ops.slice(size, [1], [-1])]))
out_grads.append(
ops.IndexedSlices(new_values, grad.indices, size))
# Lint complains begin = begin + ...
begin = math_ops.add(begin, size * mask)
else:
# IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
# only for the relevant indices.
start = constant_op.constant(0, dtype=grad.indices.dtype)
for size in sizes:
size_concat_dim = array_ops.gather(size, concat_dim)
if size_concat_dim.dtype != grad.indices.dtype:
size_concat_dim = math_ops.cast(size_concat_dim,
dtype=grad.indices.dtype)
end = start + size_concat_dim
# Compute the 1-D Tensor of indices relevant for this input.
indices_to_select = array_ops.squeeze(
array_ops.where(math_ops.logical_and(grad.indices >= start,
grad.indices < end)),
squeeze_dims=[1])
new_indices = array_ops.gather(grad.indices, indices_to_select) - start
new_values = array_ops.gather(grad.values, indices_to_select)
out_grads.append(
ops.IndexedSlices(new_values, new_indices, size))
start = end
else:
raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
return [None] + out_grads
@ -201,6 +256,7 @@ def _PadGrad(op, grad):
def _ReverseSequenceGrad(op, grad):
seq_lengths = op.inputs[1]
return [array_ops.reverse_sequence(grad,
seq_dim=op.get_attr("seq_dim"),
seq_lengths=seq_lengths),
batch_dim=op.get_attr("batch_dim"),
seq_dim=op.get_attr("seq_dim"),
seq_lengths=seq_lengths),
None]

View File

@ -990,17 +990,22 @@ def _ReverseSequenceShape(op):
A single-element list containing the shape of the output.
Raises:
ValueError: If the input shapes are incompatible.
ValueError: If the input shapes are incompatible or seq_dim == batch_dim.
"""
input_shape = op.inputs[0].get_shape()
seq_lens_shape = op.inputs[1].get_shape().with_rank(1)
batch_size = input_shape[0].merge_with(seq_lens_shape[0])
input_shape = tensor_shape.TensorShape([batch_size]).concatenate(
input_shape[1:])
seq_dim = op.get_attr("seq_dim")
batch_dim = op.get_attr("batch_dim")
if batch_dim >= input_shape.ndims:
raise ValueError("batch_dim must be < input.dims() (%d vs %d)" %
(batch_dim, input_shape.ndims))
if seq_dim >= input_shape.ndims:
raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
(seq_dim, input_shape.ndims))
batch_size = input_shape[batch_dim].merge_with(seq_lens_shape[0])
input_shape = tensor_shape.TensorShape([
value if ix != batch_dim else batch_size
for ix, value in enumerate(input_shape)])
return [input_shape]

View File

@ -172,12 +172,24 @@ def _ConstantShape(op):
[d.size for d in op.get_attr("value").tensor_shape.dim])]
ops.register_tensor_conversion_function((list, tuple), constant, 100)
ops.register_tensor_conversion_function(np.ndarray, constant, 100)
ops.register_tensor_conversion_function(np.generic, constant, 100)
ops.register_tensor_conversion_function(object, constant, 200)
def _constant_tensor_conversion_function(v, dtype=None, name=None,
as_ref=False):
_ = as_ref
return constant(v, dtype=dtype, name=name)
def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None):
ops.register_tensor_conversion_function(
(list, tuple), _constant_tensor_conversion_function, 100)
ops.register_tensor_conversion_function(
np.ndarray, _constant_tensor_conversion_function, 100)
ops.register_tensor_conversion_function(
np.generic, _constant_tensor_conversion_function, 100)
ops.register_tensor_conversion_function(
object, _constant_tensor_conversion_function, 200)
def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None,
as_ref=False):
_ = as_ref
if not s.is_fully_defined():
raise ValueError(
"Cannot convert a partially known TensorShape to a Tensor: %s" % s)
@ -193,7 +205,9 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None):
ops.register_tensor_conversion_function(
tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
def _dimension_tensor_conversion_function(d, dtype=None, name=None):
def _dimension_tensor_conversion_function(d, dtype=None, name=None,
as_ref=False):
_ = as_ref
if d.value is None:
raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d)
if dtype is not None:

View File

@ -33,7 +33,7 @@ def _SwitchGrad(op, *grad):
if isinstance(ctxt, WhileContext):
merge_op = ctxt.switch_map.get(op)
if merge_op:
merge_op._update_input(1, grad[1])
merge_op._update_input(1, next_iteration(grad[1]))
return None, None
else:
merge_op = merge(grad, name="b_switch")[0]
@ -70,7 +70,7 @@ def _MergeGrad(op, grad, _):
else:
num_inputs = len(op.inputs)
cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
return [Switch(grad, cond[i])[1] for i in xrange(num_inputs)]
return [switch(grad, cond[i])[1] for i in xrange(num_inputs)]
@ops.RegisterGradient("Exit")
@ -89,7 +89,7 @@ def _ExitGrad(op, grad):
@ops.RegisterGradient("NextIteration")
def _NextIterationGrad(_, grad):
return next_iteration(grad)
return grad
@ops.RegisterGradient("Enter")

Some files were not shown because too many files have changed in this diff Show More