diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 6976a372983..729d84a07b0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -104,6 +104,7 @@ filegroup( "//tensorflow/contrib/testing:all_files", "//tensorflow/contrib/util:all_files", "//tensorflow/core:all_files", + "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", "//tensorflow/core/distributed_runtime/rpc:all_files", "//tensorflow/core/kernels:all_files", diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 510426cc034..9db453f0dd2 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -38,7 +38,6 @@ namespace { const char kFfmpegExecutable[] = "ffmpeg"; const int32 kDefaultProbeSize = 5000000; // 5MB - std::vector<string> FfmpegCommandLine(const string& input_filename, const string& output_filename, const string& input_format_id, @@ -63,6 +62,39 @@ std::vector<string> FfmpegCommandLine(const string& input_filename, }; } +// Is a named binary installed and executable by the current process? +// Note that this is harder than it seems like it should be... +bool IsBinaryInstalled(const string& binary_name) { + string path = ::getenv("PATH"); + for (const string& dir : str_util::Split(path, ':')) { + const string binary_path = io::JoinPath(dir, binary_name); + char absolute_path[PATH_MAX + 1]; + ::realpath(binary_path.c_str(), absolute_path); + struct stat statinfo; + int result = ::stat(absolute_path, &statinfo); + if (result < 0) { + continue; + } + if (!S_ISREG(statinfo.st_mode)) { + continue; + } + + // Is the current user able to execute the file? + if (statinfo.st_uid == ::geteuid() && statinfo.st_mode & S_IXUSR) { + return true; + } + // Is the current group able to execute the file? + if (statinfo.st_uid == ::getegid() && statinfo.st_mode & S_IXGRP) { + return true; + } + // Is anyone able to execute the file? + if (statinfo.st_mode & S_IXOTH) { + return true; + } + } + return false; +} + [[noreturn]] int ExecuteFfmpeg(const std::vector<string>& args) { std::vector<char*> args_chars; std::transform(args.begin(), args.end(), std::back_inserter(args_chars), @@ -191,6 +223,14 @@ Status ReadAudioFile(const string& filename, FfmpegCommandLine(filename, output_filename, audio_format_id, samples_per_second, channel_count); + // Unfortunately, it's impossible to differentiate an exec failure due to the + // binary being missing and an error from the binary's execution. Therefore, + // check to see if the binary *should* be available. If not, return an error + // that will be converted into a helpful error message by the TensorFlow op. + if (!IsBinaryInstalled(kFfmpegExecutable)) { + return Status(error::Code::NOT_FOUND, StrCat("FFmpeg could not be found.")); + } + // Execute ffmpeg and report errors. pid_t child_pid = ::fork(); if (child_pid < 0) { @@ -202,7 +242,7 @@ Status ReadAudioFile(const string& filename, int status_code; ::waitpid(child_pid, &status_code, 0); if (status_code) { - return Status(error::Code::NOT_FOUND, + return Status(error::Code::UNKNOWN, StrCat("FFmpeg execution failed: ", status_code)); } *output_samples = ReadPcmFile(output_filename); diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 9c9cfe4c99b..4d849894051 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -818,7 +818,7 @@ class DropoutTest(tf.test.TestCase): with self.test_session(): images = np.random.uniform(size=(5, height, width, 3)) output = tf.contrib.layers.dropout(images) - self.assertEquals(output.op.name, 'Dropout/dropout/mul_1') + self.assertEquals(output.op.name, 'Dropout/dropout/mul') output.get_shape().assert_is_compatible_with( tf.convert_to_tensor(images).get_shape()) @@ -828,7 +828,7 @@ class DropoutTest(tf.test.TestCase): is_training = tf.constant(True) images = tf.random_uniform((5, height, width, 3), seed=1) output = tf.contrib.layers.dropout(images, is_training=is_training) - self.assertEquals(output.op.name, 'Dropout/dropout/mul_1') + self.assertEquals(output.op.name, 'Dropout/dropout/mul') output.get_shape().assert_is_compatible_with(images.get_shape()) def testCreateDropoutWithConstantFalse(self): diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 9f321895025..08280446723 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -22,6 +22,7 @@ import inspect import six +from tensorflow.contrib import losses from tensorflow.contrib import metrics as metrics_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops def regression_target(label_name=None, @@ -297,8 +297,17 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn): """_TargetColumn for binary classification using SVMs.""" def __init__(self, label_name, weight_column_name): + def loss_fn(logits, target): + check_shape_op = logging_ops.Assert( + math_ops.less_equal(array_ops.rank(target), 2), + ["target's shape should be either [batch_size, 1] or [batch_size]"]) + with ops.control_dependencies([check_shape_op]): + target = array_ops.reshape( + target, shape=[array_ops.shape(target)[0], 1]) + return losses.hinge_loss(logits, target) + super(_BinarySvmTargetColumn, self).__init__( - loss_fn=_binary_hinge_loss, + loss_fn=loss_fn, n_classes=2, label_name=label_name, weight_column_name=weight_column_name) @@ -331,22 +340,6 @@ def _log_loss_with_two_classes(logits, target): return loss_vec -# TODO(sibyl-vie3Poto): Move this to contrib/losses/python/losses/loss_ops.py. -def _binary_hinge_loss(logits, target): - """Method that returns the loss vector for binary hinge loss.""" - check_shape_op = logging_ops.Assert( - math_ops.less_equal( - array_ops.rank(target), 2), - ["target's shape should be either [batch_size, 1] or [batch_size]"]) - with ops.control_dependencies([check_shape_op]): - target = array_ops.reshape(target, shape=[array_ops.shape(target)[0], 1]) - # First need to convert binary labels to -1/1 labels (as floats). - all_ones = array_ops.ones_like(logits) - labels = math_ops.sub(2 * math_ops.to_float(target), all_ones) - loss_vec = nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits))) - return loss_vec - - def _softmax_cross_entropy_loss(logits, target): # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target. # Check that we got int32/int64 for classification. diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index f646cdf477c..a39254e7b49 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -61,13 +61,13 @@ class SVM(linear.LinearClassifier): whose `value` is a `SparseTensor`. - if `column` is a `RealValuedColumn, a feature with `key=column.name` whose `value` is a `Tensor`. - - if `feauture_columns` is None, then `input` must contains only real + - if `feature_columns` is None, then `input` must contains only real valued `Tensor`. Parameters: example_id_column: A string defining the feature column name representing - example ids. Used do initialize the underlying optimizer. + example ids. Used to initialize the underlying optimizer. feature_columns: An iterable containing all the feature columns used by the model. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -75,10 +75,12 @@ class SVM(linear.LinearClassifier): weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. model_dir: Directory to save model parameters, graph and etc. This can also - be used to load checkpoints from the directory into a estimator to continue - training a previously saved model. - l1_regularization: L1-regularization parameter - l2_regularization: L2-regularization parameter + be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + l1_regularization: L1-regularization parameter. Refers to global L1 + regularization (across all examples). + l2_regularization: L2-regularization parameter. Refers to global L2 + regularization (across all examples). kernels: A list of kernels for the SVM. Currently, no kernels are supported. Reserved for future use for non-linear SVMs config: RunConfig object to configure the runtime settings. @@ -100,12 +102,13 @@ class SVM(linear.LinearClassifier): symmetric_l1_regularization=l1_regularization, symmetric_l2_regularization=l2_regularization) - super(SVM, self).__init__(model_dir=model_dir, - n_classes=2, - weight_column_name=weight_column_name, - feature_columns=feature_columns, - optimizer=optimizer, - config=config) + super(SVM, self).__init__( + model_dir=model_dir, + n_classes=2, + weight_column_name=weight_column_name, + feature_columns=feature_columns, + optimizer=optimizer, + config=config) self._target_column = layers.binary_svm_target( weight_column_name=weight_column_name) diff --git a/tensorflow/contrib/losses/python/losses/__init__.py b/tensorflow/contrib/losses/python/losses/__init__.py index 081d47e4b55..d8181632bf8 100644 --- a/tensorflow/contrib/losses/python/losses/__init__.py +++ b/tensorflow/contrib/losses/python/losses/__init__.py @@ -106,6 +106,7 @@ weighted average over the individual prediction errors: @@absolute_difference @@add_loss +@@hinge_loss @@cosine_distance @@get_losses @@get_regularization_losses diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 99aab8b44c2..597e6aeda93 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops __all__ = ["absolute_difference", @@ -33,6 +34,7 @@ __all__ = ["absolute_difference", "get_losses", "get_regularization_losses", "get_total_loss", + "hinge_loss", "log_loss", "sigmoid_cross_entropy", "softmax_cross_entropy", @@ -410,6 +412,31 @@ def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None): return _compute_weighted_loss(losses, weight) +def hinge_loss(logits, target, scope=None): + """Method that returns the loss tensor for hinge loss. + + Args: + logits: The logits, a float tensor. + target: The ground truth output tensor. Its shape should match the shape of + logits. The values of the tensor are expected to be 0.0 or 1.0. + scope: The scope for the operations performed in computing the loss. + + Returns: + A `Tensor` of same shape as logits and target representing the loss values + across the batch. + + Raises: + ValueError: If the shapes of `logits` and `target` don't match. + """ + with ops.op_scope([logits, target], scope, "hinge_loss") as scope: + logits.get_shape().assert_is_compatible_with(target.get_shape()) + # We first need to convert binary labels to -1/1 labels (as floats). + target = math_ops.to_float(target) + all_ones = array_ops.ones_like(target) + labels = math_ops.sub(2 * target, all_ones) + return nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits))) + + def sum_of_squares(predictions, targets, weight=1.0, scope=None): """Adds a Sum-of-Squares loss to the training procedure. diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py index 49460ec2279..824c24451be 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -499,6 +499,42 @@ class LogLossTest(tf.test.TestCase): self.assertAlmostEqual(0.0, loss.eval(), 3) +class HingeLossTest(tf.test.TestCase): + + def testIncompatibleShapes(self): + with self.test_session(): + logits = tf.constant([[-1.0], [2.1]]) + target = tf.constant([0.0, 1.0]) + with self.assertRaises(ValueError): + _ = tf.contrib.losses.hinge_loss(logits, target).eval() + + def testAllOutsideMargin(self): + with self.test_session(): + logits = tf.constant([1.2, -1.4, -1.0, 2.1]) + target = tf.constant([1.0, 0.0, 0.0, 1.0]) + loss = tf.contrib.losses.hinge_loss(logits, target) + self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3) + + def testSomeInsideMargin(self): + with self.test_session(): + logits = tf.constant([[-0.7], [-1.4], [1.4], [0.6]]) + target = tf.constant([[0.0], [0.0], [1.0], [1.0]]) + loss = tf.contrib.losses.hinge_loss(logits, target) + # Examples 1 and 4 are on the correct side of the hyperplane but within + # the margin so they incur some (small) loss. + self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3) + + def testSomeMisclassified(self): + with self.test_session(): + logits = tf.constant([[[1.2], [0.4], [-1.0], [-1.1]]]) + target = tf.constant([[[1.0], [0.0], [0.0], [1.0]]]) + loss = tf.contrib.losses.hinge_loss(logits, target) + # Examples 2 and 4 are on the wrong side of the hyperplane so they incur + # some (fairly large) loss. + self.assertAllClose( + loss.eval(), [[[0.0], [1.4], [0.0], [2.1]]], atol=1e-3) + + class SumOfSquaresLossTest(tf.test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 9e819ba62fd..dffd139ec0d 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -9,10 +9,14 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) load("//tensorflow:tensorflow.bzl", "cuda_py_tests") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") py_library( name = "rnn_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + data = [ + ":python/ops/_lstm_ops.so", + ], srcs_version = "PY2AND3", ) @@ -27,6 +31,33 @@ cuda_py_tests( ], ) +cuda_py_tests( + name = "lstm_ops_test", + size = "small", + srcs = ["python/kernel_tests/lstm_ops_test.py"], + additional_deps = [ + ":rnn_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +tf_custom_op_library( + name = "python/ops/_lstm_ops.so", + srcs = [ + "kernels/lstm_ops.cc", + "kernels/lstm_ops.h", + "ops/lstm_ops.cc", + ], + gpu_srcs = [ + "kernels/lstm_ops_gpu.cu.cc", + "kernels/lstm_ops.h", + ], + deps = [ + "//tensorflow/core/kernels:eigen_helpers", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 2193f644849..8ead5f00045 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -12,14 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Ops for representing statistical distributions. +"""Additional RNN operations and cells. -## This package provides classes for statistical distributions. +## This package provides additional contributed RNNCells. +### Fused RNNCells +@@LSTMFusedCell + +### LSTM-like cells +@@CoupledInputForgetGateLSTMCell +@@TimeFreqLSTMCell +@@GridLSTMCell + +### RNNCell wrappers +@@AttentionCellWrapper """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.contrib.rnn.python.ops.lstm_ops import * from tensorflow.contrib.rnn.python.ops.rnn_cell import * diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc new file mode 100644 index 00000000000..74bede713c1 --- /dev/null +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -0,0 +1,1053 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/contrib/rnn/kernels/lstm_ops.h" + +#include <memory> +#include <vector> + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#if GOOGLE_CUDA + +namespace { +template <typename T> +perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { + perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); + perftools::gputools::DeviceMemory<T> typed(wrapped); + return typed; +} +} // namespace + +#endif // GOOGLE_CUDA + +namespace functor { +template <typename T> +void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx, + perftools::gputools::Stream* stream, + bool transa, bool transb, uint64 m, + uint64 n, uint64 k, T alpha, const T* a, + int lda, const T* b, int ldb, T beta, T* c, + int ldc) { +#if GOOGLE_CUDA + perftools::gputools::blas::Transpose trans[] = { + perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose}; + + auto a_ptr = AsDeviceMemory(a); + auto b_ptr = AsDeviceMemory(b); + auto c_ptr = AsDeviceMemory(c); + + bool blas_launch_status = + stream + ->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr, + lda, b_ptr, ldb, beta, &c_ptr, ldc) + .ok(); + OP_REQUIRES(ctx, blas_launch_status, errors::Aborted("CuBlasGemm failed!")); +#else + ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA.")); +#endif +} + +template struct TensorCuBlasGemm<float>; +// template struct TensorCuBlasGemm<double>; +} // end namespace functor + +template <typename Device, typename T, bool USE_CUBLAS> +class LSTMFusedCellOp : public OpKernel { + public: + explicit LSTMFusedCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* x_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor)); + + const Tensor* cs_prev_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); + + const Tensor* h_prev_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); + + const Tensor* w_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); + + const Tensor* wci_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); + + const Tensor* wcf_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); + + const Tensor* wco_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); + + const Tensor* b_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); + + const int64 batch_size = x_tensor->dim_size(0); + const int64 input_size = x_tensor->dim_size(1); + const int64 cell_size = cs_prev_tensor->dim_size(1); + + // Sanity checks for our input shapes. + OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("cs_prev.dims(0) != batch_size: ", + cs_prev_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("cs_prev.dims(1) != cell_size: ", + cs_prev_tensor->dim_size(1), " vs. ", + cell_size)); + + OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("h_prev.dims(0) != batch_size: ", + h_prev_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("h_prev.dims(1) != cell_size: ", + h_prev_tensor->dim_size(1), " vs. ", + cell_size)); + + OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, + errors::InvalidArgument( + "w.dim_size(0) != input_size + cell_size: ", + w_tensor->dim_size(0), " vs. ", input_size + cell_size)); + OP_REQUIRES( + ctx, w_tensor->dim_size(1) == cell_size * 4, + errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ", + w_tensor->dim_size(1), " vs. ", cell_size * 4)); + + OP_REQUIRES( + ctx, b_tensor->dim_size(0) == cell_size * 4, + errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ", + b_tensor->dim_size(0), " vs. ", cell_size * 4)); + + // Allocate our output tensors. + Tensor* i_tensor = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("i", TensorShape({batch_size, cell_size}), + &i_tensor)); + + Tensor* cs_tensor = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("cs", TensorShape({batch_size, cell_size}), + &cs_tensor)); + + Tensor* f_tensor = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("f", TensorShape({batch_size, cell_size}), + &f_tensor)); + + Tensor* o_tensor = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("o", TensorShape({batch_size, cell_size}), + &o_tensor)); + + Tensor* ci_tensor = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("ci", TensorShape({batch_size, cell_size}), + &ci_tensor)); + + Tensor* co_tensor = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("co", TensorShape({batch_size, cell_size}), + &co_tensor)); + + Tensor* h_tensor = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("h", TensorShape({batch_size, cell_size}), + &h_tensor)); + + // Allocate our temp tensors. + Tensor xh_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DataTypeToEnum<T>::v(), + TensorShape({batch_size, input_size + cell_size}), + &xh_tensor)); + + Tensor icfo_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<T>::v(), + TensorShape({batch_size, cell_size * 4}), + &icfo_tensor)); + + const Device& device = ctx->eigen_device<Device>(); + perftools::gputools::Stream* stream = + std::is_same<Device, GPUDevice>::value + ? ctx->op_device_context()->stream() + : nullptr; + + functor::LSTMFusedCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size, + cell_size)( + ctx, stream, device, forget_bias_, cell_clip_, use_peephole_, + x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(), + h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(), + wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(), + xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(), + f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(), + co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>()); + } + + private: + float forget_bias_; + float cell_clip_; + bool use_peephole_; +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("LSTMFusedCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + LSTMFusedCellOp<CPUDevice, T, false>); +REGISTER_KERNEL(float); +// REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void LSTMFusedCellFprop<GPUDevice, T, true>::operator()( \ + OpKernelContext* ctx, perftools::gputools::Stream* stream, \ + const GPUDevice& d, const T forget_bias, const T cell_clip, \ + bool use_peephole, typename TTypes<T>::ConstMatrix x, \ + typename TTypes<T>::ConstMatrix cs_prev, \ + typename TTypes<T>::ConstMatrix h_prev, \ + typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ + typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ + typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \ + typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \ + typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \ + typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \ + typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \ + \ + extern template struct LSTMFusedCellFprop<GPUDevice, T, true>; + +DECLARE_GPU_SPEC(float); +// DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // end namespace functor + +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("LSTMFusedCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ + LSTMFusedCellOp<GPUDevice, T, true>); + +REGISTER_GPU_KERNEL(float); +// REGISTER_GPU_KERNEL(double); +#undef REGISTER_GPU_KERNEL +#endif // GOOGLE_CUDA + +template <typename Device, typename T, bool USE_CUBLAS> +class LSTMFusedCellGradOp : public OpKernel { + public: + explicit LSTMFusedCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* x_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor)); + + const Tensor* cs_prev_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); + + const Tensor* h_prev_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); + + const Tensor* w_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); + + const Tensor* wci_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); + + const Tensor* wcf_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); + + const Tensor* wco_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); + + const Tensor* b_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); + + const Tensor* i_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("i", &i_tensor)); + + const Tensor* cs_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_tensor)); + + const Tensor* f_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("f", &f_tensor)); + + const Tensor* o_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("o", &o_tensor)); + + const Tensor* ci_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_tensor)); + + const Tensor* co_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("co", &co_tensor)); + + const Tensor* cs_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad_tensor)); + + const Tensor* h_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad_tensor)); + + const int64 batch_size = x_tensor->dim_size(0); + const int64 input_size = x_tensor->dim_size(1); + const int64 cell_size = cs_prev_tensor->dim_size(1); + + // Sanity checks for our input shapes. + OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("cs_prev.dims(0) != batch_size: ", + cs_prev_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("cs_prev.dims(1) != cell_size: ", + cs_prev_tensor->dim_size(1), " vs. ", + cell_size)); + + OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("h_prev.dims(0) != batch_size: ", + h_prev_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("h_prev.dims(1) != cell_size: ", + h_prev_tensor->dim_size(1), " vs. ", + cell_size)); + + OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, + errors::InvalidArgument( + "w.dim_size(0) != input_size + cell_size: ", + w_tensor->dim_size(0), " vs. ", input_size + cell_size)); + OP_REQUIRES( + ctx, w_tensor->dim_size(1) == cell_size * 4, + errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ", + w_tensor->dim_size(1), " vs. ", cell_size * 4)); + + OP_REQUIRES( + ctx, b_tensor->dim_size(0) == cell_size * 4, + errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ", + b_tensor->dim_size(0), " vs. ", cell_size * 4)); + + OP_REQUIRES( + ctx, i_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("i.dim_size(0) != batch_size: ", + i_tensor->dim_size(0), " vs. ", batch_size)); + OP_REQUIRES( + ctx, i_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("i.dim_size(1) != cell_size: ", + i_tensor->dim_size(1), " vs. ", cell_size)); + + OP_REQUIRES( + ctx, cs_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("cs.dim_size(0) != batch_size: ", + cs_tensor->dim_size(0), " vs. ", batch_size)); + OP_REQUIRES( + ctx, cs_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("cs.dim_size(1) != cell_size: ", + cs_tensor->dim_size(1), " vs. ", cell_size)); + + OP_REQUIRES( + ctx, f_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("f.dim_size(0) != batch_size: ", + f_tensor->dim_size(0), " vs. ", batch_size)); + OP_REQUIRES( + ctx, f_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("i.dim_size(1) != cell_size: ", + f_tensor->dim_size(1), " vs. ", cell_size)); + + OP_REQUIRES( + ctx, o_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("o.dim_size(0) != batch_size: ", + o_tensor->dim_size(0), " vs. ", batch_size)); + OP_REQUIRES( + ctx, o_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("o.dim_size(1) != cell_size: ", + o_tensor->dim_size(1), " vs. ", cell_size)); + + OP_REQUIRES( + ctx, ci_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("ci.dim_size(0) != batch_size: ", + ci_tensor->dim_size(0), " vs. ", batch_size)); + OP_REQUIRES( + ctx, ci_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("ci.dim_size(1) != cell_size: ", + ci_tensor->dim_size(1), " vs. ", cell_size)); + + OP_REQUIRES( + ctx, co_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("co.dim_size(0) != batch_size: ", + co_tensor->dim_size(0), " vs. ", batch_size)); + OP_REQUIRES( + ctx, co_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("co.dim_size(1) != cell_size: ", + co_tensor->dim_size(1), " vs. ", cell_size)); + + OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "cs_grad_tensor.dims(0) != batch_size: ", + cs_grad_tensor->dim_size(0), " vs. ", batch_size)); + OP_REQUIRES(ctx, cs_grad_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("cs_grad_tensor.dims(1) != cell_size: ", + cs_grad_tensor->dim_size(1), " vs. ", + cell_size)); + + OP_REQUIRES(ctx, h_grad_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("h_grad_tensor.dims(0) != batch_size: ", + h_grad_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, h_grad_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("h_grad_tensor.dims(1) != cell_size: ", + h_grad_tensor->dim_size(1), " vs. ", + cell_size)); + + // Allocate our output tensors. + Tensor* cs_prev_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("cs_prev_grad", + TensorShape({batch_size, cell_size}), + &cs_prev_grad_tensor)); + + Tensor* dicfo_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + "dicfo", TensorShape({batch_size, cell_size * 4}), + &dicfo_tensor)); + + Tensor* wci_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(), + &wci_grad_tensor)); + + Tensor* wcf_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(), + &wcf_grad_tensor)); + + Tensor* wco_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(), + &wco_grad_tensor)); + + // Allocate our temp tensors. + Tensor do_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + TensorShape({batch_size, cell_size}), + &do_tensor)); + + Tensor dcs_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + TensorShape({batch_size, cell_size}), + &dcs_tensor)); + + Tensor dci_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + TensorShape({batch_size, cell_size}), + &dci_tensor)); + + Tensor df_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + TensorShape({batch_size, cell_size}), + &df_tensor)); + + Tensor di_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + TensorShape({batch_size, cell_size}), + &di_tensor)); + + const Device& device = ctx->eigen_device<Device>(); + perftools::gputools::Stream* stream = + std::is_same<Device, GPUDevice>::value + ? ctx->op_device_context()->stream() + : nullptr; + + functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>()); + functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>()); + functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<float>()); + + functor::LSTMFusedCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size, + cell_size)( + ctx, stream, device, use_peephole_, x_tensor->matrix<T>(), + cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(), + w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(), + wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(), + cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(), + ci_tensor->matrix<T>(), co_tensor->matrix<T>(), + cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(), + do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(), + df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(), + cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(), + wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>()); + } + + protected: + bool use_peephole_; +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("LSTMFusedCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + LSTMFusedCellGradOp<CPUDevice, T, false>); +REGISTER_KERNEL(float); +// REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void LSTMFusedCellBprop<GPUDevice, T, true>::operator()( \ + OpKernelContext* ctx, perftools::gputools::Stream* stream, \ + const GPUDevice& d, bool use_peephole, \ + typename TTypes<T>::ConstMatrix x, \ + typename TTypes<T>::ConstMatrix cs_prev, \ + typename TTypes<T>::ConstMatrix h_prev, \ + typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ + typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ + typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \ + typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \ + typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \ + typename TTypes<T>::ConstMatrix co, \ + typename TTypes<T>::ConstMatrix cs_grad, \ + typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \ + typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \ + typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \ + typename TTypes<T>::Matrix dicfo, \ + typename TTypes<T>::Matrix cs_prev_grad, \ + typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \ + typename TTypes<T>::Vec wco_grad); \ + \ + extern template struct LSTMFusedCellBprop<GPUDevice, T, true>; + +DECLARE_GPU_SPEC(float); +// DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("LSTMFusedCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ + LSTMFusedCellGradOp<GPUDevice, T, true>); + +REGISTER_GPU_KERNEL(float); +// REGISTER_GPU_KERNEL(double); +#undef REGISTER_GPU_KERNEL +#endif // GOOGLE_CUDA + +template <typename Device, typename T, bool USE_CUBLAS> +class FusedLSTMOp : public OpKernel { + public: + explicit FusedLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_len", &max_len_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* seq_len_max_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor)); + + OpInputList x_list; + OP_REQUIRES_OK(ctx, ctx->input_list("x", &x_list)); + const int64 batch_size = x_list[0].dim_size(0); + const int64 input_size = x_list[0].dim_size(1); + + const Tensor* cs_prev_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); + + const Tensor* h_prev_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); + + const Tensor* w_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); + + const Tensor* wci_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); + + const Tensor* wcf_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); + + const Tensor* wco_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); + + const Tensor* b_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); + const int64 cell_size = b_tensor->dim_size(0) / 4; + + OpOutputList i_list; + OP_REQUIRES_OK(ctx, ctx->output_list("i", &i_list)); + + OpOutputList cs_list; + OP_REQUIRES_OK(ctx, ctx->output_list("cs", &cs_list)); + + OpOutputList f_list; + OP_REQUIRES_OK(ctx, ctx->output_list("f", &f_list)); + + OpOutputList o_list; + OP_REQUIRES_OK(ctx, ctx->output_list("o", &o_list)); + + OpOutputList ci_list; + OP_REQUIRES_OK(ctx, ctx->output_list("ci", &ci_list)); + + OpOutputList co_list; + OP_REQUIRES_OK(ctx, ctx->output_list("co", &co_list)); + + OpOutputList h_list; + OP_REQUIRES_OK(ctx, ctx->output_list("h", &h_list)); + + TensorShape batch_cell_shape({batch_size, cell_size}); + for (int64 t = 0; t < max_len_; ++t) { + Tensor* i_tensor = nullptr; + OP_REQUIRES_OK(ctx, i_list.allocate(t, batch_cell_shape, &i_tensor)); + + Tensor* cs_tensor = nullptr; + OP_REQUIRES_OK(ctx, cs_list.allocate(t, batch_cell_shape, &cs_tensor)); + + Tensor* f_tensor = nullptr; + OP_REQUIRES_OK(ctx, f_list.allocate(t, batch_cell_shape, &f_tensor)); + + Tensor* o_tensor = nullptr; + OP_REQUIRES_OK(ctx, o_list.allocate(t, batch_cell_shape, &o_tensor)); + + Tensor* ci_tensor = nullptr; + OP_REQUIRES_OK(ctx, ci_list.allocate(t, batch_cell_shape, &ci_tensor)); + + Tensor* co_tensor = nullptr; + OP_REQUIRES_OK(ctx, co_list.allocate(t, batch_cell_shape, &co_tensor)); + + Tensor* h_tensor = nullptr; + OP_REQUIRES_OK(ctx, h_list.allocate(t, batch_cell_shape, &h_tensor)); + } + + Tensor xh_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DataTypeToEnum<T>::v(), + TensorShape({batch_size, input_size + cell_size}), + &xh_tensor)); + + Tensor icfo_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<T>::v(), + TensorShape({batch_size, cell_size * 4}), + &icfo_tensor)); + + const Device& device = ctx->eigen_device<Device>(); + perftools::gputools::Stream* stream = + std::is_same<Device, GPUDevice>::value + ? ctx->op_device_context()->stream() + : nullptr; + + const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()(); + for (int64 t = 0; t < seq_len_max; ++t) { + const Tensor& x_tensor = x_list[t]; + const Tensor& cs_prev_tensor2 = + t == 0 ? *cs_prev_tensor : *cs_list[t - 1]; + const Tensor& h_prev_tensor2 = t == 0 ? *h_prev_tensor : *h_list[t - 1]; + + Tensor* i_tensor = i_list[t]; + Tensor* cs_tensor = cs_list[t]; + Tensor* f_tensor = f_list[t]; + Tensor* o_tensor = o_list[t]; + Tensor* ci_tensor = ci_list[t]; + Tensor* co_tensor = co_list[t]; + Tensor* h_tensor = h_list[t]; + + functor::LSTMFusedCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size, + cell_size)( + ctx, stream, device, forget_bias_, cell_clip_, use_peephole_, + x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(), + h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(), + wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(), + b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor->matrix<T>(), + cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(), + ci_tensor->matrix<T>(), co_tensor->matrix<T>(), + icfo_tensor.matrix<T>(), h_tensor->matrix<T>()); + } + + for (int64 t = seq_len_max; t < max_len_; ++t) { + Tensor* cs_tensor = cs_list[t]; + Tensor* h_tensor = h_list[t]; + + functor::TensorZero<Device, T>()(device, cs_tensor->flat<float>()); + functor::TensorZero<Device, T>()(device, h_tensor->flat<float>()); + } + } + + private: + int64 max_len_; + float forget_bias_; + float cell_clip_; + bool use_peephole_; +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("FusedLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + FusedLSTMOp<CPUDevice, T, false>); +REGISTER_KERNEL(float); +// REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d, \ + typename TTypes<T>::Flat t); \ + \ + extern template struct TensorZero<GPUDevice, T>; + +DECLARE_GPU_SPEC(float); +// DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // end namespace functor + +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("FusedLSTM") \ + .Device(DEVICE_GPU) \ + .HostMemory("seq_len_max") \ + .TypeConstraint<T>("T"), \ + FusedLSTMOp<GPUDevice, T, true>); + +REGISTER_GPU_KERNEL(float); +// REGISTER_GPU_KERNEL(double); +#undef REGISTER_GPU_KERNEL +#endif // GOOGLE_CUDA + +template <typename Device, typename T, bool USE_CUBLAS> +class FusedLSTMGradOp : public OpKernel { + public: + explicit FusedLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_len", &max_len_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* seq_len_max_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor)); + + OpInputList x_list; + OP_REQUIRES_OK(ctx, ctx->input_list("x", &x_list)); + const int64 batch_size = x_list[0].dim_size(0); + const int64 input_size = x_list[0].dim_size(1); + + const Tensor* cs_prev_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); + + const Tensor* h_prev_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); + + const Tensor* w_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); + const int64 cell_size = w_tensor->dim_size(1) / 4; + OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0), + errors::InvalidArgument("w matrix rows don't match: ", + input_size + cell_size, " vs. ", + w_tensor->dim_size(0))); + + const Tensor* wci_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); + + const Tensor* wcf_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); + + const Tensor* wco_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); + + const Tensor* b_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); + OP_REQUIRES( + ctx, cell_size == b_tensor->dim_size(0) / 4, + errors::InvalidArgument("w and b cell_size don't match: ", cell_size, + " vs. ", b_tensor->dim_size(0))); + + OpInputList i_list; + OP_REQUIRES_OK(ctx, ctx->input_list("i", &i_list)); + + OpInputList cs_list; + OP_REQUIRES_OK(ctx, ctx->input_list("cs", &cs_list)); + + OpInputList f_list; + OP_REQUIRES_OK(ctx, ctx->input_list("f", &f_list)); + + OpInputList o_list; + OP_REQUIRES_OK(ctx, ctx->input_list("o", &o_list)); + + OpInputList ci_list; + OP_REQUIRES_OK(ctx, ctx->input_list("ci", &ci_list)); + + OpInputList co_list; + OP_REQUIRES_OK(ctx, ctx->input_list("co", &co_list)); + + OpInputList h_list; + OP_REQUIRES_OK(ctx, ctx->input_list("h", &h_list)); + + OpInputList cs_grad_list; + OP_REQUIRES_OK(ctx, ctx->input_list("cs_grad", &cs_grad_list)); + + OpInputList h_grad_list; + OP_REQUIRES_OK(ctx, ctx->input_list("h_grad", &h_grad_list)); + + OpOutputList x_grad_list; + OP_REQUIRES_OK(ctx, ctx->output_list("x_grad", &x_grad_list)); + + Tensor* cs_prev_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("cs_prev_grad", cs_prev_tensor->shape(), + &cs_prev_grad_tensor)); + + Tensor* h_prev_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("h_prev_grad", h_prev_tensor->shape(), + &h_prev_grad_tensor)); + + Tensor* w_grad_tensor = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("w_grad", w_tensor->shape(), &w_grad_tensor)); + + Tensor* wci_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(), + &wci_grad_tensor)); + + Tensor* wcf_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(), + &wcf_grad_tensor)); + + Tensor* wco_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(), + &wco_grad_tensor)); + + Tensor* b_grad_tensor = nullptr; + OP_REQUIRES_OK( + ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor)); + + TensorShape batch_input_shape({batch_size, input_size}); + TensorShape batch_cell_shape({batch_size, cell_size}); + for (int64 t = 0; t < max_len_; ++t) { + Tensor* x_grad_tensor = nullptr; + OP_REQUIRES_OK( + ctx, x_grad_list.allocate(t, batch_input_shape, &x_grad_tensor)); + } + + Tensor xh_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DataTypeToEnum<T>::v(), + TensorShape({batch_size, input_size + cell_size}), + &xh_tensor)); + + Tensor xh_grad_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + xh_tensor.shape(), &xh_grad_tensor)); + + Tensor do_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + batch_cell_shape, &do_tensor)); + + Tensor dcs_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + batch_cell_shape, &dcs_tensor)); + + Tensor dci_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + batch_cell_shape, &dci_tensor)); + + Tensor df_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + batch_cell_shape, &df_tensor)); + + Tensor di_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + batch_cell_shape, &di_tensor)); + + Tensor dicfo_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<T>::v(), + TensorShape({batch_size, cell_size * 4}), + &dicfo_tensor)); + + Tensor cs_grad_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + batch_cell_shape, &cs_grad_tensor)); + + Tensor h_grad_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), + batch_cell_shape, &h_grad_tensor)); + + + const Device& device = ctx->eigen_device<Device>(); + perftools::gputools::Stream* stream = + std::is_same<Device, GPUDevice>::value + ? ctx->op_device_context()->stream() + : nullptr; + + functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<float>()); + functor::TensorZero<Device, T>()(device, + cs_prev_grad_tensor->flat<float>()); + functor::TensorZero<Device, T>()(device, h_grad_tensor.flat<float>()); + functor::TensorZero<Device, T>()(device, h_prev_grad_tensor->flat<float>()); + functor::TensorZero<Device, T>()(device, w_grad_tensor->flat<float>()); + functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>()); + functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>()); + functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<float>()); + functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<float>()); + + const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()(); + for (int64 t = seq_len_max - 1; t >= 0; --t) { + const Tensor& x_tensor = x_list[t]; + const Tensor& cs_prev_tensor2 = t == 0 ? *cs_prev_tensor : cs_list[t - 1]; + const Tensor& h_prev_tensor2 = t == 0 ? *h_prev_tensor : h_list[t - 1]; + const Tensor& i_tensor = i_list[t]; + const Tensor& cs_tensor = cs_list[t]; + const Tensor& f_tensor = f_list[t]; + const Tensor& o_tensor = o_list[t]; + const Tensor& ci_tensor = ci_list[t]; + const Tensor& co_tensor = co_list[t]; + + // Grab previous CS grad. + const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor; + functor::TensorAdd<Device, T>()( + device, const_cs_prev_grad_tensor.flat<T>(), + cs_grad_list[t].flat<T>(), cs_grad_tensor.flat<T>()); + + // Combine previous h grad and h grad coming on top. + const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor; + functor::TensorAdd<Device, T>()( + device, const_h_prev_grad_tensor.flat<T>(), h_grad_list[t].flat<T>(), + h_grad_tensor.flat<T>()); + + const Tensor& const_cs_grad_tensor = cs_grad_tensor; + const Tensor& const_h_grad_tensor = h_grad_tensor; + + Tensor* x_grad_tensor = x_grad_list[t]; + functor::FusedLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size, + cell_size)( + ctx, stream, device, use_peephole_, x_tensor.matrix<T>(), + cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(), + w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(), + wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(), + i_tensor.matrix<T>(), cs_tensor.matrix<T>(), f_tensor.matrix<T>(), + o_tensor.matrix<T>(), ci_tensor.matrix<T>(), co_tensor.matrix<T>(), + const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(), + do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(), + df_tensor.matrix<T>(), di_tensor.matrix<T>(), + dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(), + h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(), + x_grad_tensor->matrix<T>(), w_grad_tensor->matrix<T>(), + wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(), + wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>()); + } + + for (int64 t = seq_len_max; t < max_len_; ++t) { + Tensor* x_grad_tensor = x_grad_list[t]; + functor::TensorZero<Device, T>()(device, x_grad_tensor->flat<T>()); + } + } + + private: + int64 max_len_; + bool use_peephole_; +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("FusedLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + FusedLSTMGradOp<CPUDevice, T, false>); +REGISTER_KERNEL(float); +// REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d, \ + typename TTypes<T>::ConstFlat src, \ + typename TTypes<T>::Flat dst); \ + \ + template <> \ + void TensorAdd<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstFlat a, \ + typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c); \ + \ + template <> \ + void FusedLSTMBprop<GPUDevice, T, true>::operator()( \ + OpKernelContext* ctx, perftools::gputools::Stream* stream, \ + const GPUDevice& d, bool use_peephole, \ + typename TTypes<T>::ConstMatrix x, \ + typename TTypes<T>::ConstMatrix cs_prev, \ + typename TTypes<T>::ConstMatrix h_prev, \ + typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ + typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ + typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \ + typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs, \ + typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o, \ + typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, \ + typename TTypes<T>::ConstMatrix cs_grad, \ + typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \ + typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \ + typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \ + typename TTypes<T>::Matrix dicfo, \ + typename TTypes<T>::Matrix cs_prev_grad, \ + typename TTypes<T>::Matrix h_prev_grad, \ + typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad, \ + typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, \ + typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, \ + typename TTypes<T>::Vec b_grad); \ + \ + extern template struct TensorCopy<GPUDevice, T>; \ + extern template struct TensorAdd<GPUDevice, T>; \ + extern template struct FusedLSTMBprop<GPUDevice, T, true>; + +DECLARE_GPU_SPEC(float); +// DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // end namespace functor + +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("FusedLSTMGrad") \ + .Device(DEVICE_GPU) \ + .HostMemory("seq_len_max") \ + .TypeConstraint<T>("T"), \ + FusedLSTMGradOp<GPUDevice, T, true>); + +REGISTER_GPU_KERNEL(float); +// REGISTER_GPU_KERNEL(double); +#undef REGISTER_GPU_KERNEL +#endif // GOOGLE_CUDA + +} // end namespace tensorflow diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h new file mode 100644 index 00000000000..bcb7bfa1e6e --- /dev/null +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -0,0 +1,420 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_activations.h" +#include "tensorflow/core/platform/types.h" + +namespace perftools { +namespace gputools { +class Stream; +} // end namespace gputools +} // end namespace perftools + +namespace tensorflow { +class OpKernelContext; + +namespace functor { + +template <typename Device, typename T> +struct TensorZero { + void operator()(const Device& d, typename TTypes<T>::Flat t) { + t.device(d) = t.constant(T(0)); + } +}; + +template <typename Device, typename T> +struct TensorCopy { + void operator()(const Device& d, typename TTypes<T>::ConstFlat src, + typename TTypes<T>::Flat dst) { + dst.device(d) = src; + } +}; + +template <typename Device, typename T> +struct TensorAdd { + void operator()(const Device& d, typename TTypes<T>::ConstFlat a, + typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) { + c.device(d) = a + b; + } +}; + +template <typename Device, typename T> +struct TensorZeroPadding { + void operator()(const Device& d, const int64 time_idx, + typename TTypes<int64>::ConstVec seq_len, + typename TTypes<float>::Vec mask, + typename TTypes<float>::Matrix m) { + // mask is shape [batch_size]. + mask.device(d) = seq_len.constant(time_idx) < seq_len; + + // m_shape is [batch_size, 1]. + Eigen::array<Eigen::DenseIndex, 2> m_shape({m.dimensions()[0], 1}); + // broadcast_shape is [1, units]. + Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({1, m.dimensions()[1]}); + + // m is shape [batch_size, units]. + m.device(d) = m * mask.reshape(m_shape).broadcast(broadcast_shape); + } +}; + +template <typename T> +struct TensorCuBlasGemm { + void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream, + bool transa, bool transb, uint64 m, uint64 n, uint64 k, + T alpha, const T* a, int lda, const T* b, int ldb, T beta, + T* c, int ldc); +}; + +template <typename Device, typename T, bool USE_CUBLAS> +struct TensorBlasGemm; + +template <typename Device, typename T> +struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> { + static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream, + const Device& d, bool transa, bool transb, T alpha, + typename TTypes<T>::ConstMatrix a, + typename TTypes<T>::ConstMatrix b, T beta, + typename TTypes<T>::Matrix c) { + int64 m = c.dimensions()[0]; + int64 n = c.dimensions()[1]; + int64 k = transa ? a.dimensions()[0] : a.dimensions()[1]; + + TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(), + transb ? k : n, a.data(), transa ? m : k, beta, + c.data(), n); + } +}; + +template <typename Device, typename T> +struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> { + static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream, + const Device& d, bool transa, bool transb, T alpha, + typename TTypes<T>::ConstMatrix a, + typename TTypes<T>::ConstMatrix b, T beta, + typename TTypes<T>::Matrix c) { + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; + contract_pairs[0] = + Eigen::IndexPair<Eigen::DenseIndex>(transa == false, transb == true); + if (alpha == T(1) && beta == T(0)) { + c.device(d) = a.contract(b, contract_pairs); + } else if (alpha == T(1) && beta == T(1)) { + c.device(d) += a.contract(b, contract_pairs); + } else { + c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) + + c.constant(beta) * c; + } + } +}; + +struct LSTMFusedCell { + LSTMFusedCell(const int batch_size, const int input_size, const int cell_size) + : batch_size_(batch_size), + input_size_(input_size), + cell_size_(cell_size) {} + + inline Eigen::array<Eigen::DenseIndex, 2> icfo_i_offsets() const { + return {0, 0}; + } + + inline Eigen::array<Eigen::DenseIndex, 2> icfo_c_offsets() const { + return {0, cell_size_}; + } + + inline Eigen::array<Eigen::DenseIndex, 2> icfo_f_offsets() const { + return {0, cell_size_ * 2}; + } + + inline Eigen::array<Eigen::DenseIndex, 2> icfo_o_offsets() const { + return {0, cell_size_ * 3}; + } + + inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const { + return {batch_size_, cell_size_}; + } + + inline Eigen::array<Eigen::DenseIndex, 2> xh_x_offsets() const { + return {0, 0}; + } + + inline Eigen::array<Eigen::DenseIndex, 2> xh_x_extents() const { + return {batch_size_, input_size_}; + } + + inline Eigen::array<Eigen::DenseIndex, 2> xh_h_offsets() const { + return {0, input_size_}; + } + + inline Eigen::array<Eigen::DenseIndex, 2> xh_h_extents() const { + return {batch_size_, cell_size_}; + } + + protected: + const int batch_size_; + const int input_size_; + const int cell_size_; +}; + +template <typename Device, typename T, bool USE_CUBLAS> +struct LSTMFusedCellFprop : public LSTMFusedCell { + LSTMFusedCellFprop(const int batch_size, const int input_size, + const int cell_size) + : LSTMFusedCell(batch_size, input_size, cell_size) {} + + void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream, + const Device& d, const T forget_bias, const T cell_clip, + bool use_peephole, typename TTypes<T>::ConstMatrix x, + typename TTypes<T>::ConstMatrix cs_prev, + typename TTypes<T>::ConstMatrix h_prev, + typename TTypes<T>::ConstMatrix w, + typename TTypes<T>::ConstVec wci, + typename TTypes<T>::ConstVec wcf, + typename TTypes<T>::ConstVec wco, + typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, + typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, + typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, + typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, + typename TTypes<T>::Matrix icfo, + typename TTypes<T>::Matrix h) { + // Concat xh = [x, h]. + xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x; + xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev; + + // states1 = xh * w + b + typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions()); + TensorBlasGemm<Device, T, USE_CUBLAS>::compute( + ctx, stream, d, false, false, T(1), const_xh, w, T(0), icfo); + Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]}); + Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1}); + icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); + + Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_}); + Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1}); + + // Input gate. + if (use_peephole) { + auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape); + i.device(d) = + (icfo.slice(icfo_i_offsets(), cell_extents()) + i_peep).sigmoid(); + } else { + i.device(d) = icfo.slice(icfo_i_offsets(), cell_extents()).sigmoid(); + } + + // Cell input. + ci.device(d) = icfo.slice(icfo_c_offsets(), cell_extents()).tanh(); + + // Forget gate (w/ bias). + if (use_peephole) { + auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + f.device(d) = (icfo.slice(icfo_f_offsets(), cell_extents()) + + f.constant(forget_bias) + f_peep) + .sigmoid(); + } else { + f.device(d) = (icfo.slice(icfo_f_offsets(), cell_extents()) + + f.constant(forget_bias)) + .sigmoid(); + } + + // cs = ci .* i + f .* cs_prev + cs.device(d) = i * ci + f * cs_prev; + + if (cell_clip > 0.0f) { + cs.device(d) = + cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op<T>()); + } + + // co = tanh(cs) + co.device(d) = cs.tanh(); + + // Output gate. + if (use_peephole) { + auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape); + o.device(d) = + (icfo.slice(icfo_o_offsets(), cell_extents()) + o_peep).sigmoid(); + } else { + o.device(d) = icfo.slice(icfo_o_offsets(), cell_extents()).sigmoid(); + } + + // h = o .* co + h.device(d) = o * co; + } +}; + +template <typename Device, typename T, bool USE_CUBLAS> +struct LSTMFusedCellBprop : public LSTMFusedCell { + LSTMFusedCellBprop(const int batch_size, const int input_size, + const int cell_size) + : LSTMFusedCell(batch_size, input_size, cell_size) {} + + void operator()( + OpKernelContext* ctx, perftools::gputools::Stream* stream, + const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x, + typename TTypes<T>::ConstMatrix cs_prev, + typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, + typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, + typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, + typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs, + typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o, + typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, + typename TTypes<T>::ConstMatrix cs_grad, + typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, + typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, + typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, + typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad, + typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, + typename TTypes<T>::Vec wco_grad) { + // do[t] = sigm'(o[t]) .* dh[t] .* co[t] + do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; + + // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] + dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; + + Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_}); + Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1}); + if (use_peephole) { + dcs.device(d) = + dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); + } + + // dci[t] = tanh'(ci[t]) dcs[t] i[t] + dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; + + // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] + df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; + + // di[t] = sigm'(i[t]) dcs[t] ci[t] + di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; + + dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di; + dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci; + dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df; + dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_; + + cs_prev_grad.device(d) = dcs * f; + if (use_peephole) { + cs_prev_grad.device(d) = + cs_prev_grad + + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + } + + if (use_peephole) { + wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0})); + wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0})); + wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0})); + } + } +}; + +template <typename Device, typename T, bool USE_CUBLAS> +struct FusedLSTMBprop : public LSTMFusedCell { + FusedLSTMBprop(const int batch_size, const int input_size, + const int cell_size) + : LSTMFusedCell(batch_size, input_size, cell_size) {} + + void operator()( + OpKernelContext* ctx, perftools::gputools::Stream* stream, + const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x, + typename TTypes<T>::ConstMatrix cs_prev, + typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, + typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, + typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, + typename TTypes<T>::Matrix xh, typename TTypes<T>::ConstMatrix i, + typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, + typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, + typename TTypes<T>::ConstMatrix co, + typename TTypes<T>::ConstMatrix cs_grad, + typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, + typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, + typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, + typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad, + typename TTypes<T>::Matrix h_prev_grad, + typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad, + typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, + typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, + typename TTypes<T>::Vec b_grad) { + // do[t] = sigm'(o[t]) .* dh[t] .* co[t] + do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; + + // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] + dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; + + Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_}); + Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1}); + if (use_peephole) { + dcs.device(d) = + dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); + } + + // dci[t] = tanh'(ci[t]) dcs[t] i[t] + dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; + + // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] + df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; + + // di[t] = sigm'(i[t]) dcs[t] ci[t] + di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; + + dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di; + dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci; + dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df; + dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_; + + cs_prev_grad.device(d) = dcs * f; + if (use_peephole) { + cs_prev_grad.device(d) = + cs_prev_grad + + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + } + + // xh_grad. + typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(), + dicfo.dimensions()); + TensorBlasGemm<Device, T, USE_CUBLAS>::compute( + ctx, stream, d, false, true, T(1), const_dicfo, w, T(0), xh_grad); + + // xh. + xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x; + xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev; + typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions()); + + // x_grad. + x_grad.device(d) = xh_grad.slice(xh_x_offsets(), xh_x_extents()); + h_prev_grad.device(d) = xh_grad.slice(xh_h_offsets(), xh_h_extents()); + + // w_grad. + TensorBlasGemm<Device, T, USE_CUBLAS>::compute( + ctx, stream, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad); + + // b_grad. + b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0})); + + if (use_peephole) { + wci_grad.device(d) += (di * cs_prev).sum(Eigen::array<int, 1>({0})); + wcf_grad.device(d) += (df * cs_prev).sum(Eigen::array<int, 1>({0})); + wco_grad.device(d) += (do_ * cs).sum(Eigen::array<int, 1>({0})); + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_ diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc new file mode 100644 index 00000000000..2c5e500c289 --- /dev/null +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -0,0 +1,41 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/rnn/kernels/lstm_ops.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_GPU_SPECS(T) \ + template struct TensorZero<GPUDevice, T>; \ + template struct TensorCopy<GPUDevice, T>; \ + template struct TensorAdd<GPUDevice, T>; \ + template struct LSTMFusedCellFprop<GPUDevice, T, true>; \ + template struct LSTMFusedCellBprop<GPUDevice, T, true>; \ + template struct FusedLSTMBprop<GPUDevice, T, true>; + +DEFINE_GPU_SPECS(float); +// DEFINE_GPU_SPECS(double); +#undef DEFINE_GPU_SPECS + +} // end namespace functor +} // end namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/rnn/ops/lstm_ops.cc b/tensorflow/contrib/rnn/ops/lstm_ops.cc new file mode 100644 index 00000000000..a55c6232886 --- /dev/null +++ b/tensorflow/contrib/rnn/ops/lstm_ops.cc @@ -0,0 +1,180 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("LSTMFusedCell") + .Input("x: T") + .Input("cs_prev: T") + .Input("h_prev: T") + .Input("w: T") + .Input("wci: T") + .Input("wcf: T") + .Input("wco: T") + .Input("b: T") + .Output("i: T") + .Output("cs: T") + .Output("f: T") + .Output("o: T") + .Output("ci: T") + .Output("co: T") + .Output("h: T") + .Attr("forget_bias: float = 1.0") + .Attr("cell_clip: float = 3.0") + .Attr("use_peephole: bool = false") + .Attr("T: {float}") + .Doc(R"doc( +Computes the LSTM cell forward propagation for 1 time step. + +This implementation uses 1 weight matrix and 1 bias vector, there is no +diagonal peephole connection. + +This kernel op implements the following mathematical equations: + +```python +xh = [x, h_prev] +[i, f, ci, o] = xh * w + b +f = f + forget_bias + +i = sigmoid(i) +f = sigmoid(f) +ci = tanh(ci) +o = sigmoid(o) + +cs = ci .* i + cs_prev .* f +co = tanh(cs) + +h = co .* o +``` + +forget_bias: The forget gate bias. +x: The input to the LSTM cell. +w: The weight matrix. +b: The bias vector. +i: The input gate. +cs: The cell state before the tanh. +f: The forget gate. +o: The output gate. +ci: The cell input. +co: The cell after the tanh. +h: The output h vector. +)doc"); + +REGISTER_OP("LSTMFusedCellGrad") + .Input("x: T") + .Input("cs_prev: T") + .Input("h_prev: T") + .Input("w: T") + .Input("wci: T") + .Input("wcf: T") + .Input("wco: T") + .Input("b: T") + .Input("i: T") + .Input("cs: T") + .Input("f: T") + .Input("o: T") + .Input("ci: T") + .Input("co: T") + .Input("cs_grad: T") + .Input("h_grad: T") + .Output("cs_prev_grad: T") + .Output("dicfo: T") + .Output("wci_grad: T") + .Output("wcf_grad: T") + .Output("wco_grad: T") + .Attr("use_peephole: bool") + .Attr("T: {float}") + .Doc(R"doc( +Computes the LSTM cell backward propagation for 1 timestep. + +This implementation is to be used in conjunction of LSTMFusedCell. + +x: The input to the LSTM cell. +cs_prev: The previous cell state. +h_prev: The previous h state. +w: The weight matrix. +b: The bias vector. +i: The input gate. +cs: The cell state before the tanh. +f: The forget gate. +o: The output gate. +ci: The cell input. +co: The cell after the tanh. +h_grad: THe gradient of h vector. +cs_prev_grad: The gradient of cs. +dicfo: The derivative wrt to [i, cs, f, o]. +)doc"); + +REGISTER_OP("FusedLSTM") + .Input("seq_len_max: int64") + .Input("x: max_len * T") + .Input("cs_prev: T") + .Input("h_prev: T") + .Input("w: T") + .Input("wci: T") + .Input("wcf: T") + .Input("wco: T") + .Input("b: T") + .Output("i: max_len * T") + .Output("cs: max_len * T") + .Output("f: max_len * T") + .Output("o: max_len * T") + .Output("ci: max_len * T") + .Output("co: max_len * T") + .Output("h: max_len * T") + .Attr("max_len: int") + .Attr("forget_bias: float = 1.0") + .Attr("cell_clip: float = 3.0") + .Attr("use_peephole: bool = false") + .Attr("T: {float}") + .Doc(R"doc( +)doc"); + +REGISTER_OP("FusedLSTMGrad") + .Input("seq_len_max: int64") + .Input("x: max_len * T") + .Input("cs_prev: T") + .Input("h_prev: T") + .Input("w: T") + .Input("wci: T") + .Input("wcf: T") + .Input("wco: T") + .Input("b: T") + .Input("i: max_len * T") + .Input("cs: max_len * T") + .Input("f: max_len * T") + .Input("o: max_len * T") + .Input("ci: max_len * T") + .Input("co: max_len * T") + .Input("h: max_len * T") + .Input("cs_grad: max_len * T") + .Input("h_grad: max_len * T") + .Output("x_grad: max_len * T") + .Output("cs_prev_grad: T") + .Output("h_prev_grad: T") + .Output("w_grad: T") + .Output("wci_grad: T") + .Output("wcf_grad: T") + .Output("wco_grad: T") + .Output("b_grad: T") + .Attr("max_len: int") + .Attr("use_peephole: bool") + .Attr("T: {float}") + .Doc(R"doc( +)doc"); + +} // end namespace tensorflow diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py new file mode 100644 index 00000000000..70aeb5ff559 --- /dev/null +++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py @@ -0,0 +1,290 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LSTM Fused Cell ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import tensorflow as tf + +from tensorflow.contrib.rnn.python.ops import lstm_ops + + +fused_lstm = lstm_ops._fused_lstm # pylint: disable=protected-access + + +class LSTMFusedCellTest(tf.test.TestCase): + _use_gpu = False + + def testNoneDimsWithDynamicRNN(self): + with self.test_session(use_gpu=self._use_gpu, graph=tf.Graph()) as sess: + batch_size = 4 + num_steps = 5 + input_dim = 6 + cell_size = 7 + + cell = tf.contrib.rnn.LSTMFusedCell(cell_size) + x = tf.placeholder(tf.float32, shape=(None, None, input_dim)) + + output, _ = tf.nn.dynamic_rnn(cell, x, time_major=True, dtype=tf.float32) + sess.run(tf.initialize_all_variables()) + feed = {} + feed[x] = np.random.randn(num_steps, batch_size, input_dim) + sess.run(output, feed) + + def testLSTMFusedCell(self): + with self.test_session(use_gpu=self._use_gpu, graph=tf.Graph()) as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + m0 = tf.zeros([1, 2]) + m1 = tf.zeros([1, 2]) + m2 = tf.zeros([1, 2]) + m3 = tf.zeros([1, 2]) + g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell( + [tf.contrib.rnn.LSTMFusedCell(2)] * 2, + state_is_tuple=True)(x, ((m0, m1), (m2, m3))) + sess.run([tf.initialize_all_variables()]) + res = sess.run([g, out_m0, out_m1, out_m2, out_m3], + {x.name: np.array([[1., 1.]]), + m0.name: 0.1 * np.ones([1, 2]), + m1.name: 0.1 * np.ones([1, 2]), + m2.name: 0.1 * np.ones([1, 2]), + m3.name: 0.1 * np.ones([1, 2])}) + self.assertEqual(len(res), 5) + self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) + # These numbers are from testBasicLSTMCell and only test c/h. + self.assertAllClose(res[1], [[0.68967271, 0.68967271]]) + self.assertAllClose(res[2], [[0.44848421, 0.44848421]]) + self.assertAllClose(res[3], [[0.39897051, 0.39897051]]) + self.assertAllClose(res[4], [[0.24024698, 0.24024698]]) + + def testLSTMBasicToBlockCell(self): + with self.test_session(use_gpu=self._use_gpu) as sess: + x = tf.zeros([1, 2]) + x_values = np.random.randn(1, 2) + + m0_val = 0.1 * np.ones([1, 2]) + m1_val = -0.1 * np.ones([1, 2]) + m2_val = -0.2 * np.ones([1, 2]) + m3_val = 0.2 * np.ones([1, 2]) + + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=19890212) + with tf.variable_scope("basic", initializer=initializer): + m0 = tf.zeros([1, 2]) + m1 = tf.zeros([1, 2]) + m2 = tf.zeros([1, 2]) + m3 = tf.zeros([1, 2]) + g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell( + [tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)] * 2, + state_is_tuple=True)(x, ((m0, m1), (m2, m3))) + sess.run([tf.initialize_all_variables()]) + basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], + {x.name: x_values, + m0.name: m0_val, + m1.name: m1_val, + m2.name: m2_val, + m3.name: m3_val}) + + with tf.variable_scope("block", initializer=initializer): + m0 = tf.zeros([1, 2]) + m1 = tf.zeros([1, 2]) + m2 = tf.zeros([1, 2]) + m3 = tf.zeros([1, 2]) + g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell( + [tf.contrib.rnn.LSTMFusedCell(2)] * 2, + state_is_tuple=True)(x, ((m0, m1), (m2, m3))) + sess.run([tf.initialize_all_variables()]) + block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], + {x.name: x_values, + m0.name: m0_val, + m1.name: m1_val, + m2.name: m2_val, + m3.name: m3_val}) + + self.assertEqual(len(basic_res), len(block_res)) + for basic, block in zip(basic_res, block_res): + self.assertAllClose(basic, block) + + def testLSTMBasicToBlockCellPeeping(self): + with self.test_session(use_gpu=self._use_gpu) as sess: + x = tf.zeros([1, 2]) + x_values = np.random.randn(1, 2) + + m0_val = 0.1 * np.ones([1, 2]) + m1_val = -0.1 * np.ones([1, 2]) + m2_val = -0.2 * np.ones([1, 2]) + m3_val = 0.2 * np.ones([1, 2]) + + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=19890212) + with tf.variable_scope("basic", initializer=initializer): + m0 = tf.zeros([1, 2]) + m1 = tf.zeros([1, 2]) + m2 = tf.zeros([1, 2]) + m3 = tf.zeros([1, 2]) + g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell( + [tf.nn.rnn_cell.LSTMCell(2, + use_peepholes=True, + state_is_tuple=True)] * 2, + state_is_tuple=True)(x, ((m0, m1), (m2, m3))) + sess.run([tf.initialize_all_variables()]) + basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], + {x.name: x_values, + m0.name: m0_val, + m1.name: m1_val, + m2.name: m2_val, + m3.name: m3_val}) + + with tf.variable_scope("block", initializer=initializer): + m0 = tf.zeros([1, 2]) + m1 = tf.zeros([1, 2]) + m2 = tf.zeros([1, 2]) + m3 = tf.zeros([1, 2]) + g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell( + [tf.contrib.rnn.LSTMFusedCell(2, use_peephole=True)] * 2, + state_is_tuple=True)(x, ((m0, m1), (m2, m3))) + sess.run([tf.initialize_all_variables()]) + block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], + {x.name: x_values, + m0.name: m0_val, + m1.name: m1_val, + m2.name: m2_val, + m3.name: m3_val}) + + self.assertEqual(len(basic_res), len(block_res)) + for basic, block in zip(basic_res, block_res): + self.assertAllClose(basic, block) + + def testLSTMBasicToBlock(self): + with self.test_session(use_gpu=self._use_gpu) as sess: + batch_size = 2 + input_size = 3 + cell_size = 4 + sequence_length = 5 + + inputs = [] + for _ in range(sequence_length): + inp = tf.convert_to_tensor( + np.random.randn(batch_size, input_size), + dtype=tf.float32) + inputs.append(inp) + + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=19890212) + with tf.variable_scope("basic", initializer=initializer): + cell = tf.nn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) + outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) + + sess.run([tf.initialize_all_variables()]) + basic_outputs = sess.run(outputs) + basic_grads = sess.run(tf.gradients(outputs, inputs)) + basic_wgrads = sess.run(tf.gradients(outputs, tf.trainable_variables())) + + with tf.variable_scope("block", initializer=initializer): + w = tf.get_variable("w", + shape=[input_size + cell_size, cell_size * 4], + dtype=tf.float32) + b = tf.get_variable("b", + shape=[cell_size * 4], + dtype=tf.float32, + initializer=tf.zeros_initializer) + + _, _, _, _, _, _, outputs = fused_lstm( + tf.convert_to_tensor(sequence_length, + dtype=tf.int64), + inputs, + w, + b, + cell_clip=0) + + sess.run([tf.initialize_all_variables()]) + block_outputs = sess.run(outputs) + block_grads = sess.run(tf.gradients(outputs, inputs)) + block_wgrads = sess.run(tf.gradients(outputs, [w, b])) + + self.assertAllClose(basic_outputs, block_outputs) + self.assertAllClose(basic_grads, block_grads) + for basic, block in zip(basic_wgrads, block_wgrads): + self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2) + + def testLSTMBasicToBlockPeeping(self): + with self.test_session(use_gpu=self._use_gpu) as sess: + batch_size = 2 + input_size = 3 + cell_size = 4 + sequence_length = 5 + + inputs = [] + for _ in range(sequence_length): + inp = tf.convert_to_tensor( + np.random.randn(batch_size, input_size), + dtype=tf.float32) + inputs.append(inp) + + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=19890212) + with tf.variable_scope("basic", initializer=initializer): + cell = tf.nn.rnn_cell.LSTMCell(cell_size, + use_peepholes=True, + state_is_tuple=True) + outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) + + sess.run([tf.initialize_all_variables()]) + basic_outputs = sess.run(outputs) + basic_grads = sess.run(tf.gradients(outputs, inputs)) + basic_wgrads = sess.run(tf.gradients(outputs, tf.trainable_variables())) + + with tf.variable_scope("block", initializer=initializer): + w = tf.get_variable("w", + shape=[input_size + cell_size, cell_size * 4], + dtype=tf.float32) + b = tf.get_variable("b", + shape=[cell_size * 4], + dtype=tf.float32, + initializer=tf.zeros_initializer) + + wci = tf.get_variable("wci", shape=[cell_size], dtype=tf.float32) + wcf = tf.get_variable("wcf", shape=[cell_size], dtype=tf.float32) + wco = tf.get_variable("wco", shape=[cell_size], dtype=tf.float32) + + _, _, _, _, _, _, outputs = fused_lstm( + tf.convert_to_tensor(sequence_length, + dtype=tf.int64), + inputs, + w, + b, + wci=wci, + wcf=wcf, + wco=wco, + cell_clip=0, + use_peephole=True) + + sess.run([tf.initialize_all_variables()]) + block_outputs = sess.run(outputs) + block_grads = sess.run(tf.gradients(outputs, inputs)) + block_wgrads = sess.run(tf.gradients(outputs, [w, b, wci, wcf, wco])) + + self.assertAllClose(basic_outputs, block_outputs) + self.assertAllClose(basic_grads, block_grads) + for basic, block in zip(basic_wgrads, block_wgrads): + self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2) + + +class LSTMFusedCellGpuTest(LSTMFusedCellTest): + _use_gpu = True + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py new file mode 100644 index 00000000000..2ecc415d351 --- /dev/null +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -0,0 +1,456 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LSTM Fused Cell ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import resource_loader + +_lstm_ops_so = load_library.load_op_library( + resource_loader.get_path_to_datafile("_lstm_ops.so")) +assert _lstm_ops_so, "Could not load _lstm_ops.so." + + +# pylint: disable=invalid-name +def _lstm_fused_cell(x, + cs_prev, + h_prev, + w, + b, + wci=None, + wcf=None, + wco=None, + forget_bias=None, + cell_clip=None, + use_peephole=None, + name=None): + r"""Computes the LSTM cell forward propagation for 1 time step. + + This implementation uses 1 weight matrix and 1 bias vector, there is no + diagonal peephole connection. + + This kernel op implements the following mathematical equations: + + ```python + xh = [x, h_prev] + [i, f, ci, o] = xh * w + b + f = f + forget_bias + + i = sigmoid(i) + f = sigmoid(f) + ci = tanh(ci) + o = sigmoid(o) + + cs = ci .* i + cs_prev .* f + co = tanh(cs) + + h = co .* o + ``` + + Args: + x: A `Tensor`. Must be one of the following types: `float32`, `float64`. + The input to the LSTM cell. + cs_prev: A `Tensor`. Must have the same type as `x`. + h_prev: A `Tensor`. Must have the same type as `x`. + w: A `Tensor`. Must have the same type as `x`. The weight matrix. + b: A `Tensor`. Must have the same type as `x`. The bias vector. + wci: A `Tensor`. Must have the same type as `x`. + wcf: A `Tensor`. Must have the same type as `x`. + wco: A `Tensor`. Must have the same type as `x`. + forget_bias: An optional `float`. Defaults to `1`. The forget gate bias. + cell_clip: An optional `float`. Defaults to `3`. + use_peephole: An optional `bool`. Defaults to `False`. + name: A name for the operation (optional). + + Returns: + A tuple of `Tensor` objects (i, cs, f, o, ci, co, h). + i: A `Tensor`. Has the same type as `x`. The input gate. + cs: A `Tensor`. Has the same type as `x`. The cell state before the tanh. + f: A `Tensor`. Has the same type as `x`. The forget gate. + o: A `Tensor`. Has the same type as `x`. The output gate. + ci: A `Tensor`. Has the same type as `x`. The cell input. + co: A `Tensor`. Has the same type as `x`. The cell after the tanh. + h: A `Tensor`. Has the same type as `x`. The output h vector. + + Raises: + ValueError: If cell_size is None. + """ + if wci is None: + cell_size = cs_prev.get_shape().with_rank(2)[1].value + if cell_size is None: + raise ValueError("cell_size from `cs_prev` should not be None.") + wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) + wco = wci + wcf = wci + + # pylint: disable=protected-access + return _lstm_ops_so.lstm_fused_cell(x=x, + cs_prev=cs_prev, + h_prev=h_prev, + w=w, + wci=wci, + wco=wco, + wcf=wcf, + b=b, + forget_bias=forget_bias, + cell_clip=cell_clip, + use_peephole=use_peephole, + name=name) + # pylint: enable=protected-access + + +def _fused_lstm(seq_len_max, + x, + w, + b, + cs_prev=None, + h_prev=None, + wci=None, + wcf=None, + wco=None, + forget_bias=None, + cell_clip=None, + use_peephole=None, + name=None): + r"""TODO(williamchan): add doc. + + Args: + seq_len_max: A `Tensor` of type `int64`. + x: A list of at least 1 `Tensor` objects of the same type in: `float32`. + w: A `Tensor`. Must have the same type as `x`. + b: A `Tensor`. Must have the same type as `x`. + cs_prev: A `Tensor`. Must have the same type as `x`. + h_prev: A `Tensor`. Must have the same type as `x`. + wci: A `Tensor`. Must have the same type as `x`. + wcf: A `Tensor`. Must have the same type as `x`. + wco: A `Tensor`. Must have the same type as `x`. + forget_bias: An optional `float`. Defaults to `1`. + cell_clip: An optional `float`. Defaults to `3`. + use_peephole: An optional `bool`. Defaults to `False`. + name: A name for the operation (optional). + + Returns: + A tuple of `Tensor` objects (i, cs, f, o, ci, co, h). + i: A list with the same number of `Tensor` objects as `x` of `Tensor` + objects of the same type as x. + cs: A list with the same number of `Tensor` objects as `x` of `Tensor` + objects of the same type as x. + f: A list with the same number of `Tensor` objects as `x` of `Tensor` + objects of the same type as x. + o: A list with the same number of `Tensor` objects as `x` of `Tensor` + objects of the same type as x. + ci: A list with the same number of `Tensor` objects as `x` of `Tensor` + objects of the same type as x. + co: A list with the same number of `Tensor` objects as `x` of `Tensor` + objects of the same type as x. + h: A list with the same number of `Tensor` objects as `x` of `Tensor` + objects of the same type as x. + + Raises: + ValueError: If `b` does not have a valid shape. + """ + batch_size = x[0].get_shape().with_rank(2)[0].value + cell_size4 = b.get_shape().with_rank(1)[0].value + if cell_size4 is None: + raise ValueError("`b` shape must not be None.") + cell_size = cell_size4 / 4 + zero_state = None + if cs_prev is None or h_prev is None: + zero_state = array_ops.constant(0, + dtype=dtypes.float32, + shape=[batch_size, cell_size]) + if cs_prev is None: + cs_prev = zero_state + if h_prev is None: + h_prev = zero_state + if wci is None: + wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) + wco = wci + wcf = wci + + # pylint: disable=protected-access + return _lstm_ops_so.fused_lstm(seq_len_max=seq_len_max, + x=x, + cs_prev=cs_prev, + h_prev=h_prev, + w=w, + wci=wci, + wco=wco, + wcf=wcf, + b=b, + forget_bias=forget_bias, + cell_clip=cell_clip, + name=name, + use_peephole=use_peephole) + # pylint: enable=protected-access + # pylint: enable=invalid-name + + +ops.RegisterShape("LSTMFusedCell")(None) +_lstm_fused_cell_grad_outputs = ["cs_prev_grad", "dicfo"] + + +@ops.RegisterShape("LSTMFusedCell") +def _LSTMFusedCellShape(op): + batch_size = op.inputs[0].get_shape().with_rank(2)[0].value + cell_size = op.inputs[1].get_shape().with_rank(2)[1].value + + return (tensor_shape.TensorShape([batch_size, cell_size]), + tensor_shape.TensorShape([batch_size, cell_size]), + tensor_shape.TensorShape([batch_size, cell_size]), + tensor_shape.TensorShape([batch_size, cell_size]), + tensor_shape.TensorShape([batch_size, cell_size]), + tensor_shape.TensorShape([batch_size, cell_size]), + tensor_shape.TensorShape([batch_size, cell_size])) + + +@ops.RegisterGradient("LSTMFusedCell") +def _LSTMFusedCellGrad(op, *grad): + """Gradient for LSTMFusedCell.""" + (x, cs_prev, h_prev, w, wci, wco, wcf, b) = op.inputs + (i, cs, f, o, ci, co, _) = op.outputs + (_, cs_grad, _, _, _, _, h_grad) = grad + + batch_size = x.get_shape().with_rank(2)[0].value + if batch_size is None: + batch_size = -1 + input_size = x.get_shape().with_rank(2)[1].value + if input_size is None: + raise ValueError("input_size from `x` should not be None.") + cell_size = cs_prev.get_shape().with_rank(2)[1].value + if cell_size is None: + raise ValueError("cell_size from `cs_prev` should not be None.") + + (cs_prev_grad, dicfo, wci_grad, wcf_grad, + wco_grad) = _lstm_ops_so.lstm_fused_cell_grad( + x, + cs_prev, + h_prev, + w, + wci, + wcf, + wco, + b, + i, + cs, + f, + o, + ci, + co, + cs_grad, + h_grad, + use_peephole=op.get_attr("use_peephole")) + + # Backprop from dicfo to xh. + xh_grad = math_ops.matmul(dicfo, w, transpose_b=True) + + x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size)) + x_grad.get_shape().merge_with(x.get_shape()) + + h_prev_grad = array_ops.slice(xh_grad, (0, input_size), + (batch_size, cell_size)) + h_prev_grad.get_shape().merge_with(h_prev.get_shape()) + + # Backprop from dicfo to w. + xh = array_ops.concat(1, [x, h_prev]) + w_grad = math_ops.matmul(xh, dicfo, transpose_a=True) + w_grad.get_shape().merge_with(w.get_shape()) + + # Backprop from dicfo to b. + b_grad = nn_ops.bias_add_grad(dicfo) + b_grad.get_shape().merge_with(b.get_shape()) + + return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, + wco_grad, b_grad) + + +@ops.RegisterShape("LSTMFusedCellGrad") +def _LSTMFusedCellGradShape(op): + batch_size = op.inputs[0].get_shape().with_rank(2)[0].value + cell_size = op.inputs[1].get_shape().with_rank(2)[1].value + + return [tensor_shape.TensorShape([batch_size, cell_size]), + tensor_shape.TensorShape([batch_size, cell_size * 4]), + tensor_shape.TensorShape([cell_size]), + tensor_shape.TensorShape([cell_size]), + tensor_shape.TensorShape([cell_size])] + + +@ops.RegisterShape("FusedLSTM") +def _FusedLSTMShape(op): + max_len = op.get_attr("max_len") + + x = op.inputs[1] + b = op.inputs[-1] + + batch_size = x.get_shape().with_rank(2)[0].value + cell_size = b.get_shape().with_rank(1)[0].value / 4 + + return [tensor_shape.TensorShape([batch_size, cell_size])] * max_len * 7 + + +@ops.RegisterGradient("FusedLSTM") +def _FusedLSTMGrad(op, *grad): + """Gradient for FusedLSTM.""" + max_len = op.get_attr("max_len") + + seq_len_max = op.inputs[0] + x = op.inputs[1:1 + max_len] + cs_prev = op.inputs[-7] + h_prev = op.inputs[-6] + w = op.inputs[-5] + wci = op.inputs[-4] + wco = op.inputs[-3] + wcf = op.inputs[-2] + b = op.inputs[-1] + + i = op.outputs[0 * max_len:1 * max_len] + cs = op.outputs[1 * max_len:2 * max_len] + f = op.outputs[2 * max_len:3 * max_len] + o = op.outputs[3 * max_len:4 * max_len] + ci = op.outputs[4 * max_len:5 * max_len] + co = op.outputs[5 * max_len:6 * max_len] + h = op.outputs[6 * max_len:7 * max_len] + + cs_grad = grad[-max_len * 2:-max_len] + h_grad = grad[-max_len:] + + (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, wcf_grad, + b_grad) = _lstm_ops_so.fused_lstm_grad( + seq_len_max, + x, + cs_prev, + h_prev, + w, + wci, + wco, + wcf, + b, + i, + cs, + f, + o, + ci, + co, + h, + cs_grad, + h_grad, + use_peephole=op.get_attr("use_peephole")) + + return [None] + x_grad + [cs_prev_grad, h_prev_grad, w_grad, wci_grad, + wco_grad, wcf_grad, b_grad] + + +@ops.RegisterShape("FusedLSTMGrad") +def _FusedLSTMGradShape(op): + """Shape for FusedLSTM.""" + max_len = op.get_attr("max_len") + + x = op.inputs[1] + cs_prev = op.inputs[1 + max_len] + h_prev = op.inputs[2 + max_len] + w = op.inputs[3 + max_len] + wci = op.inputs[4 + max_len] + wco = op.inputs[5 + max_len] + wcf = op.inputs[6 + max_len] + b = op.inputs[7 + max_len] + + x_shape = x.get_shape().with_rank(2) + cs_prev_shape = cs_prev.get_shape().with_rank(2) + h_prev_shape = h_prev.get_shape().with_rank(2) + w_shape = w.get_shape().with_rank(2) + wci_shape = wci.get_shape().with_rank(1) + wco_shape = wco.get_shape().with_rank(1) + wcf_shape = wcf.get_shape().with_rank(1) + b_shape = b.get_shape().with_rank(1) + + return [x_shape] * max_len + [cs_prev_shape, h_prev_shape, w_shape, wci_shape, + wco_shape, wcf_shape, b_shape] + + +class LSTMFusedCell(rnn_cell.RNNCell): + """Basic LSTM recurrent network cell. + + The implementation is based on: http://arxiv.org/abs/1409.2329. + + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + + Unlike BasicLSTMCell, this is a monolithic op and should be much faster. The + weight and bias matrixes should be compatible as long as the variabel scope + matches. + """ + + def __init__(self, num_units, forget_bias=1.0, use_peephole=False): + """Initialize the basic LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + forget_bias: float, The bias added to forget gates (see above). + use_peephole: Whether to use peephole connectios or not. + """ + self._num_units = num_units + self._forget_bias = forget_bias + self._use_peephole = use_peephole + + @property + def state_size(self): + return (self._num_units,) * 2 + + @property + def output_size(self): + return self._num_units + + def __call__(self, x, states_prev, scope=None): + """Long short-term memory cell (LSTM).""" + with vs.variable_scope(scope or type(self).__name__): + x_shape = x.get_shape().with_rank(2) + if not x_shape[1]: + raise ValueError("Expecting x_shape[1] to be sets: %s" % str(x_shape)) + if len(states_prev) != 2: + raise ValueError("Expecting states_prev to be a tuple with length 2.") + input_size = x_shape[1] + w = vs.get_variable("W", [input_size + self._num_units, + self._num_units * 4]) + b = vs.get_variable("b", [w.get_shape().with_rank(2)[1]], + initializer=init_ops.constant_initializer(0.0)) + wci = vs.get_variable("wci", [self._num_units]) + wco = vs.get_variable("wco", [self._num_units]) + wcf = vs.get_variable("wcf", [self._num_units]) + (cs_prev, h_prev) = states_prev + (_, cs, _, _, _, _, h) = _lstm_fused_cell(x, + cs_prev, + h_prev, + w, + b, + wci=wci, + wco=wco, + wcf=wcf, + forget_bias=self._forget_bias, + use_peephole=self._use_peephole) + + return (h, (cs, h)) diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 7d00e73f90a..0ea41e10102 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -27,12 +27,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops.math_ops import reduce_sum -from tensorflow.python.ops.math_ops import sigmoid -from tensorflow.python.ops.math_ops import tanh -from tensorflow.python.ops.nn_ops import conv2d -from tensorflow.python.ops.nn_ops import softmax - from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -104,7 +98,7 @@ class CoupledInputForgetGateLSTMCell(rnn_cell.RNNCell): initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=False, - activation=tanh): + activation=math_ops.tanh): """Initialize the parameters for an LSTM cell. Args: @@ -188,6 +182,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell.RNNCell): ValueError: If input size cannot be inferred from inputs via static shape inference. """ + sigmoid = math_ops.sigmoid + num_proj = self._num_units if self._num_proj is None else self._num_proj if self._state_is_tuple: @@ -322,6 +318,8 @@ class TimeFreqLSTMCell(rnn_cell.RNNCell): ValueError: if an input_size was specified and the provided inputs have a different dimension. """ + sigmoid = math_ops.sigmoid + tanh = math_ops.tanh freq_inputs = self._make_tf_features(inputs) dtype = inputs.dtype @@ -489,6 +487,8 @@ class GridLSTMCell(rnn_cell.RNNCell): ValueError: if an input_size was specified and the provided inputs have a different dimension. """ + sigmoid = math_ops.sigmoid + tanh = math_ops.tanh freq_inputs = self._make_tf_features(inputs) dtype = inputs.dtype @@ -771,6 +771,11 @@ class AttentionCellWrapper(rnn_cell.RNNCell): return output, new_state def _attention(self, query, attn_states): + conv2d = nn_ops.conv2d + reduce_sum = math_ops.reduce_sum + softmax = nn_ops.softmax + tanh = math_ops.tanh + with vs.variable_scope("Attention"): k = vs.get_variable("AttnW", [1, 1, self._attn_size, self._attn_vec_size]) v = vs.get_variable("AttnV", [self._attn_vec_size]) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index da0e8ca4c95..abf4f0ee1f1 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -570,6 +570,7 @@ filegroup( name = "android_srcs", srcs = [ ":proto_text_srcs_all", + "//tensorflow/core/debug:android_srcs", "//tensorflow/core/kernels:android_srcs", "//tensorflow/core/platform/default/build_config:android_srcs", "//tensorflow/core/util/ctc:android_srcs", @@ -580,8 +581,6 @@ filegroup( "client/**/*.cc", "common_runtime/**/*.h", "common_runtime/**/*.cc", - "debug/**/*.h", - "debug/**/*.cc", "framework/**/*.h", "framework/**/*.cc", "graph/**/*.h", @@ -1103,49 +1102,13 @@ tf_cuda_library( linkstatic = 1, deps = [ ":core_cpu_internal", - ":debug_graph_utils", ":framework", ":gpu_tracer", ":lib", ":lib_internal", ":proto_text", ":protos_all_cc", - ], - alwayslink = 1, -) - -tf_cuda_library( - name = "debug_gateway_internal", - srcs = ["debug/debug_gateway.cc"], - hdrs = ["debug/debug_gateway.h"], - copts = tf_copts(), - linkstatic = 1, - deps = [ - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":gpu_tracer", - ":lib", - ":lib_internal", - ":proto_text", - ":protos_all_cc", - ], - alwayslink = 1, -) - -tf_cuda_library( - name = "debug_graph_utils", - srcs = ["debug/debug_graph_utils.cc"], - hdrs = ["debug/debug_graph_utils.h"], - copts = tf_copts(), - linkstatic = 1, - deps = [ - ":core_cpu_internal", - ":framework", - ":lib", - ":lib_internal", - ":proto_text", - ":protos_all_cc", + "//tensorflow/core/debug:debug_graph_utils", ], alwayslink = 1, ) @@ -1604,35 +1567,6 @@ tf_cc_test( ], ) -tf_cc_test_gpu( - name = "debug/debug_gateway_test", - size = "small", - args = ["--heap_check=local"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags() + ["nomac"], - deps = [ - ":all_kernels", - ":core_cpu", - ":core_cpu_internal", - ":debug_gateway_internal", - ":debug_graph_utils", - ":direct_session", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:debug_ops", - "//tensorflow/core/kernels:ops_util", - ], -) - tf_cc_test( name = "common_runtime/direct_session_with_tracking_alloc_test", size = "small", diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD new file mode 100644 index 00000000000..da4c45520e1 --- /dev/null +++ b/tensorflow/core/debug/BUILD @@ -0,0 +1,157 @@ +# Description: +# TensorFlow Debugger (tfdbg). +# +# Public Android targets: +# filegroup ":android_srcs" - Debugger source files for Android. + +package( + default_visibility = ["//tensorflow:internal"], + features = ["-parse_headers"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_cc_test", + "tf_cuda_library", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") + +# For platform specific build config +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) + +tf_cuda_library( + name = "debug_gateway_internal", + srcs = ["debug_gateway.cc"], + hdrs = ["debug_gateway.h"], + copts = tf_copts(), + linkstatic = 1, + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_tracer", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +tf_cuda_library( + name = "debug_graph_utils", + srcs = ["debug_graph_utils.cc"], + hdrs = ["debug_graph_utils.h"], + copts = tf_copts(), + linkstatic = 1, + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +tf_cuda_library( + name = "debug_io_utils", + srcs = ["debug_io_utils.cc"], + hdrs = ["debug_io_utils.h"], + copts = tf_copts(), + linkstatic = 1, + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +tf_cc_test_gpu( + name = "debug_gateway_test", + size = "small", + args = ["--heap_check=local"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":debug_gateway_internal", + ":debug_graph_utils", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:debug_ops", + "//tensorflow/core/kernels:ops_util", + ], +) + +tf_cc_test( + name = "debug_io_utils_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":debug_io_utils", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +filegroup( + name = "android_srcs", + srcs = [ + "debug_graph_utils.cc", + "debug_graph_utils.h", + ], + visibility = ["//visibility:public"], +) + +# ----------------------------------------------------------------------------- +# Google-internal targets. These must be at the end for syncrepo. + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc index 8374b245cb6..118847686d3 100644 --- a/tensorflow/core/debug/debug_graph_utils.cc +++ b/tensorflow/core/debug/debug_graph_utils.cc @@ -36,6 +36,8 @@ Status DebugNodeInserter::InsertNodes( // A map from tensor name (e.g., "node_a:0") to list of debug op names // (e.g., {"DebugIdentity", "DebugNanCount"}) std::unordered_map<string, std::vector<string>> tensor_watches; + // A map from tensor name to debug_url. + std::unordered_map<string, std::vector<string>> tensor_watch_urls; // Cache the proto content for fast lookup later for (const DebugTensorWatch& watch : watches) { @@ -58,6 +60,12 @@ Status DebugNodeInserter::InsertNodes( } tensor_watches[tensor_name] = debug_ops; + + std::vector<string> urls; + for (const string& url : watch.debug_urls()) { + urls.push_back(url); + } + tensor_watch_urls[tensor_name] = urls; } if (tensor_watches.empty()) { @@ -150,9 +158,9 @@ Status DebugNodeInserter::InsertNodes( const string& debug_op_name = tensor_watches[tensor_name][i]; Node* debug_node; - Status debug_s = - CreateDebugNode(graph, device_type, copy_node->name(), src_dt, - tensor_name, i, debug_op_name, &debug_node); + Status debug_s = CreateDebugNode( + graph, device_type, copy_node->name(), src_dt, tensor_name, + tensor_watch_urls[tensor_name], i, debug_op_name, &debug_node); if (!debug_s.ok()) { return Status( error::FAILED_PRECONDITION, @@ -267,17 +275,17 @@ Status DebugNodeInserter::CreateCopyNode( Status DebugNodeInserter::CreateDebugNode( Graph* graph, const DeviceType device_type, const string& src_copy_node_name, const DataType src_dt, - const string& tensor_name, const int debug_op_num, - const string& debug_op_name, Node** debug_node) { + const string& tensor_name, const std::vector<string>& debug_urls, + const int debug_op_num, const string& debug_op_name, Node** debug_node) { NodeDef node_def; const KernelDef* kdef; const string debug_node_name = GetDebugNodeName(tensor_name, debug_op_num, debug_op_name); - // TODO(cais): Hook up with DebugTensorWatch proto auto builder = NodeDefBuilder(debug_node_name, debug_op_name) .Input(src_copy_node_name, 0, src_dt) - .Attr("tensor_name", tensor_name); + .Attr("tensor_name", tensor_name) + .Attr("debug_urls", debug_urls); if (!builder.Finalize(&node_def).ok()) { return Status( diff --git a/tensorflow/core/debug/debug_graph_utils.h b/tensorflow/core/debug/debug_graph_utils.h index 41789a30ffe..ea61dee4d08 100644 --- a/tensorflow/core/debug/debug_graph_utils.h +++ b/tensorflow/core/debug/debug_graph_utils.h @@ -94,6 +94,7 @@ class DebugNodeInserter { const string& src_copy_node_name, const DataType src_dt, const string& tensor_name, + const std::vector<string>& debug_urls, const int debug_op_num, const string& debug_op_name, Node** debug_node); }; diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc new file mode 100644 index 00000000000..474577a79c0 --- /dev/null +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -0,0 +1,211 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/debug/debug_io_utils.h" + +#include <vector> + +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { + +namespace { + +// Encapsulate the tensor value inside a Summary proto, and then inside an +// Event proto. +Event WrapTensorAsEvent(const string& tensor_name, const string& debug_op, + const Tensor& tensor, const uint64 wall_time_us) { + Event event; + event.set_wall_time(static_cast<double>(wall_time_us)); + + Summary::Value* summ_val = event.mutable_summary()->add_value(); + + // Create the debug node_name in the Summary proto. + // For example, if tensor_name = "foo/node_a:0", and the debug_op is + // "DebugIdentity", the debug node_name in the Summary proto will be + // "foo/node_a:0:DebugIdentity". + const string debug_node_name = strings::StrCat(tensor_name, ":", debug_op); + summ_val->set_node_name(debug_node_name); + + if (tensor.dtype() == DT_STRING) { + // Treat DT_STRING specially, so that tensor_util.MakeNdarray can convert + // the TensorProto to string-type numpy array. MakeNdarray does not work + // with strings encoded by AsProtoTensorContent() in tensor_content. + tensor.AsProtoField(summ_val->mutable_tensor()); + } else { + tensor.AsProtoTensorContent(summ_val->mutable_tensor()); + } + + return event; +} + +} // namespace + +// static +const char* const DebugIO::kFileURLScheme = "file://"; +// static +const char* const DebugIO::kGrpcURLScheme = "grpc://"; + +Status DebugIO::PublishDebugTensor(const string& tensor_name, + const string& debug_op, const Tensor& tensor, + const uint64 wall_time_us, + const gtl::ArraySlice<string>& debug_urls) { + // Split the tensor_name into node name and output slot index. + std::vector<string> name_items = str_util::Split(tensor_name, ':'); + string node_name; + int32 output_slot = 0; + if (name_items.size() == 2) { + node_name = name_items[0]; + if (!strings::safe_strto32(name_items[1], &output_slot)) { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("Invalid string value for output_slot: \"", + name_items[1], "\"")); + } + } else if (name_items.size() == 1) { + node_name = name_items[0]; + } else { + return Status( + error::INVALID_ARGUMENT, + strings::StrCat("Failed to parse tensor name: \"", tensor_name, "\"")); + } + + int num_failed_urls = 0; + for (const string& url : debug_urls) { + if (str_util::Lowercase(url).find(kFileURLScheme) == 0) { + const string dump_root_dir = url.substr(strlen(kFileURLScheme)); + + Status s = + DebugFileIO::DumpTensorToDir(node_name, output_slot, debug_op, tensor, + wall_time_us, dump_root_dir, nullptr); + if (!s.ok()) { + num_failed_urls++; + } + } else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) { + // TODO(cais): Implement PublishTensor with grpc urls. + return Status(error::UNIMPLEMENTED, + strings::StrCat("Puslishing to GRPC debug target is not ", + "implemented yet")); + } else { + return Status(error::UNAVAILABLE, + strings::StrCat("Invalid debug target URL: ", url)); + } + } + + if (num_failed_urls == 0) { + return Status::OK(); + } else { + return Status( + error::INTERNAL, + strings::StrCat("Puslishing to ", num_failed_urls, " of ", + debug_urls.size(), " debug target URLs failed")); + } +} + +// static +Status DebugFileIO::DumpTensorToDir( + const string& node_name, const int32 output_slot, const string& debug_op, + const Tensor& tensor, const uint64 wall_time_us, + const string& dump_root_dir, string* dump_file_path) { + const string file_path = GetDumpFilePath(dump_root_dir, node_name, + output_slot, debug_op, wall_time_us); + + if (dump_file_path != nullptr) { + *dump_file_path = file_path; + } + + return DumpTensorToEventFile(node_name, output_slot, debug_op, tensor, + wall_time_us, file_path); +} + +// static +string DebugFileIO::GetDumpFilePath(const string& dump_root_dir, + const string& node_name, + const int32 output_slot, + const string& debug_op, + const uint64 wall_time_us) { + return io::JoinPath( + dump_root_dir, strings::StrCat(node_name, "_", output_slot, "_", debug_op, + "_", wall_time_us)); +} + +// static +Status DebugFileIO::DumpTensorToEventFile( + const string& node_name, const int32 output_slot, const string& debug_op, + const Tensor& tensor, const uint64 wall_time_us, const string& file_path) { + Env* env(Env::Default()); + + // Create the directory if necessary. + string file_dir = io::Dirname(file_path).ToString(); + Status s = DebugFileIO::RecursiveCreateDir(env, file_dir); + + if (!s.ok()) { + return Status(error::FAILED_PRECONDITION, + strings::StrCat("Failed to create directory ", file_dir, + ", due to: ", s.error_message())); + } + + const string tensor_name = strings::StrCat(node_name, ":", output_slot); + Event event = WrapTensorAsEvent(tensor_name, debug_op, tensor, wall_time_us); + + string event_str; + event.SerializeToString(&event_str); + + std::unique_ptr<WritableFile> f = nullptr; + TF_CHECK_OK(env->NewWritableFile(file_path, &f)); + f->Append(event_str); + TF_CHECK_OK(f->Close()); + + return Status::OK(); +} + +// static +Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { + if (env->FileExists(dir) && env->IsDirectory(dir).ok()) { + // The path already exists as a directory. Return OK right away. + return Status::OK(); + } + + string parent_dir = io::Dirname(dir).ToString(); + if (!env->FileExists(parent_dir)) { + // The parent path does not exist yet, create it first. + Status s = RecursiveCreateDir(env, parent_dir); // Recursive call + if (!s.ok()) { + return Status( + error::FAILED_PRECONDITION, + strings::StrCat("Failed to create directory ", parent_dir)); + } + } else if (env->FileExists(parent_dir) && + !env->IsDirectory(parent_dir).ok()) { + // The path exists, but it is a file. + return Status(error::FAILED_PRECONDITION, + strings::StrCat("Failed to create directory ", parent_dir, + " because the path exists as a file ")); + } + + env->CreateDir(dir); + // Guard against potential race in creating directories by doing a check + // after the CreateDir call. + if (env->FileExists(dir) && env->IsDirectory(dir).ok()) { + return Status::OK(); + } else { + return Status(error::ABORTED, + strings::StrCat("Failed to create directory ", parent_dir)); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h new file mode 100644 index 00000000000..553ae9ab7d2 --- /dev/null +++ b/tensorflow/core/debug/debug_io_utils.h @@ -0,0 +1,107 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_ +#define TENSORFLOW_DEBUG_IO_UTILS_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +class DebugIO { + public: + // Publish a tensor to a debug target URL. + // + // Args: + // tensor_name: Name of the tensor being published: node_name followed by + // a colon, followed by the output slot index. E.g., "node_a:0". + // debug_op: Name of the debug op, e.g., "DebugIdentity". + // tensor: The Tensor object being published. + // wall_time_us: Time stamp for the Tensor. Unit: microseconds (us). + // debug_urls: An array of debug target URLs, e.g., + // "file:///foo/tfdbg_dump", "grpc://localhot:11011" + static Status PublishDebugTensor(const string& tensor_name, + const string& debug_op, const Tensor& tensor, + const uint64 wall_time_us, + const gtl::ArraySlice<string>& debug_urls); + + private: + static const char* const kFileURLScheme; + static const char* const kGrpcURLScheme; +}; + +// Helper class for debug ops. +class DebugFileIO { + public: + // Encapsulate the Tensor in an Event protobuf and write it to a directory. + // The actual path of the dump file will be a contactenation of + // dump_root_dir, tensor_name, along with the wall_time. + // + // For example: + // let dump_root_dir = "/tmp/tfdbg_dump", + // node_name = "foo/bar", + // output_slot = 0, + // debug_op = DebugIdentity, + // and wall_time_us = 1467891234512345, + // the dump file will be generated at path: + // /tmp/tfdbg_dump/foo/bar_0_DebugIdentity_1467891234512345. + // + // Args: + // node_name: Name of the node from which the tensor is output. + // output_slot: Output slot index. + // debug_op: Name of the debug op, e.g., "DebugIdentity". + // tensor: The Tensor object to be dumped to file. + // wall_time_us: Wall time at which the Tensor is generated during graph + // execution. Unit: microseconds (us). + // dump_root_dir: Root diretory for dumping the tensor. + // dump_file_path: The actual dump file path (passed as reference). + static Status DumpTensorToDir(const string& node_name, + const int32 output_slot, const string& debug_op, + const Tensor& tensor, const uint64 wall_time_us, + const string& dump_root_dir, + string* dump_file_path); + + // Get the full path to the dump file. + // + // Args: + // dump_root_dir: The dump root directory, e.g., /tmp/tfdbg_dump + // node_name: Name of the node from which the dumped tensor is generated, + // e.g., foo/bar/node_a + // output_slot: Output slot index of the said node, e.g., 0. + // debug_op: Name of the debug op, e.g., DebugIdentity. + // wall_time_us: Time stamp of the dumped tensor, in microseconds (us). + static string GetDumpFilePath(const string& dump_root_dir, + const string& node_name, + const int32 output_slot, const string& debug_op, + const uint64 wall_time_us); + + private: + // Encapsulate the Tensor in an Event protobuf and write it to file. + static Status DumpTensorToEventFile( + const string& node_name, const int32 output_slot, const string& debug_op, + const Tensor& tensor, const uint64 wall_time_us, const string& file_path); + + // Implemented ad hoc here for now. + // TODO(cais): Replace with shared implementation once http://b/30497715 is + // fixed. + static Status RecursiveCreateDir(Env* env, const string& dir); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_DEBUG_IO_UTILS_H_ diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc new file mode 100644 index 00000000000..ecdda643c3a --- /dev/null +++ b/tensorflow/core/debug/debug_io_utils_test.cc @@ -0,0 +1,382 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/debug/debug_io_utils.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { +namespace { + +class DebugIOUtilsTest : public ::testing::Test { + public: + void Initialize() { + env_ = Env::Default(); + + tensor_a_.reset(new Tensor(DT_FLOAT, TensorShape({2, 2}))); + tensor_a_->flat<float>()(0) = 5.0; + tensor_a_->flat<float>()(1) = 3.0; + tensor_a_->flat<float>()(2) = -1.0; + tensor_a_->flat<float>()(3) = 0.0; + + tensor_b_.reset(new Tensor(DT_STRING, TensorShape{2})); + tensor_b_->flat<string>()(0) = "corge"; + tensor_b_->flat<string>()(1) = "garply"; + } + + Status ReadEventFromFile(const string& dump_file_path, Event* event) { + string content; + uint64 file_size = 0; + + Status s = env_->GetFileSize(dump_file_path, &file_size); + if (!s.ok()) { + return s; + } + + content.resize(file_size); + + std::unique_ptr<RandomAccessFile> file; + s = env_->NewRandomAccessFile(dump_file_path, &file); + if (!s.ok()) { + return s; + } + + StringPiece result; + s = file->Read(0, file_size, &result, &(content)[0]); + if (!s.ok()) { + return s; + } + + event->ParseFromString(content); + return Status::OK(); + } + + Env* env_; + std::unique_ptr<Tensor> tensor_a_; + std::unique_ptr<Tensor> tensor_b_; +}; + +TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) { + Initialize(); + + const string test_dir = testing::TmpDir(); + + // Append levels of nonexisting directories, to test that the function can + // create directories. + const string kNodeName = "foo/bar/qux/tensor_a"; + const string kDebugOpName = "DebugIdentity"; + const int32 output_slot = 0; + uint64 wall_time = env_->NowMicros(); + + string dump_file_path; + TF_ASSERT_OK(DebugFileIO::DumpTensorToDir(kNodeName, output_slot, + kDebugOpName, *tensor_a_, wall_time, + test_dir, &dump_file_path)); + + // Read the file into a Event proto. + Event event; + TF_ASSERT_OK(ReadEventFromFile(dump_file_path, &event)); + + ASSERT_GE(wall_time, event.wall_time()); + ASSERT_EQ(1, event.summary().value().size()); + ASSERT_EQ(strings::StrCat(kNodeName, ":", output_slot, ":", kDebugOpName), + event.summary().value(0).node_name()); + + Tensor a_prime(DT_FLOAT); + ASSERT_TRUE(a_prime.FromProto(event.summary().value(0).tensor())); + + // Verify tensor shape and value. + ASSERT_EQ(tensor_a_->shape(), a_prime.shape()); + for (int i = 0; i < a_prime.flat<float>().size(); ++i) { + ASSERT_EQ(tensor_a_->flat<float>()(i), a_prime.flat<float>()(i)); + } + + // Tear down temporary file and directories. + int64 undeleted_files = 0; + int64 undeleted_dirs = 0; + ASSERT_TRUE( + env_->DeleteRecursively(test_dir, &undeleted_files, &undeleted_dirs) + .ok()); + ASSERT_EQ(0, undeleted_files); + ASSERT_EQ(0, undeleted_dirs); +} + +TEST_F(DebugIOUtilsTest, DumpStringTensorToFileSunnyDay) { + Initialize(); + + const string test_dir = testing::TmpDir(); + + const string kNodeName = "quux/grault/tensor_b"; + const string kDebugOpName = "DebugIdentity"; + const int32 output_slot = 1; + uint64 wall_time = env_->NowMicros(); + + string dump_file_name; + Status s = DebugFileIO::DumpTensorToDir(kNodeName, output_slot, kDebugOpName, + *tensor_b_, wall_time, test_dir, + &dump_file_name); + ASSERT_TRUE(s.ok()); + + // Read the file into a Event proto. + Event event; + TF_ASSERT_OK(ReadEventFromFile(dump_file_name, &event)); + + ASSERT_GE(wall_time, event.wall_time()); + ASSERT_EQ(1, event.summary().value().size()); + ASSERT_EQ(strings::StrCat(kNodeName, ":", output_slot, ":", kDebugOpName), + event.summary().value(0).node_name()); + + Tensor b_prime(DT_STRING); + ASSERT_TRUE(b_prime.FromProto(event.summary().value(0).tensor())); + + // Verify tensor shape and value. + ASSERT_EQ(tensor_b_->shape(), b_prime.shape()); + for (int i = 0; i < b_prime.flat<string>().size(); ++i) { + ASSERT_EQ(tensor_b_->flat<string>()(i), b_prime.flat<string>()(i)); + } + + // Tear down temporary file and directories. + int64 undeleted_files = 0; + int64 undeleted_dirs = 0; + ASSERT_TRUE( + env_->DeleteRecursively(test_dir, &undeleted_files, &undeleted_dirs) + .ok()); + ASSERT_EQ(0, undeleted_files); + ASSERT_EQ(0, undeleted_dirs); +} + +TEST_F(DebugIOUtilsTest, DumpTensorToFileCannotCreateDirectory) { + Initialize(); + + // First, create the file at the path. + const string test_dir = testing::TmpDir(); + const string txt_file_name = strings::StrCat(test_dir, "/baz"); + + if (!env_->FileExists(test_dir)) { + ASSERT_TRUE(env_->CreateDir(test_dir).ok()); + } + ASSERT_FALSE(env_->FileExists(txt_file_name)); + + std::unique_ptr<WritableFile> file; + ASSERT_TRUE(env_->NewWritableFile(txt_file_name, &file).ok()); + file->Append("text in baz"); + file->Flush(); + file->Close(); + + // Verify that the path exists and that it is a file, not a directory. + ASSERT_TRUE(env_->FileExists(txt_file_name)); + ASSERT_FALSE(env_->IsDirectory(txt_file_name).ok()); + + // Second, try to dump the tensor to a path that requires "baz" to be a + // directory, which should lead to an error. + const string kNodeName = "baz/tensor_a"; + const string kDebugOpName = "DebugIdentity"; + const int32 output_slot = 0; + uint64 wall_time = env_->NowMicros(); + + string dump_file_name; + Status s = DebugFileIO::DumpTensorToDir(kNodeName, output_slot, kDebugOpName, + *tensor_a_, wall_time, test_dir, + &dump_file_name); + ASSERT_FALSE(s.ok()); + + // Tear down temporary file and directories. + int64 undeleted_files = 0; + int64 undeleted_dirs = 0; + ASSERT_TRUE( + env_->DeleteRecursively(test_dir, &undeleted_files, &undeleted_dirs) + .ok()); + ASSERT_EQ(0, undeleted_files); + ASSERT_EQ(0, undeleted_dirs); +} + +TEST_F(DebugIOUtilsTest, PublishTensorToMultipleFileURLs) { + Initialize(); + + const int kNumDumpRoots = 3; + const string kNodeName = "foo/bar/qux/tensor_a"; + const string kDebugOpName = "DebugIdentity"; + const int32 output_slot = 0; + + uint64 wall_time = env_->NowMicros(); + + std::vector<string> dump_roots; + std::vector<string> dump_file_paths; + std::vector<string> urls; + for (int i = 0; i < kNumDumpRoots; ++i) { + string dump_root = strings::StrCat(testing::TmpDir(), "/", i); + + dump_roots.push_back(dump_root); + dump_file_paths.push_back(DebugFileIO::GetDumpFilePath( + dump_root, kNodeName, output_slot, kDebugOpName, wall_time)); + urls.push_back(strings::StrCat("file://", dump_root)); + } + + for (int i = 1; i < kNumDumpRoots; ++i) { + ASSERT_NE(dump_roots[0], dump_roots[i]); + } + + const string tensor_name = strings::StrCat(kNodeName, ":", output_slot); + const string debug_node_name = + strings::StrCat(tensor_name, ":", kDebugOpName); + Status s = DebugIO::PublishDebugTensor(tensor_name, kDebugOpName, *tensor_a_, + wall_time, urls); + ASSERT_TRUE(s.ok()); + + // Try reading the file into a Event proto. + for (int i = 0; i < kNumDumpRoots; ++i) { + // Read the file into a Event proto. + Event event; + TF_ASSERT_OK(ReadEventFromFile(dump_file_paths[i], &event)); + + ASSERT_GE(wall_time, event.wall_time()); + ASSERT_EQ(1, event.summary().value().size()); + ASSERT_EQ(debug_node_name, event.summary().value(0).node_name()); + + Tensor a_prime(DT_FLOAT); + ASSERT_TRUE(a_prime.FromProto(event.summary().value(0).tensor())); + + // Verify tensor shape and value. + ASSERT_EQ(tensor_a_->shape(), a_prime.shape()); + for (int i = 0; i < a_prime.flat<float>().size(); ++i) { + ASSERT_EQ(tensor_a_->flat<float>()(i), a_prime.flat<float>()(i)); + } + } + + // Tear down temporary file and directories. + for (int i = 0; i < kNumDumpRoots; ++i) { + int64 undeleted_files = 0; + int64 undeleted_dirs = 0; + ASSERT_TRUE(env_->DeleteRecursively(dump_roots[i], &undeleted_files, + &undeleted_dirs) + .ok()); + ASSERT_EQ(0, undeleted_files); + ASSERT_EQ(0, undeleted_dirs); + } +} + +TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) { + Initialize(); + + const int kConcurrentPubs = 3; + const string kNodeName = "tensor_a"; + const string kDebugOpName = "DebugIdentity"; + const int32 kOutputSlot = 0; + + thread::ThreadPool* tp = + new thread::ThreadPool(Env::Default(), "test", kConcurrentPubs); + uint64 wall_time = env_->NowMicros(); + + const string dump_root_base = testing::TmpDir(); + const string tensor_name = strings::StrCat(kNodeName, ":", kOutputSlot); + const string debug_node_name = + strings::StrCat(tensor_name, ":", kDebugOpName); + + mutex mu; + std::vector<string> dump_roots GUARDED_BY(mu); + std::vector<string> dump_file_paths GUARDED_BY(mu); + + int dump_count GUARDED_BY(mu) = 0; + int done_count GUARDED_BY(mu) = 0; + Notification all_done; + + auto fn = [this, &dump_count, &done_count, &mu, &dump_root_base, &dump_roots, + &dump_file_paths, &wall_time, &tensor_name, &debug_node_name, + &kNodeName, &kDebugOpName, &kConcurrentPubs, &all_done]() { + // "gumpy" is the shared directory part of the path. + string dump_root; + string debug_url; + { + mutex_lock l(mu); + dump_root = + strings::StrCat(dump_root_base, "grumpy/", "dump_", dump_count++); + + dump_roots.push_back(dump_root); + dump_file_paths.push_back(DebugFileIO::GetDumpFilePath( + dump_root, kNodeName, kOutputSlot, kDebugOpName, wall_time)); + + debug_url = strings::StrCat("file://", dump_root); + } + + std::vector<string> urls; + urls.push_back(debug_url); + Status s = DebugIO::PublishDebugTensor(tensor_name, kDebugOpName, + *tensor_a_, wall_time, urls); + ASSERT_TRUE(s.ok()); + + { + mutex_lock l(mu); + + done_count++; + if (done_count == kConcurrentPubs) { + all_done.Notify(); + } + } + }; + + for (int i = 0; i < kConcurrentPubs; ++i) { + tp->Schedule(fn); + } + + // Wait for all dumping calls to finish. + all_done.WaitForNotification(); + delete tp; + + { + mutex_lock l(mu); + + for (int i = 1; i < kConcurrentPubs; ++i) { + ASSERT_NE(dump_roots[0], dump_roots[i]); + } + + // Try reading the file into a Event proto. + for (int i = 0; i < kConcurrentPubs; ++i) { + // Read the file into a Event proto. + Event event; + TF_ASSERT_OK(ReadEventFromFile(dump_file_paths[i], &event)); + + ASSERT_GE(wall_time, event.wall_time()); + ASSERT_EQ(1, event.summary().value().size()); + ASSERT_EQ(debug_node_name, event.summary().value(0).node_name()); + + Tensor a_prime(DT_FLOAT); + ASSERT_TRUE(a_prime.FromProto(event.summary().value(0).tensor())); + + // Verify tensor shape and value. + ASSERT_EQ(tensor_a_->shape(), a_prime.shape()); + for (int i = 0; i < a_prime.flat<float>().size(); ++i) { + ASSERT_EQ(tensor_a_->flat<float>()(i), a_prime.flat<float>()(i)); + } + } + + // Tear down temporary file and directories. + int64 undeleted_files = 0; + int64 undeleted_dirs = 0; + ASSERT_TRUE(env_->DeleteRecursively(dump_root_base, &undeleted_files, + &undeleted_dirs) + .ok()); + ASSERT_EQ(0, undeleted_files); + ASSERT_EQ(0, undeleted_dirs); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 8d9066ab52c..a078488dd18 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -423,6 +423,7 @@ tf_kernel_libraries( "//tensorflow/core:lib_internal", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/debug:debug_io_utils", "//third_party/eigen3", ], ) diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index 3e46970812f..8132cf1f6b0 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_KERNELS_DEBUG_OP_H_ #include "tensorflow/core/common_runtime/gpu/gpu_util.h" +#include "tensorflow/core/debug/debug_io_utils.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_util.h" @@ -73,10 +74,16 @@ class DebugIdentityOp : public OpKernel { public: explicit DebugIdentityOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("tensor_name", &tensor_name_)); - // TODO(cais): Add debug_url + OP_REQUIRES_OK(context, context->GetAttr("debug_urls", &debug_urls_)); } void Compute(OpKernelContext* context) override { + if (!debug_urls_.empty()) { + DebugIO::PublishDebugTensor(tensor_name_, "DebugIdentity", + context->input(0), + Env::Default()->NowMicros(), debug_urls_); + } + context->set_output(0, context->input(0)); } @@ -84,6 +91,7 @@ class DebugIdentityOp : public OpKernel { private: string tensor_name_; + std::vector<string> debug_urls_; }; // NaN-counter op for debugging. @@ -92,6 +100,7 @@ class DebugNanCountOp : public OpKernel { public: explicit DebugNanCountOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("tensor_name", &tensor_name_)); + OP_REQUIRES_OK(context, context->GetAttr("debug_urls", &debug_urls_)); } void Compute(OpKernelContext* context) override { @@ -120,6 +129,7 @@ class DebugNanCountOp : public OpKernel { private: string tensor_name_; + std::vector<string> debug_urls_; }; // TODO(cais): Add DebugInfinityCount diff --git a/tensorflow/core/kernels/debug_ops_test.cc b/tensorflow/core/kernels/debug_ops_test.cc index e584d43e22d..e526754d316 100644 --- a/tensorflow/core/kernels/debug_ops_test.cc +++ b/tensorflow/core/kernels/debug_ops_test.cc @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <dirent.h> +#include <string.h> +#include <fstream> +#include <vector> + #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -22,20 +27,32 @@ limitations under the License. #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/event.pb.h" namespace tensorflow { namespace { class DebugIdentityOpTest : public OpsTestBase { protected: - Status Init(DataType input_type) { + Status Init(DataType input_type, const std::vector<string> debug_urls) { + env_ = Env::Default(); + TF_CHECK_OK(NodeDefBuilder("op", "DebugIdentity") .Input(FakeInput(input_type)) .Attr("tensor_name", "FakeTensor:0") + .Attr("debug_urls", debug_urls) .Finalize(node_def())); return InitOp(); } + + Status Init(DataType input_type) { + std::vector<string> empty_debug_urls; + return Init(input_type, empty_debug_urls); + } + + Env* env_; }; TEST_F(DebugIdentityOpTest, Int32Success_6) { @@ -48,6 +65,80 @@ TEST_F(DebugIdentityOpTest, Int32Success_6) { test::ExpectTensorEqual<int32>(expected, *GetOutput(0)); } +TEST_F(DebugIdentityOpTest, Int32Success_6_FileURLs) { + const int kNumDumpDirs = 3; + + const string tmp_dir = testing::TmpDir(); + + std::vector<string> dump_roots; + std::vector<string> debug_urls; + for (int i = 0; i < kNumDumpDirs; ++i) { + const string dump_root = strings::StrCat(tmp_dir, "_", i); + dump_roots.push_back(dump_root); + + debug_urls.push_back(strings::StrCat("file://", dump_root)); + } + + uint64 wall_time = Env::Default()->NowMicros(); + + TF_ASSERT_OK(Init(DT_INT32, debug_urls)); + AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_INT32, TensorShape({6})); + test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6}); + // Verify the identity output + test::ExpectTensorEqual<int32>(expected, *GetOutput(0)); + + for (int i = 0; i < kNumDumpDirs; ++i) { + ASSERT_TRUE(env_->FileExists(dump_roots[i])); + ASSERT_TRUE(env_->IsDirectory(dump_roots[i]).ok()); + + DIR* dir = opendir(dump_roots[i].c_str()); + struct dirent* ent; + int dump_files_found = 0; + while ((ent = readdir(dir)) != NULL) { + if (strcmp(ent->d_name, ".") && strcmp(ent->d_name, "..")) { + dump_files_found++; + + // Try reading the file into a Event proto. + const string dump_file_path = + strings::StrCat(dump_roots[i], "/", ent->d_name); + std::fstream ifs(dump_file_path, std::ios::in | std::ios::binary); + Event event; + event.ParseFromIstream(&ifs); + ifs.close(); + + ASSERT_GE(event.wall_time(), wall_time); + ASSERT_EQ(1, event.summary().value().size()); + ASSERT_EQ(strings::StrCat("FakeTensor", ":", 0, ":", "DebugIdentity"), + event.summary().value(0).node_name()); + + Tensor tensor_prime(DT_INT32); + ASSERT_TRUE(tensor_prime.FromProto(event.summary().value(0).tensor())); + + // Verify tensor shape and value from the dump file. + ASSERT_EQ(TensorShape({6}), tensor_prime.shape()); + + for (int j = 0; j < 6; ++j) { + ASSERT_EQ(j + 1, tensor_prime.flat<int32>()(j)); + } + } + } + closedir(dir); + + ASSERT_EQ(1, dump_files_found); + + // Remove temporary dump directory and file. + int64 undeleted_files = 0; + int64 undeleted_dirs = 0; + ASSERT_TRUE(env_->DeleteRecursively(dump_roots[i], &undeleted_files, + &undeleted_dirs) + .ok()); + ASSERT_EQ(0, undeleted_files); + ASSERT_EQ(0, undeleted_dirs); + } +} + TEST_F(DebugIdentityOpTest, Int32Success_2_3) { TF_ASSERT_OK(Init(DT_INT32)); AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6}); @@ -66,8 +157,6 @@ TEST_F(DebugIdentityOpTest, StringSuccess) { test::ExpectTensorEqual<string>(expected, *GetOutput(0)); } -TEST_F(DebugIdentityOpTest, RefInputError) { TF_ASSERT_OK(Init(DT_INT32_REF)); } - // Tests for DebugNanCountOp class DebugNanCountOpTest : public OpsTestBase { protected: diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc index 0861fa99821..63ad0059d45 100644 --- a/tensorflow/core/kernels/shape_ops.cc +++ b/tensorflow/core/kernels/shape_ops.cc @@ -253,6 +253,8 @@ class ExpandDimsOp : public OpKernel { " and output shape ", output_shape.DebugString())); } } + + bool IsExpensive() override { return false; } }; REGISTER_KERNEL_BUILDER(Name("ExpandDims").Device(DEVICE_CPU).HostMemory("dim"), ExpandDimsOp); @@ -342,6 +344,8 @@ class SqueezeOp : public OpKernel { } } + bool IsExpensive() override { return false; } + private: std::unordered_set<int32> squeeze_dims_; }; diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index fe3d7406961..5ba4e0cce69 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2985,6 +2985,7 @@ REGISTER_OP("DebugIdentity") .Output("output: T") .Attr("T: type") .Attr("tensor_name: string = ''") + .Attr("debug_urls: list(string) = []") .Doc(R"doc( Debug Identity Op. @@ -2993,6 +2994,8 @@ Provides an identity mapping of the non-Ref type input tensor for debugging. input: Input tensor, non-Reference type. output: Output tensor that equals the input tensor. tensor_name: Name of the input tensor. +debug_urls: List of URLs to debug targets, e.g., + file:///foo/tfdbg_dump, grpc:://localhost:11011 )doc"); REGISTER_OP("DebugNanCount") @@ -3000,6 +3003,7 @@ REGISTER_OP("DebugNanCount") .Output("output: int64") // The debug signal (nan count) is int64 .Attr("T: type") .Attr("tensor_name: string = ''") + .Attr("debug_urls: list(string) = []") .Doc(R"doc( Debug NaN Value Counter Op @@ -3008,6 +3012,8 @@ Counts number of NaNs in the input tensor, for debugging. input: Input tensor, non-Reference type. output: An integer output tensor that is the number of NaNs in the input. tensor_name: Name of the input tensor. +debug_urls: List of URLs to debug targets, e.g., + file:///foo/tfdbg_dump, grpc:://localhost:11011 )doc"); } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt index 8b3230c3fed..6c7556076a9 100644 --- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt @@ -7876,6 +7876,36 @@ op { } } } +op { + name: "DebugIdentity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "tensor_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "debug_urls" + type: "list(string)" + default_value { + list { + } + } + } +} op { name: "DebugNanCount" input_arg { @@ -7898,6 +7928,36 @@ op { } } } +op { + name: "DebugNanCount" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type: DT_INT64 + } + attr { + name: "T" + type: "type" + } + attr { + name: "tensor_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "debug_urls" + type: "list(string)" + default_value { + list { + } + } + } +} op { name: "DecodeCSV" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 5da2754fd47..01bb4bc82f8 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4278,6 +4278,15 @@ op { } description: "Name of the input tensor." } + attr { + name: "debug_urls" + type: "list(string)" + default_value { + list { + } + } + description: "List of URLs to debug targets, e.g.,\nfile:///foo/tfdbg_dump, grpc:://localhost:11011" + } summary: "Debug Identity Op." description: "Provides an identity mapping of the non-Ref type input tensor for debugging." } @@ -4305,6 +4314,15 @@ op { } description: "Name of the input tensor." } + attr { + name: "debug_urls" + type: "list(string)" + default_value { + list { + } + } + description: "List of URLs to debug targets, e.g.,\nfile:///foo/tfdbg_dump, grpc:://localhost:11011" + } summary: "Debug NaN Value Counter Op" description: "Counts number of NaNs in the input tensor, for debugging." } diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc index 1ad9f7175fc..ac213385054 100644 --- a/tensorflow/core/ops/sparse_ops.cc +++ b/tensorflow/core/ops/sparse_ops.cc @@ -662,13 +662,19 @@ keep_dims: If true, retain reduced dimensions with length 1. output: `R-K`-D. The reduced Tensor. )doc"); -#define SPARSE_DENSE_CWISE_SIGNATURE() \ - Input("sp_indices: int64") \ - .Input("sp_values: T") \ - .Input("sp_shape: int64") \ - .Input("dense: T") \ - .Output("output: T") \ - .Attr("T: numbertype") +#define SPARSE_DENSE_CWISE_SIGNATURE() \ + Input("sp_indices: int64") \ + .Input("sp_values: T") \ + .Input("sp_shape: int64") \ + .Input("dense: T") \ + .Output("output: T") \ + .Attr("T: numbertype") \ + .SetShapeFn([](InferenceContext* c) { \ + const Shape* input; \ + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &input)); \ + c->set_output(0, c->Vector(c->Dim(input, 0))); \ + return Status::OK(); \ + }) REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE().Doc(R"doc( Component-wise multiplies a SparseTensor by a dense Tensor. @@ -722,6 +728,8 @@ dense: `R`-D. The dense Tensor operand. output: 1-D. The `N` values that are operated on. )doc"); +#undef SPARSE_DENSE_CWISE_SIGNATURE + REGISTER_OP("SparseSoftmax") .Input("sp_indices: int64") .Input("sp_values: T") diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 33f38019d01..6f78f8cd8a9 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -180,7 +180,7 @@ message ConfigProto { int64 operation_timeout_in_ms = 11; }; -// EXPERIMENTAL. Option for watching a node +// EXPERIMENTAL. Option for watching a node. message DebugTensorWatch { // Name of the node to watch. string node_name = 1; @@ -196,6 +196,12 @@ message DebugTensorWatch { // One or more than one probes on a tensor. // e.g., {"DebugIdentity", "DebugNanCount"} repeated string debug_ops = 3; + + // URL(s) for debug targets(s). + // E.g., "file:///foo/tfdbg_dump", "grpc://localhost:11011" + // Each debug op listed in debug_ops will publish its output tensor (debug + // signal) to all URLs in debug_urls. + repeated string debug_urls = 4; } // EXPERIMENTAL. Options for a single Run() call. diff --git a/tensorflow/g3doc/api_docs/python/contrib.losses.md b/tensorflow/g3doc/api_docs/python/contrib.losses.md index 846718e196c..26d297b38f3 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.losses.md +++ b/tensorflow/g3doc/api_docs/python/contrib.losses.md @@ -140,6 +140,31 @@ Notice that the function adds the given losses to the regularization losses. * <b>`ValueError`</b>: if `losses` is not iterable. +- - - + +### `tf.contrib.losses.hinge_loss(logits, target, scope=None)` {#hinge_loss} + +Method that returns the loss tensor for hinge loss. + +##### Args: + + +* <b>`logits`</b>: The logits, a float tensor. +* <b>`target`</b>: The ground truth output tensor. Its shape should match the shape of + logits. The values of the tensor are expected to be 0.0 or 1.0. +* <b>`scope`</b>: The scope for the operations performed in computing the loss. + +##### Returns: + + A `Tensor` of same shape as logits and target representing the loss values + across the batch. + +##### Raises: + + +* <b>`ValueError`</b>: If the shapes of `logits` and `target` don't match. + + - - - ### `tf.contrib.losses.log_loss(predictions, targets, weight=1.0, epsilon=1e-07, scope=None)` {#log_loss} diff --git a/tensorflow/g3doc/api_docs/python/contrib.rnn.md b/tensorflow/g3doc/api_docs/python/contrib.rnn.md new file mode 100644 index 00000000000..201e23c66d3 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/contrib.rnn.md @@ -0,0 +1,409 @@ +<!-- This file is machine generated: DO NOT EDIT! --> + +# RNN (contrib) +[TOC] + +Additional RNN operations and cells. + +## This package provides additional contributed RNNCells. + +### Fused RNNCells +- - - + +### `class tf.contrib.rnn.LSTMFusedCell` {#LSTMFusedCell} + +Basic LSTM recurrent network cell. + +The implementation is based on: http://arxiv.org/abs/1409.2329. + +We add forget_bias (default: 1) to the biases of the forget gate in order to +reduce the scale of forgetting in the beginning of the training. + +Unlike BasicLSTMCell, this is a monolithic op and should be much faster. The +weight and bias matrixes should be compatible as long as the variabel scope +matches. +- - - + +#### `tf.contrib.rnn.LSTMFusedCell.__init__(num_units, forget_bias=1.0, use_peephole=False)` {#LSTMFusedCell.__init__} + +Initialize the basic LSTM cell. + +##### Args: + + +* <b>`num_units`</b>: int, The number of units in the LSTM cell. +* <b>`forget_bias`</b>: float, The bias added to forget gates (see above). +* <b>`use_peephole`</b>: Whether to use peephole connectios or not. + + +- - - + +#### `tf.contrib.rnn.LSTMFusedCell.output_size` {#LSTMFusedCell.output_size} + + + + +- - - + +#### `tf.contrib.rnn.LSTMFusedCell.state_size` {#LSTMFusedCell.state_size} + + + + +- - - + +#### `tf.contrib.rnn.LSTMFusedCell.zero_state(batch_size, dtype)` {#LSTMFusedCell.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + + + +### LSTM-like cells +- - - + +### `class tf.contrib.rnn.CoupledInputForgetGateLSTMCell` {#CoupledInputForgetGateLSTMCell} + +Long short-term memory unit (LSTM) recurrent network cell. + +The default non-peephole implementation is based on: + + http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf + +S. Hochreiter and J. Schmidhuber. +"Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + +The peephole implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + +Hasim Sak, Andrew Senior, and Francoise Beaufays. +"Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + +The coupling of input and forget gate is based on: + + http://arxiv.org/pdf/1503.04069.pdf + +Greff et al. "LSTM: A Search Space Odyssey" + +The class uses optional peep-hole connections, and an optional projection +layer. +- - - + +#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.__init__(num_units, use_peepholes=False, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=False, activation=tanh)` {#CoupledInputForgetGateLSTMCell.__init__} + +Initialize the parameters for an LSTM cell. + +##### Args: + + +* <b>`num_units`</b>: int, The number of units in the LSTM cell +* <b>`use_peepholes`</b>: bool, set True to enable diagonal/peephole connections. +* <b>`initializer`</b>: (optional) The initializer to use for the weight and + projection matrices. +* <b>`num_proj`</b>: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. +* <b>`proj_clip`</b>: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + +* <b>`num_unit_shards`</b>: How to split the weight matrix. If >1, the weight + matrix is stored across num_unit_shards. +* <b>`num_proj_shards`</b>: How to split the projection matrix. If >1, the + projection matrix is stored across num_proj_shards. +* <b>`forget_bias`</b>: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. +* <b>`state_is_tuple`</b>: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. By default (False), they are concatenated + along the column axis. This default behavior will soon be deprecated. +* <b>`activation`</b>: Activation function of the inner states. + + +- - - + +#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.output_size` {#CoupledInputForgetGateLSTMCell.output_size} + + + + +- - - + +#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.state_size` {#CoupledInputForgetGateLSTMCell.state_size} + + + + +- - - + +#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.zero_state(batch_size, dtype)` {#CoupledInputForgetGateLSTMCell.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + + +- - - + +### `class tf.contrib.rnn.TimeFreqLSTMCell` {#TimeFreqLSTMCell} + +Time-Frequency Long short-term memory unit (LSTM) recurrent network cell. + +This implementation is based on: + + Tara N. Sainath and Bo Li + "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures + for LVCSR Tasks." submitted to INTERSPEECH, 2016. + +It uses peep-hole connections and optional cell clipping. +- - - + +#### `tf.contrib.rnn.TimeFreqLSTMCell.__init__(num_units, use_peepholes=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None)` {#TimeFreqLSTMCell.__init__} + +Initialize the parameters for an LSTM cell. + +##### Args: + + +* <b>`num_units`</b>: int, The number of units in the LSTM cell +* <b>`use_peepholes`</b>: bool, set True to enable diagonal/peephole connections. +* <b>`cell_clip`</b>: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. +* <b>`initializer`</b>: (optional) The initializer to use for the weight and + projection matrices. +* <b>`num_unit_shards`</b>: int, How to split the weight matrix. If >1, the weight + matrix is stored across num_unit_shards. +* <b>`forget_bias`</b>: float, Biases of the forget gate are initialized by default + to 1 in order to reduce the scale of forgetting at the beginning + of the training. +* <b>`feature_size`</b>: int, The size of the input feature the LSTM spans over. +* <b>`frequency_skip`</b>: int, The amount the LSTM filter is shifted by in + frequency. + + +- - - + +#### `tf.contrib.rnn.TimeFreqLSTMCell.output_size` {#TimeFreqLSTMCell.output_size} + + + + +- - - + +#### `tf.contrib.rnn.TimeFreqLSTMCell.state_size` {#TimeFreqLSTMCell.state_size} + + + + +- - - + +#### `tf.contrib.rnn.TimeFreqLSTMCell.zero_state(batch_size, dtype)` {#TimeFreqLSTMCell.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + + +- - - + +### `class tf.contrib.rnn.GridLSTMCell` {#GridLSTMCell} + +Grid Long short-term memory unit (LSTM) recurrent network cell. + +The default is based on: + Nal Kalchbrenner, Ivo Danihelka and Alex Graves + "Grid Long Short-Term Memory," Proc. ICLR 2016. + http://arxiv.org/abs/1507.01526 + +When peephole connections are used, the implementation is based on: + Tara N. Sainath and Bo Li + "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures + for LVCSR Tasks." submitted to INTERSPEECH, 2016. + +The code uses optional peephole connections, shared_weights and cell clipping. +- - - + +#### `tf.contrib.rnn.GridLSTMCell.__init__(num_units, use_peepholes=False, share_time_frequency_weights=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None)` {#GridLSTMCell.__init__} + +Initialize the parameters for an LSTM cell. + +##### Args: + + +* <b>`num_units`</b>: int, The number of units in the LSTM cell +* <b>`use_peepholes`</b>: bool, default False. Set True to enable diagonal/peephole + connections. +* <b>`share_time_frequency_weights`</b>: bool, default False. Set True to enable + shared cell weights between time and frequency LSTMs. +* <b>`cell_clip`</b>: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. +* <b>`initializer`</b>: (optional) The initializer to use for the weight and + projection matrices. +* <b>`num_unit_shards`</b>: int, How to split the weight matrix. If >1, the weight + matrix is stored across num_unit_shards. +* <b>`forget_bias`</b>: float, Biases of the forget gate are initialized by default + to 1 in order to reduce the scale of forgetting at the beginning + of the training. +* <b>`feature_size`</b>: int, The size of the input feature the LSTM spans over. +* <b>`frequency_skip`</b>: int, The amount the LSTM filter is shifted by in + frequency. + + +- - - + +#### `tf.contrib.rnn.GridLSTMCell.output_size` {#GridLSTMCell.output_size} + + + + +- - - + +#### `tf.contrib.rnn.GridLSTMCell.state_size` {#GridLSTMCell.state_size} + + + + +- - - + +#### `tf.contrib.rnn.GridLSTMCell.zero_state(batch_size, dtype)` {#GridLSTMCell.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + + + +### RNNCell wrappers +- - - + +### `class tf.contrib.rnn.AttentionCellWrapper` {#AttentionCellWrapper} + +Basic attention cell wrapper. + +Implementation based on https://arxiv.org/pdf/1601.06733.pdf. +- - - + +#### `tf.contrib.rnn.AttentionCellWrapper.__init__(cell, attn_length, attn_size=None, attn_vec_size=None, input_size=None, state_is_tuple=False)` {#AttentionCellWrapper.__init__} + +Create a cell with attention. + +##### Args: + + +* <b>`cell`</b>: an RNNCell, an attention is added to it. +* <b>`attn_length`</b>: integer, the size of an attention window. +* <b>`attn_size`</b>: integer, the size of an attention vector. Equal to + cell.output_size by default. +* <b>`attn_vec_size`</b>: integer, the number of convolutional features calculated + on attention state and a size of the hidden layer built from + base cell state. Equal attn_size to by default. +* <b>`input_size`</b>: integer, the size of a hidden linear layer, + built from inputs and attention. Derived from the input tensor + by default. +* <b>`state_is_tuple`</b>: If True, accepted and returned states are n-tuples, where + `n = len(cells)`. By default (False), the states are all + concatenated along the column axis. + +##### Raises: + + +* <b>`TypeError`</b>: if cell is not an RNNCell. +* <b>`ValueError`</b>: if cell returns a state tuple but the flag + `state_is_tuple` is `False` or if attn_length is zero or less. + + +- - - + +#### `tf.contrib.rnn.AttentionCellWrapper.output_size` {#AttentionCellWrapper.output_size} + + + + +- - - + +#### `tf.contrib.rnn.AttentionCellWrapper.state_size` {#AttentionCellWrapper.state_size} + + + + +- - - + +#### `tf.contrib.rnn.AttentionCellWrapper.zero_state(batch_size, dtype)` {#AttentionCellWrapper.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + + diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md index 6108155a0ec..9bc9111fc39 100644 --- a/tensorflow/g3doc/api_docs/python/framework.md +++ b/tensorflow/g3doc/api_docs/python/framework.md @@ -1105,7 +1105,10 @@ DEPRECATED: Use outputs. ### `class tf.Tensor` {#Tensor} -Represents a value produced by an `Operation`. +Represents one of the outputs of an `Operation`. + +*Note:* the `Tensor` class will be replaced by `Output` in the future. +Currently these two are aliases for each other. A `Tensor` is a symbolic handle to one of the outputs of an `Operation`. It does not hold the values of that operation's output, diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.rnn.AttentionCellWrapper.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.rnn.AttentionCellWrapper.md new file mode 100644 index 00000000000..73f35490f75 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.rnn.AttentionCellWrapper.md @@ -0,0 +1,70 @@ +Basic attention cell wrapper. + +Implementation based on https://arxiv.org/pdf/1601.06733.pdf. +- - - + +#### `tf.contrib.rnn.AttentionCellWrapper.__init__(cell, attn_length, attn_size=None, attn_vec_size=None, input_size=None, state_is_tuple=False)` {#AttentionCellWrapper.__init__} + +Create a cell with attention. + +##### Args: + + +* <b>`cell`</b>: an RNNCell, an attention is added to it. +* <b>`attn_length`</b>: integer, the size of an attention window. +* <b>`attn_size`</b>: integer, the size of an attention vector. Equal to + cell.output_size by default. +* <b>`attn_vec_size`</b>: integer, the number of convolutional features calculated + on attention state and a size of the hidden layer built from + base cell state. Equal attn_size to by default. +* <b>`input_size`</b>: integer, the size of a hidden linear layer, + built from inputs and attention. Derived from the input tensor + by default. +* <b>`state_is_tuple`</b>: If True, accepted and returned states are n-tuples, where + `n = len(cells)`. By default (False), the states are all + concatenated along the column axis. + +##### Raises: + + +* <b>`TypeError`</b>: if cell is not an RNNCell. +* <b>`ValueError`</b>: if cell returns a state tuple but the flag + `state_is_tuple` is `False` or if attn_length is zero or less. + + +- - - + +#### `tf.contrib.rnn.AttentionCellWrapper.output_size` {#AttentionCellWrapper.output_size} + + + + +- - - + +#### `tf.contrib.rnn.AttentionCellWrapper.state_size` {#AttentionCellWrapper.state_size} + + + + +- - - + +#### `tf.contrib.rnn.AttentionCellWrapper.zero_state(batch_size, dtype)` {#AttentionCellWrapper.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md index 73af134a7a5..6925d9d6d7c 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md @@ -1,4 +1,7 @@ -Represents a value produced by an `Operation`. +Represents one of the outputs of an `Operation`. + +*Note:* the `Tensor` class will be replaced by `Output` in the future. +Currently these two are aliases for each other. A `Tensor` is a symbolic handle to one of the outputs of an `Operation`. It does not hold the values of that operation's output, diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.hinge_loss.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.hinge_loss.md new file mode 100644 index 00000000000..57758e07104 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.hinge_loss.md @@ -0,0 +1,22 @@ +### `tf.contrib.losses.hinge_loss(logits, target, scope=None)` {#hinge_loss} + +Method that returns the loss tensor for hinge loss. + +##### Args: + + +* <b>`logits`</b>: The logits, a float tensor. +* <b>`target`</b>: The ground truth output tensor. Its shape should match the shape of + logits. The values of the tensor are expected to be 0.0 or 1.0. +* <b>`scope`</b>: The scope for the operations performed in computing the loss. + +##### Returns: + + A `Tensor` of same shape as logits and target representing the loss values + across the batch. + +##### Raises: + + +* <b>`ValueError`</b>: If the shapes of `logits` and `target` don't match. + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.rnn.GridLSTMCell.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.rnn.GridLSTMCell.md new file mode 100644 index 00000000000..509f59748cd --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.rnn.GridLSTMCell.md @@ -0,0 +1,77 @@ +Grid Long short-term memory unit (LSTM) recurrent network cell. + +The default is based on: + Nal Kalchbrenner, Ivo Danihelka and Alex Graves + "Grid Long Short-Term Memory," Proc. ICLR 2016. + http://arxiv.org/abs/1507.01526 + +When peephole connections are used, the implementation is based on: + Tara N. Sainath and Bo Li + "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures + for LVCSR Tasks." submitted to INTERSPEECH, 2016. + +The code uses optional peephole connections, shared_weights and cell clipping. +- - - + +#### `tf.contrib.rnn.GridLSTMCell.__init__(num_units, use_peepholes=False, share_time_frequency_weights=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None)` {#GridLSTMCell.__init__} + +Initialize the parameters for an LSTM cell. + +##### Args: + + +* <b>`num_units`</b>: int, The number of units in the LSTM cell +* <b>`use_peepholes`</b>: bool, default False. Set True to enable diagonal/peephole + connections. +* <b>`share_time_frequency_weights`</b>: bool, default False. Set True to enable + shared cell weights between time and frequency LSTMs. +* <b>`cell_clip`</b>: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. +* <b>`initializer`</b>: (optional) The initializer to use for the weight and + projection matrices. +* <b>`num_unit_shards`</b>: int, How to split the weight matrix. If >1, the weight + matrix is stored across num_unit_shards. +* <b>`forget_bias`</b>: float, Biases of the forget gate are initialized by default + to 1 in order to reduce the scale of forgetting at the beginning + of the training. +* <b>`feature_size`</b>: int, The size of the input feature the LSTM spans over. +* <b>`frequency_skip`</b>: int, The amount the LSTM filter is shifted by in + frequency. + + +- - - + +#### `tf.contrib.rnn.GridLSTMCell.output_size` {#GridLSTMCell.output_size} + + + + +- - - + +#### `tf.contrib.rnn.GridLSTMCell.state_size` {#GridLSTMCell.state_size} + + + + +- - - + +#### `tf.contrib.rnn.GridLSTMCell.zero_state(batch_size, dtype)` {#GridLSTMCell.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.rnn.CoupledInputForgetGateLSTMCell.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.rnn.CoupledInputForgetGateLSTMCell.md new file mode 100644 index 00000000000..0e36b224bc6 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.rnn.CoupledInputForgetGateLSTMCell.md @@ -0,0 +1,93 @@ +Long short-term memory unit (LSTM) recurrent network cell. + +The default non-peephole implementation is based on: + + http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf + +S. Hochreiter and J. Schmidhuber. +"Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + +The peephole implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + +Hasim Sak, Andrew Senior, and Francoise Beaufays. +"Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + +The coupling of input and forget gate is based on: + + http://arxiv.org/pdf/1503.04069.pdf + +Greff et al. "LSTM: A Search Space Odyssey" + +The class uses optional peep-hole connections, and an optional projection +layer. +- - - + +#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.__init__(num_units, use_peepholes=False, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=False, activation=tanh)` {#CoupledInputForgetGateLSTMCell.__init__} + +Initialize the parameters for an LSTM cell. + +##### Args: + + +* <b>`num_units`</b>: int, The number of units in the LSTM cell +* <b>`use_peepholes`</b>: bool, set True to enable diagonal/peephole connections. +* <b>`initializer`</b>: (optional) The initializer to use for the weight and + projection matrices. +* <b>`num_proj`</b>: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. +* <b>`proj_clip`</b>: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + +* <b>`num_unit_shards`</b>: How to split the weight matrix. If >1, the weight + matrix is stored across num_unit_shards. +* <b>`num_proj_shards`</b>: How to split the projection matrix. If >1, the + projection matrix is stored across num_proj_shards. +* <b>`forget_bias`</b>: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. +* <b>`state_is_tuple`</b>: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. By default (False), they are concatenated + along the column axis. This default behavior will soon be deprecated. +* <b>`activation`</b>: Activation function of the inner states. + + +- - - + +#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.output_size` {#CoupledInputForgetGateLSTMCell.output_size} + + + + +- - - + +#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.state_size` {#CoupledInputForgetGateLSTMCell.state_size} + + + + +- - - + +#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.zero_state(batch_size, dtype)` {#CoupledInputForgetGateLSTMCell.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.rnn.TimeFreqLSTMCell.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.rnn.TimeFreqLSTMCell.md new file mode 100644 index 00000000000..e870477b6ba --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.rnn.TimeFreqLSTMCell.md @@ -0,0 +1,70 @@ +Time-Frequency Long short-term memory unit (LSTM) recurrent network cell. + +This implementation is based on: + + Tara N. Sainath and Bo Li + "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures + for LVCSR Tasks." submitted to INTERSPEECH, 2016. + +It uses peep-hole connections and optional cell clipping. +- - - + +#### `tf.contrib.rnn.TimeFreqLSTMCell.__init__(num_units, use_peepholes=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None)` {#TimeFreqLSTMCell.__init__} + +Initialize the parameters for an LSTM cell. + +##### Args: + + +* <b>`num_units`</b>: int, The number of units in the LSTM cell +* <b>`use_peepholes`</b>: bool, set True to enable diagonal/peephole connections. +* <b>`cell_clip`</b>: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. +* <b>`initializer`</b>: (optional) The initializer to use for the weight and + projection matrices. +* <b>`num_unit_shards`</b>: int, How to split the weight matrix. If >1, the weight + matrix is stored across num_unit_shards. +* <b>`forget_bias`</b>: float, Biases of the forget gate are initialized by default + to 1 in order to reduce the scale of forgetting at the beginning + of the training. +* <b>`feature_size`</b>: int, The size of the input feature the LSTM spans over. +* <b>`frequency_skip`</b>: int, The amount the LSTM filter is shifted by in + frequency. + + +- - - + +#### `tf.contrib.rnn.TimeFreqLSTMCell.output_size` {#TimeFreqLSTMCell.output_size} + + + + +- - - + +#### `tf.contrib.rnn.TimeFreqLSTMCell.state_size` {#TimeFreqLSTMCell.state_size} + + + + +- - - + +#### `tf.contrib.rnn.TimeFreqLSTMCell.zero_state(batch_size, dtype)` {#TimeFreqLSTMCell.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.rnn.LSTMFusedCell.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.rnn.LSTMFusedCell.md new file mode 100644 index 00000000000..fec80caecf1 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.rnn.LSTMFusedCell.md @@ -0,0 +1,60 @@ +Basic LSTM recurrent network cell. + +The implementation is based on: http://arxiv.org/abs/1409.2329. + +We add forget_bias (default: 1) to the biases of the forget gate in order to +reduce the scale of forgetting in the beginning of the training. + +Unlike BasicLSTMCell, this is a monolithic op and should be much faster. The +weight and bias matrixes should be compatible as long as the variabel scope +matches. +- - - + +#### `tf.contrib.rnn.LSTMFusedCell.__init__(num_units, forget_bias=1.0, use_peephole=False)` {#LSTMFusedCell.__init__} + +Initialize the basic LSTM cell. + +##### Args: + + +* <b>`num_units`</b>: int, The number of units in the LSTM cell. +* <b>`forget_bias`</b>: float, The bias added to forget gates (see above). +* <b>`use_peephole`</b>: Whether to use peephole connectios or not. + + +- - - + +#### `tf.contrib.rnn.LSTMFusedCell.output_size` {#LSTMFusedCell.output_size} + + + + +- - - + +#### `tf.contrib.rnn.LSTMFusedCell.state_size` {#LSTMFusedCell.state_size} + + + + +- - - + +#### `tf.contrib.rnn.LSTMFusedCell.zero_state(batch_size, dtype)` {#LSTMFusedCell.zero_state} + +Return zero-filled state tensor(s). + +##### Args: + + +* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size. +* <b>`dtype`</b>: the data type to use for the state. + +##### Returns: + + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with +the shapes `[batch_size x s]` for each s in `state_size`. + + diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index bad44886b63..448a32d72a5 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -745,12 +745,20 @@ * [`get_losses`](../../api_docs/python/contrib.losses.md#get_losses) * [`get_regularization_losses`](../../api_docs/python/contrib.losses.md#get_regularization_losses) * [`get_total_loss`](../../api_docs/python/contrib.losses.md#get_total_loss) + * [`hinge_loss`](../../api_docs/python/contrib.losses.md#hinge_loss) * [`log_loss`](../../api_docs/python/contrib.losses.md#log_loss) * [`sigmoid_cross_entropy`](../../api_docs/python/contrib.losses.md#sigmoid_cross_entropy) * [`softmax_cross_entropy`](../../api_docs/python/contrib.losses.md#softmax_cross_entropy) * [`sum_of_pairwise_squares`](../../api_docs/python/contrib.losses.md#sum_of_pairwise_squares) * [`sum_of_squares`](../../api_docs/python/contrib.losses.md#sum_of_squares) +* **[RNN (contrib)](../../api_docs/python/contrib.rnn.md)**: + * [`AttentionCellWrapper`](../../api_docs/python/contrib.rnn.md#AttentionCellWrapper) + * [`CoupledInputForgetGateLSTMCell`](../../api_docs/python/contrib.rnn.md#CoupledInputForgetGateLSTMCell) + * [`GridLSTMCell`](../../api_docs/python/contrib.rnn.md#GridLSTMCell) + * [`LSTMFusedCell`](../../api_docs/python/contrib.rnn.md#LSTMFusedCell) + * [`TimeFreqLSTMCell`](../../api_docs/python/contrib.rnn.md#TimeFreqLSTMCell) + * **[Metrics (contrib)](../../api_docs/python/contrib.metrics.md)**: * [`accuracy`](../../api_docs/python/contrib.metrics.md#accuracy) * [`aggregate_metric_map`](../../api_docs/python/contrib.metrics.md#aggregate_metric_map) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c5c5573f211..fce04e50f7c 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1182,6 +1182,18 @@ py_test( ], ) +py_test( + name = "session_debug_test", + size = "small", + srcs = ["debug/session_debug_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework", + ":framework_test_lib", + ":session", + ], +) + cuda_py_test( name = "timeline_test", size = "small", diff --git a/tensorflow/python/debug/session_debug_test.py b/tensorflow/python/debug/session_debug_test.py new file mode 100644 index 00000000000..d9fdb240c9d --- /dev/null +++ b/tensorflow/python/debug/session_debug_test.py @@ -0,0 +1,298 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for debugger functionalities in tf.Session.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import os +import shutil +import tempfile + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.util import event_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import test_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class SessionDebugTest(test_util.TensorFlowTestCase): + + def setUp(self): + self.dump_root_ = tempfile.mkdtemp() + + def tearDown(self): + # Tear down temporary dump directory. + shutil.rmtree(self.dump_root_) + + def _addDebugTensorWatch(self, + run_opts, + node_name, + output_slot, + debug_op="DebugIdentity", + debug_urls=None): + watch_opts = run_opts.debug_tensor_watch_opts + + # Add debug tensor watch for u. + watch = watch_opts.add() + watch.node_name = node_name + watch.output_slot = 0 + watch.debug_ops.append(debug_op) + + if debug_urls: + for debug_url in debug_urls: + watch.debug_urls.append(debug_url) + + def _verifyTensorDumpFile(self, dump_file, expected_tensor_name, debug_op, + wall_time_lower_bound, expected_tensor_val): + """Helper method: Verify a Tensor debug dump file and its content. + + Args: + dump_file: Path to the dump file. + expected_tensor_name: Expected name of the tensor, e.g., node_a:0. + debug_op: Name of the debug Op, e.g., DebugIdentity. + wall_time_lower_bound: Lower bound of the wall time. + expected_tensor_val: Expected tensor value, as a numpy array. + """ + self.assertTrue(os.path.isfile(dump_file)) + + event = event_pb2.Event() + f = open(dump_file, "rb") + event.ParseFromString(f.read()) + + wall_time = event.wall_time + debg_node_name = event.summary.value[0].node_name + + tensor_value = tensor_util.MakeNdarray(event.summary.value[0].tensor) + + self.assertGreater(wall_time, wall_time_lower_bound) + self.assertEqual("%s:%s" % (expected_tensor_name, debug_op), debg_node_name) + + if expected_tensor_val.dtype.type is np.string_: + self.assertEqual(str(expected_tensor_val), str(tensor_value)) + else: + self.assertAllClose(expected_tensor_val, tensor_value) + + def testDumpToFileOverlaoppinpParentDir(self): + with session.Session() as sess: + u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]]) + v_init_val = np.array([[2.0], [-1.0]]) + + # Use node names with overlapping namespace (i.e., parent directory) to + # test concurrent, non-racing directory creation. + u_name = "testDumpToFile/u" + v_name = "testDumpToFile/v" + + u_init = constant_op.constant(u_init_val, shape=[2, 2]) + u = variables.Variable(u_init, name=u_name) + v_init = constant_op.constant(v_init_val, shape=[2, 1]) + v = variables.Variable(v_init, name=v_name) + + w = math_ops.matmul(u, v, name="testDumpToFile/matmul") + + u.initializer.run() + v.initializer.run() + + run_options = config_pb2.RunOptions() + debug_url = "file://%s" % self.dump_root_ + + # Add debug tensor watch for u. + self._addDebugTensorWatch( + run_options, "%s/read" % u_name, 0, debug_urls=[debug_url]) + # Add debug tensor watch for v. + self._addDebugTensorWatch( + run_options, "%s/read" % v_name, 0, debug_urls=[debug_url]) + + run_metadata = config_pb2.RunMetadata() + + # Invoke Session.run(). + sess.run(w, options=run_options, run_metadata=run_metadata) + + # Verify the dump file for u. + dump_files = os.listdir(os.path.join(self.dump_root_, u_name)) + self.assertEqual(1, len(dump_files)) + self.assertTrue(dump_files[0].startswith("read_0_")) + + dump_file = os.path.join(self.dump_root_, u_name, dump_files[0]) + self._verifyTensorDumpFile(dump_file, "%s/read:0" % u_name, + "DebugIdentity", 0, u_init_val) + + # Verify the dump file for v. + dump_files = os.listdir(os.path.join(self.dump_root_, v_name)) + self.assertEqual(1, len(dump_files)) + self.assertTrue(dump_files[0].startswith("read_0_")) + + dump_file = os.path.join(self.dump_root_, v_name, dump_files[0]) + self._verifyTensorDumpFile(dump_file, "%s/read:0" % v_name, + "DebugIdentity", 0, v_init_val) + + def testDumpStringTensorsToFileSystem(self): + with session.Session() as sess: + str1_init_val = np.array(b"abc") + str2_init_val = np.array(b"def") + + str1_init = constant_op.constant(str1_init_val) + str2_init = constant_op.constant(str2_init_val) + + str1_name = "str1" + str2_name = "str2" + str1 = variables.Variable(str1_init, name=str1_name) + str2 = variables.Variable(str2_init, name=str2_name) + # Concatenate str1 and str2 + str_concat = math_ops.add(str1, str2, name="str_concat") + + str1.initializer.run() + str2.initializer.run() + + run_options = config_pb2.RunOptions() + debug_url = "file://%s" % self.dump_root_ + + # Add debug tensor watch for u. + self._addDebugTensorWatch( + run_options, "%s/read" % str1_name, 0, debug_urls=[debug_url]) + # Add debug tensor watch for v. + self._addDebugTensorWatch( + run_options, "%s/read" % str2_name, 0, debug_urls=[debug_url]) + + run_metadata = config_pb2.RunMetadata() + + # Invoke Session.run(). + sess.run(str_concat, options=run_options, run_metadata=run_metadata) + + # Verify the dump file for str1. + dump_files = os.listdir(os.path.join(self.dump_root_, str1_name)) + self.assertEqual(1, len(dump_files)) + self.assertTrue(dump_files[0].startswith("read_0_")) + dump_file = os.path.join(self.dump_root_, str1_name, dump_files[0]) + self._verifyTensorDumpFile(dump_file, "%s/read:0" % str1_name, + "DebugIdentity", 0, str1_init_val) + + # Verify the dump file for str2. + dump_files = os.listdir(os.path.join(self.dump_root_, str2_name)) + self.assertEqual(1, len(dump_files)) + self.assertTrue(dump_files[0].startswith("read_0_")) + dump_file = os.path.join(self.dump_root_, str2_name, dump_files[0]) + self._verifyTensorDumpFile(dump_file, "%s/read:0" % str2_name, + "DebugIdentity", 0, str2_init_val) + + def testDumpToFileWhileLoop(self): + with session.Session() as sess: + num_iter = 10 + + # "u" is the Variable being updated in the loop. + u_name = "testDumpToFileWhileLoop/u" + u_namespace = u_name.split("/")[0] + + u_init_val = np.array(11.0) + u_init = constant_op.constant(u_init_val) + u = variables.Variable(u_init, name=u_name) + + # "v" is the increment. + v_name = "testDumpToFileWhileLoop/v" + v_namespace = v_name.split("/")[0] + + v_init_val = np.array(2.0) + v_init = constant_op.constant(v_init_val) + v = variables.Variable(v_init, name=v_name) + + u.initializer.run() + v.initializer.run() + + i = constant_op.constant(0, name="testDumpToFileWhileLoop/i") + + def cond(i): + return math_ops.less(i, num_iter) + + def body(i): + new_u = state_ops.assign_add(u, v) + new_i = math_ops.add(i, 1) + op = control_flow_ops.group(new_u) + new_i = control_flow_ops.with_dependencies([op], new_i) + return [new_i] + + loop = control_flow_ops.while_loop(cond, body, [i], parallel_iterations=1) + + # Create RunOptions for debug-watching tensors + run_options = config_pb2.RunOptions() + debug_url = "file://%s" % self.dump_root_ + + # Add debug tensor watch for u. + self._addDebugTensorWatch(run_options, u_name, 0, debug_urls=[debug_url]) + # Add debug tensor watch for v. + self._addDebugTensorWatch( + run_options, "%s/read" % v_name, 0, debug_urls=[debug_url]) + # Add debug tensor watch for while/Identity. + self._addDebugTensorWatch( + run_options, "while/Identity", 0, debug_urls=[debug_url]) + + run_metadata = config_pb2.RunMetadata() + + r = sess.run(loop, options=run_options, run_metadata=run_metadata) + + self.assertEqual(num_iter, r) + + u_val_final = sess.run(u) + self.assertAllClose(u_init_val + num_iter * v_init_val, u_val_final) + + # Verify dump files + self.assertTrue(os.path.isdir(self.dump_root_)) + + self.assertTrue(os.path.isdir(os.path.join(self.dump_root_, u_namespace))) + self.assertTrue( + os.path.isdir(os.path.join(self.dump_root_, v_namespace, "v"))) + + # Verify the dump file for tensor "u". + dump_files = glob.glob( + os.path.join(self.dump_root_, u_namespace, "u_0_*")) + self.assertEqual(1, len(dump_files)) + dump_file = os.path.join(self.dump_root_, u_namespace, dump_files[0]) + self.assertTrue(os.path.isfile(dump_file)) + self._verifyTensorDumpFile(dump_file, "%s:0" % u_name, "DebugIdentity", 0, + u_init_val) + + # Verify the dump file for tensor "v". + dump_files = os.listdir(os.path.join(self.dump_root_, v_name)) + self.assertEqual(1, len(dump_files)) + self.assertTrue(dump_files[0].startswith("read_0_")) + + dump_file = os.path.join(self.dump_root_, v_name, dump_files[0]) + self._verifyTensorDumpFile(dump_file, "%s/read:0" % v_name, + "DebugIdentity", 0, v_init_val) + + # Verify the dump files for tensor while/Identity + while_identity_dump_files = sorted( + os.listdir(os.path.join(self.dump_root_, "while"))) + self.assertEqual(num_iter, len(while_identity_dump_files)) + + # Verify the content of the individual + for k in xrange(len(while_identity_dump_files)): + dump_file_path = os.path.join(self.dump_root_, "while", + while_identity_dump_files[k]) + self._verifyTensorDumpFile(dump_file_path, "while/Identity:0", + "DebugIdentity", 0, np.array(k)) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py index b06605cf592..3f77187a25c 100644 --- a/tensorflow/python/framework/framework_lib.py +++ b/tensorflow/python/framework/framework_lib.py @@ -72,6 +72,7 @@ from tensorflow.python.framework.device import DeviceSpec from tensorflow.python.framework.ops import Graph from tensorflow.python.framework.ops import Operation from tensorflow.python.framework.ops import Tensor +from tensorflow.python.framework.ops import Output from tensorflow.python.framework.ops import SparseTensor from tensorflow.python.framework.ops import SparseTensorValue from tensorflow.python.framework.ops import IndexedSlices diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index bfea7b6aca7..49d9cec7c19 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -65,6 +65,7 @@ def get_module_to_name(): tf.contrib.learn.monitors: ( "tf.contrib.learn.monitors"), tf.contrib.losses: "tf.contrib.losses", + tf.contrib.rnn: "tf.contrib.rnn", tf.contrib.metrics: "tf.contrib.metrics", tf.contrib.util: "tf.contrib.util", } @@ -171,6 +172,7 @@ def all_libraries(module_to_name, members, documented): library("contrib.learn.monitors", "Monitors (contrib)", tf.contrib.learn.monitors), library("contrib.losses", "Losses (contrib)", tf.contrib.losses), + library("contrib.rnn", "RNN (contrib)", tf.contrib.rnn), library("contrib.metrics", "Metrics (contrib)", tf.contrib.metrics), library("contrib.util", "Utilities (contrib)", tf.contrib.util), library("contrib.copy_graph", "Copying Graph Elements (contrib)", diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f89f3d46972..854d46b955e 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -185,7 +185,10 @@ def register_dense_tensor_like_type(tensor_type): class Tensor(object): - """Represents a value produced by an `Operation`. + """Represents one of the outputs of an `Operation`. + + *Note:* the `Tensor` class will be replaced by `Output` in the future. + Currently these two are aliases for each other. A `Tensor` is a symbolic handle to one of the outputs of an `Operation`. It does not hold the values of that operation's output, @@ -556,6 +559,10 @@ class Tensor(object): return _eval_using_default_session(self, feed_dict, self.graph, session) +# TODO(josh11b): Switch everyone from "Tensor" to "Output" to match C++ API. +Output = Tensor + + def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False): _ = name, as_ref if dtype and not dtype.is_compatible_with(t.dtype): diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index f63cf812474..abed6b5777a 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -846,8 +846,7 @@ def per_image_whitening(image): stddev = math_ops.sqrt(variance) # Apply a minimum normalization that protects us against uniform images. - min_stddev = math_ops.inv( - math_ops.sqrt(math_ops.cast(num_pixels, dtypes.float32))) + min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, dtypes.float32)) pixel_value_scale = math_ops.maximum(stddev, min_stddev) pixel_value_offset = image_mean diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 0620a3da2c4..2331f21d479 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -161,7 +161,7 @@ def _SegmentMeanGrad(op, grad): array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1)]) ones = array_ops.fill(ones_shape, constant_op.constant(1, dtype=grad.dtype)) - scaled_grad = grad * math_ops.inv(math_ops.segment_sum(ones, op.inputs[1])) + scaled_grad = math_ops.div(grad, math_ops.segment_sum(ones, op.inputs[1])) return array_ops.gather(scaled_grad, op.inputs[1]), None diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 73e51aab7de..562c0408b94 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1125,7 +1125,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): dtype=x.dtype) # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) binary_tensor = math_ops.floor(random_tensor) - ret = x * math_ops.inv(keep_prob) * binary_tensor + ret = math_ops.div(x, keep_prob) * binary_tensor ret.set_shape(x.get_shape()) return ret