diff --git a/configure b/configure index 030300e23a2..48ebebedf89 100755 --- a/configure +++ b/configure @@ -1,12 +1,38 @@ #!/bin/bash +## Set up python-related environment settings +while true; do + fromuser="" + if [ -z "$PYTHON_BIN_PATH" ]; then + default_python_bin_path=$(which python) + read -p "Please specify the location of python. [Default is $default_python_bin_path]: " PYTHON_BIN_PATH + fromuser="1" + if [ -z "$PYTHON_BIN_PATH" ]; then + PYTHON_BIN_PATH=$default_python_bin_path + fi + fi + if [ -e "$PYTHON_BIN_PATH" ]; then + break + fi + echo "Invalid python path. ${PYTHON_BIN_PATH} cannot be found" 1>&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + PYTHON_BIN_PATH="" + # Retry +done + +# Invoke python_config and set up symlinks to python includes +(./util/python/python_config.sh --setup "$PYTHON_BIN_PATH";) || exit -1 + ## Set up Cuda-related environment settings while [ "$TF_NEED_CUDA" == "" ]; do - read -p "Do you wish to build TensorFlow with GPU support? [y/n] " INPUT + read -p "Do you wish to build TensorFlow with GPU support? [y/N] " INPUT case $INPUT in - [Yy]* ) echo -e "GPU support will be enabled for TensorFlow\n"; TF_NEED_CUDA=1;; - [Nn]* ) echo -e "No GPU support will be enabled for TensorFlow\n"; TF_NEED_CUDA=0;; + [Yy]* ) echo "GPU support will be enabled for TensorFlow"; TF_NEED_CUDA=1;; + [Nn]* ) echo "No GPU support will be enabled for TensorFlow"; TF_NEED_CUDA=0;; + "" ) echo "No GPU support will be enabled for TensorFlow"; TF_NEED_CUDA=0;; * ) echo "Invalid selection: " $INPUT;; esac done @@ -77,7 +103,7 @@ CUDNN_INSTALL_PATH="$CUDNN_INSTALL_PATH" EOF function UnofficialSetting() { - echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n" + echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n" 1>&2 # Configure the compute capabilities that TensorFlow builds for. # Since Cuda toolkit is not backward-compatible, this is not guaranteed to work. diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc index d254d3bd8fe..dd51d7eea9e 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc @@ -342,6 +342,7 @@ size_t GPUBFCAllocator::AllocatedSize(void* ptr) { void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) { // For each bin: tally up the total number of chunks and bytes. + // Note that bins hold only free chunks. for (auto bit : bins_) { Bin* b = bit.second; @@ -389,6 +390,24 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) { LOG(INFO) << c->DebugString(true); } } -} + // Next show the the chunks that are in use, and also summarize their + // number by size. + std::map in_use_by_size; + for (auto& it : ptr_to_chunk_map_) { + const Chunk& c = *it.second; + in_use_by_size[c.size]++; + LOG(INFO) << "Chunk at " << it.first << " of size " << c.size; + } + + LOG(INFO) << " Summary of in-use Chunks by size: "; + size_t total_bytes = 0; + for (auto& it : in_use_by_size) { + LOG(INFO) << it.second << " Chunks of size " << it.first << " totalling " + << strings::HumanReadableNumBytes(it.first * it.second); + total_bytes += (it.first * it.second); + } + LOG(INFO) << "Sum Total of in-use chunks: " + << strings::HumanReadableNumBytes(total_bytes); +} } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h index 9371ef33209..b9651c01d3f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h @@ -115,14 +115,14 @@ class GPUBFCAllocator : public VisitableAllocator { }; Chunk* AllocateNewChunk(size_t num_bytes); - void SplitChunk(Chunk* c, size_t num_bytes); - void Merge(Chunk* c1, Chunk* c2); - void FreeAndMaybeCoalesce(Chunk* c); - void InsertFreeChunkIntoBin(Chunk* c); + void SplitChunk(Chunk* c, size_t num_bytes) EXCLUSIVE_LOCKS_REQUIRED(lock_); + void Merge(Chunk* c1, Chunk* c2) EXCLUSIVE_LOCKS_REQUIRED(lock_); + void FreeAndMaybeCoalesce(Chunk* c) EXCLUSIVE_LOCKS_REQUIRED(lock_); + void InsertFreeChunkIntoBin(Chunk* c) EXCLUSIVE_LOCKS_REQUIRED(lock_); void RemoveFreeChunkFromBin(Chunk* c); - void DeleteChunk(Chunk* c); + void DeleteChunk(Chunk* c) EXCLUSIVE_LOCKS_REQUIRED(lock_); - void DumpMemoryLog(size_t num_bytes); + void DumpMemoryLog(size_t num_bytes) EXCLUSIVE_LOCKS_REQUIRED(lock_); // A Bin is a collection of similar-sized free chunks. struct Bin { @@ -163,7 +163,7 @@ class GPUBFCAllocator : public VisitableAllocator { // Structures mutable after construction mutable mutex lock_; // Chunk * owned. - std::unordered_map ptr_to_chunk_map_; + std::unordered_map ptr_to_chunk_map_ GUARDED_BY(lock_); // Called once on each region, ASAP. std::vector region_visitors_; diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto index 4acd8ccd72d..f4d946dcf05 100644 --- a/tensorflow/core/example/example.proto +++ b/tensorflow/core/example/example.proto @@ -7,7 +7,21 @@ import "tensorflow/core/example/feature.proto"; package tensorflow; -// Example for a movie recommendation application: +// An Example is a mostly-normalized data format for storing data for +// training and inference. It contains a key-value store (features); where +// each key (string) maps to a Feature message (which is oneof packed BytesList, +// FloatList, or Int64List). This flexible and compact format allows the +// storage of large amounts of typed data, but requires that the data shape +// and use be determined by the configuration files and parsers that are used to +// read and write this format. That is, the Example is mostly *not* a +// self-describing format. In TensorFlow, Examples are read in row-major +// format, so any configuration that describes data with rank-2 or above +// should keep this in mind. For example, to store an M x N matrix of Bytes, +// the BytesList must contain M*N bytes, with M rows of N contiguous values +// each. That is, the BytesList value must store the matrix as: +// .... row 0 .... .... row 1 .... // ........... // ... row M-1 .... +// +// An Example for a movie recommendation application: // features { // feature { // key: "age" @@ -58,7 +72,7 @@ package tensorflow; // } // } // -// A conformant data set obeys the following conventions: +// A conformant Example data set obeys the following conventions: // - If a Feature K exists in one example with data type T, it must be of // type T in all other examples when present. It may be omitted. // - The number of instances of Feature K list data may vary across examples, @@ -72,23 +86,182 @@ message Example { Features features = 1; }; -// Example representing a ranking instance. -message RankingExample { - Features context = 1; - repeated Features positive = 2; - repeated Features negative = 3; -}; +// A SequenceExample is an Example representing one or more sequences, and +// some context. The context contains features which apply to the entire +// example. The feature_lists contain a key, value map where each key is +// associated with a repeated set of Features (a FeatureList). +// +// A SequenceExample for a movie recommendation application: +// +// context: { +// feature: { +// key : "locale" +// value: { +// bytes_list: { +// value: [ "pt_BR" ] +// } +// } +// } +// feature: { +// key : "age" +// value: { +// float_list: { +// value: [ 19.0 ] +// } +// } +// } +// feature: { +// key : "favorites" +// value: { +// bytes_list: { +// value: [ "Majesty Rose", "Savannah Outen", "One Direction" ] +// } +// } +// } +// } +// feature_lists: { +// feature_list: { +// key : "movie_ratings" +// value: { +// feature: { +// float_list: { +// value: [ 4.5 ] +// } +// } +// feature: { +// float_list: { +// value: [ 5.0 ] +// } +// } +// } +// } +// feature_list: { +// key : "movie_names" +// value: { +// feature: { +// bytes_list: { +// value: [ "The Shawshank Redemption" ] +// } +// } +// feature: { +// bytes_list: { +// value: [ "Fight Club" ] +// } +// } +// } +// } +// } +// +// A conformant SequenceExample data set obeys the following conventions: +// +// Context: +// - All conformant context features K must obey the same conventions as +// a conformant Example's features (see above). +// Feature lists: +// - A FeatureList L may be missing in an example; it is up to the +// parser configuration to determine if this is allowed or considered +// an empty list (zero length). +// - If a FeatureList L exists, it may be empty (zero length). +// - If a FeatureList L is non-empty, all features within the FeatureList +// must have data type T, and all features within the FeatureList must +// have the same size. +// - If a FeatureList L exists in one example with data type T, +// it must be of type T in all other examples when present. +// - If a FeatureList L exists in one example having features' sizes all S, +// these sizes must be S in all other examples when present. +// +// Examples of conformant and non-conformant examples' FeatureLists: +// +// Conformant FeatureLists: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// +// Non-conformant FeatureLists (mismatched types): +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { int64_list: { value: [ 5 ] } } } +// } } +// +// Non-conformant FeatureLists (mismatched sizes): +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0, 6.0 ] } } } +// } } +// +// Conformant pair of SequenceExample +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } +// feature: { float_list: { value: [ 2.0 ] } } } +// } } +// +// Conformant pair of SequenceExample +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { } +// } } +// +// Conditionally conformant pair of SequenceExample, the parser configuration +// determines if the second feature_lists is consistent (zero-length) or +// invalid (missing "movie_ratings"): +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { } +// +// Non-conformant pair of SequenceExample (mismatched types) +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { int64_list: { value: [ 4 ] } } +// feature: { int64_list: { value: [ 5 ] } } +// feature: { int64_list: { value: [ 2 ] } } } +// } } +// +// Non-conformant pair of SequenceExample (mismatched sizes) +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.0, 5.0 ] } } +// feature: { float_list: { value: [ 5.0, 3.0 ] } } +// } } -// Example representing a sequence. -// The context contains features which apply to the entire sequence. -// Each element in example represents an entry in the sequence. message SequenceExample { Features context = 1; - repeated Features features = 2; + FeatureLists feature_lists = 2; }; -// Example representing a list of feature maps. -// The context contains features which apply to all feature maps. message InferenceExample { Features context = 1; repeated Features features = 2; diff --git a/tensorflow/core/example/feature.proto b/tensorflow/core/example/feature.proto index 69f19233609..52d5fac4411 100644 --- a/tensorflow/core/example/feature.proto +++ b/tensorflow/core/example/feature.proto @@ -6,7 +6,8 @@ // - float // - int64 // -// Base features are contained in Lists which may hold zero or more values. +// A Feature contains Lists which may hold zero or more values. These +// lists are the base values BytesList, FloatList, Int64List. // // Features are organized into categories by name. The Features message // contains the mapping from name to Feature. @@ -50,12 +51,25 @@ // value: 9.99 // }} // } +// syntax = "proto3"; // option cc_enable_arenas = true; package tensorflow; +// Containers to hold repeated fundamental values. +message BytesList { + repeated bytes value = 1; +} +message FloatList { + repeated float value = 1 [packed = true]; +} +message Int64List { + repeated int64 value = 1 [packed = true]; +} + +// Containers for non-sequential data. message Feature { // Each feature can be exactly one kind. oneof kind { @@ -70,13 +84,19 @@ message Features { map feature = 1; }; -// Containers to hold repeated fundamental features. -message BytesList { - repeated bytes value = 1; -} -message FloatList { - repeated float value = 1 [packed = true]; -} -message Int64List { - repeated int64 value = 1 [packed = true]; -} +// Containers for sequential data. +// +// A FeatureList contains lists of Features. These may hold zero or more +// Feature values. +// +// FeatureLists are organized into categories by name. The FeatureLists message +// contains the mapping from name to FeatureList. +// +message FeatureList { + repeated Feature feature = 1; +}; + +message FeatureLists { + // Map from feature name to feature list. + map feature_list = 1; +}; diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 0b228c7b98b..94f029d5b3f 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -29,16 +29,6 @@ limitations under the License. namespace Eigen { namespace internal { -template -struct scalar_sign_op { - // TODO(zhifengc): this only works for real types. In theory, - // sign(x) = x / |x| works for both real and complex values. - EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op); - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { - return T(x > T(0)) - T(x < T(0)); - } -}; - // TODO(zhifengc): Eigen::internal::pow_impl does not have proper // EIGEN host/device decoration. We duplicate code here for now. template diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index 513d7b8390c..cd1093dd448 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -312,7 +312,7 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, int32 s = queues_[0].size(); if (closed_ && s < attempt->elements_requested) { attempt->context->SetStatus(errors::OutOfRange( - "RandomSuffleQueue '", name_, "' is closed and has ", + "RandomShuffleQueue '", name_, "' is closed and has ", "insufficient elements (requested ", attempt->elements_requested, ", current size ", s, ")")); return kComplete; diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index 8b78ef5df25..c6e876b8b54 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -42,18 +42,6 @@ class ReluOp : public UnaryElementWiseOp> { } }; -template -class Relu6Op : public UnaryElementWiseOp> { - public: - using UnaryElementWiseOp>::UnaryElementWiseOp; - - void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { - functor::Relu6 functor; - functor(context->eigen_device(), input.flat(), - output->flat()); - } -}; - template class ReluGradOp : public BinaryElementWiseOp> { public: @@ -75,6 +63,18 @@ class ReluGradOp : public BinaryElementWiseOp> { } }; +template +class Relu6Op : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Relu6 functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + template class Relu6GradOp : public BinaryElementWiseOp> { public: @@ -96,52 +96,110 @@ class Relu6GradOp : public BinaryElementWiseOp> { } }; -#define REGISTER_KERNELS(type) \ +template +class EluOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Elu functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class EluGradOp : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + // INPUTS: + // g (gradients): backpropagated gradients + // a (outputs): outputs of the EluOp() + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OP_REQUIRES(context, a.IsSameSize(g), + errors::InvalidArgument("g and a must be the same size")); + functor::EluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); + } +}; + +#define REGISTER_RELU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ Name("Relu").Device(DEVICE_CPU).TypeConstraint("T"), \ ReluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_CPU).TypeConstraint("T"), \ - Relu6Op); \ REGISTER_KERNEL_BUILDER( \ Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_CPU).TypeConstraint("T"), \ + Relu6Op); \ REGISTER_KERNEL_BUILDER( \ Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint("T"), \ Relu6GradOp) -TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); -#undef REGISTER_KERNELS +TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS); +#undef REGISTER_RELU_KERNELS + +#define REGISTER_ELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_CPU).TypeConstraint("T"), \ + EluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + EluGradOp) + +// Elu only makes sense with float or double. +TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS); +#undef REGISTER_ELU_KERNELS #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void Relu::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor features, \ - typename TTypes::Tensor activations); \ - extern template struct Relu; \ - \ - template <> \ - void ReluGrad::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, \ - typename TTypes::Tensor backprops); \ - \ - extern template struct ReluGrad; \ - template <> \ - void Relu6::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor features, \ - typename TTypes::Tensor activations); \ - extern template struct Relu6; \ - \ - template <> \ - void Relu6Grad::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, \ - typename TTypes::Tensor backprops); \ - extern template struct Relu6Grad; +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void Relu::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Relu; \ + \ + template <> \ + void ReluGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct ReluGrad; \ + \ + template <> \ + void Relu6::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Relu6; \ + \ + template <> \ + void Relu6Grad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor backprops); \ + extern template struct Relu6Grad; \ + \ + template <> \ + void Elu::operator()(const GPUDevice& d, \ + typename TTypes::ConstTensor features, \ + typename TTypes::Tensor activations); \ + extern template struct Elu; \ + \ + template <> \ + void EluGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor activations, \ + typename TTypes::Tensor backprops); \ + extern template struct EluGrad; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor @@ -151,15 +209,21 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); REGISTER_KERNEL_BUILDER( \ Name("Relu").Device(DEVICE_GPU).TypeConstraint("T"), \ ReluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_GPU).TypeConstraint("T"), \ - Relu6Op); \ REGISTER_KERNEL_BUILDER( \ Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_GPU).TypeConstraint("T"), \ + Relu6Op); \ REGISTER_KERNEL_BUILDER( \ Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint("T"), \ - Relu6GradOp) + Relu6GradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_GPU).TypeConstraint("T"), \ + EluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + EluGradOp) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index b5717a49c50..22e8bde0326 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -88,6 +88,40 @@ struct Relu6Grad { } }; +// Functor used by EluOp to do the computations. +template +struct Elu { + // Computes Relu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + // features.constant(?) + activations.device(d) = + (features < static_cast(0)) + .select(features.exp() - features.constant(static_cast(1)), + features); + } +}; + +// Functor used by EluGradOp to do the computations. +template +struct EluGrad { + // Computes EluGrad backprops. + // + // gradients: gradients backpropagated to the Elu op. + // activations: outputs of the Elu op. + // backprops: gradients to backpropagate to the Elu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor activations, + typename TTypes::Tensor backprops) { + backprops.device(d) = + (activations < static_cast(0)) + .select((activations + static_cast(1)) * gradients, gradients); + } +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index 65c17ad047a..6451619768b 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -29,11 +29,13 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; // Definition of the GPU implementations declared in relu_op.cc. -#define DEFINE_GPU_KERNELS(T) \ - template struct functor::Relu; \ - template struct functor::ReluGrad; \ - template struct functor::Relu6; \ - template struct functor::Relu6Grad; +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Relu; \ + template struct functor::ReluGrad; \ + template struct functor::Relu6; \ + template struct functor::Relu6Grad; \ + template struct functor::Elu; \ + template struct functor::EluGrad TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 29a71730950..01fc92c072a 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -445,6 +445,31 @@ backprops: The gradients: `gradients * features * (features > 0) * (features < 6)`. )doc"); +REGISTER_OP("Elu") + .Input("features: T") + .Output("activations: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. + +See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) +](http://arxiv.org/abs/1511.07289) +)doc"); + +REGISTER_OP("EluGrad") + .Input("gradients: T") + .Input("outputs: T") + .Output("backprops: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Computes gradients for the exponential linear (Elu) operation. + +gradients: The backpropagated gradients to the corresponding Elu operation. +outputs: The outputs of the corresponding Elu operation. +backprops: The gradients: `gradients * (outputs + 1)` if outputs < 0, +`gradients` otherwise. +)doc"); + REGISTER_OP("Softplus") .Input("features: T") .Output("activations: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index bcff941d82d..b529ef9c76a 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -2096,6 +2096,58 @@ op { summary: "Computes the (possibly normalized) Levenshtein Edit Distance." description: "The inputs are variable-length sequences provided by SparseTensors\n (hypothesis_indices, hypothesis_values, hypothesis_shape)\nand\n (truth_indices, truth_values, truth_shape).\n\nThe inputs are:" } +op { + name: "Elu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + summary: "Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise." + description: "See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)\n](http://arxiv.org/abs/1511.07289)" +} +op { + name: "EluGrad" + input_arg { + name: "gradients" + description: "The backpropagated gradients to the corresponding Elu operation." + type_attr: "T" + } + input_arg { + name: "outputs" + description: "The outputs of the corresponding Elu operation." + type_attr: "T" + } + output_arg { + name: "backprops" + description: "The gradients: `gradients * (outputs + 1)` if outputs < 0,\n`gradients` otherwise." + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + summary: "Computes gradients for the exponential linear (Elu) operation." +} op { name: "EncodeJpeg" input_arg { diff --git a/tensorflow/examples/android/jni/imageutils_jni.cc b/tensorflow/examples/android/jni/imageutils_jni.cc index 1139ecaa3b3..c3063c2d36c 100644 --- a/tensorflow/examples/android/jni/imageutils_jni.cc +++ b/tensorflow/examples/android/jni/imageutils_jni.cc @@ -38,10 +38,14 @@ IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)( JNIEnv* env, jclass clazz, jbyteArray input, jintArray output, jint width, jint height, jboolean halfSize); -JNIEXPORT void JNICALL -IMAGEUTILS_METHOD(convertYUV420SPToRGB565)( - JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, - jint width, jint height); +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)( + JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v, + jintArray output, jint width, jint height, jint y_row_stride, + jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize); + +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)( + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width, + jint height); JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)( @@ -82,10 +86,39 @@ IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)( env->ReleaseIntArrayElements(output, o, 0); } -JNIEXPORT void JNICALL -IMAGEUTILS_METHOD(convertYUV420SPToRGB565)( - JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, - jint width, jint height) { +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)( + JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v, + jintArray output, jint width, jint height, jint y_row_stride, + jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize) { + jboolean inputCopy = JNI_FALSE; + jbyte* const y_buff = env->GetByteArrayElements(y, &inputCopy); + jboolean outputCopy = JNI_FALSE; + jint* const o = env->GetIntArrayElements(output, &outputCopy); + + if (halfSize) { + ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast(y_buff), + reinterpret_cast(o), width, + height); + } else { + jbyte* const u_buff = env->GetByteArrayElements(u, &inputCopy); + jbyte* const v_buff = env->GetByteArrayElements(v, &inputCopy); + + ConvertYUV420ToARGB8888( + reinterpret_cast(y_buff), reinterpret_cast(u_buff), + reinterpret_cast(v_buff), reinterpret_cast(o), width, + height, y_row_stride, uv_row_stride, uv_pixel_stride); + + env->ReleaseByteArrayElements(u, u_buff, JNI_ABORT); + env->ReleaseByteArrayElements(v, v_buff, JNI_ABORT); + } + + env->ReleaseByteArrayElements(y, y_buff, JNI_ABORT); + env->ReleaseIntArrayElements(output, o, 0); +} + +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)( + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width, + jint height) { jboolean inputCopy = JNI_FALSE; jbyte* const i = env->GetByteArrayElements(input, &inputCopy); diff --git a/tensorflow/examples/android/jni/yuv2rgb.cc b/tensorflow/examples/android/jni/yuv2rgb.cc index c50a2268061..0e73fb4340b 100644 --- a/tensorflow/examples/android/jni/yuv2rgb.cc +++ b/tensorflow/examples/android/jni/yuv2rgb.cc @@ -27,6 +27,58 @@ limitations under the License. // are normalized to eight bits. static const int kMaxChannelValue = 262143; +static inline uint32 YUV2RGB(int nY, int nU, int nV) { + nY -= 16; + nU -= 128; + nV -= 128; + if (nY < 0) nY = 0; + + // This is the floating point equivalent. We do the conversion in integer + // because some Android devices do not have floating point in hardware. + // nR = (int)(1.164 * nY + 2.018 * nU); + // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU); + // nB = (int)(1.164 * nY + 1.596 * nV); + + int nR = (int)(1192 * nY + 1634 * nV); + int nG = (int)(1192 * nY - 833 * nV - 400 * nU); + int nB = (int)(1192 * nY + 2066 * nU); + + nR = MIN(kMaxChannelValue, MAX(0, nR)); + nG = MIN(kMaxChannelValue, MAX(0, nG)); + nB = MIN(kMaxChannelValue, MAX(0, nB)); + + nR = (nR >> 10) & 0xff; + nG = (nG >> 10) & 0xff; + nB = (nB >> 10) & 0xff; + + return 0xff000000 | (nR << 16) | (nG << 8) | nB; +} + +// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by +// separate u and v planes with arbitrary row and column strides, +// containing 8 bit 2x2 subsampled chroma samples. +// Converts to a packed ARGB 32 bit output of the same pixel dimensions. +void ConvertYUV420ToARGB8888(const uint8* const yData, const uint8* const uData, + const uint8* const vData, uint32* const output, + const int width, const int height, + const int y_row_stride, const int uv_row_stride, + const int uv_pixel_stride) { + uint32* out = output; + + for (int y = 0; y < height; y++) { + const uint8* pY = yData + y_row_stride * y; + + const int uv_row_start = uv_row_stride * (y >> 1); + const uint8* pU = uData + uv_row_start; + const uint8* pV = vData + uv_row_start; + + for (int x = 0; x < width; x++) { + const int uv_offset = (x >> 1) * uv_pixel_stride; + *out++ = YUV2RGB(pY[x], pU[uv_offset], pV[uv_offset]); + } + } +} + // Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an // interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples, // except the interleave order of U and V is reversed. Converts to a packed @@ -51,29 +103,7 @@ void ConvertYUV420SPToARGB8888(const uint8* const yData, int nU = pUV[offset + 1]; #endif - nY -= 16; - nU -= 128; - nV -= 128; - if (nY < 0) nY = 0; - - // This is the floating point equivalent. We do the conversion in integer - // because some Android devices do not have floating point in hardware. - // nR = (int)(1.164 * nY + 2.018 * nU); - // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU); - // nB = (int)(1.164 * nY + 1.596 * nV); - - int nR = (int)(1192 * nY + 1634 * nV); - int nG = (int)(1192 * nY - 833 * nV - 400 * nU); - int nB = (int)(1192 * nY + 2066 * nU); - - nR = MIN(kMaxChannelValue, MAX(0, nR)); - nG = MIN(kMaxChannelValue, MAX(0, nG)); - nB = MIN(kMaxChannelValue, MAX(0, nB)); - - nR = (nR >> 10) & 0xff; - nG = (nG >> 10) & 0xff; - nB = (nB >> 10) & 0xff; - *out++ = 0xff000000 | (nR << 16) | (nG << 8) | nB; + *out++ = YUV2RGB(nY, nU, nV); } } } @@ -101,23 +131,7 @@ void ConvertYUV420SPToARGB8888HalfSize(const uint8* const input, int nU = *pUV++; #endif - nY -= 16; - nU -= 128; - nV -= 128; - if (nY < 0) nY = 0; - - int nR = (int)(1192 * nY + 1634 * nV); - int nG = (int)(1192 * nY - 833 * nV - 400 * nU); - int nB = (int)(1192 * nY + 2066 * nU); - - nR = MIN(kMaxChannelValue, MAX(0, nR)); - nG = MIN(kMaxChannelValue, MAX(0, nG)); - nB = MIN(kMaxChannelValue, MAX(0, nB)); - - nR = (nR >> 10) & 0xff; - nG = (nG >> 10) & 0xff; - nB = (nB >> 10) & 0xff; - *out++ = 0xff000000 | (nR << 16) | (nG << 8) | nB; + *out++ = YUV2RGB(nY, nU, nV); } pY += stride; } diff --git a/tensorflow/examples/android/jni/yuv2rgb.h b/tensorflow/examples/android/jni/yuv2rgb.h index fda526f65a4..37a996f0b3b 100644 --- a/tensorflow/examples/android/jni/yuv2rgb.h +++ b/tensorflow/examples/android/jni/yuv2rgb.h @@ -27,6 +27,12 @@ using namespace tensorflow; extern "C" { #endif +void ConvertYUV420ToARGB8888(const uint8* const yData, const uint8* const uData, + const uint8* const vData, uint32* const output, + const int width, const int height, + const int y_row_stride, const int uv_row_stride, + const int uv_pixel_stride); + // Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width // and height. The input and output must already be allocated and non-null. // For efficiency, no error checking is performed. diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java index 0deb7c151eb..16aa5813b15 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java @@ -24,12 +24,12 @@ import android.media.Image; import android.media.Image.Plane; import android.media.ImageReader; import android.media.ImageReader.OnImageAvailableListener; + import junit.framework.Assert; import org.tensorflow.demo.env.ImageUtils; import org.tensorflow.demo.env.Logger; -import java.nio.ByteBuffer; import java.util.List; /** @@ -38,7 +38,7 @@ import java.util.List; public class TensorflowImageListener implements OnImageAvailableListener { private static final Logger LOGGER = new Logger(); - private static final boolean SAVE_PREVIEW_BITMAP = false; + private static final boolean SAVE_PREVIEW_BITMAP = true; private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb"; private static final String LABEL_FILE = @@ -55,7 +55,7 @@ public class TensorflowImageListener implements OnImageAvailableListener { private int previewWidth = 0; private int previewHeight = 0; - private byte[] yuvBytes = null; + private byte[][] yuvBytes; private int[] rgbBytes = null; private Bitmap rgbFrameBitmap = null; private Bitmap croppedBitmap = null; @@ -68,44 +68,6 @@ public class TensorflowImageListener implements OnImageAvailableListener { this.scoreView = scoreView; } - private void readPlanesToYuvBuffer(final Plane[] planes, final byte[] yuvBytes) { - int position = 0; - - // Copy the bytes from the Image into a buffer for easier conversion to RGB. - // TODO(andrewharp): Modify native code to accept multiple buffers so that - // only one pass is necessary during conversion to RGB. - final Plane yPlane = planes[0]; - final ByteBuffer yBuffer = yPlane.getBuffer(); - final int yRowStride = yPlane.getRowStride(); - - // Read the y (luminance buffer). - for (int row = 0; row < previewHeight; ++row) { - yBuffer.position(yRowStride * row); - - // Pixel stride is guaranteed to be 1 so we can - // just do a copy operation. - yBuffer.get(yuvBytes, position, previewWidth); - position += previewWidth; - } - - // Interleave the u and v buffers. - final ByteBuffer uBuffer = planes[1].getBuffer(); - final ByteBuffer vBuffer = planes[2].getBuffer(); - final int uvPixelStride = planes[1].getPixelStride(); - final int uvWidth = previewWidth / 2; - final int uvHeight = previewHeight / 2; - Assert.assertEquals( - planes[1].getRowStride(), planes[2].getRowStride()); - for (int y = 0; y < uvHeight; ++y) { - int readPos = planes[1].getRowStride() * y; - for (int x = 0; x < uvWidth; ++x) { - yuvBytes[position++] = vBuffer.get(readPos); - yuvBytes[position++] = uBuffer.get(readPos); - readPos += uvPixelStride; - } - } - } - private void drawResizedBitmap(final Bitmap src, final Bitmap dst) { Assert.assertEquals(dst.getWidth(), dst.getHeight()); final float minDim = Math.min(src.getWidth(), src.getHeight()); @@ -141,6 +103,8 @@ public class TensorflowImageListener implements OnImageAvailableListener { return; } + final Plane[] planes = image.getPlanes(); + // Initialize the storage bitmaps once when the resolution is known. if (previewWidth != image.getWidth() || previewHeight != image.getHeight()) { previewWidth = image.getWidth(); @@ -148,16 +112,35 @@ public class TensorflowImageListener implements OnImageAvailableListener { LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); rgbBytes = new int[previewWidth * previewHeight]; - yuvBytes = new byte[ImageUtils.getYUVByteSize(previewWidth, previewHeight)]; rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888); + + yuvBytes = new byte[planes.length][]; + for (int i = 0; i < planes.length; ++i) { + yuvBytes[i] = new byte[planes[i].getBuffer().capacity()]; + } } - readPlanesToYuvBuffer(image.getPlanes(), yuvBytes); + for (int i = 0; i < planes.length; ++i) { + planes[i].getBuffer().get(yuvBytes[i]); + } + + final int yRowStride = planes[0].getRowStride(); + final int uvRowStride = planes[1].getRowStride(); + final int uvPixelStride = planes[1].getPixelStride(); + ImageUtils.convertYUV420ToARGB8888( + yuvBytes[0], + yuvBytes[1], + yuvBytes[2], + rgbBytes, + previewWidth, + previewHeight, + yRowStride, + uvRowStride, + uvPixelStride, + false); image.close(); - - ImageUtils.convertYUV420SPToARGB8888(yuvBytes, rgbBytes, previewWidth, previewHeight, false); } catch (final Exception e) { if (image != null) { image.close(); diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java index f9e564908b0..9c87e50a764 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java @@ -87,6 +87,32 @@ public class ImageUtils { public static native void convertYUV420SPToARGB8888( byte[] input, int[] output, int width, int height, boolean halfSize); + /** + * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width + * and height. The input and output must already be allocated and non-null. + * For efficiency, no error checking is performed. + * + * @param y + * @param u + * @param v + * @param uvPixelStride + * @param width The width of the input image. + * @param height The height of the input image. + * @param halfSize If true, downsample to 50% in each dimension, otherwise not. + * @param output A pre-allocated array for the ARGB 8:8:8:8 output data. + */ + public static native void convertYUV420ToARGB8888( + byte[] y, + byte[] u, + byte[] v, + int[] output, + int width, + int height, + int yRowStride, + int uvRowStride, + int uvPixelStride, + boolean halfSize); + /** * Converts YUV420 semi-planar data to RGB 565 data using the supplied width * and height. The input and output must already be allocated and non-null. diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md index 13aa64aef7d..3aec8c11367 100644 --- a/tensorflow/g3doc/api_docs/python/constant_op.md +++ b/tensorflow/g3doc/api_docs/python/constant_op.md @@ -402,7 +402,7 @@ deviations from the mean are dropped and re-picked. - - - -### `tf.random_uniform(shape, minval=0.0, maxval=1.0, dtype=tf.float32, seed=None, name=None)` {#random_uniform} +### `tf.random_uniform(shape, minval=0, maxval=None, dtype=tf.float32, seed=None, name=None)` {#random_uniform} Outputs random values from a uniform distribution. @@ -410,15 +410,24 @@ The generated values follow a uniform distribution in the range `[minval, maxval)`. The lower bound `minval` is included in the range, while the upper bound `maxval` is excluded. +For floats, the default range is `[0, 1)`. For ints, at least `maxval` must +be specified explicitly. + +In the integer case, the random integers are slightly biased unless +`maxval - minval` is an exact power of two. The bias is small for values of +`maxval - minval` significantly smaller than the range of the output (either +`2**32` or `2**64`). + ##### Args: * `shape`: A 1-D integer Tensor or Python array. The shape of the output tensor. * `minval`: A 0-D Tensor or Python value of type `dtype`. The lower bound on the - range of random values to generate. + range of random values to generate. Defaults to 0. * `maxval`: A 0-D Tensor or Python value of type `dtype`. The upper bound on - the range of random values to generate. -* `dtype`: The type of the output. + the range of random values to generate. Defaults to 1 if `dtype` is + floating point. +* `dtype`: The type of the output: `float32`, `float64`, `int32`, or `int64`. * `seed`: A Python integer. Used to create a random seed for the distribution. See [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) @@ -429,6 +438,11 @@ the upper bound `maxval` is excluded. A tensor of the specified shape filled with random uniform values. +##### Raises: + + +* `ValueError`: If `dtype` is integral and `maxval` is not specified. + - - - diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index b04559287cd..b46ef9ffbc6 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -270,6 +270,7 @@ * [`conv2d`](../../api_docs/python/nn.md#conv2d) * [`depthwise_conv2d`](../../api_docs/python/nn.md#depthwise_conv2d) * [`dropout`](../../api_docs/python/nn.md#dropout) + * [`elu`](../../api_docs/python/nn.md#elu) * [`embedding_lookup`](../../api_docs/python/nn.md#embedding_lookup) * [`fixed_unigram_candidate_sampler`](../../api_docs/python/nn.md#fixed_unigram_candidate_sampler) * [`in_top_k`](../../api_docs/python/nn.md#in_top_k) @@ -332,6 +333,8 @@ * [`Coordinator`](../../api_docs/python/train.md#Coordinator) * [`exponential_decay`](../../api_docs/python/train.md#exponential_decay) * [`ExponentialMovingAverage`](../../api_docs/python/train.md#ExponentialMovingAverage) + * [`FeatureList`](../../api_docs/python/train.md#FeatureList) + * [`FeatureLists`](../../api_docs/python/train.md#FeatureLists) * [`FtrlOptimizer`](../../api_docs/python/train.md#FtrlOptimizer) * [`global_norm`](../../api_docs/python/train.md#global_norm) * [`global_step`](../../api_docs/python/train.md#global_step) diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 67c315745dc..1ce124a8350 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -10,9 +10,10 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by ## Activation Functions The activation ops provide different types of nonlinearities for use in neural -networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `softplus`, -and `softsign`), continuous but not everywhere differentiable functions (`relu`, -`relu6`, and `relu_x`), and random regularization (`dropout`). +networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`, +`softplus`, and `softsign`), continuous but not everywhere differentiable +functions (`relu`, `relu6`, and `relu_x`), and random regularization +(`dropout`). All activation ops apply componentwise, and produce a tensor of the same shape as the input tensor. @@ -52,6 +53,26 @@ Computes Rectified Linear 6: `min(max(features, 0), 6)`. A `Tensor` with the same type as `features`. +- - - + +### `tf.nn.elu(features, name=None)` {#elu} + +Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. + +See [Fast and Aaccurate Deep Network Learning by Exponential Linear Units (ELUs) +](http://arxiv.org/abs/1511.07289) + +##### Args: + + +* `features`: A `Tensor`. Must be one of the following types: `float32`, `float64`. +* `name`: A name for the operation (optional). + +##### Returns: + + A `Tensor`. Has the same type as `features`. + + - - - ### `tf.nn.softplus(features, name=None)` {#softplus} diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index 8b2e8b379f7..82c2047e2c9 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -946,7 +946,7 @@ assert v1 == v Sharing a variable by capturing a scope and setting reuse: ```python -with tf.variable_scope("foo") as scope. +with tf.variable_scope("foo") as scope: v = tf.get_variable("v", [1]) scope.reuse_variables() v1 = tf.get_variable("v", [1]) @@ -957,7 +957,7 @@ To prevent accidental sharing of variables, we raise an exception when getting an existing variable in a non-reusing scope. ```python -with tf.variable_scope("foo") as scope. +with tf.variable_scope("foo"): v = tf.get_variable("v", [1]) v1 = tf.get_variable("v", [1]) # Raises ValueError("... v already exists ..."). diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index b686968a8c1..6e7ecf8350c 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -1788,3 +1788,351 @@ tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt') * `as_text`: If `True`, writes the graph as an ASCII proto. + +## Other Functions and Classes +- - - + +### `class tf.train.FeatureList` {#FeatureList} + + +- - - + +#### `tf.train.FeatureList.ByteSize()` {#FeatureList.ByteSize} + + + + +- - - + +#### `tf.train.FeatureList.Clear()` {#FeatureList.Clear} + + + + +- - - + +#### `tf.train.FeatureList.ClearExtension(extension_handle)` {#FeatureList.ClearExtension} + + + + +- - - + +#### `tf.train.FeatureList.ClearField(field_name)` {#FeatureList.ClearField} + + + + +- - - + +#### `tf.train.FeatureList.CopyFrom(other_msg)` {#FeatureList.CopyFrom} + +Copies the content of the specified message into the current message. + +The method clears the current message and then merges the specified +message using MergeFrom. + +##### Args: + + +* `other_msg`: Message to copy into the current one. + + +- - - + +#### `tf.train.FeatureList.FindInitializationErrors()` {#FeatureList.FindInitializationErrors} + +Finds required fields which are not initialized. + +##### Returns: + + A list of strings. Each string is a path to an uninitialized field from + the top-level message, e.g. "foo.bar[5].baz". + + +- - - + +#### `tf.train.FeatureList.FromString(s)` {#FeatureList.FromString} + + + + +- - - + +#### `tf.train.FeatureList.HasExtension(extension_handle)` {#FeatureList.HasExtension} + + + + +- - - + +#### `tf.train.FeatureList.HasField(field_name)` {#FeatureList.HasField} + + + + +- - - + +#### `tf.train.FeatureList.IsInitialized(errors=None)` {#FeatureList.IsInitialized} + +Checks if all required fields of a message are set. + +##### Args: + + +* `errors`: A list which, if provided, will be populated with the field + paths of all missing required fields. + +##### Returns: + + True iff the specified message has all required fields set. + + +- - - + +#### `tf.train.FeatureList.ListFields()` {#FeatureList.ListFields} + + + + +- - - + +#### `tf.train.FeatureList.MergeFrom(msg)` {#FeatureList.MergeFrom} + + + + +- - - + +#### `tf.train.FeatureList.MergeFromString(serialized)` {#FeatureList.MergeFromString} + + + + +- - - + +#### `tf.train.FeatureList.ParseFromString(serialized)` {#FeatureList.ParseFromString} + +Parse serialized protocol buffer data into this message. + +Like MergeFromString(), except we clear the object first and +do not return the value that MergeFromString returns. + + +- - - + +#### `tf.train.FeatureList.RegisterExtension(extension_handle)` {#FeatureList.RegisterExtension} + + + + +- - - + +#### `tf.train.FeatureList.SerializePartialToString()` {#FeatureList.SerializePartialToString} + + + + +- - - + +#### `tf.train.FeatureList.SerializeToString()` {#FeatureList.SerializeToString} + + + + +- - - + +#### `tf.train.FeatureList.SetInParent()` {#FeatureList.SetInParent} + +Sets the _cached_byte_size_dirty bit to true, +and propagates this to our listener iff this was a state change. + + +- - - + +#### `tf.train.FeatureList.WhichOneof(oneof_name)` {#FeatureList.WhichOneof} + +Returns the name of the currently set field inside a oneof, or None. + + +- - - + +#### `tf.train.FeatureList.feature` {#FeatureList.feature} + +Magic attribute generated for "feature" proto field. + + + +- - - + +### `class tf.train.FeatureLists` {#FeatureLists} + + +- - - + +#### `tf.train.FeatureLists.ByteSize()` {#FeatureLists.ByteSize} + + + + +- - - + +#### `tf.train.FeatureLists.Clear()` {#FeatureLists.Clear} + + + + +- - - + +#### `tf.train.FeatureLists.ClearExtension(extension_handle)` {#FeatureLists.ClearExtension} + + + + +- - - + +#### `tf.train.FeatureLists.ClearField(field_name)` {#FeatureLists.ClearField} + + + + +- - - + +#### `tf.train.FeatureLists.CopyFrom(other_msg)` {#FeatureLists.CopyFrom} + +Copies the content of the specified message into the current message. + +The method clears the current message and then merges the specified +message using MergeFrom. + +##### Args: + + +* `other_msg`: Message to copy into the current one. + + +- - - + +#### `tf.train.FeatureLists.FindInitializationErrors()` {#FeatureLists.FindInitializationErrors} + +Finds required fields which are not initialized. + +##### Returns: + + A list of strings. Each string is a path to an uninitialized field from + the top-level message, e.g. "foo.bar[5].baz". + + +- - - + +#### `tf.train.FeatureLists.FromString(s)` {#FeatureLists.FromString} + + + + +- - - + +#### `tf.train.FeatureLists.HasExtension(extension_handle)` {#FeatureLists.HasExtension} + + + + +- - - + +#### `tf.train.FeatureLists.HasField(field_name)` {#FeatureLists.HasField} + + + + +- - - + +#### `tf.train.FeatureLists.IsInitialized(errors=None)` {#FeatureLists.IsInitialized} + +Checks if all required fields of a message are set. + +##### Args: + + +* `errors`: A list which, if provided, will be populated with the field + paths of all missing required fields. + +##### Returns: + + True iff the specified message has all required fields set. + + +- - - + +#### `tf.train.FeatureLists.ListFields()` {#FeatureLists.ListFields} + + + + +- - - + +#### `tf.train.FeatureLists.MergeFrom(msg)` {#FeatureLists.MergeFrom} + + + + +- - - + +#### `tf.train.FeatureLists.MergeFromString(serialized)` {#FeatureLists.MergeFromString} + + + + +- - - + +#### `tf.train.FeatureLists.ParseFromString(serialized)` {#FeatureLists.ParseFromString} + +Parse serialized protocol buffer data into this message. + +Like MergeFromString(), except we clear the object first and +do not return the value that MergeFromString returns. + + +- - - + +#### `tf.train.FeatureLists.RegisterExtension(extension_handle)` {#FeatureLists.RegisterExtension} + + + + +- - - + +#### `tf.train.FeatureLists.SerializePartialToString()` {#FeatureLists.SerializePartialToString} + + + + +- - - + +#### `tf.train.FeatureLists.SerializeToString()` {#FeatureLists.SerializeToString} + + + + +- - - + +#### `tf.train.FeatureLists.SetInParent()` {#FeatureLists.SetInParent} + +Sets the _cached_byte_size_dirty bit to true, +and propagates this to our listener iff this was a state change. + + +- - - + +#### `tf.train.FeatureLists.WhichOneof(oneof_name)` {#FeatureLists.WhichOneof} + +Returns the name of the currently set field inside a oneof, or None. + + +- - - + +#### `tf.train.FeatureLists.feature_list` {#FeatureLists.feature_list} + +Magic attribute generated for "feature_list" proto field. + + + diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index 6f79d22d3f9..feda1572327 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -39,7 +39,7 @@ Python. The packages that will be installed or upgraded during the pip install are listed in the [REQUIRED_PACKAGES section of setup.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/setup.py) -Install pip if not already installed: +Install pip if it is not already installed: ```bash # Ubuntu/Linux 64-bit @@ -118,8 +118,9 @@ $ source ~/tensorflow/bin/activate.csh # If using csh With the Virtualenv environment activated, you can now [test your installation](#test_install). +When you are done using TensorFlow, deactivate the environment. + ```bash -# When you are done using TensorFlow, deactivate the environment. (tensorflow)$ deactivate $ # Your prompt should change back @@ -152,16 +153,15 @@ We provide 2 Docker images: With Docker the installation is as follows: * Install Docker on your machine. +* Create a [Docker +group](http://docs.docker.com/engine/installation/ubuntulinux/#create-a-docker-group) +to allow launching containers without `sudo`. * Launch a Docker container with the TensorFlow image. The image gets downloaded automatically on first launch. See [installing Docker](http://docs.docker.com/engine/installation/) for instructions on installing Docker on your machine. -Also create a [Docker -group](http://docs.docker.com/engine/installation/ubuntulinux/#create-a-docker-group) -to allow launching containers without `sudo`. - After Docker is installed, launch a Docker container with the TensorFlow binary image as follows. @@ -169,7 +169,7 @@ image as follows. $ docker run -it b.gcr.io/tensorflow/tensorflow ``` -Within the Docker container, you can now [test your installation](#test_install). +You can now [test your installation](#test_install) within the Docker container. You can alternatively launch the TensorFlow source image, for example if you want to experiment directly with the source. @@ -196,7 +196,7 @@ export CUDA_HOME=/usr/local/cuda ### Run TensorFlow from the Command Line -See [common problems](#common_install_problems) if some error happens. +See [common problems](#common_install_problems) if an error happens. Open a terminal and type the following: @@ -275,10 +275,10 @@ $ chmod +x PATH_TO_INSTALL.SH $ ./PATH_TO_INSTALL.SH --user ``` -Remember to replace `PATH_TO_INSTALL.SH` to point to the location where you +Remember to replace `PATH_TO_INSTALL.SH` with the location where you downloaded the installer. -Finally, follow the instructions in that script to place bazel into your binary +Finally, follow the instructions in that script to place `bazel` into your binary path. #### Install other dependencies @@ -287,12 +287,26 @@ path. $ sudo apt-get install python-numpy swig python-dev ``` +#### Configure the installation {#configure} + +Run the `configure` script at the root of the tree. The configure script +asks you for the path to your python interpreter and allows (optional) +configuration of the CUDA libraries (see [below](#configure_cuda)). + +This step is used to locate the python and numpy header files. + +```bash +$ ./configure +Please specify the location of python. [Default is /usr/bin/python]: +``` + #### Optional: Install CUDA (GPUs on Linux) {#install_cuda} In order to build or run TensorFlow with GPU support, both Cuda Toolkit 7.0 and CUDNN 6.5 V2 from NVIDIA need to be installed. -TensorFlow GPU support requires having a GPU card with NVidia Compute Capability >= 3.5. Supported cards include but are not limited to: +TensorFlow GPU support requires having a GPU card with NVidia Compute Capability >= 3.5. +Supported cards include but are not limited to: * NVidia Titan * NVidia Titan X @@ -318,12 +332,14 @@ sudo cp cudnn-6.5-linux-x64-v2/cudnn.h /usr/local/cuda/include sudo cp cudnn-6.5-linux-x64-v2/libcudnn* /usr/local/cuda/lib64 ``` -##### Configure TensorFlow's canonical view of Cuda libraries -From the root of your source tree, run: +##### Configure TensorFlow's canonical view of Cuda libraries {#configure_cuda} +When running the `configure` script from the root of your source tree, select +the option `Y` when asked to build TensorFlow with GPU support. ``` bash $ ./configure -Do you wish to build TensorFlow with GPU support? [y/n] y +Please specify the location of python. [Default is /usr/bin/python]: +Do you wish to build TensorFlow with GPU support? [y/N] y GPU support will be enabled for TensorFlow Please specify the location where CUDA 7.0 toolkit is installed. Refer to @@ -400,9 +416,9 @@ given necessary bazel new feature support. ### Installation for Mac OS X -Mac needs the same set of dependencies as Linux, however installing those -dependencies is different. Here is a set of useful links to help with installing -the dependencies on Mac OS X : +Mac needs the same set of dependencies as Linux, but the installation +process for those dependencies is different. Here is a set of useful links +to help with installing the dependencies on Mac OS X : #### Bazel @@ -420,6 +436,18 @@ Notes : You need to install Follow installation instructions [here](http://docs.scipy.org/doc/numpy/user/install.html). +#### Configure the installation {#configure_osx} + +Run the `configure` script at the root of the tree. The configure script +asks you for the path to your python interpreter. + +This step is used to locate the python and numpy header files. + +```bash +$ ./configure +Please specify the location of python. [Default is /usr/bin/python]: +Do you wish to build TensorFlow with GPU support? [y/N] +``` ### Create the pip package and install {#create-pip} @@ -505,7 +533,7 @@ SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed Solution: Download the wheel manually via curl or wget, and pip install locally. -### On Linux +### Linux issues If you encounter: @@ -533,7 +561,7 @@ Solution: TensorFlow depends on protobuf, which requires the Python package You can resolve the issue in one of the following ways: -* Upgrade the Python installation with the current version `six`: +* Upgrade the Python installation with the current version of `six`: ```bash $ sudo easy_install -U six diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 618d0e2ad3f..0c75bd8dc72 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -14,16 +14,6 @@ load("/tensorflow/tensorflow", "cuda_py_tests") load("/tensorflow/tensorflow", "tf_py_wrap_cc") load("/tensorflow/core/platform/default/build_config", "tf_proto_library_py") -config_setting( - name = "macosx", - values = {"cpu": "darwin"}, -) - -numpy_macosx_include_dir = select({ - ":macosx": ["-I/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/numpy/core/include"], - "//conditions:default": [], -}) - py_library( name = "python", srcs = ["__init__.py"], @@ -467,6 +457,7 @@ tf_gen_op_wrapper_py( "MaxPoolGradWithArgmax", "ReluGrad", "Relu6Grad", + "EluGrad", "SoftplusGrad", "SoftsignGrad", "BiasAdd", @@ -712,7 +703,6 @@ tf_cuda_library( name = "tf_session_helper", srcs = ["client/tf_session_helper.cc"], hdrs = ["client/tf_session_helper.h"], - copts = numpy_macosx_include_dir + ["-I/usr/include/python2.7"], deps = [ ":construction_fails_op", ":test_kernel_label_op_kernel", @@ -721,13 +711,14 @@ tf_cuda_library( "//tensorflow/core:kernels", "//tensorflow/core:lib", "//tensorflow/core:protos_cc", + "//third_party/py/numpy:headers", + "//util/python:python_headers", ], ) tf_py_wrap_cc( name = "client/pywraptensorflow_server_lib", srcs = ["client/tensorflow_server.i"], - copts = numpy_macosx_include_dir, swig_includes = [ "lib/core/status.i", "lib/core/strings.i", @@ -737,13 +728,13 @@ tf_py_wrap_cc( "//tensorflow/core", "//tensorflow/core:lib", "//tensorflow/core:protos_cc", + "//util/python:python_headers", ], ) tf_py_wrap_cc( name = "pywrap_tensorflow", srcs = ["tensorflow.i"], - copts = numpy_macosx_include_dir, swig_includes = [ "client/events_writer.i", "client/tf_session.i", @@ -760,6 +751,7 @@ tf_py_wrap_cc( ":py_record_reader_lib", ":py_record_writer_lib", ":tf_session_helper", + "//util/python:python_headers", ], ) diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index 9bd73f710a5..627435be6d1 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -276,6 +276,10 @@ class ConcatOpTest(tf.test.TestCase): concat = tf.concat(dim, [p1, c1, p2, c2]) self.assertEqual(4, concat.get_shape().ndims) + # All dimensions unknown. + concat2 = tf.concat(dim, [p1, p2]) + self.assertEqual(None, concat2.get_shape()) + # Rank doesn't match. c3 = tf.constant(30.0, shape=[4, 4, 4]) with self.assertRaises(ValueError): diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py index c6223be886a..331c62edf26 100644 --- a/tensorflow/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/parsing_ops_test.py @@ -33,6 +33,10 @@ features = lambda d: tf.train.Features(feature=d) bytes_feature = lambda v: feature(bytes_list=tf.train.BytesList(value=v)) int64_feature = lambda v: feature(int64_list=tf.train.Int64List(value=v)) float_feature = lambda v: feature(float_list=tf.train.FloatList(value=v)) +# Helpers for creating SequenceExample objects +feature_list = lambda l: tf.train.FeatureList(feature=l) +feature_lists = lambda d: tf.train.FeatureLists(feature_list=d) +sequence_example = tf.train.SequenceExample def flatten(list_of_lists): @@ -475,5 +479,24 @@ class ParseSingleExampleTest(tf.test.TestCase): }, expected_output) +class ParseSequenceExampleTest(tf.test.TestCase): + + def testCreateSequenceExample(self): + value = sequence_example( + context=features({ + "global_feature": float_feature([1, 2, 3]), + }), + feature_lists=feature_lists({ + "repeated_feature_2_frames": feature_list([ + bytes_feature(["a", "b", "c"]), + bytes_feature(["a", "d", "e"])]), + "repeated_feature_3_frames": feature_list([ + int64_feature([3, 4, 5, 6, 7]), + int64_feature([-1, 0, 0, 0, 0]), + int64_feature([1, 2, 3, 4, 5])]) + })) + value.SerializeToString() # Smoke test + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 38ab52b0c16..201cf28f452 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -45,18 +45,18 @@ class ReluTest(tf.test.TestCase): self.assertShapeEqual(np_relu, relu) def testNumbers(self): - for t in [np.int32, np.int64, np.float, np.double]: + for t in [np.int32, np.int64, np.float32, np.float64]: self._testRelu( np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), use_gpu=False) - if t in [np.float, np.double]: + if t in [np.float32, np.float64]: self._testRelu( np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), use_gpu=True) # The gradient test for ReLU is a bit tricky as the derivative is not well # defined at around zero and we want to avoid that in terms of input values. - def testGradientFloat(self): + def testGradientFloat32(self): with self.test_session(): x = tf.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], @@ -70,7 +70,7 @@ class ReluTest(tf.test.TestCase): y, [2, 5], x_init_value=x_init) - print("relu (float) gradient err = ", err) + print("relu (float32) gradient err = ", err) self.assertLess(err, 1e-4) def testGradientNaN(self): @@ -91,7 +91,7 @@ class ReluTest(tf.test.TestCase): except Exception as e: # pylint: disable=broad-except assert "ReluGrad input is not finite." in str(e) - def testGradientDouble(self): + def testGradientFloat64(self): with self.test_session(): x = tf.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], @@ -105,10 +105,10 @@ class ReluTest(tf.test.TestCase): y, [2, 5], x_init_value=x_init) - print("relu (double) gradient err = ", err) + print("relu (float64) gradient err = ", err) self.assertLess(err, 1e-10) - def testGradGradFloat(self): + def testGradGradFloat32(self): with self.test_session(): x = tf.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], @@ -123,10 +123,10 @@ class ReluTest(tf.test.TestCase): z[0], [2, 5], x_init_value=x_init) - print("relu (float) gradient of gradient err = ", err) + print("relu (float32) gradient of gradient err = ", err) self.assertLess(err, 1e-4) - def testGradGradDouble(self): + def testGradGradFloat64(self): with self.test_session(): x = tf.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], @@ -141,7 +141,7 @@ class ReluTest(tf.test.TestCase): z[0], [2, 5], x_init_value=x_init) - print("relu (double) gradient of gradient err = ", err) + print("relu (float64) gradient of gradient err = ", err) self.assertLess(err, 1e-10) @@ -169,7 +169,7 @@ class Relu6Test(tf.test.TestCase): self.assertShapeEqual(np_relu6, relu6) def testNumbers(self): - for t in [np.int32, np.int64, np.float, np.double]: + for t in [np.int32, np.int64, np.float32, np.float64]: self._testRelu6( np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), use_gpu=False) @@ -181,7 +181,7 @@ class Relu6Test(tf.test.TestCase): # The gradient test for ReLU6 is a bit tricky as the derivative is # not well defined at around zero and six and we want to avoid that # in terms of input values. - def testGradientFloat(self): + def testGradientFloat32(self): with self.test_session(): x = tf.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9], @@ -195,10 +195,10 @@ class Relu6Test(tf.test.TestCase): y, [2, 5], x_init_value=x_init) - print("relu6 (float) gradient err = ", err) + print("relu6 (float32) gradient err = ", err) self.assertLess(err, 1e-4) - def testGradientDouble(self): + def testGradientFloat64(self): with self.test_session(): x = tf.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9], @@ -212,9 +212,67 @@ class Relu6Test(tf.test.TestCase): y, [2, 5], x_init_value=x_init) - print("relu6 (double) gradient err = ", err) + print("relu6 (float64) gradient err = ", err) self.assertLess(err, 1e-10) +class EluTest(tf.test.TestCase): + + def _npElu(self, np_features): + return np.where(np_features < 0, np.exp(np_features) - 1, np_features) + + def testNpElu(self): + self.assertAllClose( + np.array([[-0.59343034025, 0.7, -0.39346934028, 0.3, -0.09516258196], + [0.1, -0.25918177931, 0.5, -0.5034146962, 0.9]]), + self._npElu(np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, - + 0.7, 0.9]]))) + + def _testElu(self, np_features, use_gpu=False): + np_elu = self._npElu(np_features) + with self.test_session(use_gpu=use_gpu): + elu = tf.nn.elu(np_features) + tf_elu = elu.eval() + self.assertAllClose(np_elu, tf_elu) + self.assertShapeEqual(np_elu, elu) + + def testNumbers(self): + for t in [np.float32, np.float64]: + self._testElu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + use_gpu=False) + self._testElu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + use_gpu=True) + + def testGradientFloat32(self): + with self.test_session(): + x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]] + x = tf.constant(x_val, name="x") + y = tf.nn.elu(x, name="elu") + x_init = np.asarray(x_val, dtype=np.float32, order="F") + err = tf.test.compute_gradient_error(x, + [2, 5], + y, + [2, 5], + x_init_value=x_init) + print("elu (float32) gradient err = ", err) + self.assertLess(err, 1e-4) + + def testGradientFloat64(self): + with self.test_session(): + x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]] + x = tf.constant(x_val, dtype=tf.float64, name="x") + y = tf.nn.elu(x, name="elu") + x_init = np.asarray(x_val, dtype=np.float64, order="F") + err = tf.test.compute_gradient_error(x, + [2, 5], + y, + [2, 5], + x_init_value=x_init) + print("elu (float64) gradient err = ", err) + self.assertLess(err, 1e-6) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index e913fc1f501..ca8f3808ab4 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -106,7 +106,7 @@ class VariableStoreTest(tf.test.TestCase): with variable_scope.variable_scope(tower, reuse=True) as tower_shared: self.assertEqual(tower_shared.name, "tower") with tf.name_scope("scope") as sc: - self.assertEqual(sc, "foo_1/scope/") + self.assertEqual(sc, "foo_1/tower/scope/") def testVarScopeNameScope(self): with self.test_session(): @@ -124,7 +124,65 @@ class VariableStoreTest(tf.test.TestCase): self.assertEqual(sc2, "scope3/tower/scope2/") with variable_scope.variable_scope(tower): with tf.name_scope("scope2") as sc2: - self.assertEqual(sc2, "scope3/scope2/") + self.assertEqual(sc2, "scope3/tower_1/scope2/") + + root_var_scope = variable_scope.get_variable_scope() + with tf.name_scope("scope4"): + with variable_scope.variable_scope(root_var_scope): + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "scope4/scope2/") + + def testVarOpScope(self): + with self.test_session(): + with tf.name_scope("scope1"): + with variable_scope.variable_op_scope([], "tower", "default"): + self.assertEqual(variable_scope.get_variable("w", []).name, + "tower/w:0") + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "scope1/tower/scope2/") + with variable_scope.variable_op_scope([], "tower", "default"): + with self.assertRaises(ValueError): + variable_scope.get_variable("w", []) + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "scope1/tower_1/scope2/") + + with tf.name_scope("scope2"): + with variable_scope.variable_op_scope([], None, "default"): + self.assertEqual(variable_scope.get_variable("w", []).name, + "default/w:0") + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "scope2/default/scope2/") + with variable_scope.variable_op_scope([], None, "default"): + self.assertEqual(variable_scope.get_variable("w", []).name, + "default_1/w:0") + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "scope2/default_1/scope2/") + + def testVarOpScopeReuse(self): + with self.test_session(): + with tf.variable_scope("outer") as outer: + with variable_scope.variable_op_scope([], "tower", "default"): + self.assertEqual(variable_scope.get_variable("w", []).name, + "outer/tower/w:0") + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "outer/tower/scope2/") + with variable_scope.variable_op_scope([], None, "default"): + self.assertEqual(variable_scope.get_variable("w", []).name, + "outer/default/w:0") + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "outer/default/scope2/") + + with tf.variable_scope(outer, reuse=True) as outer: + with variable_scope.variable_op_scope([], "tower", "default"): + self.assertEqual(variable_scope.get_variable("w", []).name, + "outer/tower/w:0") + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "outer_1/tower/scope2/") + with variable_scope.variable_op_scope([], None, "default"): + self.assertEqual(variable_scope.get_variable("w", []).name, + "outer/default/w:0") + with tf.name_scope("scope2") as sc2: + self.assertEqual(sc2, "outer_1/default/scope2/") def testVarScopeGetVar(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index 39ec5f10a63..b9380655a0c 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -50,6 +50,29 @@ class XentTest(tf.test.TestCase): self._testXent(features, labels, use_gpu=False) self._testXent(features, labels, use_gpu=True) + def _testSingleClass(self, use_gpu=False): + with self.test_session(use_gpu=use_gpu) as sess: + loss = tf.nn.softmax_cross_entropy_with_logits( + np.array([[1.], [-1.], [0.]]).astype(np.float32), + np.array([[-1.], [0.], [1.]]).astype(np.float32)) + backprop = loss.op.outputs[1] + tf_loss, tf_backprop = sess.run([loss, backprop]) + self.assertAllClose([0.0, 0.0, 0.0], tf_loss) + self.assertAllClose([[2.0], [1.0], [0.0]], tf_backprop) + + def testSingleClass(self): + self._testSingleClass(True) + self._testSingleClass(False) + + def testRankTooLarge(self): + np_features = np.array( + [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(np.float32) + np_labels = np.array( + [[[0., 0., 0., 1.]], [[0., .5, .5, 0.]]]).astype(np.float32) + self.assertRaisesRegexp( + ValueError, "must have the same rank", + tf.nn.softmax_cross_entropy_with_logits, np_features, np_labels) + def testNpXent(self): # We create 2 batches of logits for testing. # batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3. diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 1e2950e74ff..5930b3486da 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -335,7 +335,10 @@ def _ConcatShape(op): value.get_shape().assert_has_rank(rank) else: rank = value.get_shape().ndims - return [tensor_shape.unknown_shape(ndims=max(rank, 1))] + # TODO(irving): Remove once !kAllowLegacyScalars. + if rank is not None: + rank = max(rank, 1) + return [tensor_shape.unknown_shape(ndims=rank)] else: # Merge all the non-concat dims, and sum the concat dim to make an diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 2392042d504..b208b0cea28 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -43,7 +43,7 @@ The convenience function [`resize_images()`](#resize_images) supports both 4-D and 3-D tensors as input and output. 4-D tensors are for batches of images, 3-D tensors for individual images. -Other resizing Ops only support 3-D individual images as input: +Other resizing Ops only support 4-D batches of images as input: [`resize_area`](#resize_area), [`resize_bicubic`](#resize_bicubic), [`resize_bilinear`](#resize_bilinear), [`resize_nearest_neighbor`](#resize_nearest_neighbor). @@ -51,9 +51,9 @@ Other resizing Ops only support 3-D individual images as input: Example: ```python -# Decode a JPG image and resize it to 299 by 299. +# Decode a JPG image and resize it to 299 by 299 using default method. image = tf.image.decode_jpeg(...) -resized_image = tf.image.resize_bilinear(image, [299, 299]) +resized_image = tf.image.resize_images(image, 299, 299) ``` @@resize_images @@ -528,7 +528,7 @@ def resize_images(images, new_height, new_width, method=ResizeMethod.BILINEAR): raise ValueError('Resize method is not implemented.') if not is_batch: - images = array_ops.reshape(images, [new_height, new_width, depth]) + images = array_ops.squeeze(images, squeeze_dims=[0]) return images diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 3bd0e875631..065552c7ede 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -17,15 +17,17 @@ """## Activation Functions The activation ops provide different types of nonlinearities for use in neural -networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `softplus`, -and `softsign`), continuous but not everywhere differentiable functions (`relu`, -`relu6`, and `relu_x`), and random regularization (`dropout`). +networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`, +`softplus`, and `softsign`), continuous but not everywhere differentiable +functions (`relu`, `relu6`, and `relu_x`), and random regularization +(`dropout`). All activation ops apply componentwise, and produce a tensor of the same shape as the input tensor. @@relu @@relu6 +@@elu @@softplus @@softsign @@dropout diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 48f57b65279..b4b6b3b0c18 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -132,6 +132,11 @@ def _Relu6Grad(op, grad): return gen_nn_ops._relu6_grad(grad, op.inputs[0]) +@ops.RegisterGradient("Elu") +def _EluGrad(op, grad): + return gen_nn_ops._elu_grad(grad, op.outputs[0]) + + @ops.RegisterGradient("Softplus") def _SoftplusGrad(op, grad): return gen_nn_ops._softplus_grad(grad, op.inputs[0]) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 604739a6b6e..39497f2bf86 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -237,12 +237,14 @@ def max_pool(value, ksize, strides, padding, name=None): ops.RegisterShape("Relu")(common_shapes.unchanged_shape) ops.RegisterShape("Relu6")(common_shapes.unchanged_shape) +ops.RegisterShape("Elu")(common_shapes.unchanged_shape) ops.RegisterShape("Softplus")(common_shapes.unchanged_shape) ops.RegisterShape("Softsign")(common_shapes.unchanged_shape) @ops.RegisterShape("ReluGrad") @ops.RegisterShape("Relu6Grad") +@ops.RegisterShape("EluGrad") @ops.RegisterShape("SoftplusGrad") @ops.RegisterShape("SoftsignGrad") def _BinaryElementwiseShape(op): diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py index 131524b77c5..1ad16828aa5 100644 --- a/tensorflow/python/ops/seq2seq.py +++ b/tensorflow/python/ops/seq2seq.py @@ -147,7 +147,7 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols, """RNN decoder with embedding and a pure-decoding option. Args: - decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs). + decoder_inputs: a list of 1D batch-sized int32 Tensors (decoder inputs). initial_state: 2D Tensor [batch_size x cell.state_size]. cell: rnn_cell.RNNCell defining the cell function. num_symbols: integer, how many symbols come into the embedding. @@ -219,8 +219,8 @@ def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, encoder state, on embedded decoder_inputs. Args: - encoder_inputs: a list of 1D int32-Tensors of shape [batch_size]. - decoder_inputs: a list of 1D int32-Tensors of shape [batch_size]. + encoder_inputs: a list of 1D int32 Tensors of shape [batch_size]. + decoder_inputs: a list of 1D int32 Tensors of shape [batch_size]. cell: rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: integer; number of symbols on the encoder side. num_decoder_symbols: integer; number of symbols on the decoder side. @@ -286,8 +286,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, encoder state, on embedded decoder_inputs. Args: - encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + encoder_inputs: a list of 1D int32 Tensors of shape [batch_size]. + decoder_inputs: a list of 1D int32 Tensors of shape [batch_size]. cell: rnn_cell.RNNCell defining the cell function and size. num_symbols: integer; number of symbols for both encoder and decoder. output_projection: None or a pair (W, B) of output projection weights and @@ -486,7 +486,7 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states, """RNN decoder with embedding and attention and a pure-decoding option. Args: - decoder_inputs: a list of 1D batch-sized int32-Tensors (decoder inputs). + decoder_inputs: a list of 1D batch-sized int32 Tensors (decoder inputs). initial_state: 2D Tensor [batch_size x cell.state_size]. attention_states: 3D Tensor [batch_size x attn_length x attn_size]. cell: rnn_cell.RNNCell defining the cell function. @@ -566,8 +566,8 @@ def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell, encoder state, on embedded decoder_inputs and attending to encoder outputs. Args: - encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + encoder_inputs: a list of 1D int32 Tensors of shape [batch_size]. + decoder_inputs: a list of 1D int32 Tensors of shape [batch_size]. cell: rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: integer; number of symbols on the encoder side. num_decoder_symbols: integer; number of symbols on the decoder side. @@ -636,7 +636,7 @@ def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols, Args: logits: list of 2D Tensors of shape [batch_size x num_decoder_symbols]. - targets: list of 1D batch-sized int32-Tensors of the same length as logits. + targets: list of 1D batch-sized int32 Tensors of the same length as logits. weights: list of 1D batch-sized float-Tensors of the same length as logits. num_decoder_symbols: integer, number of decoder symbols (output classes). average_across_timesteps: If set, divide the returned cost by the total @@ -692,7 +692,7 @@ def sequence_loss(logits, targets, weights, num_decoder_symbols, Args: logits: list of 2D Tensors os shape [batch_size x num_decoder_symbols]. - targets: list of 1D batch-sized int32-Tensors of the same length as logits. + targets: list of 1D batch-sized int32 Tensors of the same length as logits. weights: list of 1D batch-sized float-Tensors of the same length as logits. num_decoder_symbols: integer, number of decoder symbols (output classes). average_across_timesteps: If set, divide the returned cost by the total @@ -731,7 +731,7 @@ def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, Args: encoder_inputs: a list of Tensors to feed the encoder; first seq2seq input. decoder_inputs: a list of Tensors to feed the decoder; second seq2seq input. - targets: a list of 1D batch-sized int32-Tensors (desired output sequence). + targets: a list of 1D batch-sized int32 Tensors (desired output sequence). weights: list of 1D batch-sized float-Tensors to weight the targets. buckets: a list of pairs of (input size, output size) for each bucket. num_decoder_symbols: integer, number of decoder symbols (output classes). diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index f8176846b2a..25c679b80bb 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -45,6 +45,7 @@ create variables contingent on certain conditions. @@get_variable @@get_variable_scope +@@variable_op_scope @@variable_scope @@constant_initializer diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 2146405d3e2..5b2778717d4 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -131,12 +131,14 @@ class _VariableScope(object): name: name of the current scope, used as prefix in get_variable. initializer: default initializer passed to get_variable. reuse: Boolean or None, setting the reuse in get_variable. + name_scope: The name passed to tf.name_scope. """ - def __init__(self, reuse, name="", initializer=None): + def __init__(self, reuse, name="", initializer=None, name_scope=""): self._name = name self._initializer = initializer self._reuse = reuse + self._name_scope = name_scope @property def name(self): @@ -238,6 +240,60 @@ def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None, trainable, collections) +@contextlib.contextmanager +def _pure_variable_scope(name_or_scope, reuse=None, initializer=None): + """Creates a context for the variable_scope, see `variable_scope` for docs. + + Note: this does not create a name scope. + + Args: + name_or_scope: `string` or `VariableScope`: the scope to open. + reuse: `True` or `None`; if `True`, we go into reuse mode for this scope as + well as all sub-scopes; if `None`, we just inherit the parent scope reuse. + initializer: default initializer for variables within this scope. + + Yields: + A scope that can be to captured and reused. + + Raises: + ValueError: when trying to reuse within a create scope, or create within + a reuse scope, or if reuse is not `None` or `True`. + TypeError: when the types of some arguments are not appropriate. + + """ + get_variable_scope() # Ensure that a default exists, then get a pointer. + default_varscope = ops.get_collection(_VARSCOPE_KEY) + try: + old = default_varscope[0] + reuse = reuse or old.reuse # Re-using is inherited by sub-scopes. + if isinstance(name_or_scope, _VariableScope): + name_scope = name_or_scope._name_scope # pylint: disable=protected-access + # Handler for the case when we jump to a shared scope. + # We create a new VariableScope (default_varscope[0]) that contains + # a copy of the provided shared scope, possibly with changed reuse + # and initializer, if the user requested this. + default_varscope[0] = _VariableScope(reuse, name_or_scope.name, + name_or_scope.initializer, + name_scope) + if initializer: + default_varscope[0].set_initializer(initializer) + yield default_varscope[0] + else: + # Handler for the case when we just prolong current variable scope. + # VariableScope with name extended by the provided one, and inherited + # reuse and initializer (except if the user provided values to set). + new_name = old.name + "/" + name_or_scope if old.name else name_or_scope + default_varscope[0] = _VariableScope(reuse, name=new_name, + initializer=old.initializer, + name_scope=name_or_scope) + if initializer: + default_varscope[0].set_initializer(initializer) + yield default_varscope[0] + finally: + default_varscope[0] = old + + +# pylint: disable=g-doc-return-or-yield @contextlib.contextmanager def variable_scope(name_or_scope, reuse=None, initializer=None): """Returns a context for variable scope. @@ -304,7 +360,7 @@ def variable_scope(name_or_scope, reuse=None, initializer=None): well as all sub-scopes; if `None`, we just inherit the parent scope reuse. initializer: default initializer for variables within this scope. - Yields: + Returns: A scope that can be to captured and reused. Raises: @@ -315,39 +371,76 @@ def variable_scope(name_or_scope, reuse=None, initializer=None): if not isinstance(name_or_scope, (_VariableScope,) + six.string_types): raise TypeError("VariableScope: name_scope must be a string or " "VariableScope.") - if reuse not in [None, True]: - raise ValueError("VariableScope reuse parameter must be True or None.") - if not reuse and isinstance(name_or_scope, (_VariableScope)): - logging.info("Passing VariableScope to a non-reusing scope, intended?") - if reuse and isinstance(name_or_scope, six.string_types): - logging.info("Re-using string-named scope, consider capturing as object.") - get_variable_scope() # Ensure that a default exists, then get a pointer. - default_varscope = ops.get_collection(_VARSCOPE_KEY) - try: - old = default_varscope[0] - reuse = reuse or old.reuse # Re-using is inherited by sub-scopes. - if isinstance(name_or_scope, _VariableScope): - # Handler for the case when we jump to a shared scope. - # In this case, we leave the current name_scope unchanged. - # We create a new VariableScope (default_varscope[0]) that contains - # a copy of the provided shared scope, possibly with changed reuse - # and initializer, if the user requested this. - default_varscope[0] = _VariableScope(reuse, name_or_scope.name, - name_or_scope.initializer) - if initializer: - default_varscope[0].set_initializer(initializer) - yield default_varscope[0] + if isinstance(name_or_scope, six.string_types): + name = name_or_scope + else: + name = name_or_scope._name_scope # pylint: disable=protected-access + if name: + with ops.name_scope(name), _pure_variable_scope( + name_or_scope, reuse, initializer) as vs: + yield vs + else: + # This can only happen if someone is entering the root variable scope. + with _pure_variable_scope(name_or_scope, reuse, initializer) as vs: + yield vs + + +# pylint: disable=g-doc-return-or-yield +@contextlib.contextmanager +def variable_op_scope(values, name, default_name, initializer=None): + """Returns a context manager for defining an op that creates variables. + + This context manager validates that the given `values` are from the + same graph, ensures that that graph is the default graph, and pushes a + name scope and a variable scope. + + If `name` is not None, it is used as is in the variable scope. If `name` + is None, then `default_name` is used. In that case, if the same name has been + previously used in the same scope, it will made unique be appending `_N` to + it. + + This is intended to be used when defining generic ops and so reuse is always + inherited. + + For example, to define a new Python op called `my_op_with_vars`: + + ```python + def my_op_with_vars(a, b, name=None): + with tf.variable_op_scope([a, b], name, "MyOp") as scope: + a = tf.convert_to_tensor(a, name="a") + b = tf.convert_to_tensor(b, name="b") + c = tf.get_variable('c') + # Define some computation that uses `a`, `b`, and `c`. + return foo_op(..., name=scope) + ``` + + Args: + values: The list of `Tensor` arguments that are passed to the op function. + name: The name argument that is passed to the op function, this name is not + uniquified in the variable scope. + default_name: The default name to use if the `name` argument is `None`, this + name will be uniquified. + initializer: A default initializer to pass to variable scope. + + Returns: + A context manager for use in defining a Python op. + + Raises: + ValueError: when trying to reuse within a create scope, or create within + a reuse scope, or if reuse is not `None` or `True`. + TypeError: when the types of some arguments are not appropriate. + """ + if default_name is None: + raise TypeError("default_name cannot be None") + g = ops._get_graph_from_inputs(values) # pylint: disable=protected-access + with g.as_default(): + if name: + with variable_scope(name, initializer=initializer) as vs: + yield vs else: - # Handler for the case when we just prolong current variable scope. - # In this case we prolong the current name_scope and create a new - # VariableScope with name extended by the provided one, and inherited - # reuse and initializer (except if the user provided values to set). - with ops.name_scope(name_or_scope): - new_name = old.name + "/" + name_or_scope if old.name else name_or_scope - default_varscope[0] = _VariableScope(reuse, name=new_name, - initializer=old.initializer) - if initializer: - default_varscope[0].set_initializer(initializer) - yield default_varscope[0] - finally: - default_varscope[0] = old + with ops.name_scope(default_name) as scope: + count = len(default_name.split("/")) + scoped_name = "/".join(scope.split("/")[-count - 1:-1]) + with _pure_variable_scope(scoped_name, + initializer=initializer) as vs: + yield vs diff --git a/tensorflow/tensorboard/.gitignore b/tensorflow/tensorboard/.gitignore index 9f4cfe9b129..33117e9d630 100644 --- a/tensorflow/tensorboard/.gitignore +++ b/tensorflow/tensorboard/.gitignore @@ -13,3 +13,4 @@ components/tf-graph-common/lib/scene/*.js components/tf-event-dashboard/*.js components/tf-categorizer/*.js components/tf-dashboard-common/*.js +components/**/test/*.js diff --git a/tensorflow/tensorboard/components/imports/local-imports/dagre.html b/tensorflow/tensorboard/components/imports/local-imports/dagre.html index b685aea6c93..29586e769de 100644 --- a/tensorflow/tensorboard/components/imports/local-imports/dagre.html +++ b/tensorflow/tensorboard/components/imports/local-imports/dagre.html @@ -1,4 +1,5 @@ // hackhack for some reason getting graphlib via an import reference results in // out of order script evaluation + diff --git a/tensorflow/tensorboard/components/imports/local-imports/graphlib.html b/tensorflow/tensorboard/components/imports/local-imports/graphlib.html index a1e98e9089d..4bf3528fcd3 100644 --- a/tensorflow/tensorboard/components/imports/local-imports/graphlib.html +++ b/tensorflow/tensorboard/components/imports/local-imports/graphlib.html @@ -1 +1,2 @@ + diff --git a/tensorflow/tensorboard/components/tf-graph-board/tf-graph-board.html b/tensorflow/tensorboard/components/tf-graph-board/tf-graph-board.html index 6ff365a6c1d..41bef4e0a0a 100644 --- a/tensorflow/tensorboard/components/tf-graph-board/tf-graph-board.html +++ b/tensorflow/tensorboard/components/tf-graph-board/tf-graph-board.html @@ -77,6 +77,32 @@ paper-progress { --paper-progress-height: 6px; --paper-progress-active-color: #f3913e; } + +.context-menu { + position: absolute; + display: none; + background-color: #e2e2e2; + border-radius: 2px; + font-size: 14px; + min-width: 150px; + border: 1px solid #d4d4d4; +} + +/deep/ .context-menu ul { + list-style-type: none; + margin: 0; + padding: 0; + cursor: default; +} + +/deep/ .context-menu ul li { + padding: 4px 16px; +} + +/deep/ .context-menu ul li:hover { + background-color: #f3913e; + color: white; +} @@ -137,9 +165,17 @@ Polymer({ }, // Private API: Data routing between child components. _selectedNode: String, + // The enum value of the include property of the selected node. + _selectedNodeInclude: Number, _highlightedNode: String, _renderHierarchy: Object, }, + listeners: { + 'node-toggle-extract': '_nodeToggleExtract' + }, + observers: [ + '_updateNodeInclude(_selectedNode)' + ], /** True if the progress is not complete yet (< 100 %). */ _isNotComplete: function(progress) { return progress.value < 100; @@ -153,6 +189,14 @@ Polymer({ result += ' loading'; } return result; + }, + _updateNodeInclude: function(nodeName) { + var node = this.graphHierarchy.node(nodeName); + this.set("_selectedNodeInclude", + node ? node.include : tf.graph.InclusionType.UNSPECIFIED); + }, + _nodeToggleExtract: function() { + this._updateNodeInclude(this._selectedNode); } }); diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts index 8d1faaceecc..41f00c54f42 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts @@ -29,6 +29,9 @@ export enum GraphType {FULL, EMBEDDED, META, SERIES, CORE, SHADOW, BRIDGE, EDGE}; export enum NodeType {META, OP, SERIES, BRIDGE, ELLIPSIS}; +/** Indicates if a node is to be included in the main graph when rendered. */ +export enum InclusionType {INCLUDE, EXCLUDE, UNSPECIFIED}; + /** * A BaseEdge is the label object (in the graphlib sense) for an edge in the * original, full graph produced after parsing. Subsequent graphs, like those @@ -98,6 +101,12 @@ export interface Node { parentNode: Node; /** Runtime execution stats for this node, if available */ stats: NodeStats; + /** If the node is to be included or excluded from the main graph when + * rendered. Defaults to UNSPECIFIED, which means that the rendering + * algorithm determines if it will be included or not. Then can be set to + * INCLUDE or EXCLUDE manually by the user. + */ + include: InclusionType; } export interface OpNode extends Node { @@ -258,6 +267,7 @@ export class EllipsisNodeImpl implements EllipsisNode { isGroupNode: boolean; cardinality: number; parentNode: Node; + include: InclusionType; /** * Constructs a new ellipsis annotation node. @@ -271,6 +281,7 @@ export class EllipsisNodeImpl implements EllipsisNode { this.parentNode = null; this.stats = null; this.setNumMoreNodes(numNodes); + this.include = InclusionType.UNSPECIFIED; } setNumMoreNodes(numNodes: number) { @@ -296,6 +307,7 @@ class OpNodeImpl implements OpNode { inEmbeddings: OpNode[]; outEmbeddings: OpNode[]; parentNode: Node; + include: InclusionType; /** * Constructs a new Op node. @@ -319,6 +331,7 @@ class OpNodeImpl implements OpNode { this.inEmbeddings = []; this.outEmbeddings = []; this.parentNode = null; + this.include = InclusionType.UNSPECIFIED; } }; @@ -419,6 +432,7 @@ class MetanodeImpl implements Metanode { deviceHistogram: {[op: string]: number}; parentNode: Node; hasNonControlEdges: boolean; + include: InclusionType; /** A label object for meta-nodes in the graph hierarchy */ constructor(name: string, opt = {}) { @@ -448,6 +462,7 @@ class MetanodeImpl implements Metanode { this.parentNode = null; this.stats = new NodeStats(0, 0, null); this.hasNonControlEdges = false; + this.include = InclusionType.UNSPECIFIED; } getFirstChild(): GroupNode|OpNode { @@ -599,6 +614,7 @@ class SeriesNodeImpl implements SeriesNode { parentNode: Node; deviceHistogram: {[op: string]: number}; hasNonControlEdges: boolean; + include: InclusionType; constructor(prefix: string, suffix: string, parent: string, clusterId: number, name: string) { @@ -619,6 +635,7 @@ class SeriesNodeImpl implements SeriesNode { this.deviceHistogram = {}; this.hasNonControlEdges = false; this.stats = new NodeStats(0, 0, null); + this.include = InclusionType.UNSPECIFIED; } } @@ -901,4 +918,15 @@ export function getHierarchicalPath(name: string, return path; }; +/** + * Returns the string for the node inclusion toggle button, dependant + * on the provided current InclusionType. + */ +export function getIncludeNodeButtonString(include: InclusionType) { + if (include === tf.graph.InclusionType.EXCLUDE) { + return "Add to main graph"; + } else { + return "Remove from main graph"; + } +}; } // close module tf.graph diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts index 81e1873137f..c99c61a8496 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts @@ -288,6 +288,24 @@ export class RenderGraphInformation { setGroupNodeDepth(this.root, +depth); } + /** + * Returns true if the renderNode is an isolated node within its parent node. + */ + isNodeAuxilliary(renderNode: RenderNodeInformation): boolean { + let parentNode = this.getRenderNodeByName( + renderNode.node.parentNode.name); + let found = _.find(parentNode.isolatedInExtract, node => { + return node.node.name === renderNode.node.name; + }); + if (found) { + return true; + } + found = _.find(parentNode.isolatedOutExtract, node => { + return node.node.name === renderNode.node.name; + }); + return !!found; + } + buildSubhierarchy(nodeName: string): void { // Terminate if the rendering hierarchy was already constructed // for this node. @@ -555,6 +573,7 @@ export class RenderGraphInformation { cardinality: 0, parentNode: null, stats: null, + include: InclusionType.UNSPECIFIED, // BridgeNode properties. inbound: inbound, }; @@ -573,6 +592,7 @@ export class RenderGraphInformation { cardinality: 1, parentNode: null, stats: null, + include: InclusionType.UNSPECIFIED, // BridgeNode properties. inbound: inbound, }; @@ -692,6 +712,7 @@ export class RenderGraphInformation { cardinality: 1, parentNode: null, stats: null, + include: InclusionType.UNSPECIFIED, // BridgeNode properties. inbound: inbound, }; @@ -1127,6 +1148,16 @@ function createShortcut(graph: graphlib.Graph, let sink = graph.node(w); let edge = graph.edge(v, w); + // If either of the nodes is explicitly included in the main graph and + // both nodes are in the main graph then do not create the shortcut + // and instead keep the real edge. + if ((src.node.include === InclusionType.INCLUDE || + sink.node.include === InclusionType.INCLUDE) && + src.node.include !== InclusionType.EXCLUDE && + sink.node.include !== InclusionType.EXCLUDE) { + return; + } + // Add each annotation. addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT, params); addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT, params); @@ -1139,27 +1170,29 @@ function createShortcut(graph: graphlib.Graph, * Remove edges from a node, and set its isOutExtract property to true, * and remove the node and move it to isolatedOutExtract. * - * If detachAllEdgesForHighDegree is true, extract all of its edges. - * Otherwise, only extract all in-edges. + * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its + * edges. Otherwise, only extract all in-edges. */ function makeOutExtract(renderNode: RenderGroupNodeInformation, n: string, - params: RenderGraphParams) { + params: RenderGraphParams, forceDetach?: boolean) { let graph = renderNode.coreGraph; - - graph.node(n).isOutExtract = true; + let child = graph.node(n); + child.isOutExtract = true; _.each(graph.predecessors(n), (p, index) => { createShortcut(graph, p, n, params); }); - if (params.detachAllEdgesForHighDegree) { + if (params.detachAllEdgesForHighDegree || forceDetach) { _.each(graph.successors(n), (s, index) => { createShortcut(graph, n, s, params); }); } - if (params.detachAllEdgesForHighDegree || graph.neighbors(n).length === 0) { - renderNode.isolatedOutExtract.push(graph.node(n)); + // Remove the node from the core graph if it no longer has neighbors. + if (graph.neighbors(n).length === 0) { + child.node.include = InclusionType.EXCLUDE; + renderNode.isolatedOutExtract.push(child); graph.removeNode(n); } } @@ -1167,27 +1200,30 @@ function makeOutExtract(renderNode: RenderGroupNodeInformation, n: string, /** * Remove edges from a node, set its isInExtract property to true, * and remove the node and move it to isolatedInExtract. - * If detachAllEdgesForHighDegree is true, extract all of its edges. - * Otherwise, only remove all out-edges. + * + * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its + * edges. Otherwise, only remove all out-edges. */ -function makeInExtract(renderNode: RenderGroupNodeInformation, n: string, - params: RenderGraphParams) { +export function makeInExtract(renderNode: RenderGroupNodeInformation, n: string, + params: RenderGraphParams, forceDetach?: boolean) { let graph = renderNode.coreGraph; - graph.node(n).isInExtract = true; + let child = graph.node(n); + child.isInExtract = true; _.each(graph.successors(n), (s, index) => { createShortcut(graph, n, s, params); }); - if (params.detachAllEdgesForHighDegree) { + if (params.detachAllEdgesForHighDegree || forceDetach) { _.each(graph.predecessors(n), (p, index) => { createShortcut(graph, p, n, params); }); } - // Remove the node from the core graph if conditions are met. - if (params.detachAllEdgesForHighDegree || graph.neighbors(n).length === 0) { - renderNode.isolatedInExtract.push(graph.node(n)); + // Remove the node from the core graph if it no longer has neighbors. + if (graph.neighbors(n).length === 0) { + child.node.include = InclusionType.EXCLUDE; + renderNode.isolatedInExtract.push(child); graph.removeNode(n); } } @@ -1214,12 +1250,32 @@ function hasTypeIn(node: Node, types: string[]): boolean { return false; } +/** Move nodes that are speficied to be excluded out of the core graph. */ +function extractSpeficiedNodes(renderNode: RenderGroupNodeInformation, + params: RenderGraphParams) { + let graph = renderNode.coreGraph; + _.each(graph.nodes(), n => { + let renderInfo = graph.node(n); + if (renderInfo.node.include === InclusionType.EXCLUDE) { + if (renderNode.coreGraph.outEdges(n).length > + renderNode.coreGraph.inEdges(n).length) { + makeOutExtract(renderNode, n, params, true); + } else { + makeInExtract(renderNode, n, params, true); + } + } + }); +} + /** Remove edges from pre-defined out-extract patterns */ function extractPredefinedSink(renderNode: RenderGroupNodeInformation, params: RenderGraphParams) { let graph = renderNode.coreGraph; _.each(graph.nodes(), n => { let renderInfo = graph.node(n); + if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { + return; + } if (hasTypeIn(renderInfo.node, params.outExtractTypes)) { makeOutExtract(renderNode, n, params); } @@ -1233,6 +1289,9 @@ function extractPredefinedSource(renderNode: RenderGroupNodeInformation, _.each(graph.nodes(), n => { let renderInfo = graph.node(n); + if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { + return; + } if (hasTypeIn(renderInfo.node, params.inExtractTypes)) { makeInExtract(renderNode, n, params); } @@ -1247,6 +1306,9 @@ function extractHighInDegree(renderNode: RenderGroupNodeInformation, // detect first so degrees don't get affected by other removal let highInDegreeNames = _.filter(graph.nodes(), n => { + if (graph.node(n).node.include !== InclusionType.UNSPECIFIED) { + return false; + } // Count the in-degree based on only regular edges, unless there are // no regular edges, in which case use the number of control edges. // This is done so that control edges don't effect if nodes are extracted @@ -1274,6 +1336,9 @@ function extractHighOutDegree(renderNode: RenderGroupNodeInformation, // detect first so degrees don't get affected by other removal let highOutDegreeNames = _.filter(graph.nodes(), n => { + if (graph.node(n).node.include !== InclusionType.UNSPECIFIED) { + return false; + } // Count the out-degree based on only regular edges, unless there are // no regular edges, in which case use the number of control edges. // This is done so that control edges don't effect if nodes are extracted @@ -1345,6 +1410,9 @@ export function mapIndexToHue(id: number): number { */ function extractHighDegrees(renderNode: RenderGroupNodeInformation, params: RenderGraphParams) { + + extractSpeficiedNodes(renderNode, params); + if (params.outExtractTypes) { extractPredefinedSink(renderNode, params); } @@ -1386,7 +1454,9 @@ function extractHighDegrees(renderNode: RenderGroupNodeInformation, _.each(graph.nodes(), n => { let child = graph.node(n); let degree = graph.neighbors(n).length; - + if (child.node.include !== InclusionType.UNSPECIFIED) { + return; + } if (degree === 0) { let hasOutAnnotations = child.outAnnotations.list.length > 0; let hasInAnnotations = child.inAnnotations.list.length > 0; @@ -1395,20 +1465,24 @@ function extractHighDegrees(renderNode: RenderGroupNodeInformation, // This case only happens if detachAllEdgesForHighDegree is false. // (Otherwise all source-like nodes are all isolated already.) renderNode.isolatedInExtract.push(child); + child.node.include = InclusionType.EXCLUDE; graph.removeNode(n); } else if (child.isOutExtract) { // Is sink-like. // This case only happens if detachAllEdgesForHighDegree is false. // // (Otherwise all sink-like nodes are all isolated already.) renderNode.isolatedOutExtract.push(child); + child.node.include = InclusionType.EXCLUDE; graph.removeNode(n); } else if (params.extractIsolatedNodesWithAnnotationsOnOneSide) { if (hasOutAnnotations && !hasInAnnotations) { child.isInExtract = true; // for ones with high out-annotations renderNode.isolatedInExtract.push(child); + child.node.include = InclusionType.EXCLUDE; graph.removeNode(n); } else if (hasInAnnotations && !hasOutAnnotations) { child.isOutExtract = true; // for ones with high in-annotations renderNode.isolatedOutExtract.push(child); + child.node.include = InclusionType.EXCLUDE; graph.removeNode(n); } else { // if a low degree node has both in- & out- annotations, do nothing diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts index 6823d8b48c7..d973f75fd3c 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts @@ -17,6 +17,7 @@ limitations under the License. /// /// /// +/// module tf.graph.scene.annotation { @@ -90,7 +91,7 @@ export function buildGroup(container, annotationData: render.AnnotationList, let aGroup = d3.select(this); update(aGroup, d, a, sceneBehavior); if (a.annotationType !== tf.graph.render.AnnotationType.ELLIPSIS) { - addInteraction(aGroup, d, sceneBehavior); + addInteraction(aGroup, d, a, sceneBehavior); } }); @@ -151,7 +152,7 @@ function addAnnotationLabel(aGroup, label, a, additionalClassNames, } function addInteraction(selection, d: render.RenderNodeInformation, - sceneBehavior) { + annotation: tf.graph.render.Annotation, sceneBehavior) { selection .on("mouseover", a => { sceneBehavior.fire("annotation-highlight", { @@ -174,6 +175,11 @@ function addInteraction(selection, d: render.RenderNodeInformation, hostName: d.node.name }); }); + if (annotation.annotationType !== tf.graph.render.AnnotationType.SUMMARY && + annotation.annotationType !== tf.graph.render.AnnotationType.CONSTANT) { + selection.on("contextmenu", tf.graph.scene.contextmenu.getMenu( + tf.graph.scene.node.getContextMenu(annotation.node, sceneBehavior))); + } }; /** diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/contextmenu.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/contextmenu.ts new file mode 100644 index 00000000000..648f3461fe1 --- /dev/null +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/contextmenu.ts @@ -0,0 +1,77 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +module tf.graph.scene.contextmenu { + +/** Function that converts data to a title string. */ +export interface TitleFunction { + (data: any): string; +} + +/** Function that takes action based on item clicked in the context menu. */ +export interface ActionFunction { + (elem: any, d: any, i: number): void; +} + +/** + * The interface for an item in the context menu + */ +export interface ContextMenuItem { + title: TitleFunction; + action: ActionFunction; +} + +/** + * Returns the event listener, which can be used as an argument for the d3 + * selection.on function. Renders the context menu that is to be displayed + * in response to the event. + */ +export function getMenu(menu: ContextMenuItem[]) { + let menuSelection = d3.select(".context-menu"); + // Close the menu when anything else is clicked. + d3.select("body").on("click.context", function() { + menuSelection.style("display", "none"); + }); + + // Function called to populate the context menu. + return function(data, index: number): void { + // Position and display the menu. + let event = d3.event; + menuSelection.style({ + "display": "block", + "left": (event.layerX + 1) + "px", + "top": (event.layerY + 1) + "px" + }); + + // Stop the event from propagating further. + event.preventDefault(); + event.stopPropagation(); + + // Add provided items to the context menu. + menuSelection.html(""); + let list = menuSelection.append("ul"); + list.selectAll("li").data(menu).enter() + .append("li") + .html(function(d) { + return d.title(data); + }) + .on("click", (d, i) => { + d.action(this, data, index); + menuSelection.style("display", "none"); + }); + }; +}; + +} // close module diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts index 57496ceffb6..bb1d1fdcdc1 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts @@ -16,6 +16,7 @@ limitations under the License. /// /// /// +/// module tf.graph.scene.node { @@ -229,31 +230,52 @@ function addInteraction(selection, d: render.RenderNodeInformation, selection.attr("pointer-events", "none"); return; } + + let contextMenuFunction = tf.graph.scene.contextmenu.getMenu( + getContextMenu(d.node, sceneBehavior)); selection.on("dblclick", d => { - sceneBehavior.fire("node-toggle-expand", { name: d.node.name }); + sceneBehavior.fire("node-toggle-expand", { name: d.node.name }); }) .on("mouseover", d => { // don't send mouseover over expanded group, // otherwise it is causing too much glitches - if (sceneBehavior.isNodeExpanded(d)) { return; } + if (sceneBehavior.isNodeExpanded(d)) { return; } sceneBehavior.fire("node-highlight", { name: d.node.name }); }) .on("mouseout", d => { // don't send mouseover over expanded group, // otherwise it is causing too much glitches - if (sceneBehavior.isNodeExpanded(d)) { return; } + if (sceneBehavior.isNodeExpanded(d)) { return; } - sceneBehavior.fire("node-unhighlight", { name: d.node.name }); + sceneBehavior.fire("node-unhighlight", { name: d.node.name }); }) .on("click", d => { // Stop this event's propagation so that it isn't also considered // a graph-select. (d3.event).stopPropagation(); sceneBehavior.fire("node-select", { name: d.node.name }); + }) + .on("contextmenu", (d, i) => { + sceneBehavior.fire("node-select", { name: d.node.name }); + contextMenuFunction.call(d, i); }); }; +/** + * Returns the d3 context menu specification for the provided node. + */ +export function getContextMenu(node: Node, sceneBehavior) { + return [{ + title: d => { + return tf.graph.getIncludeNodeButtonString(node.include); + }, + action: (elm, d, i) => { + sceneBehavior.fire("node-toggle-extract", { name: node.name }); + } + }]; +} + /** * Append svg text for label and assign data. * @param nodeGroup diff --git a/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html b/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html index f42e7d27dd7..87a7efe500f 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html +++ b/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html @@ -13,5 +13,6 @@ + diff --git a/tensorflow/tensorboard/components/tf-graph-dashboard/tf-graph-dashboard.html b/tensorflow/tensorboard/components/tf-graph-dashboard/tf-graph-dashboard.html index 779d2d3fe96..134f6f1ee91 100644 --- a/tensorflow/tensorboard/components/tf-graph-dashboard/tf-graph-dashboard.html +++ b/tensorflow/tensorboard/components/tf-graph-dashboard/tf-graph-dashboard.html @@ -83,6 +83,7 @@ by default. The user can select a different run from a dropdown menu. } .center { + position: relative; height: 100%; } diff --git a/tensorflow/tensorboard/components/tf-graph-info/tf-graph-info.html b/tensorflow/tensorboard/components/tf-graph-info/tf-graph-info.html index e47269861ba..e1e2998d0ab 100644 --- a/tensorflow/tensorboard/components/tf-graph-info/tf-graph-info.html +++ b/tensorflow/tensorboard/components/tf-graph-info/tf-graph-info.html @@ -22,6 +22,7 @@ h2 { render-hierarchy="[[renderHierarchy]]" flat-graph="[[graph]]" node-name="[[selectedNode]]" + node-include="[[selectedNodeInclude]]" highlighted-node="{{highlightedNode}}" color-by="[[colorBy]]"> @@ -47,6 +48,11 @@ h2 { highlightedNode: { type: String, notify: true + }, + // The enum value of the include property of the selected node. + selectedNodeInclude: { + type: Number, + notify: true } }, listeners: { diff --git a/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html b/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html index 056c851924d..5d54ea82d97 100644 --- a/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html +++ b/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html @@ -89,6 +89,23 @@ padding: 0; } + .toggle-include-group { + padding-top: 4px; + } + + .toggle-include { + margin: 5px 6px; + text-transform: none; + padding: 4px 6px; + font-size: 10pt; + background-color: #fafafa; + color: #666; + } + + .toggle-include:hover { + background-color: var(--google-yellow-100); + } + .non-control-list-item { padding-left: 10px; } @@ -248,6 +265,11 @@ +
+ + [[_auxButtonText]] + +
@@ -273,6 +295,11 @@ computed: '_getNode(nodeName, graphHierarchy)', observer: '_resetState' }, + // The enum value of the include property of the selected node. + nodeInclude: { + type: Number, + observer: '_nodeIncludeStateChanged' + }, _attributes: { type: Array, computed: '_getAttributes(_node)' @@ -313,6 +340,7 @@ type: Boolean, value: false }, + _auxButtonText: String }, expandNode: function() { this.fire('_node.expand', this.node); @@ -379,6 +407,16 @@ if (list) { list.fire('iron-resize'); } + }, + _toggleInclude: function() { + var graphElem = document.querySelector("#graph"); + graphElem.fire("node-toggle-extract", { name: this.nodeName }); + var graphBoardElem = document.querySelector("#graphboard"); + graphBoardElem.fire("node-toggle-extract"); + }, + _nodeIncludeStateChanged: function(include, oldInclude) { + this.set("_auxButtonText", + tf.graph.getIncludeNodeButtonString(include)); } }); })(); diff --git a/tensorflow/tensorboard/components/tf-graph/demo/tf-graph-demo.html b/tensorflow/tensorboard/components/tf-graph/demo/tf-graph-demo.html index d6e736d185f..4664c8334cf 100644 --- a/tensorflow/tensorboard/components/tf-graph/demo/tf-graph-demo.html +++ b/tensorflow/tensorboard/components/tf-graph/demo/tf-graph-demo.html @@ -1,4 +1,4 @@ - + diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-params.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-params.html index 576816ddd0f..96ce6da0264 100644 --- a/tensorflow/tensorboard/components/tf-graph/tf-graph-params.html +++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-params.html @@ -68,7 +68,7 @@ Module for adjusting render graph building parameter. */ detachAllEdgesForHighDegree: { type: Boolean, - value: false + value: true }, /** diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph.html b/tensorflow/tensorboard/components/tf-graph/tf-graph.html index 0bcd4d5521a..cea58415751 100644 --- a/tensorflow/tensorboard/components/tf-graph/tf-graph.html +++ b/tensorflow/tensorboard/components/tf-graph/tf-graph.html @@ -82,7 +82,6 @@ Polymer({ type: Object, readOnly: true, notify: true, - computed: '_buildRenderHierarchy(graphHierarchy, _graphParams)' }, // internal properties _graphParams: { @@ -100,8 +99,11 @@ Polymer({ value: true } }, + observers: [ + '_buildRenderHierarchy(graphHierarchy, _graphParams)' + ], _buildRenderHierarchy: function(graphHierarchy, params) { - return tf.time('new tf.graph.render.Hierarchy', function() { + tf.time('new tf.graph.render.Hierarchy', function() { if (graphHierarchy.root.type !== tf.graph.NodeType.META) { // root must be metanode but sometimes Polymer's dom-if has not // remove tf-graph element yet in @@ -135,7 +137,7 @@ Polymer({ }; }) }); - return renderGraph; + this._setRenderHierarchy(renderGraph); }.bind(this)); }, _getVisible: function(name) { @@ -153,6 +155,7 @@ Polymer({ 'node-select': '_nodeSelected', 'node-highlight': '_nodeHighlighted', 'node-unhighlight': '_nodeUnhighlighted', + 'node-toggle-extract': '_nodeToggleExtract', // Annotations @@ -214,6 +217,23 @@ Polymer({ // Also select the expanded node. this._nodeSelected(event); }, + _nodeToggleExtract: function(event) { + // Toggle the include setting of the specified node appropriately. + var nodeName = event.detail.name; + var renderNode = this.renderHierarchy.getRenderNodeByName(nodeName); + if (renderNode.node.include == tf.graph.InclusionType.INCLUDE) { + renderNode.node.include = tf.graph.InclusionType.EXCLUDE; + } else if (renderNode.node.include == tf.graph.InclusionType.EXCLUDE) { + renderNode.node.include = tf.graph.InclusionType.INCLUDE; + } else { + renderNode.node.include = + this.renderHierarchy.isNodeAuxilliary(renderNode) + ? tf.graph.InclusionType.INCLUDE : tf.graph.InclusionType.EXCLUDE; + } + + // Rebuild the render hierarchy. + this._buildRenderHierarchy(this.graphHierarchy, this._graphParams); + }, not: function(x) { return !x; } diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 150e5321abc..d541762aa38 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -284,30 +284,33 @@ _py_wrap_cc = rule(attrs={ }, implementation=_py_wrap_cc_impl,) - def tf_extension_linkopts(): return [] # No extension link opts +def tf_extension_copts(): + return [] # No extension c opts + def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs): module_name = name.split("/")[-1] # Convert a rule name such as foo/bar/baz to foo/bar/_baz.so # and use that as the name for the rule producing the .so file. cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"]) + extra_deps = [] _py_wrap_cc(name=name + "_py_wrap", srcs=srcs, swig_includes=swig_includes, - deps=deps, + deps=deps + extra_deps, module_name=module_name, py_module_name=name) native.cc_binary( name=cc_library_name, srcs=[module_name + ".cc"], - copts=copts + ["-Wno-self-assign", "-Wno-write-strings" - ] + ["-I/usr/include/python2.7"], + copts=(copts + ["-Wno-self-assign", "-Wno-write-strings"] + + tf_extension_copts()), linkopts=tf_extension_linkopts(), linkstatic=1, linkshared=1, - deps=deps) + deps=deps + extra_deps) native.py_library(name=name, srcs=[":" + name + ".py"], srcs_version="PY2AND3", diff --git a/third_party/py/numpy/BUILD b/third_party/py/numpy/BUILD new file mode 100644 index 00000000000..c025984cca7 --- /dev/null +++ b/third_party/py/numpy/BUILD @@ -0,0 +1,14 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "headers", + hdrs = glob([ + "numpy_include/**/*.h", + ]), + data = ["//util/python:python_checked"], + includes = [ + "numpy_include", + ], +) diff --git a/util/python/BUILD b/util/python/BUILD new file mode 100644 index 00000000000..861cd8d6ff1 --- /dev/null +++ b/util/python/BUILD @@ -0,0 +1,24 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "python_headers", + hdrs = glob([ + "python_include/**/*.h", + ]), + data = [":python_checked"], + includes = ["python_include"], +) + +genrule( + name = "python_check", + srcs = [ + "python_config.sh", + ], + outs = [ + "python_checked", + ], + cmd = "OUTPUTDIR=\"$(@D)/\"; ./util/python/python_config.sh --check && touch $$OUTPUTDIR/python_checked", + local = 1, +) diff --git a/util/python/python_config.sh b/util/python/python_config.sh new file mode 100755 index 00000000000..a52f6f70306 --- /dev/null +++ b/util/python/python_config.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +set -e -o errexit + +EXPECTED_PATHS="util/python/python_include util/python/python_lib third_party/py/numpy/numpy_include" + +function main { + argument="$1" + shift + case $argument in + --check) + check_python + exit 0 + ;; + --setup) + setup_python "$1" + exit 0 + ;; + esac +} + +function setup_python { + PYTHON_BIN_PATH="$1"; + + if [ -z "$PYTHON_BIN_PATH" ]; then + echo "PYTHON_BIN_PATH was not provided. Did you run configure?" + exit 1 + fi + if [ ! -x "$PYTHON_BIN_PATH" ] || [ -d "$PYTHON_BIN_PATH" ]; then + echo "PYTHON_BIN_PATH is not executable. Is it the python binary?" + exit 1 + fi + + local python_include=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; from distutils import sysconfig; print(sysconfig.get_python_inc());') + if [ "$python_include" == "" ]; then + echo -e "\n\nERROR: Problem getting python include path. Is distutils installed?" + exit 1 + fi + local python_lib=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; from distutils import sysconfig; print(sysconfig.get_python_lib());') + if [ "$python_lib" == "" ]; then + echo -e "\n\nERROR: Problem getting python lib path. Is distutils installed?" + exit 1 + fi + local numpy_include=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import numpy; print(numpy.get_include());') + if [ "$numpy_include" == "" ]; then + echo -e "\n\nERROR: Problem getting numpy include path. Is numpy installed?" + exit 1 + fi + + for x in $EXPECTED_PATHS; do + if [ -e "$x" ]; then + rm "$x" + fi + done + + ln -s "${python_include}" util/python/python_include + ln -s "${python_lib}" util/python/python_lib + ln -s "${numpy_include}" third_party/py/numpy/numpy_include +} + +function check_python { + for x in $EXPECTED_PATHS; do + if [ ! -e "$x" ]; then + echo -e "\n\nERROR: Cannot find '${x}'. Did you run configure?\n\n" 1>&2 + exit 1 + fi + if [ ! -L "${x}" ]; then + echo -e "\n\nERROR: '${x}' is not a symbolic link. Internal error.\n\n" 1>&2 + exit 1 + fi + true_path=$(readlink "${x}") + if [ ! -d "${true_path}" ]; then + echo -e "\n\nERROR: '${x}' does not refer to an existing directory: ${true_path}. Do you need to rerun configure?\n\n" 1>&2 + exit 1 + fi + done +} + +main "$@"