diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 2b1a0412354..4667f096e00 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -495,6 +495,11 @@ static void SimplifyGraph(Graph* g) { void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) { DumpGraph("Initial", *g); + + // Run SimplifyGraph at least once to rewrite away ops such as + // _ListToArray, _ArrayToList, etc. + SimplifyGraph(*g); + const int kNumInlineRounds = 10; for (int i = 0; i < kNumInlineRounds; ++i) { if (!ExpandInlineFunctions(lib, *g)) break; diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc index 3df833594f8..8979c94e3d9 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc @@ -65,7 +65,7 @@ GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory) ptr_to_chunk_map_.insert(std::make_pair(c->ptr, c)); // Insert the chunk into the right bin. - ReassignChunkToBin(c); + InsertFreeChunkIntoBin(c); } GPUBFCAllocator::~GPUBFCAllocator() { @@ -76,6 +76,7 @@ GPUBFCAllocator::~GPUBFCAllocator() { } gtl::STLDeleteValues(&bins_); + gtl::STLDeleteValues(&ptr_to_chunk_map_); } void* GPUBFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) { @@ -115,10 +116,12 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment, // Start searching from the first bin for the smallest chunk that fits // rounded_bytes. Bin* b = it->second; - for (GPUBFCAllocator::Chunk* chunk : b->chunks) { - if (!chunk->in_use && chunk->size > rounded_bytes) { - // We found an existing chunk that fits us that wasn't in use. - chunk->in_use = true; + for (GPUBFCAllocator::Chunk* chunk : b->free_chunks) { + DCHECK(!chunk->in_use); + if (chunk->size >= rounded_bytes) { + // We found an existing chunk that fits us that wasn't in use, so remove + // it from the free bin structure prior to using. + RemoveFreeChunkFromBin(chunk); // If we can break the size of the chunk into two reasonably // large pieces, do so. @@ -132,6 +135,7 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment, // The requested size of the returned chunk is what the user // has allocated. chunk->requested_size = num_bytes; + chunk->in_use = true; VLOG(4) << "Returning: " << chunk->ptr; return chunk->ptr; @@ -152,6 +156,8 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment, } void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) { + CHECK(!c->in_use && !c->bin); + // Create a new chunk starting num_bytes after c GPUBFCAllocator::Chunk* new_chunk = new GPUBFCAllocator::Chunk(); new_chunk->ptr = static_cast<void*>(static_cast<char*>(c->ptr) + num_bytes); @@ -176,9 +182,8 @@ void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) { c_neighbor->prev = new_chunk; } - // Maintain the bins - ReassignChunkToBin(new_chunk); - ReassignChunkToBin(c); + // Add the newly free chunk to the free bin. + InsertFreeChunkIntoBin(new_chunk); } void GPUBFCAllocator::DeallocateRaw(void* ptr) { @@ -200,11 +205,9 @@ void GPUBFCAllocator::DeallocateRawInternal(void* ptr) { GPUBFCAllocator::Chunk* c = it->second; VLOG(6) << "Chunk at " << c->ptr << " no longer in use"; - // Mark the chunk as no longer in use - c->in_use = false; // Consider coalescing it. - MaybeCoalesce(c); + FreeAndMaybeCoalesce(c); } // Merges c1 and c2 when c1->next is c2 and c2->prev is c1. @@ -212,7 +215,7 @@ void GPUBFCAllocator::DeallocateRawInternal(void* ptr) { void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1, GPUBFCAllocator::Chunk* c2) { // We can only merge chunks that are not in use. - DCHECK(!c1->in_use && !c2->in_use); + CHECK(!c1->in_use && !c2->in_use); // c1's prev doesn't change, still points to the same ptr, and is // still not in use. @@ -231,62 +234,42 @@ void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1, // Set the new size c1->size += c2->size; - // Delete c2 and cleanup all state - RemoveChunkFromBin(c2); + DeleteChunk(c2); } -void GPUBFCAllocator::ReassignChunkToBin(GPUBFCAllocator::Chunk* c) { +void GPUBFCAllocator::DeleteChunk(Chunk* c) { + // Delete c2 and cleanup all state + VLOG(4) << "Removing: " << c->ptr; + ptr_to_chunk_map_.erase(c->ptr); + delete c; +} + +void GPUBFCAllocator::InsertFreeChunkIntoBin(GPUBFCAllocator::Chunk* c) { + CHECK(!c->in_use && !c->bin); auto it = bins_.lower_bound(c->size); CHECK(it != bins_.end()) << " Tried to reassign to non-existent bin for size " << c->size; - Bin* new_bin = it->second; - - // If the bin has not changed, do nothing. - Bin* old_bin = c->bin; - if (old_bin != nullptr && new_bin == old_bin) { - return; - } - - // The bin has changed. Add the chunk to the new bin and remove - // the chunk from the old bin. - new_bin->chunks.insert(c); c->bin = new_bin; - - if (old_bin == nullptr) { - return; - } - - // Remove chunk from old bin - for (auto it = old_bin->chunks.begin(); it != old_bin->chunks.end(); ++it) { - if (*it == c) { - old_bin->chunks.erase(it); - return; - } - } - CHECK(false) << "Could not find chunk in old bin"; + new_bin->free_chunks.insert(c); } -void GPUBFCAllocator::RemoveChunkFromBin(GPUBFCAllocator::Chunk* c) { - Bin* b = c->bin; - for (auto it = b->chunks.begin(); it != b->chunks.end(); ++it) { - Chunk* other_c = *it; - if (other_c->ptr == c->ptr) { - b->chunks.erase(it); - VLOG(4) << "Removing: " << c->ptr; - ptr_to_chunk_map_.erase(c->ptr); - delete c; - return; - } - } - - CHECK(false) << "Could not find chunk in bin"; +void GPUBFCAllocator::RemoveFreeChunkFromBin(GPUBFCAllocator::Chunk* c) { + CHECK(!c->in_use && c->bin); + int count = c->bin->free_chunks.erase(c); + CHECK(count > 0) << "Could not find chunk in bin"; + c->bin = nullptr; } -void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) { +void GPUBFCAllocator::FreeAndMaybeCoalesce(GPUBFCAllocator::Chunk* c) { + CHECK(c->in_use && !c->bin); + + // Mark the chunk as no longer in use + c->in_use = false; + // This chunk is no longer in-use, consider coalescing the chunk // with adjacent chunks. - Chunk* chunk_to_reassign = nullptr; + Chunk* chunk_to_reassign = c; // If the next chunk is free, coalesce the two, if the result would // fit in an existing bin. @@ -296,6 +279,7 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) { chunk_to_reassign = c; // Deletes c->next + RemoveFreeChunkFromBin(c->next); Merge(c, c->next); } @@ -307,13 +291,11 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) { chunk_to_reassign = c->prev; // Deletes c + RemoveFreeChunkFromBin(c->prev); Merge(c->prev, c); } - // Reassign the final merged chunk into the right bin. - if (chunk_to_reassign) { - ReassignChunkToBin(chunk_to_reassign); - } + InsertFreeChunkIntoBin(chunk_to_reassign); } void GPUBFCAllocator::AddAllocVisitor(Visitor visitor) { @@ -354,7 +336,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) { size_t total_requested_bytes_in_bin = 0; size_t total_chunks_in_use = 0; size_t total_chunks_in_bin = 0; - for (Chunk* c : b->chunks) { + for (Chunk* c : b->free_chunks) { total_bytes_in_bin += c->size; total_requested_bytes_in_bin += c->requested_size; ++total_chunks_in_bin; @@ -388,7 +370,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) { << " was " << strings::HumanReadableNumBytes(b->bin_size) << ", Chunk State: "; - for (Chunk* c : b->chunks) { + for (Chunk* c : b->free_chunks) { LOG(INFO) << c->DebugString(true); } } diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h index 3d1601e132a..417df6f4136 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h @@ -102,28 +102,33 @@ class GPUBFCAllocator : public VisitableAllocator { Chunk* AllocateNewChunk(size_t num_bytes); void SplitChunk(Chunk* c, size_t num_bytes); void Merge(Chunk* c1, Chunk* c2); - void MaybeCoalesce(Chunk* c); - - void ReassignChunkToBin(Chunk* c); - void RemoveChunkFromBin(Chunk* c); + void FreeAndMaybeCoalesce(Chunk* c); + void InsertFreeChunkIntoBin(Chunk* c); + void RemoveFreeChunkFromBin(Chunk* c); + void DeleteChunk(Chunk* c); void DumpMemoryLog(size_t num_bytes); - // A Bin is a collection of similar-sized Chunks. + // A Bin is a collection of similar-sized free chunks. struct Bin { // All chunks in this bin have >= bin_size memory. size_t bin_size = 0; struct ChunkComparator { - bool operator()(Chunk* a, Chunk* b) { return a->size < b->size; } + // Sort first by size and then use pointer address as a tie breaker. + bool operator()(const Chunk* a, const Chunk* b) const { + if (a->size != b->size) { + return a->size < b->size; + } + return a->ptr < b->ptr; + } }; - // List of chunks within the bin, sorted by chunk size. - std::multiset<Chunk*, ChunkComparator> chunks; + // List of free chunks within the bin, sorted by chunk size. + // Chunk * not owned. + std::set<Chunk*, ChunkComparator> free_chunks; explicit Bin(size_t bs) : bin_size(bs) {} - - ~Bin() { gtl::STLDeleteElements(&chunks); } }; GPUAllocatorRetry retry_helper_; @@ -142,7 +147,7 @@ class GPUBFCAllocator : public VisitableAllocator { // Structures mutable after construction mutable mutex lock_; - // Not owned. + // Chunk * owned. std::unordered_map<void*, Chunk*> ptr_to_chunk_map_; // Called once on each region, ASAP. diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc index 474b988d2fa..fa9c0170f54 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.cc +++ b/tensorflow/core/common_runtime/gpu/process_state.cc @@ -20,7 +20,7 @@ DEFINE_bool(record_mem_types, false, DEFINE_bool(brain_mem_reg_cuda_dma, true, "If true, register CPU RAM used to copy to/from GPU RAM " "with the CUDA driver."); -DEFINE_bool(brain_gpu_use_bfc_allocator, false, +DEFINE_bool(brain_gpu_use_bfc_allocator, true, "If true, uses the Best-Fit GPU allocator."); DEFINE_bool(brain_gpu_region_allocator_debug, false, "If true, checks for memory overwrites by writing " @@ -34,7 +34,7 @@ bool FLAGS_record_mem_types = false; bool FLAGS_brain_mem_reg_cuda_dma = true; bool FLAGS_brain_gpu_region_allocator_debug = false; bool FLAGS_brain_gpu_region_allocator_reset_to_nan = false; -bool FLAGS_brain_gpu_use_bfc_allocator = false; +bool FLAGS_brain_gpu_use_bfc_allocator = true; #endif namespace gpu = ::perftools::gputools; diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index b68fcec5151..adc802cb452 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -135,6 +135,7 @@ REGISTER_CONCAT(bfloat16); ConcatOp<GPUDevice, type>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +REGISTER_GPU(bfloat16); #undef REGISTER_GPU // A special GPU kernel for int32. diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc index aed36dccef7..0f21d5f07c5 100644 --- a/tensorflow/core/kernels/concat_op_gpu.cu.cc +++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc @@ -6,6 +6,7 @@ #include <memory> +#include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" @@ -34,6 +35,7 @@ void ConcatGPU(const GPUDevice& d, typename TTypes<T, 2>::Matrix* output); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +REGISTER_GPU(bfloat16); #undef REGISTER_GPU } // end namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index 7bfe8b095f6..eeb9fa4c388 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -541,12 +541,36 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { // The output image size is the spatial size of the output. const int output_image_size = out_rows * out_cols; + // Shard 'batch' images into 'shard_size' groups of images to be fed + // into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache + // size ('target_working_set_size') by the matmul size of an individual + // image ('work_unit_size'). + + // TODO(andydavis) + // *) Get L3 cache size from device at runtime (30MB is from ivybridge). + // *) Consider reducing 'target_working_set_size' if L3 is shared by + // other concurrently running tensorflow ops. + const size_t target_working_set_size = (30LL << 20) / sizeof(T); + + const size_t size_A = output_image_size * filter_total_size; + + const size_t size_B = output_image_size * out_depth; + + const size_t size_C = filter_total_size * out_depth; + + const size_t work_unit_size = size_A + size_B + size_C; + + const size_t shard_size = + (target_working_set_size + work_unit_size - 1) / work_unit_size; + Tensor col_buffer; - OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum<T>::value, - TensorShape({output_image_size, filter_total_size}), &col_buffer)); + OP_REQUIRES_OK(context, + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({static_cast<int64>(shard_size), + static_cast<int64>(output_image_size), + static_cast<int64>(filter_total_size)}), + &col_buffer)); // The input offset corresponding to a single input image. const int input_offset = input_rows * input_cols * in_depth; @@ -571,21 +595,29 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { contract_dims[0].first = 0; contract_dims[0].second = 0; - for (int image_id = 0; image_id < batch; ++image_id) { - // When we compute the gradient with respect to the filters, we need to do - // im2col to allow gemm-type computation. - Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows, - filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride, - stride, col_buffer_data); + for (int image_id = 0; image_id < batch; image_id += shard_size) { + const int shard_limit = std::min(static_cast<int>(shard_size), + static_cast<int>(batch) - image_id); + for (int shard_id = 0; shard_id < shard_limit; ++shard_id) { + // TODO(andydavis) Parallelize this loop. + // When we compute the gradient with respect to the filters, we need + // to do im2col to allow gemm-type computation. + Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows, + filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride, + stride, col_buffer_data + shard_id * size_A); - ConstTensorMap A(col_buffer_data, output_image_size, filter_total_size); - ConstTensorMap B(out_backprop_data + output_offset * image_id, - output_image_size, out_depth); + input_data += input_offset; + } + + ConstTensorMap A(col_buffer_data, output_image_size * shard_limit, + filter_total_size); + ConstTensorMap B(out_backprop_data, output_image_size * shard_limit, + out_depth); // Gradient with respect to filter. C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims); - input_data += input_offset; + out_backprop_data += output_offset * shard_limit; } } diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops.h index 1a614fe4f18..99455adce2a 100644 --- a/tensorflow/core/kernels/tile_ops.h +++ b/tensorflow/core/kernels/tile_ops.h @@ -46,8 +46,9 @@ template <typename Device, typename T> struct TileGrad<Device, T, 0> { void operator()(const Device& d, typename TTypes<T, 0>::Tensor out, typename TTypes<T, 0>::ConstTensor in, - const Eigen::DSizes<ptrdiff_t, 0>&, - const Eigen::DSizes<ptrdiff_t, 0>&, bool first) const { + const Eigen::DSizes<Eigen::DenseIndex, 0>&, + const Eigen::DSizes<Eigen::DenseIndex, 0>&, + bool first) const { if (first) { out.device(d) = in; } else { diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 8c0571b50e9..321d9c02768 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -407,7 +407,7 @@ reshape(t, [3, 3]) ==> [[1, 2, 3] # tensor 't' is [[[1, 1], [2, 2]] # [[3, 3], [4, 4]]] -# tensor 't' has shape [2, 2] +# tensor 't' has shape [2, 2, 2] reshape(t, [2, 4]) ==> [[1, 1, 2, 2] [3, 3, 4, 4]] diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md index f34e12b1dcb..10a6df41ff2 100644 --- a/tensorflow/g3doc/api_docs/python/array_ops.md +++ b/tensorflow/g3doc/api_docs/python/array_ops.md @@ -529,7 +529,7 @@ tf.shape(split0) ==> [5, 10] * <b>`split_dim`</b>: A 0-D `int32` `Tensor`. The dimension along which to split. Must be in the range `[0, rank(value))`. -* <b>`num_split`</b>: A 0-D `int32` `Tensor`. The number of ways to split. +* <b>`num_split`</b>: A Python integer. The number of ways to split. * <b>`value`</b>: The `Tensor` to split. * <b>`name`</b>: A name for the operation (optional). diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md index edbb8101a49..0e3abf66766 100644 --- a/tensorflow/g3doc/api_docs/python/constant_op.md +++ b/tensorflow/g3doc/api_docs/python/constant_op.md @@ -17,7 +17,7 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by * [`tf.constant(value, dtype=None, shape=None, name='Const')`](#constant) * [Sequences](#AUTOGENERATED-sequences) * [`tf.linspace(start, stop, num, name=None)`](#linspace) - * [`tf.range(start, limit, delta=1, name='range')`](#range) + * [`tf.range(start, limit=None, delta=1, name='range')`](#range) * [Random Tensors](#AUTOGENERATED-random-tensors) * [Examples:](#AUTOGENERATED-examples-) * [`tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)`](#random_normal) @@ -273,12 +273,15 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] - - - -### `tf.range(start, limit, delta=1, name='range')` <a class="md-anchor" id="range"></a> +### `tf.range(start, limit=None, delta=1, name='range')` <a class="md-anchor" id="range"></a> Creates a sequence of integers. -This operation creates a sequence of integers that begins at `start` and -extends by increments of `delta` up to but not including `limit`. +Creates a sequence of integers that begins at `start` and extends by +increments of `delta` up to but not including `limit`. + +Like the Python builtin `range`, `start` defaults to 0, so that +`range(n) = range(0, n)`. For example: @@ -287,12 +290,16 @@ For example: # 'limit' is 18 # 'delta' is 3 tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] + +# 'limit' is 5 +tf.range(limit) ==> [0, 1, 2, 3, 4] ``` ##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a> * <b>`start`</b>: A 0-D (scalar) of type `int32`. First entry in sequence. + Defaults to 0. * <b>`limit`</b>: A 0-D (scalar) of type `int32`. Upper limit of sequence, exclusive. * <b>`delta`</b>: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1. diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index 90c70bec24f..85739e6d3bd 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -143,7 +143,7 @@ This must be called by the constructors of subclasses. - - - -#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, name=None)` <a class="md-anchor" id="Optimizer.minimize"></a> +#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, name=None)` <a class="md-anchor" id="Optimizer.minimize"></a> Add operations to minimize 'loss' by updating 'var_list'. @@ -163,6 +163,8 @@ this function. under the key GraphKeys.TRAINABLE_VARIABLES. * <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be GATE_NONE, GATE_OP, or GATE_GRAPH. +* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. * <b>`name`</b>: Optional name for the returned operation. ##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a> @@ -178,7 +180,7 @@ this function. - - - -#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1)` <a class="md-anchor" id="Optimizer.compute_gradients"></a> +#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None)` <a class="md-anchor" id="Optimizer.compute_gradients"></a> Compute gradients of "loss" for the variables in "var_list". @@ -197,6 +199,8 @@ given variable. under the key GraphKey.TRAINABLE_VARIABLES. * <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be GATE_NONE, GATE_OP, or GATE_GRAPH. +* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. ##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a> diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index aa0301028a1..e5cd0165443 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -268,6 +268,34 @@ $ bazel-bin/tensorflow/cc/tutorials_example_trainer --use_gpu Note that "--config=cuda" is needed to enable the GPU support. +##### Enabling Cuda 3.0. <a class="md-anchor" id="AUTOGENERATED-enabling-cuda-3.0."></a> +TensorFlow officially supports Cuda devices with 3.5 and 5.2 compute +capabilities. In order to enable earlier Cuda devices such as Grid K520, you +need to target Cuda 3.0. This can be done through TensorFlow unofficial +settings with "configure". + +```bash +$ TF_UNOFFICIAL_SETTING=1 ./configure + +# Same as the official settings above + +WARNING: You are configuring unofficial settings in TensorFlow. Because some +external libraries are not backward compatible, these settings are largely +untested and unsupported. + +Please specify a list of comma-separated Cuda compute capabilities you want to +build with. You can find the compute capability of your device at: +https://developer.nvidia.com/cuda-gpus. +Please note that each additional compute capability significantly increases +your build time and binary size. [Default is: "3.5,5.2"]: 3.0 + +Setting up Cuda include +Setting up Cuda lib64 +Setting up Cuda bin +Setting up Cuda nvvm +Configuration finished +``` + ##### Known issues <a class="md-anchor" id="AUTOGENERATED-known-issues"></a> * Although it is possible to build both Cuda and non-Cuda configs under the same @@ -360,14 +388,20 @@ Make sure you followed the the GPU installation [instructions](#install_cuda). #### Can't find setup.py <a class="md-anchor" id="AUTOGENERATED-can-t-find-setup.py"></a> -If, during pip install, you encounter an error like: +If, during `pip install`, you encounter an error like: ```bash ... IOError: [Errno 2] No such file or directory: '/tmp/pip-o6Tpui-build/setup.py' ``` -Solution: upgrade your version of pip. +Solution: upgrade your version of `pip`: + +```bash +pip install --upgrade pip +``` + +This may require `sudo`, depending on how `pip` is installed. #### SSLError: SSL_VERIFY_FAILED <a class="md-anchor" id="AUTOGENERATED-sslerror--ssl_verify_failed"></a> diff --git a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md index 8de7b080ebd..e431f4c26d6 100644 --- a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md +++ b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md @@ -66,7 +66,7 @@ every hundred steps or so, as in the following code example. ```python merged_summary_op = tf.merge_all_summaries() -summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph) +summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph_def) total_step = 0 while training: total_step += 1 diff --git a/tensorflow/g3doc/tutorials/mnist/input_data.py b/tensorflow/g3doc/tutorials/mnist/input_data.py index 391d133ea1f..890a552010f 100644 --- a/tensorflow/g3doc/tutorials/mnist/input_data.py +++ b/tensorflow/g3doc/tutorials/mnist/input_data.py @@ -5,9 +5,9 @@ from __future__ import print_function import gzip import os -import urllib import numpy +from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' @@ -19,7 +19,7 @@ def maybe_download(filename, work_directory): os.mkdir(work_directory) filepath = os.path.join(work_directory, filename) if not os.path.exists(filepath): - filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath) + filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) statinfo = os.stat(filepath) print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') return filepath diff --git a/tensorflow/g3doc/tutorials/mnist/mnist.py b/tensorflow/g3doc/tutorials/mnist/mnist.py index 64be52293a5..925debac6ef 100644 --- a/tensorflow/g3doc/tutorials/mnist/mnist.py +++ b/tensorflow/g3doc/tutorials/mnist/mnist.py @@ -91,7 +91,7 @@ def loss(logits, labels): # be a 1.0 in the entry corresponding to the label). batch_size = tf.size(labels) labels = tf.expand_dims(labels, 1) - indices = tf.expand_dims(tf.range(0, batch_size, 1), 1) + indices = tf.expand_dims(tf.range(batch_size), 1) concated = tf.concat(1, [indices, labels]) onehot_labels = tf.sparse_to_dense( concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0) diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md index a83d3eabd52..0a7d1aeae0f 100644 --- a/tensorflow/g3doc/tutorials/mnist/pros/index.md +++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md @@ -39,7 +39,7 @@ Tensorflow relies on a highly efficient C++ backend to do its computation. The connection to this backend is called a session. The common usage for TensorFlow programs is to first create a graph and then launch it in a session. -Here we instead use the convenience `InteractiveSession` class, which +Here we instead use the convenient `InteractiveSession` class, which makes TensorFlow more flexible about how you structure your code. It allows you to interleave operations which build a @@ -232,7 +232,7 @@ print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}) ## Build a Multilayer Convolutional Network <a class="md-anchor" id="AUTOGENERATED-build-a-multilayer-convolutional-network"></a> Getting 91% accuracy on MNIST is bad. It's almost embarrassingly bad. In this -section, we'll fix that, jumping from a very simple model to something moderatly +section, we'll fix that, jumping from a very simple model to something moderately sophisticated: a small convolutional neural network. This will get us to around 99.2% accuracy -- not state of the art, but respectable. diff --git a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py index a9e0f284369..d7acb3960ef 100644 --- a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py @@ -9,9 +9,9 @@ import math import numpy as np import os import random +from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf -import urllib import zipfile # Step 1: Download the data. @@ -20,7 +20,7 @@ url = 'http://mattmahoney.net/dc/' def maybe_download(filename, expected_bytes): """Download a file if not present, and make sure it's the right size.""" if not os.path.exists(filename): - filename, _ = urllib.urlretrieve(url + filename, filename) + filename, _ = urllib.request.urlretrieve(url + filename, filename) statinfo = os.stat(filename) if statinfo.st_size == expected_bytes: print('Found and verified', filename) diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py index 8fcd7901300..627cf01b6a2 100644 --- a/tensorflow/models/image/cifar10/cifar10.py +++ b/tensorflow/models/image/cifar10/cifar10.py @@ -25,9 +25,9 @@ import os import re import sys import tarfile -import urllib import tensorflow.python.platform +from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf @@ -366,7 +366,7 @@ def loss(logits, labels): # Reshape the labels into a dense Tensor of # shape [batch_size, NUM_CLASSES]. sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1]) - indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1]) + indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1]) concated = tf.concat(1, [indices, sparse_labels]) dense_labels = tf.sparse_to_dense(concated, [FLAGS.batch_size, NUM_CLASSES], @@ -478,7 +478,8 @@ def maybe_download_and_extract(): sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, float(count * block_size) / float(total_size) * 100.0)) sys.stdout.flush() - filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) + filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, + reporthook=_progress) print() statinfo = os.stat(filepath) print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py index e388b772fe9..8e9275ab6a1 100644 --- a/tensorflow/models/image/mnist/convolutional.py +++ b/tensorflow/models/image/mnist/convolutional.py @@ -11,11 +11,11 @@ from __future__ import print_function import gzip import os import sys -import urllib import tensorflow.python.platform import numpy +from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf @@ -41,7 +41,7 @@ def maybe_download(filename): os.mkdir(WORK_DIRECTORY) filepath = os.path.join(WORK_DIRECTORY, filename) if not os.path.exists(filepath): - filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath) + filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) statinfo = os.stat(filepath) print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') return filepath diff --git a/tensorflow/models/rnn/rnn_cell_test.py b/tensorflow/models/rnn/rnn_cell_test.py index 937e1557bd6..447ddfebd43 100644 --- a/tensorflow/models/rnn/rnn_cell_test.py +++ b/tensorflow/models/rnn/rnn_cell_test.py @@ -118,7 +118,7 @@ class RNNCellTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): x = tf.zeros([1, 3]) m = tf.zeros([1, 3]) - keep = tf.zeros([1]) + 1 + keep = tf.zeros([]) + 1 g, new_m = rnn_cell.DropoutWrapper(rnn_cell.GRUCell(3), keep, keep)(x, m) sess.run([tf.variables.initialize_all_variables()]) diff --git a/tensorflow/models/rnn/seq2seq.py b/tensorflow/models/rnn/seq2seq.py index 875bcb5e6ee..63e6181a263 100644 --- a/tensorflow/models/rnn/seq2seq.py +++ b/tensorflow/models/rnn/seq2seq.py @@ -636,7 +636,7 @@ def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols, # SparseToDense does not accept batched inputs, we need to do this by # re-indexing and re-sizing. When TensorFlow adds SparseCrossEntropy, # rewrite this method. - indices = targets[i] + num_decoder_symbols * tf.range(0, batch_size) + indices = targets[i] + num_decoder_symbols * tf.range(batch_size) with tf.device("/cpu:0"): # Sparse-to-dense must happen on CPU for now. dense = tf.sparse_to_dense(indices, tf.expand_dims(length, 0), 1.0, 0.0) diff --git a/tensorflow/models/rnn/translate/data_utils.py b/tensorflow/models/rnn/translate/data_utils.py index b9d951ccd7a..00f77599af5 100644 --- a/tensorflow/models/rnn/translate/data_utils.py +++ b/tensorflow/models/rnn/translate/data_utils.py @@ -7,9 +7,9 @@ import gzip import os import re import tarfile -import urllib from tensorflow.python.platform import gfile +from six.moves import urllib # Special vocabulary symbols - we always put them at the start. _PAD = "_PAD" @@ -40,7 +40,7 @@ def maybe_download(directory, filename, url): filepath = os.path.join(directory, filename) if not os.path.exists(filepath): print("Downloading %s to %s" % (url, filepath)) - filepath, _ = urllib.urlretrieve(url, filepath) + filepath, _ = urllib.request.urlretrieve(url, filepath) statinfo = os.stat(filepath) print("Succesfully downloaded", filename, statinfo.st_size, "bytes") return filepath diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 2cbdf191c63..1b7e0eb9a9a 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -13,8 +13,16 @@ import tensorflow as tf """ -import tensorflow.python.platform -from tensorflow.core.framework.graph_pb2 import * +try: + import tensorflow.python.platform + from tensorflow.core.framework.graph_pb2 import * +except ImportError as e: + msg = """Error importing tensorflow: you should not try to import + tensorflow from its source directory; please exit the tensorflow source tree, + and relaunch your python interpreter from there. + Original ImportError: %s""" % str(e) + raise ImportError(msg) + from tensorflow.core.framework.summary_pb2 import * from tensorflow.core.framework.config_pb2 import * from tensorflow.core.util.event_pb2 import * diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 70321e76dcb..2801d588e89 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1551,6 +1551,8 @@ class Graph(object): # True if the graph is considered "finalized". In that case no # new operations can be added. self._finalized = False + # Functions defined in the graph + self._functions = [] def _check_not_finalized(self): """Check if the graph is finalized. @@ -1655,8 +1657,30 @@ class Graph(object): bytesize += op.node_def.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB.") + if self._functions: + for f in self._functions: + bytesize += f.ByteSize() + if bytesize >= (1 << 31) or bytesize < 0: + raise ValueError("GraphDef cannot be larger than 2GB.") + graph.library.function.extend(self._functions) return graph + def _add_function(self, function_def): + """Adds a function to the graph. + + The function is specified as a [`FunctionDef`] + (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) + protocol buffer. + + After the function has been added, you can call to the function by + passing the function name in place of an op name to + `Graph.create_op()`. + + Args: + function_def: A `FunctionDef` protocol buffer. + """ + self._functions.append(function_def) + # Helper functions to create operations. def create_op(self, op_type, inputs, dtypes, input_types=None, name=None, attrs=None, op_def=None, @@ -1869,7 +1893,6 @@ class Graph(object): A list of Operations. """ return list(self._nodes_by_id.values()) - def get_operation_by_name(self, name): """Returns the `Operation` with the given `name`. diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index b4959485543..16efddc6973 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -75,18 +75,28 @@ class ConcatOpTest(tf.test.TestCase): # Random dim to concat on concat_dim = np.random.randint(5) params = {} + if dtype == tf.bfloat16: + dtype_feed = tf.float32 + else: + dtype_feed = dtype with self.test_session(use_gpu=use_gpu): p = [] for i in np.arange(num_tensors): input_shape = shape input_shape[concat_dim] = np.random.randint(1, 5) - placeholder = tf.placeholder(dtype, shape=input_shape) + placeholder = tf.placeholder(dtype_feed, shape=input_shape) p.append(placeholder) - t = dtype.as_numpy_dtype + t = dtype_feed.as_numpy_dtype params[placeholder] = np.random.rand(*input_shape).astype(t) - c = tf.concat(concat_dim, p) + if dtype != dtype_feed: + concat_inputs = [tf.cast(p_i, dtype) for p_i in p] + else: + concat_inputs = p + c = tf.concat(concat_dim, concat_inputs) + if dtype != dtype_feed: + c = tf.cast(c, dtype_feed) result = c.eval(feed_dict=params) self.assertEqual(result.shape, c.get_shape()) @@ -100,15 +110,17 @@ class ConcatOpTest(tf.test.TestCase): ind[concat_dim] = slice(cur_offset, cur_offset + params[p[i]].shape[concat_dim]) cur_offset += params[p[i]].shape[concat_dim] - self.assertAllEqual(result[ind], params[p[i]]) + if dtype == dtype_feed: + self.assertAllEqual(result[ind], params[p[i]]) + else: + self.assertAllClose(result[ind], params[p[i]], 0.01) def testRandom(self): self._testRandom(tf.float32) self._testRandom(tf.int16) self._testRandom(tf.int32, use_gpu=True) - # Note that the following does not work since bfloat16 is not supported in - # numpy. - # self._testRandom(tf.bfloat16) + self._testRandom(tf.bfloat16) + self._testRandom(tf.bfloat16, use_gpu=True) def _testGradientsSimple(self, use_gpu): with self.test_session(use_gpu=use_gpu): diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 03844d61771..5a987c912c6 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -262,7 +262,7 @@ class EmbeddingLookupTest(tf.test.TestCase): self.assertAllEqual(simple, tf.gather(params, ids).eval()) # Run a few random sharded versions for procs in 1, 2, 3: - stride = procs * tf.range(0, params.shape[0] // procs) + stride = procs * tf.range(params.shape[0] // procs) split_params = [tf.gather(params, stride + p) for p in xrange(procs)] sharded = tf.nn.embedding_lookup(split_params, ids).eval() diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index 1b9f1323e8b..8036989a0e6 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -190,6 +190,10 @@ class RangeTest(tf.test.TestCase): self._Range(100, 500, 100), np.array([100, 200, 300, 400]))) self.assertEqual(tf.range(0, 5, 1).dtype, tf.int32) + def testLimitOnly(self): + with self.test_session(): + self.assertAllEqual(np.arange(5), tf.range(5).eval()) + def testEmpty(self): for start in 0, 5: self.assertTrue(np.array_equal(self._Range(start, start, 1), [])) diff --git a/tensorflow/python/kernel_tests/lookup_table_op_test.py b/tensorflow/python/kernel_tests/lookup_table_op_test.py deleted file mode 100644 index 7b5942cacd4..00000000000 --- a/tensorflow/python/kernel_tests/lookup_table_op_test.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Tests for lookup table ops from tf.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow.python.platform - -import numpy as np -import tensorflow as tf - - -class HashTableOpTest(tf.test.TestCase): - - def testHashTable(self): - with self.test_session(): - shared_name = '' - default_val = -1 - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - # Initialize with keys and values tensors. - keys = tf.constant(['brain', 'salad', 'surgery']) - values = tf.constant([0, 1, 2], tf.int64) - init = table.initialize_from(keys, values) - init.run() - self.assertAllEqual(3, table.size().eval()) - - input_string = tf.constant(['brain', 'salad', 'tank']) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testHashTableFindHighRank(self): - with self.test_session(): - shared_name = '' - default_val = -1 - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - # Initialize with keys and values tensors. - keys = tf.constant(['brain', 'salad', 'surgery']) - values = tf.constant([0, 1, 2], tf.int64) - init = table.initialize_from(keys, values) - init.run() - self.assertAllEqual(3, table.size().eval()) - - input_string = tf.constant([['brain', 'salad'], ['tank', 'tarkus']]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([[0, 1], [-1, -1]], result) - - def testHashTableInitWithPythonArrays(self): - with self.test_session(): - shared_name = '' - default_val = -1 - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - # Empty table. - self.assertAllEqual(0, table.size().eval()) - - # Initialize with keys and values tensors. - keys = ['brain', 'salad', 'surgery'] - values = [0, 1, 2] - init = table.initialize_from(keys, values) - init.run() - self.assertAllEqual(3, table.size().eval()) - - input_string = tf.constant(['brain', 'salad', 'tank']) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testHashTableInitWithNumPyArrays(self): - with self.test_session(): - shared_name = '' - default_val = -1 - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - # Initialize with keys and values tensors. - keys = np.array(['brain', 'salad', 'surgery'], dtype=np.str) - values = np.array([0, 1, 2], dtype=np.int64) - init = table.initialize_from(keys, values) - init.run() - self.assertAllEqual(3, table.size().eval()) - - input_string = tf.constant(['brain', 'salad', 'tank']) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testMultipleHashTables(self): - with self.test_session() as sess: - shared_name = '' - default_val = -1 - table1 = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - table2 = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - table3 = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - keys = tf.constant(['brain', 'salad', 'surgery']) - values = tf.constant([0, 1, 2], tf.int64) - table1.initialize_from(keys, values) - table2.initialize_from(keys, values) - table3.initialize_from(keys, values) - - tf.initialize_all_tables().run() - self.assertAllEqual(3, table1.size().eval()) - self.assertAllEqual(3, table2.size().eval()) - self.assertAllEqual(3, table3.size().eval()) - - input_string = tf.constant(['brain', 'salad', 'tank']) - output1 = table1.lookup(input_string) - output2 = table2.lookup(input_string) - output3 = table3.lookup(input_string) - - out1, out2, out3 = sess.run([output1, output2, output3]) - self.assertAllEqual([0, 1, -1], out1) - self.assertAllEqual([0, 1, -1], out2) - self.assertAllEqual([0, 1, -1], out3) - - def testHashTableWithTensorDefault(self): - with self.test_session(): - shared_name = '' - default_val = tf.constant(-1, tf.int64) - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - # Initialize with keys and values tensors. - keys = tf.constant(['brain', 'salad', 'surgery']) - values = tf.constant([0, 1, 2], tf.int64) - init = table.initialize_from(keys, values) - init.run() - - input_string = tf.constant(['brain', 'salad', 'tank']) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testSignatureMismatch(self): - with self.test_session(): - shared_name = '' - default_val = -1 - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - # Initialize with keys and values tensors. - keys = tf.constant(['brain', 'salad', 'surgery']) - values = tf.constant([0, 1, 2], tf.int64) - init = table.initialize_from(keys, values) - init.run() - - input_string = tf.constant([1, 2, 3], tf.int64) - with self.assertRaises(TypeError): - table.lookup(input_string) - - with self.assertRaises(TypeError): - tf.HashTable(tf.string, tf.int64, 'UNK', shared_name) - - def testDTypes(self): - with self.test_session(): - shared_name = '' - default_val = -1 - with self.assertRaises(TypeError): - tf.HashTable([tf.string], tf.string, default_val, shared_name) - - def testNotInitialized(self): - with self.test_session(): - shared_name = '' - default_val = -1 - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - input_string = tf.constant(['brain', 'salad', 'surgery']) - output = table.lookup(input_string) - - with self.assertRaisesOpError('Table not initialized'): - output.eval() - - def testInitializeTwice(self): - with self.test_session(): - shared_name = '' - default_val = -1 - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - # Initialize with keys and values tensors. - keys = tf.constant(['brain', 'salad', 'surgery']) - values = tf.constant([0, 1, 2], tf.int64) - init = table.initialize_from(keys, values) - init.run() - - with self.assertRaisesOpError('Table already initialized'): - init.run() - - def testInitializationWithInvalidDimensions(self): - with self.test_session(): - shared_name = '' - default_val = -1 - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - # Initialize with keys and values tensors. - keys = tf.constant(['brain', 'salad', 'surgery']) - values = tf.constant([0, 1, 2, 3, 4], tf.int64) - with self.assertRaises(ValueError): - table.initialize_from(keys, values) - - def testInitializationWithInvalidDataTypes(self): - with self.test_session(): - shared_name = '' - default_val = -1 - table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) - - # Initialize with keys and values tensors. - keys = [0, 1, 2] - values = ['brain', 'salad', 'surgery'] - with self.assertRaises(TypeError): - table.initialize_from(keys, values) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index 3dedd33cb9a..9893c9b824b 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -74,7 +74,7 @@ def clip_by_norm(t, clip_norm, name=None): # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm l2norm_inv = math_ops.rsqrt( - math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t)))) + math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t)))) tclip = array_ops.identity(t * clip_norm * math_ops.minimum( l2norm_inv, constant_op.constant(1.0 / clip_norm)), name=name) @@ -228,7 +228,7 @@ def clip_by_average_norm(t, clip_norm, name=None): # L2-norm per element n_element = math_ops.cast(array_ops.size(t), types.float32) l2norm_inv = math_ops.rsqrt( - math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t)))) + math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t)))) tclip = array_ops.identity( t * clip_norm * math_ops.minimum( l2norm_inv * n_element, constant_op.constant(1.0 / clip_norm)), diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 178f716e48a..ed09ff3655e 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -10,7 +10,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import common_shapes from tensorflow.python.ops import control_flow_ops @@ -408,174 +407,12 @@ class FIFOQueue(QueueBase): # TODO(josh11b): class BatchQueue(QueueBase): -# pylint: disable=protected-access -class LookupTableBase(object): - """Represents a lookup table that persists across different steps.""" - - def __init__(self, key_dtype, value_dtype, default_value, table_ref): - """Construct a table object from a table reference. - - Args: - key_dtype: The table key type. - value_dtype: The table value type. - default_value: The value to use if a key is missing in the table. - table_ref: The table reference, i.e. the output of the lookup table ops. - """ - self._key_dtype = types.as_dtype(key_dtype) - self._value_dtype = types.as_dtype(value_dtype) - self._shapes = [tensor_shape.TensorShape([1])] - self._table_ref = table_ref - self._name = self._table_ref.op.name.split("/")[-1] - self._default_value = ops.convert_to_tensor(default_value, - dtype=self._value_dtype) - self._default_value.get_shape().merge_with(tensor_shape.scalar()) - - @property - def table_ref(self): - """Get the underlying table reference.""" - return self._table_ref - - @property - def key_dtype(self): - """The table key dtype.""" - return self._key_dtype - - @property - def value_dtype(self): - """The table value dtype.""" - return self._value_dtype - - @property - def name(self): - """The name of the table.""" - return self._name - - @property - def default_value(self): - """The default value of the table.""" - return self._default_value - - def size(self, name=None): - """Compute the number of elements in this table. - - Args: - name: A name for the operation (optional). - - Returns: - A scalar tensor containing the number of elements in this table. - """ - if name is None: - name = "%s_Size" % self._name - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) - - def lookup(self, keys, name=None): - """Looks up `keys` in a table, outputs the corresponding values. - - The `default_value` is use for keys not present in the table. - - Args: - keys: Keys to look up. - name: Optional name for the op. - - Returns: - The operation that looks up the keys. - - Raises: - TypeError: when `keys` or `default_value` doesn't match the table data - types. - """ - if name is None: - name = "%s_lookup_table_find" % self._name - - if keys.dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % ( - self._key_dtype, keys.dtype)) - - return gen_data_flow_ops._lookup_table_find( - self._table_ref, keys, self._default_value, name=name) - - def initialize_from(self, keys, values, name=None): - """Initialize the table with the provided keys and values tensors. - - Construct an initializer object from keys and value tensors. - - Args: - keys: The tensor for the keys. - values: The tensor for the values. - name: Optional name for the op. - - Returns: - The operation that initializes the table. - - Raises: - TypeError: when the keys and values data types do not match the table - key and value data types. - """ - if name is None: - name = "%s_initialize_table" % self.name - with ops.op_scope([keys, values], None, name): - keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys") - values = ops.convert_to_tensor(values, dtype=self.value_dtype, - name="values") - - init_op = gen_data_flow_ops._initialize_table( - self.table_ref, keys, values, name=name) - ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) - return init_op - - def _check_table_dtypes(self, key_dtype, value_dtype): - """Check that the given key_dtype and value_dtype matches the table dtypes'. - - Args: - key_dtype: The key data type to check. - value_dtype: The value data type to check. - - Raises: - TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data - types. - """ - if key_dtype != self.key_dtype: - raise TypeError("Invalid key dtype, expected %s but got %s." % ( - self.key_dtype, key_dtype)) - if value_dtype != self.value_dtype: - raise TypeError("Invalid value dtype, expected %s but got %s." % ( - self.value_dtype, value_dtype)) - - -class HashTable(LookupTableBase): - """A generic hash table implementation.""" - - def __init__(self, key_dtype, value_dtype, default_value, shared_name=None, - name="hash_table"): - """Creates a non-initialized hash table. - - This op creates a hash table, specifying the type of its keys and values. - Before using the table you will have to initialize it. After initialization - the table will be immutable. - - Args: - key_dtype: Type of the table keys. - value_dtype: Type of the table values. - default_value: The scalar tensor to be used when a key is missing in the - table. - shared_name: Optional. If non-empty, this table will be shared under - the given name across multiple sessions. - name: Optional name for the hash table op. - - Returns: - A `HashTable` object. - """ - table_ref = gen_data_flow_ops._hash_table( - shared_name=shared_name, key_dtype=key_dtype, - value_dtype=value_dtype, name=name) - - super(HashTable, self).__init__(key_dtype, value_dtype, default_value, - table_ref) - - def initialize_all_tables(name="init_all_tables"): """Returns an Op that initializes all tables of the default graph. + Args: + name: Optional name for the initialization op. + Returns: An Op that initializes all tables. Note that if there are not tables the returned Op is a NoOp. diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 80bedd49846..b74f8f5426a 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -51,7 +51,7 @@ def embedding_lookup(params, ids, name=None): else: ids = ops.convert_to_tensor(ids, name="ids") flat_ids = array_ops.reshape(ids, [-1]) - original_indices = math_ops.range(0, array_ops.size(flat_ids)) + original_indices = math_ops.range(array_ops.size(flat_ids)) # Compute flat_ids % partitions for each id ids_mod_p = flat_ids % np if ids_mod_p.dtype != types.int32: diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index b404fbc7d70..b8b0741e07a 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -21,7 +21,7 @@ def _ReductionGradAssist(op): indices = op.inputs[1] # [1, 2] indices_shape = array_ops.shape(indices) # [2] new_output_shape = data_flow_ops.dynamic_stitch( # [2, 1, 1, 7] - [math_ops.range(0, input_rank), # [0, 1, 2, 3] + [math_ops.range(input_rank), # [0, 1, 2, 3] indices], # [1, 2] [input_shape, # [2, 3, 5, 7] array_ops.fill(indices_shape, 1)]) # [1, 1] diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index f7289ff2340..4a2473cae50 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -536,11 +536,14 @@ ops.Tensor._override_operator("__gt__", greater) ops.Tensor._override_operator("__ge__", greater_equal) -def range(start, limit, delta=1, name="range"): +def range(start, limit=None, delta=1, name="range"): """Creates a sequence of integers. - This operation creates a sequence of integers that begins at `start` and - extends by increments of `delta` up to but not including `limit`. + Creates a sequence of integers that begins at `start` and extends by + increments of `delta` up to but not including `limit`. + + Like the Python builtin `range`, `start` defaults to 0, so that + `range(n) = range(0, n)`. For example: @@ -549,10 +552,14 @@ def range(start, limit, delta=1, name="range"): # 'limit' is 18 # 'delta' is 3 tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] + + # 'limit' is 5 + tf.range(limit) ==> [0, 1, 2, 3, 4] ``` Args: start: A 0-D (scalar) of type `int32`. First entry in sequence. + Defaults to 0. limit: A 0-D (scalar) of type `int32`. Upper limit of sequence, exclusive. delta: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1. @@ -562,6 +569,8 @@ def range(start, limit, delta=1, name="range"): Returns: An 1-D `int32` `Tensor`. """ + if limit is None: + start, limit = 0, start return gen_math_ops._range(start, limit, delta, name=name) diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index caf47b1431c..5a5c06f975e 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -173,6 +173,7 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import candidate_sampling_ops @@ -347,7 +348,8 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): Args: x: A tensor. - keep_prob: A Python float. The probability that each element is kept. + keep_prob: A scalar `Tensor` with the same type as x. The probability + that each element is kept. noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for randomly generated keep/drop flags. seed: A Python integer. Used to create random seeds. See @@ -361,10 +363,15 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): Raises: ValueError: If `keep_prob` is not in `(0, 1]`. """ - if not (0 < keep_prob <= 1): - raise ValueError("Expected keep_prob in (0, 1], got %g" % keep_prob) with ops.op_scope([x], name, "dropout") as name: x = ops.convert_to_tensor(x, name="x") + if isinstance(keep_prob, float) and not(0 < keep_prob <= 1): + raise ValueError("keep_prob must be a scalar tensor or a float in the " + "range (0, 1], got %g" % keep_prob) + keep_prob = ops.convert_to_tensor( + keep_prob, dtype=x.dtype, name="keep_prob") + keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) + noise_shape = noise_shape or array_ops.shape(x) # uniform [keep_prob, 1.0 + keep_prob) random_tensor = keep_prob @@ -372,7 +379,9 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): noise_shape, seed=seed, dtype=x.dtype) # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) binary_tensor = math_ops.floor(random_tensor) - return x * (1.0 / keep_prob) * binary_tensor + ret = x * math_ops.inv(keep_prob) * binary_tensor + ret.set_shape(x.get_shape()) + return ret def depthwise_conv2d(input, filter, strides, padding, name=None): diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 535d55f00f0..919ce784918 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -85,7 +85,7 @@ def _BiasAddGrad(unused_bias_op, received_grad): Two tensors, the first one for the "tensor" input of the BiasOp, the second one for the "bias" input of the BiasOp. """ - reduction_dim_tensor = math_ops.range(0, array_ops.rank(received_grad) - 1) + reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1) return (received_grad, math_ops.reduce_sum(received_grad, reduction_dim_tensor)) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 48f7a4c9876..c24bd0e3726 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -405,28 +405,82 @@ class DropoutTest(test_util.TensorFlowTestCase): sorted_value = np.unique(np.sort(value[i, :])) self.assertEqual(sorted_value.size, 1) + def testDropoutPlaceholderKeepProb(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + with self.test_session(): + t = constant_op.constant(1.0, + shape=[x_dim, y_dim], + dtype=types.float32) + keep_prob_placeholder = array_ops.placeholder(types.float32) + dropout = nn.dropout(t, keep_prob_placeholder) + final_count = 0 + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob}) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + def testShapedDropoutUnknownShape(self): + x_dim = 40 + y_dim = 30 + keep_prob = 0.5 + x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=types.float32) + dropout_x = nn.dropout( + x, keep_prob, noise_shape=array_ops.placeholder(types.int32)) + self.assertEqual(x.get_shape(), dropout_x.get_shape()) + + def testInvalidKeepProb(self): + x_dim = 40 + y_dim = 30 + t = constant_op.constant(1.0, + shape=[x_dim, y_dim], + dtype=types.float32) + with self.assertRaises(ValueError): + nn.dropout(t, -1.0) + with self.assertRaises(ValueError): + nn.dropout(t, 1.1) + with self.assertRaises(ValueError): + nn.dropout(t, [0.0, 1.0]) + with self.assertRaises(ValueError): + nn.dropout(t, array_ops.placeholder(types.float64)) + with self.assertRaises(ValueError): + nn.dropout(t, array_ops.placeholder(types.float32, shape=[2])) + def testShapedDropoutShapeError(self): # Runs shaped dropout and verifies an error is thrown on misshapen noise. x_dim = 40 y_dim = 30 keep_prob = 0.5 - with self.test_session(): - t = constant_op.constant(1.0, - shape=[x_dim, y_dim], - dtype=types.float32) - with self.assertRaises(ValueError): - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) - with self.assertRaises(ValueError): - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5]) - with self.assertRaises(ValueError): - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3]) - with self.assertRaises(ValueError): - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim]) - # test that broadcasting proceeds - _ = nn.dropout(t, keep_prob, noise_shape=[y_dim]) - _ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim]) - _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) - _ = nn.dropout(t, keep_prob, noise_shape=[1, 1]) + t = constant_op.constant(1.0, + shape=[x_dim, y_dim], + dtype=types.float32) + with self.assertRaises(ValueError): + _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) + with self.assertRaises(ValueError): + _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5]) + with self.assertRaises(ValueError): + _ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3]) + with self.assertRaises(ValueError): + _ = nn.dropout(t, keep_prob, noise_shape=[x_dim]) + # test that broadcasting proceeds + _ = nn.dropout(t, keep_prob, noise_shape=[y_dim]) + _ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim]) + _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + _ = nn.dropout(t, keep_prob, noise_shape=[1, 1]) class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 1a7af78a33a..23d3a974cfb 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -437,8 +437,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None): default_value, dtype=sp_input.values.dtype) num_rows = math_ops.cast(sp_input.shape[0], types.int32) - all_row_indices = math_ops.cast( - math_ops.range(0, num_rows, 1), types.int64) + all_row_indices = math_ops.cast(math_ops.range(num_rows), types.int64) empty_row_indices, _ = array_ops.list_diff( all_row_indices, sp_input.indices[:, 0]) empty_row_indicator = gen_sparse_ops.sparse_to_dense( diff --git a/tensorflow/python/platform/default/_logging.py b/tensorflow/python/platform/default/_logging.py index 66bf2c08890..23318691ee8 100644 --- a/tensorflow/python/platform/default/_logging.py +++ b/tensorflow/python/platform/default/_logging.py @@ -6,18 +6,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import logging import os import sys import time import thread -from logging import getLogger -from logging import log -from logging import debug -from logging import error -from logging import fatal -from logging import info -from logging import warn -from logging import warning from logging import DEBUG from logging import ERROR from logging import FATAL @@ -25,13 +18,25 @@ from logging import INFO from logging import WARN # Controls which methods from pyglib.logging are available within the project -# Do not add methods here without also adding to platform/default/_logging.py +# Do not add methods here without also adding to platform/google/_logging.py __all__ = ['log', 'debug', 'error', 'fatal', 'info', 'warn', 'warning', 'DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN', 'flush', 'log_every_n', 'log_first_n', 'vlog', 'TaskLevelStatusMessage', 'get_verbosity', 'set_verbosity'] -warning = warn +# Scope the tensorflow logger to not conflict with users' loggers +_logger = logging.getLogger('tensorflow') +_handler = logging.StreamHandler() +_handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT, None)) +_logger.addHandler(_handler) + +log = _logger.log +debug = _logger.debug +error = _logger.error +fatal = _logger.fatal +info = _logger.info +warn = _logger.warn +warning = _logger.warn _level_names = { FATAL: 'FATAL', @@ -61,7 +66,7 @@ def flush(): # Code below is taken from pyglib/logging def vlog(level, msg, *args, **kwargs): - log(level, msg, *args, **kwargs) + _logger.log(level, msg, *args, **kwargs) def _GetNextLogCountPerToken(token): @@ -169,12 +174,12 @@ def google2_log_prefix(level, timestamp=None, file_and_line=None): def get_verbosity(): """Return how much logging output will be produced.""" - return getLogger().getEffectiveLevel() + return _logger.getEffectiveLevel() def set_verbosity(verbosity): """Sets the threshold for what messages will be logged.""" - getLogger().setLevel(verbosity) + _logger.setLevel(verbosity) def _get_thread_id(): diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 6734690397c..77c496fb85c 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -163,7 +163,7 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None, is added to the current Graph's QUEUE_RUNNER collection. """ with ops.op_scope([limit], name, "input_producer") as name: - range_tensor = math_ops.range(0, limit) + range_tensor = math_ops.range(limit) return _input_producer( range_tensor, types.int32, num_epochs, shuffle, seed, capacity, name, "fraction_of_%d_full" % capacity) diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index 8450fae5cb6..c1f886d6644 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -33,9 +33,9 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, ... global_step = tf.Variable(0, trainable=False) starter_learning_rate = 0.1 - learning_rate = tf.exponential_decay(starter_learning_rate, global_step, - 100000, 0.96, staircase=True) - optimizer = tf.GradientDescent(learning_rate) + learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, + 100000, 0.96, staircase=True) + optimizer = tf.GradientDescentOptimizer(learning_rate) # Passing global_step to minimize() will increment it at each step. optimizer.minimize(...my loss..., global_step=global_step) ``` diff --git a/tensorflow/tensorboard/bower.json b/tensorflow/tensorboard/bower.json index bdd16d662aa..995ba30363f 100644 --- a/tensorflow/tensorboard/bower.json +++ b/tensorflow/tensorboard/bower.json @@ -19,26 +19,27 @@ "dagre": "~0.7.4", "es6-promise": "~3.0.2", "graphlib": "~1.0.7", - "iron-ajax": "PolymerElements/iron-ajax#~1.0.8", - "iron-collapse": "PolymerElements/iron-collapse#~1.0.4", - "iron-list": "PolymerElements/iron-list#~1.1.5", - "paper-button": "PolymerElements/paper-button#~1.0.7", - "paper-checkbox": "PolymerElements/paper-checkbox#~1.0.6", - "paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#~1.0.4", - "paper-header-panel": "PolymerElements/paper-header-panel#~1.0.5", - "paper-icon-button": "PolymerElements/paper-icon-button#~1.0.3", - "paper-input": "PolymerElements/paper-input#~1.0.15", - "paper-item": "PolymerElements/paper-item#~1.0.3", - "paper-menu": "PolymerElements/paper-menu#~1.1.1", - "paper-progress": "PolymerElements/paper-progress#~1.0.7", - "paper-radio-button": "PolymerElements/paper-radio-button#~1.0.8", - "paper-radio-group": "PolymerElements/paper-radio-group#~1.0.4", - "paper-slider": "PolymerElements/paper-slider#~1.0.4", - "paper-styles": "PolymerElements/paper-styles#~1.0.11", - "paper-toggle-button": "PolymerElements/paper-toggle-button#~1.0.6", - "paper-toolbar": "PolymerElements/paper-toolbar#~1.0.4", + "iron-ajax": "PolymerElements/iron-ajax#1.0.7", + "iron-collapse": "PolymerElements/iron-collapse#1.0.4", + "iron-list": "PolymerElements/iron-list#1.1.5", + "iron-selector": "PolymerElements/iron-selector#1.0.7", + "paper-button": "PolymerElements/paper-button#1.0.8", + "paper-checkbox": "PolymerElements/paper-checkbox#1.0.13", + "paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#1.0.5", + "paper-header-panel": "PolymerElements/paper-header-panel#1.0.5", + "paper-icon-button": "PolymerElements/paper-icon-button#1.0.5", + "paper-input": "PolymerElements/paper-input#1.0.16", + "paper-item": "PolymerElements/paper-item#1.0.5", + "paper-menu": "PolymerElements/paper-menu#1.1.1", + "paper-progress": "PolymerElements/paper-progress#1.0.7", + "paper-radio-button": "PolymerElements/paper-radio-button#1.0.10", + "paper-radio-group": "PolymerElements/paper-radio-group#1.0.6", + "paper-slider": "PolymerElements/paper-slider#1.0.7", + "paper-styles": "PolymerElements/paper-styles#1.0.12", + "paper-toggle-button": "PolymerElements/paper-toggle-button#1.0.11", + "paper-toolbar": "PolymerElements/paper-toolbar#1.0.4", "plottable": "~1.16.1", - "polymer": "~1.2.0" + "polymer": "1.1.5" }, "devDependencies": { "iron-component-page": "PolymerElements/iron-component-page#^1.0.0", diff --git a/tensorflow/tensorboard/tensorboard_handler.py b/tensorflow/tensorboard/tensorboard_handler.py index 0ea7f3a58d5..2cec3b8812b 100644 --- a/tensorflow/tensorboard/tensorboard_handler.py +++ b/tensorflow/tensorboard/tensorboard_handler.py @@ -17,9 +17,9 @@ import json import mimetypes import os import StringIO -import urllib import urlparse +from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin from google.protobuf import text_format import tensorflow.python.platform @@ -289,7 +289,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler): A string representation of a URL that will load the index-th sampled image in the given run with the given tag. """ - query_string = urllib.urlencode({ + query_string = urllib.parse.urlencode({ 'run': run, 'tag': tag, 'index': index