TensorFlow: Upstream changes from afternoon.
Changes: - Ptrdiff -> DenseIndex change by @jiayq - Fix to scoping the logging in logging.py by @dga - Improvement to Conv2DBackpropFilter on CPU by Andy - Remove lookup table wrappers for the time being (wasn't in our public API yet) by Yukata - Add a check similar to numpy to make sure the user isn't in the tensorflow src directory by @vrv - More changes for python 3 compat by @girving - Make dropout preserve shape info from input (@mrry) - Significant speed improvements by @zheng-xq to BFC allocator to bring on par (CPU overhead-wise) to the region allocator. Make BFC allocator the default now that it's working well for a variety of models. - Fix a bunch of typos reported by users (@vrv) - Enable concat for bfloat16 on GPU by Ashish. Base CL: 107733123
This commit is contained in:
parent
4dffee7f62
commit
d50565b35e
tensorflow
core
common_runtime
kernels
ops
g3doc
api_docs/python
get_started
how_tos/summaries_and_tensorboard
tutorials
models
python
__init__.py
framework
kernel_tests
ops
clip_ops.pydata_flow_ops.pyembedding_ops.pymath_grad.pymath_ops.pynn.pynn_grad.pynn_test.pysparse_ops.py
platform/default
training
tensorboard
@ -495,6 +495,11 @@ static void SimplifyGraph(Graph* g) {
|
||||
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
|
||||
DumpGraph("Initial", *g);
|
||||
|
||||
// Run SimplifyGraph at least once to rewrite away ops such as
|
||||
// _ListToArray, _ArrayToList, etc.
|
||||
SimplifyGraph(*g);
|
||||
|
||||
const int kNumInlineRounds = 10;
|
||||
for (int i = 0; i < kNumInlineRounds; ++i) {
|
||||
if (!ExpandInlineFunctions(lib, *g)) break;
|
||||
|
@ -65,7 +65,7 @@ GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory)
|
||||
ptr_to_chunk_map_.insert(std::make_pair(c->ptr, c));
|
||||
|
||||
// Insert the chunk into the right bin.
|
||||
ReassignChunkToBin(c);
|
||||
InsertFreeChunkIntoBin(c);
|
||||
}
|
||||
|
||||
GPUBFCAllocator::~GPUBFCAllocator() {
|
||||
@ -76,6 +76,7 @@ GPUBFCAllocator::~GPUBFCAllocator() {
|
||||
}
|
||||
|
||||
gtl::STLDeleteValues(&bins_);
|
||||
gtl::STLDeleteValues(&ptr_to_chunk_map_);
|
||||
}
|
||||
|
||||
void* GPUBFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) {
|
||||
@ -115,10 +116,12 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
||||
// Start searching from the first bin for the smallest chunk that fits
|
||||
// rounded_bytes.
|
||||
Bin* b = it->second;
|
||||
for (GPUBFCAllocator::Chunk* chunk : b->chunks) {
|
||||
if (!chunk->in_use && chunk->size > rounded_bytes) {
|
||||
// We found an existing chunk that fits us that wasn't in use.
|
||||
chunk->in_use = true;
|
||||
for (GPUBFCAllocator::Chunk* chunk : b->free_chunks) {
|
||||
DCHECK(!chunk->in_use);
|
||||
if (chunk->size >= rounded_bytes) {
|
||||
// We found an existing chunk that fits us that wasn't in use, so remove
|
||||
// it from the free bin structure prior to using.
|
||||
RemoveFreeChunkFromBin(chunk);
|
||||
|
||||
// If we can break the size of the chunk into two reasonably
|
||||
// large pieces, do so.
|
||||
@ -132,6 +135,7 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
||||
// The requested size of the returned chunk is what the user
|
||||
// has allocated.
|
||||
chunk->requested_size = num_bytes;
|
||||
chunk->in_use = true;
|
||||
|
||||
VLOG(4) << "Returning: " << chunk->ptr;
|
||||
return chunk->ptr;
|
||||
@ -152,6 +156,8 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
||||
}
|
||||
|
||||
void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
|
||||
CHECK(!c->in_use && !c->bin);
|
||||
|
||||
// Create a new chunk starting num_bytes after c
|
||||
GPUBFCAllocator::Chunk* new_chunk = new GPUBFCAllocator::Chunk();
|
||||
new_chunk->ptr = static_cast<void*>(static_cast<char*>(c->ptr) + num_bytes);
|
||||
@ -176,9 +182,8 @@ void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
|
||||
c_neighbor->prev = new_chunk;
|
||||
}
|
||||
|
||||
// Maintain the bins
|
||||
ReassignChunkToBin(new_chunk);
|
||||
ReassignChunkToBin(c);
|
||||
// Add the newly free chunk to the free bin.
|
||||
InsertFreeChunkIntoBin(new_chunk);
|
||||
}
|
||||
|
||||
void GPUBFCAllocator::DeallocateRaw(void* ptr) {
|
||||
@ -200,11 +205,9 @@ void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {
|
||||
|
||||
GPUBFCAllocator::Chunk* c = it->second;
|
||||
VLOG(6) << "Chunk at " << c->ptr << " no longer in use";
|
||||
// Mark the chunk as no longer in use
|
||||
c->in_use = false;
|
||||
|
||||
// Consider coalescing it.
|
||||
MaybeCoalesce(c);
|
||||
FreeAndMaybeCoalesce(c);
|
||||
}
|
||||
|
||||
// Merges c1 and c2 when c1->next is c2 and c2->prev is c1.
|
||||
@ -212,7 +215,7 @@ void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {
|
||||
void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
|
||||
GPUBFCAllocator::Chunk* c2) {
|
||||
// We can only merge chunks that are not in use.
|
||||
DCHECK(!c1->in_use && !c2->in_use);
|
||||
CHECK(!c1->in_use && !c2->in_use);
|
||||
|
||||
// c1's prev doesn't change, still points to the same ptr, and is
|
||||
// still not in use.
|
||||
@ -231,62 +234,42 @@ void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
|
||||
// Set the new size
|
||||
c1->size += c2->size;
|
||||
|
||||
// Delete c2 and cleanup all state
|
||||
RemoveChunkFromBin(c2);
|
||||
DeleteChunk(c2);
|
||||
}
|
||||
|
||||
void GPUBFCAllocator::ReassignChunkToBin(GPUBFCAllocator::Chunk* c) {
|
||||
void GPUBFCAllocator::DeleteChunk(Chunk* c) {
|
||||
// Delete c2 and cleanup all state
|
||||
VLOG(4) << "Removing: " << c->ptr;
|
||||
ptr_to_chunk_map_.erase(c->ptr);
|
||||
delete c;
|
||||
}
|
||||
|
||||
void GPUBFCAllocator::InsertFreeChunkIntoBin(GPUBFCAllocator::Chunk* c) {
|
||||
CHECK(!c->in_use && !c->bin);
|
||||
auto it = bins_.lower_bound(c->size);
|
||||
CHECK(it != bins_.end()) << " Tried to reassign to non-existent bin for size "
|
||||
<< c->size;
|
||||
|
||||
Bin* new_bin = it->second;
|
||||
|
||||
// If the bin has not changed, do nothing.
|
||||
Bin* old_bin = c->bin;
|
||||
if (old_bin != nullptr && new_bin == old_bin) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The bin has changed. Add the chunk to the new bin and remove
|
||||
// the chunk from the old bin.
|
||||
new_bin->chunks.insert(c);
|
||||
c->bin = new_bin;
|
||||
|
||||
if (old_bin == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove chunk from old bin
|
||||
for (auto it = old_bin->chunks.begin(); it != old_bin->chunks.end(); ++it) {
|
||||
if (*it == c) {
|
||||
old_bin->chunks.erase(it);
|
||||
return;
|
||||
}
|
||||
}
|
||||
CHECK(false) << "Could not find chunk in old bin";
|
||||
new_bin->free_chunks.insert(c);
|
||||
}
|
||||
|
||||
void GPUBFCAllocator::RemoveChunkFromBin(GPUBFCAllocator::Chunk* c) {
|
||||
Bin* b = c->bin;
|
||||
for (auto it = b->chunks.begin(); it != b->chunks.end(); ++it) {
|
||||
Chunk* other_c = *it;
|
||||
if (other_c->ptr == c->ptr) {
|
||||
b->chunks.erase(it);
|
||||
VLOG(4) << "Removing: " << c->ptr;
|
||||
ptr_to_chunk_map_.erase(c->ptr);
|
||||
delete c;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(false) << "Could not find chunk in bin";
|
||||
void GPUBFCAllocator::RemoveFreeChunkFromBin(GPUBFCAllocator::Chunk* c) {
|
||||
CHECK(!c->in_use && c->bin);
|
||||
int count = c->bin->free_chunks.erase(c);
|
||||
CHECK(count > 0) << "Could not find chunk in bin";
|
||||
c->bin = nullptr;
|
||||
}
|
||||
|
||||
void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
|
||||
void GPUBFCAllocator::FreeAndMaybeCoalesce(GPUBFCAllocator::Chunk* c) {
|
||||
CHECK(c->in_use && !c->bin);
|
||||
|
||||
// Mark the chunk as no longer in use
|
||||
c->in_use = false;
|
||||
|
||||
// This chunk is no longer in-use, consider coalescing the chunk
|
||||
// with adjacent chunks.
|
||||
Chunk* chunk_to_reassign = nullptr;
|
||||
Chunk* chunk_to_reassign = c;
|
||||
|
||||
// If the next chunk is free, coalesce the two, if the result would
|
||||
// fit in an existing bin.
|
||||
@ -296,6 +279,7 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
|
||||
chunk_to_reassign = c;
|
||||
|
||||
// Deletes c->next
|
||||
RemoveFreeChunkFromBin(c->next);
|
||||
Merge(c, c->next);
|
||||
}
|
||||
|
||||
@ -307,13 +291,11 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
|
||||
chunk_to_reassign = c->prev;
|
||||
|
||||
// Deletes c
|
||||
RemoveFreeChunkFromBin(c->prev);
|
||||
Merge(c->prev, c);
|
||||
}
|
||||
|
||||
// Reassign the final merged chunk into the right bin.
|
||||
if (chunk_to_reassign) {
|
||||
ReassignChunkToBin(chunk_to_reassign);
|
||||
}
|
||||
InsertFreeChunkIntoBin(chunk_to_reassign);
|
||||
}
|
||||
|
||||
void GPUBFCAllocator::AddAllocVisitor(Visitor visitor) {
|
||||
@ -354,7 +336,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
|
||||
size_t total_requested_bytes_in_bin = 0;
|
||||
size_t total_chunks_in_use = 0;
|
||||
size_t total_chunks_in_bin = 0;
|
||||
for (Chunk* c : b->chunks) {
|
||||
for (Chunk* c : b->free_chunks) {
|
||||
total_bytes_in_bin += c->size;
|
||||
total_requested_bytes_in_bin += c->requested_size;
|
||||
++total_chunks_in_bin;
|
||||
@ -388,7 +370,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
|
||||
<< " was " << strings::HumanReadableNumBytes(b->bin_size)
|
||||
<< ", Chunk State: ";
|
||||
|
||||
for (Chunk* c : b->chunks) {
|
||||
for (Chunk* c : b->free_chunks) {
|
||||
LOG(INFO) << c->DebugString(true);
|
||||
}
|
||||
}
|
||||
|
@ -102,28 +102,33 @@ class GPUBFCAllocator : public VisitableAllocator {
|
||||
Chunk* AllocateNewChunk(size_t num_bytes);
|
||||
void SplitChunk(Chunk* c, size_t num_bytes);
|
||||
void Merge(Chunk* c1, Chunk* c2);
|
||||
void MaybeCoalesce(Chunk* c);
|
||||
|
||||
void ReassignChunkToBin(Chunk* c);
|
||||
void RemoveChunkFromBin(Chunk* c);
|
||||
void FreeAndMaybeCoalesce(Chunk* c);
|
||||
void InsertFreeChunkIntoBin(Chunk* c);
|
||||
void RemoveFreeChunkFromBin(Chunk* c);
|
||||
void DeleteChunk(Chunk* c);
|
||||
|
||||
void DumpMemoryLog(size_t num_bytes);
|
||||
|
||||
// A Bin is a collection of similar-sized Chunks.
|
||||
// A Bin is a collection of similar-sized free chunks.
|
||||
struct Bin {
|
||||
// All chunks in this bin have >= bin_size memory.
|
||||
size_t bin_size = 0;
|
||||
|
||||
struct ChunkComparator {
|
||||
bool operator()(Chunk* a, Chunk* b) { return a->size < b->size; }
|
||||
// Sort first by size and then use pointer address as a tie breaker.
|
||||
bool operator()(const Chunk* a, const Chunk* b) const {
|
||||
if (a->size != b->size) {
|
||||
return a->size < b->size;
|
||||
}
|
||||
return a->ptr < b->ptr;
|
||||
}
|
||||
};
|
||||
|
||||
// List of chunks within the bin, sorted by chunk size.
|
||||
std::multiset<Chunk*, ChunkComparator> chunks;
|
||||
// List of free chunks within the bin, sorted by chunk size.
|
||||
// Chunk * not owned.
|
||||
std::set<Chunk*, ChunkComparator> free_chunks;
|
||||
|
||||
explicit Bin(size_t bs) : bin_size(bs) {}
|
||||
|
||||
~Bin() { gtl::STLDeleteElements(&chunks); }
|
||||
};
|
||||
|
||||
GPUAllocatorRetry retry_helper_;
|
||||
@ -142,7 +147,7 @@ class GPUBFCAllocator : public VisitableAllocator {
|
||||
|
||||
// Structures mutable after construction
|
||||
mutable mutex lock_;
|
||||
// Not owned.
|
||||
// Chunk * owned.
|
||||
std::unordered_map<void*, Chunk*> ptr_to_chunk_map_;
|
||||
|
||||
// Called once on each region, ASAP.
|
||||
|
@ -20,7 +20,7 @@ DEFINE_bool(record_mem_types, false,
|
||||
DEFINE_bool(brain_mem_reg_cuda_dma, true,
|
||||
"If true, register CPU RAM used to copy to/from GPU RAM "
|
||||
"with the CUDA driver.");
|
||||
DEFINE_bool(brain_gpu_use_bfc_allocator, false,
|
||||
DEFINE_bool(brain_gpu_use_bfc_allocator, true,
|
||||
"If true, uses the Best-Fit GPU allocator.");
|
||||
DEFINE_bool(brain_gpu_region_allocator_debug, false,
|
||||
"If true, checks for memory overwrites by writing "
|
||||
@ -34,7 +34,7 @@ bool FLAGS_record_mem_types = false;
|
||||
bool FLAGS_brain_mem_reg_cuda_dma = true;
|
||||
bool FLAGS_brain_gpu_region_allocator_debug = false;
|
||||
bool FLAGS_brain_gpu_region_allocator_reset_to_nan = false;
|
||||
bool FLAGS_brain_gpu_use_bfc_allocator = false;
|
||||
bool FLAGS_brain_gpu_use_bfc_allocator = true;
|
||||
#endif
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
|
@ -135,6 +135,7 @@ REGISTER_CONCAT(bfloat16);
|
||||
ConcatOp<GPUDevice, type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/bfloat16.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
@ -34,6 +35,7 @@ void ConcatGPU(const GPUDevice& d,
|
||||
typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -541,12 +541,36 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
// The output image size is the spatial size of the output.
|
||||
const int output_image_size = out_rows * out_cols;
|
||||
|
||||
// Shard 'batch' images into 'shard_size' groups of images to be fed
|
||||
// into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache
|
||||
// size ('target_working_set_size') by the matmul size of an individual
|
||||
// image ('work_unit_size').
|
||||
|
||||
// TODO(andydavis)
|
||||
// *) Get L3 cache size from device at runtime (30MB is from ivybridge).
|
||||
// *) Consider reducing 'target_working_set_size' if L3 is shared by
|
||||
// other concurrently running tensorflow ops.
|
||||
const size_t target_working_set_size = (30LL << 20) / sizeof(T);
|
||||
|
||||
const size_t size_A = output_image_size * filter_total_size;
|
||||
|
||||
const size_t size_B = output_image_size * out_depth;
|
||||
|
||||
const size_t size_C = filter_total_size * out_depth;
|
||||
|
||||
const size_t work_unit_size = size_A + size_B + size_C;
|
||||
|
||||
const size_t shard_size =
|
||||
(target_working_set_size + work_unit_size - 1) / work_unit_size;
|
||||
|
||||
Tensor col_buffer;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({output_image_size, filter_total_size}), &col_buffer));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({static_cast<int64>(shard_size),
|
||||
static_cast<int64>(output_image_size),
|
||||
static_cast<int64>(filter_total_size)}),
|
||||
&col_buffer));
|
||||
|
||||
// The input offset corresponding to a single input image.
|
||||
const int input_offset = input_rows * input_cols * in_depth;
|
||||
@ -571,21 +595,29 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
contract_dims[0].first = 0;
|
||||
contract_dims[0].second = 0;
|
||||
|
||||
for (int image_id = 0; image_id < batch; ++image_id) {
|
||||
// When we compute the gradient with respect to the filters, we need to do
|
||||
// im2col to allow gemm-type computation.
|
||||
Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
|
||||
filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
|
||||
stride, col_buffer_data);
|
||||
for (int image_id = 0; image_id < batch; image_id += shard_size) {
|
||||
const int shard_limit = std::min(static_cast<int>(shard_size),
|
||||
static_cast<int>(batch) - image_id);
|
||||
for (int shard_id = 0; shard_id < shard_limit; ++shard_id) {
|
||||
// TODO(andydavis) Parallelize this loop.
|
||||
// When we compute the gradient with respect to the filters, we need
|
||||
// to do im2col to allow gemm-type computation.
|
||||
Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
|
||||
filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
|
||||
stride, col_buffer_data + shard_id * size_A);
|
||||
|
||||
ConstTensorMap A(col_buffer_data, output_image_size, filter_total_size);
|
||||
ConstTensorMap B(out_backprop_data + output_offset * image_id,
|
||||
output_image_size, out_depth);
|
||||
input_data += input_offset;
|
||||
}
|
||||
|
||||
ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
|
||||
filter_total_size);
|
||||
ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
|
||||
out_depth);
|
||||
|
||||
// Gradient with respect to filter.
|
||||
C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
|
||||
|
||||
input_data += input_offset;
|
||||
out_backprop_data += output_offset * shard_limit;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -46,8 +46,9 @@ template <typename Device, typename T>
|
||||
struct TileGrad<Device, T, 0> {
|
||||
void operator()(const Device& d, typename TTypes<T, 0>::Tensor out,
|
||||
typename TTypes<T, 0>::ConstTensor in,
|
||||
const Eigen::DSizes<ptrdiff_t, 0>&,
|
||||
const Eigen::DSizes<ptrdiff_t, 0>&, bool first) const {
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 0>&,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 0>&,
|
||||
bool first) const {
|
||||
if (first) {
|
||||
out.device(d) = in;
|
||||
} else {
|
||||
|
@ -407,7 +407,7 @@ reshape(t, [3, 3]) ==> [[1, 2, 3]
|
||||
|
||||
# tensor 't' is [[[1, 1], [2, 2]]
|
||||
# [[3, 3], [4, 4]]]
|
||||
# tensor 't' has shape [2, 2]
|
||||
# tensor 't' has shape [2, 2, 2]
|
||||
reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
|
||||
[3, 3, 4, 4]]
|
||||
|
||||
|
@ -529,7 +529,7 @@ tf.shape(split0) ==> [5, 10]
|
||||
|
||||
* <b>`split_dim`</b>: A 0-D `int32` `Tensor`. The dimension along which to split.
|
||||
Must be in the range `[0, rank(value))`.
|
||||
* <b>`num_split`</b>: A 0-D `int32` `Tensor`. The number of ways to split.
|
||||
* <b>`num_split`</b>: A Python integer. The number of ways to split.
|
||||
* <b>`value`</b>: The `Tensor` to split.
|
||||
* <b>`name`</b>: A name for the operation (optional).
|
||||
|
||||
|
@ -17,7 +17,7 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
|
||||
* [`tf.constant(value, dtype=None, shape=None, name='Const')`](#constant)
|
||||
* [Sequences](#AUTOGENERATED-sequences)
|
||||
* [`tf.linspace(start, stop, num, name=None)`](#linspace)
|
||||
* [`tf.range(start, limit, delta=1, name='range')`](#range)
|
||||
* [`tf.range(start, limit=None, delta=1, name='range')`](#range)
|
||||
* [Random Tensors](#AUTOGENERATED-random-tensors)
|
||||
* [Examples:](#AUTOGENERATED-examples-)
|
||||
* [`tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)`](#random_normal)
|
||||
@ -273,12 +273,15 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.range(start, limit, delta=1, name='range')` <a class="md-anchor" id="range"></a>
|
||||
### `tf.range(start, limit=None, delta=1, name='range')` <a class="md-anchor" id="range"></a>
|
||||
|
||||
Creates a sequence of integers.
|
||||
|
||||
This operation creates a sequence of integers that begins at `start` and
|
||||
extends by increments of `delta` up to but not including `limit`.
|
||||
Creates a sequence of integers that begins at `start` and extends by
|
||||
increments of `delta` up to but not including `limit`.
|
||||
|
||||
Like the Python builtin `range`, `start` defaults to 0, so that
|
||||
`range(n) = range(0, n)`.
|
||||
|
||||
For example:
|
||||
|
||||
@ -287,12 +290,16 @@ For example:
|
||||
# 'limit' is 18
|
||||
# 'delta' is 3
|
||||
tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
|
||||
|
||||
# 'limit' is 5
|
||||
tf.range(limit) ==> [0, 1, 2, 3, 4]
|
||||
```
|
||||
|
||||
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
|
||||
|
||||
|
||||
* <b>`start`</b>: A 0-D (scalar) of type `int32`. First entry in sequence.
|
||||
Defaults to 0.
|
||||
* <b>`limit`</b>: A 0-D (scalar) of type `int32`. Upper limit of sequence,
|
||||
exclusive.
|
||||
* <b>`delta`</b>: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1.
|
||||
|
@ -143,7 +143,7 @@ This must be called by the constructors of subclasses.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, name=None)` <a class="md-anchor" id="Optimizer.minimize"></a>
|
||||
#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, name=None)` <a class="md-anchor" id="Optimizer.minimize"></a>
|
||||
|
||||
Add operations to minimize 'loss' by updating 'var_list'.
|
||||
|
||||
@ -163,6 +163,8 @@ this function.
|
||||
under the key GraphKeys.TRAINABLE_VARIABLES.
|
||||
* <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be
|
||||
GATE_NONE, GATE_OP, or GATE_GRAPH.
|
||||
* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
* <b>`name`</b>: Optional name for the returned operation.
|
||||
|
||||
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
|
||||
@ -178,7 +180,7 @@ this function.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1)` <a class="md-anchor" id="Optimizer.compute_gradients"></a>
|
||||
#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None)` <a class="md-anchor" id="Optimizer.compute_gradients"></a>
|
||||
|
||||
Compute gradients of "loss" for the variables in "var_list".
|
||||
|
||||
@ -197,6 +199,8 @@ given variable.
|
||||
under the key GraphKey.TRAINABLE_VARIABLES.
|
||||
* <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be
|
||||
GATE_NONE, GATE_OP, or GATE_GRAPH.
|
||||
* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
|
||||
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
|
||||
|
||||
|
@ -268,6 +268,34 @@ $ bazel-bin/tensorflow/cc/tutorials_example_trainer --use_gpu
|
||||
|
||||
Note that "--config=cuda" is needed to enable the GPU support.
|
||||
|
||||
##### Enabling Cuda 3.0. <a class="md-anchor" id="AUTOGENERATED-enabling-cuda-3.0."></a>
|
||||
TensorFlow officially supports Cuda devices with 3.5 and 5.2 compute
|
||||
capabilities. In order to enable earlier Cuda devices such as Grid K520, you
|
||||
need to target Cuda 3.0. This can be done through TensorFlow unofficial
|
||||
settings with "configure".
|
||||
|
||||
```bash
|
||||
$ TF_UNOFFICIAL_SETTING=1 ./configure
|
||||
|
||||
# Same as the official settings above
|
||||
|
||||
WARNING: You are configuring unofficial settings in TensorFlow. Because some
|
||||
external libraries are not backward compatible, these settings are largely
|
||||
untested and unsupported.
|
||||
|
||||
Please specify a list of comma-separated Cuda compute capabilities you want to
|
||||
build with. You can find the compute capability of your device at:
|
||||
https://developer.nvidia.com/cuda-gpus.
|
||||
Please note that each additional compute capability significantly increases
|
||||
your build time and binary size. [Default is: "3.5,5.2"]: 3.0
|
||||
|
||||
Setting up Cuda include
|
||||
Setting up Cuda lib64
|
||||
Setting up Cuda bin
|
||||
Setting up Cuda nvvm
|
||||
Configuration finished
|
||||
```
|
||||
|
||||
##### Known issues <a class="md-anchor" id="AUTOGENERATED-known-issues"></a>
|
||||
|
||||
* Although it is possible to build both Cuda and non-Cuda configs under the same
|
||||
@ -360,14 +388,20 @@ Make sure you followed the the GPU installation [instructions](#install_cuda).
|
||||
|
||||
#### Can't find setup.py <a class="md-anchor" id="AUTOGENERATED-can-t-find-setup.py"></a>
|
||||
|
||||
If, during pip install, you encounter an error like:
|
||||
If, during `pip install`, you encounter an error like:
|
||||
|
||||
```bash
|
||||
...
|
||||
IOError: [Errno 2] No such file or directory: '/tmp/pip-o6Tpui-build/setup.py'
|
||||
```
|
||||
|
||||
Solution: upgrade your version of pip.
|
||||
Solution: upgrade your version of `pip`:
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
```
|
||||
|
||||
This may require `sudo`, depending on how `pip` is installed.
|
||||
|
||||
#### SSLError: SSL_VERIFY_FAILED <a class="md-anchor" id="AUTOGENERATED-sslerror--ssl_verify_failed"></a>
|
||||
|
||||
|
@ -66,7 +66,7 @@ every hundred steps or so, as in the following code example.
|
||||
|
||||
```python
|
||||
merged_summary_op = tf.merge_all_summaries()
|
||||
summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph)
|
||||
summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph_def)
|
||||
total_step = 0
|
||||
while training:
|
||||
total_step += 1
|
||||
|
@ -5,9 +5,9 @@ from __future__ import print_function
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import urllib
|
||||
|
||||
import numpy
|
||||
from six.moves import urllib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
|
||||
@ -19,7 +19,7 @@ def maybe_download(filename, work_directory):
|
||||
os.mkdir(work_directory)
|
||||
filepath = os.path.join(work_directory, filename)
|
||||
if not os.path.exists(filepath):
|
||||
filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
|
||||
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
return filepath
|
||||
|
@ -91,7 +91,7 @@ def loss(logits, labels):
|
||||
# be a 1.0 in the entry corresponding to the label).
|
||||
batch_size = tf.size(labels)
|
||||
labels = tf.expand_dims(labels, 1)
|
||||
indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
|
||||
indices = tf.expand_dims(tf.range(batch_size), 1)
|
||||
concated = tf.concat(1, [indices, labels])
|
||||
onehot_labels = tf.sparse_to_dense(
|
||||
concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)
|
||||
|
@ -39,7 +39,7 @@ Tensorflow relies on a highly efficient C++ backend to do its computation. The
|
||||
connection to this backend is called a session. The common usage for TensorFlow
|
||||
programs is to first create a graph and then launch it in a session.
|
||||
|
||||
Here we instead use the convenience `InteractiveSession` class, which
|
||||
Here we instead use the convenient `InteractiveSession` class, which
|
||||
makes TensorFlow more flexible about how you
|
||||
structure your code.
|
||||
It allows you to interleave operations which build a
|
||||
@ -232,7 +232,7 @@ print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
|
||||
## Build a Multilayer Convolutional Network <a class="md-anchor" id="AUTOGENERATED-build-a-multilayer-convolutional-network"></a>
|
||||
|
||||
Getting 91% accuracy on MNIST is bad. It's almost embarrassingly bad. In this
|
||||
section, we'll fix that, jumping from a very simple model to something moderatly
|
||||
section, we'll fix that, jumping from a very simple model to something moderately
|
||||
sophisticated: a small convolutional neural network. This will get us to around
|
||||
99.2% accuracy -- not state of the art, but respectable.
|
||||
|
||||
|
@ -9,9 +9,9 @@ import math
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from six.moves import urllib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
import urllib
|
||||
import zipfile
|
||||
|
||||
# Step 1: Download the data.
|
||||
@ -20,7 +20,7 @@ url = 'http://mattmahoney.net/dc/'
|
||||
def maybe_download(filename, expected_bytes):
|
||||
"""Download a file if not present, and make sure it's the right size."""
|
||||
if not os.path.exists(filename):
|
||||
filename, _ = urllib.urlretrieve(url + filename, filename)
|
||||
filename, _ = urllib.request.urlretrieve(url + filename, filename)
|
||||
statinfo = os.stat(filename)
|
||||
if statinfo.st_size == expected_bytes:
|
||||
print('Found and verified', filename)
|
||||
|
@ -25,9 +25,9 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import tarfile
|
||||
import urllib
|
||||
|
||||
import tensorflow.python.platform
|
||||
from six.moves import urllib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
@ -366,7 +366,7 @@ def loss(logits, labels):
|
||||
# Reshape the labels into a dense Tensor of
|
||||
# shape [batch_size, NUM_CLASSES].
|
||||
sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1])
|
||||
indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1])
|
||||
indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
|
||||
concated = tf.concat(1, [indices, sparse_labels])
|
||||
dense_labels = tf.sparse_to_dense(concated,
|
||||
[FLAGS.batch_size, NUM_CLASSES],
|
||||
@ -478,7 +478,8 @@ def maybe_download_and_extract():
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
|
||||
float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress)
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath,
|
||||
reporthook=_progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
|
@ -11,11 +11,11 @@ from __future__ import print_function
|
||||
import gzip
|
||||
import os
|
||||
import sys
|
||||
import urllib
|
||||
|
||||
import tensorflow.python.platform
|
||||
|
||||
import numpy
|
||||
from six.moves import urllib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
@ -41,7 +41,7 @@ def maybe_download(filename):
|
||||
os.mkdir(WORK_DIRECTORY)
|
||||
filepath = os.path.join(WORK_DIRECTORY, filename)
|
||||
if not os.path.exists(filepath):
|
||||
filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
|
||||
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
return filepath
|
||||
|
@ -118,7 +118,7 @@ class RNNCellTest(tf.test.TestCase):
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
x = tf.zeros([1, 3])
|
||||
m = tf.zeros([1, 3])
|
||||
keep = tf.zeros([1]) + 1
|
||||
keep = tf.zeros([]) + 1
|
||||
g, new_m = rnn_cell.DropoutWrapper(rnn_cell.GRUCell(3),
|
||||
keep, keep)(x, m)
|
||||
sess.run([tf.variables.initialize_all_variables()])
|
||||
|
@ -636,7 +636,7 @@ def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols,
|
||||
# SparseToDense does not accept batched inputs, we need to do this by
|
||||
# re-indexing and re-sizing. When TensorFlow adds SparseCrossEntropy,
|
||||
# rewrite this method.
|
||||
indices = targets[i] + num_decoder_symbols * tf.range(0, batch_size)
|
||||
indices = targets[i] + num_decoder_symbols * tf.range(batch_size)
|
||||
with tf.device("/cpu:0"): # Sparse-to-dense must happen on CPU for now.
|
||||
dense = tf.sparse_to_dense(indices, tf.expand_dims(length, 0), 1.0,
|
||||
0.0)
|
||||
|
@ -7,9 +7,9 @@ import gzip
|
||||
import os
|
||||
import re
|
||||
import tarfile
|
||||
import urllib
|
||||
|
||||
from tensorflow.python.platform import gfile
|
||||
from six.moves import urllib
|
||||
|
||||
# Special vocabulary symbols - we always put them at the start.
|
||||
_PAD = "_PAD"
|
||||
@ -40,7 +40,7 @@ def maybe_download(directory, filename, url):
|
||||
filepath = os.path.join(directory, filename)
|
||||
if not os.path.exists(filepath):
|
||||
print("Downloading %s to %s" % (url, filepath))
|
||||
filepath, _ = urllib.urlretrieve(url, filepath)
|
||||
filepath, _ = urllib.request.urlretrieve(url, filepath)
|
||||
statinfo = os.stat(filepath)
|
||||
print("Succesfully downloaded", filename, statinfo.st_size, "bytes")
|
||||
return filepath
|
||||
|
@ -13,8 +13,16 @@ import tensorflow as tf
|
||||
|
||||
"""
|
||||
|
||||
import tensorflow.python.platform
|
||||
from tensorflow.core.framework.graph_pb2 import *
|
||||
try:
|
||||
import tensorflow.python.platform
|
||||
from tensorflow.core.framework.graph_pb2 import *
|
||||
except ImportError as e:
|
||||
msg = """Error importing tensorflow: you should not try to import
|
||||
tensorflow from its source directory; please exit the tensorflow source tree,
|
||||
and relaunch your python interpreter from there.
|
||||
Original ImportError: %s""" % str(e)
|
||||
raise ImportError(msg)
|
||||
|
||||
from tensorflow.core.framework.summary_pb2 import *
|
||||
from tensorflow.core.framework.config_pb2 import *
|
||||
from tensorflow.core.util.event_pb2 import *
|
||||
|
@ -1551,6 +1551,8 @@ class Graph(object):
|
||||
# True if the graph is considered "finalized". In that case no
|
||||
# new operations can be added.
|
||||
self._finalized = False
|
||||
# Functions defined in the graph
|
||||
self._functions = []
|
||||
|
||||
def _check_not_finalized(self):
|
||||
"""Check if the graph is finalized.
|
||||
@ -1655,8 +1657,30 @@ class Graph(object):
|
||||
bytesize += op.node_def.ByteSize()
|
||||
if bytesize >= (1 << 31) or bytesize < 0:
|
||||
raise ValueError("GraphDef cannot be larger than 2GB.")
|
||||
if self._functions:
|
||||
for f in self._functions:
|
||||
bytesize += f.ByteSize()
|
||||
if bytesize >= (1 << 31) or bytesize < 0:
|
||||
raise ValueError("GraphDef cannot be larger than 2GB.")
|
||||
graph.library.function.extend(self._functions)
|
||||
return graph
|
||||
|
||||
def _add_function(self, function_def):
|
||||
"""Adds a function to the graph.
|
||||
|
||||
The function is specified as a [`FunctionDef`]
|
||||
(https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
|
||||
protocol buffer.
|
||||
|
||||
After the function has been added, you can call to the function by
|
||||
passing the function name in place of an op name to
|
||||
`Graph.create_op()`.
|
||||
|
||||
Args:
|
||||
function_def: A `FunctionDef` protocol buffer.
|
||||
"""
|
||||
self._functions.append(function_def)
|
||||
|
||||
# Helper functions to create operations.
|
||||
def create_op(self, op_type, inputs, dtypes,
|
||||
input_types=None, name=None, attrs=None, op_def=None,
|
||||
@ -1869,7 +1893,6 @@ class Graph(object):
|
||||
A list of Operations.
|
||||
"""
|
||||
return list(self._nodes_by_id.values())
|
||||
|
||||
def get_operation_by_name(self, name):
|
||||
"""Returns the `Operation` with the given `name`.
|
||||
|
||||
|
@ -75,18 +75,28 @@ class ConcatOpTest(tf.test.TestCase):
|
||||
# Random dim to concat on
|
||||
concat_dim = np.random.randint(5)
|
||||
params = {}
|
||||
if dtype == tf.bfloat16:
|
||||
dtype_feed = tf.float32
|
||||
else:
|
||||
dtype_feed = dtype
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
p = []
|
||||
for i in np.arange(num_tensors):
|
||||
input_shape = shape
|
||||
input_shape[concat_dim] = np.random.randint(1, 5)
|
||||
placeholder = tf.placeholder(dtype, shape=input_shape)
|
||||
placeholder = tf.placeholder(dtype_feed, shape=input_shape)
|
||||
p.append(placeholder)
|
||||
|
||||
t = dtype.as_numpy_dtype
|
||||
t = dtype_feed.as_numpy_dtype
|
||||
params[placeholder] = np.random.rand(*input_shape).astype(t)
|
||||
|
||||
c = tf.concat(concat_dim, p)
|
||||
if dtype != dtype_feed:
|
||||
concat_inputs = [tf.cast(p_i, dtype) for p_i in p]
|
||||
else:
|
||||
concat_inputs = p
|
||||
c = tf.concat(concat_dim, concat_inputs)
|
||||
if dtype != dtype_feed:
|
||||
c = tf.cast(c, dtype_feed)
|
||||
result = c.eval(feed_dict=params)
|
||||
|
||||
self.assertEqual(result.shape, c.get_shape())
|
||||
@ -100,15 +110,17 @@ class ConcatOpTest(tf.test.TestCase):
|
||||
ind[concat_dim] = slice(cur_offset,
|
||||
cur_offset + params[p[i]].shape[concat_dim])
|
||||
cur_offset += params[p[i]].shape[concat_dim]
|
||||
self.assertAllEqual(result[ind], params[p[i]])
|
||||
if dtype == dtype_feed:
|
||||
self.assertAllEqual(result[ind], params[p[i]])
|
||||
else:
|
||||
self.assertAllClose(result[ind], params[p[i]], 0.01)
|
||||
|
||||
def testRandom(self):
|
||||
self._testRandom(tf.float32)
|
||||
self._testRandom(tf.int16)
|
||||
self._testRandom(tf.int32, use_gpu=True)
|
||||
# Note that the following does not work since bfloat16 is not supported in
|
||||
# numpy.
|
||||
# self._testRandom(tf.bfloat16)
|
||||
self._testRandom(tf.bfloat16)
|
||||
self._testRandom(tf.bfloat16, use_gpu=True)
|
||||
|
||||
def _testGradientsSimple(self, use_gpu):
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
|
@ -262,7 +262,7 @@ class EmbeddingLookupTest(tf.test.TestCase):
|
||||
self.assertAllEqual(simple, tf.gather(params, ids).eval())
|
||||
# Run a few random sharded versions
|
||||
for procs in 1, 2, 3:
|
||||
stride = procs * tf.range(0, params.shape[0] // procs)
|
||||
stride = procs * tf.range(params.shape[0] // procs)
|
||||
split_params = [tf.gather(params, stride + p)
|
||||
for p in xrange(procs)]
|
||||
sharded = tf.nn.embedding_lookup(split_params, ids).eval()
|
||||
|
@ -190,6 +190,10 @@ class RangeTest(tf.test.TestCase):
|
||||
self._Range(100, 500, 100), np.array([100, 200, 300, 400])))
|
||||
self.assertEqual(tf.range(0, 5, 1).dtype, tf.int32)
|
||||
|
||||
def testLimitOnly(self):
|
||||
with self.test_session():
|
||||
self.assertAllEqual(np.arange(5), tf.range(5).eval())
|
||||
|
||||
def testEmpty(self):
|
||||
for start in 0, 5:
|
||||
self.assertTrue(np.array_equal(self._Range(start, start, 1), []))
|
||||
|
@ -1,218 +0,0 @@
|
||||
"""Tests for lookup table ops from tf."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.python.platform
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class HashTableOpTest(tf.test.TestCase):
|
||||
|
||||
def testHashTable(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
# Initialize with keys and values tensors.
|
||||
keys = tf.constant(['brain', 'salad', 'surgery'])
|
||||
values = tf.constant([0, 1, 2], tf.int64)
|
||||
init = table.initialize_from(keys, values)
|
||||
init.run()
|
||||
self.assertAllEqual(3, table.size().eval())
|
||||
|
||||
input_string = tf.constant(['brain', 'salad', 'tank'])
|
||||
output = table.lookup(input_string)
|
||||
|
||||
result = output.eval()
|
||||
self.assertAllEqual([0, 1, -1], result)
|
||||
|
||||
def testHashTableFindHighRank(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
# Initialize with keys and values tensors.
|
||||
keys = tf.constant(['brain', 'salad', 'surgery'])
|
||||
values = tf.constant([0, 1, 2], tf.int64)
|
||||
init = table.initialize_from(keys, values)
|
||||
init.run()
|
||||
self.assertAllEqual(3, table.size().eval())
|
||||
|
||||
input_string = tf.constant([['brain', 'salad'], ['tank', 'tarkus']])
|
||||
output = table.lookup(input_string)
|
||||
|
||||
result = output.eval()
|
||||
self.assertAllEqual([[0, 1], [-1, -1]], result)
|
||||
|
||||
def testHashTableInitWithPythonArrays(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
# Empty table.
|
||||
self.assertAllEqual(0, table.size().eval())
|
||||
|
||||
# Initialize with keys and values tensors.
|
||||
keys = ['brain', 'salad', 'surgery']
|
||||
values = [0, 1, 2]
|
||||
init = table.initialize_from(keys, values)
|
||||
init.run()
|
||||
self.assertAllEqual(3, table.size().eval())
|
||||
|
||||
input_string = tf.constant(['brain', 'salad', 'tank'])
|
||||
output = table.lookup(input_string)
|
||||
|
||||
result = output.eval()
|
||||
self.assertAllEqual([0, 1, -1], result)
|
||||
|
||||
def testHashTableInitWithNumPyArrays(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
# Initialize with keys and values tensors.
|
||||
keys = np.array(['brain', 'salad', 'surgery'], dtype=np.str)
|
||||
values = np.array([0, 1, 2], dtype=np.int64)
|
||||
init = table.initialize_from(keys, values)
|
||||
init.run()
|
||||
self.assertAllEqual(3, table.size().eval())
|
||||
|
||||
input_string = tf.constant(['brain', 'salad', 'tank'])
|
||||
output = table.lookup(input_string)
|
||||
|
||||
result = output.eval()
|
||||
self.assertAllEqual([0, 1, -1], result)
|
||||
|
||||
def testMultipleHashTables(self):
|
||||
with self.test_session() as sess:
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table1 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
table2 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
table3 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
keys = tf.constant(['brain', 'salad', 'surgery'])
|
||||
values = tf.constant([0, 1, 2], tf.int64)
|
||||
table1.initialize_from(keys, values)
|
||||
table2.initialize_from(keys, values)
|
||||
table3.initialize_from(keys, values)
|
||||
|
||||
tf.initialize_all_tables().run()
|
||||
self.assertAllEqual(3, table1.size().eval())
|
||||
self.assertAllEqual(3, table2.size().eval())
|
||||
self.assertAllEqual(3, table3.size().eval())
|
||||
|
||||
input_string = tf.constant(['brain', 'salad', 'tank'])
|
||||
output1 = table1.lookup(input_string)
|
||||
output2 = table2.lookup(input_string)
|
||||
output3 = table3.lookup(input_string)
|
||||
|
||||
out1, out2, out3 = sess.run([output1, output2, output3])
|
||||
self.assertAllEqual([0, 1, -1], out1)
|
||||
self.assertAllEqual([0, 1, -1], out2)
|
||||
self.assertAllEqual([0, 1, -1], out3)
|
||||
|
||||
def testHashTableWithTensorDefault(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = tf.constant(-1, tf.int64)
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
# Initialize with keys and values tensors.
|
||||
keys = tf.constant(['brain', 'salad', 'surgery'])
|
||||
values = tf.constant([0, 1, 2], tf.int64)
|
||||
init = table.initialize_from(keys, values)
|
||||
init.run()
|
||||
|
||||
input_string = tf.constant(['brain', 'salad', 'tank'])
|
||||
output = table.lookup(input_string)
|
||||
|
||||
result = output.eval()
|
||||
self.assertAllEqual([0, 1, -1], result)
|
||||
|
||||
def testSignatureMismatch(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
# Initialize with keys and values tensors.
|
||||
keys = tf.constant(['brain', 'salad', 'surgery'])
|
||||
values = tf.constant([0, 1, 2], tf.int64)
|
||||
init = table.initialize_from(keys, values)
|
||||
init.run()
|
||||
|
||||
input_string = tf.constant([1, 2, 3], tf.int64)
|
||||
with self.assertRaises(TypeError):
|
||||
table.lookup(input_string)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
tf.HashTable(tf.string, tf.int64, 'UNK', shared_name)
|
||||
|
||||
def testDTypes(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
with self.assertRaises(TypeError):
|
||||
tf.HashTable([tf.string], tf.string, default_val, shared_name)
|
||||
|
||||
def testNotInitialized(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
input_string = tf.constant(['brain', 'salad', 'surgery'])
|
||||
output = table.lookup(input_string)
|
||||
|
||||
with self.assertRaisesOpError('Table not initialized'):
|
||||
output.eval()
|
||||
|
||||
def testInitializeTwice(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
# Initialize with keys and values tensors.
|
||||
keys = tf.constant(['brain', 'salad', 'surgery'])
|
||||
values = tf.constant([0, 1, 2], tf.int64)
|
||||
init = table.initialize_from(keys, values)
|
||||
init.run()
|
||||
|
||||
with self.assertRaisesOpError('Table already initialized'):
|
||||
init.run()
|
||||
|
||||
def testInitializationWithInvalidDimensions(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
# Initialize with keys and values tensors.
|
||||
keys = tf.constant(['brain', 'salad', 'surgery'])
|
||||
values = tf.constant([0, 1, 2, 3, 4], tf.int64)
|
||||
with self.assertRaises(ValueError):
|
||||
table.initialize_from(keys, values)
|
||||
|
||||
def testInitializationWithInvalidDataTypes(self):
|
||||
with self.test_session():
|
||||
shared_name = ''
|
||||
default_val = -1
|
||||
table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
|
||||
|
||||
# Initialize with keys and values tensors.
|
||||
keys = [0, 1, 2]
|
||||
values = ['brain', 'salad', 'surgery']
|
||||
with self.assertRaises(TypeError):
|
||||
table.initialize_from(keys, values)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -74,7 +74,7 @@ def clip_by_norm(t, clip_norm, name=None):
|
||||
|
||||
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
|
||||
l2norm_inv = math_ops.rsqrt(
|
||||
math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
|
||||
math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t))))
|
||||
tclip = array_ops.identity(t * clip_norm * math_ops.minimum(
|
||||
l2norm_inv, constant_op.constant(1.0 / clip_norm)), name=name)
|
||||
|
||||
@ -228,7 +228,7 @@ def clip_by_average_norm(t, clip_norm, name=None):
|
||||
# L2-norm per element
|
||||
n_element = math_ops.cast(array_ops.size(t), types.float32)
|
||||
l2norm_inv = math_ops.rsqrt(
|
||||
math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
|
||||
math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t))))
|
||||
tclip = array_ops.identity(
|
||||
t * clip_norm * math_ops.minimum(
|
||||
l2norm_inv * n_element, constant_op.constant(1.0 / clip_norm)),
|
||||
|
@ -10,7 +10,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import types
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import common_shapes
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -408,174 +407,12 @@ class FIFOQueue(QueueBase):
|
||||
# TODO(josh11b): class BatchQueue(QueueBase):
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class LookupTableBase(object):
|
||||
"""Represents a lookup table that persists across different steps."""
|
||||
|
||||
def __init__(self, key_dtype, value_dtype, default_value, table_ref):
|
||||
"""Construct a table object from a table reference.
|
||||
|
||||
Args:
|
||||
key_dtype: The table key type.
|
||||
value_dtype: The table value type.
|
||||
default_value: The value to use if a key is missing in the table.
|
||||
table_ref: The table reference, i.e. the output of the lookup table ops.
|
||||
"""
|
||||
self._key_dtype = types.as_dtype(key_dtype)
|
||||
self._value_dtype = types.as_dtype(value_dtype)
|
||||
self._shapes = [tensor_shape.TensorShape([1])]
|
||||
self._table_ref = table_ref
|
||||
self._name = self._table_ref.op.name.split("/")[-1]
|
||||
self._default_value = ops.convert_to_tensor(default_value,
|
||||
dtype=self._value_dtype)
|
||||
self._default_value.get_shape().merge_with(tensor_shape.scalar())
|
||||
|
||||
@property
|
||||
def table_ref(self):
|
||||
"""Get the underlying table reference."""
|
||||
return self._table_ref
|
||||
|
||||
@property
|
||||
def key_dtype(self):
|
||||
"""The table key dtype."""
|
||||
return self._key_dtype
|
||||
|
||||
@property
|
||||
def value_dtype(self):
|
||||
"""The table value dtype."""
|
||||
return self._value_dtype
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""The name of the table."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def default_value(self):
|
||||
"""The default value of the table."""
|
||||
return self._default_value
|
||||
|
||||
def size(self, name=None):
|
||||
"""Compute the number of elements in this table.
|
||||
|
||||
Args:
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A scalar tensor containing the number of elements in this table.
|
||||
"""
|
||||
if name is None:
|
||||
name = "%s_Size" % self._name
|
||||
return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
|
||||
|
||||
def lookup(self, keys, name=None):
|
||||
"""Looks up `keys` in a table, outputs the corresponding values.
|
||||
|
||||
The `default_value` is use for keys not present in the table.
|
||||
|
||||
Args:
|
||||
keys: Keys to look up.
|
||||
name: Optional name for the op.
|
||||
|
||||
Returns:
|
||||
The operation that looks up the keys.
|
||||
|
||||
Raises:
|
||||
TypeError: when `keys` or `default_value` doesn't match the table data
|
||||
types.
|
||||
"""
|
||||
if name is None:
|
||||
name = "%s_lookup_table_find" % self._name
|
||||
|
||||
if keys.dtype != self._key_dtype:
|
||||
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (
|
||||
self._key_dtype, keys.dtype))
|
||||
|
||||
return gen_data_flow_ops._lookup_table_find(
|
||||
self._table_ref, keys, self._default_value, name=name)
|
||||
|
||||
def initialize_from(self, keys, values, name=None):
|
||||
"""Initialize the table with the provided keys and values tensors.
|
||||
|
||||
Construct an initializer object from keys and value tensors.
|
||||
|
||||
Args:
|
||||
keys: The tensor for the keys.
|
||||
values: The tensor for the values.
|
||||
name: Optional name for the op.
|
||||
|
||||
Returns:
|
||||
The operation that initializes the table.
|
||||
|
||||
Raises:
|
||||
TypeError: when the keys and values data types do not match the table
|
||||
key and value data types.
|
||||
"""
|
||||
if name is None:
|
||||
name = "%s_initialize_table" % self.name
|
||||
with ops.op_scope([keys, values], None, name):
|
||||
keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys")
|
||||
values = ops.convert_to_tensor(values, dtype=self.value_dtype,
|
||||
name="values")
|
||||
|
||||
init_op = gen_data_flow_ops._initialize_table(
|
||||
self.table_ref, keys, values, name=name)
|
||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
|
||||
return init_op
|
||||
|
||||
def _check_table_dtypes(self, key_dtype, value_dtype):
|
||||
"""Check that the given key_dtype and value_dtype matches the table dtypes'.
|
||||
|
||||
Args:
|
||||
key_dtype: The key data type to check.
|
||||
value_dtype: The value data type to check.
|
||||
|
||||
Raises:
|
||||
TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
|
||||
types.
|
||||
"""
|
||||
if key_dtype != self.key_dtype:
|
||||
raise TypeError("Invalid key dtype, expected %s but got %s." % (
|
||||
self.key_dtype, key_dtype))
|
||||
if value_dtype != self.value_dtype:
|
||||
raise TypeError("Invalid value dtype, expected %s but got %s." % (
|
||||
self.value_dtype, value_dtype))
|
||||
|
||||
|
||||
class HashTable(LookupTableBase):
|
||||
"""A generic hash table implementation."""
|
||||
|
||||
def __init__(self, key_dtype, value_dtype, default_value, shared_name=None,
|
||||
name="hash_table"):
|
||||
"""Creates a non-initialized hash table.
|
||||
|
||||
This op creates a hash table, specifying the type of its keys and values.
|
||||
Before using the table you will have to initialize it. After initialization
|
||||
the table will be immutable.
|
||||
|
||||
Args:
|
||||
key_dtype: Type of the table keys.
|
||||
value_dtype: Type of the table values.
|
||||
default_value: The scalar tensor to be used when a key is missing in the
|
||||
table.
|
||||
shared_name: Optional. If non-empty, this table will be shared under
|
||||
the given name across multiple sessions.
|
||||
name: Optional name for the hash table op.
|
||||
|
||||
Returns:
|
||||
A `HashTable` object.
|
||||
"""
|
||||
table_ref = gen_data_flow_ops._hash_table(
|
||||
shared_name=shared_name, key_dtype=key_dtype,
|
||||
value_dtype=value_dtype, name=name)
|
||||
|
||||
super(HashTable, self).__init__(key_dtype, value_dtype, default_value,
|
||||
table_ref)
|
||||
|
||||
|
||||
def initialize_all_tables(name="init_all_tables"):
|
||||
"""Returns an Op that initializes all tables of the default graph.
|
||||
|
||||
Args:
|
||||
name: Optional name for the initialization op.
|
||||
|
||||
Returns:
|
||||
An Op that initializes all tables. Note that if there are
|
||||
not tables the returned Op is a NoOp.
|
||||
|
@ -51,7 +51,7 @@ def embedding_lookup(params, ids, name=None):
|
||||
else:
|
||||
ids = ops.convert_to_tensor(ids, name="ids")
|
||||
flat_ids = array_ops.reshape(ids, [-1])
|
||||
original_indices = math_ops.range(0, array_ops.size(flat_ids))
|
||||
original_indices = math_ops.range(array_ops.size(flat_ids))
|
||||
# Compute flat_ids % partitions for each id
|
||||
ids_mod_p = flat_ids % np
|
||||
if ids_mod_p.dtype != types.int32:
|
||||
|
@ -21,7 +21,7 @@ def _ReductionGradAssist(op):
|
||||
indices = op.inputs[1] # [1, 2]
|
||||
indices_shape = array_ops.shape(indices) # [2]
|
||||
new_output_shape = data_flow_ops.dynamic_stitch( # [2, 1, 1, 7]
|
||||
[math_ops.range(0, input_rank), # [0, 1, 2, 3]
|
||||
[math_ops.range(input_rank), # [0, 1, 2, 3]
|
||||
indices], # [1, 2]
|
||||
[input_shape, # [2, 3, 5, 7]
|
||||
array_ops.fill(indices_shape, 1)]) # [1, 1]
|
||||
|
@ -536,11 +536,14 @@ ops.Tensor._override_operator("__gt__", greater)
|
||||
ops.Tensor._override_operator("__ge__", greater_equal)
|
||||
|
||||
|
||||
def range(start, limit, delta=1, name="range"):
|
||||
def range(start, limit=None, delta=1, name="range"):
|
||||
"""Creates a sequence of integers.
|
||||
|
||||
This operation creates a sequence of integers that begins at `start` and
|
||||
extends by increments of `delta` up to but not including `limit`.
|
||||
Creates a sequence of integers that begins at `start` and extends by
|
||||
increments of `delta` up to but not including `limit`.
|
||||
|
||||
Like the Python builtin `range`, `start` defaults to 0, so that
|
||||
`range(n) = range(0, n)`.
|
||||
|
||||
For example:
|
||||
|
||||
@ -549,10 +552,14 @@ def range(start, limit, delta=1, name="range"):
|
||||
# 'limit' is 18
|
||||
# 'delta' is 3
|
||||
tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
|
||||
|
||||
# 'limit' is 5
|
||||
tf.range(limit) ==> [0, 1, 2, 3, 4]
|
||||
```
|
||||
|
||||
Args:
|
||||
start: A 0-D (scalar) of type `int32`. First entry in sequence.
|
||||
Defaults to 0.
|
||||
limit: A 0-D (scalar) of type `int32`. Upper limit of sequence,
|
||||
exclusive.
|
||||
delta: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1.
|
||||
@ -562,6 +569,8 @@ def range(start, limit, delta=1, name="range"):
|
||||
Returns:
|
||||
An 1-D `int32` `Tensor`.
|
||||
"""
|
||||
if limit is None:
|
||||
start, limit = 0, start
|
||||
return gen_math_ops._range(start, limit, delta, name=name)
|
||||
|
||||
|
||||
|
@ -173,6 +173,7 @@ from __future__ import print_function
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import types
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import candidate_sampling_ops
|
||||
@ -347,7 +348,8 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
|
||||
|
||||
Args:
|
||||
x: A tensor.
|
||||
keep_prob: A Python float. The probability that each element is kept.
|
||||
keep_prob: A scalar `Tensor` with the same type as x. The probability
|
||||
that each element is kept.
|
||||
noise_shape: A 1-D `Tensor` of type `int32`, representing the
|
||||
shape for randomly generated keep/drop flags.
|
||||
seed: A Python integer. Used to create random seeds. See
|
||||
@ -361,10 +363,15 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
|
||||
Raises:
|
||||
ValueError: If `keep_prob` is not in `(0, 1]`.
|
||||
"""
|
||||
if not (0 < keep_prob <= 1):
|
||||
raise ValueError("Expected keep_prob in (0, 1], got %g" % keep_prob)
|
||||
with ops.op_scope([x], name, "dropout") as name:
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
if isinstance(keep_prob, float) and not(0 < keep_prob <= 1):
|
||||
raise ValueError("keep_prob must be a scalar tensor or a float in the "
|
||||
"range (0, 1], got %g" % keep_prob)
|
||||
keep_prob = ops.convert_to_tensor(
|
||||
keep_prob, dtype=x.dtype, name="keep_prob")
|
||||
keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
|
||||
|
||||
noise_shape = noise_shape or array_ops.shape(x)
|
||||
# uniform [keep_prob, 1.0 + keep_prob)
|
||||
random_tensor = keep_prob
|
||||
@ -372,7 +379,9 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
|
||||
noise_shape, seed=seed, dtype=x.dtype)
|
||||
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
|
||||
binary_tensor = math_ops.floor(random_tensor)
|
||||
return x * (1.0 / keep_prob) * binary_tensor
|
||||
ret = x * math_ops.inv(keep_prob) * binary_tensor
|
||||
ret.set_shape(x.get_shape())
|
||||
return ret
|
||||
|
||||
|
||||
def depthwise_conv2d(input, filter, strides, padding, name=None):
|
||||
|
@ -85,7 +85,7 @@ def _BiasAddGrad(unused_bias_op, received_grad):
|
||||
Two tensors, the first one for the "tensor" input of the BiasOp,
|
||||
the second one for the "bias" input of the BiasOp.
|
||||
"""
|
||||
reduction_dim_tensor = math_ops.range(0, array_ops.rank(received_grad) - 1)
|
||||
reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1)
|
||||
return (received_grad, math_ops.reduce_sum(received_grad, reduction_dim_tensor))
|
||||
|
||||
|
||||
|
@ -405,28 +405,82 @@ class DropoutTest(test_util.TensorFlowTestCase):
|
||||
sorted_value = np.unique(np.sort(value[i, :]))
|
||||
self.assertEqual(sorted_value.size, 1)
|
||||
|
||||
def testDropoutPlaceholderKeepProb(self):
|
||||
# Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
|
||||
# that it is producing approximately the right number of ones over a large
|
||||
# number of samples, based on the keep probability.
|
||||
x_dim = 40
|
||||
y_dim = 30
|
||||
num_iter = 10
|
||||
for keep_prob in [0.1, 0.5, 0.8]:
|
||||
with self.test_session():
|
||||
t = constant_op.constant(1.0,
|
||||
shape=[x_dim, y_dim],
|
||||
dtype=types.float32)
|
||||
keep_prob_placeholder = array_ops.placeholder(types.float32)
|
||||
dropout = nn.dropout(t, keep_prob_placeholder)
|
||||
final_count = 0
|
||||
self.assertEqual([x_dim, y_dim], dropout.get_shape())
|
||||
for _ in xrange(0, num_iter):
|
||||
value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob})
|
||||
final_count += np.count_nonzero(value)
|
||||
# Verifies that there are only two values: 0 and 1/keep_prob.
|
||||
sorted_value = np.unique(np.sort(value))
|
||||
self.assertEqual(0, sorted_value[0])
|
||||
self.assertAllClose(1 / keep_prob, sorted_value[1])
|
||||
# Check that we are in the 15% error range
|
||||
expected_count = x_dim * y_dim * keep_prob * num_iter
|
||||
rel_error = math.fabs(final_count - expected_count) / expected_count
|
||||
print(rel_error)
|
||||
self.assertTrue(rel_error < 0.15)
|
||||
|
||||
def testShapedDropoutUnknownShape(self):
|
||||
x_dim = 40
|
||||
y_dim = 30
|
||||
keep_prob = 0.5
|
||||
x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=types.float32)
|
||||
dropout_x = nn.dropout(
|
||||
x, keep_prob, noise_shape=array_ops.placeholder(types.int32))
|
||||
self.assertEqual(x.get_shape(), dropout_x.get_shape())
|
||||
|
||||
def testInvalidKeepProb(self):
|
||||
x_dim = 40
|
||||
y_dim = 30
|
||||
t = constant_op.constant(1.0,
|
||||
shape=[x_dim, y_dim],
|
||||
dtype=types.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
nn.dropout(t, -1.0)
|
||||
with self.assertRaises(ValueError):
|
||||
nn.dropout(t, 1.1)
|
||||
with self.assertRaises(ValueError):
|
||||
nn.dropout(t, [0.0, 1.0])
|
||||
with self.assertRaises(ValueError):
|
||||
nn.dropout(t, array_ops.placeholder(types.float64))
|
||||
with self.assertRaises(ValueError):
|
||||
nn.dropout(t, array_ops.placeholder(types.float32, shape=[2]))
|
||||
|
||||
def testShapedDropoutShapeError(self):
|
||||
# Runs shaped dropout and verifies an error is thrown on misshapen noise.
|
||||
x_dim = 40
|
||||
y_dim = 30
|
||||
keep_prob = 0.5
|
||||
with self.test_session():
|
||||
t = constant_op.constant(1.0,
|
||||
shape=[x_dim, y_dim],
|
||||
dtype=types.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim])
|
||||
# test that broadcasting proceeds
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[y_dim])
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[1, 1])
|
||||
t = constant_op.constant(1.0,
|
||||
shape=[x_dim, y_dim],
|
||||
dtype=types.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
|
||||
with self.assertRaises(ValueError):
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim])
|
||||
# test that broadcasting proceeds
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[y_dim])
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
|
||||
_ = nn.dropout(t, keep_prob, noise_shape=[1, 1])
|
||||
|
||||
|
||||
class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase):
|
||||
|
@ -437,8 +437,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
|
||||
default_value, dtype=sp_input.values.dtype)
|
||||
|
||||
num_rows = math_ops.cast(sp_input.shape[0], types.int32)
|
||||
all_row_indices = math_ops.cast(
|
||||
math_ops.range(0, num_rows, 1), types.int64)
|
||||
all_row_indices = math_ops.cast(math_ops.range(num_rows), types.int64)
|
||||
empty_row_indices, _ = array_ops.list_diff(
|
||||
all_row_indices, sp_input.indices[:, 0])
|
||||
empty_row_indicator = gen_sparse_ops.sparse_to_dense(
|
||||
|
@ -6,18 +6,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import thread
|
||||
from logging import getLogger
|
||||
from logging import log
|
||||
from logging import debug
|
||||
from logging import error
|
||||
from logging import fatal
|
||||
from logging import info
|
||||
from logging import warn
|
||||
from logging import warning
|
||||
from logging import DEBUG
|
||||
from logging import ERROR
|
||||
from logging import FATAL
|
||||
@ -25,13 +18,25 @@ from logging import INFO
|
||||
from logging import WARN
|
||||
|
||||
# Controls which methods from pyglib.logging are available within the project
|
||||
# Do not add methods here without also adding to platform/default/_logging.py
|
||||
# Do not add methods here without also adding to platform/google/_logging.py
|
||||
__all__ = ['log', 'debug', 'error', 'fatal', 'info', 'warn', 'warning',
|
||||
'DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN',
|
||||
'flush', 'log_every_n', 'log_first_n', 'vlog',
|
||||
'TaskLevelStatusMessage', 'get_verbosity', 'set_verbosity']
|
||||
|
||||
warning = warn
|
||||
# Scope the tensorflow logger to not conflict with users' loggers
|
||||
_logger = logging.getLogger('tensorflow')
|
||||
_handler = logging.StreamHandler()
|
||||
_handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT, None))
|
||||
_logger.addHandler(_handler)
|
||||
|
||||
log = _logger.log
|
||||
debug = _logger.debug
|
||||
error = _logger.error
|
||||
fatal = _logger.fatal
|
||||
info = _logger.info
|
||||
warn = _logger.warn
|
||||
warning = _logger.warn
|
||||
|
||||
_level_names = {
|
||||
FATAL: 'FATAL',
|
||||
@ -61,7 +66,7 @@ def flush():
|
||||
|
||||
# Code below is taken from pyglib/logging
|
||||
def vlog(level, msg, *args, **kwargs):
|
||||
log(level, msg, *args, **kwargs)
|
||||
_logger.log(level, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def _GetNextLogCountPerToken(token):
|
||||
@ -169,12 +174,12 @@ def google2_log_prefix(level, timestamp=None, file_and_line=None):
|
||||
|
||||
def get_verbosity():
|
||||
"""Return how much logging output will be produced."""
|
||||
return getLogger().getEffectiveLevel()
|
||||
return _logger.getEffectiveLevel()
|
||||
|
||||
|
||||
def set_verbosity(verbosity):
|
||||
"""Sets the threshold for what messages will be logged."""
|
||||
getLogger().setLevel(verbosity)
|
||||
_logger.setLevel(verbosity)
|
||||
|
||||
|
||||
def _get_thread_id():
|
||||
|
@ -163,7 +163,7 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
|
||||
is added to the current Graph's QUEUE_RUNNER collection.
|
||||
"""
|
||||
with ops.op_scope([limit], name, "input_producer") as name:
|
||||
range_tensor = math_ops.range(0, limit)
|
||||
range_tensor = math_ops.range(limit)
|
||||
return _input_producer(
|
||||
range_tensor, types.int32, num_epochs, shuffle, seed, capacity, name,
|
||||
"fraction_of_%d_full" % capacity)
|
||||
|
@ -33,9 +33,9 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
|
||||
...
|
||||
global_step = tf.Variable(0, trainable=False)
|
||||
starter_learning_rate = 0.1
|
||||
learning_rate = tf.exponential_decay(starter_learning_rate, global_step,
|
||||
100000, 0.96, staircase=True)
|
||||
optimizer = tf.GradientDescent(learning_rate)
|
||||
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
|
||||
100000, 0.96, staircase=True)
|
||||
optimizer = tf.GradientDescentOptimizer(learning_rate)
|
||||
# Passing global_step to minimize() will increment it at each step.
|
||||
optimizer.minimize(...my loss..., global_step=global_step)
|
||||
```
|
||||
|
@ -19,26 +19,27 @@
|
||||
"dagre": "~0.7.4",
|
||||
"es6-promise": "~3.0.2",
|
||||
"graphlib": "~1.0.7",
|
||||
"iron-ajax": "PolymerElements/iron-ajax#~1.0.8",
|
||||
"iron-collapse": "PolymerElements/iron-collapse#~1.0.4",
|
||||
"iron-list": "PolymerElements/iron-list#~1.1.5",
|
||||
"paper-button": "PolymerElements/paper-button#~1.0.7",
|
||||
"paper-checkbox": "PolymerElements/paper-checkbox#~1.0.6",
|
||||
"paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#~1.0.4",
|
||||
"paper-header-panel": "PolymerElements/paper-header-panel#~1.0.5",
|
||||
"paper-icon-button": "PolymerElements/paper-icon-button#~1.0.3",
|
||||
"paper-input": "PolymerElements/paper-input#~1.0.15",
|
||||
"paper-item": "PolymerElements/paper-item#~1.0.3",
|
||||
"paper-menu": "PolymerElements/paper-menu#~1.1.1",
|
||||
"paper-progress": "PolymerElements/paper-progress#~1.0.7",
|
||||
"paper-radio-button": "PolymerElements/paper-radio-button#~1.0.8",
|
||||
"paper-radio-group": "PolymerElements/paper-radio-group#~1.0.4",
|
||||
"paper-slider": "PolymerElements/paper-slider#~1.0.4",
|
||||
"paper-styles": "PolymerElements/paper-styles#~1.0.11",
|
||||
"paper-toggle-button": "PolymerElements/paper-toggle-button#~1.0.6",
|
||||
"paper-toolbar": "PolymerElements/paper-toolbar#~1.0.4",
|
||||
"iron-ajax": "PolymerElements/iron-ajax#1.0.7",
|
||||
"iron-collapse": "PolymerElements/iron-collapse#1.0.4",
|
||||
"iron-list": "PolymerElements/iron-list#1.1.5",
|
||||
"iron-selector": "PolymerElements/iron-selector#1.0.7",
|
||||
"paper-button": "PolymerElements/paper-button#1.0.8",
|
||||
"paper-checkbox": "PolymerElements/paper-checkbox#1.0.13",
|
||||
"paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#1.0.5",
|
||||
"paper-header-panel": "PolymerElements/paper-header-panel#1.0.5",
|
||||
"paper-icon-button": "PolymerElements/paper-icon-button#1.0.5",
|
||||
"paper-input": "PolymerElements/paper-input#1.0.16",
|
||||
"paper-item": "PolymerElements/paper-item#1.0.5",
|
||||
"paper-menu": "PolymerElements/paper-menu#1.1.1",
|
||||
"paper-progress": "PolymerElements/paper-progress#1.0.7",
|
||||
"paper-radio-button": "PolymerElements/paper-radio-button#1.0.10",
|
||||
"paper-radio-group": "PolymerElements/paper-radio-group#1.0.6",
|
||||
"paper-slider": "PolymerElements/paper-slider#1.0.7",
|
||||
"paper-styles": "PolymerElements/paper-styles#1.0.12",
|
||||
"paper-toggle-button": "PolymerElements/paper-toggle-button#1.0.11",
|
||||
"paper-toolbar": "PolymerElements/paper-toolbar#1.0.4",
|
||||
"plottable": "~1.16.1",
|
||||
"polymer": "~1.2.0"
|
||||
"polymer": "1.1.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
"iron-component-page": "PolymerElements/iron-component-page#^1.0.0",
|
||||
|
@ -17,9 +17,9 @@ import json
|
||||
import mimetypes
|
||||
import os
|
||||
import StringIO
|
||||
import urllib
|
||||
import urlparse
|
||||
|
||||
from six.moves import urllib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from google.protobuf import text_format
|
||||
import tensorflow.python.platform
|
||||
@ -289,7 +289,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
A string representation of a URL that will load the index-th
|
||||
sampled image in the given run with the given tag.
|
||||
"""
|
||||
query_string = urllib.urlencode({
|
||||
query_string = urllib.parse.urlencode({
|
||||
'run': run,
|
||||
'tag': tag,
|
||||
'index': index
|
||||
|
Loading…
Reference in New Issue
Block a user